1 //===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to generate ROCDLIR operations for higher-level 10 // GPU operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 15 #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" 16 #include "mlir/Dialect/Arith/Transforms/Passes.h" 17 #include "mlir/Pass/Pass.h" 18 #include "mlir/Pass/PassManager.h" 19 #include "mlir/Transforms/Passes.h" 20 21 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" 22 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" 23 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 24 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" 25 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 26 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" 27 #include "mlir/Conversion/LLVMCommon/Pattern.h" 28 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 29 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" 30 #include "mlir/Conversion/MathToROCDL/MathToROCDL.h" 31 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 32 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 33 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" 34 #include "mlir/Dialect/Func/IR/FuncOps.h" 35 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 36 #include "mlir/Dialect/GPU/Transforms/Passes.h" 37 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 38 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" 39 #include "mlir/Dialect/Math/IR/Math.h" 40 #include "mlir/Dialect/MemRef/IR/MemRef.h" 41 #include "mlir/Dialect/Vector/IR/VectorOps.h" 42 #include "mlir/IR/BuiltinAttributes.h" 43 #include "mlir/Pass/Pass.h" 44 #include "mlir/Transforms/DialectConversion.h" 45 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 46 #include "llvm/Support/FormatVariadic.h" 47 48 #include "../GPUCommon/GPUOpsLowering.h" 49 #include "../GPUCommon/IndexIntrinsicsOpLowering.h" 50 #include "../GPUCommon/OpToFuncCallLowering.h" 51 52 namespace mlir { 53 #define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS 54 #include "mlir/Conversion/Passes.h.inc" 55 } // namespace mlir 56 57 using namespace mlir; 58 59 /// Returns true if the given `gpu.func` can be safely called using the bare 60 /// pointer calling convention. 61 static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) { 62 bool canBeBare = true; 63 for (Type type : func.getArgumentTypes()) 64 if (auto memrefTy = dyn_cast<BaseMemRefType>(type)) 65 canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy); 66 return canBeBare; 67 } 68 69 Value getLaneId(ConversionPatternRewriter &rewriter, Location loc, 70 const unsigned indexBitwidth) { 71 auto int32Type = IntegerType::get(rewriter.getContext(), 32); 72 Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32); 73 Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32); 74 Value mbcntLo = rewriter.create<ROCDL::MbcntLoOp>(loc, int32Type, 75 ValueRange{minus1, zero}); 76 Value laneId = rewriter.create<ROCDL::MbcntHiOp>(loc, int32Type, 77 ValueRange{minus1, mbcntLo}); 78 return laneId; 79 } 80 static constexpr StringLiteral amdgcnDataLayout = 81 "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32" 82 "-p7:160:256:256:32-p8:128:128-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:" 83 "32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:" 84 "64-S32-A5-G1-ni:7:8:9"; 85 86 namespace { 87 struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> { 88 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern; 89 90 LogicalResult 91 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor, 92 ConversionPatternRewriter &rewriter) const override { 93 auto loc = op->getLoc(); 94 MLIRContext *context = rewriter.getContext(); 95 // convert to: %mlo = call @llvm.amdgcn.mbcnt.lo(-1, 0) 96 // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo) 97 98 Type intTy = IntegerType::get(context, 32); 99 Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32); 100 Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32); 101 Value mbcntLo = 102 rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero}); 103 Value laneId = rewriter.create<ROCDL::MbcntHiOp>( 104 loc, intTy, ValueRange{minus1, mbcntLo}); 105 // Truncate or extend the result depending on the index bitwidth specified 106 // by the LLVMTypeConverter options. 107 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); 108 if (indexBitwidth > 32) { 109 laneId = rewriter.create<LLVM::SExtOp>( 110 loc, IntegerType::get(context, indexBitwidth), laneId); 111 } else if (indexBitwidth < 32) { 112 laneId = rewriter.create<LLVM::TruncOp>( 113 loc, IntegerType::get(context, indexBitwidth), laneId); 114 } 115 rewriter.replaceOp(op, {laneId}); 116 return success(); 117 } 118 }; 119 120 struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> { 121 using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern; 122 123 /// Lowers a shuffle to the corresponding ROCDL ops. 124 /// 125 /// Use the `width` argument to see if src lane is participating. 126 /// If not the dstLane would be itself. 127 /// 128 /// Shuffle with DS Bpermute: 129 /// let shflMode = [xor, up, down, idx] 130 /// let width = 32(usually warpsize), step = [1, 2, 4, 8, 16, ... , width]. 131 /// 1. curLaneId = using mbcnt.lo + mbcnt.hi 132 /// 2. widthOrZeroIfOutside = (curLaneId + width) & -width 133 /// 3. dstLane = shflMode(curLaneId, step) 134 /// 4. isActiveSrcLane = dstLane < isActiveSrcLane 135 /// 5. dstLane = isActiveSrcLane ? dstLane : curLaneId 136 /// 6. dwordAlignedDstLane = dstLane * 4 or dstLane << 2. 137 /// 7. bpermute(dwordAlignedDstLane, shfl_value). 138 /// 139 LogicalResult 140 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor, 141 ConversionPatternRewriter &rewriter) const override { 142 Location loc = op->getLoc(); 143 // TODO: Add support for non 32-bit shuffle values. 144 if (adaptor.getValue().getType().getIntOrFloatBitWidth() != 32) 145 return failure(); 146 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); 147 Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth); 148 149 auto int32Type = IntegerType::get(rewriter.getContext(), 32); 150 Value width = adaptor.getWidth(); 151 Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0); 152 Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width); 153 Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width); 154 Value widthOrZeroIfOutside = 155 rewriter.create<LLVM::AndOp>(loc, int32Type, add, negwidth); 156 Value dstLane; 157 // TODO: Add support for gpu::ShuffleMode::UP and gpu::ShuffleMode::DOWN. 158 // TODO: Use ds_swizzle for XOR when step/offsets are constants for better 159 // perf. 160 switch (op.getMode()) { 161 case gpu::ShuffleMode::XOR: 162 dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId, 163 adaptor.getOffset()); 164 break; 165 case gpu::ShuffleMode::IDX: 166 dstLane = adaptor.getOffset(); 167 break; 168 default: 169 return failure(); 170 } 171 Value isActiveSrcLane = rewriter.create<LLVM::ICmpOp>( 172 loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside); 173 Value selectDstLane = rewriter.create<LLVM::SelectOp>(loc, isActiveSrcLane, 174 dstLane, srcLaneId); 175 Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2); 176 Value dwordAlignedDstLane = 177 rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two); 178 Value initShflValue = adaptor.getValue(); 179 if (adaptor.getValue().getType().isF32()) { 180 initShflValue = 181 rewriter.create<LLVM::BitcastOp>(loc, int32Type, initShflValue); 182 } 183 Value shflValue = rewriter.create<ROCDL::DsBpermuteOp>( 184 loc, int32Type, dwordAlignedDstLane, initShflValue); 185 if (adaptor.getValue().getType().isF32()) { 186 shflValue = rewriter.create<LLVM::BitcastOp>( 187 loc, adaptor.getValue().getType(), shflValue); 188 } 189 rewriter.replaceOp(op, {shflValue, isActiveSrcLane}); 190 return success(); 191 } 192 }; 193 194 /// Import the GPU Ops to ROCDL Patterns. 195 #include "GPUToROCDL.cpp.inc" 196 197 // A pass that replaces all occurrences of GPU device operations with their 198 // corresponding ROCDL equivalent. 199 // 200 // This pass only handles device code and is not meant to be run on GPU host 201 // code. 202 struct LowerGpuOpsToROCDLOpsPass 203 : public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> { 204 LowerGpuOpsToROCDLOpsPass() = default; 205 LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth, 206 bool useBarePtrCallConv, 207 gpu::amd::Runtime runtime) { 208 if (this->chipset.getNumOccurrences() == 0) 209 this->chipset = chipset; 210 if (this->indexBitwidth.getNumOccurrences() == 0) 211 this->indexBitwidth = indexBitwidth; 212 if (this->useBarePtrCallConv.getNumOccurrences() == 0) 213 this->useBarePtrCallConv = useBarePtrCallConv; 214 if (this->runtime.getNumOccurrences() == 0) 215 this->runtime = runtime; 216 } 217 218 void runOnOperation() override { 219 gpu::GPUModuleOp m = getOperation(); 220 MLIRContext *ctx = m.getContext(); 221 222 auto llvmDataLayout = m->getAttrOfType<StringAttr>( 223 LLVM::LLVMDialect::getDataLayoutAttrName()); 224 if (!llvmDataLayout) { 225 llvmDataLayout = StringAttr::get(ctx, amdgcnDataLayout); 226 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), llvmDataLayout); 227 } 228 // Request C wrapper emission. 229 for (auto func : m.getOps<func::FuncOp>()) { 230 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), 231 UnitAttr::get(ctx)); 232 } 233 234 FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset); 235 if (failed(maybeChipset)) { 236 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset); 237 return signalPassFailure(); 238 } 239 240 /// Customize the bitwidth used for the device side index computations. 241 LowerToLLVMOptions options( 242 ctx, DataLayout(cast<DataLayoutOpInterface>(m.getOperation()))); 243 options.dataLayout = llvm::DataLayout(llvmDataLayout.getValue()); 244 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 245 options.overrideIndexBitwidth(indexBitwidth); 246 247 if (useBarePtrCallConv) { 248 options.useBarePtrCallConv = true; 249 WalkResult canUseBarePointers = 250 m.walk([](gpu::GPUFuncOp func) -> WalkResult { 251 if (canBeCalledWithBarePointers(func)) 252 return WalkResult::advance(); 253 return WalkResult::interrupt(); 254 }); 255 if (canUseBarePointers.wasInterrupted()) { 256 emitError(UnknownLoc::get(ctx), 257 "bare pointer calling convention requires all memrefs to " 258 "have static shape and use the identity map"); 259 return signalPassFailure(); 260 } 261 } 262 263 // Apply in-dialect lowering. In-dialect lowering will replace 264 // ops which need to be lowered further, which is not supported by a 265 // single conversion pass. 266 { 267 RewritePatternSet patterns(ctx); 268 populateGpuRewritePatterns(patterns); 269 arith::populateExpandBFloat16Patterns(patterns); 270 (void)applyPatternsAndFoldGreedily(m, std::move(patterns)); 271 } 272 273 LLVMTypeConverter converter(ctx, options); 274 populateGpuMemorySpaceAttributeConversions( 275 converter, [](gpu::AddressSpace space) { 276 switch (space) { 277 case gpu::AddressSpace::Global: 278 return 1; 279 case gpu::AddressSpace::Workgroup: 280 return 3; 281 case gpu::AddressSpace::Private: 282 return 5; 283 } 284 llvm_unreachable("unknown address space enum value"); 285 return 0; 286 }); 287 288 RewritePatternSet llvmPatterns(ctx); 289 290 mlir::arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns); 291 populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns, 292 *maybeChipset); 293 populateVectorToLLVMConversionPatterns(converter, llvmPatterns); 294 populateMathToLLVMConversionPatterns(converter, llvmPatterns); 295 cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns); 296 populateFuncToLLVMConversionPatterns(converter, llvmPatterns); 297 populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns); 298 populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime); 299 LLVMConversionTarget target(getContext()); 300 configureGpuToROCDLConversionLegality(target); 301 if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) 302 signalPassFailure(); 303 auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>(); 304 auto reqdWorkGroupSizeAttrHelper = 305 rocdlDialect->getReqdWorkGroupSizeAttrHelper(); 306 auto flatWorkGroupSizeAttrHelper = 307 rocdlDialect->getFlatWorkGroupSizeAttrHelper(); 308 // Manually rewrite known block size attributes so the LLVMIR translation 309 // infrastructure can pick them up. 310 m.walk([&](LLVM::LLVMFuncOp op) { 311 if (reqdWorkGroupSizeAttrHelper.isAttrPresent(op)) { 312 auto blockSizes = reqdWorkGroupSizeAttrHelper.getAttr(op); 313 // Also set up the rocdl.flat_work_group_size attribute to prevent 314 // conflicting metadata. 315 uint32_t flatSize = 1; 316 for (uint32_t size : blockSizes.asArrayRef()) { 317 flatSize *= size; 318 } 319 StringAttr flatSizeAttr = 320 StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize)); 321 flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr); 322 } 323 }); 324 } 325 }; 326 327 } // namespace 328 329 void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) { 330 target.addIllegalOp<func::FuncOp>(); 331 target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); 332 target.addLegalDialect<ROCDL::ROCDLDialect>(); 333 target.addIllegalDialect<gpu::GPUDialect>(); 334 target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp, 335 LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, 336 LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>(); 337 // These ops are legal for f32 type. 338 target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>([](Operation *op) { 339 return any_of(op->getOperandTypes(), llvm::IsaPred<Float32Type>); 340 }); 341 // TODO: Remove once we support replacing non-root ops. 342 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>(); 343 } 344 345 template <typename OpTy> 346 static void populateOpPatterns(LLVMTypeConverter &converter, 347 RewritePatternSet &patterns, StringRef f32Func, 348 StringRef f64Func, StringRef f32ApproxFunc, 349 StringRef f16Func) { 350 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter); 351 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f32ApproxFunc, 352 f16Func); 353 } 354 355 void mlir::populateGpuToROCDLConversionPatterns( 356 LLVMTypeConverter &converter, RewritePatternSet &patterns, 357 mlir::gpu::amd::Runtime runtime) { 358 using gpu::index_lowering::IndexKind; 359 using gpu::index_lowering::IntrType; 360 using mlir::gpu::amd::Runtime; 361 auto *rocdlDialect = 362 converter.getContext().getLoadedDialect<ROCDL::ROCDLDialect>(); 363 populateWithGenerated(patterns); 364 patterns.add< 365 gpu::index_lowering::OpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp, 366 ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>>( 367 converter, IndexKind::Block, IntrType::Id); 368 patterns.add<gpu::index_lowering::OpLowering< 369 gpu::BlockIdOp, ROCDL::BlockIdXOp, ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>>( 370 converter, IndexKind::Grid, IntrType::Id); 371 patterns.add< 372 gpu::index_lowering::OpLowering<gpu::BlockDimOp, ROCDL::BlockDimXOp, 373 ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>>( 374 converter, IndexKind::Block, IntrType::Dim); 375 patterns.add<gpu::index_lowering::OpLowering< 376 gpu::GridDimOp, ROCDL::GridDimXOp, ROCDL::GridDimYOp, ROCDL::GridDimZOp>>( 377 converter, IndexKind::Grid, IntrType::Dim); 378 patterns.add<GPUReturnOpLowering>(converter); 379 patterns.add<GPUFuncOpLowering>( 380 converter, 381 GPUFuncOpLoweringOptions{ 382 /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace, 383 /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace, 384 rocdlDialect->getKernelAttrHelper().getName(), 385 rocdlDialect->getReqdWorkGroupSizeAttrHelper().getName()}); 386 if (Runtime::HIP == runtime) { 387 patterns.add<GPUPrintfOpToHIPLowering>(converter); 388 } else if (Runtime::OpenCL == runtime) { 389 // Use address space = 4 to match the OpenCL definition of printf() 390 patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/4); 391 } 392 // TODO: Add alignment for workgroup memory 393 patterns.add<GPUDynamicSharedMemoryOpLowering>(converter); 394 395 patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter); 396 397 populateMathToROCDLConversionPatterns(converter, patterns); 398 } 399 400 std::unique_ptr<OperationPass<gpu::GPUModuleOp>> 401 mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset, 402 unsigned indexBitwidth, 403 bool useBarePtrCallConv, 404 gpu::amd::Runtime runtime) { 405 return std::make_unique<LowerGpuOpsToROCDLOpsPass>( 406 chipset, indexBitwidth, useBarePtrCallConv, runtime); 407 } 408