xref: /llvm-project/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp (revision 9919295cfd05222159246d7448ec42392e98fbf2)
18ae074b1SFabian Mora //===- ObjectHandler.cpp - Implements base ObjectManager attributes -------===//
28ae074b1SFabian Mora //
38ae074b1SFabian Mora // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48ae074b1SFabian Mora // See https://llvm.org/LICENSE.txt for license information.
58ae074b1SFabian Mora // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68ae074b1SFabian Mora //
78ae074b1SFabian Mora //===----------------------------------------------------------------------===//
88ae074b1SFabian Mora //
98ae074b1SFabian Mora // This file implements the `OffloadingLLVMTranslationAttrInterface` for the
108ae074b1SFabian Mora // `SelectObject` attribute.
118ae074b1SFabian Mora //
128ae074b1SFabian Mora //===----------------------------------------------------------------------===//
138ae074b1SFabian Mora 
14*9919295cSRenaud Kauffmann #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
158ae074b1SFabian Mora #include "mlir/Dialect/GPU/IR/GPUDialect.h"
168ae074b1SFabian Mora 
178ae074b1SFabian Mora #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
188ae074b1SFabian Mora #include "mlir/Target/LLVMIR/Export.h"
198ae074b1SFabian Mora #include "mlir/Target/LLVMIR/ModuleTranslation.h"
208ae074b1SFabian Mora 
218ae074b1SFabian Mora #include "llvm/IR/Constants.h"
228ae074b1SFabian Mora #include "llvm/IR/IRBuilder.h"
238ae074b1SFabian Mora #include "llvm/IR/LLVMContext.h"
248ae074b1SFabian Mora #include "llvm/IR/Module.h"
258ae074b1SFabian Mora #include "llvm/Support/FormatVariadic.h"
268ae074b1SFabian Mora 
278ae074b1SFabian Mora using namespace mlir;
288ae074b1SFabian Mora 
298ae074b1SFabian Mora namespace {
308ae074b1SFabian Mora // Implementation of the `OffloadingLLVMTranslationAttrInterface` model.
318ae074b1SFabian Mora class SelectObjectAttrImpl
328ae074b1SFabian Mora     : public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel<
338ae074b1SFabian Mora           SelectObjectAttrImpl> {
348ae074b1SFabian Mora public:
358ae074b1SFabian Mora   // Translates a `gpu.binary`, embedding the binary into a host LLVM module as
368ae074b1SFabian Mora   // global binary string.
378ae074b1SFabian Mora   LogicalResult embedBinary(Attribute attribute, Operation *operation,
388ae074b1SFabian Mora                             llvm::IRBuilderBase &builder,
398ae074b1SFabian Mora                             LLVM::ModuleTranslation &moduleTranslation) const;
408ae074b1SFabian Mora 
418ae074b1SFabian Mora   // Translates a `gpu.launch_func` to a sequence of LLVM instructions resulting
428ae074b1SFabian Mora   // in a kernel launch call.
438ae074b1SFabian Mora   LogicalResult launchKernel(Attribute attribute,
448ae074b1SFabian Mora                              Operation *launchFuncOperation,
458ae074b1SFabian Mora                              Operation *binaryOperation,
468ae074b1SFabian Mora                              llvm::IRBuilderBase &builder,
478ae074b1SFabian Mora                              LLVM::ModuleTranslation &moduleTranslation) const;
485093413aSFabian Mora 
495093413aSFabian Mora   // Returns the selected object for embedding.
505093413aSFabian Mora   gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) const;
518ae074b1SFabian Mora };
528ae074b1SFabian Mora // Returns an identifier for the global string holding the binary.
538ae074b1SFabian Mora std::string getBinaryIdentifier(StringRef binaryName) {
548ae074b1SFabian Mora   return binaryName.str() + "_bin_cst";
558ae074b1SFabian Mora }
568ae074b1SFabian Mora } // namespace
578ae074b1SFabian Mora 
587c4e8c6aSNicolas Vasilache void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(
598ae074b1SFabian Mora     DialectRegistry &registry) {
608ae074b1SFabian Mora   registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
618ae074b1SFabian Mora     SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
628ae074b1SFabian Mora   });
638ae074b1SFabian Mora }
648ae074b1SFabian Mora 
655093413aSFabian Mora gpu::ObjectAttr
665093413aSFabian Mora SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op) const {
678ae074b1SFabian Mora   ArrayRef<Attribute> objects = op.getObjectsAttr().getValue();
688ae074b1SFabian Mora 
698ae074b1SFabian Mora   // Obtain the index of the object to select.
708ae074b1SFabian Mora   int64_t index = -1;
715093413aSFabian Mora   if (Attribute target =
725093413aSFabian Mora           cast<gpu::SelectObjectAttr>(op.getOffloadingHandlerAttr())
735093413aSFabian Mora               .getTarget()) {
748ae074b1SFabian Mora     // If the target attribute is a number it is the index. Otherwise compare
758ae074b1SFabian Mora     // the attribute to every target inside the object array to find the index.
768ae074b1SFabian Mora     if (auto indexAttr = mlir::dyn_cast<IntegerAttr>(target)) {
778ae074b1SFabian Mora       index = indexAttr.getInt();
788ae074b1SFabian Mora     } else {
798ae074b1SFabian Mora       for (auto [i, attr] : llvm::enumerate(objects)) {
808ae074b1SFabian Mora         auto obj = mlir::dyn_cast<gpu::ObjectAttr>(attr);
818ae074b1SFabian Mora         if (obj.getTarget() == target) {
828ae074b1SFabian Mora           index = i;
838ae074b1SFabian Mora         }
848ae074b1SFabian Mora       }
858ae074b1SFabian Mora     }
868ae074b1SFabian Mora   } else {
878ae074b1SFabian Mora     // If the target attribute is null then it's selecting the first object in
888ae074b1SFabian Mora     // the object array.
898ae074b1SFabian Mora     index = 0;
908ae074b1SFabian Mora   }
918ae074b1SFabian Mora 
928ae074b1SFabian Mora   if (index < 0 || index >= static_cast<int64_t>(objects.size())) {
935093413aSFabian Mora     op->emitError("the requested target object couldn't be found");
945093413aSFabian Mora     return nullptr;
955093413aSFabian Mora   }
965093413aSFabian Mora   return mlir::dyn_cast<gpu::ObjectAttr>(objects[index]);
975093413aSFabian Mora }
985093413aSFabian Mora 
995093413aSFabian Mora LogicalResult SelectObjectAttrImpl::embedBinary(
1005093413aSFabian Mora     Attribute attribute, Operation *operation, llvm::IRBuilderBase &builder,
1015093413aSFabian Mora     LLVM::ModuleTranslation &moduleTranslation) const {
1025093413aSFabian Mora   assert(operation && "The binary operation must be non null.");
1035093413aSFabian Mora   if (!operation)
1045093413aSFabian Mora     return failure();
1055093413aSFabian Mora 
1065093413aSFabian Mora   auto op = mlir::dyn_cast<gpu::BinaryOp>(operation);
1075093413aSFabian Mora   if (!op) {
1085093413aSFabian Mora     operation->emitError("operation must be a GPU binary");
1098ae074b1SFabian Mora     return failure();
1108ae074b1SFabian Mora   }
1115093413aSFabian Mora 
1125093413aSFabian Mora   gpu::ObjectAttr object = getSelectedObject(op);
1135093413aSFabian Mora   if (!object)
1145093413aSFabian Mora     return failure();
1158ae074b1SFabian Mora 
1168ae074b1SFabian Mora   llvm::Module *module = moduleTranslation.getLLVMModule();
1178ae074b1SFabian Mora 
1188ae074b1SFabian Mora   // Embed the object as a global string.
1198ae074b1SFabian Mora   llvm::Constant *binary = llvm::ConstantDataArray::getString(
1208ae074b1SFabian Mora       builder.getContext(), object.getObject().getValue(), false);
1218ae074b1SFabian Mora   llvm::GlobalVariable *serializedObj =
1228ae074b1SFabian Mora       new llvm::GlobalVariable(*module, binary->getType(), true,
1238ae074b1SFabian Mora                                llvm::GlobalValue::LinkageTypes::InternalLinkage,
1248ae074b1SFabian Mora                                binary, getBinaryIdentifier(op.getName()));
1257fcc0f90SRenaud Kauffmann 
1267fcc0f90SRenaud Kauffmann   if (object.getProperties()) {
1277fcc0f90SRenaud Kauffmann     if (auto section = mlir::dyn_cast_or_null<mlir::StringAttr>(
128*9919295cSRenaud Kauffmann             object.getProperties().get(gpu::elfSectionName))) {
1297fcc0f90SRenaud Kauffmann       serializedObj->setSection(section.getValue());
1307fcc0f90SRenaud Kauffmann     }
1317fcc0f90SRenaud Kauffmann   }
1328ae074b1SFabian Mora   serializedObj->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);
1338ae074b1SFabian Mora   serializedObj->setAlignment(llvm::MaybeAlign(8));
1348ae074b1SFabian Mora   serializedObj->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::None);
1358ae074b1SFabian Mora   return success();
1368ae074b1SFabian Mora }
1378ae074b1SFabian Mora 
1388ae074b1SFabian Mora namespace llvm {
1398ae074b1SFabian Mora namespace {
1408ae074b1SFabian Mora class LaunchKernel {
1418ae074b1SFabian Mora public:
1428ae074b1SFabian Mora   LaunchKernel(Module &module, IRBuilderBase &builder,
1438ae074b1SFabian Mora                mlir::LLVM::ModuleTranslation &moduleTranslation);
1448ae074b1SFabian Mora   // Get the kernel launch callee.
1458ae074b1SFabian Mora   FunctionCallee getKernelLaunchFn();
1468ae074b1SFabian Mora 
147edf5cae7SGuray Ozen   // Get the kernel launch callee.
148edf5cae7SGuray Ozen   FunctionCallee getClusterKernelLaunchFn();
149edf5cae7SGuray Ozen 
1508ae074b1SFabian Mora   // Get the module function callee.
1518ae074b1SFabian Mora   FunctionCallee getModuleFunctionFn();
1528ae074b1SFabian Mora 
1538ae074b1SFabian Mora   // Get the module load callee.
1548ae074b1SFabian Mora   FunctionCallee getModuleLoadFn();
1558ae074b1SFabian Mora 
1565093413aSFabian Mora   // Get the module load JIT callee.
1575093413aSFabian Mora   FunctionCallee getModuleLoadJITFn();
1585093413aSFabian Mora 
1598ae074b1SFabian Mora   // Get the module unload callee.
1608ae074b1SFabian Mora   FunctionCallee getModuleUnloadFn();
1618ae074b1SFabian Mora 
1628ae074b1SFabian Mora   // Get the stream create callee.
1638ae074b1SFabian Mora   FunctionCallee getStreamCreateFn();
1648ae074b1SFabian Mora 
1658ae074b1SFabian Mora   // Get the stream destroy callee.
1668ae074b1SFabian Mora   FunctionCallee getStreamDestroyFn();
1678ae074b1SFabian Mora 
1688ae074b1SFabian Mora   // Get the stream sync callee.
1698ae074b1SFabian Mora   FunctionCallee getStreamSyncFn();
1708ae074b1SFabian Mora 
1718ae074b1SFabian Mora   // Ger or create the function name global string.
1728ae074b1SFabian Mora   Value *getOrCreateFunctionName(StringRef moduleName, StringRef kernelName);
1738ae074b1SFabian Mora 
1748ae074b1SFabian Mora   // Create the void* kernel array for passing the arguments.
1758ae074b1SFabian Mora   Value *createKernelArgArray(mlir::gpu::LaunchFuncOp op);
1768ae074b1SFabian Mora 
1778ae074b1SFabian Mora   // Create the full kernel launch.
178db791b27SRamkumar Ramachandra   llvm::LogicalResult createKernelLaunch(mlir::gpu::LaunchFuncOp op,
1795093413aSFabian Mora                                          mlir::gpu::ObjectAttr object);
1808ae074b1SFabian Mora 
1818ae074b1SFabian Mora private:
1828ae074b1SFabian Mora   Module &module;
1838ae074b1SFabian Mora   IRBuilderBase &builder;
1848ae074b1SFabian Mora   mlir::LLVM::ModuleTranslation &moduleTranslation;
1858ae074b1SFabian Mora   Type *i32Ty{};
1867fc792cbSSang Ik Lee   Type *i64Ty{};
1878ae074b1SFabian Mora   Type *voidTy{};
1888ae074b1SFabian Mora   Type *intPtrTy{};
1898ae074b1SFabian Mora   PointerType *ptrTy{};
1908ae074b1SFabian Mora };
1918ae074b1SFabian Mora } // namespace
1928ae074b1SFabian Mora } // namespace llvm
1938ae074b1SFabian Mora 
1948ae074b1SFabian Mora LogicalResult SelectObjectAttrImpl::launchKernel(
1958ae074b1SFabian Mora     Attribute attribute, Operation *launchFuncOperation,
1968ae074b1SFabian Mora     Operation *binaryOperation, llvm::IRBuilderBase &builder,
1978ae074b1SFabian Mora     LLVM::ModuleTranslation &moduleTranslation) const {
1988ae074b1SFabian Mora 
1998ae074b1SFabian Mora   assert(launchFuncOperation && "The launch func operation must be non null.");
2008ae074b1SFabian Mora   if (!launchFuncOperation)
2018ae074b1SFabian Mora     return failure();
2028ae074b1SFabian Mora 
2038ae074b1SFabian Mora   auto launchFuncOp = mlir::dyn_cast<gpu::LaunchFuncOp>(launchFuncOperation);
2048ae074b1SFabian Mora   if (!launchFuncOp) {
2055093413aSFabian Mora     launchFuncOperation->emitError("operation must be a GPU launch func Op.");
2068ae074b1SFabian Mora     return failure();
2078ae074b1SFabian Mora   }
2088ae074b1SFabian Mora 
2095093413aSFabian Mora   auto binOp = mlir::dyn_cast<gpu::BinaryOp>(binaryOperation);
2105093413aSFabian Mora   if (!binOp) {
2115093413aSFabian Mora     binaryOperation->emitError("operation must be a GPU binary.");
2125093413aSFabian Mora     return failure();
2135093413aSFabian Mora   }
2145093413aSFabian Mora   gpu::ObjectAttr object = getSelectedObject(binOp);
2155093413aSFabian Mora   if (!object)
2165093413aSFabian Mora     return failure();
2175093413aSFabian Mora 
2188ae074b1SFabian Mora   return llvm::LaunchKernel(*moduleTranslation.getLLVMModule(), builder,
2198ae074b1SFabian Mora                             moduleTranslation)
2205093413aSFabian Mora       .createKernelLaunch(launchFuncOp, object);
2218ae074b1SFabian Mora }
2228ae074b1SFabian Mora 
2238ae074b1SFabian Mora llvm::LaunchKernel::LaunchKernel(
2248ae074b1SFabian Mora     Module &module, IRBuilderBase &builder,
2258ae074b1SFabian Mora     mlir::LLVM::ModuleTranslation &moduleTranslation)
2268ae074b1SFabian Mora     : module(module), builder(builder), moduleTranslation(moduleTranslation) {
2278ae074b1SFabian Mora   i32Ty = builder.getInt32Ty();
2287fc792cbSSang Ik Lee   i64Ty = builder.getInt64Ty();
2298ae074b1SFabian Mora   ptrTy = builder.getPtrTy(0);
2308ae074b1SFabian Mora   voidTy = builder.getVoidTy();
2318ae074b1SFabian Mora   intPtrTy = builder.getIntPtrTy(module.getDataLayout());
2328ae074b1SFabian Mora }
2338ae074b1SFabian Mora 
2348ae074b1SFabian Mora llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() {
2358ae074b1SFabian Mora   return module.getOrInsertFunction(
2368ae074b1SFabian Mora       "mgpuLaunchKernel",
2377fc792cbSSang Ik Lee       FunctionType::get(voidTy,
2387fc792cbSSang Ik Lee                         ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy,
2397fc792cbSSang Ik Lee                                           intPtrTy, intPtrTy, intPtrTy, i32Ty,
2407fc792cbSSang Ik Lee                                           ptrTy, ptrTy, ptrTy, i64Ty}),
2418ae074b1SFabian Mora                         false));
2428ae074b1SFabian Mora }
2438ae074b1SFabian Mora 
244edf5cae7SGuray Ozen llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
245edf5cae7SGuray Ozen   return module.getOrInsertFunction(
246edf5cae7SGuray Ozen       "mgpuLaunchClusterKernel",
247edf5cae7SGuray Ozen       FunctionType::get(
248edf5cae7SGuray Ozen           voidTy,
249edf5cae7SGuray Ozen           ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
250edf5cae7SGuray Ozen                             intPtrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
251edf5cae7SGuray Ozen                             i32Ty, ptrTy, ptrTy, ptrTy}),
252edf5cae7SGuray Ozen           false));
253edf5cae7SGuray Ozen }
254edf5cae7SGuray Ozen 
2558ae074b1SFabian Mora llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
2568ae074b1SFabian Mora   return module.getOrInsertFunction(
2578ae074b1SFabian Mora       "mgpuModuleGetFunction",
2588ae074b1SFabian Mora       FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, ptrTy}), false));
2598ae074b1SFabian Mora }
2608ae074b1SFabian Mora 
2618ae074b1SFabian Mora llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadFn() {
2628ae074b1SFabian Mora   return module.getOrInsertFunction(
2638ae074b1SFabian Mora       "mgpuModuleLoad",
2647fc792cbSSang Ik Lee       FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i64Ty}), false));
2658ae074b1SFabian Mora }
2668ae074b1SFabian Mora 
2675093413aSFabian Mora llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadJITFn() {
2685093413aSFabian Mora   return module.getOrInsertFunction(
2695093413aSFabian Mora       "mgpuModuleLoadJIT",
2705093413aSFabian Mora       FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i32Ty}), false));
2715093413aSFabian Mora }
2725093413aSFabian Mora 
2738ae074b1SFabian Mora llvm::FunctionCallee llvm::LaunchKernel::getModuleUnloadFn() {
2748ae074b1SFabian Mora   return module.getOrInsertFunction(
2758ae074b1SFabian Mora       "mgpuModuleUnload",
2768ae074b1SFabian Mora       FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
2778ae074b1SFabian Mora }
2788ae074b1SFabian Mora 
2798ae074b1SFabian Mora llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn() {
2808ae074b1SFabian Mora   return module.getOrInsertFunction("mgpuStreamCreate",
2818ae074b1SFabian Mora                                     FunctionType::get(ptrTy, false));
2828ae074b1SFabian Mora }
2838ae074b1SFabian Mora 
2848ae074b1SFabian Mora llvm::FunctionCallee llvm::LaunchKernel::getStreamDestroyFn() {
2858ae074b1SFabian Mora   return module.getOrInsertFunction(
2868ae074b1SFabian Mora       "mgpuStreamDestroy",
2878ae074b1SFabian Mora       FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
2888ae074b1SFabian Mora }
2898ae074b1SFabian Mora 
2908ae074b1SFabian Mora llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() {
2918ae074b1SFabian Mora   return module.getOrInsertFunction(
2928ae074b1SFabian Mora       "mgpuStreamSynchronize",
2938ae074b1SFabian Mora       FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
2948ae074b1SFabian Mora }
2958ae074b1SFabian Mora 
2968ae074b1SFabian Mora // Generates an LLVM IR dialect global that contains the name of the given
2978ae074b1SFabian Mora // kernel function as a C string, and returns a pointer to its beginning.
2988ae074b1SFabian Mora llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName,
2998ae074b1SFabian Mora                                                          StringRef kernelName) {
3008ae074b1SFabian Mora   std::string globalName =
3018ae074b1SFabian Mora       std::string(formatv("{0}_{1}_kernel_name", moduleName, kernelName));
3028ae074b1SFabian Mora 
3038ae074b1SFabian Mora   if (GlobalVariable *gv = module.getGlobalVariable(globalName))
3048ae074b1SFabian Mora     return gv;
3058ae074b1SFabian Mora 
3068ae074b1SFabian Mora   return builder.CreateGlobalString(kernelName, globalName);
3078ae074b1SFabian Mora }
3088ae074b1SFabian Mora 
3098ae074b1SFabian Mora // Creates a struct containing all kernel parameters on the stack and returns
3108ae074b1SFabian Mora // an array of type-erased pointers to the fields of the struct. The array can
3118ae074b1SFabian Mora // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
3128ae074b1SFabian Mora // The generated code is essentially as follows:
3138ae074b1SFabian Mora //
3148ae074b1SFabian Mora // %struct = alloca(sizeof(struct { Parameters... }))
3158ae074b1SFabian Mora // %array = alloca(NumParameters * sizeof(void *))
3168ae074b1SFabian Mora // for (i : [0, NumParameters))
3178ae074b1SFabian Mora //   %fieldPtr = llvm.getelementptr %struct[0, i]
3188ae074b1SFabian Mora //   llvm.store parameters[i], %fieldPtr
3198ae074b1SFabian Mora //   %elementPtr = llvm.getelementptr %array[i]
3208ae074b1SFabian Mora //   llvm.store %fieldPtr, %elementPtr
3218ae074b1SFabian Mora // return %array
3228ae074b1SFabian Mora llvm::Value *
3238ae074b1SFabian Mora llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) {
3248ae074b1SFabian Mora   SmallVector<Value *> args =
3258ae074b1SFabian Mora       moduleTranslation.lookupValues(op.getKernelOperands());
3268ae074b1SFabian Mora   SmallVector<Type *> structTypes(args.size(), nullptr);
3278ae074b1SFabian Mora 
3288ae074b1SFabian Mora   for (auto [i, arg] : llvm::enumerate(args))
3298ae074b1SFabian Mora     structTypes[i] = arg->getType();
3308ae074b1SFabian Mora 
3318ae074b1SFabian Mora   Type *structTy = StructType::create(module.getContext(), structTypes);
3328ae074b1SFabian Mora   Value *argStruct = builder.CreateAlloca(structTy, 0u);
3338ae074b1SFabian Mora   Value *argArray = builder.CreateAlloca(
3348ae074b1SFabian Mora       ptrTy, ConstantInt::get(intPtrTy, structTypes.size()));
3358ae074b1SFabian Mora 
3368ae074b1SFabian Mora   for (auto [i, arg] : enumerate(args)) {
3378ae074b1SFabian Mora     Value *structMember = builder.CreateStructGEP(structTy, argStruct, i);
3388ae074b1SFabian Mora     builder.CreateStore(arg, structMember);
3398ae074b1SFabian Mora     Value *arrayMember = builder.CreateConstGEP1_32(ptrTy, argArray, i);
3408ae074b1SFabian Mora     builder.CreateStore(structMember, arrayMember);
3418ae074b1SFabian Mora   }
3428ae074b1SFabian Mora   return argArray;
3438ae074b1SFabian Mora }
3448ae074b1SFabian Mora 
3458ae074b1SFabian Mora // Emits LLVM IR to launch a kernel function:
3468ae074b1SFabian Mora // %0 = call %binarygetter
3478ae074b1SFabian Mora // %1 = call %moduleLoad(%0)
3488ae074b1SFabian Mora // %2 = <see generateKernelNameConstant>
3498ae074b1SFabian Mora // %3 = call %moduleGetFunction(%1, %2)
3508ae074b1SFabian Mora // %4 = call %streamCreate()
3518ae074b1SFabian Mora // %5 = <see generateParamsArray>
3528ae074b1SFabian Mora // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
3538ae074b1SFabian Mora // call %streamSynchronize(%4)
3548ae074b1SFabian Mora // call %streamDestroy(%4)
3558ae074b1SFabian Mora // call %moduleUnload(%1)
356db791b27SRamkumar Ramachandra llvm::LogicalResult
3575093413aSFabian Mora llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
3585093413aSFabian Mora                                        mlir::gpu::ObjectAttr object) {
3598ae074b1SFabian Mora   auto llvmValue = [&](mlir::Value value) -> Value * {
3608ae074b1SFabian Mora     Value *v = moduleTranslation.lookupValue(value);
3618ae074b1SFabian Mora     assert(v && "Value has not been translated.");
3628ae074b1SFabian Mora     return v;
3638ae074b1SFabian Mora   };
3648ae074b1SFabian Mora 
3658ae074b1SFabian Mora   // Get grid dimensions.
3668ae074b1SFabian Mora   mlir::gpu::KernelDim3 grid = op.getGridSizeOperandValues();
3678ae074b1SFabian Mora   Value *gx = llvmValue(grid.x), *gy = llvmValue(grid.y),
3688ae074b1SFabian Mora         *gz = llvmValue(grid.z);
3698ae074b1SFabian Mora 
3708ae074b1SFabian Mora   // Get block dimensions.
3718ae074b1SFabian Mora   mlir::gpu::KernelDim3 block = op.getBlockSizeOperandValues();
3728ae074b1SFabian Mora   Value *bx = llvmValue(block.x), *by = llvmValue(block.y),
3738ae074b1SFabian Mora         *bz = llvmValue(block.z);
3748ae074b1SFabian Mora 
3758ae074b1SFabian Mora   // Get dynamic shared memory size.
3768ae074b1SFabian Mora   Value *dynamicMemorySize = nullptr;
3778ae074b1SFabian Mora   if (mlir::Value dynSz = op.getDynamicSharedMemorySize())
3788ae074b1SFabian Mora     dynamicMemorySize = llvmValue(dynSz);
3798ae074b1SFabian Mora   else
3808ae074b1SFabian Mora     dynamicMemorySize = ConstantInt::get(i32Ty, 0);
3818ae074b1SFabian Mora 
3828ae074b1SFabian Mora   // Create the argument array.
3838ae074b1SFabian Mora   Value *argArray = createKernelArgArray(op);
3848ae074b1SFabian Mora 
3855093413aSFabian Mora   // Default JIT optimization level.
3865093413aSFabian Mora   llvm::Constant *optV = llvm::ConstantInt::get(i32Ty, 0);
3875093413aSFabian Mora   // Check if there's an optimization level embedded in the object.
3885093413aSFabian Mora   DictionaryAttr objectProps = object.getProperties();
3895093413aSFabian Mora   mlir::Attribute optAttr;
3905093413aSFabian Mora   if (objectProps && (optAttr = objectProps.get("O"))) {
3915093413aSFabian Mora     auto optLevel = dyn_cast<IntegerAttr>(optAttr);
3925093413aSFabian Mora     if (!optLevel)
3935093413aSFabian Mora       return op.emitError("the optimization level must be an integer");
3945093413aSFabian Mora     optV = llvm::ConstantInt::get(i32Ty, optLevel.getValue());
3955093413aSFabian Mora   }
3965093413aSFabian Mora 
3978ae074b1SFabian Mora   // Load the kernel module.
3988ae074b1SFabian Mora   StringRef moduleName = op.getKernelModuleName().getValue();
3998ae074b1SFabian Mora   std::string binaryIdentifier = getBinaryIdentifier(moduleName);
4008ae074b1SFabian Mora   Value *binary = module.getGlobalVariable(binaryIdentifier, true);
4018ae074b1SFabian Mora   if (!binary)
4028ae074b1SFabian Mora     return op.emitError() << "Couldn't find the binary: " << binaryIdentifier;
4035093413aSFabian Mora 
4047fc792cbSSang Ik Lee   auto binaryVar = dyn_cast<llvm::GlobalVariable>(binary);
4057fc792cbSSang Ik Lee   if (!binaryVar)
4067fc792cbSSang Ik Lee     return op.emitError() << "Binary is not a global variable: "
4077fc792cbSSang Ik Lee                           << binaryIdentifier;
4087fc792cbSSang Ik Lee   llvm::Constant *binaryInit = binaryVar->getInitializer();
4097fc792cbSSang Ik Lee   auto binaryDataSeq =
4107fc792cbSSang Ik Lee       dyn_cast_if_present<llvm::ConstantDataSequential>(binaryInit);
4117fc792cbSSang Ik Lee   if (!binaryDataSeq)
4127fc792cbSSang Ik Lee     return op.emitError() << "Couldn't find binary data array: "
4137fc792cbSSang Ik Lee                           << binaryIdentifier;
4147fc792cbSSang Ik Lee   llvm::Constant *binarySize =
4157fc792cbSSang Ik Lee       llvm::ConstantInt::get(i64Ty, binaryDataSeq->getNumElements() *
4167fc792cbSSang Ik Lee                                         binaryDataSeq->getElementByteSize());
4177fc792cbSSang Ik Lee 
4185093413aSFabian Mora   Value *moduleObject =
4195093413aSFabian Mora       object.getFormat() == gpu::CompilationTarget::Assembly
4205093413aSFabian Mora           ? builder.CreateCall(getModuleLoadJITFn(), {binary, optV})
4217fc792cbSSang Ik Lee           : builder.CreateCall(getModuleLoadFn(), {binary, binarySize});
4228ae074b1SFabian Mora 
4238ae074b1SFabian Mora   // Load the kernel function.
4248ae074b1SFabian Mora   Value *moduleFunction = builder.CreateCall(
4258ae074b1SFabian Mora       getModuleFunctionFn(),
4268ae074b1SFabian Mora       {moduleObject,
4278ae074b1SFabian Mora        getOrCreateFunctionName(moduleName, op.getKernelName().getValue())});
4288ae074b1SFabian Mora 
4298ae074b1SFabian Mora   // Get the stream to use for execution. If there's no async object then create
4308ae074b1SFabian Mora   // a stream to make a synchronous kernel launch.
4318ae074b1SFabian Mora   Value *stream = nullptr;
4328ae074b1SFabian Mora   bool handleStream = false;
4338ae074b1SFabian Mora   if (mlir::Value asyncObject = op.getAsyncObject()) {
4348ae074b1SFabian Mora     stream = llvmValue(asyncObject);
4358ae074b1SFabian Mora   } else {
4368ae074b1SFabian Mora     handleStream = true;
4378ae074b1SFabian Mora     stream = builder.CreateCall(getStreamCreateFn(), {});
4388ae074b1SFabian Mora   }
4398ae074b1SFabian Mora 
4407fc792cbSSang Ik Lee   llvm::Constant *paramsCount =
4417fc792cbSSang Ik Lee       llvm::ConstantInt::get(i64Ty, op.getNumKernelOperands());
4427fc792cbSSang Ik Lee 
4438ae074b1SFabian Mora   // Create the launch call.
4448ae074b1SFabian Mora   Value *nullPtr = ConstantPointerNull::get(ptrTy);
445edf5cae7SGuray Ozen 
446edf5cae7SGuray Ozen   // Launch kernel with clusters if cluster size is specified.
447edf5cae7SGuray Ozen   if (op.hasClusterSize()) {
448edf5cae7SGuray Ozen     mlir::gpu::KernelDim3 cluster = op.getClusterSizeOperandValues();
449edf5cae7SGuray Ozen     Value *cx = llvmValue(cluster.x), *cy = llvmValue(cluster.y),
450edf5cae7SGuray Ozen           *cz = llvmValue(cluster.z);
451edf5cae7SGuray Ozen     builder.CreateCall(
452edf5cae7SGuray Ozen         getClusterKernelLaunchFn(),
453edf5cae7SGuray Ozen         ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
454edf5cae7SGuray Ozen                            dynamicMemorySize, stream, argArray, nullPtr}));
455edf5cae7SGuray Ozen   } else {
4567fc792cbSSang Ik Lee     builder.CreateCall(getKernelLaunchFn(),
4577fc792cbSSang Ik Lee                        ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by,
4587fc792cbSSang Ik Lee                                           bz, dynamicMemorySize, stream,
4597fc792cbSSang Ik Lee                                           argArray, nullPtr, paramsCount}));
460edf5cae7SGuray Ozen   }
4618ae074b1SFabian Mora 
4628ae074b1SFabian Mora   // Sync & destroy the stream, for synchronous launches.
4638ae074b1SFabian Mora   if (handleStream) {
4648ae074b1SFabian Mora     builder.CreateCall(getStreamSyncFn(), {stream});
4658ae074b1SFabian Mora     builder.CreateCall(getStreamDestroyFn(), {stream});
4668ae074b1SFabian Mora   }
4678ae074b1SFabian Mora 
4688ae074b1SFabian Mora   // Unload the kernel module.
4698ae074b1SFabian Mora   builder.CreateCall(getModuleUnloadFn(), {moduleObject});
4708ae074b1SFabian Mora 
4718ae074b1SFabian Mora   return success();
4728ae074b1SFabian Mora }
473