xref: /llvm-project/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (revision 599c73990532333e62edf8ba19a5302b543f976f)
1 //===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to generate ROCDLIR operations for higher-level
10 // GPU operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
15 #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
16 #include "mlir/Dialect/Arith/Transforms/Passes.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Pass/PassManager.h"
19 #include "mlir/Transforms/Passes.h"
20 
21 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
22 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
23 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
24 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
25 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
26 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
27 #include "mlir/Conversion/LLVMCommon/Pattern.h"
28 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
29 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
30 #include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
31 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
32 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
33 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
34 #include "mlir/Dialect/Func/IR/FuncOps.h"
35 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
36 #include "mlir/Dialect/GPU/Transforms/Passes.h"
37 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
38 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
39 #include "mlir/Dialect/Math/IR/Math.h"
40 #include "mlir/Dialect/MemRef/IR/MemRef.h"
41 #include "mlir/Dialect/Vector/IR/VectorOps.h"
42 #include "mlir/IR/BuiltinAttributes.h"
43 #include "mlir/Pass/Pass.h"
44 #include "mlir/Transforms/DialectConversion.h"
45 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
46 #include "llvm/Support/FormatVariadic.h"
47 
48 #include "../GPUCommon/GPUOpsLowering.h"
49 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
50 
51 namespace mlir {
52 #define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS
53 #include "mlir/Conversion/Passes.h.inc"
54 } // namespace mlir
55 
56 using namespace mlir;
57 
58 /// Returns true if the given `gpu.func` can be safely called using the bare
59 /// pointer calling convention.
60 static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
61   bool canBeBare = true;
62   for (Type type : func.getArgumentTypes())
63     if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
64       canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy);
65   return canBeBare;
66 }
67 
68 Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
69                 const unsigned indexBitwidth) {
70   auto int32Type = IntegerType::get(rewriter.getContext(), 32);
71   Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
72   Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
73   Value mbcntLo = rewriter.create<ROCDL::MbcntLoOp>(loc, int32Type,
74                                                     ValueRange{minus1, zero});
75   Value laneId = rewriter.create<ROCDL::MbcntHiOp>(loc, int32Type,
76                                                    ValueRange{minus1, mbcntLo});
77   return laneId;
78 }
79 static constexpr StringLiteral amdgcnDataLayout =
80     "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32"
81     "-p7:160:256:256:32-p8:128:128-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:"
82     "32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:"
83     "64-S32-A5-G1-ni:7:8:9";
84 
85 namespace {
86 struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
87   using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
88 
89   LogicalResult
90   matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
91                   ConversionPatternRewriter &rewriter) const override {
92     auto loc = op->getLoc();
93     MLIRContext *context = rewriter.getContext();
94     // convert to:  %mlo = call @llvm.amdgcn.mbcnt.lo(-1, 0)
95     // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo)
96 
97     Type intTy = IntegerType::get(context, 32);
98     Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
99     Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
100     Value mbcntLo =
101         rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
102     Value laneId = rewriter.create<ROCDL::MbcntHiOp>(
103         loc, intTy, ValueRange{minus1, mbcntLo});
104     // Truncate or extend the result depending on the index bitwidth specified
105     // by the LLVMTypeConverter options.
106     const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
107     if (indexBitwidth > 32) {
108       laneId = rewriter.create<LLVM::SExtOp>(
109           loc, IntegerType::get(context, indexBitwidth), laneId);
110     } else if (indexBitwidth < 32) {
111       laneId = rewriter.create<LLVM::TruncOp>(
112           loc, IntegerType::get(context, indexBitwidth), laneId);
113     }
114     rewriter.replaceOp(op, {laneId});
115     return success();
116   }
117 };
118 
119 struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
120   using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
121 
122   /// Lowers a shuffle to the corresponding ROCDL ops.
123   ///
124   /// Use the `width` argument to see if src lane is participating.
125   /// If not the dstLane would be itself.
126   ///
127   ///  Shuffle with DS Bpermute:
128   ///   let shflMode = [xor, up, down, idx]
129   ///   let width = 32(usually warpsize), step = [1, 2, 4, 8, 16, ... , width].
130   ///   1. curLaneId = using mbcnt.lo + mbcnt.hi
131   ///   2. widthOrZeroIfOutside = (curLaneId + width) & -width
132   ///   3. dstLane = shflMode(curLaneId, step)
133   ///   4. isActiveSrcLane = dstLane < isActiveSrcLane
134   ///   5. dstLane = isActiveSrcLane ? dstLane : curLaneId
135   ///   6. dwordAlignedDstLane = dstLane * 4 or dstLane << 2.
136   ///   7. bpermute(dwordAlignedDstLane, shfl_value).
137   ///
138   LogicalResult
139   matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
140                   ConversionPatternRewriter &rewriter) const override {
141     Location loc = op->getLoc();
142     // TODO: Add support for non 32-bit shuffle values.
143     if (adaptor.getValue().getType().getIntOrFloatBitWidth() != 32)
144       return failure();
145     const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
146     Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
147 
148     auto int32Type = IntegerType::get(rewriter.getContext(), 32);
149     Value width = adaptor.getWidth();
150     Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0);
151     Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width);
152     Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width);
153     Value widthOrZeroIfOutside =
154         rewriter.create<LLVM::AndOp>(loc, int32Type, add, negwidth);
155     Value dstLane;
156     // TODO: Add support for gpu::ShuffleMode::UP and gpu::ShuffleMode::DOWN.
157     // TODO: Use ds_swizzle for XOR when step/offsets are constants for better
158     // perf.
159     switch (op.getMode()) {
160     case gpu::ShuffleMode::DOWN:
161       dstLane = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId,
162                                              adaptor.getOffset());
163       break;
164     case gpu::ShuffleMode::XOR:
165       dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId,
166                                              adaptor.getOffset());
167       break;
168     case gpu::ShuffleMode::IDX:
169       dstLane = adaptor.getOffset();
170       break;
171     default:
172       return failure();
173     }
174     Value isActiveSrcLane = rewriter.create<LLVM::ICmpOp>(
175         loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
176     Value selectDstLane = rewriter.create<LLVM::SelectOp>(loc, isActiveSrcLane,
177                                                           dstLane, srcLaneId);
178     Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2);
179     Value dwordAlignedDstLane =
180         rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
181     Value initShflValue = adaptor.getValue();
182     if (adaptor.getValue().getType().isF32()) {
183       initShflValue =
184           rewriter.create<LLVM::BitcastOp>(loc, int32Type, initShflValue);
185     }
186     Value shflValue = rewriter.create<ROCDL::DsBpermuteOp>(
187         loc, int32Type, dwordAlignedDstLane, initShflValue);
188     if (adaptor.getValue().getType().isF32()) {
189       shflValue = rewriter.create<LLVM::BitcastOp>(
190           loc, adaptor.getValue().getType(), shflValue);
191     }
192     rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
193     return success();
194   }
195 };
196 
197 /// Import the GPU Ops to ROCDL Patterns.
198 #include "GPUToROCDL.cpp.inc"
199 
200 // A pass that replaces all occurrences of GPU device operations with their
201 // corresponding ROCDL equivalent.
202 //
203 // This pass only handles device code and is not meant to be run on GPU host
204 // code.
205 struct LowerGpuOpsToROCDLOpsPass
206     : public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
207   LowerGpuOpsToROCDLOpsPass() = default;
208   LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth,
209                             bool useBarePtrCallConv,
210                             gpu::amd::Runtime runtime) {
211     if (this->chipset.getNumOccurrences() == 0)
212       this->chipset = chipset;
213     if (this->indexBitwidth.getNumOccurrences() == 0)
214       this->indexBitwidth = indexBitwidth;
215     if (this->useBarePtrCallConv.getNumOccurrences() == 0)
216       this->useBarePtrCallConv = useBarePtrCallConv;
217     if (this->runtime.getNumOccurrences() == 0)
218       this->runtime = runtime;
219   }
220 
221   void runOnOperation() override {
222     gpu::GPUModuleOp m = getOperation();
223     MLIRContext *ctx = m.getContext();
224 
225     auto llvmDataLayout = m->getAttrOfType<StringAttr>(
226         LLVM::LLVMDialect::getDataLayoutAttrName());
227     if (!llvmDataLayout) {
228       llvmDataLayout = StringAttr::get(ctx, amdgcnDataLayout);
229       m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), llvmDataLayout);
230     }
231     // Request C wrapper emission.
232     for (auto func : m.getOps<func::FuncOp>()) {
233       func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
234                     UnitAttr::get(ctx));
235     }
236 
237     FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
238     if (failed(maybeChipset)) {
239       emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
240       return signalPassFailure();
241     }
242 
243     /// Customize the bitwidth used for the device side index computations.
244     LowerToLLVMOptions options(
245         ctx, DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
246     options.dataLayout = llvm::DataLayout(llvmDataLayout.getValue());
247     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
248       options.overrideIndexBitwidth(indexBitwidth);
249 
250     if (useBarePtrCallConv) {
251       options.useBarePtrCallConv = true;
252       WalkResult canUseBarePointers =
253           m.walk([](gpu::GPUFuncOp func) -> WalkResult {
254             if (canBeCalledWithBarePointers(func))
255               return WalkResult::advance();
256             return WalkResult::interrupt();
257           });
258       if (canUseBarePointers.wasInterrupted()) {
259         emitError(UnknownLoc::get(ctx),
260                   "bare pointer calling convention requires all memrefs to "
261                   "have static shape and use the identity map");
262         return signalPassFailure();
263       }
264     }
265 
266     // Apply in-dialect lowering. In-dialect lowering will replace
267     // ops which need to be lowered further, which is not supported by a
268     // single conversion pass.
269     {
270       RewritePatternSet patterns(ctx);
271       populateGpuRewritePatterns(patterns);
272       arith::populateExpandBFloat16Patterns(patterns);
273       (void)applyPatternsGreedily(m, std::move(patterns));
274     }
275 
276     LLVMTypeConverter converter(ctx, options);
277     populateGpuMemorySpaceAttributeConversions(
278         converter, [](gpu::AddressSpace space) {
279           switch (space) {
280           case gpu::AddressSpace::Global:
281             return 1;
282           case gpu::AddressSpace::Workgroup:
283             return 3;
284           case gpu::AddressSpace::Private:
285             return 5;
286           }
287           llvm_unreachable("unknown address space enum value");
288           return 0;
289         });
290 
291     RewritePatternSet llvmPatterns(ctx);
292 
293     mlir::arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
294     populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
295                                             *maybeChipset);
296     populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
297     populateMathToLLVMConversionPatterns(converter, llvmPatterns);
298     cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
299     cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns);
300     populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
301     populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
302     populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
303     LLVMConversionTarget target(getContext());
304     configureGpuToROCDLConversionLegality(target);
305     if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
306       signalPassFailure();
307     auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
308     auto reqdWorkGroupSizeAttrHelper =
309         rocdlDialect->getReqdWorkGroupSizeAttrHelper();
310     auto flatWorkGroupSizeAttrHelper =
311         rocdlDialect->getFlatWorkGroupSizeAttrHelper();
312     // Manually rewrite known block size attributes so the LLVMIR translation
313     // infrastructure can pick them up.
314     m.walk([&](LLVM::LLVMFuncOp op) {
315       if (reqdWorkGroupSizeAttrHelper.isAttrPresent(op)) {
316         auto blockSizes = reqdWorkGroupSizeAttrHelper.getAttr(op);
317         // Also set up the rocdl.flat_work_group_size attribute to prevent
318         // conflicting metadata.
319         uint32_t flatSize = 1;
320         for (uint32_t size : blockSizes.asArrayRef()) {
321           flatSize *= size;
322         }
323         StringAttr flatSizeAttr =
324             StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
325         flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
326       }
327     });
328   }
329 };
330 
331 } // namespace
332 
333 void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
334   target.addIllegalOp<func::FuncOp>();
335   target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
336   target.addLegalDialect<ROCDL::ROCDLDialect>();
337   target.addIllegalDialect<gpu::GPUDialect>();
338   target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp,
339                       LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op,
340                       LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>();
341   // These ops are legal for f32 type.
342   target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>([](Operation *op) {
343     return any_of(op->getOperandTypes(), llvm::IsaPred<Float32Type>);
344   });
345   // TODO: Remove once we support replacing non-root ops.
346   target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
347 }
348 
349 void mlir::populateGpuToROCDLConversionPatterns(
350     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
351     mlir::gpu::amd::Runtime runtime) {
352   using gpu::index_lowering::IndexKind;
353   using gpu::index_lowering::IntrType;
354   using mlir::gpu::amd::Runtime;
355   auto *rocdlDialect =
356       converter.getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
357   populateWithGenerated(patterns);
358   patterns.add<
359       gpu::index_lowering::OpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
360                                       ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>>(
361       converter, IndexKind::Block, IntrType::Id);
362   patterns.add<gpu::index_lowering::OpLowering<
363       gpu::BlockIdOp, ROCDL::BlockIdXOp, ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>>(
364       converter, IndexKind::Grid, IntrType::Id);
365   patterns.add<
366       gpu::index_lowering::OpLowering<gpu::BlockDimOp, ROCDL::BlockDimXOp,
367                                       ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>>(
368       converter, IndexKind::Block, IntrType::Dim);
369   patterns.add<gpu::index_lowering::OpLowering<
370       gpu::GridDimOp, ROCDL::GridDimXOp, ROCDL::GridDimYOp, ROCDL::GridDimZOp>>(
371       converter, IndexKind::Grid, IntrType::Dim);
372   patterns.add<GPUReturnOpLowering>(converter);
373   patterns.add<GPUFuncOpLowering>(
374       converter,
375       GPUFuncOpLoweringOptions{
376           /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
377           /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
378           rocdlDialect->getKernelAttrHelper().getName(),
379           rocdlDialect->getReqdWorkGroupSizeAttrHelper().getName()});
380   if (Runtime::HIP == runtime) {
381     patterns.add<GPUPrintfOpToHIPLowering>(converter);
382   } else if (Runtime::OpenCL == runtime) {
383     // Use address space = 4 to match the OpenCL definition of printf()
384     patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/4);
385   }
386   // TODO: Add alignment for workgroup memory
387   patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);
388 
389   patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
390 
391   populateMathToROCDLConversionPatterns(converter, patterns);
392 }
393 
394 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
395 mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset,
396                                       unsigned indexBitwidth,
397                                       bool useBarePtrCallConv,
398                                       gpu::amd::Runtime runtime) {
399   return std::make_unique<LowerGpuOpsToROCDLOpsPass>(
400       chipset, indexBitwidth, useBarePtrCallConv, runtime);
401 }
402