xref: /llvm-project/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp (revision 599c73990532333e62edf8ba19a5302b543f976f)
1 //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "GPUOpsLowering.h"
10 
11 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/IR/Attributes.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "llvm/ADT/SmallVectorExtras.h"
17 #include "llvm/ADT/StringSet.h"
18 #include "llvm/Support/FormatVariadic.h"
19 
20 using namespace mlir;
21 
22 LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
23                                            Location loc, OpBuilder &b,
24                                            StringRef name,
25                                            LLVM::LLVMFunctionType type) {
26   LLVM::LLVMFuncOp ret;
27   if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
28     OpBuilder::InsertionGuard guard(b);
29     b.setInsertionPointToStart(moduleOp.getBody());
30     ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External);
31   }
32   return ret;
33 }
34 
35 static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
36                                            StringRef prefix) {
37   // Get a unique global name.
38   unsigned stringNumber = 0;
39   SmallString<16> stringConstName;
40   do {
41     stringConstName.clear();
42     (prefix + Twine(stringNumber++)).toStringRef(stringConstName);
43   } while (moduleOp.lookupSymbol(stringConstName));
44   return stringConstName;
45 }
46 
47 LLVM::GlobalOp
48 mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
49                                 gpu::GPUModuleOp moduleOp, Type llvmI8,
50                                 StringRef namePrefix, StringRef str,
51                                 uint64_t alignment, unsigned addrSpace) {
52   llvm::SmallString<20> nullTermStr(str);
53   nullTermStr.push_back('\0'); // Null terminate for C
54   auto globalType =
55       LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes());
56   StringAttr attr = b.getStringAttr(nullTermStr);
57 
58   // Try to find existing global.
59   for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
60     if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
61         globalOp.getValueAttr() == attr &&
62         globalOp.getAlignment().value_or(0) == alignment &&
63         globalOp.getAddrSpace() == addrSpace)
64       return globalOp;
65 
66   // Not found: create new global.
67   OpBuilder::InsertionGuard guard(b);
68   b.setInsertionPointToStart(moduleOp.getBody());
69   SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
70   return b.create<LLVM::GlobalOp>(loc, globalType,
71                                   /*isConstant=*/true, LLVM::Linkage::Internal,
72                                   name, attr, alignment, addrSpace);
73 }
74 
75 LogicalResult
76 GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
77                                    ConversionPatternRewriter &rewriter) const {
78   Location loc = gpuFuncOp.getLoc();
79 
80   SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
81   if (encodeWorkgroupAttributionsAsArguments) {
82     // Append an `llvm.ptr` argument to the function signature to encode
83     // workgroup attributions.
84 
85     ArrayRef<BlockArgument> workgroupAttributions =
86         gpuFuncOp.getWorkgroupAttributions();
87     size_t numAttributions = workgroupAttributions.size();
88 
89     // Insert all arguments at the end.
90     unsigned index = gpuFuncOp.getNumArguments();
91     SmallVector<unsigned> argIndices(numAttributions, index);
92 
93     // New arguments will simply be `llvm.ptr` with the correct address space
94     Type workgroupPtrType =
95         rewriter.getType<LLVM::LLVMPointerType>(workgroupAddrSpace);
96     SmallVector<Type> argTypes(numAttributions, workgroupPtrType);
97 
98     // Attributes: noalias, llvm.mlir.workgroup_attribution(<size>, <type>)
99     std::array attrs{
100         rewriter.getNamedAttr(LLVM::LLVMDialect::getNoAliasAttrName(),
101                               rewriter.getUnitAttr()),
102         rewriter.getNamedAttr(
103             getDialect().getWorkgroupAttributionAttrHelper().getName(),
104             rewriter.getUnitAttr()),
105     };
106     SmallVector<DictionaryAttr> argAttrs;
107     for (BlockArgument attribution : workgroupAttributions) {
108       auto attributionType = cast<MemRefType>(attribution.getType());
109       IntegerAttr numElements =
110           rewriter.getI64IntegerAttr(attributionType.getNumElements());
111       Type llvmElementType =
112           getTypeConverter()->convertType(attributionType.getElementType());
113       if (!llvmElementType)
114         return failure();
115       TypeAttr type = TypeAttr::get(llvmElementType);
116       attrs.back().setValue(
117           rewriter.getAttr<LLVM::WorkgroupAttributionAttr>(numElements, type));
118       argAttrs.push_back(rewriter.getDictionaryAttr(attrs));
119     }
120 
121     // Location match function location
122     SmallVector<Location> argLocs(numAttributions, gpuFuncOp.getLoc());
123 
124     // Perform signature modification
125     rewriter.modifyOpInPlace(
126         gpuFuncOp, [gpuFuncOp, &argIndices, &argTypes, &argAttrs, &argLocs]() {
127           static_cast<FunctionOpInterface>(gpuFuncOp).insertArguments(
128               argIndices, argTypes, argAttrs, argLocs);
129         });
130   } else {
131     workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
132     for (auto [idx, attribution] :
133          llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
134       auto type = dyn_cast<MemRefType>(attribution.getType());
135       assert(type && type.hasStaticShape() && "unexpected type in attribution");
136 
137       uint64_t numElements = type.getNumElements();
138 
139       auto elementType =
140           cast<Type>(typeConverter->convertType(type.getElementType()));
141       auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
142       std::string name =
143           std::string(llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), idx));
144       uint64_t alignment = 0;
145       if (auto alignAttr = dyn_cast_or_null<IntegerAttr>(
146               gpuFuncOp.getWorkgroupAttributionAttr(
147                   idx, LLVM::LLVMDialect::getAlignAttrName())))
148         alignment = alignAttr.getInt();
149       auto globalOp = rewriter.create<LLVM::GlobalOp>(
150           gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
151           LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment,
152           workgroupAddrSpace);
153       workgroupBuffers.push_back(globalOp);
154     }
155   }
156 
157   // Remap proper input types.
158   TypeConverter::SignatureConversion signatureConversion(
159       gpuFuncOp.front().getNumArguments());
160 
161   Type funcType = getTypeConverter()->convertFunctionSignature(
162       gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
163       getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
164   if (!funcType) {
165     return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) {
166       diag << "failed to convert function signature type for: "
167            << gpuFuncOp.getFunctionType();
168     });
169   }
170 
171   // Create the new function operation. Only copy those attributes that are
172   // not specific to function modeling.
173   SmallVector<NamedAttribute, 4> attributes;
174   ArrayAttr argAttrs;
175   for (const auto &attr : gpuFuncOp->getAttrs()) {
176     if (attr.getName() == SymbolTable::getSymbolAttrName() ||
177         attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
178         attr.getName() ==
179             gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() ||
180         attr.getName() == gpuFuncOp.getWorkgroupAttribAttrsAttrName() ||
181         attr.getName() == gpuFuncOp.getPrivateAttribAttrsAttrName() ||
182         attr.getName() == gpuFuncOp.getKnownBlockSizeAttrName() ||
183         attr.getName() == gpuFuncOp.getKnownGridSizeAttrName())
184       continue;
185     if (attr.getName() == gpuFuncOp.getArgAttrsAttrName()) {
186       argAttrs = gpuFuncOp.getArgAttrsAttr();
187       continue;
188     }
189     attributes.push_back(attr);
190   }
191 
192   DenseI32ArrayAttr knownBlockSize = gpuFuncOp.getKnownBlockSizeAttr();
193   DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr();
194   // Ensure we don't lose information if the function is lowered before its
195   // surrounding context.
196   auto *gpuDialect = cast<gpu::GPUDialect>(gpuFuncOp->getDialect());
197   if (knownBlockSize)
198     attributes.emplace_back(gpuDialect->getKnownBlockSizeAttrHelper().getName(),
199                             knownBlockSize);
200   if (knownGridSize)
201     attributes.emplace_back(gpuDialect->getKnownGridSizeAttrHelper().getName(),
202                             knownGridSize);
203 
204   // Add a dialect specific kernel attribute in addition to GPU kernel
205   // attribute. The former is necessary for further translation while the
206   // latter is expected by gpu.launch_func.
207   if (gpuFuncOp.isKernel()) {
208     if (kernelAttributeName)
209       attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
210     // Set the dialect-specific block size attribute if there is one.
211     if (kernelBlockSizeAttributeName && knownBlockSize) {
212       attributes.emplace_back(kernelBlockSizeAttributeName, knownBlockSize);
213     }
214   }
215   LLVM::CConv callingConvention = gpuFuncOp.isKernel()
216                                       ? kernelCallingConvention
217                                       : nonKernelCallingConvention;
218   auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
219       gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
220       LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention,
221       /*comdat=*/nullptr, attributes);
222 
223   {
224     // Insert operations that correspond to converted workgroup and private
225     // memory attributions to the body of the function. This must operate on
226     // the original function, before the body region is inlined in the new
227     // function to maintain the relation between block arguments and the
228     // parent operation that assigns their semantics.
229     OpBuilder::InsertionGuard guard(rewriter);
230 
231     // Rewrite workgroup memory attributions to addresses of global buffers.
232     rewriter.setInsertionPointToStart(&gpuFuncOp.front());
233     unsigned numProperArguments = gpuFuncOp.getNumArguments();
234 
235     if (encodeWorkgroupAttributionsAsArguments) {
236       // Build a MemRefDescriptor with each of the arguments added above.
237 
238       unsigned numAttributions = gpuFuncOp.getNumWorkgroupAttributions();
239       assert(numProperArguments >= numAttributions &&
240              "Expecting attributions to be encoded as arguments already");
241 
242       // Arguments encoding workgroup attributions will be in positions
243       // [numProperArguments, numProperArguments+numAttributions)
244       ArrayRef<BlockArgument> attributionArguments =
245           gpuFuncOp.getArguments().slice(numProperArguments - numAttributions,
246                                          numAttributions);
247       for (auto [idx, vals] : llvm::enumerate(llvm::zip_equal(
248                gpuFuncOp.getWorkgroupAttributions(), attributionArguments))) {
249         auto [attribution, arg] = vals;
250         auto type = cast<MemRefType>(attribution.getType());
251 
252         // Arguments are of llvm.ptr type and attributions are of memref type:
253         // we need to wrap them in memref descriptors.
254         Value descr = MemRefDescriptor::fromStaticShape(
255             rewriter, loc, *getTypeConverter(), type, arg);
256 
257         // And remap the arguments
258         signatureConversion.remapInput(numProperArguments + idx, descr);
259       }
260     } else {
261       for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
262         auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
263                                                   global.getAddrSpace());
264         Value address = rewriter.create<LLVM::AddressOfOp>(
265             loc, ptrType, global.getSymNameAttr());
266         Value memory =
267             rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(),
268                                          address, ArrayRef<LLVM::GEPArg>{0, 0});
269 
270         // Build a memref descriptor pointing to the buffer to plug with the
271         // existing memref infrastructure. This may use more registers than
272         // otherwise necessary given that memref sizes are fixed, but we can try
273         // and canonicalize that away later.
274         Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
275         auto type = cast<MemRefType>(attribution.getType());
276         auto descr = MemRefDescriptor::fromStaticShape(
277             rewriter, loc, *getTypeConverter(), type, memory);
278         signatureConversion.remapInput(numProperArguments + idx, descr);
279       }
280     }
281 
282     // Rewrite private memory attributions to alloca'ed buffers.
283     unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
284     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
285     for (const auto [idx, attribution] :
286          llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
287       auto type = cast<MemRefType>(attribution.getType());
288       assert(type && type.hasStaticShape() && "unexpected type in attribution");
289 
290       // Explicitly drop memory space when lowering private memory
291       // attributions since NVVM models it as `alloca`s in the default
292       // memory space and does not support `alloca`s with addrspace(5).
293       Type elementType = typeConverter->convertType(type.getElementType());
294       auto ptrType =
295           LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace);
296       Value numElements = rewriter.create<LLVM::ConstantOp>(
297           gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
298       uint64_t alignment = 0;
299       if (auto alignAttr =
300               dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
301                   idx, LLVM::LLVMDialect::getAlignAttrName())))
302         alignment = alignAttr.getInt();
303       Value allocated = rewriter.create<LLVM::AllocaOp>(
304           gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
305       auto descr = MemRefDescriptor::fromStaticShape(
306           rewriter, loc, *getTypeConverter(), type, allocated);
307       signatureConversion.remapInput(
308           numProperArguments + numWorkgroupAttributions + idx, descr);
309     }
310   }
311 
312   // Move the region to the new function, update the entry block signature.
313   rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
314                               llvmFuncOp.end());
315   if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
316                                          &signatureConversion)))
317     return failure();
318 
319   // Get memref type from function arguments and set the noalias to
320   // pointer arguments.
321   for (const auto [idx, argTy] :
322        llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
323     auto remapping = signatureConversion.getInputMapping(idx);
324     NamedAttrList argAttr =
325         argAttrs ? cast<DictionaryAttr>(argAttrs[idx]) : NamedAttrList();
326     auto copyAttribute = [&](StringRef attrName) {
327       Attribute attr = argAttr.erase(attrName);
328       if (!attr)
329         return;
330       for (size_t i = 0, e = remapping->size; i < e; ++i)
331         llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
332     };
333     auto copyPointerAttribute = [&](StringRef attrName) {
334       Attribute attr = argAttr.erase(attrName);
335 
336       if (!attr)
337         return;
338       if (remapping->size > 1 &&
339           attrName == LLVM::LLVMDialect::getNoAliasAttrName()) {
340         emitWarning(llvmFuncOp.getLoc(),
341                     "Cannot copy noalias with non-bare pointers.\n");
342         return;
343       }
344       for (size_t i = 0, e = remapping->size; i < e; ++i) {
345         if (isa<LLVM::LLVMPointerType>(
346                 llvmFuncOp.getArgument(remapping->inputNo + i).getType())) {
347           llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
348         }
349       }
350     };
351 
352     if (argAttr.empty())
353       continue;
354 
355     copyAttribute(LLVM::LLVMDialect::getReturnedAttrName());
356     copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName());
357     copyAttribute(LLVM::LLVMDialect::getInRegAttrName());
358     bool lowersToPointer = false;
359     for (size_t i = 0, e = remapping->size; i < e; ++i) {
360       lowersToPointer |= isa<LLVM::LLVMPointerType>(
361           llvmFuncOp.getArgument(remapping->inputNo + i).getType());
362     }
363 
364     if (lowersToPointer) {
365       copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName());
366       copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName());
367       copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName());
368       copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName());
369       copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName());
370       copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName());
371       copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName());
372       copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName());
373       copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName());
374       copyPointerAttribute(
375           LLVM::LLVMDialect::getDereferenceableOrNullAttrName());
376       copyPointerAttribute(
377           LLVM::LLVMDialect::WorkgroupAttributionAttrHelper::getNameStr());
378     }
379   }
380   rewriter.eraseOp(gpuFuncOp);
381   return success();
382 }
383 
384 LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
385     gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
386     ConversionPatternRewriter &rewriter) const {
387   Location loc = gpuPrintfOp->getLoc();
388 
389   mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
390   auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
391   mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
392   mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
393   // Note: this is the GPUModule op, not the ModuleOp that surrounds it
394   // This ensures that global constants and declarations are placed within
395   // the device code, not the host code
396   auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
397 
398   auto ocklBegin =
399       getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
400                           LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
401   LLVM::LLVMFuncOp ocklAppendArgs;
402   if (!adaptor.getArgs().empty()) {
403     ocklAppendArgs = getOrDefineFunction(
404         moduleOp, loc, rewriter, "__ockl_printf_append_args",
405         LLVM::LLVMFunctionType::get(
406             llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64,
407                       llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32}));
408   }
409   auto ocklAppendStringN = getOrDefineFunction(
410       moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
411       LLVM::LLVMFunctionType::get(
412           llvmI64,
413           {llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
414 
415   /// Start the printf hostcall
416   Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
417   auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
418   Value printfDesc = printfBeginCall.getResult();
419 
420   // Create the global op or find an existing one.
421   LLVM::GlobalOp global = getOrCreateStringConstant(
422       rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
423 
424   // Get a pointer to the format string's first element and pass it to printf()
425   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
426       loc,
427       LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
428       global.getSymNameAttr());
429   Value stringStart =
430       rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
431                                    globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
432   Value stringLen = rewriter.create<LLVM::ConstantOp>(
433       loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
434 
435   Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
436   Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
437 
438   auto appendFormatCall = rewriter.create<LLVM::CallOp>(
439       loc, ocklAppendStringN,
440       ValueRange{printfDesc, stringStart, stringLen,
441                  adaptor.getArgs().empty() ? oneI32 : zeroI32});
442   printfDesc = appendFormatCall.getResult();
443 
444   // __ockl_printf_append_args takes 7 values per append call
445   constexpr size_t argsPerAppend = 7;
446   size_t nArgs = adaptor.getArgs().size();
447   for (size_t group = 0; group < nArgs; group += argsPerAppend) {
448     size_t bound = std::min(group + argsPerAppend, nArgs);
449     size_t numArgsThisCall = bound - group;
450 
451     SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
452     arguments.push_back(printfDesc);
453     arguments.push_back(
454         rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
455     for (size_t i = group; i < bound; ++i) {
456       Value arg = adaptor.getArgs()[i];
457       if (auto floatType = dyn_cast<FloatType>(arg.getType())) {
458         if (!floatType.isF64())
459           arg = rewriter.create<LLVM::FPExtOp>(
460               loc, typeConverter->convertType(rewriter.getF64Type()), arg);
461         arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
462       }
463       if (arg.getType().getIntOrFloatBitWidth() != 64)
464         arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
465 
466       arguments.push_back(arg);
467     }
468     // Pad out to 7 arguments since the hostcall always needs 7
469     for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
470       arguments.push_back(zeroI64);
471     }
472 
473     auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
474     arguments.push_back(isLast);
475     auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
476     printfDesc = call.getResult();
477   }
478   rewriter.eraseOp(gpuPrintfOp);
479   return success();
480 }
481 
482 LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
483     gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
484     ConversionPatternRewriter &rewriter) const {
485   Location loc = gpuPrintfOp->getLoc();
486 
487   mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
488   mlir::Type ptrType =
489       LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
490 
491   // Note: this is the GPUModule op, not the ModuleOp that surrounds it
492   // This ensures that global constants and declarations are placed within
493   // the device code, not the host code
494   auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
495 
496   auto printfType =
497       LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
498                                   /*isVarArg=*/true);
499   LLVM::LLVMFuncOp printfDecl =
500       getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
501 
502   // Create the global op or find an existing one.
503   LLVM::GlobalOp global = getOrCreateStringConstant(
504       rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(),
505       /*alignment=*/0, addressSpace);
506 
507   // Get a pointer to the format string's first element
508   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
509       loc,
510       LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
511       global.getSymNameAttr());
512   Value stringStart =
513       rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
514                                    globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
515 
516   // Construct arguments and function call
517   auto argsRange = adaptor.getArgs();
518   SmallVector<Value, 4> printfArgs;
519   printfArgs.reserve(argsRange.size() + 1);
520   printfArgs.push_back(stringStart);
521   printfArgs.append(argsRange.begin(), argsRange.end());
522 
523   rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
524   rewriter.eraseOp(gpuPrintfOp);
525   return success();
526 }
527 
528 LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
529     gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
530     ConversionPatternRewriter &rewriter) const {
531   Location loc = gpuPrintfOp->getLoc();
532 
533   mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
534   mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
535 
536   // Note: this is the GPUModule op, not the ModuleOp that surrounds it
537   // This ensures that global constants and declarations are placed within
538   // the device code, not the host code
539   auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
540 
541   auto vprintfType =
542       LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
543   LLVM::LLVMFuncOp vprintfDecl =
544       getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
545 
546   // Create the global op or find an existing one.
547   LLVM::GlobalOp global = getOrCreateStringConstant(
548       rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
549 
550   // Get a pointer to the format string's first element
551   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
552   Value stringStart =
553       rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
554                                    globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
555   SmallVector<Type> types;
556   SmallVector<Value> args;
557   // Promote and pack the arguments into a stack allocation.
558   for (Value arg : adaptor.getArgs()) {
559     Type type = arg.getType();
560     Value promotedArg = arg;
561     assert(type.isIntOrFloat());
562     if (isa<FloatType>(type)) {
563       type = rewriter.getF64Type();
564       promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg);
565     }
566     types.push_back(type);
567     args.push_back(promotedArg);
568   }
569   Type structType =
570       LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
571   Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
572                                                 rewriter.getIndexAttr(1));
573   Value tempAlloc =
574       rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one,
575                                       /*alignment=*/0);
576   for (auto [index, arg] : llvm::enumerate(args)) {
577     Value ptr = rewriter.create<LLVM::GEPOp>(
578         loc, ptrType, structType, tempAlloc,
579         ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)});
580     rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
581   }
582   std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
583 
584   rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
585   rewriter.eraseOp(gpuPrintfOp);
586   return success();
587 }
588 
589 /// Unrolls op if it's operating on vectors.
590 LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
591                                       ConversionPatternRewriter &rewriter,
592                                       const LLVMTypeConverter &converter) {
593   TypeRange operandTypes(operands);
594   if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) {
595     return rewriter.notifyMatchFailure(op, "expected vector operand");
596   }
597   if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
598     return rewriter.notifyMatchFailure(op, "expected no region/successor");
599   if (op->getNumResults() != 1)
600     return rewriter.notifyMatchFailure(op, "expected single result");
601   VectorType vectorType = dyn_cast<VectorType>(op->getResult(0).getType());
602   if (!vectorType)
603     return rewriter.notifyMatchFailure(op, "expected vector result");
604 
605   Location loc = op->getLoc();
606   Value result = rewriter.create<LLVM::UndefOp>(loc, vectorType);
607   Type indexType = converter.convertType(rewriter.getIndexType());
608   StringAttr name = op->getName().getIdentifier();
609   Type elementType = vectorType.getElementType();
610 
611   for (int64_t i = 0; i < vectorType.getNumElements(); ++i) {
612     Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i);
613     auto extractElement = [&](Value operand) -> Value {
614       if (!isa<VectorType>(operand.getType()))
615         return operand;
616       return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index);
617     };
618     auto scalarOperands = llvm::map_to_vector(operands, extractElement);
619     Operation *scalarOp =
620         rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs());
621     result = rewriter.create<LLVM::InsertElementOp>(
622         loc, result, scalarOp->getResult(0), index);
623   }
624 
625   rewriter.replaceOp(op, result);
626   return success();
627 }
628 
629 static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
630   return IntegerAttr::get(IntegerType::get(ctx, 64), space);
631 }
632 
633 /// Generates a symbol with 0-sized array type for dynamic shared memory usage,
634 /// or uses existing symbol.
635 LLVM::GlobalOp getDynamicSharedMemorySymbol(
636     ConversionPatternRewriter &rewriter, gpu::GPUModuleOp moduleOp,
637     gpu::DynamicSharedMemoryOp op, const LLVMTypeConverter *typeConverter,
638     MemRefType memrefType, unsigned alignmentBit) {
639   uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
640 
641   FailureOr<unsigned> addressSpace =
642       typeConverter->getMemRefAddressSpace(memrefType);
643   if (failed(addressSpace)) {
644     op->emitError() << "conversion of memref memory space "
645                     << memrefType.getMemorySpace()
646                     << " to integer address space "
647                        "failed. Consider adding memory space conversions.";
648   }
649 
650   // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
651   // LLVM::GlobalOp is suitable for shared memory, return it.
652   llvm::StringSet<> existingGlobalNames;
653   for (auto globalOp : moduleOp.getBody()->getOps<LLVM::GlobalOp>()) {
654     existingGlobalNames.insert(globalOp.getSymName());
655     if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
656       if (globalOp.getAddrSpace() == addressSpace.value() &&
657           arrayType.getNumElements() == 0 &&
658           globalOp.getAlignment().value_or(0) == alignmentByte) {
659         return globalOp;
660       }
661     }
662   }
663 
664   // Step 2. Find a unique symbol name
665   unsigned uniquingCounter = 0;
666   SmallString<128> symName = SymbolTable::generateSymbolName<128>(
667       "__dynamic_shmem_",
668       [&](StringRef candidate) {
669         return existingGlobalNames.contains(candidate);
670       },
671       uniquingCounter);
672 
673   // Step 3. Generate a global op
674   OpBuilder::InsertionGuard guard(rewriter);
675   rewriter.setInsertionPointToStart(moduleOp.getBody());
676 
677   auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
678       typeConverter->convertType(memrefType.getElementType()), 0);
679 
680   return rewriter.create<LLVM::GlobalOp>(
681       op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
682       LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
683       addressSpace.value());
684 }
685 
686 LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
687     gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
688     ConversionPatternRewriter &rewriter) const {
689   Location loc = op.getLoc();
690   MemRefType memrefType = op.getResultMemref().getType();
691   Type elementType = typeConverter->convertType(memrefType.getElementType());
692 
693   // Step 1: Generate a memref<0xi8> type
694   MemRefLayoutAttrInterface layout = {};
695   auto memrefType0sz =
696       MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
697 
698   // Step 2: Generate a global symbol or existing for the dynamic shared
699   // memory with memref<0xi8> type
700   auto moduleOp = op->getParentOfType<gpu::GPUModuleOp>();
701   LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol(
702       rewriter, moduleOp, op, getTypeConverter(), memrefType0sz, alignmentBit);
703 
704   // Step 3. Get address of the global symbol
705   OpBuilder::InsertionGuard guard(rewriter);
706   rewriter.setInsertionPoint(op);
707   auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp);
708   Type baseType = basePtr->getResultTypes().front();
709 
710   // Step 4. Generate GEP using offsets
711   SmallVector<LLVM::GEPArg> gepArgs = {0};
712   Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType,
713                                                 basePtr, gepArgs);
714   // Step 5. Create a memref descriptor
715   SmallVector<Value> shape, strides;
716   Value sizeBytes;
717   getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides,
718                            sizeBytes);
719   auto memRefDescriptor = this->createMemRefDescriptor(
720       loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter);
721 
722   // Step 5. Replace the op with memref descriptor
723   rewriter.replaceOp(op, {memRefDescriptor});
724   return success();
725 }
726 
727 LogicalResult GPUReturnOpLowering::matchAndRewrite(
728     gpu::ReturnOp op, OpAdaptor adaptor,
729     ConversionPatternRewriter &rewriter) const {
730   Location loc = op.getLoc();
731   unsigned numArguments = op.getNumOperands();
732   SmallVector<Value, 4> updatedOperands;
733 
734   bool useBarePtrCallConv = getTypeConverter()->getOptions().useBarePtrCallConv;
735   if (useBarePtrCallConv) {
736     // For the bare-ptr calling convention, extract the aligned pointer to
737     // be returned from the memref descriptor.
738     for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
739       Type oldTy = std::get<0>(it).getType();
740       Value newOperand = std::get<1>(it);
741       if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
742                                         cast<BaseMemRefType>(oldTy))) {
743         MemRefDescriptor memrefDesc(newOperand);
744         newOperand = memrefDesc.allocatedPtr(rewriter, loc);
745       } else if (isa<UnrankedMemRefType>(oldTy)) {
746         // Unranked memref is not supported in the bare pointer calling
747         // convention.
748         return failure();
749       }
750       updatedOperands.push_back(newOperand);
751     }
752   } else {
753     updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
754     (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
755                                   updatedOperands,
756                                   /*toDynamic=*/true);
757   }
758 
759   // If ReturnOp has 0 or 1 operand, create it and return immediately.
760   if (numArguments <= 1) {
761     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
762         op, TypeRange(), updatedOperands, op->getAttrs());
763     return success();
764   }
765 
766   // Otherwise, we need to pack the arguments into an LLVM struct type before
767   // returning.
768   auto packedType = getTypeConverter()->packFunctionResults(
769       op.getOperandTypes(), useBarePtrCallConv);
770   if (!packedType) {
771     return rewriter.notifyMatchFailure(op, "could not convert result types");
772   }
773 
774   Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
775   for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
776     packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
777   }
778   rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
779                                               op->getAttrs());
780   return success();
781 }
782 
783 void mlir::populateGpuMemorySpaceAttributeConversions(
784     TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
785   typeConverter.addTypeAttributeConversion(
786       [mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) {
787         gpu::AddressSpace memorySpace = memorySpaceAttr.getValue();
788         unsigned addressSpace = mapping(memorySpace);
789         return wrapNumericMemorySpace(memorySpaceAttr.getContext(),
790                                       addressSpace);
791       });
792 }
793