14c4876c3SAlex Zinenko //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===// 24c4876c3SAlex Zinenko // 34c4876c3SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 44c4876c3SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 54c4876c3SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 64c4876c3SAlex Zinenko // 74c4876c3SAlex Zinenko //===----------------------------------------------------------------------===// 84c4876c3SAlex Zinenko 94c4876c3SAlex Zinenko #include "GPUOpsLowering.h" 10888717e8SNicolas Vasilache 11888717e8SNicolas Vasilache #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" 12e1da6291SKrzysztof Drewniak #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 13499abb24SKrzysztof Drewniak #include "mlir/IR/Attributes.h" 144c4876c3SAlex Zinenko #include "mlir/IR/Builders.h" 15b251b608SChristian Sigg #include "mlir/IR/BuiltinTypes.h" 1617faae95SLaszlo Kindrat #include "llvm/ADT/SmallVectorExtras.h" 17ea84897bSGuray Ozen #include "llvm/ADT/StringSet.h" 184c4876c3SAlex Zinenko #include "llvm/Support/FormatVariadic.h" 194c4876c3SAlex Zinenko 204c4876c3SAlex Zinenko using namespace mlir; 214c4876c3SAlex Zinenko 22*599c7399SMatthias Springer LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp, 23*599c7399SMatthias Springer Location loc, OpBuilder &b, 24*599c7399SMatthias Springer StringRef name, 25*599c7399SMatthias Springer LLVM::LLVMFunctionType type) { 26*599c7399SMatthias Springer LLVM::LLVMFuncOp ret; 27*599c7399SMatthias Springer if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) { 28*599c7399SMatthias Springer OpBuilder::InsertionGuard guard(b); 29*599c7399SMatthias Springer b.setInsertionPointToStart(moduleOp.getBody()); 30*599c7399SMatthias Springer ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External); 31*599c7399SMatthias Springer } 32*599c7399SMatthias Springer return ret; 33*599c7399SMatthias Springer } 34*599c7399SMatthias Springer 35*599c7399SMatthias Springer static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp, 36*599c7399SMatthias Springer StringRef prefix) { 37*599c7399SMatthias Springer // Get a unique global name. 38*599c7399SMatthias Springer unsigned stringNumber = 0; 39*599c7399SMatthias Springer SmallString<16> stringConstName; 40*599c7399SMatthias Springer do { 41*599c7399SMatthias Springer stringConstName.clear(); 42*599c7399SMatthias Springer (prefix + Twine(stringNumber++)).toStringRef(stringConstName); 43*599c7399SMatthias Springer } while (moduleOp.lookupSymbol(stringConstName)); 44*599c7399SMatthias Springer return stringConstName; 45*599c7399SMatthias Springer } 46*599c7399SMatthias Springer 47*599c7399SMatthias Springer LLVM::GlobalOp 48*599c7399SMatthias Springer mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, 49*599c7399SMatthias Springer gpu::GPUModuleOp moduleOp, Type llvmI8, 50*599c7399SMatthias Springer StringRef namePrefix, StringRef str, 51*599c7399SMatthias Springer uint64_t alignment, unsigned addrSpace) { 52*599c7399SMatthias Springer llvm::SmallString<20> nullTermStr(str); 53*599c7399SMatthias Springer nullTermStr.push_back('\0'); // Null terminate for C 54*599c7399SMatthias Springer auto globalType = 55*599c7399SMatthias Springer LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes()); 56*599c7399SMatthias Springer StringAttr attr = b.getStringAttr(nullTermStr); 57*599c7399SMatthias Springer 58*599c7399SMatthias Springer // Try to find existing global. 59*599c7399SMatthias Springer for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>()) 60*599c7399SMatthias Springer if (globalOp.getGlobalType() == globalType && globalOp.getConstant() && 61*599c7399SMatthias Springer globalOp.getValueAttr() == attr && 62*599c7399SMatthias Springer globalOp.getAlignment().value_or(0) == alignment && 63*599c7399SMatthias Springer globalOp.getAddrSpace() == addrSpace) 64*599c7399SMatthias Springer return globalOp; 65*599c7399SMatthias Springer 66*599c7399SMatthias Springer // Not found: create new global. 67*599c7399SMatthias Springer OpBuilder::InsertionGuard guard(b); 68*599c7399SMatthias Springer b.setInsertionPointToStart(moduleOp.getBody()); 69*599c7399SMatthias Springer SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix); 70*599c7399SMatthias Springer return b.create<LLVM::GlobalOp>(loc, globalType, 71*599c7399SMatthias Springer /*isConstant=*/true, LLVM::Linkage::Internal, 72*599c7399SMatthias Springer name, attr, alignment, addrSpace); 73*599c7399SMatthias Springer } 74*599c7399SMatthias Springer 754c4876c3SAlex Zinenko LogicalResult 76ef976337SRiver Riddle GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, 774c4876c3SAlex Zinenko ConversionPatternRewriter &rewriter) const { 784c4876c3SAlex Zinenko Location loc = gpuFuncOp.getLoc(); 794c4876c3SAlex Zinenko 804c4876c3SAlex Zinenko SmallVector<LLVM::GlobalOp, 3> workgroupBuffers; 81d45de800SVictor Perez if (encodeWorkgroupAttributionsAsArguments) { 82d45de800SVictor Perez // Append an `llvm.ptr` argument to the function signature to encode 83d45de800SVictor Perez // workgroup attributions. 84d45de800SVictor Perez 85d45de800SVictor Perez ArrayRef<BlockArgument> workgroupAttributions = 86d45de800SVictor Perez gpuFuncOp.getWorkgroupAttributions(); 87d45de800SVictor Perez size_t numAttributions = workgroupAttributions.size(); 88d45de800SVictor Perez 89d45de800SVictor Perez // Insert all arguments at the end. 90d45de800SVictor Perez unsigned index = gpuFuncOp.getNumArguments(); 91d45de800SVictor Perez SmallVector<unsigned> argIndices(numAttributions, index); 92d45de800SVictor Perez 93d45de800SVictor Perez // New arguments will simply be `llvm.ptr` with the correct address space 94d45de800SVictor Perez Type workgroupPtrType = 95d45de800SVictor Perez rewriter.getType<LLVM::LLVMPointerType>(workgroupAddrSpace); 96d45de800SVictor Perez SmallVector<Type> argTypes(numAttributions, workgroupPtrType); 97d45de800SVictor Perez 98d45de800SVictor Perez // Attributes: noalias, llvm.mlir.workgroup_attribution(<size>, <type>) 99d45de800SVictor Perez std::array attrs{ 100d45de800SVictor Perez rewriter.getNamedAttr(LLVM::LLVMDialect::getNoAliasAttrName(), 101d45de800SVictor Perez rewriter.getUnitAttr()), 102d45de800SVictor Perez rewriter.getNamedAttr( 103d45de800SVictor Perez getDialect().getWorkgroupAttributionAttrHelper().getName(), 104d45de800SVictor Perez rewriter.getUnitAttr()), 105d45de800SVictor Perez }; 106d45de800SVictor Perez SmallVector<DictionaryAttr> argAttrs; 107d45de800SVictor Perez for (BlockArgument attribution : workgroupAttributions) { 108d45de800SVictor Perez auto attributionType = cast<MemRefType>(attribution.getType()); 109d45de800SVictor Perez IntegerAttr numElements = 110d45de800SVictor Perez rewriter.getI64IntegerAttr(attributionType.getNumElements()); 111d45de800SVictor Perez Type llvmElementType = 112d45de800SVictor Perez getTypeConverter()->convertType(attributionType.getElementType()); 113d45de800SVictor Perez if (!llvmElementType) 114d45de800SVictor Perez return failure(); 115d45de800SVictor Perez TypeAttr type = TypeAttr::get(llvmElementType); 116d45de800SVictor Perez attrs.back().setValue( 117d45de800SVictor Perez rewriter.getAttr<LLVM::WorkgroupAttributionAttr>(numElements, type)); 118d45de800SVictor Perez argAttrs.push_back(rewriter.getDictionaryAttr(attrs)); 119d45de800SVictor Perez } 120d45de800SVictor Perez 121d45de800SVictor Perez // Location match function location 122d45de800SVictor Perez SmallVector<Location> argLocs(numAttributions, gpuFuncOp.getLoc()); 123d45de800SVictor Perez 124d45de800SVictor Perez // Perform signature modification 125d45de800SVictor Perez rewriter.modifyOpInPlace( 126d45de800SVictor Perez gpuFuncOp, [gpuFuncOp, &argIndices, &argTypes, &argAttrs, &argLocs]() { 127d45de800SVictor Perez static_cast<FunctionOpInterface>(gpuFuncOp).insertArguments( 128d45de800SVictor Perez argIndices, argTypes, argAttrs, argLocs); 129d45de800SVictor Perez }); 130d45de800SVictor Perez } else { 1314c4876c3SAlex Zinenko workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions()); 132d45de800SVictor Perez for (auto [idx, attribution] : 133ddd6acd7SKrzysztof Drewniak llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) { 1345550c821STres Popp auto type = dyn_cast<MemRefType>(attribution.getType()); 1354c4876c3SAlex Zinenko assert(type && type.hasStaticShape() && "unexpected type in attribution"); 1364c4876c3SAlex Zinenko 1374c4876c3SAlex Zinenko uint64_t numElements = type.getNumElements(); 1384c4876c3SAlex Zinenko 1394c4876c3SAlex Zinenko auto elementType = 1405550c821STres Popp cast<Type>(typeConverter->convertType(type.getElementType())); 1414c4876c3SAlex Zinenko auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements); 142ddd6acd7SKrzysztof Drewniak std::string name = 143ddd6acd7SKrzysztof Drewniak std::string(llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), idx)); 14494058c41SKrzysztof Drewniak uint64_t alignment = 0; 145d45de800SVictor Perez if (auto alignAttr = dyn_cast_or_null<IntegerAttr>( 146d45de800SVictor Perez gpuFuncOp.getWorkgroupAttributionAttr( 147ddd6acd7SKrzysztof Drewniak idx, LLVM::LLVMDialect::getAlignAttrName()))) 14894058c41SKrzysztof Drewniak alignment = alignAttr.getInt(); 1494c4876c3SAlex Zinenko auto globalOp = rewriter.create<LLVM::GlobalOp>( 1504c4876c3SAlex Zinenko gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, 15194058c41SKrzysztof Drewniak LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment, 15294058c41SKrzysztof Drewniak workgroupAddrSpace); 1534c4876c3SAlex Zinenko workgroupBuffers.push_back(globalOp); 1544c4876c3SAlex Zinenko } 155d45de800SVictor Perez } 1564c4876c3SAlex Zinenko 1574c4876c3SAlex Zinenko // Remap proper input types. 1584c4876c3SAlex Zinenko TypeConverter::SignatureConversion signatureConversion( 1594c4876c3SAlex Zinenko gpuFuncOp.front().getNumArguments()); 16067754a9dSNicolas Vasilache 1610e5aeae6SMarkus Böck Type funcType = getTypeConverter()->convertFunctionSignature( 162162f7572SMahesh Ravishankar gpuFuncOp.getFunctionType(), /*isVariadic=*/false, 163162f7572SMahesh Ravishankar getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion); 16467754a9dSNicolas Vasilache if (!funcType) { 16567754a9dSNicolas Vasilache return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) { 16667754a9dSNicolas Vasilache diag << "failed to convert function signature type for: " 16767754a9dSNicolas Vasilache << gpuFuncOp.getFunctionType(); 16867754a9dSNicolas Vasilache }); 16967754a9dSNicolas Vasilache } 1704c4876c3SAlex Zinenko 1714c4876c3SAlex Zinenko // Create the new function operation. Only copy those attributes that are 1724c4876c3SAlex Zinenko // not specific to function modeling. 1734c4876c3SAlex Zinenko SmallVector<NamedAttribute, 4> attributes; 174fbf67bfaSstefankoncarevic ArrayAttr argAttrs; 17556774bddSMarius Brehler for (const auto &attr : gpuFuncOp->getAttrs()) { 1760c7890c8SRiver Riddle if (attr.getName() == SymbolTable::getSymbolAttrName() || 17753406427SJeff Niu attr.getName() == gpuFuncOp.getFunctionTypeAttrName() || 17894058c41SKrzysztof Drewniak attr.getName() == 17994058c41SKrzysztof Drewniak gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() || 18094058c41SKrzysztof Drewniak attr.getName() == gpuFuncOp.getWorkgroupAttribAttrsAttrName() || 18143fd4c49SKrzysztof Drewniak attr.getName() == gpuFuncOp.getPrivateAttribAttrsAttrName() || 18243fd4c49SKrzysztof Drewniak attr.getName() == gpuFuncOp.getKnownBlockSizeAttrName() || 18343fd4c49SKrzysztof Drewniak attr.getName() == gpuFuncOp.getKnownGridSizeAttrName()) 1844c4876c3SAlex Zinenko continue; 185fbf67bfaSstefankoncarevic if (attr.getName() == gpuFuncOp.getArgAttrsAttrName()) { 186fbf67bfaSstefankoncarevic argAttrs = gpuFuncOp.getArgAttrsAttr(); 187fbf67bfaSstefankoncarevic continue; 188fbf67bfaSstefankoncarevic } 1894c4876c3SAlex Zinenko attributes.push_back(attr); 1904c4876c3SAlex Zinenko } 19143fd4c49SKrzysztof Drewniak 19243fd4c49SKrzysztof Drewniak DenseI32ArrayAttr knownBlockSize = gpuFuncOp.getKnownBlockSizeAttr(); 19343fd4c49SKrzysztof Drewniak DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr(); 19443fd4c49SKrzysztof Drewniak // Ensure we don't lose information if the function is lowered before its 19543fd4c49SKrzysztof Drewniak // surrounding context. 19643fd4c49SKrzysztof Drewniak auto *gpuDialect = cast<gpu::GPUDialect>(gpuFuncOp->getDialect()); 19743fd4c49SKrzysztof Drewniak if (knownBlockSize) 19843fd4c49SKrzysztof Drewniak attributes.emplace_back(gpuDialect->getKnownBlockSizeAttrHelper().getName(), 19943fd4c49SKrzysztof Drewniak knownBlockSize); 20043fd4c49SKrzysztof Drewniak if (knownGridSize) 20143fd4c49SKrzysztof Drewniak attributes.emplace_back(gpuDialect->getKnownGridSizeAttrHelper().getName(), 20243fd4c49SKrzysztof Drewniak knownGridSize); 20343fd4c49SKrzysztof Drewniak 2044c4876c3SAlex Zinenko // Add a dialect specific kernel attribute in addition to GPU kernel 2054c4876c3SAlex Zinenko // attribute. The former is necessary for further translation while the 2064c4876c3SAlex Zinenko // latter is expected by gpu.launch_func. 207763109e3SGuray Ozen if (gpuFuncOp.isKernel()) { 208d45de800SVictor Perez if (kernelAttributeName) 2094c4876c3SAlex Zinenko attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr()); 21043fd4c49SKrzysztof Drewniak // Set the dialect-specific block size attribute if there is one. 211d45de800SVictor Perez if (kernelBlockSizeAttributeName && knownBlockSize) { 212d45de800SVictor Perez attributes.emplace_back(kernelBlockSizeAttributeName, knownBlockSize); 213763109e3SGuray Ozen } 214763109e3SGuray Ozen } 215d45de800SVictor Perez LLVM::CConv callingConvention = gpuFuncOp.isKernel() 216d45de800SVictor Perez ? kernelCallingConvention 217d45de800SVictor Perez : nonKernelCallingConvention; 2184c4876c3SAlex Zinenko auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>( 2194c4876c3SAlex Zinenko gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, 220d45de800SVictor Perez LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention, 221b126ee65STobias Gysi /*comdat=*/nullptr, attributes); 2224c4876c3SAlex Zinenko 2234c4876c3SAlex Zinenko { 2244c4876c3SAlex Zinenko // Insert operations that correspond to converted workgroup and private 2254c4876c3SAlex Zinenko // memory attributions to the body of the function. This must operate on 2264c4876c3SAlex Zinenko // the original function, before the body region is inlined in the new 2274c4876c3SAlex Zinenko // function to maintain the relation between block arguments and the 2284c4876c3SAlex Zinenko // parent operation that assigns their semantics. 2294c4876c3SAlex Zinenko OpBuilder::InsertionGuard guard(rewriter); 2304c4876c3SAlex Zinenko 2314c4876c3SAlex Zinenko // Rewrite workgroup memory attributions to addresses of global buffers. 2324c4876c3SAlex Zinenko rewriter.setInsertionPointToStart(&gpuFuncOp.front()); 2334c4876c3SAlex Zinenko unsigned numProperArguments = gpuFuncOp.getNumArguments(); 2344c4876c3SAlex Zinenko 235d45de800SVictor Perez if (encodeWorkgroupAttributionsAsArguments) { 236d45de800SVictor Perez // Build a MemRefDescriptor with each of the arguments added above. 237d45de800SVictor Perez 238d45de800SVictor Perez unsigned numAttributions = gpuFuncOp.getNumWorkgroupAttributions(); 239d45de800SVictor Perez assert(numProperArguments >= numAttributions && 240d45de800SVictor Perez "Expecting attributions to be encoded as arguments already"); 241d45de800SVictor Perez 242d45de800SVictor Perez // Arguments encoding workgroup attributions will be in positions 243d45de800SVictor Perez // [numProperArguments, numProperArguments+numAttributions) 244d45de800SVictor Perez ArrayRef<BlockArgument> attributionArguments = 245d45de800SVictor Perez gpuFuncOp.getArguments().slice(numProperArguments - numAttributions, 246d45de800SVictor Perez numAttributions); 247d45de800SVictor Perez for (auto [idx, vals] : llvm::enumerate(llvm::zip_equal( 248d45de800SVictor Perez gpuFuncOp.getWorkgroupAttributions(), attributionArguments))) { 249d45de800SVictor Perez auto [attribution, arg] = vals; 250d45de800SVictor Perez auto type = cast<MemRefType>(attribution.getType()); 251d45de800SVictor Perez 252d45de800SVictor Perez // Arguments are of llvm.ptr type and attributions are of memref type: 253d45de800SVictor Perez // we need to wrap them in memref descriptors. 254d45de800SVictor Perez Value descr = MemRefDescriptor::fromStaticShape( 255d45de800SVictor Perez rewriter, loc, *getTypeConverter(), type, arg); 256d45de800SVictor Perez 257d45de800SVictor Perez // And remap the arguments 258d45de800SVictor Perez signatureConversion.remapInput(numProperArguments + idx, descr); 259d45de800SVictor Perez } 260d45de800SVictor Perez } else { 261ddd6acd7SKrzysztof Drewniak for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) { 26297a238e8SChristian Ulmann auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 26397a238e8SChristian Ulmann global.getAddrSpace()); 2640e5aeae6SMarkus Böck Value address = rewriter.create<LLVM::AddressOfOp>( 26597a238e8SChristian Ulmann loc, ptrType, global.getSymNameAttr()); 26697a238e8SChristian Ulmann Value memory = 267d45de800SVictor Perez rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(), 268d45de800SVictor Perez address, ArrayRef<LLVM::GEPArg>{0, 0}); 2694c4876c3SAlex Zinenko 2704c4876c3SAlex Zinenko // Build a memref descriptor pointing to the buffer to plug with the 2714c4876c3SAlex Zinenko // existing memref infrastructure. This may use more registers than 2724c4876c3SAlex Zinenko // otherwise necessary given that memref sizes are fixed, but we can try 2734c4876c3SAlex Zinenko // and canonicalize that away later. 274ddd6acd7SKrzysztof Drewniak Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx]; 2755550c821STres Popp auto type = cast<MemRefType>(attribution.getType()); 2764c4876c3SAlex Zinenko auto descr = MemRefDescriptor::fromStaticShape( 2774c4876c3SAlex Zinenko rewriter, loc, *getTypeConverter(), type, memory); 278ddd6acd7SKrzysztof Drewniak signatureConversion.remapInput(numProperArguments + idx, descr); 2794c4876c3SAlex Zinenko } 280d45de800SVictor Perez } 2814c4876c3SAlex Zinenko 2824c4876c3SAlex Zinenko // Rewrite private memory attributions to alloca'ed buffers. 2834c4876c3SAlex Zinenko unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions(); 2844c4876c3SAlex Zinenko auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 285ddd6acd7SKrzysztof Drewniak for (const auto [idx, attribution] : 286ddd6acd7SKrzysztof Drewniak llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { 2875550c821STres Popp auto type = cast<MemRefType>(attribution.getType()); 2884c4876c3SAlex Zinenko assert(type && type.hasStaticShape() && "unexpected type in attribution"); 2894c4876c3SAlex Zinenko 2904c4876c3SAlex Zinenko // Explicitly drop memory space when lowering private memory 2914c4876c3SAlex Zinenko // attributions since NVVM models it as `alloca`s in the default 2924c4876c3SAlex Zinenko // memory space and does not support `alloca`s with addrspace(5). 2930e5aeae6SMarkus Böck Type elementType = typeConverter->convertType(type.getElementType()); 2940e5aeae6SMarkus Böck auto ptrType = 29597a238e8SChristian Ulmann LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace); 2964c4876c3SAlex Zinenko Value numElements = rewriter.create<LLVM::ConstantOp>( 2970af643f3SJeff Niu gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); 29894058c41SKrzysztof Drewniak uint64_t alignment = 0; 29994058c41SKrzysztof Drewniak if (auto alignAttr = 3005550c821STres Popp dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr( 301ddd6acd7SKrzysztof Drewniak idx, LLVM::LLVMDialect::getAlignAttrName()))) 30294058c41SKrzysztof Drewniak alignment = alignAttr.getInt(); 3034c4876c3SAlex Zinenko Value allocated = rewriter.create<LLVM::AllocaOp>( 30494058c41SKrzysztof Drewniak gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment); 3054c4876c3SAlex Zinenko auto descr = MemRefDescriptor::fromStaticShape( 3064c4876c3SAlex Zinenko rewriter, loc, *getTypeConverter(), type, allocated); 3074c4876c3SAlex Zinenko signatureConversion.remapInput( 308ddd6acd7SKrzysztof Drewniak numProperArguments + numWorkgroupAttributions + idx, descr); 3094c4876c3SAlex Zinenko } 3104c4876c3SAlex Zinenko } 3114c4876c3SAlex Zinenko 3124c4876c3SAlex Zinenko // Move the region to the new function, update the entry block signature. 3134c4876c3SAlex Zinenko rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(), 3144c4876c3SAlex Zinenko llvmFuncOp.end()); 3154c4876c3SAlex Zinenko if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter, 3164c4876c3SAlex Zinenko &signatureConversion))) 3174c4876c3SAlex Zinenko return failure(); 3184c4876c3SAlex Zinenko 319fbf67bfaSstefankoncarevic // Get memref type from function arguments and set the noalias to 320fbf67bfaSstefankoncarevic // pointer arguments. 321ddd6acd7SKrzysztof Drewniak for (const auto [idx, argTy] : 322ddd6acd7SKrzysztof Drewniak llvm::enumerate(gpuFuncOp.getArgumentTypes())) { 323ddd6acd7SKrzysztof Drewniak auto remapping = signatureConversion.getInputMapping(idx); 324ddd6acd7SKrzysztof Drewniak NamedAttrList argAttr = 325a5757c5bSChristian Sigg argAttrs ? cast<DictionaryAttr>(argAttrs[idx]) : NamedAttrList(); 326ddd6acd7SKrzysztof Drewniak auto copyAttribute = [&](StringRef attrName) { 327ddd6acd7SKrzysztof Drewniak Attribute attr = argAttr.erase(attrName); 328ddd6acd7SKrzysztof Drewniak if (!attr) 329ddd6acd7SKrzysztof Drewniak return; 330ddd6acd7SKrzysztof Drewniak for (size_t i = 0, e = remapping->size; i < e; ++i) 331ddd6acd7SKrzysztof Drewniak llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr); 332ddd6acd7SKrzysztof Drewniak }; 333fbf67bfaSstefankoncarevic auto copyPointerAttribute = [&](StringRef attrName) { 334fbf67bfaSstefankoncarevic Attribute attr = argAttr.erase(attrName); 335fbf67bfaSstefankoncarevic 336fbf67bfaSstefankoncarevic if (!attr) 337fbf67bfaSstefankoncarevic return; 338fbf67bfaSstefankoncarevic if (remapping->size > 1 && 339fbf67bfaSstefankoncarevic attrName == LLVM::LLVMDialect::getNoAliasAttrName()) { 340fbf67bfaSstefankoncarevic emitWarning(llvmFuncOp.getLoc(), 341fbf67bfaSstefankoncarevic "Cannot copy noalias with non-bare pointers.\n"); 342fbf67bfaSstefankoncarevic return; 343fbf67bfaSstefankoncarevic } 344fbf67bfaSstefankoncarevic for (size_t i = 0, e = remapping->size; i < e; ++i) { 345a5757c5bSChristian Sigg if (isa<LLVM::LLVMPointerType>( 346a5757c5bSChristian Sigg llvmFuncOp.getArgument(remapping->inputNo + i).getType())) { 347fbf67bfaSstefankoncarevic llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr); 348fbf67bfaSstefankoncarevic } 349fbf67bfaSstefankoncarevic } 350fbf67bfaSstefankoncarevic }; 351fbf67bfaSstefankoncarevic 352fbf67bfaSstefankoncarevic if (argAttr.empty()) 353fbf67bfaSstefankoncarevic continue; 354fbf67bfaSstefankoncarevic 355ddd6acd7SKrzysztof Drewniak copyAttribute(LLVM::LLVMDialect::getReturnedAttrName()); 356ddd6acd7SKrzysztof Drewniak copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName()); 357ddd6acd7SKrzysztof Drewniak copyAttribute(LLVM::LLVMDialect::getInRegAttrName()); 358ddd6acd7SKrzysztof Drewniak bool lowersToPointer = false; 359ddd6acd7SKrzysztof Drewniak for (size_t i = 0, e = remapping->size; i < e; ++i) { 360ddd6acd7SKrzysztof Drewniak lowersToPointer |= isa<LLVM::LLVMPointerType>( 361ddd6acd7SKrzysztof Drewniak llvmFuncOp.getArgument(remapping->inputNo + i).getType()); 362ddd6acd7SKrzysztof Drewniak } 363ddd6acd7SKrzysztof Drewniak 364ddd6acd7SKrzysztof Drewniak if (lowersToPointer) { 365fbf67bfaSstefankoncarevic copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName()); 366ddd6acd7SKrzysztof Drewniak copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName()); 367ddd6acd7SKrzysztof Drewniak copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName()); 368ddd6acd7SKrzysztof Drewniak copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName()); 369fbf67bfaSstefankoncarevic copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName()); 370fbf67bfaSstefankoncarevic copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName()); 371ddd6acd7SKrzysztof Drewniak copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName()); 372fbf67bfaSstefankoncarevic copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName()); 373fbf67bfaSstefankoncarevic copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName()); 374fbf67bfaSstefankoncarevic copyPointerAttribute( 375fbf67bfaSstefankoncarevic LLVM::LLVMDialect::getDereferenceableOrNullAttrName()); 376d45de800SVictor Perez copyPointerAttribute( 377d45de800SVictor Perez LLVM::LLVMDialect::WorkgroupAttributionAttrHelper::getNameStr()); 378fbf67bfaSstefankoncarevic } 379fbf67bfaSstefankoncarevic } 3804c4876c3SAlex Zinenko rewriter.eraseOp(gpuFuncOp); 3814c4876c3SAlex Zinenko return success(); 3824c4876c3SAlex Zinenko } 383e1da6291SKrzysztof Drewniak 384e1da6291SKrzysztof Drewniak LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( 385e1da6291SKrzysztof Drewniak gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, 386e1da6291SKrzysztof Drewniak ConversionPatternRewriter &rewriter) const { 387e1da6291SKrzysztof Drewniak Location loc = gpuPrintfOp->getLoc(); 388e1da6291SKrzysztof Drewniak 389e1da6291SKrzysztof Drewniak mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type()); 39097a238e8SChristian Ulmann auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); 391e1da6291SKrzysztof Drewniak mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type()); 392e1da6291SKrzysztof Drewniak mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type()); 393e1da6291SKrzysztof Drewniak // Note: this is the GPUModule op, not the ModuleOp that surrounds it 394e1da6291SKrzysztof Drewniak // This ensures that global constants and declarations are placed within 395e1da6291SKrzysztof Drewniak // the device code, not the host code 396e1da6291SKrzysztof Drewniak auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>(); 397e1da6291SKrzysztof Drewniak 398e1da6291SKrzysztof Drewniak auto ocklBegin = 399e1da6291SKrzysztof Drewniak getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin", 400e1da6291SKrzysztof Drewniak LLVM::LLVMFunctionType::get(llvmI64, {llvmI64})); 401e1da6291SKrzysztof Drewniak LLVM::LLVMFuncOp ocklAppendArgs; 40210c04f46SRiver Riddle if (!adaptor.getArgs().empty()) { 403e1da6291SKrzysztof Drewniak ocklAppendArgs = getOrDefineFunction( 404e1da6291SKrzysztof Drewniak moduleOp, loc, rewriter, "__ockl_printf_append_args", 405e1da6291SKrzysztof Drewniak LLVM::LLVMFunctionType::get( 406e1da6291SKrzysztof Drewniak llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64, 407e1da6291SKrzysztof Drewniak llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32})); 408e1da6291SKrzysztof Drewniak } 409e1da6291SKrzysztof Drewniak auto ocklAppendStringN = getOrDefineFunction( 410e1da6291SKrzysztof Drewniak moduleOp, loc, rewriter, "__ockl_printf_append_string_n", 411e1da6291SKrzysztof Drewniak LLVM::LLVMFunctionType::get( 412e1da6291SKrzysztof Drewniak llvmI64, 41397a238e8SChristian Ulmann {llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32})); 414e1da6291SKrzysztof Drewniak 415e1da6291SKrzysztof Drewniak /// Start the printf hostcall 4160af643f3SJeff Niu Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0); 417e1da6291SKrzysztof Drewniak auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64); 4185e0c3b43SJeff Niu Value printfDesc = printfBeginCall.getResult(); 419e1da6291SKrzysztof Drewniak 4202da417e7SMatthias Springer // Create the global op or find an existing one. 421*599c7399SMatthias Springer LLVM::GlobalOp global = getOrCreateStringConstant( 422*599c7399SMatthias Springer rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat()); 423e1da6291SKrzysztof Drewniak 424e1da6291SKrzysztof Drewniak // Get a pointer to the format string's first element and pass it to printf() 4250e5aeae6SMarkus Böck Value globalPtr = rewriter.create<LLVM::AddressOfOp>( 4260e5aeae6SMarkus Böck loc, 42797a238e8SChristian Ulmann LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()), 4280e5aeae6SMarkus Böck global.getSymNameAttr()); 4292da417e7SMatthias Springer Value stringStart = 4302da417e7SMatthias Springer rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(), 4312da417e7SMatthias Springer globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); 4322da417e7SMatthias Springer Value stringLen = rewriter.create<LLVM::ConstantOp>( 4332da417e7SMatthias Springer loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size()); 434e1da6291SKrzysztof Drewniak 4350af643f3SJeff Niu Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1); 4360af643f3SJeff Niu Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0); 437e1da6291SKrzysztof Drewniak 43879a0330aSMehdi Amini auto appendFormatCall = rewriter.create<LLVM::CallOp>( 43979a0330aSMehdi Amini loc, ocklAppendStringN, 44079a0330aSMehdi Amini ValueRange{printfDesc, stringStart, stringLen, 44110c04f46SRiver Riddle adaptor.getArgs().empty() ? oneI32 : zeroI32}); 4425e0c3b43SJeff Niu printfDesc = appendFormatCall.getResult(); 443e1da6291SKrzysztof Drewniak 444e1da6291SKrzysztof Drewniak // __ockl_printf_append_args takes 7 values per append call 445e1da6291SKrzysztof Drewniak constexpr size_t argsPerAppend = 7; 44610c04f46SRiver Riddle size_t nArgs = adaptor.getArgs().size(); 447e1da6291SKrzysztof Drewniak for (size_t group = 0; group < nArgs; group += argsPerAppend) { 448e1da6291SKrzysztof Drewniak size_t bound = std::min(group + argsPerAppend, nArgs); 449e1da6291SKrzysztof Drewniak size_t numArgsThisCall = bound - group; 450e1da6291SKrzysztof Drewniak 451e1da6291SKrzysztof Drewniak SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments; 452e1da6291SKrzysztof Drewniak arguments.push_back(printfDesc); 4530af643f3SJeff Niu arguments.push_back( 4540af643f3SJeff Niu rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall)); 455e1da6291SKrzysztof Drewniak for (size_t i = group; i < bound; ++i) { 45610c04f46SRiver Riddle Value arg = adaptor.getArgs()[i]; 4575550c821STres Popp if (auto floatType = dyn_cast<FloatType>(arg.getType())) { 458e1da6291SKrzysztof Drewniak if (!floatType.isF64()) 459e1da6291SKrzysztof Drewniak arg = rewriter.create<LLVM::FPExtOp>( 460e1da6291SKrzysztof Drewniak loc, typeConverter->convertType(rewriter.getF64Type()), arg); 461e1da6291SKrzysztof Drewniak arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg); 462e1da6291SKrzysztof Drewniak } 463e1da6291SKrzysztof Drewniak if (arg.getType().getIntOrFloatBitWidth() != 64) 464e1da6291SKrzysztof Drewniak arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg); 465e1da6291SKrzysztof Drewniak 466e1da6291SKrzysztof Drewniak arguments.push_back(arg); 467e1da6291SKrzysztof Drewniak } 468e1da6291SKrzysztof Drewniak // Pad out to 7 arguments since the hostcall always needs 7 469e1da6291SKrzysztof Drewniak for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) { 470e1da6291SKrzysztof Drewniak arguments.push_back(zeroI64); 471e1da6291SKrzysztof Drewniak } 472e1da6291SKrzysztof Drewniak 473e1da6291SKrzysztof Drewniak auto isLast = (bound == nArgs) ? oneI32 : zeroI32; 474e1da6291SKrzysztof Drewniak arguments.push_back(isLast); 475e1da6291SKrzysztof Drewniak auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments); 4765e0c3b43SJeff Niu printfDesc = call.getResult(); 477e1da6291SKrzysztof Drewniak } 478e1da6291SKrzysztof Drewniak rewriter.eraseOp(gpuPrintfOp); 479e1da6291SKrzysztof Drewniak return success(); 480e1da6291SKrzysztof Drewniak } 481e1da6291SKrzysztof Drewniak 482e1da6291SKrzysztof Drewniak LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( 483e1da6291SKrzysztof Drewniak gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, 484e1da6291SKrzysztof Drewniak ConversionPatternRewriter &rewriter) const { 485e1da6291SKrzysztof Drewniak Location loc = gpuPrintfOp->getLoc(); 486e1da6291SKrzysztof Drewniak 487e1da6291SKrzysztof Drewniak mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8)); 48897a238e8SChristian Ulmann mlir::Type ptrType = 48997a238e8SChristian Ulmann LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); 490e1da6291SKrzysztof Drewniak 491e1da6291SKrzysztof Drewniak // Note: this is the GPUModule op, not the ModuleOp that surrounds it 492e1da6291SKrzysztof Drewniak // This ensures that global constants and declarations are placed within 493e1da6291SKrzysztof Drewniak // the device code, not the host code 494e1da6291SKrzysztof Drewniak auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>(); 495e1da6291SKrzysztof Drewniak 49697a238e8SChristian Ulmann auto printfType = 49797a238e8SChristian Ulmann LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType}, 498e1da6291SKrzysztof Drewniak /*isVarArg=*/true); 499e1da6291SKrzysztof Drewniak LLVM::LLVMFuncOp printfDecl = 500e1da6291SKrzysztof Drewniak getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType); 501e1da6291SKrzysztof Drewniak 5022da417e7SMatthias Springer // Create the global op or find an existing one. 503*599c7399SMatthias Springer LLVM::GlobalOp global = getOrCreateStringConstant( 504*599c7399SMatthias Springer rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(), 505*599c7399SMatthias Springer /*alignment=*/0, addressSpace); 506e1da6291SKrzysztof Drewniak 507e1da6291SKrzysztof Drewniak // Get a pointer to the format string's first element 5080e5aeae6SMarkus Böck Value globalPtr = rewriter.create<LLVM::AddressOfOp>( 5090e5aeae6SMarkus Böck loc, 51097a238e8SChristian Ulmann LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()), 5110e5aeae6SMarkus Böck global.getSymNameAttr()); 5122da417e7SMatthias Springer Value stringStart = 5132da417e7SMatthias Springer rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(), 5142da417e7SMatthias Springer globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); 515e1da6291SKrzysztof Drewniak 516e1da6291SKrzysztof Drewniak // Construct arguments and function call 51710c04f46SRiver Riddle auto argsRange = adaptor.getArgs(); 518e1da6291SKrzysztof Drewniak SmallVector<Value, 4> printfArgs; 519e1da6291SKrzysztof Drewniak printfArgs.reserve(argsRange.size() + 1); 520e1da6291SKrzysztof Drewniak printfArgs.push_back(stringStart); 521e1da6291SKrzysztof Drewniak printfArgs.append(argsRange.begin(), argsRange.end()); 522e1da6291SKrzysztof Drewniak 523e1da6291SKrzysztof Drewniak rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs); 524e1da6291SKrzysztof Drewniak rewriter.eraseOp(gpuPrintfOp); 525e1da6291SKrzysztof Drewniak return success(); 526e1da6291SKrzysztof Drewniak } 527b251b608SChristian Sigg 5287efdc117SThomas Raoux LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( 5297efdc117SThomas Raoux gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, 5307efdc117SThomas Raoux ConversionPatternRewriter &rewriter) const { 5317efdc117SThomas Raoux Location loc = gpuPrintfOp->getLoc(); 5327efdc117SThomas Raoux 5337efdc117SThomas Raoux mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8)); 534484668c7SChristian Ulmann mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); 5357efdc117SThomas Raoux 5367efdc117SThomas Raoux // Note: this is the GPUModule op, not the ModuleOp that surrounds it 5377efdc117SThomas Raoux // This ensures that global constants and declarations are placed within 5387efdc117SThomas Raoux // the device code, not the host code 5397efdc117SThomas Raoux auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>(); 5407efdc117SThomas Raoux 5417efdc117SThomas Raoux auto vprintfType = 542484668c7SChristian Ulmann LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType}); 5437efdc117SThomas Raoux LLVM::LLVMFuncOp vprintfDecl = 5447efdc117SThomas Raoux getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType); 5457efdc117SThomas Raoux 5462da417e7SMatthias Springer // Create the global op or find an existing one. 547*599c7399SMatthias Springer LLVM::GlobalOp global = getOrCreateStringConstant( 548*599c7399SMatthias Springer rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat()); 5497efdc117SThomas Raoux 5507efdc117SThomas Raoux // Get a pointer to the format string's first element 5517efdc117SThomas Raoux Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global); 5522da417e7SMatthias Springer Value stringStart = 5532da417e7SMatthias Springer rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(), 5542da417e7SMatthias Springer globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); 5557efdc117SThomas Raoux SmallVector<Type> types; 5567efdc117SThomas Raoux SmallVector<Value> args; 5577efdc117SThomas Raoux // Promote and pack the arguments into a stack allocation. 5587efdc117SThomas Raoux for (Value arg : adaptor.getArgs()) { 5597efdc117SThomas Raoux Type type = arg.getType(); 5607efdc117SThomas Raoux Value promotedArg = arg; 5617efdc117SThomas Raoux assert(type.isIntOrFloat()); 5625550c821STres Popp if (isa<FloatType>(type)) { 5637efdc117SThomas Raoux type = rewriter.getF64Type(); 5647efdc117SThomas Raoux promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg); 5657efdc117SThomas Raoux } 5667efdc117SThomas Raoux types.push_back(type); 5677efdc117SThomas Raoux args.push_back(promotedArg); 5687efdc117SThomas Raoux } 5697efdc117SThomas Raoux Type structType = 5707efdc117SThomas Raoux LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types); 5717efdc117SThomas Raoux Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), 5727efdc117SThomas Raoux rewriter.getIndexAttr(1)); 573484668c7SChristian Ulmann Value tempAlloc = 574484668c7SChristian Ulmann rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one, 5757efdc117SThomas Raoux /*alignment=*/0); 5767efdc117SThomas Raoux for (auto [index, arg] : llvm::enumerate(args)) { 5779397e5f5SChristian Ulmann Value ptr = rewriter.create<LLVM::GEPOp>( 57889cd3456SAndrei Golubev loc, ptrType, structType, tempAlloc, 57989cd3456SAndrei Golubev ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)}); 5807efdc117SThomas Raoux rewriter.create<LLVM::StoreOp>(loc, arg, ptr); 5817efdc117SThomas Raoux } 5827efdc117SThomas Raoux std::array<Value, 2> printfArgs = {stringStart, tempAlloc}; 5837efdc117SThomas Raoux 5847efdc117SThomas Raoux rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs); 5857efdc117SThomas Raoux rewriter.eraseOp(gpuPrintfOp); 5867efdc117SThomas Raoux return success(); 5877efdc117SThomas Raoux } 5887efdc117SThomas Raoux 589b251b608SChristian Sigg /// Unrolls op if it's operating on vectors. 590b251b608SChristian Sigg LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands, 591b251b608SChristian Sigg ConversionPatternRewriter &rewriter, 592ce254598SMatthias Springer const LLVMTypeConverter &converter) { 593b251b608SChristian Sigg TypeRange operandTypes(operands); 594971b8525SJakub Kuderski if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) { 595b251b608SChristian Sigg return rewriter.notifyMatchFailure(op, "expected vector operand"); 596b251b608SChristian Sigg } 597b251b608SChristian Sigg if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0) 598b251b608SChristian Sigg return rewriter.notifyMatchFailure(op, "expected no region/successor"); 599b251b608SChristian Sigg if (op->getNumResults() != 1) 600b251b608SChristian Sigg return rewriter.notifyMatchFailure(op, "expected single result"); 6015550c821STres Popp VectorType vectorType = dyn_cast<VectorType>(op->getResult(0).getType()); 602b251b608SChristian Sigg if (!vectorType) 603b251b608SChristian Sigg return rewriter.notifyMatchFailure(op, "expected vector result"); 604b251b608SChristian Sigg 605b251b608SChristian Sigg Location loc = op->getLoc(); 606b251b608SChristian Sigg Value result = rewriter.create<LLVM::UndefOp>(loc, vectorType); 607b251b608SChristian Sigg Type indexType = converter.convertType(rewriter.getIndexType()); 608b251b608SChristian Sigg StringAttr name = op->getName().getIdentifier(); 609b251b608SChristian Sigg Type elementType = vectorType.getElementType(); 610b251b608SChristian Sigg 611b251b608SChristian Sigg for (int64_t i = 0; i < vectorType.getNumElements(); ++i) { 612b251b608SChristian Sigg Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i); 613b251b608SChristian Sigg auto extractElement = [&](Value operand) -> Value { 6145550c821STres Popp if (!isa<VectorType>(operand.getType())) 615b251b608SChristian Sigg return operand; 616b251b608SChristian Sigg return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index); 617b251b608SChristian Sigg }; 61817faae95SLaszlo Kindrat auto scalarOperands = llvm::map_to_vector(operands, extractElement); 619b251b608SChristian Sigg Operation *scalarOp = 620b251b608SChristian Sigg rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs()); 62114858cf0SChristopher Bate result = rewriter.create<LLVM::InsertElementOp>( 62214858cf0SChristopher Bate loc, result, scalarOp->getResult(0), index); 623b251b608SChristian Sigg } 624b251b608SChristian Sigg 625b251b608SChristian Sigg rewriter.replaceOp(op, result); 626b251b608SChristian Sigg return success(); 627b251b608SChristian Sigg } 628499abb24SKrzysztof Drewniak 629499abb24SKrzysztof Drewniak static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) { 630499abb24SKrzysztof Drewniak return IntegerAttr::get(IntegerType::get(ctx, 64), space); 631499abb24SKrzysztof Drewniak } 632499abb24SKrzysztof Drewniak 633ea84897bSGuray Ozen /// Generates a symbol with 0-sized array type for dynamic shared memory usage, 634ea84897bSGuray Ozen /// or uses existing symbol. 63549df12c0SMatthias Springer LLVM::GlobalOp getDynamicSharedMemorySymbol( 63649df12c0SMatthias Springer ConversionPatternRewriter &rewriter, gpu::GPUModuleOp moduleOp, 63749df12c0SMatthias Springer gpu::DynamicSharedMemoryOp op, const LLVMTypeConverter *typeConverter, 638ea84897bSGuray Ozen MemRefType memrefType, unsigned alignmentBit) { 639ea84897bSGuray Ozen uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth(); 640ea84897bSGuray Ozen 641ea84897bSGuray Ozen FailureOr<unsigned> addressSpace = 642ea84897bSGuray Ozen typeConverter->getMemRefAddressSpace(memrefType); 643ea84897bSGuray Ozen if (failed(addressSpace)) { 644ea84897bSGuray Ozen op->emitError() << "conversion of memref memory space " 645ea84897bSGuray Ozen << memrefType.getMemorySpace() 646ea84897bSGuray Ozen << " to integer address space " 647ea84897bSGuray Ozen "failed. Consider adding memory space conversions."; 648ea84897bSGuray Ozen } 649ea84897bSGuray Ozen 650ea84897bSGuray Ozen // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of 651ea84897bSGuray Ozen // LLVM::GlobalOp is suitable for shared memory, return it. 652ea84897bSGuray Ozen llvm::StringSet<> existingGlobalNames; 65349df12c0SMatthias Springer for (auto globalOp : moduleOp.getBody()->getOps<LLVM::GlobalOp>()) { 654ea84897bSGuray Ozen existingGlobalNames.insert(globalOp.getSymName()); 655ea84897bSGuray Ozen if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) { 656ea84897bSGuray Ozen if (globalOp.getAddrSpace() == addressSpace.value() && 657ea84897bSGuray Ozen arrayType.getNumElements() == 0 && 658ea84897bSGuray Ozen globalOp.getAlignment().value_or(0) == alignmentByte) { 659ea84897bSGuray Ozen return globalOp; 660ea84897bSGuray Ozen } 661ea84897bSGuray Ozen } 662ea84897bSGuray Ozen } 663ea84897bSGuray Ozen 664ea84897bSGuray Ozen // Step 2. Find a unique symbol name 665ea84897bSGuray Ozen unsigned uniquingCounter = 0; 666ea84897bSGuray Ozen SmallString<128> symName = SymbolTable::generateSymbolName<128>( 667ea84897bSGuray Ozen "__dynamic_shmem_", 668ea84897bSGuray Ozen [&](StringRef candidate) { 669ea84897bSGuray Ozen return existingGlobalNames.contains(candidate); 670ea84897bSGuray Ozen }, 671ea84897bSGuray Ozen uniquingCounter); 672ea84897bSGuray Ozen 673ea84897bSGuray Ozen // Step 3. Generate a global op 674ea84897bSGuray Ozen OpBuilder::InsertionGuard guard(rewriter); 67549df12c0SMatthias Springer rewriter.setInsertionPointToStart(moduleOp.getBody()); 676ea84897bSGuray Ozen 677ea84897bSGuray Ozen auto zeroSizedArrayType = LLVM::LLVMArrayType::get( 678ea84897bSGuray Ozen typeConverter->convertType(memrefType.getElementType()), 0); 679ea84897bSGuray Ozen 680ea84897bSGuray Ozen return rewriter.create<LLVM::GlobalOp>( 681ea84897bSGuray Ozen op->getLoc(), zeroSizedArrayType, /*isConstant=*/false, 682ea84897bSGuray Ozen LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte, 683ea84897bSGuray Ozen addressSpace.value()); 684ea84897bSGuray Ozen } 685ea84897bSGuray Ozen 686ea84897bSGuray Ozen LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite( 687ea84897bSGuray Ozen gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor, 688ea84897bSGuray Ozen ConversionPatternRewriter &rewriter) const { 689ea84897bSGuray Ozen Location loc = op.getLoc(); 690ea84897bSGuray Ozen MemRefType memrefType = op.getResultMemref().getType(); 691ea84897bSGuray Ozen Type elementType = typeConverter->convertType(memrefType.getElementType()); 692ea84897bSGuray Ozen 693ea84897bSGuray Ozen // Step 1: Generate a memref<0xi8> type 694ea84897bSGuray Ozen MemRefLayoutAttrInterface layout = {}; 695ea84897bSGuray Ozen auto memrefType0sz = 696ea84897bSGuray Ozen MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace()); 697ea84897bSGuray Ozen 698ea84897bSGuray Ozen // Step 2: Generate a global symbol or existing for the dynamic shared 699ea84897bSGuray Ozen // memory with memref<0xi8> type 70049df12c0SMatthias Springer auto moduleOp = op->getParentOfType<gpu::GPUModuleOp>(); 70149df12c0SMatthias Springer LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol( 702ea84897bSGuray Ozen rewriter, moduleOp, op, getTypeConverter(), memrefType0sz, alignmentBit); 703ea84897bSGuray Ozen 704ea84897bSGuray Ozen // Step 3. Get address of the global symbol 705ea84897bSGuray Ozen OpBuilder::InsertionGuard guard(rewriter); 706ea84897bSGuray Ozen rewriter.setInsertionPoint(op); 707ea84897bSGuray Ozen auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp); 708ea84897bSGuray Ozen Type baseType = basePtr->getResultTypes().front(); 709ea84897bSGuray Ozen 710ea84897bSGuray Ozen // Step 4. Generate GEP using offsets 711ea84897bSGuray Ozen SmallVector<LLVM::GEPArg> gepArgs = {0}; 712ea84897bSGuray Ozen Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType, 713ea84897bSGuray Ozen basePtr, gepArgs); 714ea84897bSGuray Ozen // Step 5. Create a memref descriptor 715ea84897bSGuray Ozen SmallVector<Value> shape, strides; 716ea84897bSGuray Ozen Value sizeBytes; 717ea84897bSGuray Ozen getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides, 718ea84897bSGuray Ozen sizeBytes); 719ea84897bSGuray Ozen auto memRefDescriptor = this->createMemRefDescriptor( 720ea84897bSGuray Ozen loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter); 721ea84897bSGuray Ozen 722ea84897bSGuray Ozen // Step 5. Replace the op with memref descriptor 723ea84897bSGuray Ozen rewriter.replaceOp(op, {memRefDescriptor}); 724ea84897bSGuray Ozen return success(); 725ea84897bSGuray Ozen } 726ea84897bSGuray Ozen 7273f33d2f3SMatthias Springer LogicalResult GPUReturnOpLowering::matchAndRewrite( 7283f33d2f3SMatthias Springer gpu::ReturnOp op, OpAdaptor adaptor, 7293f33d2f3SMatthias Springer ConversionPatternRewriter &rewriter) const { 7303f33d2f3SMatthias Springer Location loc = op.getLoc(); 7313f33d2f3SMatthias Springer unsigned numArguments = op.getNumOperands(); 7323f33d2f3SMatthias Springer SmallVector<Value, 4> updatedOperands; 7333f33d2f3SMatthias Springer 7343f33d2f3SMatthias Springer bool useBarePtrCallConv = getTypeConverter()->getOptions().useBarePtrCallConv; 7353f33d2f3SMatthias Springer if (useBarePtrCallConv) { 7363f33d2f3SMatthias Springer // For the bare-ptr calling convention, extract the aligned pointer to 7373f33d2f3SMatthias Springer // be returned from the memref descriptor. 7383f33d2f3SMatthias Springer for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { 7393f33d2f3SMatthias Springer Type oldTy = std::get<0>(it).getType(); 7403f33d2f3SMatthias Springer Value newOperand = std::get<1>(it); 7413f33d2f3SMatthias Springer if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr( 7423f33d2f3SMatthias Springer cast<BaseMemRefType>(oldTy))) { 7433f33d2f3SMatthias Springer MemRefDescriptor memrefDesc(newOperand); 7443f33d2f3SMatthias Springer newOperand = memrefDesc.allocatedPtr(rewriter, loc); 7453f33d2f3SMatthias Springer } else if (isa<UnrankedMemRefType>(oldTy)) { 7463f33d2f3SMatthias Springer // Unranked memref is not supported in the bare pointer calling 7473f33d2f3SMatthias Springer // convention. 7483f33d2f3SMatthias Springer return failure(); 7493f33d2f3SMatthias Springer } 7503f33d2f3SMatthias Springer updatedOperands.push_back(newOperand); 7513f33d2f3SMatthias Springer } 7523f33d2f3SMatthias Springer } else { 7533f33d2f3SMatthias Springer updatedOperands = llvm::to_vector<4>(adaptor.getOperands()); 7543f33d2f3SMatthias Springer (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(), 7553f33d2f3SMatthias Springer updatedOperands, 7563f33d2f3SMatthias Springer /*toDynamic=*/true); 7573f33d2f3SMatthias Springer } 7583f33d2f3SMatthias Springer 7593f33d2f3SMatthias Springer // If ReturnOp has 0 or 1 operand, create it and return immediately. 7603f33d2f3SMatthias Springer if (numArguments <= 1) { 7613f33d2f3SMatthias Springer rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( 7623f33d2f3SMatthias Springer op, TypeRange(), updatedOperands, op->getAttrs()); 7633f33d2f3SMatthias Springer return success(); 7643f33d2f3SMatthias Springer } 7653f33d2f3SMatthias Springer 7663f33d2f3SMatthias Springer // Otherwise, we need to pack the arguments into an LLVM struct type before 7673f33d2f3SMatthias Springer // returning. 7683f33d2f3SMatthias Springer auto packedType = getTypeConverter()->packFunctionResults( 7693f33d2f3SMatthias Springer op.getOperandTypes(), useBarePtrCallConv); 7703f33d2f3SMatthias Springer if (!packedType) { 7713f33d2f3SMatthias Springer return rewriter.notifyMatchFailure(op, "could not convert result types"); 7723f33d2f3SMatthias Springer } 7733f33d2f3SMatthias Springer 7743f33d2f3SMatthias Springer Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType); 7753f33d2f3SMatthias Springer for (auto [idx, operand] : llvm::enumerate(updatedOperands)) { 7763f33d2f3SMatthias Springer packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx); 7773f33d2f3SMatthias Springer } 7783f33d2f3SMatthias Springer rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed, 7793f33d2f3SMatthias Springer op->getAttrs()); 7803f33d2f3SMatthias Springer return success(); 7813f33d2f3SMatthias Springer } 7823f33d2f3SMatthias Springer 783499abb24SKrzysztof Drewniak void mlir::populateGpuMemorySpaceAttributeConversions( 784499abb24SKrzysztof Drewniak TypeConverter &typeConverter, const MemorySpaceMapping &mapping) { 785499abb24SKrzysztof Drewniak typeConverter.addTypeAttributeConversion( 786499abb24SKrzysztof Drewniak [mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) { 787499abb24SKrzysztof Drewniak gpu::AddressSpace memorySpace = memorySpaceAttr.getValue(); 788499abb24SKrzysztof Drewniak unsigned addressSpace = mapping(memorySpace); 789499abb24SKrzysztof Drewniak return wrapNumericMemorySpace(memorySpaceAttr.getContext(), 790499abb24SKrzysztof Drewniak addressSpace); 791499abb24SKrzysztof Drewniak }); 792499abb24SKrzysztof Drewniak } 793