xref: /llvm-project/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (revision 391a7577e703516dbefd41b3da8f3bbd751c6978)
1 //===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to generate ROCDLIR operations for higher-level
10 // GPU operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
15 #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
16 #include "mlir/Dialect/Arith/Transforms/Passes.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Pass/PassManager.h"
19 #include "mlir/Transforms/Passes.h"
20 
21 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
22 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
23 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
24 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
25 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
26 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
27 #include "mlir/Conversion/LLVMCommon/Pattern.h"
28 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
29 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
30 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
31 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
32 #include "mlir/Dialect/Func/IR/FuncOps.h"
33 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
34 #include "mlir/Dialect/GPU/Transforms/Passes.h"
35 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
36 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
37 #include "mlir/Dialect/Math/IR/Math.h"
38 #include "mlir/Dialect/MemRef/IR/MemRef.h"
39 #include "mlir/Dialect/Vector/IR/VectorOps.h"
40 #include "mlir/IR/BuiltinAttributes.h"
41 #include "mlir/Pass/Pass.h"
42 #include "mlir/Transforms/DialectConversion.h"
43 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
44 #include "llvm/Support/FormatVariadic.h"
45 
46 #include "../GPUCommon/GPUOpsLowering.h"
47 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
48 #include "../GPUCommon/OpToFuncCallLowering.h"
49 
50 namespace mlir {
51 #define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS
52 #include "mlir/Conversion/Passes.h.inc"
53 } // namespace mlir
54 
55 using namespace mlir;
56 
57 /// Returns true if the given `gpu.func` can be safely called using the bare
58 /// pointer calling convention.
59 static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
60   bool canBeBare = true;
61   for (Type type : func.getArgumentTypes())
62     if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
63       canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy);
64   return canBeBare;
65 }
66 
67 Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
68                 const unsigned indexBitwidth) {
69   auto int32Type = IntegerType::get(rewriter.getContext(), 32);
70   Value zero = rewriter.createOrFold<arith::ConstantIntOp>(loc, 0, 32);
71   Value minus1 = rewriter.createOrFold<arith::ConstantIntOp>(loc, -1, 32);
72   Value mbcntLo = rewriter.create<ROCDL::MbcntLoOp>(loc, int32Type,
73                                                     ValueRange{minus1, zero});
74   Value laneId = rewriter.create<ROCDL::MbcntHiOp>(loc, int32Type,
75                                                    ValueRange{minus1, mbcntLo});
76   return laneId;
77 }
78 
79 namespace {
80 struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
81   using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
82 
83   LogicalResult
84   matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
85                   ConversionPatternRewriter &rewriter) const override {
86     auto loc = op->getLoc();
87     MLIRContext *context = rewriter.getContext();
88     // convert to:  %mlo = call @llvm.amdgcn.mbcnt.lo(-1, 0)
89     // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo)
90 
91     Type intTy = IntegerType::get(context, 32);
92     Value zero = rewriter.createOrFold<arith::ConstantIntOp>(loc, 0, 32);
93     Value minus1 = rewriter.createOrFold<arith::ConstantIntOp>(loc, -1, 32);
94     Value mbcntLo =
95         rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
96     Value laneId = rewriter.create<ROCDL::MbcntHiOp>(
97         loc, intTy, ValueRange{minus1, mbcntLo});
98     // Truncate or extend the result depending on the index bitwidth specified
99     // by the LLVMTypeConverter options.
100     const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
101     if (indexBitwidth > 32) {
102       laneId = rewriter.create<LLVM::SExtOp>(
103           loc, IntegerType::get(context, indexBitwidth), laneId);
104     } else if (indexBitwidth < 32) {
105       laneId = rewriter.create<LLVM::TruncOp>(
106           loc, IntegerType::get(context, indexBitwidth), laneId);
107     }
108     rewriter.replaceOp(op, {laneId});
109     return success();
110   }
111 };
112 
113 struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
114   using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
115 
116   /// Lowers a shuffle to the corresponding ROCDL ops.
117   ///
118   /// Use the `width` argument to see if src lane is participating.
119   /// If not the dstLane would be itself.
120   ///
121   ///  Shuffle with DS Bpermute:
122   ///   let shflMode = [xor, up, down, idx]
123   ///   let width = 32(usually warpsize), step = [1, 2, 4, 8, 16, ... , width].
124   ///   1. curLaneId = using mbcnt.lo + mbcnt.hi
125   ///   2. widthOrZeroIfOutside = (curLaneId + width) & -width
126   ///   3. dstLane = shflMode(curLaneId, step)
127   ///   4. isActiveSrcLane = dstLane < isActiveSrcLane
128   ///   5. dstLane = isActiveSrcLane ? dstLane : curLaneId
129   ///   6. dwordAlignedDstLane = dstLane * 4 or dstLane << 2.
130   ///   7. bpermute(dwordAlignedDstLane, shfl_value).
131   ///
132   LogicalResult
133   matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
134                   ConversionPatternRewriter &rewriter) const override {
135     Location loc = op->getLoc();
136     // TODO: Add support for non 32-bit shuffle values.
137     if (adaptor.getValue().getType().getIntOrFloatBitWidth() != 32)
138       return failure();
139     const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
140     Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
141 
142     auto int32Type = IntegerType::get(rewriter.getContext(), 32);
143     Value width = adaptor.getWidth();
144     Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0);
145     Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width);
146     Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width);
147     Value widthOrZeroIfOutside =
148         rewriter.create<LLVM::AndOp>(loc, int32Type, add, negwidth);
149     Value dstLane;
150     // TODO: Add support for gpu::ShuffleMode::UP and gpu::ShuffleMode::DOWN.
151     // TODO: Use ds_swizzle for XOR when step/offsets are constants for better
152     // perf.
153     switch (op.getMode()) {
154     case gpu::ShuffleMode::XOR:
155       dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId,
156                                              adaptor.getOffset());
157       break;
158     case gpu::ShuffleMode::IDX:
159       dstLane = adaptor.getOffset();
160       break;
161     default:
162       return failure();
163     }
164     Value isActiveSrcLane = rewriter.create<LLVM::ICmpOp>(
165         loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
166     Value selectDstLane = rewriter.create<LLVM::SelectOp>(loc, isActiveSrcLane,
167                                                           dstLane, srcLaneId);
168     Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2);
169     Value dwordAlignedDstLane =
170         rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
171     Value initShflValue = adaptor.getValue();
172     if (adaptor.getValue().getType().isF32()) {
173       initShflValue =
174           rewriter.create<LLVM::BitcastOp>(loc, int32Type, initShflValue);
175     }
176     Value shflValue = rewriter.create<ROCDL::DsBpermuteOp>(
177         loc, int32Type, dwordAlignedDstLane, initShflValue);
178     if (adaptor.getValue().getType().isF32()) {
179       shflValue = rewriter.create<LLVM::BitcastOp>(
180           loc, adaptor.getValue().getType(), shflValue);
181     }
182     rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
183     return success();
184   }
185 };
186 
187 /// Import the GPU Ops to ROCDL Patterns.
188 #include "GPUToROCDL.cpp.inc"
189 
190 // A pass that replaces all occurrences of GPU device operations with their
191 // corresponding ROCDL equivalent.
192 //
193 // This pass only handles device code and is not meant to be run on GPU host
194 // code.
195 struct LowerGpuOpsToROCDLOpsPass
196     : public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
197   LowerGpuOpsToROCDLOpsPass() = default;
198   LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth,
199                             bool useBarePtrCallConv,
200                             gpu::amd::Runtime runtime) {
201     if (this->chipset.getNumOccurrences() == 0)
202       this->chipset = chipset;
203     if (this->indexBitwidth.getNumOccurrences() == 0)
204       this->indexBitwidth = indexBitwidth;
205     if (this->useBarePtrCallConv.getNumOccurrences() == 0)
206       this->useBarePtrCallConv = useBarePtrCallConv;
207     if (this->runtime.getNumOccurrences() == 0)
208       this->runtime = runtime;
209   }
210 
211   void runOnOperation() override {
212     gpu::GPUModuleOp m = getOperation();
213     MLIRContext *ctx = m.getContext();
214 
215     // Request C wrapper emission.
216     for (auto func : m.getOps<func::FuncOp>()) {
217       func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
218                     UnitAttr::get(ctx));
219     }
220 
221     FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
222     if (failed(maybeChipset)) {
223       emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
224       return signalPassFailure();
225     }
226 
227     /// Customize the bitwidth used for the device side index computations.
228     LowerToLLVMOptions options(
229         ctx, DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
230     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
231       options.overrideIndexBitwidth(indexBitwidth);
232 
233     if (useBarePtrCallConv) {
234       options.useBarePtrCallConv = true;
235       WalkResult canUseBarePointers =
236           m.walk([](gpu::GPUFuncOp func) -> WalkResult {
237             if (canBeCalledWithBarePointers(func))
238               return WalkResult::advance();
239             return WalkResult::interrupt();
240           });
241       if (canUseBarePointers.wasInterrupted()) {
242         emitError(UnknownLoc::get(ctx),
243                   "bare pointer calling convention requires all memrefs to "
244                   "have static shape and use the identity map");
245         return signalPassFailure();
246       }
247     }
248 
249     // Apply in-dialect lowering. In-dialect lowering will replace
250     // ops which need to be lowered further, which is not supported by a
251     // single conversion pass.
252     {
253       RewritePatternSet patterns(ctx);
254       populateGpuRewritePatterns(patterns);
255       arith::populateExpandBFloat16Patterns(patterns);
256       (void)applyPatternsAndFoldGreedily(m, std::move(patterns));
257     }
258 
259     LLVMTypeConverter converter(ctx, options);
260     populateGpuMemorySpaceAttributeConversions(
261         converter, [](gpu::AddressSpace space) {
262           switch (space) {
263           case gpu::AddressSpace::Global:
264             return 1;
265           case gpu::AddressSpace::Workgroup:
266             return 3;
267           case gpu::AddressSpace::Private:
268             return 5;
269           }
270           llvm_unreachable("unknown address space enum value");
271           return 0;
272         });
273 
274     RewritePatternSet llvmPatterns(ctx);
275 
276     mlir::arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
277     populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
278                                             *maybeChipset);
279     populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
280     cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
281     populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
282     populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
283     populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
284     LLVMConversionTarget target(getContext());
285     configureGpuToROCDLConversionLegality(target);
286     if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
287       signalPassFailure();
288 
289     // Manually rewrite known block size attributes so the LLVMIR translation
290     // infrastructure can pick them up.
291     m.walk([ctx](LLVM::LLVMFuncOp op) {
292       if (auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
293               op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
294         op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(),
295                     blockSizes);
296         // Also set up the rocdl.flat_work_group_size attribute to prevent
297         // conflicting metadata.
298         uint32_t flatSize = 1;
299         for (uint32_t size : blockSizes.asArrayRef()) {
300           flatSize *= size;
301         }
302         StringAttr flatSizeAttr =
303             StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
304         op->setAttr(ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName(),
305                     flatSizeAttr);
306       }
307     });
308   }
309 };
310 
311 } // namespace
312 
313 void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
314   target.addIllegalOp<func::FuncOp>();
315   target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
316   target.addLegalDialect<ROCDL::ROCDLDialect>();
317   target.addIllegalDialect<gpu::GPUDialect>();
318   target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
319                       LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
320                       LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
321                       LLVM::SqrtOp>();
322 
323   // TODO: Remove once we support replacing non-root ops.
324   target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
325 }
326 
327 template <typename OpTy>
328 static void populateOpPatterns(LLVMTypeConverter &converter,
329                                RewritePatternSet &patterns, StringRef f32Func,
330                                StringRef f64Func) {
331   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
332   patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
333 }
334 
335 void mlir::populateGpuToROCDLConversionPatterns(
336     LLVMTypeConverter &converter, RewritePatternSet &patterns,
337     mlir::gpu::amd::Runtime runtime) {
338   using mlir::gpu::amd::Runtime;
339 
340   populateWithGenerated(patterns);
341   patterns
342       .add<GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
343                                        ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>>(
344           converter, gpu::GPUFuncOp::getKnownBlockSizeAttrName());
345   patterns.add<GPUIndexIntrinsicOpLowering<
346       gpu::BlockIdOp, ROCDL::BlockIdXOp, ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>>(
347       converter, gpu::GPUFuncOp::getKnownGridSizeAttrName());
348   patterns
349       .add<GPUIndexIntrinsicOpLowering<gpu::BlockDimOp, ROCDL::BlockDimXOp,
350                                        ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>,
351            GPUIndexIntrinsicOpLowering<gpu::GridDimOp, ROCDL::GridDimXOp,
352                                        ROCDL::GridDimYOp, ROCDL::GridDimZOp>,
353            GPUReturnOpLowering>(converter);
354   patterns.add<GPUFuncOpLowering>(
355       converter,
356       /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
357       /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
358       StringAttr::get(&converter.getContext(),
359                       ROCDL::ROCDLDialect::getKernelFuncAttrName()));
360   if (Runtime::HIP == runtime) {
361     patterns.add<GPUPrintfOpToHIPLowering>(converter);
362   } else if (Runtime::OpenCL == runtime) {
363     // Use address space = 4 to match the OpenCL definition of printf()
364     patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/4);
365   }
366   // TODO: Add alignment for workgroup memory
367   patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);
368 
369   patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
370 
371   populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
372                                    "__ocml_fabs_f64");
373   populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
374                                    "__ocml_atan_f64");
375   populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
376                                     "__ocml_atan2_f64");
377   populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
378                                    "__ocml_cbrt_f64");
379   populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
380                                    "__ocml_ceil_f64");
381   populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
382                                   "__ocml_cos_f64");
383   populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
384                                   "__ocml_exp_f64");
385   populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
386                                    "__ocml_exp2_f64");
387   populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
388                                     "__ocml_expm1_f64");
389   populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
390                                     "__ocml_floor_f64");
391   populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
392                                     "__ocml_fmod_f64");
393   populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
394                                   "__ocml_log_f64");
395   populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
396                                     "__ocml_log10_f64");
397   populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
398                                     "__ocml_log1p_f64");
399   populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
400                                    "__ocml_log2_f64");
401   populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
402                                    "__ocml_pow_f64");
403   populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
404                                     "__ocml_rsqrt_f64");
405   populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
406                                   "__ocml_sin_f64");
407   populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
408                                    "__ocml_sqrt_f64");
409   populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
410                                    "__ocml_tanh_f64");
411   populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
412                                   "__ocml_tan_f64");
413   populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
414                                   "__ocml_erf_f64");
415 }
416 
417 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
418 mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset,
419                                       unsigned indexBitwidth,
420                                       bool useBarePtrCallConv,
421                                       gpu::amd::Runtime runtime) {
422   return std::make_unique<LowerGpuOpsToROCDLOpsPass>(
423       chipset, indexBitwidth, useBarePtrCallConv, runtime);
424 }
425