1 //===- ROCDLToLLVMIRTranslation.cpp - Translate ROCDL to LLVM IR ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a translation between the MLIR ROCDL dialect and 10 // LLVM IR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" 15 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" 16 #include "mlir/IR/BuiltinAttributes.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/Target/LLVMIR/ModuleTranslation.h" 19 20 #include "llvm/IR/ConstantRange.h" 21 #include "llvm/IR/IRBuilder.h" 22 #include "llvm/IR/IntrinsicsAMDGPU.h" 23 #include "llvm/Support/raw_ostream.h" 24 25 using namespace mlir; 26 using namespace mlir::LLVM; 27 using mlir::LLVM::detail::createIntrinsicCall; 28 29 // Create a call to ROCm-Device-Library function that returns an ID. 30 // This is intended to specifically call device functions that fetch things like 31 // block or grid dimensions, and so is limited to functions that take one 32 // integer parameter. 33 static llvm::Value *createDimGetterFunctionCall(llvm::IRBuilderBase &builder, 34 Operation *op, StringRef fnName, 35 int parameter) { 36 llvm::Module *module = builder.GetInsertBlock()->getModule(); 37 llvm::FunctionType *functionType = llvm::FunctionType::get( 38 llvm::Type::getInt64Ty(module->getContext()), // return type. 39 llvm::Type::getInt32Ty(module->getContext()), // parameter type. 40 false); // no variadic arguments. 41 llvm::Function *fn = dyn_cast<llvm::Function>( 42 module->getOrInsertFunction(fnName, functionType).getCallee()); 43 llvm::Value *fnOp0 = llvm::ConstantInt::get( 44 llvm::Type::getInt32Ty(module->getContext()), parameter); 45 auto *call = builder.CreateCall(fn, ArrayRef<llvm::Value *>(fnOp0)); 46 if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) { 47 // Zero-extend to 64 bits because the GPU dialect uses 32-bit bounds but 48 // these ockl functions are defined to be 64-bits 49 call->addRangeRetAttr(llvm::ConstantRange(rangeAttr.getLower().zext(64), 50 rangeAttr.getUpper().zext(64))); 51 } 52 return call; 53 } 54 55 namespace { 56 /// Implementation of the dialect interface that converts operations belonging 57 /// to the ROCDL dialect to LLVM IR. 58 class ROCDLDialectLLVMIRTranslationInterface 59 : public LLVMTranslationDialectInterface { 60 public: 61 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; 62 63 /// Translates the given operation to LLVM IR using the provided IR builder 64 /// and saving the state in `moduleTranslation`. 65 LogicalResult 66 convertOperation(Operation *op, llvm::IRBuilderBase &builder, 67 LLVM::ModuleTranslation &moduleTranslation) const final { 68 Operation &opInst = *op; 69 #include "mlir/Dialect/LLVMIR/ROCDLConversions.inc" 70 71 return failure(); 72 } 73 74 /// Attaches module-level metadata for functions marked as kernels. 75 LogicalResult 76 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions, 77 NamedAttribute attribute, 78 LLVM::ModuleTranslation &moduleTranslation) const final { 79 auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect()); 80 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); 81 if (dialect->getKernelAttrHelper().getName() == attribute.getName()) { 82 auto func = dyn_cast<LLVM::LLVMFuncOp>(op); 83 if (!func) 84 return op->emitOpError(Twine(attribute.getName()) + 85 " is only supported on `llvm.func` operations"); 86 ; 87 88 // For GPU kernels, 89 // 1. Insert AMDGPU_KERNEL calling convention. 90 // 2. Insert amdgpu-flat-work-group-size(1, 256) attribute unless the user 91 // has overriden this value - 256 is the default in clang 92 llvm::Function *llvmFunc = 93 moduleTranslation.lookupFunction(func.getName()); 94 llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); 95 if (!llvmFunc->hasFnAttribute("amdgpu-flat-work-group-size")) { 96 llvmFunc->addFnAttr("amdgpu-flat-work-group-size", "1,256"); 97 } 98 99 // MLIR's GPU kernel APIs all assume and produce uniformly-sized 100 // workgroups, so the lowering of the `rocdl.kernel` marker encodes this 101 // assumption. This assumption may be overridden by setting 102 // `rocdl.uniform_work_group_size` on a given function. 103 if (!llvmFunc->hasFnAttribute("uniform-work-group-size")) 104 llvmFunc->addFnAttr("uniform-work-group-size", "true"); 105 } 106 // Override flat-work-group-size 107 // TODO: update clients to rocdl.flat_work_group_size instead, 108 // then remove this half of the branch 109 if (dialect->getMaxFlatWorkGroupSizeAttrHelper().getName() == 110 attribute.getName()) { 111 auto func = dyn_cast<LLVM::LLVMFuncOp>(op); 112 if (!func) 113 return op->emitOpError(Twine(attribute.getName()) + 114 " is only supported on `llvm.func` operations"); 115 auto value = dyn_cast<IntegerAttr>(attribute.getValue()); 116 if (!value) 117 return op->emitOpError(Twine(attribute.getName()) + 118 " must be an integer"); 119 120 llvm::Function *llvmFunc = 121 moduleTranslation.lookupFunction(func.getName()); 122 llvm::SmallString<8> llvmAttrValue; 123 llvm::raw_svector_ostream attrValueStream(llvmAttrValue); 124 attrValueStream << "1," << value.getInt(); 125 llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue); 126 } 127 if (dialect->getWavesPerEuAttrHelper().getName() == attribute.getName()) { 128 auto func = dyn_cast<LLVM::LLVMFuncOp>(op); 129 if (!func) 130 return op->emitOpError(Twine(attribute.getName()) + 131 " is only supported on `llvm.func` operations"); 132 auto value = dyn_cast<IntegerAttr>(attribute.getValue()); 133 if (!value) 134 return op->emitOpError(Twine(attribute.getName()) + 135 " must be an integer"); 136 137 llvm::Function *llvmFunc = 138 moduleTranslation.lookupFunction(func.getName()); 139 llvm::SmallString<8> llvmAttrValue; 140 llvm::raw_svector_ostream attrValueStream(llvmAttrValue); 141 attrValueStream << value.getInt(); 142 llvmFunc->addFnAttr("amdgpu-waves-per-eu", llvmAttrValue); 143 } 144 if (dialect->getFlatWorkGroupSizeAttrHelper().getName() == 145 attribute.getName()) { 146 auto func = dyn_cast<LLVM::LLVMFuncOp>(op); 147 if (!func) 148 return op->emitOpError(Twine(attribute.getName()) + 149 " is only supported on `llvm.func` operations"); 150 auto value = dyn_cast<StringAttr>(attribute.getValue()); 151 if (!value) 152 return op->emitOpError(Twine(attribute.getName()) + 153 " must be a string"); 154 155 llvm::Function *llvmFunc = 156 moduleTranslation.lookupFunction(func.getName()); 157 llvm::SmallString<8> llvmAttrValue; 158 llvmAttrValue.append(value.getValue()); 159 llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue); 160 } 161 if (ROCDL::ROCDLDialect::getUniformWorkGroupSizeAttrName() == 162 attribute.getName()) { 163 auto func = dyn_cast<LLVM::LLVMFuncOp>(op); 164 if (!func) 165 return op->emitOpError(Twine(attribute.getName()) + 166 " is only supported on `llvm.func` operations"); 167 auto value = dyn_cast<BoolAttr>(attribute.getValue()); 168 if (!value) 169 return op->emitOpError(Twine(attribute.getName()) + 170 " must be a boolean"); 171 llvm::Function *llvmFunc = 172 moduleTranslation.lookupFunction(func.getName()); 173 llvmFunc->addFnAttr("uniform-work-group-size", 174 value.getValue() ? "true" : "false"); 175 } 176 if (dialect->getUnsafeFpAtomicsAttrHelper().getName() == 177 attribute.getName()) { 178 auto func = dyn_cast<LLVM::LLVMFuncOp>(op); 179 if (!func) 180 return op->emitOpError(Twine(attribute.getName()) + 181 " is only supported on `llvm.func` operations"); 182 auto value = dyn_cast<BoolAttr>(attribute.getValue()); 183 if (!value) 184 return op->emitOpError(Twine(attribute.getName()) + 185 " must be a boolean"); 186 llvm::Function *llvmFunc = 187 moduleTranslation.lookupFunction(func.getName()); 188 llvmFunc->addFnAttr("amdgpu-unsafe-fp-atomics", 189 value.getValue() ? "true" : "false"); 190 } 191 // Set reqd_work_group_size metadata 192 if (dialect->getReqdWorkGroupSizeAttrHelper().getName() == 193 attribute.getName()) { 194 auto func = dyn_cast<LLVM::LLVMFuncOp>(op); 195 if (!func) 196 return op->emitOpError(Twine(attribute.getName()) + 197 " is only supported on `llvm.func` operations"); 198 auto value = dyn_cast<DenseI32ArrayAttr>(attribute.getValue()); 199 if (!value) 200 return op->emitOpError(Twine(attribute.getName()) + 201 " must be a dense i32 array attribute"); 202 SmallVector<llvm::Metadata *, 3> metadata; 203 llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32); 204 for (int32_t i : value.asArrayRef()) { 205 llvm::Constant *constant = llvm::ConstantInt::get(i32, i); 206 metadata.push_back(llvm::ConstantAsMetadata::get(constant)); 207 } 208 llvm::Function *llvmFunc = 209 moduleTranslation.lookupFunction(func.getName()); 210 llvm::MDNode *node = llvm::MDNode::get(llvmContext, metadata); 211 llvmFunc->setMetadata("reqd_work_group_size", node); 212 } 213 214 // Atomic and nontemporal metadata 215 if (dialect->getLastUseAttrHelper().getName() == attribute.getName()) { 216 for (llvm::Instruction *i : instructions) 217 i->setMetadata("amdgpu.last.use", llvm::MDNode::get(llvmContext, {})); 218 } 219 if (dialect->getNoRemoteMemoryAttrHelper().getName() == 220 attribute.getName()) { 221 for (llvm::Instruction *i : instructions) 222 i->setMetadata("amdgpu.no.remote.memory", 223 llvm::MDNode::get(llvmContext, {})); 224 } 225 if (dialect->getNoFineGrainedMemoryAttrHelper().getName() == 226 attribute.getName()) { 227 for (llvm::Instruction *i : instructions) 228 i->setMetadata("amdgpu.no.fine.grained.memory", 229 llvm::MDNode::get(llvmContext, {})); 230 } 231 if (dialect->getIgnoreDenormalModeAttrHelper().getName() == 232 attribute.getName()) { 233 for (llvm::Instruction *i : instructions) 234 i->setMetadata("amdgpu.ignore.denormal.mode", 235 llvm::MDNode::get(llvmContext, {})); 236 } 237 238 return success(); 239 } 240 }; 241 } // namespace 242 243 void mlir::registerROCDLDialectTranslation(DialectRegistry ®istry) { 244 registry.insert<ROCDL::ROCDLDialect>(); 245 registry.addExtension(+[](MLIRContext *ctx, ROCDL::ROCDLDialect *dialect) { 246 dialect->addInterfaces<ROCDLDialectLLVMIRTranslationInterface>(); 247 }); 248 } 249 250 void mlir::registerROCDLDialectTranslation(MLIRContext &context) { 251 DialectRegistry registry; 252 registerROCDLDialectTranslation(registry); 253 context.appendDialectRegistry(registry); 254 } 255