xref: /llvm-project/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp (revision 4df28af7134518981d40cb3242b2a90af867fdae)
1 //===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to generate NVVMIR operations for higher-level
10 // GPU operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
15 
16 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
17 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
18 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
19 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
20 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
21 #include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
22 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
23 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
24 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
25 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
26 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
27 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
28 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h"
30 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
31 #include "mlir/Dialect/GPU/Transforms/Passes.h"
32 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
33 #include "mlir/Dialect/Math/IR/Math.h"
34 #include "mlir/Dialect/MemRef/IR/MemRef.h"
35 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
36 #include "mlir/Transforms/DialectConversion.h"
37 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
38 
39 #include "../GPUCommon/GPUOpsLowering.h"
40 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
41 #include "../GPUCommon/OpToFuncCallLowering.h"
42 #include <optional>
43 
44 namespace mlir {
45 #define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
46 #include "mlir/Conversion/Passes.h.inc"
47 } // namespace mlir
48 
49 using namespace mlir;
50 
51 namespace {
52 
53 /// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
54 static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
55   switch (mode) {
56   case gpu::ShuffleMode::XOR:
57     return NVVM::ShflKind::bfly;
58   case gpu::ShuffleMode::UP:
59     return NVVM::ShflKind::up;
60   case gpu::ShuffleMode::DOWN:
61     return NVVM::ShflKind::down;
62   case gpu::ShuffleMode::IDX:
63     return NVVM::ShflKind::idx;
64   }
65   llvm_unreachable("unknown shuffle mode");
66 }
67 
68 static std::optional<NVVM::ReduxKind>
69 convertReduxKind(gpu::AllReduceOperation mode) {
70   switch (mode) {
71   case gpu::AllReduceOperation::ADD:
72     return NVVM::ReduxKind::ADD;
73   case gpu::AllReduceOperation::MUL:
74     return std::nullopt;
75   case gpu::AllReduceOperation::MINSI:
76     return NVVM::ReduxKind::MIN;
77   case gpu::AllReduceOperation::MINUI:
78     return std::nullopt;
79   case gpu::AllReduceOperation::MINNUMF:
80     return NVVM::ReduxKind::MIN;
81   case gpu::AllReduceOperation::MAXSI:
82     return NVVM::ReduxKind::MAX;
83   case gpu::AllReduceOperation::MAXUI:
84     return std::nullopt;
85   case gpu::AllReduceOperation::MAXNUMF:
86     return NVVM::ReduxKind::MAX;
87   case gpu::AllReduceOperation::AND:
88     return NVVM::ReduxKind::AND;
89   case gpu::AllReduceOperation::OR:
90     return NVVM::ReduxKind::OR;
91   case gpu::AllReduceOperation::XOR:
92     return NVVM::ReduxKind::XOR;
93   case gpu::AllReduceOperation::MINIMUMF:
94   case gpu::AllReduceOperation::MAXIMUMF:
95     return std::nullopt;
96   }
97   return std::nullopt;
98 }
99 
100 /// This pass lowers gpu.subgroup_reduce op into to the nvvm.redux op. The op
101 /// must be run by the entire subgroup, otherwise it is undefined behaviour.
102 struct GPUSubgroupReduceOpLowering
103     : public ConvertOpToLLVMPattern<gpu::SubgroupReduceOp> {
104   using ConvertOpToLLVMPattern<gpu::SubgroupReduceOp>::ConvertOpToLLVMPattern;
105   LogicalResult
106 
107   matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
108                   ConversionPatternRewriter &rewriter) const override {
109     if (op.getClusterSize())
110       return rewriter.notifyMatchFailure(
111           op, "lowering for clustered reduce not implemented");
112 
113     if (!op.getUniform())
114       return rewriter.notifyMatchFailure(
115           op, "cannot be lowered to redux as the op must be run "
116               "uniformly (entire subgroup).");
117     if (!op.getValue().getType().isInteger(32))
118       return rewriter.notifyMatchFailure(op, "unsupported data type");
119 
120     std::optional<NVVM::ReduxKind> mode = convertReduxKind(op.getOp());
121     if (!mode.has_value())
122       return rewriter.notifyMatchFailure(
123           op, "unsupported reduction mode for redux");
124 
125     Location loc = op->getLoc();
126     auto int32Type = IntegerType::get(rewriter.getContext(), 32);
127     Value offset = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
128 
129     auto reduxOp = rewriter.create<NVVM::ReduxOp>(loc, int32Type, op.getValue(),
130                                                   mode.value(), offset);
131 
132     rewriter.replaceOp(op, reduxOp->getResult(0));
133     return success();
134   }
135 };
136 
137 struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
138   using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
139 
140   /// Lowers a shuffle to the corresponding NVVM op.
141   ///
142   /// Convert the `width` argument into an activeMask (a bitmask which specifies
143   /// which threads participate in the shuffle) and a maskAndClamp (specifying
144   /// the highest lane which participates in the shuffle).
145   ///
146   ///     %one = llvm.constant(1 : i32) : i32
147   ///     %minus_one = llvm.constant(-1 : i32) : i32
148   ///     %thirty_two = llvm.constant(32 : i32) : i32
149   ///     %num_lanes = llvm.sub %thirty_two, %width : i32
150   ///     %active_mask = llvm.lshr %minus_one, %num_lanes : i32
151   ///     %mask_and_clamp = llvm.sub %width, %one : i32
152   ///     %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset,
153   ///         %mask_and_clamp : !llvm<"{ float, i1 }">
154   ///     %shfl_value = llvm.extractvalue %shfl[0] :
155   ///         !llvm<"{ float, i1 }">
156   ///     %shfl_pred = llvm.extractvalue %shfl[1] :
157   ///         !llvm<"{ float, i1 }">
158   LogicalResult
159   matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
160                   ConversionPatternRewriter &rewriter) const override {
161     Location loc = op->getLoc();
162 
163     auto valueTy = adaptor.getValue().getType();
164     auto int32Type = IntegerType::get(rewriter.getContext(), 32);
165     auto predTy = IntegerType::get(rewriter.getContext(), 1);
166 
167     Value one = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 1);
168     Value minusOne = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
169     Value thirtyTwo = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 32);
170     Value numLeadInactiveLane = rewriter.create<LLVM::SubOp>(
171         loc, int32Type, thirtyTwo, adaptor.getWidth());
172     // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
173     Value activeMask = rewriter.create<LLVM::LShrOp>(loc, int32Type, minusOne,
174                                                      numLeadInactiveLane);
175     Value maskAndClamp;
176     if (op.getMode() == gpu::ShuffleMode::UP) {
177       // Clamp lane: `32 - activeWidth`
178       maskAndClamp = numLeadInactiveLane;
179     } else {
180       // Clamp lane: `activeWidth - 1`
181       maskAndClamp =
182           rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one);
183     }
184 
185     bool predIsUsed = !op->getResult(1).use_empty();
186     UnitAttr returnValueAndIsValidAttr = nullptr;
187     Type resultTy = valueTy;
188     if (predIsUsed) {
189       returnValueAndIsValidAttr = rewriter.getUnitAttr();
190       resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
191                                                   {valueTy, predTy});
192     }
193     Value shfl = rewriter.create<NVVM::ShflOp>(
194         loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(),
195         maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr);
196     if (predIsUsed) {
197       Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0);
198       Value isActiveSrcLane =
199           rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
200       rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
201     } else {
202       rewriter.replaceOp(op, {shfl, nullptr});
203     }
204     return success();
205   }
206 };
207 
208 struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
209   using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
210 
211   LogicalResult
212   matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
213                   ConversionPatternRewriter &rewriter) const override {
214     auto loc = op->getLoc();
215     MLIRContext *context = rewriter.getContext();
216     LLVM::ConstantRangeAttr bounds = nullptr;
217     if (std::optional<APInt> upperBound = op.getUpperBound())
218       bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
219           /*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue());
220     else
221       bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
222           /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
223     Value newOp =
224         rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type(), bounds);
225     // Truncate or extend the result depending on the index bitwidth specified
226     // by the LLVMTypeConverter options.
227     const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
228     if (indexBitwidth > 32) {
229       newOp = rewriter.create<LLVM::SExtOp>(
230           loc, IntegerType::get(context, indexBitwidth), newOp);
231     } else if (indexBitwidth < 32) {
232       newOp = rewriter.create<LLVM::TruncOp>(
233           loc, IntegerType::get(context, indexBitwidth), newOp);
234     }
235     rewriter.replaceOp(op, {newOp});
236     return success();
237   }
238 };
239 
240 /// Lowering of cf.assert into a conditional __assertfail.
241 struct AssertOpToAssertfailLowering
242     : public ConvertOpToLLVMPattern<cf::AssertOp> {
243   using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
244 
245   LogicalResult
246   matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
247                   ConversionPatternRewriter &rewriter) const override {
248     MLIRContext *ctx = rewriter.getContext();
249     Location loc = assertOp.getLoc();
250     Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
251     Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
252     Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
253     Type ptrType = LLVM::LLVMPointerType::get(ctx);
254     Type voidType = LLVM::LLVMVoidType::get(ctx);
255 
256     // Find or create __assertfail function declaration.
257     auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
258     auto assertfailType = LLVM::LLVMFunctionType::get(
259         voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
260     LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
261         moduleOp, loc, rewriter, "__assertfail", assertfailType);
262     assertfailDecl.setPassthroughAttr(
263         ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
264 
265     // Split blocks and insert conditional branch.
266     // ^before:
267     //   ...
268     //   cf.cond_br %condition, ^after, ^assert
269     // ^assert:
270     //   cf.assert
271     //   cf.br ^after
272     // ^after:
273     //   ...
274     Block *beforeBlock = assertOp->getBlock();
275     Block *assertBlock =
276         rewriter.splitBlock(beforeBlock, assertOp->getIterator());
277     Block *afterBlock =
278         rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
279     rewriter.setInsertionPointToEnd(beforeBlock);
280     rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
281                                       assertBlock);
282     rewriter.setInsertionPointToEnd(assertBlock);
283     rewriter.create<cf::BranchOp>(loc, afterBlock);
284 
285     // Continue cf.assert lowering.
286     rewriter.setInsertionPoint(assertOp);
287 
288     // Populate file name, file number and function name from the location of
289     // the AssertOp.
290     StringRef fileName = "(unknown)";
291     StringRef funcName = "(unknown)";
292     int32_t fileLine = 0;
293     while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
294       loc = callSiteLoc.getCallee();
295     if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
296       fileName = fileLineColLoc.getFilename().strref();
297       fileLine = fileLineColLoc.getStartLine();
298     } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
299       funcName = nameLoc.getName().strref();
300       if (auto fileLineColLoc =
301               dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
302         fileName = fileLineColLoc.getFilename().strref();
303         fileLine = fileLineColLoc.getStartLine();
304       }
305     }
306 
307     // Create constants.
308     auto getGlobal = [&](LLVM::GlobalOp global) {
309       // Get a pointer to the format string's first element.
310       Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
311           loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
312           global.getSymNameAttr());
313       Value start =
314           rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
315                                        globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
316       return start;
317     };
318     Value assertMessage = getGlobal(getOrCreateStringConstant(
319         rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
320     Value assertFile = getGlobal(getOrCreateStringConstant(
321         rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
322     Value assertFunc = getGlobal(getOrCreateStringConstant(
323         rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
324     Value assertLine =
325         rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine);
326     Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1);
327 
328     // Insert function call to __assertfail.
329     SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
330                                  assertFunc, c1};
331     rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
332                                               arguments);
333     return success();
334   }
335 };
336 
337 /// Import the GPU Ops to NVVM Patterns.
338 #include "GPUToNVVM.cpp.inc"
339 
340 /// A pass that replaces all occurrences of GPU device operations with their
341 /// corresponding NVVM equivalent.
342 ///
343 /// This pass only handles device code and is not meant to be run on GPU host
344 /// code.
345 struct LowerGpuOpsToNVVMOpsPass
346     : public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
347   using Base::Base;
348 
349   void runOnOperation() override {
350     gpu::GPUModuleOp m = getOperation();
351 
352     // Request C wrapper emission.
353     for (auto func : m.getOps<func::FuncOp>()) {
354       func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
355                     UnitAttr::get(&getContext()));
356     }
357 
358     // Customize the bitwidth used for the device side index computations.
359     LowerToLLVMOptions options(
360         m.getContext(),
361         DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
362     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
363       options.overrideIndexBitwidth(indexBitwidth);
364     options.useBarePtrCallConv = useBarePtrCallConv;
365 
366     // Apply in-dialect lowering. In-dialect lowering will replace
367     // ops which need to be lowered further, which is not supported by a
368     // single conversion pass.
369     {
370       RewritePatternSet patterns(m.getContext());
371       populateGpuRewritePatterns(patterns);
372       if (failed(applyPatternsGreedily(m, std::move(patterns))))
373         return signalPassFailure();
374     }
375 
376     LLVMTypeConverter converter(m.getContext(), options);
377     configureGpuToNVVMTypeConverter(converter);
378     RewritePatternSet llvmPatterns(m.getContext());
379 
380     arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
381     cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
382     populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
383     populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
384     populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
385     populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
386     populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
387     if (this->hasRedux)
388       populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
389     LLVMConversionTarget target(getContext());
390     configureGpuToNVVMConversionLegality(target);
391     if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
392       signalPassFailure();
393   }
394 };
395 
396 } // namespace
397 
398 void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
399   target.addIllegalOp<func::FuncOp>();
400   target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
401   target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
402   target.addIllegalDialect<gpu::GPUDialect>();
403   target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
404                       LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp,
405                       LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op,
406                       LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp,
407                       LLVM::SinOp, LLVM::SqrtOp>();
408 
409   // TODO: Remove once we support replacing non-root ops.
410   target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
411 }
412 
413 void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
414   // NVVM uses alloca in the default address space to represent private
415   // memory allocations, so drop private annotations. NVVM uses address
416   // space 3 for shared memory. NVVM uses the default address space to
417   // represent global memory.
418   populateGpuMemorySpaceAttributeConversions(
419       converter, [](gpu::AddressSpace space) -> unsigned {
420         switch (space) {
421         case gpu::AddressSpace::Global:
422           return static_cast<unsigned>(
423               NVVM::NVVMMemorySpace::kGlobalMemorySpace);
424         case gpu::AddressSpace::Workgroup:
425           return static_cast<unsigned>(
426               NVVM::NVVMMemorySpace::kSharedMemorySpace);
427         case gpu::AddressSpace::Private:
428           return 0;
429         }
430         llvm_unreachable("unknown address space enum value");
431         return 0;
432       });
433   // Lowering for MMAMatrixType.
434   converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
435     return convertMMAToLLVMType(type);
436   });
437 }
438 
439 template <typename OpTy>
440 static void populateOpPatterns(const LLVMTypeConverter &converter,
441                                RewritePatternSet &patterns, StringRef f32Func,
442                                StringRef f64Func, StringRef f32ApproxFunc = "",
443                                StringRef f16Func = "") {
444   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
445   patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
446                                            f32ApproxFunc, f16Func);
447 }
448 
449 template <typename OpTy>
450 static void populateIntOpPatterns(const LLVMTypeConverter &converter,
451                                   RewritePatternSet &patterns,
452                                   StringRef i32Func) {
453   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
454   patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func);
455 }
456 
457 template <typename OpTy>
458 static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter,
459                                        RewritePatternSet &patterns,
460                                        StringRef f32Func, StringRef f64Func) {
461   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
462   patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "");
463 }
464 
465 void mlir::populateGpuSubgroupReduceOpLoweringPattern(
466     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
467   patterns.add<GPUSubgroupReduceOpLowering>(converter);
468 }
469 
470 void mlir::populateGpuToNVVMConversionPatterns(
471     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
472   using gpu::index_lowering::IndexKind;
473   using gpu::index_lowering::IntrType;
474   populateWithGenerated(patterns);
475   patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
476       converter);
477   patterns.add<
478       gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
479                                       NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
480       converter, IndexKind::Block, IntrType::Id);
481   patterns.add<
482       gpu::index_lowering::OpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
483                                       NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
484       converter, IndexKind::Block, IntrType::Dim);
485   patterns.add<
486       gpu::index_lowering::OpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
487                                       NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
488       converter, IndexKind::Other, IntrType::Id);
489   patterns.add<gpu::index_lowering::OpLowering<
490       gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
491       NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim);
492   patterns.add<gpu::index_lowering::OpLowering<
493       gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
494       NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
495       converter, IndexKind::Other, IntrType::Id);
496   patterns.add<gpu::index_lowering::OpLowering<
497       gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
498       NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
499       converter, IndexKind::Other, IntrType::Dim);
500   patterns.add<gpu::index_lowering::OpLowering<
501       gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
502       converter, IndexKind::Grid, IntrType::Id);
503   patterns.add<gpu::index_lowering::OpLowering<
504       gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
505       converter, IndexKind::Grid, IntrType::Dim);
506   patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
507       converter);
508 
509   patterns.add<GPUDynamicSharedMemoryOpLowering>(
510       converter, NVVM::kSharedMemoryAlignmentBit);
511 
512   // Explicitly drop memory space when lowering private memory
513   // attributions since NVVM models it as `alloca`s in the default
514   // memory space and does not support `alloca`s with addrspace(5).
515   patterns.add<GPUFuncOpLowering>(
516       converter,
517       GPUFuncOpLoweringOptions{
518           /*allocaAddrSpace=*/0,
519           /*workgroupAddrSpace=*/
520           static_cast<unsigned>(NVVM::NVVMMemorySpace::kSharedMemorySpace),
521           StringAttr::get(&converter.getContext(),
522                           NVVM::NVVMDialect::getKernelFuncAttrName()),
523           StringAttr::get(&converter.getContext(),
524                           NVVM::NVVMDialect::getMaxntidAttrName())});
525 
526   populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
527                                     "__nv_fmod");
528   populateIntOpPatterns<math::AbsIOp>(converter, patterns, "__nv_abs");
529   populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
530                                    "__nv_fabs");
531   populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf",
532                                    "__nv_acos");
533   populateOpPatterns<math::AcoshOp>(converter, patterns, "__nv_acoshf",
534                                     "__nv_acosh");
535   populateOpPatterns<math::AsinOp>(converter, patterns, "__nv_asinf",
536                                    "__nv_asin");
537   populateOpPatterns<math::AsinhOp>(converter, patterns, "__nv_asinhf",
538                                     "__nv_asinh");
539   populateOpPatterns<math::AtanOp>(converter, patterns, "__nv_atanf",
540                                    "__nv_atan");
541   populateOpPatterns<math::Atan2Op>(converter, patterns, "__nv_atan2f",
542                                     "__nv_atan2");
543   populateOpPatterns<math::AtanhOp>(converter, patterns, "__nv_atanhf",
544                                     "__nv_atanh");
545   populateOpPatterns<math::CbrtOp>(converter, patterns, "__nv_cbrtf",
546                                    "__nv_cbrt");
547   populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf",
548                                    "__nv_ceil");
549   populateOpPatterns<math::CopySignOp>(converter, patterns, "__nv_copysignf",
550                                        "__nv_copysign");
551   populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos",
552                                   "__nv_fast_cosf");
553   populateOpPatterns<math::CoshOp>(converter, patterns, "__nv_coshf",
554                                    "__nv_cosh");
555   populateOpPatterns<math::ErfOp>(converter, patterns, "__nv_erff", "__nv_erf");
556   populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp",
557                                   "__nv_fast_expf");
558   populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f",
559                                    "__nv_exp2");
560   populateOpPatterns<math::ExpM1Op>(converter, patterns, "__nv_expm1f",
561                                     "__nv_expm1");
562   populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf",
563                                     "__nv_floor");
564   populateOpPatterns<math::FmaOp>(converter, patterns, "__nv_fmaf", "__nv_fma");
565   populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log",
566                                   "__nv_fast_logf");
567   populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
568                                     "__nv_log10", "__nv_fast_log10f");
569   populateOpPatterns<math::Log1pOp>(converter, patterns, "__nv_log1pf",
570                                     "__nv_log1p");
571   populateOpPatterns<math::Log2Op>(converter, patterns, "__nv_log2f",
572                                    "__nv_log2", "__nv_fast_log2f");
573   populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf", "__nv_pow",
574                                    "__nv_fast_powf");
575   populateFloatIntOpPatterns<math::FPowIOp>(converter, patterns, "__nv_powif",
576                                             "__nv_powi");
577   populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf",
578                                     "__nv_round");
579   populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf",
580                                         "__nv_rint");
581   populateOpPatterns<math::RsqrtOp>(converter, patterns, "__nv_rsqrtf",
582                                     "__nv_rsqrt");
583   populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin",
584                                   "__nv_fast_sinf");
585   populateOpPatterns<math::SinhOp>(converter, patterns, "__nv_sinhf",
586                                    "__nv_sinh");
587   populateOpPatterns<math::SqrtOp>(converter, patterns, "__nv_sqrtf",
588                                    "__nv_sqrt");
589   populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan",
590                                   "__nv_fast_tanf");
591   populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
592                                    "__nv_tanh");
593 }
594 
595 //===----------------------------------------------------------------------===//
596 // NVVMTargetAttr convert to LLVM attr interface
597 //===----------------------------------------------------------------------===//
598 
599 namespace {
600 struct NVVMTargetConvertToLLVMAttrInterface
601     : public ConvertToLLVMAttrInterface::ExternalModel<
602           NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
603   /// Configure GPU to NVVM.
604   void populateConvertToLLVMConversionPatterns(
605       Attribute attr, ConversionTarget &target,
606       LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const;
607 };
608 } // namespace
609 
610 void NVVMTargetConvertToLLVMAttrInterface::
611     populateConvertToLLVMConversionPatterns(Attribute attr,
612                                             ConversionTarget &target,
613                                             LLVMTypeConverter &typeConverter,
614                                             RewritePatternSet &patterns) const {
615   configureGpuToNVVMConversionLegality(target);
616   configureGpuToNVVMTypeConverter(typeConverter);
617   populateGpuToNVVMConversionPatterns(typeConverter, patterns);
618 }
619 
620 void mlir::NVVM::registerConvertGpuToNVVMInterface(DialectRegistry &registry) {
621   registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
622     NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
623   });
624 }
625