1 //===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to generate ROCDLIR operations for higher-level 10 // GPU operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 15 #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" 16 #include "mlir/Dialect/Arith/Transforms/Passes.h" 17 #include "mlir/Pass/Pass.h" 18 #include "mlir/Pass/PassManager.h" 19 #include "mlir/Transforms/Passes.h" 20 21 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" 22 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" 23 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 24 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" 25 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 26 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" 27 #include "mlir/Conversion/LLVMCommon/Pattern.h" 28 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 29 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 30 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 31 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" 32 #include "mlir/Dialect/Func/IR/FuncOps.h" 33 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 34 #include "mlir/Dialect/GPU/Transforms/Passes.h" 35 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 36 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" 37 #include "mlir/Dialect/Math/IR/Math.h" 38 #include "mlir/Dialect/MemRef/IR/MemRef.h" 39 #include "mlir/Dialect/Vector/IR/VectorOps.h" 40 #include "mlir/IR/BuiltinAttributes.h" 41 #include "mlir/Pass/Pass.h" 42 #include "mlir/Transforms/DialectConversion.h" 43 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 44 #include "llvm/Support/FormatVariadic.h" 45 46 #include "../GPUCommon/GPUOpsLowering.h" 47 #include "../GPUCommon/IndexIntrinsicsOpLowering.h" 48 #include "../GPUCommon/OpToFuncCallLowering.h" 49 50 namespace mlir { 51 #define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS 52 #include "mlir/Conversion/Passes.h.inc" 53 } // namespace mlir 54 55 using namespace mlir; 56 57 /// Returns true if the given `gpu.func` can be safely called using the bare 58 /// pointer calling convention. 59 static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) { 60 bool canBeBare = true; 61 for (Type type : func.getArgumentTypes()) 62 if (auto memrefTy = dyn_cast<BaseMemRefType>(type)) 63 canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy); 64 return canBeBare; 65 } 66 67 Value getLaneId(ConversionPatternRewriter &rewriter, Location loc, 68 const unsigned indexBitwidth) { 69 auto int32Type = IntegerType::get(rewriter.getContext(), 32); 70 Value zero = rewriter.createOrFold<arith::ConstantIntOp>(loc, 0, 32); 71 Value minus1 = rewriter.createOrFold<arith::ConstantIntOp>(loc, -1, 32); 72 Value mbcntLo = rewriter.create<ROCDL::MbcntLoOp>(loc, int32Type, 73 ValueRange{minus1, zero}); 74 Value laneId = rewriter.create<ROCDL::MbcntHiOp>(loc, int32Type, 75 ValueRange{minus1, mbcntLo}); 76 return laneId; 77 } 78 79 namespace { 80 struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> { 81 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern; 82 83 LogicalResult 84 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor, 85 ConversionPatternRewriter &rewriter) const override { 86 auto loc = op->getLoc(); 87 MLIRContext *context = rewriter.getContext(); 88 // convert to: %mlo = call @llvm.amdgcn.mbcnt.lo(-1, 0) 89 // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo) 90 91 Type intTy = IntegerType::get(context, 32); 92 Value zero = rewriter.createOrFold<arith::ConstantIntOp>(loc, 0, 32); 93 Value minus1 = rewriter.createOrFold<arith::ConstantIntOp>(loc, -1, 32); 94 Value mbcntLo = 95 rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero}); 96 Value laneId = rewriter.create<ROCDL::MbcntHiOp>( 97 loc, intTy, ValueRange{minus1, mbcntLo}); 98 // Truncate or extend the result depending on the index bitwidth specified 99 // by the LLVMTypeConverter options. 100 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); 101 if (indexBitwidth > 32) { 102 laneId = rewriter.create<LLVM::SExtOp>( 103 loc, IntegerType::get(context, indexBitwidth), laneId); 104 } else if (indexBitwidth < 32) { 105 laneId = rewriter.create<LLVM::TruncOp>( 106 loc, IntegerType::get(context, indexBitwidth), laneId); 107 } 108 rewriter.replaceOp(op, {laneId}); 109 return success(); 110 } 111 }; 112 113 struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> { 114 using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern; 115 116 /// Lowers a shuffle to the corresponding ROCDL ops. 117 /// 118 /// Use the `width` argument to see if src lane is participating. 119 /// If not the dstLane would be itself. 120 /// 121 /// Shuffle with DS Bpermute: 122 /// let shflMode = [xor, up, down, idx] 123 /// let width = 32(usually warpsize), step = [1, 2, 4, 8, 16, ... , width]. 124 /// 1. curLaneId = using mbcnt.lo + mbcnt.hi 125 /// 2. widthOrZeroIfOutside = (curLaneId + width) & -width 126 /// 3. dstLane = shflMode(curLaneId, step) 127 /// 4. isActiveSrcLane = dstLane < isActiveSrcLane 128 /// 5. dstLane = isActiveSrcLane ? dstLane : curLaneId 129 /// 6. dwordAlignedDstLane = dstLane * 4 or dstLane << 2. 130 /// 7. bpermute(dwordAlignedDstLane, shfl_value). 131 /// 132 LogicalResult 133 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor, 134 ConversionPatternRewriter &rewriter) const override { 135 Location loc = op->getLoc(); 136 // TODO: Add support for non 32-bit shuffle values. 137 if (adaptor.getValue().getType().getIntOrFloatBitWidth() != 32) 138 return failure(); 139 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); 140 Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth); 141 142 auto int32Type = IntegerType::get(rewriter.getContext(), 32); 143 Value width = adaptor.getWidth(); 144 Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0); 145 Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width); 146 Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width); 147 Value widthOrZeroIfOutside = 148 rewriter.create<LLVM::AndOp>(loc, int32Type, add, negwidth); 149 Value dstLane; 150 // TODO: Add support for gpu::ShuffleMode::UP and gpu::ShuffleMode::DOWN. 151 // TODO: Use ds_swizzle for XOR when step/offsets are constants for better 152 // perf. 153 switch (op.getMode()) { 154 case gpu::ShuffleMode::XOR: 155 dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId, 156 adaptor.getOffset()); 157 break; 158 case gpu::ShuffleMode::IDX: 159 dstLane = adaptor.getOffset(); 160 break; 161 default: 162 return failure(); 163 } 164 Value isActiveSrcLane = rewriter.create<LLVM::ICmpOp>( 165 loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside); 166 Value selectDstLane = rewriter.create<LLVM::SelectOp>(loc, isActiveSrcLane, 167 dstLane, srcLaneId); 168 Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2); 169 Value dwordAlignedDstLane = 170 rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two); 171 Value initShflValue = adaptor.getValue(); 172 if (adaptor.getValue().getType().isF32()) { 173 initShflValue = 174 rewriter.create<LLVM::BitcastOp>(loc, int32Type, initShflValue); 175 } 176 Value shflValue = rewriter.create<ROCDL::DsBpermuteOp>( 177 loc, int32Type, dwordAlignedDstLane, initShflValue); 178 if (adaptor.getValue().getType().isF32()) { 179 shflValue = rewriter.create<LLVM::BitcastOp>( 180 loc, adaptor.getValue().getType(), shflValue); 181 } 182 rewriter.replaceOp(op, {shflValue, isActiveSrcLane}); 183 return success(); 184 } 185 }; 186 187 /// Import the GPU Ops to ROCDL Patterns. 188 #include "GPUToROCDL.cpp.inc" 189 190 // A pass that replaces all occurrences of GPU device operations with their 191 // corresponding ROCDL equivalent. 192 // 193 // This pass only handles device code and is not meant to be run on GPU host 194 // code. 195 struct LowerGpuOpsToROCDLOpsPass 196 : public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> { 197 LowerGpuOpsToROCDLOpsPass() = default; 198 LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth, 199 bool useBarePtrCallConv, 200 gpu::amd::Runtime runtime) { 201 if (this->chipset.getNumOccurrences() == 0) 202 this->chipset = chipset; 203 if (this->indexBitwidth.getNumOccurrences() == 0) 204 this->indexBitwidth = indexBitwidth; 205 if (this->useBarePtrCallConv.getNumOccurrences() == 0) 206 this->useBarePtrCallConv = useBarePtrCallConv; 207 if (this->runtime.getNumOccurrences() == 0) 208 this->runtime = runtime; 209 } 210 211 void runOnOperation() override { 212 gpu::GPUModuleOp m = getOperation(); 213 MLIRContext *ctx = m.getContext(); 214 215 // Request C wrapper emission. 216 for (auto func : m.getOps<func::FuncOp>()) { 217 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), 218 UnitAttr::get(ctx)); 219 } 220 221 FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset); 222 if (failed(maybeChipset)) { 223 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset); 224 return signalPassFailure(); 225 } 226 227 /// Customize the bitwidth used for the device side index computations. 228 LowerToLLVMOptions options( 229 ctx, DataLayout(cast<DataLayoutOpInterface>(m.getOperation()))); 230 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 231 options.overrideIndexBitwidth(indexBitwidth); 232 233 if (useBarePtrCallConv) { 234 options.useBarePtrCallConv = true; 235 WalkResult canUseBarePointers = 236 m.walk([](gpu::GPUFuncOp func) -> WalkResult { 237 if (canBeCalledWithBarePointers(func)) 238 return WalkResult::advance(); 239 return WalkResult::interrupt(); 240 }); 241 if (canUseBarePointers.wasInterrupted()) { 242 emitError(UnknownLoc::get(ctx), 243 "bare pointer calling convention requires all memrefs to " 244 "have static shape and use the identity map"); 245 return signalPassFailure(); 246 } 247 } 248 249 // Apply in-dialect lowering. In-dialect lowering will replace 250 // ops which need to be lowered further, which is not supported by a 251 // single conversion pass. 252 { 253 RewritePatternSet patterns(ctx); 254 populateGpuRewritePatterns(patterns); 255 arith::populateExpandBFloat16Patterns(patterns); 256 (void)applyPatternsAndFoldGreedily(m, std::move(patterns)); 257 } 258 259 LLVMTypeConverter converter(ctx, options); 260 populateGpuMemorySpaceAttributeConversions( 261 converter, [](gpu::AddressSpace space) { 262 switch (space) { 263 case gpu::AddressSpace::Global: 264 return 1; 265 case gpu::AddressSpace::Workgroup: 266 return 3; 267 case gpu::AddressSpace::Private: 268 return 5; 269 } 270 llvm_unreachable("unknown address space enum value"); 271 return 0; 272 }); 273 274 RewritePatternSet llvmPatterns(ctx); 275 276 mlir::arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns); 277 populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns, 278 *maybeChipset); 279 populateVectorToLLVMConversionPatterns(converter, llvmPatterns); 280 cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns); 281 populateFuncToLLVMConversionPatterns(converter, llvmPatterns); 282 populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns); 283 populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime); 284 LLVMConversionTarget target(getContext()); 285 configureGpuToROCDLConversionLegality(target); 286 if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) 287 signalPassFailure(); 288 289 // Manually rewrite known block size attributes so the LLVMIR translation 290 // infrastructure can pick them up. 291 m.walk([ctx](LLVM::LLVMFuncOp op) { 292 if (auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>( 293 op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) { 294 op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(), 295 blockSizes); 296 // Also set up the rocdl.flat_work_group_size attribute to prevent 297 // conflicting metadata. 298 uint32_t flatSize = 1; 299 for (uint32_t size : blockSizes.asArrayRef()) { 300 flatSize *= size; 301 } 302 StringAttr flatSizeAttr = 303 StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize)); 304 op->setAttr(ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName(), 305 flatSizeAttr); 306 } 307 }); 308 } 309 }; 310 311 } // namespace 312 313 void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) { 314 target.addIllegalOp<func::FuncOp>(); 315 target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); 316 target.addLegalDialect<ROCDL::ROCDLDialect>(); 317 target.addIllegalDialect<gpu::GPUDialect>(); 318 target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp, 319 LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, 320 LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp, 321 LLVM::SqrtOp>(); 322 323 // TODO: Remove once we support replacing non-root ops. 324 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>(); 325 } 326 327 template <typename OpTy> 328 static void populateOpPatterns(LLVMTypeConverter &converter, 329 RewritePatternSet &patterns, StringRef f32Func, 330 StringRef f64Func) { 331 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter); 332 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func); 333 } 334 335 void mlir::populateGpuToROCDLConversionPatterns( 336 LLVMTypeConverter &converter, RewritePatternSet &patterns, 337 mlir::gpu::amd::Runtime runtime) { 338 using mlir::gpu::amd::Runtime; 339 340 populateWithGenerated(patterns); 341 patterns 342 .add<GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp, 343 ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>>( 344 converter, gpu::GPUFuncOp::getKnownBlockSizeAttrName()); 345 patterns.add<GPUIndexIntrinsicOpLowering< 346 gpu::BlockIdOp, ROCDL::BlockIdXOp, ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>>( 347 converter, gpu::GPUFuncOp::getKnownGridSizeAttrName()); 348 patterns 349 .add<GPUIndexIntrinsicOpLowering<gpu::BlockDimOp, ROCDL::BlockDimXOp, 350 ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>, 351 GPUIndexIntrinsicOpLowering<gpu::GridDimOp, ROCDL::GridDimXOp, 352 ROCDL::GridDimYOp, ROCDL::GridDimZOp>, 353 GPUReturnOpLowering>(converter); 354 patterns.add<GPUFuncOpLowering>( 355 converter, 356 /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace, 357 /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace, 358 StringAttr::get(&converter.getContext(), 359 ROCDL::ROCDLDialect::getKernelFuncAttrName())); 360 if (Runtime::HIP == runtime) { 361 patterns.add<GPUPrintfOpToHIPLowering>(converter); 362 } else if (Runtime::OpenCL == runtime) { 363 // Use address space = 4 to match the OpenCL definition of printf() 364 patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/4); 365 } 366 // TODO: Add alignment for workgroup memory 367 patterns.add<GPUDynamicSharedMemoryOpLowering>(converter); 368 369 patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter); 370 371 populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32", 372 "__ocml_fabs_f64"); 373 populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32", 374 "__ocml_atan_f64"); 375 populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32", 376 "__ocml_atan2_f64"); 377 populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32", 378 "__ocml_cbrt_f64"); 379 populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32", 380 "__ocml_ceil_f64"); 381 populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32", 382 "__ocml_cos_f64"); 383 populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32", 384 "__ocml_exp_f64"); 385 populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32", 386 "__ocml_exp2_f64"); 387 populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32", 388 "__ocml_expm1_f64"); 389 populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32", 390 "__ocml_floor_f64"); 391 populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32", 392 "__ocml_fmod_f64"); 393 populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32", 394 "__ocml_log_f64"); 395 populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32", 396 "__ocml_log10_f64"); 397 populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32", 398 "__ocml_log1p_f64"); 399 populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32", 400 "__ocml_log2_f64"); 401 populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32", 402 "__ocml_pow_f64"); 403 populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32", 404 "__ocml_rsqrt_f64"); 405 populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32", 406 "__ocml_sin_f64"); 407 populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32", 408 "__ocml_sqrt_f64"); 409 populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32", 410 "__ocml_tanh_f64"); 411 populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32", 412 "__ocml_tan_f64"); 413 populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32", 414 "__ocml_erf_f64"); 415 } 416 417 std::unique_ptr<OperationPass<gpu::GPUModuleOp>> 418 mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset, 419 unsigned indexBitwidth, 420 bool useBarePtrCallConv, 421 gpu::amd::Runtime runtime) { 422 return std::make_unique<LowerGpuOpsToROCDLOpsPass>( 423 chipset, indexBitwidth, useBarePtrCallConv, runtime); 424 } 425