xref: /llvm-project/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp (revision 9919295cfd05222159246d7448ec42392e98fbf2)
1 //===- ObjectHandler.cpp - Implements base ObjectManager attributes -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the `OffloadingLLVMTranslationAttrInterface` for the
10 // `SelectObject` attribute.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
15 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
16 
17 #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
18 #include "mlir/Target/LLVMIR/Export.h"
19 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
20 
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/LLVMContext.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/Support/FormatVariadic.h"
26 
27 using namespace mlir;
28 
29 namespace {
30 // Implementation of the `OffloadingLLVMTranslationAttrInterface` model.
31 class SelectObjectAttrImpl
32     : public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel<
33           SelectObjectAttrImpl> {
34 public:
35   // Translates a `gpu.binary`, embedding the binary into a host LLVM module as
36   // global binary string.
37   LogicalResult embedBinary(Attribute attribute, Operation *operation,
38                             llvm::IRBuilderBase &builder,
39                             LLVM::ModuleTranslation &moduleTranslation) const;
40 
41   // Translates a `gpu.launch_func` to a sequence of LLVM instructions resulting
42   // in a kernel launch call.
43   LogicalResult launchKernel(Attribute attribute,
44                              Operation *launchFuncOperation,
45                              Operation *binaryOperation,
46                              llvm::IRBuilderBase &builder,
47                              LLVM::ModuleTranslation &moduleTranslation) const;
48 
49   // Returns the selected object for embedding.
50   gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) const;
51 };
52 // Returns an identifier for the global string holding the binary.
53 std::string getBinaryIdentifier(StringRef binaryName) {
54   return binaryName.str() + "_bin_cst";
55 }
56 } // namespace
57 
58 void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(
59     DialectRegistry &registry) {
60   registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
61     SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
62   });
63 }
64 
65 gpu::ObjectAttr
66 SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op) const {
67   ArrayRef<Attribute> objects = op.getObjectsAttr().getValue();
68 
69   // Obtain the index of the object to select.
70   int64_t index = -1;
71   if (Attribute target =
72           cast<gpu::SelectObjectAttr>(op.getOffloadingHandlerAttr())
73               .getTarget()) {
74     // If the target attribute is a number it is the index. Otherwise compare
75     // the attribute to every target inside the object array to find the index.
76     if (auto indexAttr = mlir::dyn_cast<IntegerAttr>(target)) {
77       index = indexAttr.getInt();
78     } else {
79       for (auto [i, attr] : llvm::enumerate(objects)) {
80         auto obj = mlir::dyn_cast<gpu::ObjectAttr>(attr);
81         if (obj.getTarget() == target) {
82           index = i;
83         }
84       }
85     }
86   } else {
87     // If the target attribute is null then it's selecting the first object in
88     // the object array.
89     index = 0;
90   }
91 
92   if (index < 0 || index >= static_cast<int64_t>(objects.size())) {
93     op->emitError("the requested target object couldn't be found");
94     return nullptr;
95   }
96   return mlir::dyn_cast<gpu::ObjectAttr>(objects[index]);
97 }
98 
99 LogicalResult SelectObjectAttrImpl::embedBinary(
100     Attribute attribute, Operation *operation, llvm::IRBuilderBase &builder,
101     LLVM::ModuleTranslation &moduleTranslation) const {
102   assert(operation && "The binary operation must be non null.");
103   if (!operation)
104     return failure();
105 
106   auto op = mlir::dyn_cast<gpu::BinaryOp>(operation);
107   if (!op) {
108     operation->emitError("operation must be a GPU binary");
109     return failure();
110   }
111 
112   gpu::ObjectAttr object = getSelectedObject(op);
113   if (!object)
114     return failure();
115 
116   llvm::Module *module = moduleTranslation.getLLVMModule();
117 
118   // Embed the object as a global string.
119   llvm::Constant *binary = llvm::ConstantDataArray::getString(
120       builder.getContext(), object.getObject().getValue(), false);
121   llvm::GlobalVariable *serializedObj =
122       new llvm::GlobalVariable(*module, binary->getType(), true,
123                                llvm::GlobalValue::LinkageTypes::InternalLinkage,
124                                binary, getBinaryIdentifier(op.getName()));
125 
126   if (object.getProperties()) {
127     if (auto section = mlir::dyn_cast_or_null<mlir::StringAttr>(
128             object.getProperties().get(gpu::elfSectionName))) {
129       serializedObj->setSection(section.getValue());
130     }
131   }
132   serializedObj->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);
133   serializedObj->setAlignment(llvm::MaybeAlign(8));
134   serializedObj->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::None);
135   return success();
136 }
137 
138 namespace llvm {
139 namespace {
140 class LaunchKernel {
141 public:
142   LaunchKernel(Module &module, IRBuilderBase &builder,
143                mlir::LLVM::ModuleTranslation &moduleTranslation);
144   // Get the kernel launch callee.
145   FunctionCallee getKernelLaunchFn();
146 
147   // Get the kernel launch callee.
148   FunctionCallee getClusterKernelLaunchFn();
149 
150   // Get the module function callee.
151   FunctionCallee getModuleFunctionFn();
152 
153   // Get the module load callee.
154   FunctionCallee getModuleLoadFn();
155 
156   // Get the module load JIT callee.
157   FunctionCallee getModuleLoadJITFn();
158 
159   // Get the module unload callee.
160   FunctionCallee getModuleUnloadFn();
161 
162   // Get the stream create callee.
163   FunctionCallee getStreamCreateFn();
164 
165   // Get the stream destroy callee.
166   FunctionCallee getStreamDestroyFn();
167 
168   // Get the stream sync callee.
169   FunctionCallee getStreamSyncFn();
170 
171   // Ger or create the function name global string.
172   Value *getOrCreateFunctionName(StringRef moduleName, StringRef kernelName);
173 
174   // Create the void* kernel array for passing the arguments.
175   Value *createKernelArgArray(mlir::gpu::LaunchFuncOp op);
176 
177   // Create the full kernel launch.
178   llvm::LogicalResult createKernelLaunch(mlir::gpu::LaunchFuncOp op,
179                                          mlir::gpu::ObjectAttr object);
180 
181 private:
182   Module &module;
183   IRBuilderBase &builder;
184   mlir::LLVM::ModuleTranslation &moduleTranslation;
185   Type *i32Ty{};
186   Type *i64Ty{};
187   Type *voidTy{};
188   Type *intPtrTy{};
189   PointerType *ptrTy{};
190 };
191 } // namespace
192 } // namespace llvm
193 
194 LogicalResult SelectObjectAttrImpl::launchKernel(
195     Attribute attribute, Operation *launchFuncOperation,
196     Operation *binaryOperation, llvm::IRBuilderBase &builder,
197     LLVM::ModuleTranslation &moduleTranslation) const {
198 
199   assert(launchFuncOperation && "The launch func operation must be non null.");
200   if (!launchFuncOperation)
201     return failure();
202 
203   auto launchFuncOp = mlir::dyn_cast<gpu::LaunchFuncOp>(launchFuncOperation);
204   if (!launchFuncOp) {
205     launchFuncOperation->emitError("operation must be a GPU launch func Op.");
206     return failure();
207   }
208 
209   auto binOp = mlir::dyn_cast<gpu::BinaryOp>(binaryOperation);
210   if (!binOp) {
211     binaryOperation->emitError("operation must be a GPU binary.");
212     return failure();
213   }
214   gpu::ObjectAttr object = getSelectedObject(binOp);
215   if (!object)
216     return failure();
217 
218   return llvm::LaunchKernel(*moduleTranslation.getLLVMModule(), builder,
219                             moduleTranslation)
220       .createKernelLaunch(launchFuncOp, object);
221 }
222 
223 llvm::LaunchKernel::LaunchKernel(
224     Module &module, IRBuilderBase &builder,
225     mlir::LLVM::ModuleTranslation &moduleTranslation)
226     : module(module), builder(builder), moduleTranslation(moduleTranslation) {
227   i32Ty = builder.getInt32Ty();
228   i64Ty = builder.getInt64Ty();
229   ptrTy = builder.getPtrTy(0);
230   voidTy = builder.getVoidTy();
231   intPtrTy = builder.getIntPtrTy(module.getDataLayout());
232 }
233 
234 llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() {
235   return module.getOrInsertFunction(
236       "mgpuLaunchKernel",
237       FunctionType::get(voidTy,
238                         ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy,
239                                           intPtrTy, intPtrTy, intPtrTy, i32Ty,
240                                           ptrTy, ptrTy, ptrTy, i64Ty}),
241                         false));
242 }
243 
244 llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
245   return module.getOrInsertFunction(
246       "mgpuLaunchClusterKernel",
247       FunctionType::get(
248           voidTy,
249           ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
250                             intPtrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
251                             i32Ty, ptrTy, ptrTy, ptrTy}),
252           false));
253 }
254 
255 llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
256   return module.getOrInsertFunction(
257       "mgpuModuleGetFunction",
258       FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, ptrTy}), false));
259 }
260 
261 llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadFn() {
262   return module.getOrInsertFunction(
263       "mgpuModuleLoad",
264       FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i64Ty}), false));
265 }
266 
267 llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadJITFn() {
268   return module.getOrInsertFunction(
269       "mgpuModuleLoadJIT",
270       FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i32Ty}), false));
271 }
272 
273 llvm::FunctionCallee llvm::LaunchKernel::getModuleUnloadFn() {
274   return module.getOrInsertFunction(
275       "mgpuModuleUnload",
276       FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
277 }
278 
279 llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn() {
280   return module.getOrInsertFunction("mgpuStreamCreate",
281                                     FunctionType::get(ptrTy, false));
282 }
283 
284 llvm::FunctionCallee llvm::LaunchKernel::getStreamDestroyFn() {
285   return module.getOrInsertFunction(
286       "mgpuStreamDestroy",
287       FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
288 }
289 
290 llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() {
291   return module.getOrInsertFunction(
292       "mgpuStreamSynchronize",
293       FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
294 }
295 
296 // Generates an LLVM IR dialect global that contains the name of the given
297 // kernel function as a C string, and returns a pointer to its beginning.
298 llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName,
299                                                          StringRef kernelName) {
300   std::string globalName =
301       std::string(formatv("{0}_{1}_kernel_name", moduleName, kernelName));
302 
303   if (GlobalVariable *gv = module.getGlobalVariable(globalName))
304     return gv;
305 
306   return builder.CreateGlobalString(kernelName, globalName);
307 }
308 
309 // Creates a struct containing all kernel parameters on the stack and returns
310 // an array of type-erased pointers to the fields of the struct. The array can
311 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
312 // The generated code is essentially as follows:
313 //
314 // %struct = alloca(sizeof(struct { Parameters... }))
315 // %array = alloca(NumParameters * sizeof(void *))
316 // for (i : [0, NumParameters))
317 //   %fieldPtr = llvm.getelementptr %struct[0, i]
318 //   llvm.store parameters[i], %fieldPtr
319 //   %elementPtr = llvm.getelementptr %array[i]
320 //   llvm.store %fieldPtr, %elementPtr
321 // return %array
322 llvm::Value *
323 llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) {
324   SmallVector<Value *> args =
325       moduleTranslation.lookupValues(op.getKernelOperands());
326   SmallVector<Type *> structTypes(args.size(), nullptr);
327 
328   for (auto [i, arg] : llvm::enumerate(args))
329     structTypes[i] = arg->getType();
330 
331   Type *structTy = StructType::create(module.getContext(), structTypes);
332   Value *argStruct = builder.CreateAlloca(structTy, 0u);
333   Value *argArray = builder.CreateAlloca(
334       ptrTy, ConstantInt::get(intPtrTy, structTypes.size()));
335 
336   for (auto [i, arg] : enumerate(args)) {
337     Value *structMember = builder.CreateStructGEP(structTy, argStruct, i);
338     builder.CreateStore(arg, structMember);
339     Value *arrayMember = builder.CreateConstGEP1_32(ptrTy, argArray, i);
340     builder.CreateStore(structMember, arrayMember);
341   }
342   return argArray;
343 }
344 
345 // Emits LLVM IR to launch a kernel function:
346 // %0 = call %binarygetter
347 // %1 = call %moduleLoad(%0)
348 // %2 = <see generateKernelNameConstant>
349 // %3 = call %moduleGetFunction(%1, %2)
350 // %4 = call %streamCreate()
351 // %5 = <see generateParamsArray>
352 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
353 // call %streamSynchronize(%4)
354 // call %streamDestroy(%4)
355 // call %moduleUnload(%1)
356 llvm::LogicalResult
357 llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
358                                        mlir::gpu::ObjectAttr object) {
359   auto llvmValue = [&](mlir::Value value) -> Value * {
360     Value *v = moduleTranslation.lookupValue(value);
361     assert(v && "Value has not been translated.");
362     return v;
363   };
364 
365   // Get grid dimensions.
366   mlir::gpu::KernelDim3 grid = op.getGridSizeOperandValues();
367   Value *gx = llvmValue(grid.x), *gy = llvmValue(grid.y),
368         *gz = llvmValue(grid.z);
369 
370   // Get block dimensions.
371   mlir::gpu::KernelDim3 block = op.getBlockSizeOperandValues();
372   Value *bx = llvmValue(block.x), *by = llvmValue(block.y),
373         *bz = llvmValue(block.z);
374 
375   // Get dynamic shared memory size.
376   Value *dynamicMemorySize = nullptr;
377   if (mlir::Value dynSz = op.getDynamicSharedMemorySize())
378     dynamicMemorySize = llvmValue(dynSz);
379   else
380     dynamicMemorySize = ConstantInt::get(i32Ty, 0);
381 
382   // Create the argument array.
383   Value *argArray = createKernelArgArray(op);
384 
385   // Default JIT optimization level.
386   llvm::Constant *optV = llvm::ConstantInt::get(i32Ty, 0);
387   // Check if there's an optimization level embedded in the object.
388   DictionaryAttr objectProps = object.getProperties();
389   mlir::Attribute optAttr;
390   if (objectProps && (optAttr = objectProps.get("O"))) {
391     auto optLevel = dyn_cast<IntegerAttr>(optAttr);
392     if (!optLevel)
393       return op.emitError("the optimization level must be an integer");
394     optV = llvm::ConstantInt::get(i32Ty, optLevel.getValue());
395   }
396 
397   // Load the kernel module.
398   StringRef moduleName = op.getKernelModuleName().getValue();
399   std::string binaryIdentifier = getBinaryIdentifier(moduleName);
400   Value *binary = module.getGlobalVariable(binaryIdentifier, true);
401   if (!binary)
402     return op.emitError() << "Couldn't find the binary: " << binaryIdentifier;
403 
404   auto binaryVar = dyn_cast<llvm::GlobalVariable>(binary);
405   if (!binaryVar)
406     return op.emitError() << "Binary is not a global variable: "
407                           << binaryIdentifier;
408   llvm::Constant *binaryInit = binaryVar->getInitializer();
409   auto binaryDataSeq =
410       dyn_cast_if_present<llvm::ConstantDataSequential>(binaryInit);
411   if (!binaryDataSeq)
412     return op.emitError() << "Couldn't find binary data array: "
413                           << binaryIdentifier;
414   llvm::Constant *binarySize =
415       llvm::ConstantInt::get(i64Ty, binaryDataSeq->getNumElements() *
416                                         binaryDataSeq->getElementByteSize());
417 
418   Value *moduleObject =
419       object.getFormat() == gpu::CompilationTarget::Assembly
420           ? builder.CreateCall(getModuleLoadJITFn(), {binary, optV})
421           : builder.CreateCall(getModuleLoadFn(), {binary, binarySize});
422 
423   // Load the kernel function.
424   Value *moduleFunction = builder.CreateCall(
425       getModuleFunctionFn(),
426       {moduleObject,
427        getOrCreateFunctionName(moduleName, op.getKernelName().getValue())});
428 
429   // Get the stream to use for execution. If there's no async object then create
430   // a stream to make a synchronous kernel launch.
431   Value *stream = nullptr;
432   bool handleStream = false;
433   if (mlir::Value asyncObject = op.getAsyncObject()) {
434     stream = llvmValue(asyncObject);
435   } else {
436     handleStream = true;
437     stream = builder.CreateCall(getStreamCreateFn(), {});
438   }
439 
440   llvm::Constant *paramsCount =
441       llvm::ConstantInt::get(i64Ty, op.getNumKernelOperands());
442 
443   // Create the launch call.
444   Value *nullPtr = ConstantPointerNull::get(ptrTy);
445 
446   // Launch kernel with clusters if cluster size is specified.
447   if (op.hasClusterSize()) {
448     mlir::gpu::KernelDim3 cluster = op.getClusterSizeOperandValues();
449     Value *cx = llvmValue(cluster.x), *cy = llvmValue(cluster.y),
450           *cz = llvmValue(cluster.z);
451     builder.CreateCall(
452         getClusterKernelLaunchFn(),
453         ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
454                            dynamicMemorySize, stream, argArray, nullPtr}));
455   } else {
456     builder.CreateCall(getKernelLaunchFn(),
457                        ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by,
458                                           bz, dynamicMemorySize, stream,
459                                           argArray, nullPtr, paramsCount}));
460   }
461 
462   // Sync & destroy the stream, for synchronous launches.
463   if (handleStream) {
464     builder.CreateCall(getStreamSyncFn(), {stream});
465     builder.CreateCall(getStreamDestroyFn(), {stream});
466   }
467 
468   // Unload the kernel module.
469   builder.CreateCall(getModuleUnloadFn(), {moduleObject});
470 
471   return success();
472 }
473