1 //===- NVVMToLLVMIRTranslation.cpp - Translate NVVM to LLVM IR ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a translation between the MLIR NVVM dialect and 10 // LLVM IR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" 15 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 16 #include "mlir/Dialect/Utils/StaticValueUtils.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/Target/LLVMIR/ModuleTranslation.h" 19 20 #include "llvm/IR/IRBuilder.h" 21 #include "llvm/IR/IntrinsicsNVPTX.h" 22 23 using namespace mlir; 24 using namespace mlir::LLVM; 25 using mlir::LLVM::detail::createIntrinsicCall; 26 27 static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, 28 NVVM::ReduxKind kind) { 29 if (!resultType->isIntegerTy(32)) 30 llvm_unreachable("unsupported data type for redux"); 31 32 switch (kind) { 33 case NVVM::ReduxKind::ADD: 34 return llvm::Intrinsic::nvvm_redux_sync_add; 35 case NVVM::ReduxKind::UMAX: 36 return llvm::Intrinsic::nvvm_redux_sync_umax; 37 case NVVM::ReduxKind::UMIN: 38 return llvm::Intrinsic::nvvm_redux_sync_umin; 39 case NVVM::ReduxKind::AND: 40 return llvm::Intrinsic::nvvm_redux_sync_and; 41 case NVVM::ReduxKind::OR: 42 return llvm::Intrinsic::nvvm_redux_sync_or; 43 case NVVM::ReduxKind::XOR: 44 return llvm::Intrinsic::nvvm_redux_sync_xor; 45 case NVVM::ReduxKind::MAX: 46 return llvm::Intrinsic::nvvm_redux_sync_max; 47 case NVVM::ReduxKind::MIN: 48 return llvm::Intrinsic::nvvm_redux_sync_min; 49 } 50 llvm_unreachable("unknown redux kind"); 51 } 52 53 static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, 54 NVVM::ShflKind kind, 55 bool withPredicate) { 56 57 if (withPredicate) { 58 resultType = cast<llvm::StructType>(resultType)->getElementType(0); 59 switch (kind) { 60 case NVVM::ShflKind::bfly: 61 return resultType->isFloatTy() 62 ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p 63 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p; 64 case NVVM::ShflKind::up: 65 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p 66 : llvm::Intrinsic::nvvm_shfl_sync_up_i32p; 67 case NVVM::ShflKind::down: 68 return resultType->isFloatTy() 69 ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p 70 : llvm::Intrinsic::nvvm_shfl_sync_down_i32p; 71 case NVVM::ShflKind::idx: 72 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p 73 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p; 74 } 75 } else { 76 switch (kind) { 77 case NVVM::ShflKind::bfly: 78 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32 79 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32; 80 case NVVM::ShflKind::up: 81 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32 82 : llvm::Intrinsic::nvvm_shfl_sync_up_i32; 83 case NVVM::ShflKind::down: 84 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32 85 : llvm::Intrinsic::nvvm_shfl_sync_down_i32; 86 case NVVM::ShflKind::idx: 87 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32 88 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32; 89 } 90 } 91 llvm_unreachable("unknown shuffle kind"); 92 } 93 94 /// Return the intrinsic ID associated with ldmatrix for the given paramters. 95 static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, 96 int32_t num) { 97 if (layout == NVVM::MMALayout::row) { 98 switch (num) { 99 case 1: 100 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16; 101 case 2: 102 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16; 103 case 4: 104 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16; 105 default: 106 llvm_unreachable("unsupported number of matrix"); 107 } 108 109 } else { 110 switch (num) { 111 case 1: 112 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16; 113 case 2: 114 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16; 115 case 4: 116 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; 117 default: 118 llvm_unreachable("unsupported number of matrix"); 119 } 120 } 121 } 122 123 static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy, 124 NVVM::ProxyKind toProxy, 125 NVVM::MemScopeKind scope, 126 bool isRelease) { 127 if (fromProxy == NVVM::ProxyKind::GENERIC && 128 toProxy == NVVM::ProxyKind::TENSORMAP) { 129 switch (scope) { 130 case NVVM::MemScopeKind::CTA: { 131 if (isRelease) 132 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta; 133 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta; 134 } 135 case NVVM::MemScopeKind::CLUSTER: { 136 if (isRelease) 137 return llvm::Intrinsic:: 138 nvvm_fence_proxy_tensormap_generic_release_cluster; 139 return llvm::Intrinsic:: 140 nvvm_fence_proxy_tensormap_generic_acquire_cluster; 141 } 142 case NVVM::MemScopeKind::GPU: { 143 if (isRelease) 144 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu; 145 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu; 146 } 147 case NVVM::MemScopeKind::SYS: { 148 if (isRelease) 149 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys; 150 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys; 151 } 152 } 153 llvm_unreachable("Unknown scope for uni-directional fence.proxy operation"); 154 } 155 llvm_unreachable("Unsupported proxy kinds"); 156 } 157 158 namespace { 159 /// Implementation of the dialect interface that converts operations belonging 160 /// to the NVVM dialect to LLVM IR. 161 class NVVMDialectLLVMIRTranslationInterface 162 : public LLVMTranslationDialectInterface { 163 public: 164 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; 165 166 /// Translates the given operation to LLVM IR using the provided IR builder 167 /// and saving the state in `moduleTranslation`. 168 LogicalResult 169 convertOperation(Operation *op, llvm::IRBuilderBase &builder, 170 LLVM::ModuleTranslation &moduleTranslation) const final { 171 Operation &opInst = *op; 172 #include "mlir/Dialect/LLVMIR/NVVMConversions.inc" 173 174 return failure(); 175 } 176 177 /// Attaches module-level metadata for functions marked as kernels. 178 LogicalResult 179 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions, 180 NamedAttribute attribute, 181 LLVM::ModuleTranslation &moduleTranslation) const final { 182 auto func = dyn_cast<LLVM::LLVMFuncOp>(op); 183 if (!func) 184 return failure(); 185 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); 186 llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName()); 187 188 auto generateMetadata = [&](int dim, StringRef name) { 189 llvm::Metadata *llvmMetadata[] = { 190 llvm::ValueAsMetadata::get(llvmFunc), 191 llvm::MDString::get(llvmContext, name), 192 llvm::ValueAsMetadata::get(llvm::ConstantInt::get( 193 llvm::Type::getInt32Ty(llvmContext), dim))}; 194 llvm::MDNode *llvmMetadataNode = 195 llvm::MDNode::get(llvmContext, llvmMetadata); 196 moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations") 197 ->addOperand(llvmMetadataNode); 198 }; 199 if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) { 200 if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue())) 201 return failure(); 202 auto values = cast<DenseI32ArrayAttr>(attribute.getValue()); 203 generateMetadata(values[0], NVVM::NVVMDialect::getMaxntidXName()); 204 if (values.size() > 1) 205 generateMetadata(values[1], NVVM::NVVMDialect::getMaxntidYName()); 206 if (values.size() > 2) 207 generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName()); 208 } else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) { 209 if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue())) 210 return failure(); 211 auto values = cast<DenseI32ArrayAttr>(attribute.getValue()); 212 generateMetadata(values[0], NVVM::NVVMDialect::getReqntidXName()); 213 if (values.size() > 1) 214 generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName()); 215 if (values.size() > 2) 216 generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName()); 217 } else if (attribute.getName() == 218 NVVM::NVVMDialect::getClusterDimAttrName()) { 219 if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue())) 220 return failure(); 221 auto values = cast<DenseI32ArrayAttr>(attribute.getValue()); 222 generateMetadata(values[0], NVVM::NVVMDialect::getClusterDimXName()); 223 if (values.size() > 1) 224 generateMetadata(values[1], NVVM::NVVMDialect::getClusterDimYName()); 225 if (values.size() > 2) 226 generateMetadata(values[2], NVVM::NVVMDialect::getClusterDimZName()); 227 } else if (attribute.getName() == 228 NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) { 229 auto value = dyn_cast<IntegerAttr>(attribute.getValue()); 230 generateMetadata(value.getInt(), "cluster_max_blocks"); 231 } else if (attribute.getName() == 232 NVVM::NVVMDialect::getMinctasmAttrName()) { 233 auto value = dyn_cast<IntegerAttr>(attribute.getValue()); 234 generateMetadata(value.getInt(), "minctasm"); 235 } else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) { 236 auto value = dyn_cast<IntegerAttr>(attribute.getValue()); 237 generateMetadata(value.getInt(), "maxnreg"); 238 } else if (attribute.getName() == 239 NVVM::NVVMDialect::getKernelFuncAttrName()) { 240 llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel); 241 } 242 return success(); 243 } 244 245 LogicalResult 246 convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute, 247 LLVM::ModuleTranslation &moduleTranslation) const final { 248 249 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); 250 llvm::Function *llvmFunc = 251 moduleTranslation.lookupFunction(funcOp.getName()); 252 llvm::NamedMDNode *nvvmAnnotations = 253 moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations"); 254 255 if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) { 256 llvm::MDNode *gridConstantMetaData = nullptr; 257 258 // Check if a 'grid_constant' metadata node exists for the given function 259 for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) { 260 if (opnd->getNumOperands() == 3 && 261 opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) && 262 opnd->getOperand(1) == 263 llvm::MDString::get(llvmContext, "grid_constant")) { 264 gridConstantMetaData = opnd; 265 break; 266 } 267 } 268 269 // 'grid_constant' is a function-level meta data node with a list of 270 // integers, where each integer n denotes that the nth parameter has the 271 // grid_constant annotation (numbering from 1). This requires aggregating 272 // the indices of the individual parameters that have this attribute. 273 llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32); 274 if (gridConstantMetaData == nullptr) { 275 // Create a new 'grid_constant' metadata node 276 SmallVector<llvm::Metadata *> gridConstMetadata = { 277 llvm::ValueAsMetadata::getConstant( 278 llvm::ConstantInt::get(i32, argIdx + 1))}; 279 llvm::Metadata *llvmMetadata[] = { 280 llvm::ValueAsMetadata::get(llvmFunc), 281 llvm::MDString::get(llvmContext, "grid_constant"), 282 llvm::MDNode::get(llvmContext, gridConstMetadata)}; 283 llvm::MDNode *llvmMetadataNode = 284 llvm::MDNode::get(llvmContext, llvmMetadata); 285 nvvmAnnotations->addOperand(llvmMetadataNode); 286 } else { 287 // Append argIdx + 1 to the 'grid_constant' argument list 288 if (auto argList = 289 dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) { 290 llvm::TempMDTuple clonedArgList = argList->clone(); 291 clonedArgList->push_back((llvm::ValueAsMetadata::getConstant( 292 llvm::ConstantInt::get(i32, argIdx + 1)))); 293 gridConstantMetaData->replaceOperandWith( 294 2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList))); 295 } 296 } 297 } 298 return success(); 299 } 300 }; 301 } // namespace 302 303 void mlir::registerNVVMDialectTranslation(DialectRegistry ®istry) { 304 registry.insert<NVVM::NVVMDialect>(); 305 registry.addExtension(+[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) { 306 dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>(); 307 }); 308 } 309 310 void mlir::registerNVVMDialectTranslation(MLIRContext &context) { 311 DialectRegistry registry; 312 registerNVVMDialectTranslation(registry); 313 context.appendDialectRegistry(registry); 314 } 315