1 //===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to generate ROCDLIR operations for higher-level 10 // GPU operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 15 #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" 16 #include "mlir/Dialect/Arith/Transforms/Passes.h" 17 #include "mlir/Pass/Pass.h" 18 #include "mlir/Pass/PassManager.h" 19 #include "mlir/Transforms/Passes.h" 20 21 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" 22 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" 23 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 24 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" 25 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 26 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" 27 #include "mlir/Conversion/LLVMCommon/Pattern.h" 28 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 29 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" 30 #include "mlir/Conversion/MathToROCDL/MathToROCDL.h" 31 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 32 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 33 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" 34 #include "mlir/Dialect/Func/IR/FuncOps.h" 35 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 36 #include "mlir/Dialect/GPU/Transforms/Passes.h" 37 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 38 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" 39 #include "mlir/Dialect/Math/IR/Math.h" 40 #include "mlir/Dialect/MemRef/IR/MemRef.h" 41 #include "mlir/Dialect/Vector/IR/VectorOps.h" 42 #include "mlir/IR/BuiltinAttributes.h" 43 #include "mlir/Pass/Pass.h" 44 #include "mlir/Transforms/DialectConversion.h" 45 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 46 #include "llvm/Support/FormatVariadic.h" 47 48 #include "../GPUCommon/GPUOpsLowering.h" 49 #include "../GPUCommon/IndexIntrinsicsOpLowering.h" 50 51 namespace mlir { 52 #define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS 53 #include "mlir/Conversion/Passes.h.inc" 54 } // namespace mlir 55 56 using namespace mlir; 57 58 /// Returns true if the given `gpu.func` can be safely called using the bare 59 /// pointer calling convention. 60 static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) { 61 bool canBeBare = true; 62 for (Type type : func.getArgumentTypes()) 63 if (auto memrefTy = dyn_cast<BaseMemRefType>(type)) 64 canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy); 65 return canBeBare; 66 } 67 68 Value getLaneId(ConversionPatternRewriter &rewriter, Location loc, 69 const unsigned indexBitwidth) { 70 auto int32Type = IntegerType::get(rewriter.getContext(), 32); 71 Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32); 72 Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32); 73 Value mbcntLo = rewriter.create<ROCDL::MbcntLoOp>(loc, int32Type, 74 ValueRange{minus1, zero}); 75 Value laneId = rewriter.create<ROCDL::MbcntHiOp>(loc, int32Type, 76 ValueRange{minus1, mbcntLo}); 77 return laneId; 78 } 79 static constexpr StringLiteral amdgcnDataLayout = 80 "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32" 81 "-p7:160:256:256:32-p8:128:128-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:" 82 "32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:" 83 "64-S32-A5-G1-ni:7:8:9"; 84 85 namespace { 86 struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> { 87 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern; 88 89 LogicalResult 90 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor, 91 ConversionPatternRewriter &rewriter) const override { 92 auto loc = op->getLoc(); 93 MLIRContext *context = rewriter.getContext(); 94 // convert to: %mlo = call @llvm.amdgcn.mbcnt.lo(-1, 0) 95 // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo) 96 97 Type intTy = IntegerType::get(context, 32); 98 Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32); 99 Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32); 100 Value mbcntLo = 101 rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero}); 102 Value laneId = rewriter.create<ROCDL::MbcntHiOp>( 103 loc, intTy, ValueRange{minus1, mbcntLo}); 104 // Truncate or extend the result depending on the index bitwidth specified 105 // by the LLVMTypeConverter options. 106 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); 107 if (indexBitwidth > 32) { 108 laneId = rewriter.create<LLVM::SExtOp>( 109 loc, IntegerType::get(context, indexBitwidth), laneId); 110 } else if (indexBitwidth < 32) { 111 laneId = rewriter.create<LLVM::TruncOp>( 112 loc, IntegerType::get(context, indexBitwidth), laneId); 113 } 114 rewriter.replaceOp(op, {laneId}); 115 return success(); 116 } 117 }; 118 119 struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> { 120 using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern; 121 122 /// Lowers a shuffle to the corresponding ROCDL ops. 123 /// 124 /// Use the `width` argument to see if src lane is participating. 125 /// If not the dstLane would be itself. 126 /// 127 /// Shuffle with DS Bpermute: 128 /// let shflMode = [xor, up, down, idx] 129 /// let width = 32(usually warpsize), step = [1, 2, 4, 8, 16, ... , width]. 130 /// 1. curLaneId = using mbcnt.lo + mbcnt.hi 131 /// 2. widthOrZeroIfOutside = (curLaneId + width) & -width 132 /// 3. dstLane = shflMode(curLaneId, step) 133 /// 4. isActiveSrcLane = dstLane < isActiveSrcLane 134 /// 5. dstLane = isActiveSrcLane ? dstLane : curLaneId 135 /// 6. dwordAlignedDstLane = dstLane * 4 or dstLane << 2. 136 /// 7. bpermute(dwordAlignedDstLane, shfl_value). 137 /// 138 LogicalResult 139 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor, 140 ConversionPatternRewriter &rewriter) const override { 141 Location loc = op->getLoc(); 142 // TODO: Add support for non 32-bit shuffle values. 143 if (adaptor.getValue().getType().getIntOrFloatBitWidth() != 32) 144 return failure(); 145 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); 146 Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth); 147 148 auto int32Type = IntegerType::get(rewriter.getContext(), 32); 149 Value width = adaptor.getWidth(); 150 Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0); 151 Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width); 152 Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width); 153 Value widthOrZeroIfOutside = 154 rewriter.create<LLVM::AndOp>(loc, int32Type, add, negwidth); 155 Value dstLane; 156 // TODO: Add support for gpu::ShuffleMode::UP and gpu::ShuffleMode::DOWN. 157 // TODO: Use ds_swizzle for XOR when step/offsets are constants for better 158 // perf. 159 switch (op.getMode()) { 160 case gpu::ShuffleMode::DOWN: 161 dstLane = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, 162 adaptor.getOffset()); 163 break; 164 case gpu::ShuffleMode::XOR: 165 dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId, 166 adaptor.getOffset()); 167 break; 168 case gpu::ShuffleMode::IDX: 169 dstLane = adaptor.getOffset(); 170 break; 171 default: 172 return failure(); 173 } 174 Value isActiveSrcLane = rewriter.create<LLVM::ICmpOp>( 175 loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside); 176 Value selectDstLane = rewriter.create<LLVM::SelectOp>(loc, isActiveSrcLane, 177 dstLane, srcLaneId); 178 Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2); 179 Value dwordAlignedDstLane = 180 rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two); 181 Value initShflValue = adaptor.getValue(); 182 if (adaptor.getValue().getType().isF32()) { 183 initShflValue = 184 rewriter.create<LLVM::BitcastOp>(loc, int32Type, initShflValue); 185 } 186 Value shflValue = rewriter.create<ROCDL::DsBpermuteOp>( 187 loc, int32Type, dwordAlignedDstLane, initShflValue); 188 if (adaptor.getValue().getType().isF32()) { 189 shflValue = rewriter.create<LLVM::BitcastOp>( 190 loc, adaptor.getValue().getType(), shflValue); 191 } 192 rewriter.replaceOp(op, {shflValue, isActiveSrcLane}); 193 return success(); 194 } 195 }; 196 197 /// Import the GPU Ops to ROCDL Patterns. 198 #include "GPUToROCDL.cpp.inc" 199 200 // A pass that replaces all occurrences of GPU device operations with their 201 // corresponding ROCDL equivalent. 202 // 203 // This pass only handles device code and is not meant to be run on GPU host 204 // code. 205 struct LowerGpuOpsToROCDLOpsPass 206 : public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> { 207 LowerGpuOpsToROCDLOpsPass() = default; 208 LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth, 209 bool useBarePtrCallConv, 210 gpu::amd::Runtime runtime) { 211 if (this->chipset.getNumOccurrences() == 0) 212 this->chipset = chipset; 213 if (this->indexBitwidth.getNumOccurrences() == 0) 214 this->indexBitwidth = indexBitwidth; 215 if (this->useBarePtrCallConv.getNumOccurrences() == 0) 216 this->useBarePtrCallConv = useBarePtrCallConv; 217 if (this->runtime.getNumOccurrences() == 0) 218 this->runtime = runtime; 219 } 220 221 void runOnOperation() override { 222 gpu::GPUModuleOp m = getOperation(); 223 MLIRContext *ctx = m.getContext(); 224 225 auto llvmDataLayout = m->getAttrOfType<StringAttr>( 226 LLVM::LLVMDialect::getDataLayoutAttrName()); 227 if (!llvmDataLayout) { 228 llvmDataLayout = StringAttr::get(ctx, amdgcnDataLayout); 229 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), llvmDataLayout); 230 } 231 // Request C wrapper emission. 232 for (auto func : m.getOps<func::FuncOp>()) { 233 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), 234 UnitAttr::get(ctx)); 235 } 236 237 FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset); 238 if (failed(maybeChipset)) { 239 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset); 240 return signalPassFailure(); 241 } 242 243 /// Customize the bitwidth used for the device side index computations. 244 LowerToLLVMOptions options( 245 ctx, DataLayout(cast<DataLayoutOpInterface>(m.getOperation()))); 246 options.dataLayout = llvm::DataLayout(llvmDataLayout.getValue()); 247 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 248 options.overrideIndexBitwidth(indexBitwidth); 249 250 if (useBarePtrCallConv) { 251 options.useBarePtrCallConv = true; 252 WalkResult canUseBarePointers = 253 m.walk([](gpu::GPUFuncOp func) -> WalkResult { 254 if (canBeCalledWithBarePointers(func)) 255 return WalkResult::advance(); 256 return WalkResult::interrupt(); 257 }); 258 if (canUseBarePointers.wasInterrupted()) { 259 emitError(UnknownLoc::get(ctx), 260 "bare pointer calling convention requires all memrefs to " 261 "have static shape and use the identity map"); 262 return signalPassFailure(); 263 } 264 } 265 266 // Apply in-dialect lowering. In-dialect lowering will replace 267 // ops which need to be lowered further, which is not supported by a 268 // single conversion pass. 269 { 270 RewritePatternSet patterns(ctx); 271 populateGpuRewritePatterns(patterns); 272 arith::populateExpandBFloat16Patterns(patterns); 273 (void)applyPatternsGreedily(m, std::move(patterns)); 274 } 275 276 LLVMTypeConverter converter(ctx, options); 277 populateGpuMemorySpaceAttributeConversions( 278 converter, [](gpu::AddressSpace space) { 279 switch (space) { 280 case gpu::AddressSpace::Global: 281 return 1; 282 case gpu::AddressSpace::Workgroup: 283 return 3; 284 case gpu::AddressSpace::Private: 285 return 5; 286 } 287 llvm_unreachable("unknown address space enum value"); 288 return 0; 289 }); 290 291 RewritePatternSet llvmPatterns(ctx); 292 293 mlir::arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns); 294 populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns, 295 *maybeChipset); 296 populateVectorToLLVMConversionPatterns(converter, llvmPatterns); 297 populateMathToLLVMConversionPatterns(converter, llvmPatterns); 298 cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns); 299 cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns); 300 populateFuncToLLVMConversionPatterns(converter, llvmPatterns); 301 populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns); 302 populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime); 303 LLVMConversionTarget target(getContext()); 304 configureGpuToROCDLConversionLegality(target); 305 if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) 306 signalPassFailure(); 307 auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>(); 308 auto reqdWorkGroupSizeAttrHelper = 309 rocdlDialect->getReqdWorkGroupSizeAttrHelper(); 310 auto flatWorkGroupSizeAttrHelper = 311 rocdlDialect->getFlatWorkGroupSizeAttrHelper(); 312 // Manually rewrite known block size attributes so the LLVMIR translation 313 // infrastructure can pick them up. 314 m.walk([&](LLVM::LLVMFuncOp op) { 315 if (reqdWorkGroupSizeAttrHelper.isAttrPresent(op)) { 316 auto blockSizes = reqdWorkGroupSizeAttrHelper.getAttr(op); 317 // Also set up the rocdl.flat_work_group_size attribute to prevent 318 // conflicting metadata. 319 uint32_t flatSize = 1; 320 for (uint32_t size : blockSizes.asArrayRef()) { 321 flatSize *= size; 322 } 323 StringAttr flatSizeAttr = 324 StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize)); 325 flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr); 326 } 327 }); 328 } 329 }; 330 331 } // namespace 332 333 void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) { 334 target.addIllegalOp<func::FuncOp>(); 335 target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); 336 target.addLegalDialect<ROCDL::ROCDLDialect>(); 337 target.addIllegalDialect<gpu::GPUDialect>(); 338 target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp, 339 LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, 340 LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>(); 341 // These ops are legal for f32 type. 342 target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>([](Operation *op) { 343 return any_of(op->getOperandTypes(), llvm::IsaPred<Float32Type>); 344 }); 345 // TODO: Remove once we support replacing non-root ops. 346 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>(); 347 } 348 349 void mlir::populateGpuToROCDLConversionPatterns( 350 const LLVMTypeConverter &converter, RewritePatternSet &patterns, 351 mlir::gpu::amd::Runtime runtime) { 352 using gpu::index_lowering::IndexKind; 353 using gpu::index_lowering::IntrType; 354 using mlir::gpu::amd::Runtime; 355 auto *rocdlDialect = 356 converter.getContext().getLoadedDialect<ROCDL::ROCDLDialect>(); 357 populateWithGenerated(patterns); 358 patterns.add< 359 gpu::index_lowering::OpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp, 360 ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>>( 361 converter, IndexKind::Block, IntrType::Id); 362 patterns.add<gpu::index_lowering::OpLowering< 363 gpu::BlockIdOp, ROCDL::BlockIdXOp, ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>>( 364 converter, IndexKind::Grid, IntrType::Id); 365 patterns.add< 366 gpu::index_lowering::OpLowering<gpu::BlockDimOp, ROCDL::BlockDimXOp, 367 ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>>( 368 converter, IndexKind::Block, IntrType::Dim); 369 patterns.add<gpu::index_lowering::OpLowering< 370 gpu::GridDimOp, ROCDL::GridDimXOp, ROCDL::GridDimYOp, ROCDL::GridDimZOp>>( 371 converter, IndexKind::Grid, IntrType::Dim); 372 patterns.add<GPUReturnOpLowering>(converter); 373 patterns.add<GPUFuncOpLowering>( 374 converter, 375 GPUFuncOpLoweringOptions{ 376 /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace, 377 /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace, 378 rocdlDialect->getKernelAttrHelper().getName(), 379 rocdlDialect->getReqdWorkGroupSizeAttrHelper().getName()}); 380 if (Runtime::HIP == runtime) { 381 patterns.add<GPUPrintfOpToHIPLowering>(converter); 382 } else if (Runtime::OpenCL == runtime) { 383 // Use address space = 4 to match the OpenCL definition of printf() 384 patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/4); 385 } 386 // TODO: Add alignment for workgroup memory 387 patterns.add<GPUDynamicSharedMemoryOpLowering>(converter); 388 389 patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter); 390 391 populateMathToROCDLConversionPatterns(converter, patterns); 392 } 393 394 std::unique_ptr<OperationPass<gpu::GPUModuleOp>> 395 mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset, 396 unsigned indexBitwidth, 397 bool useBarePtrCallConv, 398 gpu::amd::Runtime runtime) { 399 return std::make_unique<LowerGpuOpsToROCDLOpsPass>( 400 chipset, indexBitwidth, useBarePtrCallConv, runtime); 401 } 402