1 //===- GPUToSPIRV.cpp - GPU to SPIR-V Patterns ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert GPU dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" 14 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 19 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 20 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" 21 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 22 #include "mlir/IR/BuiltinOps.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include <optional> 26 27 using namespace mlir; 28 29 static constexpr const char kSPIRVModule[] = "__spv__"; 30 31 namespace { 32 /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation 33 /// builtin variables. 34 template <typename SourceOp, spirv::BuiltIn builtin> 35 class LaunchConfigConversion : public OpConversionPattern<SourceOp> { 36 public: 37 using OpConversionPattern<SourceOp>::OpConversionPattern; 38 39 LogicalResult 40 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, 41 ConversionPatternRewriter &rewriter) const override; 42 }; 43 44 /// Pattern lowering subgroup size/id to loading SPIR-V invocation 45 /// builtin variables. 46 template <typename SourceOp, spirv::BuiltIn builtin> 47 class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> { 48 public: 49 using OpConversionPattern<SourceOp>::OpConversionPattern; 50 51 LogicalResult 52 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, 53 ConversionPatternRewriter &rewriter) const override; 54 }; 55 56 /// This is separate because in Vulkan workgroup size is exposed to shaders via 57 /// a constant with WorkgroupSize decoration. So here we cannot generate a 58 /// builtin variable; instead the information in the `spirv.entry_point_abi` 59 /// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp. 60 class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> { 61 public: 62 WorkGroupSizeConversion(const TypeConverter &typeConverter, 63 MLIRContext *context) 64 : OpConversionPattern(typeConverter, context, /*benefit*/ 10) {} 65 66 LogicalResult 67 matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor, 68 ConversionPatternRewriter &rewriter) const override; 69 }; 70 71 /// Pattern to convert a kernel function in GPU dialect within a spirv.module. 72 class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> { 73 public: 74 using OpConversionPattern<gpu::GPUFuncOp>::OpConversionPattern; 75 76 LogicalResult 77 matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor, 78 ConversionPatternRewriter &rewriter) const override; 79 80 private: 81 SmallVector<int32_t, 3> workGroupSizeAsInt32; 82 }; 83 84 /// Pattern to convert a gpu.module to a spirv.module. 85 class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> { 86 public: 87 using OpConversionPattern<gpu::GPUModuleOp>::OpConversionPattern; 88 89 LogicalResult 90 matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor, 91 ConversionPatternRewriter &rewriter) const override; 92 }; 93 94 /// Pattern to convert a gpu.return into a SPIR-V return. 95 // TODO: This can go to DRR when GPU return has operands. 96 class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> { 97 public: 98 using OpConversionPattern<gpu::ReturnOp>::OpConversionPattern; 99 100 LogicalResult 101 matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor, 102 ConversionPatternRewriter &rewriter) const override; 103 }; 104 105 /// Pattern to convert a gpu.barrier op into a spirv.ControlBarrier op. 106 class GPUBarrierConversion final : public OpConversionPattern<gpu::BarrierOp> { 107 public: 108 using OpConversionPattern::OpConversionPattern; 109 110 LogicalResult 111 matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor, 112 ConversionPatternRewriter &rewriter) const override; 113 }; 114 115 /// Pattern to convert a gpu.shuffle op into a spirv.GroupNonUniformShuffle op. 116 class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> { 117 public: 118 using OpConversionPattern::OpConversionPattern; 119 120 LogicalResult 121 matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor, 122 ConversionPatternRewriter &rewriter) const override; 123 }; 124 125 class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> { 126 public: 127 using OpConversionPattern::OpConversionPattern; 128 129 LogicalResult 130 matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor, 131 ConversionPatternRewriter &rewriter) const override; 132 }; 133 134 } // namespace 135 136 //===----------------------------------------------------------------------===// 137 // Builtins. 138 //===----------------------------------------------------------------------===// 139 140 template <typename SourceOp, spirv::BuiltIn builtin> 141 LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite( 142 SourceOp op, typename SourceOp::Adaptor adaptor, 143 ConversionPatternRewriter &rewriter) const { 144 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>(); 145 Type indexType = typeConverter->getIndexType(); 146 147 // For Vulkan, these SPIR-V builtin variables are required to be a vector of 148 // type <3xi32> by the spec: 149 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumWorkgroups.html 150 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupId.html 151 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupSize.html 152 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html 153 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html 154 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/GlobalInvocationId.html 155 // 156 // For OpenCL, it depends on the Physical32/Physical64 addressing model: 157 // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables 158 bool forShader = 159 typeConverter->getTargetEnv().allows(spirv::Capability::Shader); 160 Type builtinType = forShader ? rewriter.getIntegerType(32) : indexType; 161 162 Value vector = 163 spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter); 164 Value dim = rewriter.create<spirv::CompositeExtractOp>( 165 op.getLoc(), builtinType, vector, 166 rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())})); 167 if (forShader && builtinType != indexType) 168 dim = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, dim); 169 rewriter.replaceOp(op, dim); 170 return success(); 171 } 172 173 template <typename SourceOp, spirv::BuiltIn builtin> 174 LogicalResult 175 SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite( 176 SourceOp op, typename SourceOp::Adaptor adaptor, 177 ConversionPatternRewriter &rewriter) const { 178 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>(); 179 Type indexType = typeConverter->getIndexType(); 180 Type i32Type = rewriter.getIntegerType(32); 181 182 // For Vulkan, these SPIR-V builtin variables are required to be a vector of 183 // type i32 by the spec: 184 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumSubgroups.html 185 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupId.html 186 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupSize.html 187 // 188 // For OpenCL, they are also required to be i32: 189 // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables 190 Value builtinValue = 191 spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter); 192 if (i32Type != indexType) 193 builtinValue = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, 194 builtinValue); 195 rewriter.replaceOp(op, builtinValue); 196 return success(); 197 } 198 199 LogicalResult WorkGroupSizeConversion::matchAndRewrite( 200 gpu::BlockDimOp op, OpAdaptor adaptor, 201 ConversionPatternRewriter &rewriter) const { 202 DenseI32ArrayAttr workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op); 203 if (!workGroupSizeAttr) 204 return failure(); 205 206 int val = 207 workGroupSizeAttr.asArrayRef()[static_cast<int32_t>(op.getDimension())]; 208 auto convertedType = 209 getTypeConverter()->convertType(op.getResult().getType()); 210 if (!convertedType) 211 return failure(); 212 rewriter.replaceOpWithNewOp<spirv::ConstantOp>( 213 op, convertedType, IntegerAttr::get(convertedType, val)); 214 return success(); 215 } 216 217 //===----------------------------------------------------------------------===// 218 // GPUFuncOp 219 //===----------------------------------------------------------------------===// 220 221 // Legalizes a GPU function as an entry SPIR-V function. 222 static spirv::FuncOp 223 lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, 224 ConversionPatternRewriter &rewriter, 225 spirv::EntryPointABIAttr entryPointInfo, 226 ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) { 227 auto fnType = funcOp.getFunctionType(); 228 if (fnType.getNumResults()) { 229 funcOp.emitError("SPIR-V lowering only supports entry functions" 230 "with no return values right now"); 231 return nullptr; 232 } 233 if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) { 234 funcOp.emitError( 235 "lowering as entry functions requires ABI info for all arguments " 236 "or none of them"); 237 return nullptr; 238 } 239 // Update the signature to valid SPIR-V types and add the ABI 240 // attributes. These will be "materialized" by using the 241 // LowerABIAttributesPass. 242 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); 243 { 244 for (const auto &argType : 245 enumerate(funcOp.getFunctionType().getInputs())) { 246 auto convertedType = typeConverter.convertType(argType.value()); 247 if (!convertedType) 248 return nullptr; 249 signatureConverter.addInputs(argType.index(), convertedType); 250 } 251 } 252 auto newFuncOp = rewriter.create<spirv::FuncOp>( 253 funcOp.getLoc(), funcOp.getName(), 254 rewriter.getFunctionType(signatureConverter.getConvertedTypes(), 255 std::nullopt)); 256 for (const auto &namedAttr : funcOp->getAttrs()) { 257 if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() || 258 namedAttr.getName() == SymbolTable::getSymbolAttrName()) 259 continue; 260 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); 261 } 262 263 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 264 newFuncOp.end()); 265 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, 266 &signatureConverter))) 267 return nullptr; 268 rewriter.eraseOp(funcOp); 269 270 // Set the attributes for argument and the function. 271 StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName(); 272 for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) { 273 newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]); 274 } 275 newFuncOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo); 276 277 return newFuncOp; 278 } 279 280 /// Populates `argABI` with spirv.interface_var_abi attributes for lowering 281 /// gpu.func to spirv.func if no arguments have the attributes set 282 /// already. Returns failure if any argument has the ABI attribute set already. 283 static LogicalResult 284 getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp, 285 SmallVectorImpl<spirv::InterfaceVarABIAttr> &argABI) { 286 if (!spirv::needsInterfaceVarABIAttrs(targetEnv)) 287 return success(); 288 289 for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) { 290 if (funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>( 291 argIndex, spirv::getInterfaceVarABIAttrName())) 292 return failure(); 293 // Vulkan's interface variable requirements needs scalars to be wrapped in a 294 // struct. The struct held in storage buffer. 295 std::optional<spirv::StorageClass> sc; 296 if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat()) 297 sc = spirv::StorageClass::StorageBuffer; 298 argABI.push_back( 299 spirv::getInterfaceVarABIAttr(0, argIndex, sc, funcOp.getContext())); 300 } 301 return success(); 302 } 303 304 LogicalResult GPUFuncOpConversion::matchAndRewrite( 305 gpu::GPUFuncOp funcOp, OpAdaptor adaptor, 306 ConversionPatternRewriter &rewriter) const { 307 if (!gpu::GPUDialect::isKernel(funcOp)) 308 return failure(); 309 310 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); 311 SmallVector<spirv::InterfaceVarABIAttr, 4> argABI; 312 if (failed( 313 getDefaultABIAttrs(typeConverter->getTargetEnv(), funcOp, argABI))) { 314 argABI.clear(); 315 for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) { 316 // If the ABI is already specified, use it. 317 auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>( 318 argIndex, spirv::getInterfaceVarABIAttrName()); 319 if (!abiAttr) { 320 funcOp.emitRemark( 321 "match failure: missing 'spirv.interface_var_abi' attribute at " 322 "argument ") 323 << argIndex; 324 return failure(); 325 } 326 argABI.push_back(abiAttr); 327 } 328 } 329 330 auto entryPointAttr = spirv::lookupEntryPointABI(funcOp); 331 if (!entryPointAttr) { 332 funcOp.emitRemark( 333 "match failure: missing 'spirv.entry_point_abi' attribute"); 334 return failure(); 335 } 336 spirv::FuncOp newFuncOp = lowerAsEntryFunction( 337 funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI); 338 if (!newFuncOp) 339 return failure(); 340 newFuncOp->removeAttr( 341 rewriter.getStringAttr(gpu::GPUDialect::getKernelFuncAttrName())); 342 return success(); 343 } 344 345 //===----------------------------------------------------------------------===// 346 // ModuleOp with gpu.module. 347 //===----------------------------------------------------------------------===// 348 349 LogicalResult GPUModuleConversion::matchAndRewrite( 350 gpu::GPUModuleOp moduleOp, OpAdaptor adaptor, 351 ConversionPatternRewriter &rewriter) const { 352 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); 353 const spirv::TargetEnv &targetEnv = typeConverter->getTargetEnv(); 354 spirv::AddressingModel addressingModel = spirv::getAddressingModel( 355 targetEnv, typeConverter->getOptions().use64bitIndex); 356 FailureOr<spirv::MemoryModel> memoryModel = spirv::getMemoryModel(targetEnv); 357 if (failed(memoryModel)) 358 return moduleOp.emitRemark( 359 "cannot deduce memory model from 'spirv.target_env'"); 360 361 // Add a keyword to the module name to avoid symbolic conflict. 362 std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str(); 363 auto spvModule = rewriter.create<spirv::ModuleOp>( 364 moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt, 365 StringRef(spvModuleName)); 366 367 // Move the region from the module op into the SPIR-V module. 368 Region &spvModuleRegion = spvModule.getRegion(); 369 rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion, 370 spvModuleRegion.begin()); 371 // The spirv.module build method adds a block. Remove that. 372 rewriter.eraseBlock(&spvModuleRegion.back()); 373 374 // Some of the patterns call `lookupTargetEnv` during conversion and they 375 // will fail if called after GPUModuleConversion and we don't preserve 376 // `TargetEnv` attribute. 377 // Copy TargetEnvAttr only if it is attached directly to the GPUModuleOp. 378 if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>( 379 spirv::getTargetEnvAttrName())) 380 spvModule->setAttr(spirv::getTargetEnvAttrName(), attr); 381 382 rewriter.eraseOp(moduleOp); 383 return success(); 384 } 385 386 //===----------------------------------------------------------------------===// 387 // GPU return inside kernel functions to SPIR-V return. 388 //===----------------------------------------------------------------------===// 389 390 LogicalResult GPUReturnOpConversion::matchAndRewrite( 391 gpu::ReturnOp returnOp, OpAdaptor adaptor, 392 ConversionPatternRewriter &rewriter) const { 393 if (!adaptor.getOperands().empty()) 394 return failure(); 395 396 rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp); 397 return success(); 398 } 399 400 //===----------------------------------------------------------------------===// 401 // Barrier. 402 //===----------------------------------------------------------------------===// 403 404 LogicalResult GPUBarrierConversion::matchAndRewrite( 405 gpu::BarrierOp barrierOp, OpAdaptor adaptor, 406 ConversionPatternRewriter &rewriter) const { 407 MLIRContext *context = getContext(); 408 // Both execution and memory scope should be workgroup. 409 auto scope = spirv::ScopeAttr::get(context, spirv::Scope::Workgroup); 410 // Require acquire and release memory semantics for workgroup memory. 411 auto memorySemantics = spirv::MemorySemanticsAttr::get( 412 context, spirv::MemorySemantics::WorkgroupMemory | 413 spirv::MemorySemantics::AcquireRelease); 414 rewriter.replaceOpWithNewOp<spirv::ControlBarrierOp>(barrierOp, scope, scope, 415 memorySemantics); 416 return success(); 417 } 418 419 //===----------------------------------------------------------------------===// 420 // Shuffle 421 //===----------------------------------------------------------------------===// 422 423 LogicalResult GPUShuffleConversion::matchAndRewrite( 424 gpu::ShuffleOp shuffleOp, OpAdaptor adaptor, 425 ConversionPatternRewriter &rewriter) const { 426 // Require the shuffle width to be the same as the target's subgroup size, 427 // given that for SPIR-V non-uniform subgroup ops, we cannot select 428 // participating invocations. 429 auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv(); 430 unsigned subgroupSize = 431 targetEnv.getAttr().getResourceLimits().getSubgroupSize(); 432 IntegerAttr widthAttr; 433 if (!matchPattern(shuffleOp.getWidth(), m_Constant(&widthAttr)) || 434 widthAttr.getValue().getZExtValue() != subgroupSize) 435 return rewriter.notifyMatchFailure( 436 shuffleOp, "shuffle width and target subgroup size mismatch"); 437 438 Location loc = shuffleOp.getLoc(); 439 Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), 440 shuffleOp.getLoc(), rewriter); 441 auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup); 442 Value result; 443 444 switch (shuffleOp.getMode()) { 445 case gpu::ShuffleMode::XOR: 446 result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>( 447 loc, scope, adaptor.getValue(), adaptor.getOffset()); 448 break; 449 case gpu::ShuffleMode::IDX: 450 result = rewriter.create<spirv::GroupNonUniformShuffleOp>( 451 loc, scope, adaptor.getValue(), adaptor.getOffset()); 452 break; 453 default: 454 return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode"); 455 } 456 457 rewriter.replaceOp(shuffleOp, {result, trueVal}); 458 return success(); 459 } 460 461 //===----------------------------------------------------------------------===// 462 // Group ops 463 //===----------------------------------------------------------------------===// 464 465 template <typename UniformOp, typename NonUniformOp> 466 static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, 467 Value arg, bool isGroup, bool isUniform) { 468 Type type = arg.getType(); 469 auto scope = mlir::spirv::ScopeAttr::get(builder.getContext(), 470 isGroup ? spirv::Scope::Workgroup 471 : spirv::Scope::Subgroup); 472 auto groupOp = spirv::GroupOperationAttr::get(builder.getContext(), 473 spirv::GroupOperation::Reduce); 474 if (isUniform) { 475 return builder.create<UniformOp>(loc, type, scope, groupOp, arg) 476 .getResult(); 477 } 478 return builder.create<NonUniformOp>(loc, type, scope, groupOp, arg, Value{}) 479 .getResult(); 480 } 481 482 static std::optional<Value> createGroupReduceOp(OpBuilder &builder, 483 Location loc, Value arg, 484 gpu::AllReduceOperation opType, 485 bool isGroup, bool isUniform) { 486 enum class ElemType { Float, Boolean, Integer }; 487 using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool); 488 struct OpHandler { 489 gpu::AllReduceOperation kind; 490 ElemType elemType; 491 FuncT func; 492 }; 493 494 Type type = arg.getType(); 495 ElemType elementType; 496 if (isa<FloatType>(type)) { 497 elementType = ElemType::Float; 498 } else if (auto intTy = dyn_cast<IntegerType>(type)) { 499 elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean 500 : ElemType::Integer; 501 } else { 502 return std::nullopt; 503 } 504 505 // TODO(https://github.com/llvm/llvm-project/issues/73459): The SPIR-V spec 506 // does not specify how -0.0 / +0.0 and NaN values are handled in *FMin/*FMax 507 // reduction ops. We should account possible precision requirements in this 508 // conversion. 509 510 using ReduceType = gpu::AllReduceOperation; 511 const OpHandler handlers[] = { 512 {ReduceType::ADD, ElemType::Integer, 513 &createGroupReduceOpImpl<spirv::GroupIAddOp, 514 spirv::GroupNonUniformIAddOp>}, 515 {ReduceType::ADD, ElemType::Float, 516 &createGroupReduceOpImpl<spirv::GroupFAddOp, 517 spirv::GroupNonUniformFAddOp>}, 518 {ReduceType::MUL, ElemType::Integer, 519 &createGroupReduceOpImpl<spirv::GroupIMulKHROp, 520 spirv::GroupNonUniformIMulOp>}, 521 {ReduceType::MUL, ElemType::Float, 522 &createGroupReduceOpImpl<spirv::GroupFMulKHROp, 523 spirv::GroupNonUniformFMulOp>}, 524 {ReduceType::MINUI, ElemType::Integer, 525 &createGroupReduceOpImpl<spirv::GroupUMinOp, 526 spirv::GroupNonUniformUMinOp>}, 527 {ReduceType::MINSI, ElemType::Integer, 528 &createGroupReduceOpImpl<spirv::GroupSMinOp, 529 spirv::GroupNonUniformSMinOp>}, 530 {ReduceType::MINNUMF, ElemType::Float, 531 &createGroupReduceOpImpl<spirv::GroupFMinOp, 532 spirv::GroupNonUniformFMinOp>}, 533 {ReduceType::MAXUI, ElemType::Integer, 534 &createGroupReduceOpImpl<spirv::GroupUMaxOp, 535 spirv::GroupNonUniformUMaxOp>}, 536 {ReduceType::MAXSI, ElemType::Integer, 537 &createGroupReduceOpImpl<spirv::GroupSMaxOp, 538 spirv::GroupNonUniformSMaxOp>}, 539 {ReduceType::MAXNUMF, ElemType::Float, 540 &createGroupReduceOpImpl<spirv::GroupFMaxOp, 541 spirv::GroupNonUniformFMaxOp>}, 542 {ReduceType::MINIMUMF, ElemType::Float, 543 &createGroupReduceOpImpl<spirv::GroupFMinOp, 544 spirv::GroupNonUniformFMinOp>}, 545 {ReduceType::MAXIMUMF, ElemType::Float, 546 &createGroupReduceOpImpl<spirv::GroupFMaxOp, 547 spirv::GroupNonUniformFMaxOp>}}; 548 549 for (const OpHandler &handler : handlers) 550 if (handler.kind == opType && elementType == handler.elemType) 551 return handler.func(builder, loc, arg, isGroup, isUniform); 552 553 return std::nullopt; 554 } 555 556 /// Pattern to convert a gpu.all_reduce op into a SPIR-V group op. 557 class GPUAllReduceConversion final 558 : public OpConversionPattern<gpu::AllReduceOp> { 559 public: 560 using OpConversionPattern::OpConversionPattern; 561 562 LogicalResult 563 matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor, 564 ConversionPatternRewriter &rewriter) const override { 565 auto opType = op.getOp(); 566 567 // gpu.all_reduce can have either reduction op attribute or reduction 568 // region. Only attribute version is supported. 569 if (!opType) 570 return failure(); 571 572 auto result = 573 createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), *opType, 574 /*isGroup*/ true, op.getUniform()); 575 if (!result) 576 return failure(); 577 578 rewriter.replaceOp(op, *result); 579 return success(); 580 } 581 }; 582 583 /// Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op. 584 class GPUSubgroupReduceConversion final 585 : public OpConversionPattern<gpu::SubgroupReduceOp> { 586 public: 587 using OpConversionPattern::OpConversionPattern; 588 589 LogicalResult 590 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, 591 ConversionPatternRewriter &rewriter) const override { 592 if (op.getClusterSize()) 593 return rewriter.notifyMatchFailure( 594 op, "lowering for clustered reduce not implemented"); 595 596 if (!isa<spirv::ScalarType>(adaptor.getValue().getType())) 597 return rewriter.notifyMatchFailure(op, "reduction type is not a scalar"); 598 599 auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), 600 adaptor.getOp(), 601 /*isGroup=*/false, adaptor.getUniform()); 602 if (!result) 603 return failure(); 604 605 rewriter.replaceOp(op, *result); 606 return success(); 607 } 608 }; 609 610 // Formulate a unique variable/constant name after 611 // searching in the module for existing variable/constant names. 612 // This is to avoid name collision with existing variables. 613 // Example: printfMsg0, printfMsg1, printfMsg2, ... 614 static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix) { 615 std::string name; 616 unsigned number = 0; 617 618 do { 619 name.clear(); 620 name = (prefix + llvm::Twine(number++)).str(); 621 } while (moduleOp.lookupSymbol(name)); 622 623 return name; 624 } 625 626 /// Pattern to convert a gpu.printf op into a SPIR-V CLPrintf op. 627 628 LogicalResult GPUPrintfConversion::matchAndRewrite( 629 gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor, 630 ConversionPatternRewriter &rewriter) const { 631 632 Location loc = gpuPrintfOp.getLoc(); 633 634 auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>(); 635 if (!moduleOp) 636 return failure(); 637 638 // SPIR-V global variable is used to initialize printf 639 // format string value, if there are multiple printf messages, 640 // each global var needs to be created with a unique name. 641 std::string globalVarName = makeVarName(moduleOp, llvm::Twine("printfMsg")); 642 spirv::GlobalVariableOp globalVar; 643 644 IntegerType i8Type = rewriter.getI8Type(); 645 IntegerType i32Type = rewriter.getI32Type(); 646 647 // Each character of printf format string is 648 // stored as a spec constant. We need to create 649 // unique name for this spec constant like 650 // @printfMsg0_sc0, @printfMsg0_sc1, ... by searching in the module 651 // for existing spec constant names. 652 auto createSpecConstant = [&](unsigned value) { 653 auto attr = rewriter.getI8IntegerAttr(value); 654 std::string specCstName = 655 makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc"); 656 657 return rewriter.create<spirv::SpecConstantOp>( 658 loc, rewriter.getStringAttr(specCstName), attr); 659 }; 660 { 661 Operation *parent = 662 SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp()); 663 664 ConversionPatternRewriter::InsertionGuard guard(rewriter); 665 666 Block &entryBlock = *parent->getRegion(0).begin(); 667 rewriter.setInsertionPointToStart( 668 &entryBlock); // insertion point at module level 669 670 // Create Constituents with SpecConstant by scanning format string 671 // Each character of format string is stored as a spec constant 672 // and then these spec constants are used to create a 673 // SpecConstantCompositeOp. 674 llvm::SmallString<20> formatString(adaptor.getFormat()); 675 formatString.push_back('\0'); // Null terminate for C. 676 SmallVector<Attribute, 4> constituents; 677 for (char c : formatString) { 678 spirv::SpecConstantOp cSpecConstantOp = createSpecConstant(c); 679 constituents.push_back(SymbolRefAttr::get(cSpecConstantOp)); 680 } 681 682 // Create SpecConstantCompositeOp to initialize the global variable 683 size_t contentSize = constituents.size(); 684 auto globalType = spirv::ArrayType::get(i8Type, contentSize); 685 spirv::SpecConstantCompositeOp specCstComposite; 686 // There will be one SpecConstantCompositeOp per printf message/global var, 687 // so no need do lookup for existing ones. 688 std::string specCstCompositeName = 689 (llvm::Twine(globalVarName) + "_scc").str(); 690 691 specCstComposite = rewriter.create<spirv::SpecConstantCompositeOp>( 692 loc, TypeAttr::get(globalType), 693 rewriter.getStringAttr(specCstCompositeName), 694 rewriter.getArrayAttr(constituents)); 695 696 auto ptrType = spirv::PointerType::get( 697 globalType, spirv::StorageClass::UniformConstant); 698 699 // Define a GlobalVarOp initialized using specialized constants 700 // that is used to specify the printf format string 701 // to be passed to the SPIRV CLPrintfOp. 702 globalVar = rewriter.create<spirv::GlobalVariableOp>( 703 loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite)); 704 705 globalVar->setAttr("Constant", rewriter.getUnitAttr()); 706 } 707 // Get SSA value of Global variable and create pointer to i8 to point to 708 // the format string. 709 Value globalPtr = rewriter.create<spirv::AddressOfOp>(loc, globalVar); 710 Value fmtStr = rewriter.create<spirv::BitcastOp>( 711 loc, 712 spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant), 713 globalPtr); 714 715 // Get printf arguments. 716 auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs()); 717 718 rewriter.create<spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs); 719 720 // Need to erase the gpu.printf op as gpu.printf does not use result vs 721 // spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V 722 // printf op. 723 rewriter.eraseOp(gpuPrintfOp); 724 725 return success(); 726 } 727 728 //===----------------------------------------------------------------------===// 729 // GPU To SPIRV Patterns. 730 //===----------------------------------------------------------------------===// 731 732 void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, 733 RewritePatternSet &patterns) { 734 patterns.add< 735 GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion, 736 GPUReturnOpConversion, GPUShuffleConversion, 737 LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>, 738 LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>, 739 LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>, 740 LaunchConfigConversion<gpu::ThreadIdOp, 741 spirv::BuiltIn::LocalInvocationId>, 742 LaunchConfigConversion<gpu::GlobalIdOp, 743 spirv::BuiltIn::GlobalInvocationId>, 744 SingleDimLaunchConfigConversion<gpu::SubgroupIdOp, 745 spirv::BuiltIn::SubgroupId>, 746 SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp, 747 spirv::BuiltIn::NumSubgroups>, 748 SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp, 749 spirv::BuiltIn::SubgroupSize>, 750 WorkGroupSizeConversion, GPUAllReduceConversion, 751 GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter, 752 patterns.getContext()); 753 } 754