xref: /llvm-project/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (revision 1c47fa9b620d0abb280647b4f361ada43784d00e)
1 //===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to generate ROCDLIR operations for higher-level
10 // GPU operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
15 #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
16 #include "mlir/Dialect/Arith/Transforms/Passes.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Pass/PassManager.h"
19 #include "mlir/Transforms/Passes.h"
20 
21 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
22 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
23 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
24 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
25 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
26 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
27 #include "mlir/Conversion/LLVMCommon/Pattern.h"
28 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
29 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
30 #include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
31 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
32 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
33 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
34 #include "mlir/Dialect/Func/IR/FuncOps.h"
35 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
36 #include "mlir/Dialect/GPU/Transforms/Passes.h"
37 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
38 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
39 #include "mlir/Dialect/Math/IR/Math.h"
40 #include "mlir/Dialect/MemRef/IR/MemRef.h"
41 #include "mlir/Dialect/Vector/IR/VectorOps.h"
42 #include "mlir/IR/BuiltinAttributes.h"
43 #include "mlir/Pass/Pass.h"
44 #include "mlir/Transforms/DialectConversion.h"
45 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
46 #include "llvm/Support/FormatVariadic.h"
47 
48 #include "../GPUCommon/GPUOpsLowering.h"
49 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
50 #include "../GPUCommon/OpToFuncCallLowering.h"
51 
52 namespace mlir {
53 #define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS
54 #include "mlir/Conversion/Passes.h.inc"
55 } // namespace mlir
56 
57 using namespace mlir;
58 
59 /// Returns true if the given `gpu.func` can be safely called using the bare
60 /// pointer calling convention.
61 static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
62   bool canBeBare = true;
63   for (Type type : func.getArgumentTypes())
64     if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
65       canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy);
66   return canBeBare;
67 }
68 
69 Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
70                 const unsigned indexBitwidth) {
71   auto int32Type = IntegerType::get(rewriter.getContext(), 32);
72   Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
73   Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
74   Value mbcntLo = rewriter.create<ROCDL::MbcntLoOp>(loc, int32Type,
75                                                     ValueRange{minus1, zero});
76   Value laneId = rewriter.create<ROCDL::MbcntHiOp>(loc, int32Type,
77                                                    ValueRange{minus1, mbcntLo});
78   return laneId;
79 }
80 static constexpr StringLiteral amdgcnDataLayout =
81     "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32"
82     "-p7:160:256:256:32-p8:128:128-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:"
83     "32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:"
84     "64-S32-A5-G1-ni:7:8:9";
85 
86 namespace {
87 struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
88   using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
89 
90   LogicalResult
91   matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
92                   ConversionPatternRewriter &rewriter) const override {
93     auto loc = op->getLoc();
94     MLIRContext *context = rewriter.getContext();
95     // convert to:  %mlo = call @llvm.amdgcn.mbcnt.lo(-1, 0)
96     // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo)
97 
98     Type intTy = IntegerType::get(context, 32);
99     Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
100     Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
101     Value mbcntLo =
102         rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
103     Value laneId = rewriter.create<ROCDL::MbcntHiOp>(
104         loc, intTy, ValueRange{minus1, mbcntLo});
105     // Truncate or extend the result depending on the index bitwidth specified
106     // by the LLVMTypeConverter options.
107     const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
108     if (indexBitwidth > 32) {
109       laneId = rewriter.create<LLVM::SExtOp>(
110           loc, IntegerType::get(context, indexBitwidth), laneId);
111     } else if (indexBitwidth < 32) {
112       laneId = rewriter.create<LLVM::TruncOp>(
113           loc, IntegerType::get(context, indexBitwidth), laneId);
114     }
115     rewriter.replaceOp(op, {laneId});
116     return success();
117   }
118 };
119 
120 struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
121   using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
122 
123   /// Lowers a shuffle to the corresponding ROCDL ops.
124   ///
125   /// Use the `width` argument to see if src lane is participating.
126   /// If not the dstLane would be itself.
127   ///
128   ///  Shuffle with DS Bpermute:
129   ///   let shflMode = [xor, up, down, idx]
130   ///   let width = 32(usually warpsize), step = [1, 2, 4, 8, 16, ... , width].
131   ///   1. curLaneId = using mbcnt.lo + mbcnt.hi
132   ///   2. widthOrZeroIfOutside = (curLaneId + width) & -width
133   ///   3. dstLane = shflMode(curLaneId, step)
134   ///   4. isActiveSrcLane = dstLane < isActiveSrcLane
135   ///   5. dstLane = isActiveSrcLane ? dstLane : curLaneId
136   ///   6. dwordAlignedDstLane = dstLane * 4 or dstLane << 2.
137   ///   7. bpermute(dwordAlignedDstLane, shfl_value).
138   ///
139   LogicalResult
140   matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
141                   ConversionPatternRewriter &rewriter) const override {
142     Location loc = op->getLoc();
143     // TODO: Add support for non 32-bit shuffle values.
144     if (adaptor.getValue().getType().getIntOrFloatBitWidth() != 32)
145       return failure();
146     const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
147     Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
148 
149     auto int32Type = IntegerType::get(rewriter.getContext(), 32);
150     Value width = adaptor.getWidth();
151     Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0);
152     Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width);
153     Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width);
154     Value widthOrZeroIfOutside =
155         rewriter.create<LLVM::AndOp>(loc, int32Type, add, negwidth);
156     Value dstLane;
157     // TODO: Add support for gpu::ShuffleMode::UP and gpu::ShuffleMode::DOWN.
158     // TODO: Use ds_swizzle for XOR when step/offsets are constants for better
159     // perf.
160     switch (op.getMode()) {
161     case gpu::ShuffleMode::XOR:
162       dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId,
163                                              adaptor.getOffset());
164       break;
165     case gpu::ShuffleMode::IDX:
166       dstLane = adaptor.getOffset();
167       break;
168     default:
169       return failure();
170     }
171     Value isActiveSrcLane = rewriter.create<LLVM::ICmpOp>(
172         loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
173     Value selectDstLane = rewriter.create<LLVM::SelectOp>(loc, isActiveSrcLane,
174                                                           dstLane, srcLaneId);
175     Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2);
176     Value dwordAlignedDstLane =
177         rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
178     Value initShflValue = adaptor.getValue();
179     if (adaptor.getValue().getType().isF32()) {
180       initShflValue =
181           rewriter.create<LLVM::BitcastOp>(loc, int32Type, initShflValue);
182     }
183     Value shflValue = rewriter.create<ROCDL::DsBpermuteOp>(
184         loc, int32Type, dwordAlignedDstLane, initShflValue);
185     if (adaptor.getValue().getType().isF32()) {
186       shflValue = rewriter.create<LLVM::BitcastOp>(
187           loc, adaptor.getValue().getType(), shflValue);
188     }
189     rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
190     return success();
191   }
192 };
193 
194 /// Import the GPU Ops to ROCDL Patterns.
195 #include "GPUToROCDL.cpp.inc"
196 
197 // A pass that replaces all occurrences of GPU device operations with their
198 // corresponding ROCDL equivalent.
199 //
200 // This pass only handles device code and is not meant to be run on GPU host
201 // code.
202 struct LowerGpuOpsToROCDLOpsPass
203     : public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
204   LowerGpuOpsToROCDLOpsPass() = default;
205   LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth,
206                             bool useBarePtrCallConv,
207                             gpu::amd::Runtime runtime) {
208     if (this->chipset.getNumOccurrences() == 0)
209       this->chipset = chipset;
210     if (this->indexBitwidth.getNumOccurrences() == 0)
211       this->indexBitwidth = indexBitwidth;
212     if (this->useBarePtrCallConv.getNumOccurrences() == 0)
213       this->useBarePtrCallConv = useBarePtrCallConv;
214     if (this->runtime.getNumOccurrences() == 0)
215       this->runtime = runtime;
216   }
217 
218   void runOnOperation() override {
219     gpu::GPUModuleOp m = getOperation();
220     MLIRContext *ctx = m.getContext();
221 
222     auto llvmDataLayout = m->getAttrOfType<StringAttr>(
223         LLVM::LLVMDialect::getDataLayoutAttrName());
224     if (!llvmDataLayout) {
225       llvmDataLayout = StringAttr::get(ctx, amdgcnDataLayout);
226       m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), llvmDataLayout);
227     }
228     // Request C wrapper emission.
229     for (auto func : m.getOps<func::FuncOp>()) {
230       func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
231                     UnitAttr::get(ctx));
232     }
233 
234     FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
235     if (failed(maybeChipset)) {
236       emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
237       return signalPassFailure();
238     }
239 
240     /// Customize the bitwidth used for the device side index computations.
241     LowerToLLVMOptions options(
242         ctx, DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
243     options.dataLayout = llvm::DataLayout(llvmDataLayout.getValue());
244     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
245       options.overrideIndexBitwidth(indexBitwidth);
246 
247     if (useBarePtrCallConv) {
248       options.useBarePtrCallConv = true;
249       WalkResult canUseBarePointers =
250           m.walk([](gpu::GPUFuncOp func) -> WalkResult {
251             if (canBeCalledWithBarePointers(func))
252               return WalkResult::advance();
253             return WalkResult::interrupt();
254           });
255       if (canUseBarePointers.wasInterrupted()) {
256         emitError(UnknownLoc::get(ctx),
257                   "bare pointer calling convention requires all memrefs to "
258                   "have static shape and use the identity map");
259         return signalPassFailure();
260       }
261     }
262 
263     // Apply in-dialect lowering. In-dialect lowering will replace
264     // ops which need to be lowered further, which is not supported by a
265     // single conversion pass.
266     {
267       RewritePatternSet patterns(ctx);
268       populateGpuRewritePatterns(patterns);
269       arith::populateExpandBFloat16Patterns(patterns);
270       (void)applyPatternsAndFoldGreedily(m, std::move(patterns));
271     }
272 
273     LLVMTypeConverter converter(ctx, options);
274     populateGpuMemorySpaceAttributeConversions(
275         converter, [](gpu::AddressSpace space) {
276           switch (space) {
277           case gpu::AddressSpace::Global:
278             return 1;
279           case gpu::AddressSpace::Workgroup:
280             return 3;
281           case gpu::AddressSpace::Private:
282             return 5;
283           }
284           llvm_unreachable("unknown address space enum value");
285           return 0;
286         });
287 
288     RewritePatternSet llvmPatterns(ctx);
289 
290     mlir::arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
291     populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
292                                             *maybeChipset);
293     populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
294     populateMathToLLVMConversionPatterns(converter, llvmPatterns);
295     cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
296     populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
297     populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
298     populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
299     LLVMConversionTarget target(getContext());
300     configureGpuToROCDLConversionLegality(target);
301     if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
302       signalPassFailure();
303     auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
304     auto reqdWorkGroupSizeAttrHelper =
305         rocdlDialect->getReqdWorkGroupSizeAttrHelper();
306     auto flatWorkGroupSizeAttrHelper =
307         rocdlDialect->getFlatWorkGroupSizeAttrHelper();
308     // Manually rewrite known block size attributes so the LLVMIR translation
309     // infrastructure can pick them up.
310     m.walk([&](LLVM::LLVMFuncOp op) {
311       if (reqdWorkGroupSizeAttrHelper.isAttrPresent(op)) {
312         auto blockSizes = reqdWorkGroupSizeAttrHelper.getAttr(op);
313         // Also set up the rocdl.flat_work_group_size attribute to prevent
314         // conflicting metadata.
315         uint32_t flatSize = 1;
316         for (uint32_t size : blockSizes.asArrayRef()) {
317           flatSize *= size;
318         }
319         StringAttr flatSizeAttr =
320             StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
321         flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
322       }
323     });
324   }
325 };
326 
327 } // namespace
328 
329 void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
330   target.addIllegalOp<func::FuncOp>();
331   target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
332   target.addLegalDialect<ROCDL::ROCDLDialect>();
333   target.addIllegalDialect<gpu::GPUDialect>();
334   target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp,
335                       LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op,
336                       LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>();
337   // These ops are legal for f32 type.
338   target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>([](Operation *op) {
339     return any_of(op->getOperandTypes(), llvm::IsaPred<Float32Type>);
340   });
341   // TODO: Remove once we support replacing non-root ops.
342   target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
343 }
344 
345 template <typename OpTy>
346 static void populateOpPatterns(LLVMTypeConverter &converter,
347                                RewritePatternSet &patterns, StringRef f32Func,
348                                StringRef f64Func, StringRef f32ApproxFunc,
349                                StringRef f16Func) {
350   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
351   patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f32ApproxFunc,
352                                            f16Func);
353 }
354 
355 void mlir::populateGpuToROCDLConversionPatterns(
356     LLVMTypeConverter &converter, RewritePatternSet &patterns,
357     mlir::gpu::amd::Runtime runtime) {
358   using gpu::index_lowering::IndexKind;
359   using gpu::index_lowering::IntrType;
360   using mlir::gpu::amd::Runtime;
361   auto *rocdlDialect =
362       converter.getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
363   populateWithGenerated(patterns);
364   patterns.add<
365       gpu::index_lowering::OpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
366                                       ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>>(
367       converter, IndexKind::Block, IntrType::Id);
368   patterns.add<gpu::index_lowering::OpLowering<
369       gpu::BlockIdOp, ROCDL::BlockIdXOp, ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>>(
370       converter, IndexKind::Grid, IntrType::Id);
371   patterns.add<
372       gpu::index_lowering::OpLowering<gpu::BlockDimOp, ROCDL::BlockDimXOp,
373                                       ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>>(
374       converter, IndexKind::Block, IntrType::Dim);
375   patterns.add<gpu::index_lowering::OpLowering<
376       gpu::GridDimOp, ROCDL::GridDimXOp, ROCDL::GridDimYOp, ROCDL::GridDimZOp>>(
377       converter, IndexKind::Grid, IntrType::Dim);
378   patterns.add<GPUReturnOpLowering>(converter);
379   patterns.add<GPUFuncOpLowering>(
380       converter,
381       GPUFuncOpLoweringOptions{
382           /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
383           /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
384           rocdlDialect->getKernelAttrHelper().getName(),
385           rocdlDialect->getReqdWorkGroupSizeAttrHelper().getName()});
386   if (Runtime::HIP == runtime) {
387     patterns.add<GPUPrintfOpToHIPLowering>(converter);
388   } else if (Runtime::OpenCL == runtime) {
389     // Use address space = 4 to match the OpenCL definition of printf()
390     patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/4);
391   }
392   // TODO: Add alignment for workgroup memory
393   patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);
394 
395   patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
396 
397   populateMathToROCDLConversionPatterns(converter, patterns);
398 }
399 
400 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
401 mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset,
402                                       unsigned indexBitwidth,
403                                       bool useBarePtrCallConv,
404                                       gpu::amd::Runtime runtime) {
405   return std::make_unique<LowerGpuOpsToROCDLOpsPass>(
406       chipset, indexBitwidth, useBarePtrCallConv, runtime);
407 }
408