1 //===- GPUToLLVMSPV.cpp - Convert GPU operations to LLVM dialect ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/GPUToLLVMSPV/GPUToLLVMSPVPass.h" 10 11 #include "../GPUCommon/GPUOpsLowering.h" 12 #include "mlir/Conversion/GPUCommon/AttrToSPIRVConverter.h" 13 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" 14 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 15 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" 16 #include "mlir/Conversion/LLVMCommon/Pattern.h" 17 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 18 #include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h" 19 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 20 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" 21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 22 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 23 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 24 #include "mlir/IR/BuiltinTypes.h" 25 #include "mlir/IR/Matchers.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "mlir/IR/SymbolTable.h" 28 #include "mlir/Pass/Pass.h" 29 #include "mlir/Support/LLVM.h" 30 #include "mlir/Transforms/DialectConversion.h" 31 32 #include "llvm/ADT/TypeSwitch.h" 33 #include "llvm/Support/FormatVariadic.h" 34 35 #define DEBUG_TYPE "gpu-to-llvm-spv" 36 37 using namespace mlir; 38 39 namespace mlir { 40 #define GEN_PASS_DEF_CONVERTGPUOPSTOLLVMSPVOPS 41 #include "mlir/Conversion/Passes.h.inc" 42 } // namespace mlir 43 44 //===----------------------------------------------------------------------===// 45 // Helper Functions 46 //===----------------------------------------------------------------------===// 47 48 static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, 49 StringRef name, 50 ArrayRef<Type> paramTypes, 51 Type resultType, bool isMemNone, 52 bool isConvergent) { 53 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>( 54 SymbolTable::lookupSymbolIn(symbolTable, name)); 55 if (!func) { 56 OpBuilder b(symbolTable->getRegion(0)); 57 func = b.create<LLVM::LLVMFuncOp>( 58 symbolTable->getLoc(), name, 59 LLVM::LLVMFunctionType::get(resultType, paramTypes)); 60 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC); 61 func.setNoUnwind(true); 62 func.setWillReturn(true); 63 64 if (isMemNone) { 65 // no externally observable effects 66 constexpr auto noModRef = mlir::LLVM::ModRefInfo::NoModRef; 67 auto memAttr = b.getAttr<LLVM::MemoryEffectsAttr>( 68 /*other=*/noModRef, 69 /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); 70 func.setMemoryEffectsAttr(memAttr); 71 } 72 73 func.setConvergent(isConvergent); 74 } 75 return func; 76 } 77 78 static LLVM::CallOp createSPIRVBuiltinCall(Location loc, 79 ConversionPatternRewriter &rewriter, 80 LLVM::LLVMFuncOp func, 81 ValueRange args) { 82 auto call = rewriter.create<LLVM::CallOp>(loc, func, args); 83 call.setCConv(func.getCConv()); 84 call.setConvergentAttr(func.getConvergentAttr()); 85 call.setNoUnwindAttr(func.getNoUnwindAttr()); 86 call.setWillReturnAttr(func.getWillReturnAttr()); 87 call.setMemoryEffectsAttr(func.getMemoryEffectsAttr()); 88 return call; 89 } 90 91 namespace { 92 //===----------------------------------------------------------------------===// 93 // Barriers 94 //===----------------------------------------------------------------------===// 95 96 /// Replace `gpu.barrier` with an `llvm.call` to `barrier` with 97 /// `CLK_LOCAL_MEM_FENCE` argument, indicating work-group memory scope: 98 /// ``` 99 /// // gpu.barrier 100 /// %c1 = llvm.mlir.constant(1: i32) : i32 101 /// llvm.call spir_funccc @_Z7barrierj(%c1) : (i32) -> () 102 /// ``` 103 struct GPUBarrierConversion final : ConvertOpToLLVMPattern<gpu::BarrierOp> { 104 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 105 106 LogicalResult 107 matchAndRewrite(gpu::BarrierOp op, OpAdaptor adaptor, 108 ConversionPatternRewriter &rewriter) const final { 109 constexpr StringLiteral funcName = "_Z7barrierj"; 110 111 Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>(); 112 assert(moduleOp && "Expecting module"); 113 Type flagTy = rewriter.getI32Type(); 114 Type voidTy = rewriter.getType<LLVM::LLVMVoidType>(); 115 LLVM::LLVMFuncOp func = 116 lookupOrCreateSPIRVFn(moduleOp, funcName, flagTy, voidTy, 117 /*isMemNone=*/false, /*isConvergent=*/true); 118 119 // Value used by SPIR-V backend to represent `CLK_LOCAL_MEM_FENCE`. 120 // See `llvm/lib/Target/SPIRV/SPIRVBuiltins.td`. 121 constexpr int64_t localMemFenceFlag = 1; 122 Location loc = op->getLoc(); 123 Value flag = 124 rewriter.create<LLVM::ConstantOp>(loc, flagTy, localMemFenceFlag); 125 rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, flag)); 126 return success(); 127 } 128 }; 129 130 //===----------------------------------------------------------------------===// 131 // SPIR-V Builtins 132 //===----------------------------------------------------------------------===// 133 134 /// Replace `gpu.*` with an `llvm.call` to the corresponding SPIR-V builtin with 135 /// a constant argument for the `dimension` attribute. Return type will depend 136 /// on index width option: 137 /// ``` 138 /// // %thread_id_y = gpu.thread_id y 139 /// %c1 = llvm.mlir.constant(1: i32) : i32 140 /// %0 = llvm.call spir_funccc @_Z12get_local_idj(%c1) : (i32) -> i64 141 /// ``` 142 struct LaunchConfigConversion : ConvertToLLVMPattern { 143 LaunchConfigConversion(StringRef funcName, StringRef rootOpName, 144 MLIRContext *context, 145 const LLVMTypeConverter &typeConverter, 146 PatternBenefit benefit) 147 : ConvertToLLVMPattern(rootOpName, context, typeConverter, benefit), 148 funcName(funcName) {} 149 150 virtual gpu::Dimension getDimension(Operation *op) const = 0; 151 152 LogicalResult 153 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 154 ConversionPatternRewriter &rewriter) const final { 155 Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>(); 156 assert(moduleOp && "Expecting module"); 157 Type dimTy = rewriter.getI32Type(); 158 Type indexTy = getTypeConverter()->getIndexType(); 159 LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(moduleOp, funcName, dimTy, 160 indexTy, /*isMemNone=*/true, 161 /*isConvergent=*/false); 162 163 Location loc = op->getLoc(); 164 gpu::Dimension dim = getDimension(op); 165 Value dimVal = rewriter.create<LLVM::ConstantOp>(loc, dimTy, 166 static_cast<int64_t>(dim)); 167 rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, dimVal)); 168 return success(); 169 } 170 171 StringRef funcName; 172 }; 173 174 template <typename SourceOp> 175 struct LaunchConfigOpConversion final : LaunchConfigConversion { 176 static StringRef getFuncName(); 177 178 explicit LaunchConfigOpConversion(const LLVMTypeConverter &typeConverter, 179 PatternBenefit benefit = 1) 180 : LaunchConfigConversion(getFuncName(), SourceOp::getOperationName(), 181 &typeConverter.getContext(), typeConverter, 182 benefit) {} 183 184 gpu::Dimension getDimension(Operation *op) const final { 185 return cast<SourceOp>(op).getDimension(); 186 } 187 }; 188 189 template <> 190 StringRef LaunchConfigOpConversion<gpu::BlockIdOp>::getFuncName() { 191 return "_Z12get_group_idj"; 192 } 193 194 template <> 195 StringRef LaunchConfigOpConversion<gpu::GridDimOp>::getFuncName() { 196 return "_Z14get_num_groupsj"; 197 } 198 199 template <> 200 StringRef LaunchConfigOpConversion<gpu::BlockDimOp>::getFuncName() { 201 return "_Z14get_local_sizej"; 202 } 203 204 template <> 205 StringRef LaunchConfigOpConversion<gpu::ThreadIdOp>::getFuncName() { 206 return "_Z12get_local_idj"; 207 } 208 209 template <> 210 StringRef LaunchConfigOpConversion<gpu::GlobalIdOp>::getFuncName() { 211 return "_Z13get_global_idj"; 212 } 213 214 //===----------------------------------------------------------------------===// 215 // Shuffles 216 //===----------------------------------------------------------------------===// 217 218 /// Replace `gpu.shuffle` with an `llvm.call` to the corresponding SPIR-V 219 /// builtin for `shuffleResult`, keeping `value` and `offset` arguments, and a 220 /// `true` constant for the `valid` result type. Conversion will only take place 221 /// if `width` is constant and equal to the `subgroup` pass option: 222 /// ``` 223 /// // %0 = gpu.shuffle idx %value, %offset, %width : f64 224 /// %0 = llvm.call spir_funccc @_Z17sub_group_shuffledj(%value, %offset) 225 /// : (f64, i32) -> f64 226 /// ``` 227 struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> { 228 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 229 230 static StringRef getBaseName(gpu::ShuffleMode mode) { 231 switch (mode) { 232 case gpu::ShuffleMode::IDX: 233 return "sub_group_shuffle"; 234 case gpu::ShuffleMode::XOR: 235 return "sub_group_shuffle_xor"; 236 case gpu::ShuffleMode::UP: 237 return "sub_group_shuffle_up"; 238 case gpu::ShuffleMode::DOWN: 239 return "sub_group_shuffle_down"; 240 } 241 llvm_unreachable("Unhandled shuffle mode"); 242 } 243 244 static std::optional<StringRef> getTypeMangling(Type type) { 245 return TypeSwitch<Type, std::optional<StringRef>>(type) 246 .Case<Float16Type>([](auto) { return "Dhj"; }) 247 .Case<Float32Type>([](auto) { return "fj"; }) 248 .Case<Float64Type>([](auto) { return "dj"; }) 249 .Case<IntegerType>([](auto intTy) -> std::optional<StringRef> { 250 switch (intTy.getWidth()) { 251 case 8: 252 return "cj"; 253 case 16: 254 return "sj"; 255 case 32: 256 return "ij"; 257 case 64: 258 return "lj"; 259 } 260 return std::nullopt; 261 }) 262 .Default([](auto) { return std::nullopt; }); 263 } 264 265 static std::optional<std::string> getFuncName(gpu::ShuffleMode mode, 266 Type type) { 267 StringRef baseName = getBaseName(mode); 268 std::optional<StringRef> typeMangling = getTypeMangling(type); 269 if (!typeMangling) 270 return std::nullopt; 271 return llvm::formatv("_Z{}{}{}", baseName.size(), baseName, 272 typeMangling.value()); 273 } 274 275 /// Get the subgroup size from the target or return a default. 276 static std::optional<int> getSubgroupSize(Operation *op) { 277 auto parentFunc = op->getParentOfType<LLVM::LLVMFuncOp>(); 278 if (!parentFunc) 279 return std::nullopt; 280 return parentFunc.getIntelReqdSubGroupSize(); 281 } 282 283 static bool hasValidWidth(gpu::ShuffleOp op) { 284 llvm::APInt val; 285 Value width = op.getWidth(); 286 return matchPattern(width, m_ConstantInt(&val)) && 287 val == getSubgroupSize(op); 288 } 289 290 static Value bitcastOrExtBeforeShuffle(Value oldVal, Location loc, 291 ConversionPatternRewriter &rewriter) { 292 return TypeSwitch<Type, Value>(oldVal.getType()) 293 .Case([&](BFloat16Type) { 294 return rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI16Type(), 295 oldVal); 296 }) 297 .Case([&](IntegerType intTy) -> Value { 298 if (intTy.getWidth() == 1) 299 return rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI8Type(), 300 oldVal); 301 return oldVal; 302 }) 303 .Default(oldVal); 304 } 305 306 static Value bitcastOrTruncAfterShuffle(Value oldVal, Type newTy, 307 Location loc, 308 ConversionPatternRewriter &rewriter) { 309 return TypeSwitch<Type, Value>(newTy) 310 .Case([&](BFloat16Type) { 311 return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal); 312 }) 313 .Case([&](IntegerType intTy) -> Value { 314 if (intTy.getWidth() == 1) 315 return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal); 316 return oldVal; 317 }) 318 .Default(oldVal); 319 } 320 321 LogicalResult 322 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor, 323 ConversionPatternRewriter &rewriter) const final { 324 if (!hasValidWidth(op)) 325 return rewriter.notifyMatchFailure( 326 op, "shuffle width and subgroup size mismatch"); 327 328 Location loc = op->getLoc(); 329 Value inValue = 330 bitcastOrExtBeforeShuffle(adaptor.getValue(), loc, rewriter); 331 std::optional<std::string> funcName = 332 getFuncName(op.getMode(), inValue.getType()); 333 if (!funcName) 334 return rewriter.notifyMatchFailure(op, "unsupported value type"); 335 336 Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>(); 337 assert(moduleOp && "Expecting module"); 338 Type valueType = inValue.getType(); 339 Type offsetType = adaptor.getOffset().getType(); 340 Type resultType = valueType; 341 LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn( 342 moduleOp, funcName.value(), {valueType, offsetType}, resultType, 343 /*isMemNone=*/false, /*isConvergent=*/true); 344 345 std::array<Value, 2> args{inValue, adaptor.getOffset()}; 346 Value result = 347 createSPIRVBuiltinCall(loc, rewriter, func, args).getResult(); 348 Value resultOrConversion = 349 bitcastOrTruncAfterShuffle(result, op.getType(0), loc, rewriter); 350 351 Value trueVal = 352 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true); 353 rewriter.replaceOp(op, {resultOrConversion, trueVal}); 354 return success(); 355 } 356 }; 357 358 class MemorySpaceToOpenCLMemorySpaceConverter final : public TypeConverter { 359 public: 360 MemorySpaceToOpenCLMemorySpaceConverter(MLIRContext *ctx) { 361 addConversion([](Type t) { return t; }); 362 addConversion([ctx](BaseMemRefType memRefType) -> std::optional<Type> { 363 // Attach global addr space attribute to memrefs with no addr space attr 364 Attribute memSpaceAttr = memRefType.getMemorySpace(); 365 if (memSpaceAttr) 366 return std::nullopt; 367 368 unsigned globalAddrspace = storageClassToAddressSpace( 369 spirv::ClientAPI::OpenCL, spirv::StorageClass::CrossWorkgroup); 370 Attribute addrSpaceAttr = 371 IntegerAttr::get(IntegerType::get(ctx, 64), globalAddrspace); 372 if (auto rankedType = dyn_cast<MemRefType>(memRefType)) { 373 return MemRefType::get(memRefType.getShape(), 374 memRefType.getElementType(), 375 rankedType.getLayout(), addrSpaceAttr); 376 } 377 return UnrankedMemRefType::get(memRefType.getElementType(), 378 addrSpaceAttr); 379 }); 380 addConversion([this](FunctionType type) { 381 auto inputs = llvm::map_to_vector( 382 type.getInputs(), [this](Type ty) { return convertType(ty); }); 383 auto results = llvm::map_to_vector( 384 type.getResults(), [this](Type ty) { return convertType(ty); }); 385 return FunctionType::get(type.getContext(), inputs, results); 386 }); 387 } 388 }; 389 390 //===----------------------------------------------------------------------===// 391 // Subgroup query ops. 392 //===----------------------------------------------------------------------===// 393 394 template <typename SubgroupOp> 395 struct GPUSubgroupOpConversion final : ConvertOpToLLVMPattern<SubgroupOp> { 396 using ConvertOpToLLVMPattern<SubgroupOp>::ConvertOpToLLVMPattern; 397 using ConvertToLLVMPattern::getTypeConverter; 398 399 LogicalResult 400 matchAndRewrite(SubgroupOp op, typename SubgroupOp::Adaptor adaptor, 401 ConversionPatternRewriter &rewriter) const final { 402 constexpr StringRef funcName = [] { 403 if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupIdOp>) { 404 return "_Z16get_sub_group_id"; 405 } else if constexpr (std::is_same_v<SubgroupOp, gpu::LaneIdOp>) { 406 return "_Z22get_sub_group_local_id"; 407 } else if constexpr (std::is_same_v<SubgroupOp, gpu::NumSubgroupsOp>) { 408 return "_Z18get_num_sub_groups"; 409 } else if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupSizeOp>) { 410 return "_Z18get_sub_group_size"; 411 } 412 }(); 413 414 Operation *moduleOp = 415 op->template getParentWithTrait<OpTrait::SymbolTable>(); 416 Type resultTy = rewriter.getI32Type(); 417 LLVM::LLVMFuncOp func = 418 lookupOrCreateSPIRVFn(moduleOp, funcName, {}, resultTy, 419 /*isMemNone=*/false, /*isConvergent=*/false); 420 421 Location loc = op->getLoc(); 422 Value result = createSPIRVBuiltinCall(loc, rewriter, func, {}).getResult(); 423 424 Type indexTy = getTypeConverter()->getIndexType(); 425 if (resultTy != indexTy) { 426 if (indexTy.getIntOrFloatBitWidth() < resultTy.getIntOrFloatBitWidth()) { 427 return failure(); 428 } 429 result = rewriter.create<LLVM::ZExtOp>(loc, indexTy, result); 430 } 431 432 rewriter.replaceOp(op, result); 433 return success(); 434 } 435 }; 436 437 //===----------------------------------------------------------------------===// 438 // GPU To LLVM-SPV Pass. 439 //===----------------------------------------------------------------------===// 440 441 struct GPUToLLVMSPVConversionPass final 442 : impl::ConvertGpuOpsToLLVMSPVOpsBase<GPUToLLVMSPVConversionPass> { 443 using Base::Base; 444 445 void runOnOperation() final { 446 MLIRContext *context = &getContext(); 447 RewritePatternSet patterns(context); 448 449 LowerToLLVMOptions options(context); 450 options.overrideIndexBitwidth(this->use64bitIndex ? 64 : 32); 451 LLVMTypeConverter converter(context, options); 452 LLVMConversionTarget target(*context); 453 454 // Force OpenCL address spaces when they are not present 455 { 456 MemorySpaceToOpenCLMemorySpaceConverter converter(context); 457 AttrTypeReplacer replacer; 458 replacer.addReplacement([&converter](BaseMemRefType origType) 459 -> std::optional<BaseMemRefType> { 460 return converter.convertType<BaseMemRefType>(origType); 461 }); 462 463 replacer.recursivelyReplaceElementsIn(getOperation(), 464 /*replaceAttrs=*/true, 465 /*replaceLocs=*/false, 466 /*replaceTypes=*/true); 467 } 468 469 target.addIllegalOp<gpu::BarrierOp, gpu::BlockDimOp, gpu::BlockIdOp, 470 gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp, 471 gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp, 472 gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp, 473 gpu::ThreadIdOp>(); 474 475 populateGpuToLLVMSPVConversionPatterns(converter, patterns); 476 populateGpuMemorySpaceAttributeConversions(converter); 477 478 if (failed(applyPartialConversion(getOperation(), target, 479 std::move(patterns)))) 480 signalPassFailure(); 481 } 482 }; 483 } // namespace 484 485 //===----------------------------------------------------------------------===// 486 // GPU To LLVM-SPV Patterns. 487 //===----------------------------------------------------------------------===// 488 489 namespace mlir { 490 namespace { 491 static unsigned 492 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace addressSpace) { 493 constexpr spirv::ClientAPI clientAPI = spirv::ClientAPI::OpenCL; 494 return storageClassToAddressSpace(clientAPI, 495 addressSpaceToStorageClass(addressSpace)); 496 } 497 } // namespace 498 499 void populateGpuToLLVMSPVConversionPatterns( 500 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { 501 patterns.add<GPUBarrierConversion, GPUReturnOpLowering, GPUShuffleConversion, 502 GPUSubgroupOpConversion<gpu::LaneIdOp>, 503 GPUSubgroupOpConversion<gpu::NumSubgroupsOp>, 504 GPUSubgroupOpConversion<gpu::SubgroupIdOp>, 505 GPUSubgroupOpConversion<gpu::SubgroupSizeOp>, 506 LaunchConfigOpConversion<gpu::BlockDimOp>, 507 LaunchConfigOpConversion<gpu::BlockIdOp>, 508 LaunchConfigOpConversion<gpu::GlobalIdOp>, 509 LaunchConfigOpConversion<gpu::GridDimOp>, 510 LaunchConfigOpConversion<gpu::ThreadIdOp>>(typeConverter); 511 MLIRContext *context = &typeConverter.getContext(); 512 unsigned privateAddressSpace = 513 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Private); 514 unsigned localAddressSpace = 515 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Workgroup); 516 OperationName llvmFuncOpName(LLVM::LLVMFuncOp::getOperationName(), context); 517 StringAttr kernelBlockSizeAttributeName = 518 LLVM::LLVMFuncOp::getReqdWorkGroupSizeAttrName(llvmFuncOpName); 519 patterns.add<GPUFuncOpLowering>( 520 typeConverter, 521 GPUFuncOpLoweringOptions{ 522 privateAddressSpace, localAddressSpace, 523 /*kernelAttributeName=*/{}, kernelBlockSizeAttributeName, 524 LLVM::CConv::SPIR_KERNEL, LLVM::CConv::SPIR_FUNC, 525 /*encodeWorkgroupAttributionsAsArguments=*/true}); 526 } 527 528 void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter) { 529 populateGpuMemorySpaceAttributeConversions(typeConverter, 530 gpuAddressSpaceToOCLAddressSpace); 531 } 532 } // namespace mlir 533