xref: /llvm-project/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //===- GPUToSPIRV.cpp - GPU to SPIR-V Patterns ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert GPU dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
14 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include <optional>
26 
27 using namespace mlir;
28 
29 static constexpr const char kSPIRVModule[] = "__spv__";
30 
31 namespace {
32 /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
33 /// builtin variables.
34 template <typename SourceOp, spirv::BuiltIn builtin>
35 class LaunchConfigConversion : public OpConversionPattern<SourceOp> {
36 public:
37   using OpConversionPattern<SourceOp>::OpConversionPattern;
38 
39   LogicalResult
40   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
41                   ConversionPatternRewriter &rewriter) const override;
42 };
43 
44 /// Pattern lowering subgroup size/id to loading SPIR-V invocation
45 /// builtin variables.
46 template <typename SourceOp, spirv::BuiltIn builtin>
47 class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> {
48 public:
49   using OpConversionPattern<SourceOp>::OpConversionPattern;
50 
51   LogicalResult
52   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
53                   ConversionPatternRewriter &rewriter) const override;
54 };
55 
56 /// This is separate because in Vulkan workgroup size is exposed to shaders via
57 /// a constant with WorkgroupSize decoration. So here we cannot generate a
58 /// builtin variable; instead the information in the `spirv.entry_point_abi`
59 /// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp.
60 class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> {
61 public:
62   WorkGroupSizeConversion(const TypeConverter &typeConverter,
63                           MLIRContext *context)
64       : OpConversionPattern(typeConverter, context, /*benefit*/ 10) {}
65 
66   LogicalResult
67   matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
68                   ConversionPatternRewriter &rewriter) const override;
69 };
70 
71 /// Pattern to convert a kernel function in GPU dialect within a spirv.module.
72 class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> {
73 public:
74   using OpConversionPattern<gpu::GPUFuncOp>::OpConversionPattern;
75 
76   LogicalResult
77   matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
78                   ConversionPatternRewriter &rewriter) const override;
79 
80 private:
81   SmallVector<int32_t, 3> workGroupSizeAsInt32;
82 };
83 
84 /// Pattern to convert a gpu.module to a spirv.module.
85 class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
86 public:
87   using OpConversionPattern<gpu::GPUModuleOp>::OpConversionPattern;
88 
89   LogicalResult
90   matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
91                   ConversionPatternRewriter &rewriter) const override;
92 };
93 
94 /// Pattern to convert a gpu.return into a SPIR-V return.
95 // TODO: This can go to DRR when GPU return has operands.
96 class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
97 public:
98   using OpConversionPattern<gpu::ReturnOp>::OpConversionPattern;
99 
100   LogicalResult
101   matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
102                   ConversionPatternRewriter &rewriter) const override;
103 };
104 
105 /// Pattern to convert a gpu.barrier op into a spirv.ControlBarrier op.
106 class GPUBarrierConversion final : public OpConversionPattern<gpu::BarrierOp> {
107 public:
108   using OpConversionPattern::OpConversionPattern;
109 
110   LogicalResult
111   matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
112                   ConversionPatternRewriter &rewriter) const override;
113 };
114 
115 /// Pattern to convert a gpu.shuffle op into a spirv.GroupNonUniformShuffle op.
116 class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
117 public:
118   using OpConversionPattern::OpConversionPattern;
119 
120   LogicalResult
121   matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
122                   ConversionPatternRewriter &rewriter) const override;
123 };
124 
125 class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
126 public:
127   using OpConversionPattern::OpConversionPattern;
128 
129   LogicalResult
130   matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
131                   ConversionPatternRewriter &rewriter) const override;
132 };
133 
134 } // namespace
135 
136 //===----------------------------------------------------------------------===//
137 // Builtins.
138 //===----------------------------------------------------------------------===//
139 
140 template <typename SourceOp, spirv::BuiltIn builtin>
141 LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
142     SourceOp op, typename SourceOp::Adaptor adaptor,
143     ConversionPatternRewriter &rewriter) const {
144   auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
145   Type indexType = typeConverter->getIndexType();
146 
147   // For Vulkan, these SPIR-V builtin variables are required to be a vector of
148   // type <3xi32> by the spec:
149   // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumWorkgroups.html
150   // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupId.html
151   // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupSize.html
152   // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
153   // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
154   // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/GlobalInvocationId.html
155   //
156   // For OpenCL, it depends on the Physical32/Physical64 addressing model:
157   // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
158   bool forShader =
159       typeConverter->getTargetEnv().allows(spirv::Capability::Shader);
160   Type builtinType = forShader ? rewriter.getIntegerType(32) : indexType;
161 
162   Value vector =
163       spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter);
164   Value dim = rewriter.create<spirv::CompositeExtractOp>(
165       op.getLoc(), builtinType, vector,
166       rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())}));
167   if (forShader && builtinType != indexType)
168     dim = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, dim);
169   rewriter.replaceOp(op, dim);
170   return success();
171 }
172 
173 template <typename SourceOp, spirv::BuiltIn builtin>
174 LogicalResult
175 SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
176     SourceOp op, typename SourceOp::Adaptor adaptor,
177     ConversionPatternRewriter &rewriter) const {
178   auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
179   Type indexType = typeConverter->getIndexType();
180   Type i32Type = rewriter.getIntegerType(32);
181 
182   // For Vulkan, these SPIR-V builtin variables are required to be a vector of
183   // type i32 by the spec:
184   // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumSubgroups.html
185   // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupId.html
186   // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupSize.html
187   //
188   // For OpenCL, they are also required to be i32:
189   // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
190   Value builtinValue =
191       spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter);
192   if (i32Type != indexType)
193     builtinValue = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType,
194                                                       builtinValue);
195   rewriter.replaceOp(op, builtinValue);
196   return success();
197 }
198 
199 LogicalResult WorkGroupSizeConversion::matchAndRewrite(
200     gpu::BlockDimOp op, OpAdaptor adaptor,
201     ConversionPatternRewriter &rewriter) const {
202   DenseI32ArrayAttr workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op);
203   if (!workGroupSizeAttr)
204     return failure();
205 
206   int val =
207       workGroupSizeAttr.asArrayRef()[static_cast<int32_t>(op.getDimension())];
208   auto convertedType =
209       getTypeConverter()->convertType(op.getResult().getType());
210   if (!convertedType)
211     return failure();
212   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
213       op, convertedType, IntegerAttr::get(convertedType, val));
214   return success();
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // GPUFuncOp
219 //===----------------------------------------------------------------------===//
220 
221 // Legalizes a GPU function as an entry SPIR-V function.
222 static spirv::FuncOp
223 lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter,
224                      ConversionPatternRewriter &rewriter,
225                      spirv::EntryPointABIAttr entryPointInfo,
226                      ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
227   auto fnType = funcOp.getFunctionType();
228   if (fnType.getNumResults()) {
229     funcOp.emitError("SPIR-V lowering only supports entry functions"
230                      "with no return values right now");
231     return nullptr;
232   }
233   if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
234     funcOp.emitError(
235         "lowering as entry functions requires ABI info for all arguments "
236         "or none of them");
237     return nullptr;
238   }
239   // Update the signature to valid SPIR-V types and add the ABI
240   // attributes. These will be "materialized" by using the
241   // LowerABIAttributesPass.
242   TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
243   {
244     for (const auto &argType :
245          enumerate(funcOp.getFunctionType().getInputs())) {
246       auto convertedType = typeConverter.convertType(argType.value());
247       if (!convertedType)
248         return nullptr;
249       signatureConverter.addInputs(argType.index(), convertedType);
250     }
251   }
252   auto newFuncOp = rewriter.create<spirv::FuncOp>(
253       funcOp.getLoc(), funcOp.getName(),
254       rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
255                                std::nullopt));
256   for (const auto &namedAttr : funcOp->getAttrs()) {
257     if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
258         namedAttr.getName() == SymbolTable::getSymbolAttrName())
259       continue;
260     newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
261   }
262 
263   rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
264                               newFuncOp.end());
265   if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
266                                          &signatureConverter)))
267     return nullptr;
268   rewriter.eraseOp(funcOp);
269 
270   // Set the attributes for argument and the function.
271   StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
272   for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
273     newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
274   }
275   newFuncOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
276 
277   return newFuncOp;
278 }
279 
280 /// Populates `argABI` with spirv.interface_var_abi attributes for lowering
281 /// gpu.func to spirv.func if no arguments have the attributes set
282 /// already. Returns failure if any argument has the ABI attribute set already.
283 static LogicalResult
284 getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp,
285                    SmallVectorImpl<spirv::InterfaceVarABIAttr> &argABI) {
286   if (!spirv::needsInterfaceVarABIAttrs(targetEnv))
287     return success();
288 
289   for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
290     if (funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
291             argIndex, spirv::getInterfaceVarABIAttrName()))
292       return failure();
293     // Vulkan's interface variable requirements needs scalars to be wrapped in a
294     // struct. The struct held in storage buffer.
295     std::optional<spirv::StorageClass> sc;
296     if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
297       sc = spirv::StorageClass::StorageBuffer;
298     argABI.push_back(
299         spirv::getInterfaceVarABIAttr(0, argIndex, sc, funcOp.getContext()));
300   }
301   return success();
302 }
303 
304 LogicalResult GPUFuncOpConversion::matchAndRewrite(
305     gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
306     ConversionPatternRewriter &rewriter) const {
307   if (!gpu::GPUDialect::isKernel(funcOp))
308     return failure();
309 
310   auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
311   SmallVector<spirv::InterfaceVarABIAttr, 4> argABI;
312   if (failed(
313           getDefaultABIAttrs(typeConverter->getTargetEnv(), funcOp, argABI))) {
314     argABI.clear();
315     for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
316       // If the ABI is already specified, use it.
317       auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
318           argIndex, spirv::getInterfaceVarABIAttrName());
319       if (!abiAttr) {
320         funcOp.emitRemark(
321             "match failure: missing 'spirv.interface_var_abi' attribute at "
322             "argument ")
323             << argIndex;
324         return failure();
325       }
326       argABI.push_back(abiAttr);
327     }
328   }
329 
330   auto entryPointAttr = spirv::lookupEntryPointABI(funcOp);
331   if (!entryPointAttr) {
332     funcOp.emitRemark(
333         "match failure: missing 'spirv.entry_point_abi' attribute");
334     return failure();
335   }
336   spirv::FuncOp newFuncOp = lowerAsEntryFunction(
337       funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
338   if (!newFuncOp)
339     return failure();
340   newFuncOp->removeAttr(
341       rewriter.getStringAttr(gpu::GPUDialect::getKernelFuncAttrName()));
342   return success();
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // ModuleOp with gpu.module.
347 //===----------------------------------------------------------------------===//
348 
349 LogicalResult GPUModuleConversion::matchAndRewrite(
350     gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
351     ConversionPatternRewriter &rewriter) const {
352   auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
353   const spirv::TargetEnv &targetEnv = typeConverter->getTargetEnv();
354   spirv::AddressingModel addressingModel = spirv::getAddressingModel(
355       targetEnv, typeConverter->getOptions().use64bitIndex);
356   FailureOr<spirv::MemoryModel> memoryModel = spirv::getMemoryModel(targetEnv);
357   if (failed(memoryModel))
358     return moduleOp.emitRemark(
359         "cannot deduce memory model from 'spirv.target_env'");
360 
361   // Add a keyword to the module name to avoid symbolic conflict.
362   std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
363   auto spvModule = rewriter.create<spirv::ModuleOp>(
364       moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
365       StringRef(spvModuleName));
366 
367   // Move the region from the module op into the SPIR-V module.
368   Region &spvModuleRegion = spvModule.getRegion();
369   rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion,
370                               spvModuleRegion.begin());
371   // The spirv.module build method adds a block. Remove that.
372   rewriter.eraseBlock(&spvModuleRegion.back());
373 
374   // Some of the patterns call `lookupTargetEnv` during conversion and they
375   // will fail if called after GPUModuleConversion and we don't preserve
376   // `TargetEnv` attribute.
377   // Copy TargetEnvAttr only if it is attached directly to the GPUModuleOp.
378   if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
379           spirv::getTargetEnvAttrName()))
380     spvModule->setAttr(spirv::getTargetEnvAttrName(), attr);
381 
382   rewriter.eraseOp(moduleOp);
383   return success();
384 }
385 
386 //===----------------------------------------------------------------------===//
387 // GPU return inside kernel functions to SPIR-V return.
388 //===----------------------------------------------------------------------===//
389 
390 LogicalResult GPUReturnOpConversion::matchAndRewrite(
391     gpu::ReturnOp returnOp, OpAdaptor adaptor,
392     ConversionPatternRewriter &rewriter) const {
393   if (!adaptor.getOperands().empty())
394     return failure();
395 
396   rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
397   return success();
398 }
399 
400 //===----------------------------------------------------------------------===//
401 // Barrier.
402 //===----------------------------------------------------------------------===//
403 
404 LogicalResult GPUBarrierConversion::matchAndRewrite(
405     gpu::BarrierOp barrierOp, OpAdaptor adaptor,
406     ConversionPatternRewriter &rewriter) const {
407   MLIRContext *context = getContext();
408   // Both execution and memory scope should be workgroup.
409   auto scope = spirv::ScopeAttr::get(context, spirv::Scope::Workgroup);
410   // Require acquire and release memory semantics for workgroup memory.
411   auto memorySemantics = spirv::MemorySemanticsAttr::get(
412       context, spirv::MemorySemantics::WorkgroupMemory |
413                    spirv::MemorySemantics::AcquireRelease);
414   rewriter.replaceOpWithNewOp<spirv::ControlBarrierOp>(barrierOp, scope, scope,
415                                                        memorySemantics);
416   return success();
417 }
418 
419 //===----------------------------------------------------------------------===//
420 // Shuffle
421 //===----------------------------------------------------------------------===//
422 
423 LogicalResult GPUShuffleConversion::matchAndRewrite(
424     gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
425     ConversionPatternRewriter &rewriter) const {
426   // Require the shuffle width to be the same as the target's subgroup size,
427   // given that for SPIR-V non-uniform subgroup ops, we cannot select
428   // participating invocations.
429   auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
430   unsigned subgroupSize =
431       targetEnv.getAttr().getResourceLimits().getSubgroupSize();
432   IntegerAttr widthAttr;
433   if (!matchPattern(shuffleOp.getWidth(), m_Constant(&widthAttr)) ||
434       widthAttr.getValue().getZExtValue() != subgroupSize)
435     return rewriter.notifyMatchFailure(
436         shuffleOp, "shuffle width and target subgroup size mismatch");
437 
438   Location loc = shuffleOp.getLoc();
439   Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
440                                             shuffleOp.getLoc(), rewriter);
441   auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
442   Value result;
443 
444   switch (shuffleOp.getMode()) {
445   case gpu::ShuffleMode::XOR:
446     result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
447         loc, scope, adaptor.getValue(), adaptor.getOffset());
448     break;
449   case gpu::ShuffleMode::IDX:
450     result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
451         loc, scope, adaptor.getValue(), adaptor.getOffset());
452     break;
453   default:
454     return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode");
455   }
456 
457   rewriter.replaceOp(shuffleOp, {result, trueVal});
458   return success();
459 }
460 
461 //===----------------------------------------------------------------------===//
462 // Group ops
463 //===----------------------------------------------------------------------===//
464 
465 template <typename UniformOp, typename NonUniformOp>
466 static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
467                                      Value arg, bool isGroup, bool isUniform) {
468   Type type = arg.getType();
469   auto scope = mlir::spirv::ScopeAttr::get(builder.getContext(),
470                                            isGroup ? spirv::Scope::Workgroup
471                                                    : spirv::Scope::Subgroup);
472   auto groupOp = spirv::GroupOperationAttr::get(builder.getContext(),
473                                                 spirv::GroupOperation::Reduce);
474   if (isUniform) {
475     return builder.create<UniformOp>(loc, type, scope, groupOp, arg)
476         .getResult();
477   }
478   return builder.create<NonUniformOp>(loc, type, scope, groupOp, arg, Value{})
479       .getResult();
480 }
481 
482 static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
483                                                 Location loc, Value arg,
484                                                 gpu::AllReduceOperation opType,
485                                                 bool isGroup, bool isUniform) {
486   enum class ElemType { Float, Boolean, Integer };
487   using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool);
488   struct OpHandler {
489     gpu::AllReduceOperation kind;
490     ElemType elemType;
491     FuncT func;
492   };
493 
494   Type type = arg.getType();
495   ElemType elementType;
496   if (isa<FloatType>(type)) {
497     elementType = ElemType::Float;
498   } else if (auto intTy = dyn_cast<IntegerType>(type)) {
499     elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
500                                                        : ElemType::Integer;
501   } else {
502     return std::nullopt;
503   }
504 
505   // TODO(https://github.com/llvm/llvm-project/issues/73459): The SPIR-V spec
506   // does not specify how -0.0 / +0.0 and NaN values are handled in *FMin/*FMax
507   // reduction ops. We should account possible precision requirements in this
508   // conversion.
509 
510   using ReduceType = gpu::AllReduceOperation;
511   const OpHandler handlers[] = {
512       {ReduceType::ADD, ElemType::Integer,
513        &createGroupReduceOpImpl<spirv::GroupIAddOp,
514                                 spirv::GroupNonUniformIAddOp>},
515       {ReduceType::ADD, ElemType::Float,
516        &createGroupReduceOpImpl<spirv::GroupFAddOp,
517                                 spirv::GroupNonUniformFAddOp>},
518       {ReduceType::MUL, ElemType::Integer,
519        &createGroupReduceOpImpl<spirv::GroupIMulKHROp,
520                                 spirv::GroupNonUniformIMulOp>},
521       {ReduceType::MUL, ElemType::Float,
522        &createGroupReduceOpImpl<spirv::GroupFMulKHROp,
523                                 spirv::GroupNonUniformFMulOp>},
524       {ReduceType::MINUI, ElemType::Integer,
525        &createGroupReduceOpImpl<spirv::GroupUMinOp,
526                                 spirv::GroupNonUniformUMinOp>},
527       {ReduceType::MINSI, ElemType::Integer,
528        &createGroupReduceOpImpl<spirv::GroupSMinOp,
529                                 spirv::GroupNonUniformSMinOp>},
530       {ReduceType::MINNUMF, ElemType::Float,
531        &createGroupReduceOpImpl<spirv::GroupFMinOp,
532                                 spirv::GroupNonUniformFMinOp>},
533       {ReduceType::MAXUI, ElemType::Integer,
534        &createGroupReduceOpImpl<spirv::GroupUMaxOp,
535                                 spirv::GroupNonUniformUMaxOp>},
536       {ReduceType::MAXSI, ElemType::Integer,
537        &createGroupReduceOpImpl<spirv::GroupSMaxOp,
538                                 spirv::GroupNonUniformSMaxOp>},
539       {ReduceType::MAXNUMF, ElemType::Float,
540        &createGroupReduceOpImpl<spirv::GroupFMaxOp,
541                                 spirv::GroupNonUniformFMaxOp>},
542       {ReduceType::MINIMUMF, ElemType::Float,
543        &createGroupReduceOpImpl<spirv::GroupFMinOp,
544                                 spirv::GroupNonUniformFMinOp>},
545       {ReduceType::MAXIMUMF, ElemType::Float,
546        &createGroupReduceOpImpl<spirv::GroupFMaxOp,
547                                 spirv::GroupNonUniformFMaxOp>}};
548 
549   for (const OpHandler &handler : handlers)
550     if (handler.kind == opType && elementType == handler.elemType)
551       return handler.func(builder, loc, arg, isGroup, isUniform);
552 
553   return std::nullopt;
554 }
555 
556 /// Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
557 class GPUAllReduceConversion final
558     : public OpConversionPattern<gpu::AllReduceOp> {
559 public:
560   using OpConversionPattern::OpConversionPattern;
561 
562   LogicalResult
563   matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor,
564                   ConversionPatternRewriter &rewriter) const override {
565     auto opType = op.getOp();
566 
567     // gpu.all_reduce can have either reduction op attribute or reduction
568     // region. Only attribute version is supported.
569     if (!opType)
570       return failure();
571 
572     auto result =
573         createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), *opType,
574                             /*isGroup*/ true, op.getUniform());
575     if (!result)
576       return failure();
577 
578     rewriter.replaceOp(op, *result);
579     return success();
580   }
581 };
582 
583 /// Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
584 class GPUSubgroupReduceConversion final
585     : public OpConversionPattern<gpu::SubgroupReduceOp> {
586 public:
587   using OpConversionPattern::OpConversionPattern;
588 
589   LogicalResult
590   matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
591                   ConversionPatternRewriter &rewriter) const override {
592     if (op.getClusterSize())
593       return rewriter.notifyMatchFailure(
594           op, "lowering for clustered reduce not implemented");
595 
596     if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
597       return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");
598 
599     auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(),
600                                       adaptor.getOp(),
601                                       /*isGroup=*/false, adaptor.getUniform());
602     if (!result)
603       return failure();
604 
605     rewriter.replaceOp(op, *result);
606     return success();
607   }
608 };
609 
610 // Formulate a unique variable/constant name after
611 // searching in the module for existing variable/constant names.
612 // This is to avoid name collision with existing variables.
613 // Example: printfMsg0, printfMsg1, printfMsg2, ...
614 static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix) {
615   std::string name;
616   unsigned number = 0;
617 
618   do {
619     name.clear();
620     name = (prefix + llvm::Twine(number++)).str();
621   } while (moduleOp.lookupSymbol(name));
622 
623   return name;
624 }
625 
626 /// Pattern to convert a gpu.printf op into a SPIR-V CLPrintf op.
627 
628 LogicalResult GPUPrintfConversion::matchAndRewrite(
629     gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
630     ConversionPatternRewriter &rewriter) const {
631 
632   Location loc = gpuPrintfOp.getLoc();
633 
634   auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>();
635   if (!moduleOp)
636     return failure();
637 
638   // SPIR-V global variable is used to initialize printf
639   // format string value, if there are multiple printf messages,
640   // each global var needs to be created with a unique name.
641   std::string globalVarName = makeVarName(moduleOp, llvm::Twine("printfMsg"));
642   spirv::GlobalVariableOp globalVar;
643 
644   IntegerType i8Type = rewriter.getI8Type();
645   IntegerType i32Type = rewriter.getI32Type();
646 
647   // Each character of printf format string is
648   // stored as a spec constant. We need to create
649   // unique name for this spec constant like
650   // @printfMsg0_sc0, @printfMsg0_sc1, ... by searching in the module
651   // for existing spec constant names.
652   auto createSpecConstant = [&](unsigned value) {
653     auto attr = rewriter.getI8IntegerAttr(value);
654     std::string specCstName =
655         makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc");
656 
657     return rewriter.create<spirv::SpecConstantOp>(
658         loc, rewriter.getStringAttr(specCstName), attr);
659   };
660   {
661     Operation *parent =
662         SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());
663 
664     ConversionPatternRewriter::InsertionGuard guard(rewriter);
665 
666     Block &entryBlock = *parent->getRegion(0).begin();
667     rewriter.setInsertionPointToStart(
668         &entryBlock); // insertion point at module level
669 
670     // Create Constituents with SpecConstant by scanning format string
671     // Each character of format string is stored as a spec constant
672     // and then these spec constants are used to create a
673     // SpecConstantCompositeOp.
674     llvm::SmallString<20> formatString(adaptor.getFormat());
675     formatString.push_back('\0'); // Null terminate for C.
676     SmallVector<Attribute, 4> constituents;
677     for (char c : formatString) {
678       spirv::SpecConstantOp cSpecConstantOp = createSpecConstant(c);
679       constituents.push_back(SymbolRefAttr::get(cSpecConstantOp));
680     }
681 
682     // Create SpecConstantCompositeOp to initialize the global variable
683     size_t contentSize = constituents.size();
684     auto globalType = spirv::ArrayType::get(i8Type, contentSize);
685     spirv::SpecConstantCompositeOp specCstComposite;
686     // There will be one SpecConstantCompositeOp per printf message/global var,
687     // so no need do lookup for existing ones.
688     std::string specCstCompositeName =
689         (llvm::Twine(globalVarName) + "_scc").str();
690 
691     specCstComposite = rewriter.create<spirv::SpecConstantCompositeOp>(
692         loc, TypeAttr::get(globalType),
693         rewriter.getStringAttr(specCstCompositeName),
694         rewriter.getArrayAttr(constituents));
695 
696     auto ptrType = spirv::PointerType::get(
697         globalType, spirv::StorageClass::UniformConstant);
698 
699     // Define a GlobalVarOp initialized using specialized constants
700     // that is used to specify the printf format string
701     // to be passed to the SPIRV CLPrintfOp.
702     globalVar = rewriter.create<spirv::GlobalVariableOp>(
703         loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite));
704 
705     globalVar->setAttr("Constant", rewriter.getUnitAttr());
706   }
707   // Get SSA value of Global variable and create pointer to i8 to point to
708   // the format string.
709   Value globalPtr = rewriter.create<spirv::AddressOfOp>(loc, globalVar);
710   Value fmtStr = rewriter.create<spirv::BitcastOp>(
711       loc,
712       spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant),
713       globalPtr);
714 
715   // Get printf arguments.
716   auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs());
717 
718   rewriter.create<spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
719 
720   // Need to erase the gpu.printf op as gpu.printf does not use result vs
721   // spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V
722   // printf op.
723   rewriter.eraseOp(gpuPrintfOp);
724 
725   return success();
726 }
727 
728 //===----------------------------------------------------------------------===//
729 // GPU To SPIRV Patterns.
730 //===----------------------------------------------------------------------===//
731 
732 void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
733                                       RewritePatternSet &patterns) {
734   patterns.add<
735       GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
736       GPUReturnOpConversion, GPUShuffleConversion,
737       LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
738       LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
739       LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
740       LaunchConfigConversion<gpu::ThreadIdOp,
741                              spirv::BuiltIn::LocalInvocationId>,
742       LaunchConfigConversion<gpu::GlobalIdOp,
743                              spirv::BuiltIn::GlobalInvocationId>,
744       SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
745                                       spirv::BuiltIn::SubgroupId>,
746       SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
747                                       spirv::BuiltIn::NumSubgroups>,
748       SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
749                                       spirv::BuiltIn::SubgroupSize>,
750       WorkGroupSizeConversion, GPUAllReduceConversion,
751       GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter,
752                                                         patterns.getContext());
753 }
754