xref: /llvm-project/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp (revision 4adeb6cf556df10da668916b22eb39d3f1313e8a)
1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16 #include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
20 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
21 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/FormatVariadic.h"
28 
29 #define DEBUG_TYPE "spirv-to-llvm-pattern"
30 
31 using namespace mlir;
32 
33 //===----------------------------------------------------------------------===//
34 // Utility functions
35 //===----------------------------------------------------------------------===//
36 
37 /// Returns true if the given type is a signed integer or vector type.
38 static bool isSignedIntegerOrVector(Type type) {
39   if (type.isSignedInteger())
40     return true;
41   if (auto vecType = dyn_cast<VectorType>(type))
42     return vecType.getElementType().isSignedInteger();
43   return false;
44 }
45 
46 /// Returns true if the given type is an unsigned integer or vector type
47 static bool isUnsignedIntegerOrVector(Type type) {
48   if (type.isUnsignedInteger())
49     return true;
50   if (auto vecType = dyn_cast<VectorType>(type))
51     return vecType.getElementType().isUnsignedInteger();
52   return false;
53 }
54 
55 /// Returns the width of an integer or of the element type of an integer vector,
56 /// if applicable.
57 static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
58   if (auto intType = dyn_cast<IntegerType>(type))
59     return intType.getWidth();
60   if (auto vecType = dyn_cast<VectorType>(type))
61     if (auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
62       return intType.getWidth();
63   return std::nullopt;
64 }
65 
66 /// Returns the bit width of integer, float or vector of float or integer values
67 static unsigned getBitWidth(Type type) {
68   assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
69          "bitwidth is not supported for this type");
70   if (type.isIntOrFloat())
71     return type.getIntOrFloatBitWidth();
72   auto vecType = dyn_cast<VectorType>(type);
73   auto elementType = vecType.getElementType();
74   assert(elementType.isIntOrFloat() &&
75          "only integers and floats have a bitwidth");
76   return elementType.getIntOrFloatBitWidth();
77 }
78 
79 /// Returns the bit width of LLVMType integer or vector.
80 static unsigned getLLVMTypeBitWidth(Type type) {
81   return cast<IntegerType>((LLVM::isCompatibleVectorType(type)
82                                 ? LLVM::getVectorElementType(type)
83                                 : type))
84       .getWidth();
85 }
86 
87 /// Creates `IntegerAttribute` with all bits set for given type
88 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
89   if (auto vecType = dyn_cast<VectorType>(type)) {
90     auto integerType = cast<IntegerType>(vecType.getElementType());
91     return builder.getIntegerAttr(integerType, -1);
92   }
93   auto integerType = cast<IntegerType>(type);
94   return builder.getIntegerAttr(integerType, -1);
95 }
96 
97 /// Creates `llvm.mlir.constant` with all bits set for the given type.
98 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
99                                       PatternRewriter &rewriter) {
100   if (isa<VectorType>(srcType)) {
101     return rewriter.create<LLVM::ConstantOp>(
102         loc, dstType,
103         SplatElementsAttr::get(cast<ShapedType>(srcType),
104                                minusOneIntegerAttribute(srcType, rewriter)));
105   }
106   return rewriter.create<LLVM::ConstantOp>(
107       loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
108 }
109 
110 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
111 static Value createFPConstant(Location loc, Type srcType, Type dstType,
112                               PatternRewriter &rewriter, double value) {
113   if (auto vecType = dyn_cast<VectorType>(srcType)) {
114     auto floatType = cast<FloatType>(vecType.getElementType());
115     return rewriter.create<LLVM::ConstantOp>(
116         loc, dstType,
117         SplatElementsAttr::get(vecType,
118                                rewriter.getFloatAttr(floatType, value)));
119   }
120   auto floatType = cast<FloatType>(srcType);
121   return rewriter.create<LLVM::ConstantOp>(
122       loc, dstType, rewriter.getFloatAttr(floatType, value));
123 }
124 
125 /// Utility function for bitfield ops:
126 ///   - `BitFieldInsert`
127 ///   - `BitFieldSExtract`
128 ///   - `BitFieldUExtract`
129 /// Truncates or extends the value. If the bitwidth of the value is the same as
130 /// `llvmType` bitwidth, the value remains unchanged.
131 static Value optionallyTruncateOrExtend(Location loc, Value value,
132                                         Type llvmType,
133                                         PatternRewriter &rewriter) {
134   auto srcType = value.getType();
135   unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
136   unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
137                                ? getLLVMTypeBitWidth(srcType)
138                                : getBitWidth(srcType);
139 
140   if (valueBitWidth < targetBitWidth)
141     return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
142   // If the bit widths of `Count` and `Offset` are greater than the bit width
143   // of the target type, they are truncated. Truncation is safe since `Count`
144   // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
145   // both values can be expressed in 8 bits.
146   if (valueBitWidth > targetBitWidth)
147     return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
148   return value;
149 }
150 
151 /// Broadcasts the value to vector with `numElements` number of elements.
152 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
153                        const TypeConverter &typeConverter,
154                        ConversionPatternRewriter &rewriter) {
155   auto vectorType = VectorType::get(numElements, toBroadcast.getType());
156   auto llvmVectorType = typeConverter.convertType(vectorType);
157   auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
158   Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
159   for (unsigned i = 0; i < numElements; ++i) {
160     auto index = rewriter.create<LLVM::ConstantOp>(
161         loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
162     broadcasted = rewriter.create<LLVM::InsertElementOp>(
163         loc, llvmVectorType, broadcasted, toBroadcast, index);
164   }
165   return broadcasted;
166 }
167 
168 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
169 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
170                                  const TypeConverter &typeConverter,
171                                  ConversionPatternRewriter &rewriter) {
172   if (auto vectorType = dyn_cast<VectorType>(srcType)) {
173     unsigned numElements = vectorType.getNumElements();
174     return broadcast(loc, value, numElements, typeConverter, rewriter);
175   }
176   return value;
177 }
178 
179 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
180 /// `BitFieldUExtract`.
181 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
182 /// a vector type, construct a vector that has:
183 ///  - same number of elements as `Base`
184 ///  - each element has the type that is the same as the type of `Offset` or
185 ///    `Count`
186 ///  - each element has the same value as `Offset` or `Count`
187 /// Then cast `Offset` and `Count` if their bit width is different
188 /// from `Base` bit width.
189 static Value processCountOrOffset(Location loc, Value value, Type srcType,
190                                   Type dstType, const TypeConverter &converter,
191                                   ConversionPatternRewriter &rewriter) {
192   Value broadcasted =
193       optionallyBroadcast(loc, value, srcType, converter, rewriter);
194   return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
195 }
196 
197 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
198 /// offset to LLVM struct. Otherwise, the conversion is not supported.
199 static Type convertStructTypeWithOffset(spirv::StructType type,
200                                         const TypeConverter &converter) {
201   if (type != VulkanLayoutUtils::decorateType(type))
202     return nullptr;
203 
204   SmallVector<Type> elementsVector;
205   if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
206     return nullptr;
207   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
208                                           /*isPacked=*/false);
209 }
210 
211 /// Converts SPIR-V struct with no offset to packed LLVM struct.
212 static Type convertStructTypePacked(spirv::StructType type,
213                                     const TypeConverter &converter) {
214   SmallVector<Type> elementsVector;
215   if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
216     return nullptr;
217   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
218                                           /*isPacked=*/true);
219 }
220 
221 /// Creates LLVM dialect constant with the given value.
222 static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
223                                  unsigned value) {
224   return rewriter.create<LLVM::ConstantOp>(
225       loc, IntegerType::get(rewriter.getContext(), 32),
226       rewriter.getIntegerAttr(rewriter.getI32Type(), value));
227 }
228 
229 /// Utility for `spirv.Load` and `spirv.Store` conversion.
230 static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
231                                             ConversionPatternRewriter &rewriter,
232                                             const TypeConverter &typeConverter,
233                                             unsigned alignment, bool isVolatile,
234                                             bool isNonTemporal) {
235   if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
236     auto dstType = typeConverter.convertType(loadOp.getType());
237     if (!dstType)
238       return rewriter.notifyMatchFailure(op, "type conversion failed");
239     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
240         loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
241         isVolatile, isNonTemporal);
242     return success();
243   }
244   auto storeOp = cast<spirv::StoreOp>(op);
245   spirv::StoreOpAdaptor adaptor(operands);
246   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
247                                              adaptor.getPtr(), alignment,
248                                              isVolatile, isNonTemporal);
249   return success();
250 }
251 
252 //===----------------------------------------------------------------------===//
253 // Type conversion
254 //===----------------------------------------------------------------------===//
255 
256 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
257 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
258 /// when converting ops that manipulate array types.
259 static std::optional<Type> convertArrayType(spirv::ArrayType type,
260                                             TypeConverter &converter) {
261   unsigned stride = type.getArrayStride();
262   Type elementType = type.getElementType();
263   auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
264   if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
265     return std::nullopt;
266 
267   auto llvmElementType = converter.convertType(elementType);
268   unsigned numElements = type.getNumElements();
269   return LLVM::LLVMArrayType::get(llvmElementType, numElements);
270 }
271 
272 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
273 /// modelled at the moment.
274 static Type convertPointerType(spirv::PointerType type,
275                                const TypeConverter &converter,
276                                spirv::ClientAPI clientAPI) {
277   unsigned addressSpace =
278       storageClassToAddressSpace(clientAPI, type.getStorageClass());
279   return LLVM::LLVMPointerType::get(type.getContext(), addressSpace);
280 }
281 
282 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
283 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
284 /// no modelling of array stride at the moment.
285 static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
286                                                    TypeConverter &converter) {
287   if (type.getArrayStride() != 0)
288     return std::nullopt;
289   auto elementType = converter.convertType(type.getElementType());
290   return LLVM::LLVMArrayType::get(elementType, 0);
291 }
292 
293 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
294 /// member decorations. Also, only natural offset is supported.
295 static Type convertStructType(spirv::StructType type,
296                               const TypeConverter &converter) {
297   SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
298   type.getMemberDecorations(memberDecorations);
299   if (!memberDecorations.empty())
300     return nullptr;
301   if (type.hasOffset())
302     return convertStructTypeWithOffset(type, converter);
303   return convertStructTypePacked(type, converter);
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // Operation conversion
308 //===----------------------------------------------------------------------===//
309 
310 namespace {
311 
312 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
313 public:
314   using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
315 
316   LogicalResult
317   matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
318                   ConversionPatternRewriter &rewriter) const override {
319     auto dstType =
320         getTypeConverter()->convertType(op.getComponentPtr().getType());
321     if (!dstType)
322       return rewriter.notifyMatchFailure(op, "type conversion failed");
323     // To use GEP we need to add a first 0 index to go through the pointer.
324     auto indices = llvm::to_vector<4>(adaptor.getIndices());
325     Type indexType = op.getIndices().front().getType();
326     auto llvmIndexType = getTypeConverter()->convertType(indexType);
327     if (!llvmIndexType)
328       return rewriter.notifyMatchFailure(op, "type conversion failed");
329     Value zero = rewriter.create<LLVM::ConstantOp>(
330         op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
331     indices.insert(indices.begin(), zero);
332 
333     auto elementType = getTypeConverter()->convertType(
334         cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
335     if (!elementType)
336       return rewriter.notifyMatchFailure(op, "type conversion failed");
337     rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
338                                              adaptor.getBasePtr(), indices);
339     return success();
340   }
341 };
342 
343 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
344 public:
345   using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
346 
347   LogicalResult
348   matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
349                   ConversionPatternRewriter &rewriter) const override {
350     auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
351     if (!dstType)
352       return rewriter.notifyMatchFailure(op, "type conversion failed");
353     rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
354                                                    op.getVariable());
355     return success();
356   }
357 };
358 
359 class BitFieldInsertPattern
360     : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
361 public:
362   using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
363 
364   LogicalResult
365   matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
366                   ConversionPatternRewriter &rewriter) const override {
367     auto srcType = op.getType();
368     auto dstType = getTypeConverter()->convertType(srcType);
369     if (!dstType)
370       return rewriter.notifyMatchFailure(op, "type conversion failed");
371     Location loc = op.getLoc();
372 
373     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
374     Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
375                                         *getTypeConverter(), rewriter);
376     Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
377                                        *getTypeConverter(), rewriter);
378 
379     // Create a mask with bits set outside [Offset, Offset + Count - 1].
380     Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
381     Value maskShiftedByCount =
382         rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
383     Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
384                                                  maskShiftedByCount, minusOne);
385     Value maskShiftedByCountAndOffset =
386         rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
387     Value mask = rewriter.create<LLVM::XOrOp>(
388         loc, dstType, maskShiftedByCountAndOffset, minusOne);
389 
390     // Extract unchanged bits from the `Base`  that are outside of
391     // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
392     Value baseAndMask =
393         rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
394     Value insertShiftedByOffset =
395         rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
396     rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
397                                             insertShiftedByOffset);
398     return success();
399   }
400 };
401 
402 /// Converts SPIR-V ConstantOp with scalar or vector type.
403 class ConstantScalarAndVectorPattern
404     : public SPIRVToLLVMConversion<spirv::ConstantOp> {
405 public:
406   using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
407 
408   LogicalResult
409   matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
410                   ConversionPatternRewriter &rewriter) const override {
411     auto srcType = constOp.getType();
412     if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
413       return failure();
414 
415     auto dstType = getTypeConverter()->convertType(srcType);
416     if (!dstType)
417       return rewriter.notifyMatchFailure(constOp, "type conversion failed");
418 
419     // SPIR-V constant can be a signed/unsigned integer, which has to be
420     // casted to signless integer when converting to LLVM dialect. Removing the
421     // sign bit may have unexpected behaviour. However, it is better to handle
422     // it case-by-case, given that the purpose of the conversion is not to
423     // cover all possible corner cases.
424     if (isSignedIntegerOrVector(srcType) ||
425         isUnsignedIntegerOrVector(srcType)) {
426       auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
427 
428       if (isa<VectorType>(srcType)) {
429         auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
430         rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
431             constOp, dstType,
432             dstElementsAttr.mapValues(
433                 signlessType, [&](const APInt &value) { return value; }));
434         return success();
435       }
436       auto srcAttr = cast<IntegerAttr>(constOp.getValue());
437       auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
438       rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
439       return success();
440     }
441     rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
442         constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
443     return success();
444   }
445 };
446 
447 class BitFieldSExtractPattern
448     : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
449 public:
450   using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
451 
452   LogicalResult
453   matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
454                   ConversionPatternRewriter &rewriter) const override {
455     auto srcType = op.getType();
456     auto dstType = getTypeConverter()->convertType(srcType);
457     if (!dstType)
458       return rewriter.notifyMatchFailure(op, "type conversion failed");
459     Location loc = op.getLoc();
460 
461     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
462     Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
463                                         *getTypeConverter(), rewriter);
464     Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
465                                        *getTypeConverter(), rewriter);
466 
467     // Create a constant that holds the size of the `Base`.
468     IntegerType integerType;
469     if (auto vecType = dyn_cast<VectorType>(srcType))
470       integerType = cast<IntegerType>(vecType.getElementType());
471     else
472       integerType = cast<IntegerType>(srcType);
473 
474     auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
475     Value size =
476         isa<VectorType>(srcType)
477             ? rewriter.create<LLVM::ConstantOp>(
478                   loc, dstType,
479                   SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
480             : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
481 
482     // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
483     // at Offset + Count - 1 is the most significant bit now.
484     Value countPlusOffset =
485         rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
486     Value amountToShiftLeft =
487         rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
488     Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
489         loc, dstType, op.getBase(), amountToShiftLeft);
490 
491     // Shift the result right, filling the bits with the sign bit.
492     Value amountToShiftRight =
493         rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
494     rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
495                                               amountToShiftRight);
496     return success();
497   }
498 };
499 
500 class BitFieldUExtractPattern
501     : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
502 public:
503   using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
504 
505   LogicalResult
506   matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
507                   ConversionPatternRewriter &rewriter) const override {
508     auto srcType = op.getType();
509     auto dstType = getTypeConverter()->convertType(srcType);
510     if (!dstType)
511       return rewriter.notifyMatchFailure(op, "type conversion failed");
512     Location loc = op.getLoc();
513 
514     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
515     Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
516                                         *getTypeConverter(), rewriter);
517     Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
518                                        *getTypeConverter(), rewriter);
519 
520     // Create a mask with bits set at [0, Count - 1].
521     Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
522     Value maskShiftedByCount =
523         rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
524     Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
525                                               minusOne);
526 
527     // Shift `Base` by `Offset` and apply the mask on it.
528     Value shiftedBase =
529         rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
530     rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
531     return success();
532   }
533 };
534 
535 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
536 public:
537   using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion;
538 
539   LogicalResult
540   matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
541                   ConversionPatternRewriter &rewriter) const override {
542     rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
543                                             branchOp.getTarget());
544     return success();
545   }
546 };
547 
548 class BranchConditionalConversionPattern
549     : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
550 public:
551   using SPIRVToLLVMConversion<
552       spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
553 
554   LogicalResult
555   matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
556                   ConversionPatternRewriter &rewriter) const override {
557     // If branch weights exist, map them to 32-bit integer vector.
558     DenseI32ArrayAttr branchWeights = nullptr;
559     if (auto weights = op.getBranchWeights()) {
560       SmallVector<int32_t> weightValues;
561       for (auto weight : weights->getAsRange<IntegerAttr>())
562         weightValues.push_back(weight.getInt());
563       branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
564     }
565 
566     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
567         op, op.getCondition(), op.getTrueBlockArguments(),
568         op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
569         op.getFalseBlock());
570     return success();
571   }
572 };
573 
574 /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
575 /// type is an aggregate type (struct or array). Otherwise, converts to
576 /// `llvm.extractelement` that operates on vectors.
577 class CompositeExtractPattern
578     : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
579 public:
580   using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion;
581 
582   LogicalResult
583   matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
584                   ConversionPatternRewriter &rewriter) const override {
585     auto dstType = this->getTypeConverter()->convertType(op.getType());
586     if (!dstType)
587       return rewriter.notifyMatchFailure(op, "type conversion failed");
588 
589     Type containerType = op.getComposite().getType();
590     if (isa<VectorType>(containerType)) {
591       Location loc = op.getLoc();
592       IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
593       Value index = createI32ConstantOf(loc, rewriter, value.getInt());
594       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
595           op, dstType, adaptor.getComposite(), index);
596       return success();
597     }
598 
599     rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
600         op, adaptor.getComposite(),
601         LLVM::convertArrayToIndices(op.getIndices()));
602     return success();
603   }
604 };
605 
606 /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
607 /// type is an aggregate type (struct or array). Otherwise, converts to
608 /// `llvm.insertelement` that operates on vectors.
609 class CompositeInsertPattern
610     : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
611 public:
612   using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion;
613 
614   LogicalResult
615   matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
616                   ConversionPatternRewriter &rewriter) const override {
617     auto dstType = this->getTypeConverter()->convertType(op.getType());
618     if (!dstType)
619       return rewriter.notifyMatchFailure(op, "type conversion failed");
620 
621     Type containerType = op.getComposite().getType();
622     if (isa<VectorType>(containerType)) {
623       Location loc = op.getLoc();
624       IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
625       Value index = createI32ConstantOf(loc, rewriter, value.getInt());
626       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
627           op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
628       return success();
629     }
630 
631     rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
632         op, adaptor.getComposite(), adaptor.getObject(),
633         LLVM::convertArrayToIndices(op.getIndices()));
634     return success();
635   }
636 };
637 
638 /// Converts SPIR-V operations that have straightforward LLVM equivalent
639 /// into LLVM dialect operations.
640 template <typename SPIRVOp, typename LLVMOp>
641 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
642 public:
643   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
644 
645   LogicalResult
646   matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
647                   ConversionPatternRewriter &rewriter) const override {
648     auto dstType = this->getTypeConverter()->convertType(op.getType());
649     if (!dstType)
650       return rewriter.notifyMatchFailure(op, "type conversion failed");
651     rewriter.template replaceOpWithNewOp<LLVMOp>(
652         op, dstType, adaptor.getOperands(), op->getAttrs());
653     return success();
654   }
655 };
656 
657 /// Converts `spirv.ExecutionMode` into a global struct constant that holds
658 /// execution mode information.
659 class ExecutionModePattern
660     : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
661 public:
662   using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion;
663 
664   LogicalResult
665   matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
666                   ConversionPatternRewriter &rewriter) const override {
667     // First, create the global struct's name that would be associated with
668     // this entry point's execution mode. We set it to be:
669     //   __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
670     ModuleOp module = op->getParentOfType<ModuleOp>();
671     spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
672     std::string moduleName;
673     if (module.getName().has_value())
674       moduleName = "_" + module.getName()->str();
675     else
676       moduleName = "";
677     std::string executionModeInfoName = llvm::formatv(
678         "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
679         static_cast<uint32_t>(executionModeAttr.getValue()));
680 
681     MLIRContext *context = rewriter.getContext();
682     OpBuilder::InsertionGuard guard(rewriter);
683     rewriter.setInsertionPointToStart(module.getBody());
684 
685     // Create a struct type, corresponding to the C struct below.
686     // struct {
687     //   int32_t executionMode;
688     //   int32_t values[];          // optional values
689     // };
690     auto llvmI32Type = IntegerType::get(context, 32);
691     SmallVector<Type, 2> fields;
692     fields.push_back(llvmI32Type);
693     ArrayAttr values = op.getValues();
694     if (!values.empty()) {
695       auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
696       fields.push_back(arrayType);
697     }
698     auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
699 
700     // Create `llvm.mlir.global` with initializer region containing one block.
701     auto global = rewriter.create<LLVM::GlobalOp>(
702         UnknownLoc::get(context), structType, /*isConstant=*/true,
703         LLVM::Linkage::External, executionModeInfoName, Attribute(),
704         /*alignment=*/0);
705     Location loc = global.getLoc();
706     Region &region = global.getInitializerRegion();
707     Block *block = rewriter.createBlock(&region);
708 
709     // Initialize the struct and set the execution mode value.
710     rewriter.setInsertionPointToStart(block);
711     Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
712     Value executionMode = rewriter.create<LLVM::ConstantOp>(
713         loc, llvmI32Type,
714         rewriter.getI32IntegerAttr(
715             static_cast<uint32_t>(executionModeAttr.getValue())));
716     structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue,
717                                                        executionMode, 0);
718 
719     // Insert extra operands if they exist into execution mode info struct.
720     for (unsigned i = 0, e = values.size(); i < e; ++i) {
721       auto attr = values.getValue()[i];
722       Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
723       structValue = rewriter.create<LLVM::InsertValueOp>(
724           loc, structValue, entry, ArrayRef<int64_t>({1, i}));
725     }
726     rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
727     rewriter.eraseOp(op);
728     return success();
729   }
730 };
731 
732 /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
733 /// global returns a pointer, whereas in LLVM dialect the global holds an actual
734 /// value. This difference is handled by `spirv.mlir.addressof` and
735 /// `llvm.mlir.addressof`ops that both return a pointer.
736 class GlobalVariablePattern
737     : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
738 public:
739   template <typename... Args>
740   GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
741       : SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
742             std::forward<Args>(args)...),
743         clientAPI(clientAPI) {}
744 
745   LogicalResult
746   matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
747                   ConversionPatternRewriter &rewriter) const override {
748     // Currently, there is no support of initialization with a constant value in
749     // SPIR-V dialect. Specialization constants are not considered as well.
750     if (op.getInitializer())
751       return failure();
752 
753     auto srcType = cast<spirv::PointerType>(op.getType());
754     auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
755     if (!dstType)
756       return rewriter.notifyMatchFailure(op, "type conversion failed");
757 
758     // Limit conversion to the current invocation only or `StorageBuffer`
759     // required by SPIR-V runner.
760     // This is okay because multiple invocations are not supported yet.
761     auto storageClass = srcType.getStorageClass();
762     switch (storageClass) {
763     case spirv::StorageClass::Input:
764     case spirv::StorageClass::Private:
765     case spirv::StorageClass::Output:
766     case spirv::StorageClass::StorageBuffer:
767     case spirv::StorageClass::UniformConstant:
768       break;
769     default:
770       return failure();
771     }
772 
773     // LLVM dialect spec: "If the global value is a constant, storing into it is
774     // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
775     // storage class that is read-only.
776     bool isConstant = (storageClass == spirv::StorageClass::Input) ||
777                       (storageClass == spirv::StorageClass::UniformConstant);
778     // SPIR-V spec: "By default, functions and global variables are private to a
779     // module and cannot be accessed by other modules. However, a module may be
780     // written to export or import functions and global (module scope)
781     // variables.". Therefore, map 'Private' storage class to private linkage,
782     // 'Input' and 'Output' to external linkage.
783     auto linkage = storageClass == spirv::StorageClass::Private
784                        ? LLVM::Linkage::Private
785                        : LLVM::Linkage::External;
786     auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
787         op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
788         /*alignment=*/0, storageClassToAddressSpace(clientAPI, storageClass));
789 
790     // Attach location attribute if applicable
791     if (op.getLocationAttr())
792       newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
793 
794     return success();
795   }
796 
797 private:
798   spirv::ClientAPI clientAPI;
799 };
800 
801 /// Converts SPIR-V cast ops that do not have straightforward LLVM
802 /// equivalent in LLVM dialect.
803 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
804 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
805 public:
806   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
807 
808   LogicalResult
809   matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
810                   ConversionPatternRewriter &rewriter) const override {
811 
812     Type fromType = op.getOperand().getType();
813     Type toType = op.getType();
814 
815     auto dstType = this->getTypeConverter()->convertType(toType);
816     if (!dstType)
817       return rewriter.notifyMatchFailure(op, "type conversion failed");
818 
819     if (getBitWidth(fromType) < getBitWidth(toType)) {
820       rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
821                                                       adaptor.getOperands());
822       return success();
823     }
824     if (getBitWidth(fromType) > getBitWidth(toType)) {
825       rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
826                                                         adaptor.getOperands());
827       return success();
828     }
829     return failure();
830   }
831 };
832 
833 class FunctionCallPattern
834     : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
835 public:
836   using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion;
837 
838   LogicalResult
839   matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
840                   ConversionPatternRewriter &rewriter) const override {
841     if (callOp.getNumResults() == 0) {
842       auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
843           callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
844       newOp.getProperties().operandSegmentSizes = {
845           static_cast<int32_t>(adaptor.getOperands().size()), 0};
846       newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
847       return success();
848     }
849 
850     // Function returns a single result.
851     auto dstType = getTypeConverter()->convertType(callOp.getType(0));
852     if (!dstType)
853       return rewriter.notifyMatchFailure(callOp, "type conversion failed");
854     auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
855         callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
856     newOp.getProperties().operandSegmentSizes = {
857         static_cast<int32_t>(adaptor.getOperands().size()), 0};
858     newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
859     return success();
860   }
861 };
862 
863 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
864 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
865 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
866 public:
867   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
868 
869   LogicalResult
870   matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
871                   ConversionPatternRewriter &rewriter) const override {
872 
873     auto dstType = this->getTypeConverter()->convertType(op.getType());
874     if (!dstType)
875       return rewriter.notifyMatchFailure(op, "type conversion failed");
876 
877     rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
878         op, dstType, predicate, op.getOperand1(), op.getOperand2());
879     return success();
880   }
881 };
882 
883 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
884 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
885 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
886 public:
887   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
888 
889   LogicalResult
890   matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
891                   ConversionPatternRewriter &rewriter) const override {
892 
893     auto dstType = this->getTypeConverter()->convertType(op.getType());
894     if (!dstType)
895       return rewriter.notifyMatchFailure(op, "type conversion failed");
896 
897     rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
898         op, dstType, predicate, op.getOperand1(), op.getOperand2());
899     return success();
900   }
901 };
902 
903 class InverseSqrtPattern
904     : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
905 public:
906   using SPIRVToLLVMConversion<spirv::GLInverseSqrtOp>::SPIRVToLLVMConversion;
907 
908   LogicalResult
909   matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
910                   ConversionPatternRewriter &rewriter) const override {
911     auto srcType = op.getType();
912     auto dstType = getTypeConverter()->convertType(srcType);
913     if (!dstType)
914       return rewriter.notifyMatchFailure(op, "type conversion failed");
915 
916     Location loc = op.getLoc();
917     Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
918     Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
919     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
920     return success();
921   }
922 };
923 
924 /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
925 template <typename SPIRVOp>
926 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
927 public:
928   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
929 
930   LogicalResult
931   matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
932                   ConversionPatternRewriter &rewriter) const override {
933     if (!op.getMemoryAccess()) {
934       return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
935                                     *this->getTypeConverter(), /*alignment=*/0,
936                                     /*isVolatile=*/false,
937                                     /*isNonTemporal=*/false);
938     }
939     auto memoryAccess = *op.getMemoryAccess();
940     switch (memoryAccess) {
941     case spirv::MemoryAccess::Aligned:
942     case spirv::MemoryAccess::None:
943     case spirv::MemoryAccess::Nontemporal:
944     case spirv::MemoryAccess::Volatile: {
945       unsigned alignment =
946           memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
947       bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
948       bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
949       return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
950                                     *this->getTypeConverter(), alignment,
951                                     isVolatile, isNonTemporal);
952     }
953     default:
954       // There is no support of other memory access attributes.
955       return failure();
956     }
957   }
958 };
959 
960 /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
961 template <typename SPIRVOp>
962 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
963 public:
964   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
965 
966   LogicalResult
967   matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
968                   ConversionPatternRewriter &rewriter) const override {
969     auto srcType = notOp.getType();
970     auto dstType = this->getTypeConverter()->convertType(srcType);
971     if (!dstType)
972       return rewriter.notifyMatchFailure(notOp, "type conversion failed");
973 
974     Location loc = notOp.getLoc();
975     IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
976     auto mask =
977         isa<VectorType>(srcType)
978             ? rewriter.create<LLVM::ConstantOp>(
979                   loc, dstType,
980                   SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
981             : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
982     rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
983                                                       notOp.getOperand(), mask);
984     return success();
985   }
986 };
987 
988 /// A template pattern that erases the given `SPIRVOp`.
989 template <typename SPIRVOp>
990 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
991 public:
992   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
993 
994   LogicalResult
995   matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
996                   ConversionPatternRewriter &rewriter) const override {
997     rewriter.eraseOp(op);
998     return success();
999   }
1000 };
1001 
1002 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
1003 public:
1004   using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
1005 
1006   LogicalResult
1007   matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1008                   ConversionPatternRewriter &rewriter) const override {
1009     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
1010                                                 ArrayRef<Value>());
1011     return success();
1012   }
1013 };
1014 
1015 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
1016 public:
1017   using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
1018 
1019   LogicalResult
1020   matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1021                   ConversionPatternRewriter &rewriter) const override {
1022     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
1023                                                 adaptor.getOperands());
1024     return success();
1025   }
1026 };
1027 
1028 static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
1029                                               StringRef name,
1030                                               ArrayRef<Type> paramTypes,
1031                                               Type resultType,
1032                                               bool convergent = true) {
1033   auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1034       SymbolTable::lookupSymbolIn(symbolTable, name));
1035   if (func)
1036     return func;
1037 
1038   OpBuilder b(symbolTable->getRegion(0));
1039   func = b.create<LLVM::LLVMFuncOp>(
1040       symbolTable->getLoc(), name,
1041       LLVM::LLVMFunctionType::get(resultType, paramTypes));
1042   func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1043   func.setConvergent(convergent);
1044   func.setNoUnwind(true);
1045   func.setWillReturn(true);
1046   return func;
1047 }
1048 
1049 static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
1050                                            LLVM::LLVMFuncOp func,
1051                                            ValueRange args) {
1052   auto call = builder.create<LLVM::CallOp>(loc, func, args);
1053   call.setCConv(func.getCConv());
1054   call.setConvergentAttr(func.getConvergentAttr());
1055   call.setNoUnwindAttr(func.getNoUnwindAttr());
1056   call.setWillReturnAttr(func.getWillReturnAttr());
1057   return call;
1058 }
1059 
1060 template <typename BarrierOpTy>
1061 class ControlBarrierPattern : public SPIRVToLLVMConversion<BarrierOpTy> {
1062 public:
1063   using OpAdaptor = typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor;
1064 
1065   using SPIRVToLLVMConversion<BarrierOpTy>::SPIRVToLLVMConversion;
1066 
1067   static constexpr StringRef getFuncName();
1068 
1069   LogicalResult
1070   matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1071                   ConversionPatternRewriter &rewriter) const override {
1072     constexpr StringRef funcName = getFuncName();
1073     Operation *symbolTable =
1074         controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();
1075 
1076     Type i32 = rewriter.getI32Type();
1077 
1078     Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
1079     LLVM::LLVMFuncOp func =
1080         lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
1081 
1082     Location loc = controlBarrierOp->getLoc();
1083     Value execution = rewriter.create<LLVM::ConstantOp>(
1084         loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
1085     Value memory = rewriter.create<LLVM::ConstantOp>(
1086         loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
1087     Value semantics = rewriter.create<LLVM::ConstantOp>(
1088         loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
1089 
1090     auto call = createSPIRVBuiltinCall(loc, rewriter, func,
1091                                        {execution, memory, semantics});
1092 
1093     rewriter.replaceOp(controlBarrierOp, call);
1094     return success();
1095   }
1096 };
1097 
1098 namespace {
1099 
1100 StringRef getTypeMangling(Type type, bool isSigned) {
1101   return llvm::TypeSwitch<Type, StringRef>(type)
1102       .Case<Float16Type>([](auto) { return "Dh"; })
1103       .Case<Float32Type>([](auto) { return "f"; })
1104       .Case<Float64Type>([](auto) { return "d"; })
1105       .Case<IntegerType>([isSigned](IntegerType intTy) {
1106         switch (intTy.getWidth()) {
1107         case 1:
1108           return "b";
1109         case 8:
1110           return (isSigned) ? "a" : "c";
1111         case 16:
1112           return (isSigned) ? "s" : "t";
1113         case 32:
1114           return (isSigned) ? "i" : "j";
1115         case 64:
1116           return (isSigned) ? "l" : "m";
1117         default:
1118           llvm_unreachable("Unsupported integer width");
1119         }
1120       })
1121       .Default([](auto) {
1122         llvm_unreachable("No mangling defined");
1123         return "";
1124       });
1125 }
1126 
1127 template <typename ReduceOp>
1128 constexpr StringLiteral getGroupFuncName();
1129 
1130 template <>
1131 constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1132   return "_Z17__spirv_GroupIAddii";
1133 }
1134 template <>
1135 constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1136   return "_Z17__spirv_GroupFAddii";
1137 }
1138 template <>
1139 constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1140   return "_Z17__spirv_GroupSMinii";
1141 }
1142 template <>
1143 constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1144   return "_Z17__spirv_GroupUMinii";
1145 }
1146 template <>
1147 constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1148   return "_Z17__spirv_GroupFMinii";
1149 }
1150 template <>
1151 constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1152   return "_Z17__spirv_GroupSMaxii";
1153 }
1154 template <>
1155 constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1156   return "_Z17__spirv_GroupUMaxii";
1157 }
1158 template <>
1159 constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1160   return "_Z17__spirv_GroupFMaxii";
1161 }
1162 template <>
1163 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1164   return "_Z27__spirv_GroupNonUniformIAddii";
1165 }
1166 template <>
1167 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1168   return "_Z27__spirv_GroupNonUniformFAddii";
1169 }
1170 template <>
1171 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1172   return "_Z27__spirv_GroupNonUniformIMulii";
1173 }
1174 template <>
1175 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1176   return "_Z27__spirv_GroupNonUniformFMulii";
1177 }
1178 template <>
1179 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1180   return "_Z27__spirv_GroupNonUniformSMinii";
1181 }
1182 template <>
1183 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1184   return "_Z27__spirv_GroupNonUniformUMinii";
1185 }
1186 template <>
1187 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1188   return "_Z27__spirv_GroupNonUniformFMinii";
1189 }
1190 template <>
1191 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1192   return "_Z27__spirv_GroupNonUniformSMaxii";
1193 }
1194 template <>
1195 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1196   return "_Z27__spirv_GroupNonUniformUMaxii";
1197 }
1198 template <>
1199 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1200   return "_Z27__spirv_GroupNonUniformFMaxii";
1201 }
1202 template <>
1203 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1204   return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1205 }
1206 template <>
1207 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1208   return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1209 }
1210 template <>
1211 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1212   return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1213 }
1214 template <>
1215 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1216   return "_Z33__spirv_GroupNonUniformLogicalAndii";
1217 }
1218 template <>
1219 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1220   return "_Z32__spirv_GroupNonUniformLogicalOrii";
1221 }
1222 template <>
1223 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1224   return "_Z33__spirv_GroupNonUniformLogicalXorii";
1225 }
1226 } // namespace
1227 
1228 template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
1229 class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
1230 public:
1231   using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion;
1232 
1233   LogicalResult
1234   matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
1235                   ConversionPatternRewriter &rewriter) const override {
1236 
1237     Type retTy = op.getResult().getType();
1238     if (!retTy.isIntOrFloat()) {
1239       return failure();
1240     }
1241     SmallString<36> funcName = getGroupFuncName<ReduceOp>();
1242     funcName += getTypeMangling(retTy, false);
1243 
1244     Type i32Ty = rewriter.getI32Type();
1245     SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
1246     if constexpr (NonUniform) {
1247       if (adaptor.getClusterSize()) {
1248         funcName += "j";
1249         paramTypes.push_back(i32Ty);
1250       }
1251     }
1252 
1253     Operation *symbolTable =
1254         op->template getParentWithTrait<OpTrait::SymbolTable>();
1255 
1256     LLVM::LLVMFuncOp func =
1257         lookupOrCreateSPIRVFn(symbolTable, funcName, paramTypes, retTy);
1258 
1259     Location loc = op.getLoc();
1260     Value scope = rewriter.create<LLVM::ConstantOp>(
1261         loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope()));
1262     Value groupOp = rewriter.create<LLVM::ConstantOp>(
1263         loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation()));
1264     SmallVector<Value> operands{scope, groupOp};
1265     operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1266 
1267     auto call = createSPIRVBuiltinCall(loc, rewriter, func, operands);
1268     rewriter.replaceOp(op, call);
1269     return success();
1270   }
1271 };
1272 
1273 template <>
1274 constexpr StringRef
1275 ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1276   return "_Z22__spirv_ControlBarrieriii";
1277 }
1278 
1279 template <>
1280 constexpr StringRef
1281 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1282   return "_Z33__spirv_ControlBarrierArriveINTELiii";
1283 }
1284 
1285 template <>
1286 constexpr StringRef
1287 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1288   return "_Z31__spirv_ControlBarrierWaitINTELiii";
1289 }
1290 
1291 /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1292 /// should be reachable for conversion to succeed. The structure of the loop in
1293 /// LLVM dialect will be the following:
1294 ///
1295 ///      +------------------------------------+
1296 ///      | <code before spirv.mlir.loop>        |
1297 ///      | llvm.br ^header                    |
1298 ///      +------------------------------------+
1299 ///                           |
1300 ///   +----------------+      |
1301 ///   |                |      |
1302 ///   |                V      V
1303 ///   |  +------------------------------------+
1304 ///   |  | ^header:                           |
1305 ///   |  |   <header code>                    |
1306 ///   |  |   llvm.cond_br %cond, ^body, ^exit |
1307 ///   |  +------------------------------------+
1308 ///   |                    |
1309 ///   |                    |----------------------+
1310 ///   |                    |                      |
1311 ///   |                    V                      |
1312 ///   |  +------------------------------------+   |
1313 ///   |  | ^body:                             |   |
1314 ///   |  |   <body code>                      |   |
1315 ///   |  |   llvm.br ^continue                |   |
1316 ///   |  +------------------------------------+   |
1317 ///   |                    |                      |
1318 ///   |                    V                      |
1319 ///   |  +------------------------------------+   |
1320 ///   |  | ^continue:                         |   |
1321 ///   |  |   <continue code>                  |   |
1322 ///   |  |   llvm.br ^header                  |   |
1323 ///   |  +------------------------------------+   |
1324 ///   |               |                           |
1325 ///   +---------------+    +----------------------+
1326 ///                        |
1327 ///                        V
1328 ///      +------------------------------------+
1329 ///      | ^exit:                             |
1330 ///      |   llvm.br ^remaining               |
1331 ///      +------------------------------------+
1332 ///                        |
1333 ///                        V
1334 ///      +------------------------------------+
1335 ///      | ^remaining:                        |
1336 ///      |   <code after spirv.mlir.loop>       |
1337 ///      +------------------------------------+
1338 ///
1339 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1340 public:
1341   using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion;
1342 
1343   LogicalResult
1344   matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1345                   ConversionPatternRewriter &rewriter) const override {
1346     // There is no support of loop control at the moment.
1347     if (loopOp.getLoopControl() != spirv::LoopControl::None)
1348       return failure();
1349 
1350     // `spirv.mlir.loop` with empty region is redundant and should be erased.
1351     if (loopOp.getBody().empty()) {
1352       rewriter.eraseOp(loopOp);
1353       return success();
1354     }
1355 
1356     Location loc = loopOp.getLoc();
1357 
1358     // Split the current block after `spirv.mlir.loop`. The remaining ops will
1359     // be used in `endBlock`.
1360     Block *currentBlock = rewriter.getBlock();
1361     auto position = Block::iterator(loopOp);
1362     Block *endBlock = rewriter.splitBlock(currentBlock, position);
1363 
1364     // Remove entry block and create a branch in the current block going to the
1365     // header block.
1366     Block *entryBlock = loopOp.getEntryBlock();
1367     assert(entryBlock->getOperations().size() == 1);
1368     auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1369     if (!brOp)
1370       return failure();
1371     Block *headerBlock = loopOp.getHeaderBlock();
1372     rewriter.setInsertionPointToEnd(currentBlock);
1373     rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1374     rewriter.eraseBlock(entryBlock);
1375 
1376     // Branch from merge block to end block.
1377     Block *mergeBlock = loopOp.getMergeBlock();
1378     Operation *terminator = mergeBlock->getTerminator();
1379     ValueRange terminatorOperands = terminator->getOperands();
1380     rewriter.setInsertionPointToEnd(mergeBlock);
1381     rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1382 
1383     rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
1384     rewriter.replaceOp(loopOp, endBlock->getArguments());
1385     return success();
1386   }
1387 };
1388 
1389 /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
1390 /// block. All blocks within selection should be reachable for conversion to
1391 /// succeed.
1392 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1393 public:
1394   using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
1395 
1396   LogicalResult
1397   matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1398                   ConversionPatternRewriter &rewriter) const override {
1399     // There is no support for `Flatten` or `DontFlatten` selection control at
1400     // the moment. This are just compiler hints and can be performed during the
1401     // optimization passes.
1402     if (op.getSelectionControl() != spirv::SelectionControl::None)
1403       return failure();
1404 
1405     // `spirv.mlir.selection` should have at least two blocks: one selection
1406     // header block and one merge block. If no blocks are present, or control
1407     // flow branches straight to merge block (two blocks are present), the op is
1408     // redundant and it is erased.
1409     if (op.getBody().getBlocks().size() <= 2) {
1410       rewriter.eraseOp(op);
1411       return success();
1412     }
1413 
1414     Location loc = op.getLoc();
1415 
1416     // Split the current block after `spirv.mlir.selection`. The remaining ops
1417     // will be used in `continueBlock`.
1418     auto *currentBlock = rewriter.getInsertionBlock();
1419     rewriter.setInsertionPointAfter(op);
1420     auto position = rewriter.getInsertionPoint();
1421     auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1422 
1423     // Extract conditional branch information from the header block. By SPIR-V
1424     // dialect spec, it should contain `spirv.BranchConditional` or
1425     // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the
1426     // moment in the SPIR-V dialect. Remove this block when finished.
1427     auto *headerBlock = op.getHeaderBlock();
1428     assert(headerBlock->getOperations().size() == 1);
1429     auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1430         headerBlock->getOperations().front());
1431     if (!condBrOp)
1432       return failure();
1433     rewriter.eraseBlock(headerBlock);
1434 
1435     // Branch from merge block to continue block.
1436     auto *mergeBlock = op.getMergeBlock();
1437     Operation *terminator = mergeBlock->getTerminator();
1438     ValueRange terminatorOperands = terminator->getOperands();
1439     rewriter.setInsertionPointToEnd(mergeBlock);
1440     rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1441 
1442     // Link current block to `true` and `false` blocks within the selection.
1443     Block *trueBlock = condBrOp.getTrueBlock();
1444     Block *falseBlock = condBrOp.getFalseBlock();
1445     rewriter.setInsertionPointToEnd(currentBlock);
1446     rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1447                                     condBrOp.getTrueTargetOperands(),
1448                                     falseBlock,
1449                                     condBrOp.getFalseTargetOperands());
1450 
1451     rewriter.inlineRegionBefore(op.getBody(), continueBlock);
1452     rewriter.replaceOp(op, continueBlock->getArguments());
1453     return success();
1454   }
1455 };
1456 
1457 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1458 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1459 /// `Shift` is zero or sign extended to match this specification. Cases when
1460 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1461 template <typename SPIRVOp, typename LLVMOp>
1462 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1463 public:
1464   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
1465 
1466   LogicalResult
1467   matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1468                   ConversionPatternRewriter &rewriter) const override {
1469 
1470     auto dstType = this->getTypeConverter()->convertType(op.getType());
1471     if (!dstType)
1472       return rewriter.notifyMatchFailure(op, "type conversion failed");
1473 
1474     Type op1Type = op.getOperand1().getType();
1475     Type op2Type = op.getOperand2().getType();
1476 
1477     if (op1Type == op2Type) {
1478       rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1479                                                    adaptor.getOperands());
1480       return success();
1481     }
1482 
1483     std::optional<uint64_t> dstTypeWidth =
1484         getIntegerOrVectorElementWidth(dstType);
1485     std::optional<uint64_t> op2TypeWidth =
1486         getIntegerOrVectorElementWidth(op2Type);
1487 
1488     if (!dstTypeWidth || !op2TypeWidth)
1489       return failure();
1490 
1491     Location loc = op.getLoc();
1492     Value extended;
1493     if (op2TypeWidth < dstTypeWidth) {
1494       if (isUnsignedIntegerOrVector(op2Type)) {
1495         extended = rewriter.template create<LLVM::ZExtOp>(
1496             loc, dstType, adaptor.getOperand2());
1497       } else {
1498         extended = rewriter.template create<LLVM::SExtOp>(
1499             loc, dstType, adaptor.getOperand2());
1500       }
1501     } else if (op2TypeWidth == dstTypeWidth) {
1502       extended = adaptor.getOperand2();
1503     } else {
1504       return failure();
1505     }
1506 
1507     Value result = rewriter.template create<LLVMOp>(
1508         loc, dstType, adaptor.getOperand1(), extended);
1509     rewriter.replaceOp(op, result);
1510     return success();
1511   }
1512 };
1513 
1514 class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1515 public:
1516   using SPIRVToLLVMConversion<spirv::GLTanOp>::SPIRVToLLVMConversion;
1517 
1518   LogicalResult
1519   matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1520                   ConversionPatternRewriter &rewriter) const override {
1521     auto dstType = getTypeConverter()->convertType(tanOp.getType());
1522     if (!dstType)
1523       return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
1524 
1525     Location loc = tanOp.getLoc();
1526     Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1527     Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1528     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1529     return success();
1530   }
1531 };
1532 
1533 /// Convert `spirv.Tanh` to
1534 ///
1535 ///   exp(2x) - 1
1536 ///   -----------
1537 ///   exp(2x) + 1
1538 ///
1539 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1540 public:
1541   using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion;
1542 
1543   LogicalResult
1544   matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1545                   ConversionPatternRewriter &rewriter) const override {
1546     auto srcType = tanhOp.getType();
1547     auto dstType = getTypeConverter()->convertType(srcType);
1548     if (!dstType)
1549       return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
1550 
1551     Location loc = tanhOp.getLoc();
1552     Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1553     Value multiplied =
1554         rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1555     Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1556     Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1557     Value numerator =
1558         rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1559     Value denominator =
1560         rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1561     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1562                                               denominator);
1563     return success();
1564   }
1565 };
1566 
1567 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1568 public:
1569   using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
1570 
1571   LogicalResult
1572   matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1573                   ConversionPatternRewriter &rewriter) const override {
1574     auto srcType = varOp.getType();
1575     // Initialization is supported for scalars and vectors only.
1576     auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1577     auto init = varOp.getInitializer();
1578     if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1579       return failure();
1580 
1581     auto dstType = getTypeConverter()->convertType(srcType);
1582     if (!dstType)
1583       return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1584 
1585     Location loc = varOp.getLoc();
1586     Value size = createI32ConstantOf(loc, rewriter, 1);
1587     if (!init) {
1588       auto elementType = getTypeConverter()->convertType(pointerTo);
1589       if (!elementType)
1590         return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1591       rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
1592                                                   size);
1593       return success();
1594     }
1595     auto elementType = getTypeConverter()->convertType(pointerTo);
1596     if (!elementType)
1597       return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1598     Value allocated =
1599         rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size);
1600     rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1601     rewriter.replaceOp(varOp, allocated);
1602     return success();
1603   }
1604 };
1605 
1606 //===----------------------------------------------------------------------===//
1607 // BitcastOp conversion
1608 //===----------------------------------------------------------------------===//
1609 
1610 class BitcastConversionPattern
1611     : public SPIRVToLLVMConversion<spirv::BitcastOp> {
1612 public:
1613   using SPIRVToLLVMConversion<spirv::BitcastOp>::SPIRVToLLVMConversion;
1614 
1615   LogicalResult
1616   matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1617                   ConversionPatternRewriter &rewriter) const override {
1618     auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1619     if (!dstType)
1620       return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
1621 
1622     // LLVM's opaque pointers do not require bitcasts.
1623     if (isa<LLVM::LLVMPointerType>(dstType)) {
1624       rewriter.replaceOp(bitcastOp, adaptor.getOperand());
1625       return success();
1626     }
1627 
1628     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1629         bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1630     return success();
1631   }
1632 };
1633 
1634 //===----------------------------------------------------------------------===//
1635 // FuncOp conversion
1636 //===----------------------------------------------------------------------===//
1637 
1638 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1639 public:
1640   using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
1641 
1642   LogicalResult
1643   matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1644                   ConversionPatternRewriter &rewriter) const override {
1645 
1646     // Convert function signature. At the moment LLVMType converter is enough
1647     // for currently supported types.
1648     auto funcType = funcOp.getFunctionType();
1649     TypeConverter::SignatureConversion signatureConverter(
1650         funcType.getNumInputs());
1651     auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter())
1652                         ->convertFunctionSignature(
1653                             funcType, /*isVariadic=*/false,
1654                             /*useBarePtrCallConv=*/false, signatureConverter);
1655     if (!llvmType)
1656       return failure();
1657 
1658     // Create a new `LLVMFuncOp`
1659     Location loc = funcOp.getLoc();
1660     StringRef name = funcOp.getName();
1661     auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1662 
1663     // Convert SPIR-V Function Control to equivalent LLVM function attribute
1664     MLIRContext *context = funcOp.getContext();
1665     switch (funcOp.getFunctionControl()) {
1666     case spirv::FunctionControl::Inline:
1667       newFuncOp.setAlwaysInline(true);
1668       break;
1669     case spirv::FunctionControl::DontInline:
1670       newFuncOp.setNoInline(true);
1671       break;
1672 
1673 #define DISPATCH(functionControl, llvmAttr)                                    \
1674   case functionControl:                                                        \
1675     newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr}));    \
1676     break;
1677 
1678       DISPATCH(spirv::FunctionControl::Pure,
1679                StringAttr::get(context, "readonly"));
1680       DISPATCH(spirv::FunctionControl::Const,
1681                StringAttr::get(context, "readnone"));
1682 
1683 #undef DISPATCH
1684 
1685     // Default: if `spirv::FunctionControl::None`, then no attributes are
1686     // needed.
1687     default:
1688       break;
1689     }
1690 
1691     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1692                                 newFuncOp.end());
1693     if (failed(rewriter.convertRegionTypes(
1694             &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1695       return failure();
1696     }
1697     rewriter.eraseOp(funcOp);
1698     return success();
1699   }
1700 };
1701 
1702 //===----------------------------------------------------------------------===//
1703 // ModuleOp conversion
1704 //===----------------------------------------------------------------------===//
1705 
1706 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1707 public:
1708   using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
1709 
1710   LogicalResult
1711   matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1712                   ConversionPatternRewriter &rewriter) const override {
1713 
1714     auto newModuleOp =
1715         rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1716     rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1717 
1718     // Remove the terminator block that was automatically added by builder
1719     rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1720     rewriter.eraseOp(spvModuleOp);
1721     return success();
1722   }
1723 };
1724 
1725 //===----------------------------------------------------------------------===//
1726 // VectorShuffleOp conversion
1727 //===----------------------------------------------------------------------===//
1728 
1729 class VectorShufflePattern
1730     : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1731 public:
1732   using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion;
1733   LogicalResult
1734   matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1735                   ConversionPatternRewriter &rewriter) const override {
1736     Location loc = op.getLoc();
1737     auto components = adaptor.getComponents();
1738     auto vector1 = adaptor.getVector1();
1739     auto vector2 = adaptor.getVector2();
1740     int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1741     int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1742     if (vector1Size == vector2Size) {
1743       rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1744           op, vector1, vector2,
1745           LLVM::convertArrayToIndices<int32_t>(components));
1746       return success();
1747     }
1748 
1749     auto dstType = getTypeConverter()->convertType(op.getType());
1750     if (!dstType)
1751       return rewriter.notifyMatchFailure(op, "type conversion failed");
1752     auto scalarType = cast<VectorType>(dstType).getElementType();
1753     auto componentsArray = components.getValue();
1754     auto *context = rewriter.getContext();
1755     auto llvmI32Type = IntegerType::get(context, 32);
1756     Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
1757     for (unsigned i = 0; i < componentsArray.size(); i++) {
1758       if (!isa<IntegerAttr>(componentsArray[i]))
1759         return op.emitError("unable to support non-constant component");
1760 
1761       int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1762       if (indexVal == -1)
1763         continue;
1764 
1765       int offsetVal = 0;
1766       Value baseVector = vector1;
1767       if (indexVal >= vector1Size) {
1768         offsetVal = vector1Size;
1769         baseVector = vector2;
1770       }
1771 
1772       Value dstIndex = rewriter.create<LLVM::ConstantOp>(
1773           loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1774       Value index = rewriter.create<LLVM::ConstantOp>(
1775           loc, llvmI32Type,
1776           rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1777 
1778       auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
1779           loc, scalarType, baseVector, index);
1780       targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1781                                                         extractOp, dstIndex);
1782     }
1783     rewriter.replaceOp(op, targetOp);
1784     return success();
1785   }
1786 };
1787 } // namespace
1788 
1789 //===----------------------------------------------------------------------===//
1790 // Pattern population
1791 //===----------------------------------------------------------------------===//
1792 
1793 void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter,
1794                                              spirv::ClientAPI clientAPI) {
1795   typeConverter.addConversion([&](spirv::ArrayType type) {
1796     return convertArrayType(type, typeConverter);
1797   });
1798   typeConverter.addConversion([&, clientAPI](spirv::PointerType type) {
1799     return convertPointerType(type, typeConverter, clientAPI);
1800   });
1801   typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1802     return convertRuntimeArrayType(type, typeConverter);
1803   });
1804   typeConverter.addConversion([&](spirv::StructType type) {
1805     return convertStructType(type, typeConverter);
1806   });
1807 }
1808 
1809 void mlir::populateSPIRVToLLVMConversionPatterns(
1810     const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
1811     spirv::ClientAPI clientAPI) {
1812   patterns.add<
1813       // Arithmetic ops
1814       DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1815       DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1816       DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1817       DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1818       DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1819       DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1820       DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1821       DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1822       DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1823       DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1824       DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1825       DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1826       DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1827 
1828       // Bitwise ops
1829       BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1830       DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1831       DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1832       DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1833       DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1834       DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1835       NotPattern<spirv::NotOp>,
1836 
1837       // Cast ops
1838       BitcastConversionPattern,
1839       DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1840       DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1841       DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1842       DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1843       IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1844       IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1845       IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1846 
1847       // Comparison ops
1848       IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1849       IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1850       FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1851       FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1852       FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1853       FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1854       FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1855       FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1856       FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1857       FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1858       FComparePattern<spirv::FUnordGreaterThanEqualOp,
1859                       LLVM::FCmpPredicate::uge>,
1860       FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1861       FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1862       FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1863       IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1864       IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1865       IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1866       IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1867       IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1868       IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1869       IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1870       IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1871 
1872       // Constant op
1873       ConstantScalarAndVectorPattern,
1874 
1875       // Control Flow ops
1876       BranchConversionPattern, BranchConditionalConversionPattern,
1877       FunctionCallPattern, LoopPattern, SelectionPattern,
1878       ErasePattern<spirv::MergeOp>,
1879 
1880       // Entry points and execution mode are handled separately.
1881       ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1882 
1883       // GLSL extended instruction set ops
1884       DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1885       DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1886       DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1887       DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1888       DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1889       DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1890       DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1891       DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1892       DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1893       DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1894       DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1895       DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1896       InverseSqrtPattern, TanPattern, TanhPattern,
1897 
1898       // Logical ops
1899       DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1900       DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1901       IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1902       IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1903       NotPattern<spirv::LogicalNotOp>,
1904 
1905       // Memory ops
1906       AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1907       LoadStorePattern<spirv::StoreOp>, VariablePattern,
1908 
1909       // Miscellaneous ops
1910       CompositeExtractPattern, CompositeInsertPattern,
1911       DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1912       DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1913       VectorShufflePattern,
1914 
1915       // Shift ops
1916       ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1917       ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1918       ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1919 
1920       // Return ops
1921       ReturnPattern, ReturnValuePattern,
1922 
1923       // Barrier ops
1924       ControlBarrierPattern<spirv::ControlBarrierOp>,
1925       ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1926       ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
1927 
1928       // Group reduction operations
1929       GroupReducePattern<spirv::GroupIAddOp>,
1930       GroupReducePattern<spirv::GroupFAddOp>,
1931       GroupReducePattern<spirv::GroupFMinOp>,
1932       GroupReducePattern<spirv::GroupUMinOp>,
1933       GroupReducePattern<spirv::GroupSMinOp, /*Signed=*/true>,
1934       GroupReducePattern<spirv::GroupFMaxOp>,
1935       GroupReducePattern<spirv::GroupUMaxOp>,
1936       GroupReducePattern<spirv::GroupSMaxOp, /*Signed=*/true>,
1937       GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed=*/false,
1938                          /*NonUniform=*/true>,
1939       GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed=*/false,
1940                          /*NonUniform=*/true>,
1941       GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed=*/false,
1942                          /*NonUniform=*/true>,
1943       GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed=*/false,
1944                          /*NonUniform=*/true>,
1945       GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed=*/true,
1946                          /*NonUniform=*/true>,
1947       GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed=*/false,
1948                          /*NonUniform=*/true>,
1949       GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed=*/false,
1950                          /*NonUniform=*/true>,
1951       GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed=*/true,
1952                          /*NonUniform=*/true>,
1953       GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed=*/false,
1954                          /*NonUniform=*/true>,
1955       GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed=*/false,
1956                          /*NonUniform=*/true>,
1957       GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed=*/false,
1958                          /*NonUniform=*/true>,
1959       GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed=*/false,
1960                          /*NonUniform=*/true>,
1961       GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed=*/false,
1962                          /*NonUniform=*/true>,
1963       GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed=*/false,
1964                          /*NonUniform=*/true>,
1965       GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed=*/false,
1966                          /*NonUniform=*/true>,
1967       GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed=*/false,
1968                          /*NonUniform=*/true>>(patterns.getContext(),
1969                                                typeConverter);
1970 
1971   patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
1972                                       typeConverter);
1973 }
1974 
1975 void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
1976     const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1977   patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1978 }
1979 
1980 void mlir::populateSPIRVToLLVMModuleConversionPatterns(
1981     const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1982   patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1983 }
1984 
1985 //===----------------------------------------------------------------------===//
1986 // Pre-conversion hooks
1987 //===----------------------------------------------------------------------===//
1988 
1989 /// Hook for descriptor set and binding number encoding.
1990 static constexpr StringRef kBinding = "binding";
1991 static constexpr StringRef kDescriptorSet = "descriptor_set";
1992 void mlir::encodeBindAttribute(ModuleOp module) {
1993   auto spvModules = module.getOps<spirv::ModuleOp>();
1994   for (auto spvModule : spvModules) {
1995     spvModule.walk([&](spirv::GlobalVariableOp op) {
1996       IntegerAttr descriptorSet =
1997           op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1998       IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1999       // For every global variable in the module, get the ones with descriptor
2000       // set and binding numbers.
2001       if (descriptorSet && binding) {
2002         // Encode these numbers into the variable's symbolic name. If the
2003         // SPIR-V module has a name, add it at the beginning.
2004         auto moduleAndName =
2005             spvModule.getName().has_value()
2006                 ? spvModule.getName()->str() + "_" + op.getSymName().str()
2007                 : op.getSymName().str();
2008         std::string name =
2009             llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
2010                           std::to_string(descriptorSet.getInt()),
2011                           std::to_string(binding.getInt()));
2012         auto nameAttr = StringAttr::get(op->getContext(), name);
2013 
2014         // Replace all symbol uses and set the new symbol name. Finally, remove
2015         // descriptor set and binding attributes.
2016         if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
2017           op.emitError("unable to replace all symbol uses for ") << name;
2018         SymbolTable::setSymbolName(op, nameAttr);
2019         op->removeAttr(kDescriptorSet);
2020         op->removeAttr(kBinding);
2021       }
2022     });
2023   }
2024 }
2025