xref: /llvm-project/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp (revision 599c73990532333e62edf8ba19a5302b543f976f)
14c4876c3SAlex Zinenko //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
24c4876c3SAlex Zinenko //
34c4876c3SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44c4876c3SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
54c4876c3SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64c4876c3SAlex Zinenko //
74c4876c3SAlex Zinenko //===----------------------------------------------------------------------===//
84c4876c3SAlex Zinenko 
94c4876c3SAlex Zinenko #include "GPUOpsLowering.h"
10888717e8SNicolas Vasilache 
11888717e8SNicolas Vasilache #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
12e1da6291SKrzysztof Drewniak #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13499abb24SKrzysztof Drewniak #include "mlir/IR/Attributes.h"
144c4876c3SAlex Zinenko #include "mlir/IR/Builders.h"
15b251b608SChristian Sigg #include "mlir/IR/BuiltinTypes.h"
1617faae95SLaszlo Kindrat #include "llvm/ADT/SmallVectorExtras.h"
17ea84897bSGuray Ozen #include "llvm/ADT/StringSet.h"
184c4876c3SAlex Zinenko #include "llvm/Support/FormatVariadic.h"
194c4876c3SAlex Zinenko 
204c4876c3SAlex Zinenko using namespace mlir;
214c4876c3SAlex Zinenko 
22*599c7399SMatthias Springer LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
23*599c7399SMatthias Springer                                            Location loc, OpBuilder &b,
24*599c7399SMatthias Springer                                            StringRef name,
25*599c7399SMatthias Springer                                            LLVM::LLVMFunctionType type) {
26*599c7399SMatthias Springer   LLVM::LLVMFuncOp ret;
27*599c7399SMatthias Springer   if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
28*599c7399SMatthias Springer     OpBuilder::InsertionGuard guard(b);
29*599c7399SMatthias Springer     b.setInsertionPointToStart(moduleOp.getBody());
30*599c7399SMatthias Springer     ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External);
31*599c7399SMatthias Springer   }
32*599c7399SMatthias Springer   return ret;
33*599c7399SMatthias Springer }
34*599c7399SMatthias Springer 
35*599c7399SMatthias Springer static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
36*599c7399SMatthias Springer                                            StringRef prefix) {
37*599c7399SMatthias Springer   // Get a unique global name.
38*599c7399SMatthias Springer   unsigned stringNumber = 0;
39*599c7399SMatthias Springer   SmallString<16> stringConstName;
40*599c7399SMatthias Springer   do {
41*599c7399SMatthias Springer     stringConstName.clear();
42*599c7399SMatthias Springer     (prefix + Twine(stringNumber++)).toStringRef(stringConstName);
43*599c7399SMatthias Springer   } while (moduleOp.lookupSymbol(stringConstName));
44*599c7399SMatthias Springer   return stringConstName;
45*599c7399SMatthias Springer }
46*599c7399SMatthias Springer 
47*599c7399SMatthias Springer LLVM::GlobalOp
48*599c7399SMatthias Springer mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
49*599c7399SMatthias Springer                                 gpu::GPUModuleOp moduleOp, Type llvmI8,
50*599c7399SMatthias Springer                                 StringRef namePrefix, StringRef str,
51*599c7399SMatthias Springer                                 uint64_t alignment, unsigned addrSpace) {
52*599c7399SMatthias Springer   llvm::SmallString<20> nullTermStr(str);
53*599c7399SMatthias Springer   nullTermStr.push_back('\0'); // Null terminate for C
54*599c7399SMatthias Springer   auto globalType =
55*599c7399SMatthias Springer       LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes());
56*599c7399SMatthias Springer   StringAttr attr = b.getStringAttr(nullTermStr);
57*599c7399SMatthias Springer 
58*599c7399SMatthias Springer   // Try to find existing global.
59*599c7399SMatthias Springer   for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
60*599c7399SMatthias Springer     if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
61*599c7399SMatthias Springer         globalOp.getValueAttr() == attr &&
62*599c7399SMatthias Springer         globalOp.getAlignment().value_or(0) == alignment &&
63*599c7399SMatthias Springer         globalOp.getAddrSpace() == addrSpace)
64*599c7399SMatthias Springer       return globalOp;
65*599c7399SMatthias Springer 
66*599c7399SMatthias Springer   // Not found: create new global.
67*599c7399SMatthias Springer   OpBuilder::InsertionGuard guard(b);
68*599c7399SMatthias Springer   b.setInsertionPointToStart(moduleOp.getBody());
69*599c7399SMatthias Springer   SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
70*599c7399SMatthias Springer   return b.create<LLVM::GlobalOp>(loc, globalType,
71*599c7399SMatthias Springer                                   /*isConstant=*/true, LLVM::Linkage::Internal,
72*599c7399SMatthias Springer                                   name, attr, alignment, addrSpace);
73*599c7399SMatthias Springer }
74*599c7399SMatthias Springer 
754c4876c3SAlex Zinenko LogicalResult
76ef976337SRiver Riddle GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
774c4876c3SAlex Zinenko                                    ConversionPatternRewriter &rewriter) const {
784c4876c3SAlex Zinenko   Location loc = gpuFuncOp.getLoc();
794c4876c3SAlex Zinenko 
804c4876c3SAlex Zinenko   SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
81d45de800SVictor Perez   if (encodeWorkgroupAttributionsAsArguments) {
82d45de800SVictor Perez     // Append an `llvm.ptr` argument to the function signature to encode
83d45de800SVictor Perez     // workgroup attributions.
84d45de800SVictor Perez 
85d45de800SVictor Perez     ArrayRef<BlockArgument> workgroupAttributions =
86d45de800SVictor Perez         gpuFuncOp.getWorkgroupAttributions();
87d45de800SVictor Perez     size_t numAttributions = workgroupAttributions.size();
88d45de800SVictor Perez 
89d45de800SVictor Perez     // Insert all arguments at the end.
90d45de800SVictor Perez     unsigned index = gpuFuncOp.getNumArguments();
91d45de800SVictor Perez     SmallVector<unsigned> argIndices(numAttributions, index);
92d45de800SVictor Perez 
93d45de800SVictor Perez     // New arguments will simply be `llvm.ptr` with the correct address space
94d45de800SVictor Perez     Type workgroupPtrType =
95d45de800SVictor Perez         rewriter.getType<LLVM::LLVMPointerType>(workgroupAddrSpace);
96d45de800SVictor Perez     SmallVector<Type> argTypes(numAttributions, workgroupPtrType);
97d45de800SVictor Perez 
98d45de800SVictor Perez     // Attributes: noalias, llvm.mlir.workgroup_attribution(<size>, <type>)
99d45de800SVictor Perez     std::array attrs{
100d45de800SVictor Perez         rewriter.getNamedAttr(LLVM::LLVMDialect::getNoAliasAttrName(),
101d45de800SVictor Perez                               rewriter.getUnitAttr()),
102d45de800SVictor Perez         rewriter.getNamedAttr(
103d45de800SVictor Perez             getDialect().getWorkgroupAttributionAttrHelper().getName(),
104d45de800SVictor Perez             rewriter.getUnitAttr()),
105d45de800SVictor Perez     };
106d45de800SVictor Perez     SmallVector<DictionaryAttr> argAttrs;
107d45de800SVictor Perez     for (BlockArgument attribution : workgroupAttributions) {
108d45de800SVictor Perez       auto attributionType = cast<MemRefType>(attribution.getType());
109d45de800SVictor Perez       IntegerAttr numElements =
110d45de800SVictor Perez           rewriter.getI64IntegerAttr(attributionType.getNumElements());
111d45de800SVictor Perez       Type llvmElementType =
112d45de800SVictor Perez           getTypeConverter()->convertType(attributionType.getElementType());
113d45de800SVictor Perez       if (!llvmElementType)
114d45de800SVictor Perez         return failure();
115d45de800SVictor Perez       TypeAttr type = TypeAttr::get(llvmElementType);
116d45de800SVictor Perez       attrs.back().setValue(
117d45de800SVictor Perez           rewriter.getAttr<LLVM::WorkgroupAttributionAttr>(numElements, type));
118d45de800SVictor Perez       argAttrs.push_back(rewriter.getDictionaryAttr(attrs));
119d45de800SVictor Perez     }
120d45de800SVictor Perez 
121d45de800SVictor Perez     // Location match function location
122d45de800SVictor Perez     SmallVector<Location> argLocs(numAttributions, gpuFuncOp.getLoc());
123d45de800SVictor Perez 
124d45de800SVictor Perez     // Perform signature modification
125d45de800SVictor Perez     rewriter.modifyOpInPlace(
126d45de800SVictor Perez         gpuFuncOp, [gpuFuncOp, &argIndices, &argTypes, &argAttrs, &argLocs]() {
127d45de800SVictor Perez           static_cast<FunctionOpInterface>(gpuFuncOp).insertArguments(
128d45de800SVictor Perez               argIndices, argTypes, argAttrs, argLocs);
129d45de800SVictor Perez         });
130d45de800SVictor Perez   } else {
1314c4876c3SAlex Zinenko     workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
132d45de800SVictor Perez     for (auto [idx, attribution] :
133ddd6acd7SKrzysztof Drewniak          llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
1345550c821STres Popp       auto type = dyn_cast<MemRefType>(attribution.getType());
1354c4876c3SAlex Zinenko       assert(type && type.hasStaticShape() && "unexpected type in attribution");
1364c4876c3SAlex Zinenko 
1374c4876c3SAlex Zinenko       uint64_t numElements = type.getNumElements();
1384c4876c3SAlex Zinenko 
1394c4876c3SAlex Zinenko       auto elementType =
1405550c821STres Popp           cast<Type>(typeConverter->convertType(type.getElementType()));
1414c4876c3SAlex Zinenko       auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
142ddd6acd7SKrzysztof Drewniak       std::string name =
143ddd6acd7SKrzysztof Drewniak           std::string(llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), idx));
14494058c41SKrzysztof Drewniak       uint64_t alignment = 0;
145d45de800SVictor Perez       if (auto alignAttr = dyn_cast_or_null<IntegerAttr>(
146d45de800SVictor Perez               gpuFuncOp.getWorkgroupAttributionAttr(
147ddd6acd7SKrzysztof Drewniak                   idx, LLVM::LLVMDialect::getAlignAttrName())))
14894058c41SKrzysztof Drewniak         alignment = alignAttr.getInt();
1494c4876c3SAlex Zinenko       auto globalOp = rewriter.create<LLVM::GlobalOp>(
1504c4876c3SAlex Zinenko           gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
15194058c41SKrzysztof Drewniak           LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment,
15294058c41SKrzysztof Drewniak           workgroupAddrSpace);
1534c4876c3SAlex Zinenko       workgroupBuffers.push_back(globalOp);
1544c4876c3SAlex Zinenko     }
155d45de800SVictor Perez   }
1564c4876c3SAlex Zinenko 
1574c4876c3SAlex Zinenko   // Remap proper input types.
1584c4876c3SAlex Zinenko   TypeConverter::SignatureConversion signatureConversion(
1594c4876c3SAlex Zinenko       gpuFuncOp.front().getNumArguments());
16067754a9dSNicolas Vasilache 
1610e5aeae6SMarkus Böck   Type funcType = getTypeConverter()->convertFunctionSignature(
162162f7572SMahesh Ravishankar       gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
163162f7572SMahesh Ravishankar       getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
16467754a9dSNicolas Vasilache   if (!funcType) {
16567754a9dSNicolas Vasilache     return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) {
16667754a9dSNicolas Vasilache       diag << "failed to convert function signature type for: "
16767754a9dSNicolas Vasilache            << gpuFuncOp.getFunctionType();
16867754a9dSNicolas Vasilache     });
16967754a9dSNicolas Vasilache   }
1704c4876c3SAlex Zinenko 
1714c4876c3SAlex Zinenko   // Create the new function operation. Only copy those attributes that are
1724c4876c3SAlex Zinenko   // not specific to function modeling.
1734c4876c3SAlex Zinenko   SmallVector<NamedAttribute, 4> attributes;
174fbf67bfaSstefankoncarevic   ArrayAttr argAttrs;
17556774bddSMarius Brehler   for (const auto &attr : gpuFuncOp->getAttrs()) {
1760c7890c8SRiver Riddle     if (attr.getName() == SymbolTable::getSymbolAttrName() ||
17753406427SJeff Niu         attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
17894058c41SKrzysztof Drewniak         attr.getName() ==
17994058c41SKrzysztof Drewniak             gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() ||
18094058c41SKrzysztof Drewniak         attr.getName() == gpuFuncOp.getWorkgroupAttribAttrsAttrName() ||
18143fd4c49SKrzysztof Drewniak         attr.getName() == gpuFuncOp.getPrivateAttribAttrsAttrName() ||
18243fd4c49SKrzysztof Drewniak         attr.getName() == gpuFuncOp.getKnownBlockSizeAttrName() ||
18343fd4c49SKrzysztof Drewniak         attr.getName() == gpuFuncOp.getKnownGridSizeAttrName())
1844c4876c3SAlex Zinenko       continue;
185fbf67bfaSstefankoncarevic     if (attr.getName() == gpuFuncOp.getArgAttrsAttrName()) {
186fbf67bfaSstefankoncarevic       argAttrs = gpuFuncOp.getArgAttrsAttr();
187fbf67bfaSstefankoncarevic       continue;
188fbf67bfaSstefankoncarevic     }
1894c4876c3SAlex Zinenko     attributes.push_back(attr);
1904c4876c3SAlex Zinenko   }
19143fd4c49SKrzysztof Drewniak 
19243fd4c49SKrzysztof Drewniak   DenseI32ArrayAttr knownBlockSize = gpuFuncOp.getKnownBlockSizeAttr();
19343fd4c49SKrzysztof Drewniak   DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr();
19443fd4c49SKrzysztof Drewniak   // Ensure we don't lose information if the function is lowered before its
19543fd4c49SKrzysztof Drewniak   // surrounding context.
19643fd4c49SKrzysztof Drewniak   auto *gpuDialect = cast<gpu::GPUDialect>(gpuFuncOp->getDialect());
19743fd4c49SKrzysztof Drewniak   if (knownBlockSize)
19843fd4c49SKrzysztof Drewniak     attributes.emplace_back(gpuDialect->getKnownBlockSizeAttrHelper().getName(),
19943fd4c49SKrzysztof Drewniak                             knownBlockSize);
20043fd4c49SKrzysztof Drewniak   if (knownGridSize)
20143fd4c49SKrzysztof Drewniak     attributes.emplace_back(gpuDialect->getKnownGridSizeAttrHelper().getName(),
20243fd4c49SKrzysztof Drewniak                             knownGridSize);
20343fd4c49SKrzysztof Drewniak 
2044c4876c3SAlex Zinenko   // Add a dialect specific kernel attribute in addition to GPU kernel
2054c4876c3SAlex Zinenko   // attribute. The former is necessary for further translation while the
2064c4876c3SAlex Zinenko   // latter is expected by gpu.launch_func.
207763109e3SGuray Ozen   if (gpuFuncOp.isKernel()) {
208d45de800SVictor Perez     if (kernelAttributeName)
2094c4876c3SAlex Zinenko       attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
21043fd4c49SKrzysztof Drewniak     // Set the dialect-specific block size attribute if there is one.
211d45de800SVictor Perez     if (kernelBlockSizeAttributeName && knownBlockSize) {
212d45de800SVictor Perez       attributes.emplace_back(kernelBlockSizeAttributeName, knownBlockSize);
213763109e3SGuray Ozen     }
214763109e3SGuray Ozen   }
215d45de800SVictor Perez   LLVM::CConv callingConvention = gpuFuncOp.isKernel()
216d45de800SVictor Perez                                       ? kernelCallingConvention
217d45de800SVictor Perez                                       : nonKernelCallingConvention;
2184c4876c3SAlex Zinenko   auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
2194c4876c3SAlex Zinenko       gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
220d45de800SVictor Perez       LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention,
221b126ee65STobias Gysi       /*comdat=*/nullptr, attributes);
2224c4876c3SAlex Zinenko 
2234c4876c3SAlex Zinenko   {
2244c4876c3SAlex Zinenko     // Insert operations that correspond to converted workgroup and private
2254c4876c3SAlex Zinenko     // memory attributions to the body of the function. This must operate on
2264c4876c3SAlex Zinenko     // the original function, before the body region is inlined in the new
2274c4876c3SAlex Zinenko     // function to maintain the relation between block arguments and the
2284c4876c3SAlex Zinenko     // parent operation that assigns their semantics.
2294c4876c3SAlex Zinenko     OpBuilder::InsertionGuard guard(rewriter);
2304c4876c3SAlex Zinenko 
2314c4876c3SAlex Zinenko     // Rewrite workgroup memory attributions to addresses of global buffers.
2324c4876c3SAlex Zinenko     rewriter.setInsertionPointToStart(&gpuFuncOp.front());
2334c4876c3SAlex Zinenko     unsigned numProperArguments = gpuFuncOp.getNumArguments();
2344c4876c3SAlex Zinenko 
235d45de800SVictor Perez     if (encodeWorkgroupAttributionsAsArguments) {
236d45de800SVictor Perez       // Build a MemRefDescriptor with each of the arguments added above.
237d45de800SVictor Perez 
238d45de800SVictor Perez       unsigned numAttributions = gpuFuncOp.getNumWorkgroupAttributions();
239d45de800SVictor Perez       assert(numProperArguments >= numAttributions &&
240d45de800SVictor Perez              "Expecting attributions to be encoded as arguments already");
241d45de800SVictor Perez 
242d45de800SVictor Perez       // Arguments encoding workgroup attributions will be in positions
243d45de800SVictor Perez       // [numProperArguments, numProperArguments+numAttributions)
244d45de800SVictor Perez       ArrayRef<BlockArgument> attributionArguments =
245d45de800SVictor Perez           gpuFuncOp.getArguments().slice(numProperArguments - numAttributions,
246d45de800SVictor Perez                                          numAttributions);
247d45de800SVictor Perez       for (auto [idx, vals] : llvm::enumerate(llvm::zip_equal(
248d45de800SVictor Perez                gpuFuncOp.getWorkgroupAttributions(), attributionArguments))) {
249d45de800SVictor Perez         auto [attribution, arg] = vals;
250d45de800SVictor Perez         auto type = cast<MemRefType>(attribution.getType());
251d45de800SVictor Perez 
252d45de800SVictor Perez         // Arguments are of llvm.ptr type and attributions are of memref type:
253d45de800SVictor Perez         // we need to wrap them in memref descriptors.
254d45de800SVictor Perez         Value descr = MemRefDescriptor::fromStaticShape(
255d45de800SVictor Perez             rewriter, loc, *getTypeConverter(), type, arg);
256d45de800SVictor Perez 
257d45de800SVictor Perez         // And remap the arguments
258d45de800SVictor Perez         signatureConversion.remapInput(numProperArguments + idx, descr);
259d45de800SVictor Perez       }
260d45de800SVictor Perez     } else {
261ddd6acd7SKrzysztof Drewniak       for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
26297a238e8SChristian Ulmann         auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
26397a238e8SChristian Ulmann                                                   global.getAddrSpace());
2640e5aeae6SMarkus Böck         Value address = rewriter.create<LLVM::AddressOfOp>(
26597a238e8SChristian Ulmann             loc, ptrType, global.getSymNameAttr());
26697a238e8SChristian Ulmann         Value memory =
267d45de800SVictor Perez             rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(),
268d45de800SVictor Perez                                          address, ArrayRef<LLVM::GEPArg>{0, 0});
2694c4876c3SAlex Zinenko 
2704c4876c3SAlex Zinenko         // Build a memref descriptor pointing to the buffer to plug with the
2714c4876c3SAlex Zinenko         // existing memref infrastructure. This may use more registers than
2724c4876c3SAlex Zinenko         // otherwise necessary given that memref sizes are fixed, but we can try
2734c4876c3SAlex Zinenko         // and canonicalize that away later.
274ddd6acd7SKrzysztof Drewniak         Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
2755550c821STres Popp         auto type = cast<MemRefType>(attribution.getType());
2764c4876c3SAlex Zinenko         auto descr = MemRefDescriptor::fromStaticShape(
2774c4876c3SAlex Zinenko             rewriter, loc, *getTypeConverter(), type, memory);
278ddd6acd7SKrzysztof Drewniak         signatureConversion.remapInput(numProperArguments + idx, descr);
2794c4876c3SAlex Zinenko       }
280d45de800SVictor Perez     }
2814c4876c3SAlex Zinenko 
2824c4876c3SAlex Zinenko     // Rewrite private memory attributions to alloca'ed buffers.
2834c4876c3SAlex Zinenko     unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
2844c4876c3SAlex Zinenko     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
285ddd6acd7SKrzysztof Drewniak     for (const auto [idx, attribution] :
286ddd6acd7SKrzysztof Drewniak          llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
2875550c821STres Popp       auto type = cast<MemRefType>(attribution.getType());
2884c4876c3SAlex Zinenko       assert(type && type.hasStaticShape() && "unexpected type in attribution");
2894c4876c3SAlex Zinenko 
2904c4876c3SAlex Zinenko       // Explicitly drop memory space when lowering private memory
2914c4876c3SAlex Zinenko       // attributions since NVVM models it as `alloca`s in the default
2924c4876c3SAlex Zinenko       // memory space and does not support `alloca`s with addrspace(5).
2930e5aeae6SMarkus Böck       Type elementType = typeConverter->convertType(type.getElementType());
2940e5aeae6SMarkus Böck       auto ptrType =
29597a238e8SChristian Ulmann           LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace);
2964c4876c3SAlex Zinenko       Value numElements = rewriter.create<LLVM::ConstantOp>(
2970af643f3SJeff Niu           gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
29894058c41SKrzysztof Drewniak       uint64_t alignment = 0;
29994058c41SKrzysztof Drewniak       if (auto alignAttr =
3005550c821STres Popp               dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
301ddd6acd7SKrzysztof Drewniak                   idx, LLVM::LLVMDialect::getAlignAttrName())))
30294058c41SKrzysztof Drewniak         alignment = alignAttr.getInt();
3034c4876c3SAlex Zinenko       Value allocated = rewriter.create<LLVM::AllocaOp>(
30494058c41SKrzysztof Drewniak           gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
3054c4876c3SAlex Zinenko       auto descr = MemRefDescriptor::fromStaticShape(
3064c4876c3SAlex Zinenko           rewriter, loc, *getTypeConverter(), type, allocated);
3074c4876c3SAlex Zinenko       signatureConversion.remapInput(
308ddd6acd7SKrzysztof Drewniak           numProperArguments + numWorkgroupAttributions + idx, descr);
3094c4876c3SAlex Zinenko     }
3104c4876c3SAlex Zinenko   }
3114c4876c3SAlex Zinenko 
3124c4876c3SAlex Zinenko   // Move the region to the new function, update the entry block signature.
3134c4876c3SAlex Zinenko   rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
3144c4876c3SAlex Zinenko                               llvmFuncOp.end());
3154c4876c3SAlex Zinenko   if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
3164c4876c3SAlex Zinenko                                          &signatureConversion)))
3174c4876c3SAlex Zinenko     return failure();
3184c4876c3SAlex Zinenko 
319fbf67bfaSstefankoncarevic   // Get memref type from function arguments and set the noalias to
320fbf67bfaSstefankoncarevic   // pointer arguments.
321ddd6acd7SKrzysztof Drewniak   for (const auto [idx, argTy] :
322ddd6acd7SKrzysztof Drewniak        llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
323ddd6acd7SKrzysztof Drewniak     auto remapping = signatureConversion.getInputMapping(idx);
324ddd6acd7SKrzysztof Drewniak     NamedAttrList argAttr =
325a5757c5bSChristian Sigg         argAttrs ? cast<DictionaryAttr>(argAttrs[idx]) : NamedAttrList();
326ddd6acd7SKrzysztof Drewniak     auto copyAttribute = [&](StringRef attrName) {
327ddd6acd7SKrzysztof Drewniak       Attribute attr = argAttr.erase(attrName);
328ddd6acd7SKrzysztof Drewniak       if (!attr)
329ddd6acd7SKrzysztof Drewniak         return;
330ddd6acd7SKrzysztof Drewniak       for (size_t i = 0, e = remapping->size; i < e; ++i)
331ddd6acd7SKrzysztof Drewniak         llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
332ddd6acd7SKrzysztof Drewniak     };
333fbf67bfaSstefankoncarevic     auto copyPointerAttribute = [&](StringRef attrName) {
334fbf67bfaSstefankoncarevic       Attribute attr = argAttr.erase(attrName);
335fbf67bfaSstefankoncarevic 
336fbf67bfaSstefankoncarevic       if (!attr)
337fbf67bfaSstefankoncarevic         return;
338fbf67bfaSstefankoncarevic       if (remapping->size > 1 &&
339fbf67bfaSstefankoncarevic           attrName == LLVM::LLVMDialect::getNoAliasAttrName()) {
340fbf67bfaSstefankoncarevic         emitWarning(llvmFuncOp.getLoc(),
341fbf67bfaSstefankoncarevic                     "Cannot copy noalias with non-bare pointers.\n");
342fbf67bfaSstefankoncarevic         return;
343fbf67bfaSstefankoncarevic       }
344fbf67bfaSstefankoncarevic       for (size_t i = 0, e = remapping->size; i < e; ++i) {
345a5757c5bSChristian Sigg         if (isa<LLVM::LLVMPointerType>(
346a5757c5bSChristian Sigg                 llvmFuncOp.getArgument(remapping->inputNo + i).getType())) {
347fbf67bfaSstefankoncarevic           llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
348fbf67bfaSstefankoncarevic         }
349fbf67bfaSstefankoncarevic       }
350fbf67bfaSstefankoncarevic     };
351fbf67bfaSstefankoncarevic 
352fbf67bfaSstefankoncarevic     if (argAttr.empty())
353fbf67bfaSstefankoncarevic       continue;
354fbf67bfaSstefankoncarevic 
355ddd6acd7SKrzysztof Drewniak     copyAttribute(LLVM::LLVMDialect::getReturnedAttrName());
356ddd6acd7SKrzysztof Drewniak     copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName());
357ddd6acd7SKrzysztof Drewniak     copyAttribute(LLVM::LLVMDialect::getInRegAttrName());
358ddd6acd7SKrzysztof Drewniak     bool lowersToPointer = false;
359ddd6acd7SKrzysztof Drewniak     for (size_t i = 0, e = remapping->size; i < e; ++i) {
360ddd6acd7SKrzysztof Drewniak       lowersToPointer |= isa<LLVM::LLVMPointerType>(
361ddd6acd7SKrzysztof Drewniak           llvmFuncOp.getArgument(remapping->inputNo + i).getType());
362ddd6acd7SKrzysztof Drewniak     }
363ddd6acd7SKrzysztof Drewniak 
364ddd6acd7SKrzysztof Drewniak     if (lowersToPointer) {
365fbf67bfaSstefankoncarevic       copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName());
366ddd6acd7SKrzysztof Drewniak       copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName());
367ddd6acd7SKrzysztof Drewniak       copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName());
368ddd6acd7SKrzysztof Drewniak       copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName());
369fbf67bfaSstefankoncarevic       copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName());
370fbf67bfaSstefankoncarevic       copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName());
371ddd6acd7SKrzysztof Drewniak       copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName());
372fbf67bfaSstefankoncarevic       copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName());
373fbf67bfaSstefankoncarevic       copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName());
374fbf67bfaSstefankoncarevic       copyPointerAttribute(
375fbf67bfaSstefankoncarevic           LLVM::LLVMDialect::getDereferenceableOrNullAttrName());
376d45de800SVictor Perez       copyPointerAttribute(
377d45de800SVictor Perez           LLVM::LLVMDialect::WorkgroupAttributionAttrHelper::getNameStr());
378fbf67bfaSstefankoncarevic     }
379fbf67bfaSstefankoncarevic   }
3804c4876c3SAlex Zinenko   rewriter.eraseOp(gpuFuncOp);
3814c4876c3SAlex Zinenko   return success();
3824c4876c3SAlex Zinenko }
383e1da6291SKrzysztof Drewniak 
384e1da6291SKrzysztof Drewniak LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
385e1da6291SKrzysztof Drewniak     gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
386e1da6291SKrzysztof Drewniak     ConversionPatternRewriter &rewriter) const {
387e1da6291SKrzysztof Drewniak   Location loc = gpuPrintfOp->getLoc();
388e1da6291SKrzysztof Drewniak 
389e1da6291SKrzysztof Drewniak   mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
39097a238e8SChristian Ulmann   auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
391e1da6291SKrzysztof Drewniak   mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
392e1da6291SKrzysztof Drewniak   mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
393e1da6291SKrzysztof Drewniak   // Note: this is the GPUModule op, not the ModuleOp that surrounds it
394e1da6291SKrzysztof Drewniak   // This ensures that global constants and declarations are placed within
395e1da6291SKrzysztof Drewniak   // the device code, not the host code
396e1da6291SKrzysztof Drewniak   auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
397e1da6291SKrzysztof Drewniak 
398e1da6291SKrzysztof Drewniak   auto ocklBegin =
399e1da6291SKrzysztof Drewniak       getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
400e1da6291SKrzysztof Drewniak                           LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
401e1da6291SKrzysztof Drewniak   LLVM::LLVMFuncOp ocklAppendArgs;
40210c04f46SRiver Riddle   if (!adaptor.getArgs().empty()) {
403e1da6291SKrzysztof Drewniak     ocklAppendArgs = getOrDefineFunction(
404e1da6291SKrzysztof Drewniak         moduleOp, loc, rewriter, "__ockl_printf_append_args",
405e1da6291SKrzysztof Drewniak         LLVM::LLVMFunctionType::get(
406e1da6291SKrzysztof Drewniak             llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64,
407e1da6291SKrzysztof Drewniak                       llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32}));
408e1da6291SKrzysztof Drewniak   }
409e1da6291SKrzysztof Drewniak   auto ocklAppendStringN = getOrDefineFunction(
410e1da6291SKrzysztof Drewniak       moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
411e1da6291SKrzysztof Drewniak       LLVM::LLVMFunctionType::get(
412e1da6291SKrzysztof Drewniak           llvmI64,
41397a238e8SChristian Ulmann           {llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
414e1da6291SKrzysztof Drewniak 
415e1da6291SKrzysztof Drewniak   /// Start the printf hostcall
4160af643f3SJeff Niu   Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
417e1da6291SKrzysztof Drewniak   auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
4185e0c3b43SJeff Niu   Value printfDesc = printfBeginCall.getResult();
419e1da6291SKrzysztof Drewniak 
4202da417e7SMatthias Springer   // Create the global op or find an existing one.
421*599c7399SMatthias Springer   LLVM::GlobalOp global = getOrCreateStringConstant(
422*599c7399SMatthias Springer       rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
423e1da6291SKrzysztof Drewniak 
424e1da6291SKrzysztof Drewniak   // Get a pointer to the format string's first element and pass it to printf()
4250e5aeae6SMarkus Böck   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
4260e5aeae6SMarkus Böck       loc,
42797a238e8SChristian Ulmann       LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
4280e5aeae6SMarkus Böck       global.getSymNameAttr());
4292da417e7SMatthias Springer   Value stringStart =
4302da417e7SMatthias Springer       rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
4312da417e7SMatthias Springer                                    globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
4322da417e7SMatthias Springer   Value stringLen = rewriter.create<LLVM::ConstantOp>(
4332da417e7SMatthias Springer       loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
434e1da6291SKrzysztof Drewniak 
4350af643f3SJeff Niu   Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
4360af643f3SJeff Niu   Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
437e1da6291SKrzysztof Drewniak 
43879a0330aSMehdi Amini   auto appendFormatCall = rewriter.create<LLVM::CallOp>(
43979a0330aSMehdi Amini       loc, ocklAppendStringN,
44079a0330aSMehdi Amini       ValueRange{printfDesc, stringStart, stringLen,
44110c04f46SRiver Riddle                  adaptor.getArgs().empty() ? oneI32 : zeroI32});
4425e0c3b43SJeff Niu   printfDesc = appendFormatCall.getResult();
443e1da6291SKrzysztof Drewniak 
444e1da6291SKrzysztof Drewniak   // __ockl_printf_append_args takes 7 values per append call
445e1da6291SKrzysztof Drewniak   constexpr size_t argsPerAppend = 7;
44610c04f46SRiver Riddle   size_t nArgs = adaptor.getArgs().size();
447e1da6291SKrzysztof Drewniak   for (size_t group = 0; group < nArgs; group += argsPerAppend) {
448e1da6291SKrzysztof Drewniak     size_t bound = std::min(group + argsPerAppend, nArgs);
449e1da6291SKrzysztof Drewniak     size_t numArgsThisCall = bound - group;
450e1da6291SKrzysztof Drewniak 
451e1da6291SKrzysztof Drewniak     SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
452e1da6291SKrzysztof Drewniak     arguments.push_back(printfDesc);
4530af643f3SJeff Niu     arguments.push_back(
4540af643f3SJeff Niu         rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
455e1da6291SKrzysztof Drewniak     for (size_t i = group; i < bound; ++i) {
45610c04f46SRiver Riddle       Value arg = adaptor.getArgs()[i];
4575550c821STres Popp       if (auto floatType = dyn_cast<FloatType>(arg.getType())) {
458e1da6291SKrzysztof Drewniak         if (!floatType.isF64())
459e1da6291SKrzysztof Drewniak           arg = rewriter.create<LLVM::FPExtOp>(
460e1da6291SKrzysztof Drewniak               loc, typeConverter->convertType(rewriter.getF64Type()), arg);
461e1da6291SKrzysztof Drewniak         arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
462e1da6291SKrzysztof Drewniak       }
463e1da6291SKrzysztof Drewniak       if (arg.getType().getIntOrFloatBitWidth() != 64)
464e1da6291SKrzysztof Drewniak         arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
465e1da6291SKrzysztof Drewniak 
466e1da6291SKrzysztof Drewniak       arguments.push_back(arg);
467e1da6291SKrzysztof Drewniak     }
468e1da6291SKrzysztof Drewniak     // Pad out to 7 arguments since the hostcall always needs 7
469e1da6291SKrzysztof Drewniak     for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
470e1da6291SKrzysztof Drewniak       arguments.push_back(zeroI64);
471e1da6291SKrzysztof Drewniak     }
472e1da6291SKrzysztof Drewniak 
473e1da6291SKrzysztof Drewniak     auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
474e1da6291SKrzysztof Drewniak     arguments.push_back(isLast);
475e1da6291SKrzysztof Drewniak     auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
4765e0c3b43SJeff Niu     printfDesc = call.getResult();
477e1da6291SKrzysztof Drewniak   }
478e1da6291SKrzysztof Drewniak   rewriter.eraseOp(gpuPrintfOp);
479e1da6291SKrzysztof Drewniak   return success();
480e1da6291SKrzysztof Drewniak }
481e1da6291SKrzysztof Drewniak 
482e1da6291SKrzysztof Drewniak LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
483e1da6291SKrzysztof Drewniak     gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
484e1da6291SKrzysztof Drewniak     ConversionPatternRewriter &rewriter) const {
485e1da6291SKrzysztof Drewniak   Location loc = gpuPrintfOp->getLoc();
486e1da6291SKrzysztof Drewniak 
487e1da6291SKrzysztof Drewniak   mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
48897a238e8SChristian Ulmann   mlir::Type ptrType =
48997a238e8SChristian Ulmann       LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
490e1da6291SKrzysztof Drewniak 
491e1da6291SKrzysztof Drewniak   // Note: this is the GPUModule op, not the ModuleOp that surrounds it
492e1da6291SKrzysztof Drewniak   // This ensures that global constants and declarations are placed within
493e1da6291SKrzysztof Drewniak   // the device code, not the host code
494e1da6291SKrzysztof Drewniak   auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
495e1da6291SKrzysztof Drewniak 
49697a238e8SChristian Ulmann   auto printfType =
49797a238e8SChristian Ulmann       LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
498e1da6291SKrzysztof Drewniak                                   /*isVarArg=*/true);
499e1da6291SKrzysztof Drewniak   LLVM::LLVMFuncOp printfDecl =
500e1da6291SKrzysztof Drewniak       getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
501e1da6291SKrzysztof Drewniak 
5022da417e7SMatthias Springer   // Create the global op or find an existing one.
503*599c7399SMatthias Springer   LLVM::GlobalOp global = getOrCreateStringConstant(
504*599c7399SMatthias Springer       rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(),
505*599c7399SMatthias Springer       /*alignment=*/0, addressSpace);
506e1da6291SKrzysztof Drewniak 
507e1da6291SKrzysztof Drewniak   // Get a pointer to the format string's first element
5080e5aeae6SMarkus Böck   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
5090e5aeae6SMarkus Böck       loc,
51097a238e8SChristian Ulmann       LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
5110e5aeae6SMarkus Böck       global.getSymNameAttr());
5122da417e7SMatthias Springer   Value stringStart =
5132da417e7SMatthias Springer       rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
5142da417e7SMatthias Springer                                    globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
515e1da6291SKrzysztof Drewniak 
516e1da6291SKrzysztof Drewniak   // Construct arguments and function call
51710c04f46SRiver Riddle   auto argsRange = adaptor.getArgs();
518e1da6291SKrzysztof Drewniak   SmallVector<Value, 4> printfArgs;
519e1da6291SKrzysztof Drewniak   printfArgs.reserve(argsRange.size() + 1);
520e1da6291SKrzysztof Drewniak   printfArgs.push_back(stringStart);
521e1da6291SKrzysztof Drewniak   printfArgs.append(argsRange.begin(), argsRange.end());
522e1da6291SKrzysztof Drewniak 
523e1da6291SKrzysztof Drewniak   rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
524e1da6291SKrzysztof Drewniak   rewriter.eraseOp(gpuPrintfOp);
525e1da6291SKrzysztof Drewniak   return success();
526e1da6291SKrzysztof Drewniak }
527b251b608SChristian Sigg 
5287efdc117SThomas Raoux LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
5297efdc117SThomas Raoux     gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
5307efdc117SThomas Raoux     ConversionPatternRewriter &rewriter) const {
5317efdc117SThomas Raoux   Location loc = gpuPrintfOp->getLoc();
5327efdc117SThomas Raoux 
5337efdc117SThomas Raoux   mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
534484668c7SChristian Ulmann   mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
5357efdc117SThomas Raoux 
5367efdc117SThomas Raoux   // Note: this is the GPUModule op, not the ModuleOp that surrounds it
5377efdc117SThomas Raoux   // This ensures that global constants and declarations are placed within
5387efdc117SThomas Raoux   // the device code, not the host code
5397efdc117SThomas Raoux   auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
5407efdc117SThomas Raoux 
5417efdc117SThomas Raoux   auto vprintfType =
542484668c7SChristian Ulmann       LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
5437efdc117SThomas Raoux   LLVM::LLVMFuncOp vprintfDecl =
5447efdc117SThomas Raoux       getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
5457efdc117SThomas Raoux 
5462da417e7SMatthias Springer   // Create the global op or find an existing one.
547*599c7399SMatthias Springer   LLVM::GlobalOp global = getOrCreateStringConstant(
548*599c7399SMatthias Springer       rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
5497efdc117SThomas Raoux 
5507efdc117SThomas Raoux   // Get a pointer to the format string's first element
5517efdc117SThomas Raoux   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
5522da417e7SMatthias Springer   Value stringStart =
5532da417e7SMatthias Springer       rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
5542da417e7SMatthias Springer                                    globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
5557efdc117SThomas Raoux   SmallVector<Type> types;
5567efdc117SThomas Raoux   SmallVector<Value> args;
5577efdc117SThomas Raoux   // Promote and pack the arguments into a stack allocation.
5587efdc117SThomas Raoux   for (Value arg : adaptor.getArgs()) {
5597efdc117SThomas Raoux     Type type = arg.getType();
5607efdc117SThomas Raoux     Value promotedArg = arg;
5617efdc117SThomas Raoux     assert(type.isIntOrFloat());
5625550c821STres Popp     if (isa<FloatType>(type)) {
5637efdc117SThomas Raoux       type = rewriter.getF64Type();
5647efdc117SThomas Raoux       promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg);
5657efdc117SThomas Raoux     }
5667efdc117SThomas Raoux     types.push_back(type);
5677efdc117SThomas Raoux     args.push_back(promotedArg);
5687efdc117SThomas Raoux   }
5697efdc117SThomas Raoux   Type structType =
5707efdc117SThomas Raoux       LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
5717efdc117SThomas Raoux   Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
5727efdc117SThomas Raoux                                                 rewriter.getIndexAttr(1));
573484668c7SChristian Ulmann   Value tempAlloc =
574484668c7SChristian Ulmann       rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one,
5757efdc117SThomas Raoux                                       /*alignment=*/0);
5767efdc117SThomas Raoux   for (auto [index, arg] : llvm::enumerate(args)) {
5779397e5f5SChristian Ulmann     Value ptr = rewriter.create<LLVM::GEPOp>(
57889cd3456SAndrei Golubev         loc, ptrType, structType, tempAlloc,
57989cd3456SAndrei Golubev         ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)});
5807efdc117SThomas Raoux     rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
5817efdc117SThomas Raoux   }
5827efdc117SThomas Raoux   std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
5837efdc117SThomas Raoux 
5847efdc117SThomas Raoux   rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
5857efdc117SThomas Raoux   rewriter.eraseOp(gpuPrintfOp);
5867efdc117SThomas Raoux   return success();
5877efdc117SThomas Raoux }
5887efdc117SThomas Raoux 
589b251b608SChristian Sigg /// Unrolls op if it's operating on vectors.
590b251b608SChristian Sigg LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
591b251b608SChristian Sigg                                       ConversionPatternRewriter &rewriter,
592ce254598SMatthias Springer                                       const LLVMTypeConverter &converter) {
593b251b608SChristian Sigg   TypeRange operandTypes(operands);
594971b8525SJakub Kuderski   if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) {
595b251b608SChristian Sigg     return rewriter.notifyMatchFailure(op, "expected vector operand");
596b251b608SChristian Sigg   }
597b251b608SChristian Sigg   if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
598b251b608SChristian Sigg     return rewriter.notifyMatchFailure(op, "expected no region/successor");
599b251b608SChristian Sigg   if (op->getNumResults() != 1)
600b251b608SChristian Sigg     return rewriter.notifyMatchFailure(op, "expected single result");
6015550c821STres Popp   VectorType vectorType = dyn_cast<VectorType>(op->getResult(0).getType());
602b251b608SChristian Sigg   if (!vectorType)
603b251b608SChristian Sigg     return rewriter.notifyMatchFailure(op, "expected vector result");
604b251b608SChristian Sigg 
605b251b608SChristian Sigg   Location loc = op->getLoc();
606b251b608SChristian Sigg   Value result = rewriter.create<LLVM::UndefOp>(loc, vectorType);
607b251b608SChristian Sigg   Type indexType = converter.convertType(rewriter.getIndexType());
608b251b608SChristian Sigg   StringAttr name = op->getName().getIdentifier();
609b251b608SChristian Sigg   Type elementType = vectorType.getElementType();
610b251b608SChristian Sigg 
611b251b608SChristian Sigg   for (int64_t i = 0; i < vectorType.getNumElements(); ++i) {
612b251b608SChristian Sigg     Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i);
613b251b608SChristian Sigg     auto extractElement = [&](Value operand) -> Value {
6145550c821STres Popp       if (!isa<VectorType>(operand.getType()))
615b251b608SChristian Sigg         return operand;
616b251b608SChristian Sigg       return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index);
617b251b608SChristian Sigg     };
61817faae95SLaszlo Kindrat     auto scalarOperands = llvm::map_to_vector(operands, extractElement);
619b251b608SChristian Sigg     Operation *scalarOp =
620b251b608SChristian Sigg         rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs());
62114858cf0SChristopher Bate     result = rewriter.create<LLVM::InsertElementOp>(
62214858cf0SChristopher Bate         loc, result, scalarOp->getResult(0), index);
623b251b608SChristian Sigg   }
624b251b608SChristian Sigg 
625b251b608SChristian Sigg   rewriter.replaceOp(op, result);
626b251b608SChristian Sigg   return success();
627b251b608SChristian Sigg }
628499abb24SKrzysztof Drewniak 
629499abb24SKrzysztof Drewniak static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
630499abb24SKrzysztof Drewniak   return IntegerAttr::get(IntegerType::get(ctx, 64), space);
631499abb24SKrzysztof Drewniak }
632499abb24SKrzysztof Drewniak 
633ea84897bSGuray Ozen /// Generates a symbol with 0-sized array type for dynamic shared memory usage,
634ea84897bSGuray Ozen /// or uses existing symbol.
63549df12c0SMatthias Springer LLVM::GlobalOp getDynamicSharedMemorySymbol(
63649df12c0SMatthias Springer     ConversionPatternRewriter &rewriter, gpu::GPUModuleOp moduleOp,
63749df12c0SMatthias Springer     gpu::DynamicSharedMemoryOp op, const LLVMTypeConverter *typeConverter,
638ea84897bSGuray Ozen     MemRefType memrefType, unsigned alignmentBit) {
639ea84897bSGuray Ozen   uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
640ea84897bSGuray Ozen 
641ea84897bSGuray Ozen   FailureOr<unsigned> addressSpace =
642ea84897bSGuray Ozen       typeConverter->getMemRefAddressSpace(memrefType);
643ea84897bSGuray Ozen   if (failed(addressSpace)) {
644ea84897bSGuray Ozen     op->emitError() << "conversion of memref memory space "
645ea84897bSGuray Ozen                     << memrefType.getMemorySpace()
646ea84897bSGuray Ozen                     << " to integer address space "
647ea84897bSGuray Ozen                        "failed. Consider adding memory space conversions.";
648ea84897bSGuray Ozen   }
649ea84897bSGuray Ozen 
650ea84897bSGuray Ozen   // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
651ea84897bSGuray Ozen   // LLVM::GlobalOp is suitable for shared memory, return it.
652ea84897bSGuray Ozen   llvm::StringSet<> existingGlobalNames;
65349df12c0SMatthias Springer   for (auto globalOp : moduleOp.getBody()->getOps<LLVM::GlobalOp>()) {
654ea84897bSGuray Ozen     existingGlobalNames.insert(globalOp.getSymName());
655ea84897bSGuray Ozen     if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
656ea84897bSGuray Ozen       if (globalOp.getAddrSpace() == addressSpace.value() &&
657ea84897bSGuray Ozen           arrayType.getNumElements() == 0 &&
658ea84897bSGuray Ozen           globalOp.getAlignment().value_or(0) == alignmentByte) {
659ea84897bSGuray Ozen         return globalOp;
660ea84897bSGuray Ozen       }
661ea84897bSGuray Ozen     }
662ea84897bSGuray Ozen   }
663ea84897bSGuray Ozen 
664ea84897bSGuray Ozen   // Step 2. Find a unique symbol name
665ea84897bSGuray Ozen   unsigned uniquingCounter = 0;
666ea84897bSGuray Ozen   SmallString<128> symName = SymbolTable::generateSymbolName<128>(
667ea84897bSGuray Ozen       "__dynamic_shmem_",
668ea84897bSGuray Ozen       [&](StringRef candidate) {
669ea84897bSGuray Ozen         return existingGlobalNames.contains(candidate);
670ea84897bSGuray Ozen       },
671ea84897bSGuray Ozen       uniquingCounter);
672ea84897bSGuray Ozen 
673ea84897bSGuray Ozen   // Step 3. Generate a global op
674ea84897bSGuray Ozen   OpBuilder::InsertionGuard guard(rewriter);
67549df12c0SMatthias Springer   rewriter.setInsertionPointToStart(moduleOp.getBody());
676ea84897bSGuray Ozen 
677ea84897bSGuray Ozen   auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
678ea84897bSGuray Ozen       typeConverter->convertType(memrefType.getElementType()), 0);
679ea84897bSGuray Ozen 
680ea84897bSGuray Ozen   return rewriter.create<LLVM::GlobalOp>(
681ea84897bSGuray Ozen       op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
682ea84897bSGuray Ozen       LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
683ea84897bSGuray Ozen       addressSpace.value());
684ea84897bSGuray Ozen }
685ea84897bSGuray Ozen 
686ea84897bSGuray Ozen LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
687ea84897bSGuray Ozen     gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
688ea84897bSGuray Ozen     ConversionPatternRewriter &rewriter) const {
689ea84897bSGuray Ozen   Location loc = op.getLoc();
690ea84897bSGuray Ozen   MemRefType memrefType = op.getResultMemref().getType();
691ea84897bSGuray Ozen   Type elementType = typeConverter->convertType(memrefType.getElementType());
692ea84897bSGuray Ozen 
693ea84897bSGuray Ozen   // Step 1: Generate a memref<0xi8> type
694ea84897bSGuray Ozen   MemRefLayoutAttrInterface layout = {};
695ea84897bSGuray Ozen   auto memrefType0sz =
696ea84897bSGuray Ozen       MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
697ea84897bSGuray Ozen 
698ea84897bSGuray Ozen   // Step 2: Generate a global symbol or existing for the dynamic shared
699ea84897bSGuray Ozen   // memory with memref<0xi8> type
70049df12c0SMatthias Springer   auto moduleOp = op->getParentOfType<gpu::GPUModuleOp>();
70149df12c0SMatthias Springer   LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol(
702ea84897bSGuray Ozen       rewriter, moduleOp, op, getTypeConverter(), memrefType0sz, alignmentBit);
703ea84897bSGuray Ozen 
704ea84897bSGuray Ozen   // Step 3. Get address of the global symbol
705ea84897bSGuray Ozen   OpBuilder::InsertionGuard guard(rewriter);
706ea84897bSGuray Ozen   rewriter.setInsertionPoint(op);
707ea84897bSGuray Ozen   auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp);
708ea84897bSGuray Ozen   Type baseType = basePtr->getResultTypes().front();
709ea84897bSGuray Ozen 
710ea84897bSGuray Ozen   // Step 4. Generate GEP using offsets
711ea84897bSGuray Ozen   SmallVector<LLVM::GEPArg> gepArgs = {0};
712ea84897bSGuray Ozen   Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType,
713ea84897bSGuray Ozen                                                 basePtr, gepArgs);
714ea84897bSGuray Ozen   // Step 5. Create a memref descriptor
715ea84897bSGuray Ozen   SmallVector<Value> shape, strides;
716ea84897bSGuray Ozen   Value sizeBytes;
717ea84897bSGuray Ozen   getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides,
718ea84897bSGuray Ozen                            sizeBytes);
719ea84897bSGuray Ozen   auto memRefDescriptor = this->createMemRefDescriptor(
720ea84897bSGuray Ozen       loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter);
721ea84897bSGuray Ozen 
722ea84897bSGuray Ozen   // Step 5. Replace the op with memref descriptor
723ea84897bSGuray Ozen   rewriter.replaceOp(op, {memRefDescriptor});
724ea84897bSGuray Ozen   return success();
725ea84897bSGuray Ozen }
726ea84897bSGuray Ozen 
7273f33d2f3SMatthias Springer LogicalResult GPUReturnOpLowering::matchAndRewrite(
7283f33d2f3SMatthias Springer     gpu::ReturnOp op, OpAdaptor adaptor,
7293f33d2f3SMatthias Springer     ConversionPatternRewriter &rewriter) const {
7303f33d2f3SMatthias Springer   Location loc = op.getLoc();
7313f33d2f3SMatthias Springer   unsigned numArguments = op.getNumOperands();
7323f33d2f3SMatthias Springer   SmallVector<Value, 4> updatedOperands;
7333f33d2f3SMatthias Springer 
7343f33d2f3SMatthias Springer   bool useBarePtrCallConv = getTypeConverter()->getOptions().useBarePtrCallConv;
7353f33d2f3SMatthias Springer   if (useBarePtrCallConv) {
7363f33d2f3SMatthias Springer     // For the bare-ptr calling convention, extract the aligned pointer to
7373f33d2f3SMatthias Springer     // be returned from the memref descriptor.
7383f33d2f3SMatthias Springer     for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
7393f33d2f3SMatthias Springer       Type oldTy = std::get<0>(it).getType();
7403f33d2f3SMatthias Springer       Value newOperand = std::get<1>(it);
7413f33d2f3SMatthias Springer       if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
7423f33d2f3SMatthias Springer                                         cast<BaseMemRefType>(oldTy))) {
7433f33d2f3SMatthias Springer         MemRefDescriptor memrefDesc(newOperand);
7443f33d2f3SMatthias Springer         newOperand = memrefDesc.allocatedPtr(rewriter, loc);
7453f33d2f3SMatthias Springer       } else if (isa<UnrankedMemRefType>(oldTy)) {
7463f33d2f3SMatthias Springer         // Unranked memref is not supported in the bare pointer calling
7473f33d2f3SMatthias Springer         // convention.
7483f33d2f3SMatthias Springer         return failure();
7493f33d2f3SMatthias Springer       }
7503f33d2f3SMatthias Springer       updatedOperands.push_back(newOperand);
7513f33d2f3SMatthias Springer     }
7523f33d2f3SMatthias Springer   } else {
7533f33d2f3SMatthias Springer     updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
7543f33d2f3SMatthias Springer     (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
7553f33d2f3SMatthias Springer                                   updatedOperands,
7563f33d2f3SMatthias Springer                                   /*toDynamic=*/true);
7573f33d2f3SMatthias Springer   }
7583f33d2f3SMatthias Springer 
7593f33d2f3SMatthias Springer   // If ReturnOp has 0 or 1 operand, create it and return immediately.
7603f33d2f3SMatthias Springer   if (numArguments <= 1) {
7613f33d2f3SMatthias Springer     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
7623f33d2f3SMatthias Springer         op, TypeRange(), updatedOperands, op->getAttrs());
7633f33d2f3SMatthias Springer     return success();
7643f33d2f3SMatthias Springer   }
7653f33d2f3SMatthias Springer 
7663f33d2f3SMatthias Springer   // Otherwise, we need to pack the arguments into an LLVM struct type before
7673f33d2f3SMatthias Springer   // returning.
7683f33d2f3SMatthias Springer   auto packedType = getTypeConverter()->packFunctionResults(
7693f33d2f3SMatthias Springer       op.getOperandTypes(), useBarePtrCallConv);
7703f33d2f3SMatthias Springer   if (!packedType) {
7713f33d2f3SMatthias Springer     return rewriter.notifyMatchFailure(op, "could not convert result types");
7723f33d2f3SMatthias Springer   }
7733f33d2f3SMatthias Springer 
7743f33d2f3SMatthias Springer   Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
7753f33d2f3SMatthias Springer   for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
7763f33d2f3SMatthias Springer     packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
7773f33d2f3SMatthias Springer   }
7783f33d2f3SMatthias Springer   rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
7793f33d2f3SMatthias Springer                                               op->getAttrs());
7803f33d2f3SMatthias Springer   return success();
7813f33d2f3SMatthias Springer }
7823f33d2f3SMatthias Springer 
783499abb24SKrzysztof Drewniak void mlir::populateGpuMemorySpaceAttributeConversions(
784499abb24SKrzysztof Drewniak     TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
785499abb24SKrzysztof Drewniak   typeConverter.addTypeAttributeConversion(
786499abb24SKrzysztof Drewniak       [mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) {
787499abb24SKrzysztof Drewniak         gpu::AddressSpace memorySpace = memorySpaceAttr.getValue();
788499abb24SKrzysztof Drewniak         unsigned addressSpace = mapping(memorySpace);
789499abb24SKrzysztof Drewniak         return wrapNumericMemorySpace(memorySpaceAttr.getContext(),
790499abb24SKrzysztof Drewniak                                       addressSpace);
791499abb24SKrzysztof Drewniak       });
792499abb24SKrzysztof Drewniak }
793