xref: /llvm-project/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp (revision 4583f6d3443c8dc6605c868724e3743161954210)
1 //===- NVVMToLLVMIRTranslation.cpp - Translate NVVM to LLVM IR ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR NVVM dialect and
10 // LLVM IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
15 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
16 #include "mlir/Dialect/Utils/StaticValueUtils.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
19 
20 #include "llvm/IR/IRBuilder.h"
21 #include "llvm/IR/IntrinsicsNVPTX.h"
22 
23 using namespace mlir;
24 using namespace mlir::LLVM;
25 using mlir::LLVM::detail::createIntrinsicCall;
26 
27 static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
28                                                NVVM::ReduxKind kind) {
29   if (!resultType->isIntegerTy(32))
30     llvm_unreachable("unsupported data type for redux");
31 
32   switch (kind) {
33   case NVVM::ReduxKind::ADD:
34     return llvm::Intrinsic::nvvm_redux_sync_add;
35   case NVVM::ReduxKind::UMAX:
36     return llvm::Intrinsic::nvvm_redux_sync_umax;
37   case NVVM::ReduxKind::UMIN:
38     return llvm::Intrinsic::nvvm_redux_sync_umin;
39   case NVVM::ReduxKind::AND:
40     return llvm::Intrinsic::nvvm_redux_sync_and;
41   case NVVM::ReduxKind::OR:
42     return llvm::Intrinsic::nvvm_redux_sync_or;
43   case NVVM::ReduxKind::XOR:
44     return llvm::Intrinsic::nvvm_redux_sync_xor;
45   case NVVM::ReduxKind::MAX:
46     return llvm::Intrinsic::nvvm_redux_sync_max;
47   case NVVM::ReduxKind::MIN:
48     return llvm::Intrinsic::nvvm_redux_sync_min;
49   }
50   llvm_unreachable("unknown redux kind");
51 }
52 
53 static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
54                                               NVVM::ShflKind kind,
55                                               bool withPredicate) {
56 
57   if (withPredicate) {
58     resultType = cast<llvm::StructType>(resultType)->getElementType(0);
59     switch (kind) {
60     case NVVM::ShflKind::bfly:
61       return resultType->isFloatTy()
62                  ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
63                  : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
64     case NVVM::ShflKind::up:
65       return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
66                                      : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
67     case NVVM::ShflKind::down:
68       return resultType->isFloatTy()
69                  ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
70                  : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
71     case NVVM::ShflKind::idx:
72       return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
73                                      : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
74     }
75   } else {
76     switch (kind) {
77     case NVVM::ShflKind::bfly:
78       return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
79                                      : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
80     case NVVM::ShflKind::up:
81       return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
82                                      : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
83     case NVVM::ShflKind::down:
84       return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
85                                      : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
86     case NVVM::ShflKind::idx:
87       return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
88                                      : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
89     }
90   }
91   llvm_unreachable("unknown shuffle kind");
92 }
93 
94 /// Return the intrinsic ID associated with ldmatrix for the given paramters.
95 static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
96                                                   int32_t num) {
97   if (layout == NVVM::MMALayout::row) {
98     switch (num) {
99     case 1:
100       return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
101     case 2:
102       return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
103     case 4:
104       return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
105     default:
106       llvm_unreachable("unsupported number of matrix");
107     }
108 
109   } else {
110     switch (num) {
111     case 1:
112       return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
113     case 2:
114       return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
115     case 4:
116       return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
117     default:
118       llvm_unreachable("unsupported number of matrix");
119     }
120   }
121 }
122 
123 static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy,
124                                               NVVM::ProxyKind toProxy,
125                                               NVVM::MemScopeKind scope,
126                                               bool isRelease) {
127   if (fromProxy == NVVM::ProxyKind::GENERIC &&
128       toProxy == NVVM::ProxyKind::TENSORMAP) {
129     switch (scope) {
130     case NVVM::MemScopeKind::CTA: {
131       if (isRelease)
132         return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta;
133       return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta;
134     }
135     case NVVM::MemScopeKind::CLUSTER: {
136       if (isRelease)
137         return llvm::Intrinsic::
138             nvvm_fence_proxy_tensormap_generic_release_cluster;
139       return llvm::Intrinsic::
140           nvvm_fence_proxy_tensormap_generic_acquire_cluster;
141     }
142     case NVVM::MemScopeKind::GPU: {
143       if (isRelease)
144         return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu;
145       return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu;
146     }
147     case NVVM::MemScopeKind::SYS: {
148       if (isRelease)
149         return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys;
150       return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys;
151     }
152     }
153     llvm_unreachable("Unknown scope for uni-directional fence.proxy operation");
154   }
155   llvm_unreachable("Unsupported proxy kinds");
156 }
157 
158 namespace {
159 /// Implementation of the dialect interface that converts operations belonging
160 /// to the NVVM dialect to LLVM IR.
161 class NVVMDialectLLVMIRTranslationInterface
162     : public LLVMTranslationDialectInterface {
163 public:
164   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
165 
166   /// Translates the given operation to LLVM IR using the provided IR builder
167   /// and saving the state in `moduleTranslation`.
168   LogicalResult
169   convertOperation(Operation *op, llvm::IRBuilderBase &builder,
170                    LLVM::ModuleTranslation &moduleTranslation) const final {
171     Operation &opInst = *op;
172 #include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
173 
174     return failure();
175   }
176 
177   /// Attaches module-level metadata for functions marked as kernels.
178   LogicalResult
179   amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
180                  NamedAttribute attribute,
181                  LLVM::ModuleTranslation &moduleTranslation) const final {
182     auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
183     if (!func)
184       return failure();
185     llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
186     llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName());
187 
188     auto generateMetadata = [&](int dim, StringRef name) {
189       llvm::Metadata *llvmMetadata[] = {
190           llvm::ValueAsMetadata::get(llvmFunc),
191           llvm::MDString::get(llvmContext, name),
192           llvm::ValueAsMetadata::get(llvm::ConstantInt::get(
193               llvm::Type::getInt32Ty(llvmContext), dim))};
194       llvm::MDNode *llvmMetadataNode =
195           llvm::MDNode::get(llvmContext, llvmMetadata);
196       moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations")
197           ->addOperand(llvmMetadataNode);
198     };
199     if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
200       if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
201         return failure();
202       auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
203       generateMetadata(values[0], NVVM::NVVMDialect::getMaxntidXName());
204       if (values.size() > 1)
205         generateMetadata(values[1], NVVM::NVVMDialect::getMaxntidYName());
206       if (values.size() > 2)
207         generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName());
208     } else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
209       if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
210         return failure();
211       auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
212       generateMetadata(values[0], NVVM::NVVMDialect::getReqntidXName());
213       if (values.size() > 1)
214         generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName());
215       if (values.size() > 2)
216         generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName());
217     } else if (attribute.getName() ==
218                NVVM::NVVMDialect::getClusterDimAttrName()) {
219       if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
220         return failure();
221       auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
222       generateMetadata(values[0], NVVM::NVVMDialect::getClusterDimXName());
223       if (values.size() > 1)
224         generateMetadata(values[1], NVVM::NVVMDialect::getClusterDimYName());
225       if (values.size() > 2)
226         generateMetadata(values[2], NVVM::NVVMDialect::getClusterDimZName());
227     } else if (attribute.getName() ==
228                NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) {
229       auto value = dyn_cast<IntegerAttr>(attribute.getValue());
230       generateMetadata(value.getInt(), "cluster_max_blocks");
231     } else if (attribute.getName() ==
232                NVVM::NVVMDialect::getMinctasmAttrName()) {
233       auto value = dyn_cast<IntegerAttr>(attribute.getValue());
234       generateMetadata(value.getInt(), "minctasm");
235     } else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
236       auto value = dyn_cast<IntegerAttr>(attribute.getValue());
237       generateMetadata(value.getInt(), "maxnreg");
238     } else if (attribute.getName() ==
239                NVVM::NVVMDialect::getKernelFuncAttrName()) {
240       llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel);
241     }
242     return success();
243   }
244 
245   LogicalResult
246   convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute,
247                        LLVM::ModuleTranslation &moduleTranslation) const final {
248 
249     llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
250     llvm::Function *llvmFunc =
251         moduleTranslation.lookupFunction(funcOp.getName());
252     llvm::NamedMDNode *nvvmAnnotations =
253         moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations");
254 
255     if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
256       llvm::MDNode *gridConstantMetaData = nullptr;
257 
258       // Check if a 'grid_constant' metadata node exists for the given function
259       for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) {
260         if (opnd->getNumOperands() == 3 &&
261             opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) &&
262             opnd->getOperand(1) ==
263                 llvm::MDString::get(llvmContext, "grid_constant")) {
264           gridConstantMetaData = opnd;
265           break;
266         }
267       }
268 
269       // 'grid_constant' is a function-level meta data node with a list of
270       // integers, where each integer n denotes that the nth parameter has the
271       // grid_constant annotation (numbering from 1). This requires aggregating
272       // the indices of the individual parameters that have this attribute.
273       llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32);
274       if (gridConstantMetaData == nullptr) {
275         // Create a new 'grid_constant' metadata node
276         SmallVector<llvm::Metadata *> gridConstMetadata = {
277             llvm::ValueAsMetadata::getConstant(
278                 llvm::ConstantInt::get(i32, argIdx + 1))};
279         llvm::Metadata *llvmMetadata[] = {
280             llvm::ValueAsMetadata::get(llvmFunc),
281             llvm::MDString::get(llvmContext, "grid_constant"),
282             llvm::MDNode::get(llvmContext, gridConstMetadata)};
283         llvm::MDNode *llvmMetadataNode =
284             llvm::MDNode::get(llvmContext, llvmMetadata);
285         nvvmAnnotations->addOperand(llvmMetadataNode);
286       } else {
287         // Append argIdx + 1 to the 'grid_constant' argument list
288         if (auto argList =
289                 dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
290           llvm::TempMDTuple clonedArgList = argList->clone();
291           clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
292               llvm::ConstantInt::get(i32, argIdx + 1))));
293           gridConstantMetaData->replaceOperandWith(
294               2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
295         }
296       }
297     }
298     return success();
299   }
300 };
301 } // namespace
302 
303 void mlir::registerNVVMDialectTranslation(DialectRegistry &registry) {
304   registry.insert<NVVM::NVVMDialect>();
305   registry.addExtension(+[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) {
306     dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
307   });
308 }
309 
310 void mlir::registerNVVMDialectTranslation(MLIRContext &context) {
311   DialectRegistry registry;
312   registerNVVMDialectTranslation(registry);
313   context.appendDialectRegistry(registry);
314 }
315