1 //===- ObjectHandler.cpp - Implements base ObjectManager attributes -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the `OffloadingLLVMTranslationAttrInterface` for the 10 // `SelectObject` attribute. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" 15 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 16 17 #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" 18 #include "mlir/Target/LLVMIR/Export.h" 19 #include "mlir/Target/LLVMIR/ModuleTranslation.h" 20 21 #include "llvm/IR/Constants.h" 22 #include "llvm/IR/IRBuilder.h" 23 #include "llvm/IR/LLVMContext.h" 24 #include "llvm/IR/Module.h" 25 #include "llvm/Support/FormatVariadic.h" 26 27 using namespace mlir; 28 29 namespace { 30 // Implementation of the `OffloadingLLVMTranslationAttrInterface` model. 31 class SelectObjectAttrImpl 32 : public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel< 33 SelectObjectAttrImpl> { 34 public: 35 // Translates a `gpu.binary`, embedding the binary into a host LLVM module as 36 // global binary string. 37 LogicalResult embedBinary(Attribute attribute, Operation *operation, 38 llvm::IRBuilderBase &builder, 39 LLVM::ModuleTranslation &moduleTranslation) const; 40 41 // Translates a `gpu.launch_func` to a sequence of LLVM instructions resulting 42 // in a kernel launch call. 43 LogicalResult launchKernel(Attribute attribute, 44 Operation *launchFuncOperation, 45 Operation *binaryOperation, 46 llvm::IRBuilderBase &builder, 47 LLVM::ModuleTranslation &moduleTranslation) const; 48 49 // Returns the selected object for embedding. 50 gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) const; 51 }; 52 // Returns an identifier for the global string holding the binary. 53 std::string getBinaryIdentifier(StringRef binaryName) { 54 return binaryName.str() + "_bin_cst"; 55 } 56 } // namespace 57 58 void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels( 59 DialectRegistry ®istry) { 60 registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) { 61 SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx); 62 }); 63 } 64 65 gpu::ObjectAttr 66 SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op) const { 67 ArrayRef<Attribute> objects = op.getObjectsAttr().getValue(); 68 69 // Obtain the index of the object to select. 70 int64_t index = -1; 71 if (Attribute target = 72 cast<gpu::SelectObjectAttr>(op.getOffloadingHandlerAttr()) 73 .getTarget()) { 74 // If the target attribute is a number it is the index. Otherwise compare 75 // the attribute to every target inside the object array to find the index. 76 if (auto indexAttr = mlir::dyn_cast<IntegerAttr>(target)) { 77 index = indexAttr.getInt(); 78 } else { 79 for (auto [i, attr] : llvm::enumerate(objects)) { 80 auto obj = mlir::dyn_cast<gpu::ObjectAttr>(attr); 81 if (obj.getTarget() == target) { 82 index = i; 83 } 84 } 85 } 86 } else { 87 // If the target attribute is null then it's selecting the first object in 88 // the object array. 89 index = 0; 90 } 91 92 if (index < 0 || index >= static_cast<int64_t>(objects.size())) { 93 op->emitError("the requested target object couldn't be found"); 94 return nullptr; 95 } 96 return mlir::dyn_cast<gpu::ObjectAttr>(objects[index]); 97 } 98 99 LogicalResult SelectObjectAttrImpl::embedBinary( 100 Attribute attribute, Operation *operation, llvm::IRBuilderBase &builder, 101 LLVM::ModuleTranslation &moduleTranslation) const { 102 assert(operation && "The binary operation must be non null."); 103 if (!operation) 104 return failure(); 105 106 auto op = mlir::dyn_cast<gpu::BinaryOp>(operation); 107 if (!op) { 108 operation->emitError("operation must be a GPU binary"); 109 return failure(); 110 } 111 112 gpu::ObjectAttr object = getSelectedObject(op); 113 if (!object) 114 return failure(); 115 116 llvm::Module *module = moduleTranslation.getLLVMModule(); 117 118 // Embed the object as a global string. 119 llvm::Constant *binary = llvm::ConstantDataArray::getString( 120 builder.getContext(), object.getObject().getValue(), false); 121 llvm::GlobalVariable *serializedObj = 122 new llvm::GlobalVariable(*module, binary->getType(), true, 123 llvm::GlobalValue::LinkageTypes::InternalLinkage, 124 binary, getBinaryIdentifier(op.getName())); 125 126 if (object.getProperties()) { 127 if (auto section = mlir::dyn_cast_or_null<mlir::StringAttr>( 128 object.getProperties().get(gpu::elfSectionName))) { 129 serializedObj->setSection(section.getValue()); 130 } 131 } 132 serializedObj->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage); 133 serializedObj->setAlignment(llvm::MaybeAlign(8)); 134 serializedObj->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::None); 135 return success(); 136 } 137 138 namespace llvm { 139 namespace { 140 class LaunchKernel { 141 public: 142 LaunchKernel(Module &module, IRBuilderBase &builder, 143 mlir::LLVM::ModuleTranslation &moduleTranslation); 144 // Get the kernel launch callee. 145 FunctionCallee getKernelLaunchFn(); 146 147 // Get the kernel launch callee. 148 FunctionCallee getClusterKernelLaunchFn(); 149 150 // Get the module function callee. 151 FunctionCallee getModuleFunctionFn(); 152 153 // Get the module load callee. 154 FunctionCallee getModuleLoadFn(); 155 156 // Get the module load JIT callee. 157 FunctionCallee getModuleLoadJITFn(); 158 159 // Get the module unload callee. 160 FunctionCallee getModuleUnloadFn(); 161 162 // Get the stream create callee. 163 FunctionCallee getStreamCreateFn(); 164 165 // Get the stream destroy callee. 166 FunctionCallee getStreamDestroyFn(); 167 168 // Get the stream sync callee. 169 FunctionCallee getStreamSyncFn(); 170 171 // Ger or create the function name global string. 172 Value *getOrCreateFunctionName(StringRef moduleName, StringRef kernelName); 173 174 // Create the void* kernel array for passing the arguments. 175 Value *createKernelArgArray(mlir::gpu::LaunchFuncOp op); 176 177 // Create the full kernel launch. 178 llvm::LogicalResult createKernelLaunch(mlir::gpu::LaunchFuncOp op, 179 mlir::gpu::ObjectAttr object); 180 181 private: 182 Module &module; 183 IRBuilderBase &builder; 184 mlir::LLVM::ModuleTranslation &moduleTranslation; 185 Type *i32Ty{}; 186 Type *i64Ty{}; 187 Type *voidTy{}; 188 Type *intPtrTy{}; 189 PointerType *ptrTy{}; 190 }; 191 } // namespace 192 } // namespace llvm 193 194 LogicalResult SelectObjectAttrImpl::launchKernel( 195 Attribute attribute, Operation *launchFuncOperation, 196 Operation *binaryOperation, llvm::IRBuilderBase &builder, 197 LLVM::ModuleTranslation &moduleTranslation) const { 198 199 assert(launchFuncOperation && "The launch func operation must be non null."); 200 if (!launchFuncOperation) 201 return failure(); 202 203 auto launchFuncOp = mlir::dyn_cast<gpu::LaunchFuncOp>(launchFuncOperation); 204 if (!launchFuncOp) { 205 launchFuncOperation->emitError("operation must be a GPU launch func Op."); 206 return failure(); 207 } 208 209 auto binOp = mlir::dyn_cast<gpu::BinaryOp>(binaryOperation); 210 if (!binOp) { 211 binaryOperation->emitError("operation must be a GPU binary."); 212 return failure(); 213 } 214 gpu::ObjectAttr object = getSelectedObject(binOp); 215 if (!object) 216 return failure(); 217 218 return llvm::LaunchKernel(*moduleTranslation.getLLVMModule(), builder, 219 moduleTranslation) 220 .createKernelLaunch(launchFuncOp, object); 221 } 222 223 llvm::LaunchKernel::LaunchKernel( 224 Module &module, IRBuilderBase &builder, 225 mlir::LLVM::ModuleTranslation &moduleTranslation) 226 : module(module), builder(builder), moduleTranslation(moduleTranslation) { 227 i32Ty = builder.getInt32Ty(); 228 i64Ty = builder.getInt64Ty(); 229 ptrTy = builder.getPtrTy(0); 230 voidTy = builder.getVoidTy(); 231 intPtrTy = builder.getIntPtrTy(module.getDataLayout()); 232 } 233 234 llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() { 235 return module.getOrInsertFunction( 236 "mgpuLaunchKernel", 237 FunctionType::get(voidTy, 238 ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, 239 intPtrTy, intPtrTy, intPtrTy, i32Ty, 240 ptrTy, ptrTy, ptrTy, i64Ty}), 241 false)); 242 } 243 244 llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() { 245 return module.getOrInsertFunction( 246 "mgpuLaunchClusterKernel", 247 FunctionType::get( 248 voidTy, 249 ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy, 250 intPtrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy, 251 i32Ty, ptrTy, ptrTy, ptrTy}), 252 false)); 253 } 254 255 llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() { 256 return module.getOrInsertFunction( 257 "mgpuModuleGetFunction", 258 FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, ptrTy}), false)); 259 } 260 261 llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadFn() { 262 return module.getOrInsertFunction( 263 "mgpuModuleLoad", 264 FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i64Ty}), false)); 265 } 266 267 llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadJITFn() { 268 return module.getOrInsertFunction( 269 "mgpuModuleLoadJIT", 270 FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i32Ty}), false)); 271 } 272 273 llvm::FunctionCallee llvm::LaunchKernel::getModuleUnloadFn() { 274 return module.getOrInsertFunction( 275 "mgpuModuleUnload", 276 FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false)); 277 } 278 279 llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn() { 280 return module.getOrInsertFunction("mgpuStreamCreate", 281 FunctionType::get(ptrTy, false)); 282 } 283 284 llvm::FunctionCallee llvm::LaunchKernel::getStreamDestroyFn() { 285 return module.getOrInsertFunction( 286 "mgpuStreamDestroy", 287 FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false)); 288 } 289 290 llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() { 291 return module.getOrInsertFunction( 292 "mgpuStreamSynchronize", 293 FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false)); 294 } 295 296 // Generates an LLVM IR dialect global that contains the name of the given 297 // kernel function as a C string, and returns a pointer to its beginning. 298 llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName, 299 StringRef kernelName) { 300 std::string globalName = 301 std::string(formatv("{0}_{1}_kernel_name", moduleName, kernelName)); 302 303 if (GlobalVariable *gv = module.getGlobalVariable(globalName)) 304 return gv; 305 306 return builder.CreateGlobalString(kernelName, globalName); 307 } 308 309 // Creates a struct containing all kernel parameters on the stack and returns 310 // an array of type-erased pointers to the fields of the struct. The array can 311 // then be passed to the CUDA / ROCm (HIP) kernel launch calls. 312 // The generated code is essentially as follows: 313 // 314 // %struct = alloca(sizeof(struct { Parameters... })) 315 // %array = alloca(NumParameters * sizeof(void *)) 316 // for (i : [0, NumParameters)) 317 // %fieldPtr = llvm.getelementptr %struct[0, i] 318 // llvm.store parameters[i], %fieldPtr 319 // %elementPtr = llvm.getelementptr %array[i] 320 // llvm.store %fieldPtr, %elementPtr 321 // return %array 322 llvm::Value * 323 llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) { 324 SmallVector<Value *> args = 325 moduleTranslation.lookupValues(op.getKernelOperands()); 326 SmallVector<Type *> structTypes(args.size(), nullptr); 327 328 for (auto [i, arg] : llvm::enumerate(args)) 329 structTypes[i] = arg->getType(); 330 331 Type *structTy = StructType::create(module.getContext(), structTypes); 332 Value *argStruct = builder.CreateAlloca(structTy, 0u); 333 Value *argArray = builder.CreateAlloca( 334 ptrTy, ConstantInt::get(intPtrTy, structTypes.size())); 335 336 for (auto [i, arg] : enumerate(args)) { 337 Value *structMember = builder.CreateStructGEP(structTy, argStruct, i); 338 builder.CreateStore(arg, structMember); 339 Value *arrayMember = builder.CreateConstGEP1_32(ptrTy, argArray, i); 340 builder.CreateStore(structMember, arrayMember); 341 } 342 return argArray; 343 } 344 345 // Emits LLVM IR to launch a kernel function: 346 // %0 = call %binarygetter 347 // %1 = call %moduleLoad(%0) 348 // %2 = <see generateKernelNameConstant> 349 // %3 = call %moduleGetFunction(%1, %2) 350 // %4 = call %streamCreate() 351 // %5 = <see generateParamsArray> 352 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr) 353 // call %streamSynchronize(%4) 354 // call %streamDestroy(%4) 355 // call %moduleUnload(%1) 356 llvm::LogicalResult 357 llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op, 358 mlir::gpu::ObjectAttr object) { 359 auto llvmValue = [&](mlir::Value value) -> Value * { 360 Value *v = moduleTranslation.lookupValue(value); 361 assert(v && "Value has not been translated."); 362 return v; 363 }; 364 365 // Get grid dimensions. 366 mlir::gpu::KernelDim3 grid = op.getGridSizeOperandValues(); 367 Value *gx = llvmValue(grid.x), *gy = llvmValue(grid.y), 368 *gz = llvmValue(grid.z); 369 370 // Get block dimensions. 371 mlir::gpu::KernelDim3 block = op.getBlockSizeOperandValues(); 372 Value *bx = llvmValue(block.x), *by = llvmValue(block.y), 373 *bz = llvmValue(block.z); 374 375 // Get dynamic shared memory size. 376 Value *dynamicMemorySize = nullptr; 377 if (mlir::Value dynSz = op.getDynamicSharedMemorySize()) 378 dynamicMemorySize = llvmValue(dynSz); 379 else 380 dynamicMemorySize = ConstantInt::get(i32Ty, 0); 381 382 // Create the argument array. 383 Value *argArray = createKernelArgArray(op); 384 385 // Default JIT optimization level. 386 llvm::Constant *optV = llvm::ConstantInt::get(i32Ty, 0); 387 // Check if there's an optimization level embedded in the object. 388 DictionaryAttr objectProps = object.getProperties(); 389 mlir::Attribute optAttr; 390 if (objectProps && (optAttr = objectProps.get("O"))) { 391 auto optLevel = dyn_cast<IntegerAttr>(optAttr); 392 if (!optLevel) 393 return op.emitError("the optimization level must be an integer"); 394 optV = llvm::ConstantInt::get(i32Ty, optLevel.getValue()); 395 } 396 397 // Load the kernel module. 398 StringRef moduleName = op.getKernelModuleName().getValue(); 399 std::string binaryIdentifier = getBinaryIdentifier(moduleName); 400 Value *binary = module.getGlobalVariable(binaryIdentifier, true); 401 if (!binary) 402 return op.emitError() << "Couldn't find the binary: " << binaryIdentifier; 403 404 auto binaryVar = dyn_cast<llvm::GlobalVariable>(binary); 405 if (!binaryVar) 406 return op.emitError() << "Binary is not a global variable: " 407 << binaryIdentifier; 408 llvm::Constant *binaryInit = binaryVar->getInitializer(); 409 auto binaryDataSeq = 410 dyn_cast_if_present<llvm::ConstantDataSequential>(binaryInit); 411 if (!binaryDataSeq) 412 return op.emitError() << "Couldn't find binary data array: " 413 << binaryIdentifier; 414 llvm::Constant *binarySize = 415 llvm::ConstantInt::get(i64Ty, binaryDataSeq->getNumElements() * 416 binaryDataSeq->getElementByteSize()); 417 418 Value *moduleObject = 419 object.getFormat() == gpu::CompilationTarget::Assembly 420 ? builder.CreateCall(getModuleLoadJITFn(), {binary, optV}) 421 : builder.CreateCall(getModuleLoadFn(), {binary, binarySize}); 422 423 // Load the kernel function. 424 Value *moduleFunction = builder.CreateCall( 425 getModuleFunctionFn(), 426 {moduleObject, 427 getOrCreateFunctionName(moduleName, op.getKernelName().getValue())}); 428 429 // Get the stream to use for execution. If there's no async object then create 430 // a stream to make a synchronous kernel launch. 431 Value *stream = nullptr; 432 bool handleStream = false; 433 if (mlir::Value asyncObject = op.getAsyncObject()) { 434 stream = llvmValue(asyncObject); 435 } else { 436 handleStream = true; 437 stream = builder.CreateCall(getStreamCreateFn(), {}); 438 } 439 440 llvm::Constant *paramsCount = 441 llvm::ConstantInt::get(i64Ty, op.getNumKernelOperands()); 442 443 // Create the launch call. 444 Value *nullPtr = ConstantPointerNull::get(ptrTy); 445 446 // Launch kernel with clusters if cluster size is specified. 447 if (op.hasClusterSize()) { 448 mlir::gpu::KernelDim3 cluster = op.getClusterSizeOperandValues(); 449 Value *cx = llvmValue(cluster.x), *cy = llvmValue(cluster.y), 450 *cz = llvmValue(cluster.z); 451 builder.CreateCall( 452 getClusterKernelLaunchFn(), 453 ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz, 454 dynamicMemorySize, stream, argArray, nullPtr})); 455 } else { 456 builder.CreateCall(getKernelLaunchFn(), 457 ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by, 458 bz, dynamicMemorySize, stream, 459 argArray, nullPtr, paramsCount})); 460 } 461 462 // Sync & destroy the stream, for synchronous launches. 463 if (handleStream) { 464 builder.CreateCall(getStreamSyncFn(), {stream}); 465 builder.CreateCall(getStreamDestroyFn(), {stream}); 466 } 467 468 // Unload the kernel module. 469 builder.CreateCall(getModuleUnloadFn(), {moduleObject}); 470 471 return success(); 472 } 473