xref: /llvm-project/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp (revision 4adeb6cf556df10da668916b22eb39d3f1313e8a)
1930c74f1SLei Zhang //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2930c74f1SLei Zhang //
3930c74f1SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4930c74f1SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
5930c74f1SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6930c74f1SLei Zhang //
7930c74f1SLei Zhang //===----------------------------------------------------------------------===//
8930c74f1SLei Zhang //
9930c74f1SLei Zhang // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10930c74f1SLei Zhang //
11930c74f1SLei Zhang //===----------------------------------------------------------------------===//
12930c74f1SLei Zhang 
13930c74f1SLei Zhang #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
1475e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h"
1575e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16d45de800SVictor Perez #include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h"
17930c74f1SLei Zhang #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18930c74f1SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19a29fffc4SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
20930c74f1SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
21930c74f1SLei Zhang #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
22930c74f1SLei Zhang #include "mlir/IR/BuiltinOps.h"
23930c74f1SLei Zhang #include "mlir/IR/PatternMatch.h"
24930c74f1SLei Zhang #include "mlir/Transforms/DialectConversion.h"
256ade03d7SLukas Sommer #include "llvm/ADT/TypeSwitch.h"
26930c74f1SLei Zhang #include "llvm/Support/Debug.h"
27930c74f1SLei Zhang #include "llvm/Support/FormatVariadic.h"
28930c74f1SLei Zhang 
29930c74f1SLei Zhang #define DEBUG_TYPE "spirv-to-llvm-pattern"
30930c74f1SLei Zhang 
31930c74f1SLei Zhang using namespace mlir;
32930c74f1SLei Zhang 
33930c74f1SLei Zhang //===----------------------------------------------------------------------===//
34930c74f1SLei Zhang // Utility functions
35930c74f1SLei Zhang //===----------------------------------------------------------------------===//
36930c74f1SLei Zhang 
37930c74f1SLei Zhang /// Returns true if the given type is a signed integer or vector type.
38930c74f1SLei Zhang static bool isSignedIntegerOrVector(Type type) {
39930c74f1SLei Zhang   if (type.isSignedInteger())
40930c74f1SLei Zhang     return true;
415550c821STres Popp   if (auto vecType = dyn_cast<VectorType>(type))
42930c74f1SLei Zhang     return vecType.getElementType().isSignedInteger();
43930c74f1SLei Zhang   return false;
44930c74f1SLei Zhang }
45930c74f1SLei Zhang 
46930c74f1SLei Zhang /// Returns true if the given type is an unsigned integer or vector type
47930c74f1SLei Zhang static bool isUnsignedIntegerOrVector(Type type) {
48930c74f1SLei Zhang   if (type.isUnsignedInteger())
49930c74f1SLei Zhang     return true;
505550c821STres Popp   if (auto vecType = dyn_cast<VectorType>(type))
51930c74f1SLei Zhang     return vecType.getElementType().isUnsignedInteger();
52930c74f1SLei Zhang   return false;
53930c74f1SLei Zhang }
54930c74f1SLei Zhang 
5500b12a94SThéo Degioanni /// Returns the width of an integer or of the element type of an integer vector,
5600b12a94SThéo Degioanni /// if applicable.
5700b12a94SThéo Degioanni static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
5800b12a94SThéo Degioanni   if (auto intType = dyn_cast<IntegerType>(type))
5900b12a94SThéo Degioanni     return intType.getWidth();
6000b12a94SThéo Degioanni   if (auto vecType = dyn_cast<VectorType>(type))
6100b12a94SThéo Degioanni     if (auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
6200b12a94SThéo Degioanni       return intType.getWidth();
6300b12a94SThéo Degioanni   return std::nullopt;
6400b12a94SThéo Degioanni }
6500b12a94SThéo Degioanni 
66930c74f1SLei Zhang /// Returns the bit width of integer, float or vector of float or integer values
67930c74f1SLei Zhang static unsigned getBitWidth(Type type) {
685550c821STres Popp   assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
69930c74f1SLei Zhang          "bitwidth is not supported for this type");
70930c74f1SLei Zhang   if (type.isIntOrFloat())
71930c74f1SLei Zhang     return type.getIntOrFloatBitWidth();
725550c821STres Popp   auto vecType = dyn_cast<VectorType>(type);
73930c74f1SLei Zhang   auto elementType = vecType.getElementType();
74930c74f1SLei Zhang   assert(elementType.isIntOrFloat() &&
75930c74f1SLei Zhang          "only integers and floats have a bitwidth");
76930c74f1SLei Zhang   return elementType.getIntOrFloatBitWidth();
77930c74f1SLei Zhang }
78930c74f1SLei Zhang 
79930c74f1SLei Zhang /// Returns the bit width of LLVMType integer or vector.
80c69c9e0fSAlex Zinenko static unsigned getLLVMTypeBitWidth(Type type) {
815550c821STres Popp   return cast<IntegerType>((LLVM::isCompatibleVectorType(type)
825550c821STres Popp                                 ? LLVM::getVectorElementType(type)
835550c821STres Popp                                 : type))
842230bf99SAlex Zinenko       .getWidth();
85930c74f1SLei Zhang }
86930c74f1SLei Zhang 
87930c74f1SLei Zhang /// Creates `IntegerAttribute` with all bits set for given type
88930c74f1SLei Zhang static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
895550c821STres Popp   if (auto vecType = dyn_cast<VectorType>(type)) {
905550c821STres Popp     auto integerType = cast<IntegerType>(vecType.getElementType());
91930c74f1SLei Zhang     return builder.getIntegerAttr(integerType, -1);
92930c74f1SLei Zhang   }
935550c821STres Popp   auto integerType = cast<IntegerType>(type);
94930c74f1SLei Zhang   return builder.getIntegerAttr(integerType, -1);
95930c74f1SLei Zhang }
96930c74f1SLei Zhang 
97930c74f1SLei Zhang /// Creates `llvm.mlir.constant` with all bits set for the given type.
98930c74f1SLei Zhang static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
99930c74f1SLei Zhang                                       PatternRewriter &rewriter) {
1005550c821STres Popp   if (isa<VectorType>(srcType)) {
101930c74f1SLei Zhang     return rewriter.create<LLVM::ConstantOp>(
102930c74f1SLei Zhang         loc, dstType,
1035550c821STres Popp         SplatElementsAttr::get(cast<ShapedType>(srcType),
104930c74f1SLei Zhang                                minusOneIntegerAttribute(srcType, rewriter)));
105930c74f1SLei Zhang   }
106930c74f1SLei Zhang   return rewriter.create<LLVM::ConstantOp>(
107930c74f1SLei Zhang       loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
108930c74f1SLei Zhang }
109930c74f1SLei Zhang 
110930c74f1SLei Zhang /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
111930c74f1SLei Zhang static Value createFPConstant(Location loc, Type srcType, Type dstType,
112930c74f1SLei Zhang                               PatternRewriter &rewriter, double value) {
1135550c821STres Popp   if (auto vecType = dyn_cast<VectorType>(srcType)) {
1145550c821STres Popp     auto floatType = cast<FloatType>(vecType.getElementType());
115930c74f1SLei Zhang     return rewriter.create<LLVM::ConstantOp>(
116930c74f1SLei Zhang         loc, dstType,
117930c74f1SLei Zhang         SplatElementsAttr::get(vecType,
118930c74f1SLei Zhang                                rewriter.getFloatAttr(floatType, value)));
119930c74f1SLei Zhang   }
1205550c821STres Popp   auto floatType = cast<FloatType>(srcType);
121930c74f1SLei Zhang   return rewriter.create<LLVM::ConstantOp>(
122930c74f1SLei Zhang       loc, dstType, rewriter.getFloatAttr(floatType, value));
123930c74f1SLei Zhang }
124930c74f1SLei Zhang 
125930c74f1SLei Zhang /// Utility function for bitfield ops:
126930c74f1SLei Zhang ///   - `BitFieldInsert`
127930c74f1SLei Zhang ///   - `BitFieldSExtract`
128930c74f1SLei Zhang ///   - `BitFieldUExtract`
129930c74f1SLei Zhang /// Truncates or extends the value. If the bitwidth of the value is the same as
130c69c9e0fSAlex Zinenko /// `llvmType` bitwidth, the value remains unchanged.
131c69c9e0fSAlex Zinenko static Value optionallyTruncateOrExtend(Location loc, Value value,
132c69c9e0fSAlex Zinenko                                         Type llvmType,
133930c74f1SLei Zhang                                         PatternRewriter &rewriter) {
134930c74f1SLei Zhang   auto srcType = value.getType();
135930c74f1SLei Zhang   unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
136c69c9e0fSAlex Zinenko   unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
137c69c9e0fSAlex Zinenko                                ? getLLVMTypeBitWidth(srcType)
138930c74f1SLei Zhang                                : getBitWidth(srcType);
139930c74f1SLei Zhang 
140930c74f1SLei Zhang   if (valueBitWidth < targetBitWidth)
141930c74f1SLei Zhang     return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
142930c74f1SLei Zhang   // If the bit widths of `Count` and `Offset` are greater than the bit width
143930c74f1SLei Zhang   // of the target type, they are truncated. Truncation is safe since `Count`
144930c74f1SLei Zhang   // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
145930c74f1SLei Zhang   // both values can be expressed in 8 bits.
146930c74f1SLei Zhang   if (valueBitWidth > targetBitWidth)
147930c74f1SLei Zhang     return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
148930c74f1SLei Zhang   return value;
149930c74f1SLei Zhang }
150930c74f1SLei Zhang 
151930c74f1SLei Zhang /// Broadcasts the value to vector with `numElements` number of elements.
152930c74f1SLei Zhang static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
15398723e65SMatthias Springer                        const TypeConverter &typeConverter,
154930c74f1SLei Zhang                        ConversionPatternRewriter &rewriter) {
155930c74f1SLei Zhang   auto vectorType = VectorType::get(numElements, toBroadcast.getType());
156930c74f1SLei Zhang   auto llvmVectorType = typeConverter.convertType(vectorType);
157930c74f1SLei Zhang   auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
158930c74f1SLei Zhang   Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
159930c74f1SLei Zhang   for (unsigned i = 0; i < numElements; ++i) {
160930c74f1SLei Zhang     auto index = rewriter.create<LLVM::ConstantOp>(
161930c74f1SLei Zhang         loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
162930c74f1SLei Zhang     broadcasted = rewriter.create<LLVM::InsertElementOp>(
163930c74f1SLei Zhang         loc, llvmVectorType, broadcasted, toBroadcast, index);
164930c74f1SLei Zhang   }
165930c74f1SLei Zhang   return broadcasted;
166930c74f1SLei Zhang }
167930c74f1SLei Zhang 
168930c74f1SLei Zhang /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
169930c74f1SLei Zhang static Value optionallyBroadcast(Location loc, Value value, Type srcType,
17098723e65SMatthias Springer                                  const TypeConverter &typeConverter,
171930c74f1SLei Zhang                                  ConversionPatternRewriter &rewriter) {
1725550c821STres Popp   if (auto vectorType = dyn_cast<VectorType>(srcType)) {
173930c74f1SLei Zhang     unsigned numElements = vectorType.getNumElements();
174930c74f1SLei Zhang     return broadcast(loc, value, numElements, typeConverter, rewriter);
175930c74f1SLei Zhang   }
176930c74f1SLei Zhang   return value;
177930c74f1SLei Zhang }
178930c74f1SLei Zhang 
179930c74f1SLei Zhang /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
180930c74f1SLei Zhang /// `BitFieldUExtract`.
181930c74f1SLei Zhang /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
182930c74f1SLei Zhang /// a vector type, construct a vector that has:
183930c74f1SLei Zhang ///  - same number of elements as `Base`
184930c74f1SLei Zhang ///  - each element has the type that is the same as the type of `Offset` or
185930c74f1SLei Zhang ///    `Count`
186930c74f1SLei Zhang ///  - each element has the same value as `Offset` or `Count`
187930c74f1SLei Zhang /// Then cast `Offset` and `Count` if their bit width is different
188930c74f1SLei Zhang /// from `Base` bit width.
189930c74f1SLei Zhang static Value processCountOrOffset(Location loc, Value value, Type srcType,
19098723e65SMatthias Springer                                   Type dstType, const TypeConverter &converter,
191930c74f1SLei Zhang                                   ConversionPatternRewriter &rewriter) {
192930c74f1SLei Zhang   Value broadcasted =
193930c74f1SLei Zhang       optionallyBroadcast(loc, value, srcType, converter, rewriter);
194930c74f1SLei Zhang   return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
195930c74f1SLei Zhang }
196930c74f1SLei Zhang 
197930c74f1SLei Zhang /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
198930c74f1SLei Zhang /// offset to LLVM struct. Otherwise, the conversion is not supported.
19930ca16ecSPierre van Houtryve static Type convertStructTypeWithOffset(spirv::StructType type,
20098723e65SMatthias Springer                                         const TypeConverter &converter) {
201930c74f1SLei Zhang   if (type != VulkanLayoutUtils::decorateType(type))
20230ca16ecSPierre van Houtryve     return nullptr;
203930c74f1SLei Zhang 
20430ca16ecSPierre van Houtryve   SmallVector<Type> elementsVector;
20530ca16ecSPierre van Houtryve   if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
20630ca16ecSPierre van Houtryve     return nullptr;
207930c74f1SLei Zhang   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
208930c74f1SLei Zhang                                           /*isPacked=*/false);
209930c74f1SLei Zhang }
210930c74f1SLei Zhang 
211930c74f1SLei Zhang /// Converts SPIR-V struct with no offset to packed LLVM struct.
212930c74f1SLei Zhang static Type convertStructTypePacked(spirv::StructType type,
21398723e65SMatthias Springer                                     const TypeConverter &converter) {
21430ca16ecSPierre van Houtryve   SmallVector<Type> elementsVector;
21530ca16ecSPierre van Houtryve   if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
21630ca16ecSPierre van Houtryve     return nullptr;
217930c74f1SLei Zhang   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
218930c74f1SLei Zhang                                           /*isPacked=*/true);
219930c74f1SLei Zhang }
220930c74f1SLei Zhang 
221930c74f1SLei Zhang /// Creates LLVM dialect constant with the given value.
222930c74f1SLei Zhang static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
223930c74f1SLei Zhang                                  unsigned value) {
224930c74f1SLei Zhang   return rewriter.create<LLVM::ConstantOp>(
2252230bf99SAlex Zinenko       loc, IntegerType::get(rewriter.getContext(), 32),
226930c74f1SLei Zhang       rewriter.getIntegerAttr(rewriter.getI32Type(), value));
227930c74f1SLei Zhang }
228930c74f1SLei Zhang 
2295ab6ef75SJakub Kuderski /// Utility for `spirv.Load` and `spirv.Store` conversion.
23098723e65SMatthias Springer static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
23198723e65SMatthias Springer                                             ConversionPatternRewriter &rewriter,
23298723e65SMatthias Springer                                             const TypeConverter &typeConverter,
23398723e65SMatthias Springer                                             unsigned alignment, bool isVolatile,
234930c74f1SLei Zhang                                             bool isNonTemporal) {
235930c74f1SLei Zhang   if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
236930c74f1SLei Zhang     auto dstType = typeConverter.convertType(loadOp.getType());
237930c74f1SLei Zhang     if (!dstType)
238c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(op, "type conversion failed");
239930c74f1SLei Zhang     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
24090a1632dSJakub Kuderski         loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
241015192c6SRiver Riddle         isVolatile, isNonTemporal);
242930c74f1SLei Zhang     return success();
243930c74f1SLei Zhang   }
244930c74f1SLei Zhang   auto storeOp = cast<spirv::StoreOp>(op);
245015192c6SRiver Riddle   spirv::StoreOpAdaptor adaptor(operands);
24690a1632dSJakub Kuderski   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
24790a1632dSJakub Kuderski                                              adaptor.getPtr(), alignment,
248930c74f1SLei Zhang                                              isVolatile, isNonTemporal);
249930c74f1SLei Zhang   return success();
250930c74f1SLei Zhang }
251930c74f1SLei Zhang 
252930c74f1SLei Zhang //===----------------------------------------------------------------------===//
253930c74f1SLei Zhang // Type conversion
254930c74f1SLei Zhang //===----------------------------------------------------------------------===//
255930c74f1SLei Zhang 
256930c74f1SLei Zhang /// Converts SPIR-V array type to LLVM array. Natural stride (according to
257930c74f1SLei Zhang /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
258930c74f1SLei Zhang /// when converting ops that manipulate array types.
2590de16fafSRamkumar Ramachandra static std::optional<Type> convertArrayType(spirv::ArrayType type,
260930c74f1SLei Zhang                                             TypeConverter &converter) {
261930c74f1SLei Zhang   unsigned stride = type.getArrayStride();
262930c74f1SLei Zhang   Type elementType = type.getElementType();
2635550c821STres Popp   auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
2646ed4310eSMehdi Amini   if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
2651a36588eSKazu Hirata     return std::nullopt;
266930c74f1SLei Zhang 
267c69c9e0fSAlex Zinenko   auto llvmElementType = converter.convertType(elementType);
268930c74f1SLei Zhang   unsigned numElements = type.getNumElements();
269930c74f1SLei Zhang   return LLVM::LLVMArrayType::get(llvmElementType, numElements);
270930c74f1SLei Zhang }
271930c74f1SLei Zhang 
272930c74f1SLei Zhang /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
273930c74f1SLei Zhang /// modelled at the moment.
274930c74f1SLei Zhang static Type convertPointerType(spirv::PointerType type,
27598723e65SMatthias Springer                                const TypeConverter &converter,
27612ce9fd1SVictor Perez                                spirv::ClientAPI clientAPI) {
277d45de800SVictor Perez   unsigned addressSpace =
278d45de800SVictor Perez       storageClassToAddressSpace(clientAPI, type.getStorageClass());
27997a238e8SChristian Ulmann   return LLVM::LLVMPointerType::get(type.getContext(), addressSpace);
280930c74f1SLei Zhang }
281930c74f1SLei Zhang 
282930c74f1SLei Zhang /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
283930c74f1SLei Zhang /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
284930c74f1SLei Zhang /// no modelling of array stride at the moment.
2850de16fafSRamkumar Ramachandra static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
286930c74f1SLei Zhang                                                    TypeConverter &converter) {
287930c74f1SLei Zhang   if (type.getArrayStride() != 0)
2881a36588eSKazu Hirata     return std::nullopt;
289c69c9e0fSAlex Zinenko   auto elementType = converter.convertType(type.getElementType());
290930c74f1SLei Zhang   return LLVM::LLVMArrayType::get(elementType, 0);
291930c74f1SLei Zhang }
292930c74f1SLei Zhang 
293930c74f1SLei Zhang /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
294930c74f1SLei Zhang /// member decorations. Also, only natural offset is supported.
29530ca16ecSPierre van Houtryve static Type convertStructType(spirv::StructType type,
29698723e65SMatthias Springer                               const TypeConverter &converter) {
297930c74f1SLei Zhang   SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
298930c74f1SLei Zhang   type.getMemberDecorations(memberDecorations);
299930c74f1SLei Zhang   if (!memberDecorations.empty())
30030ca16ecSPierre van Houtryve     return nullptr;
301930c74f1SLei Zhang   if (type.hasOffset())
302930c74f1SLei Zhang     return convertStructTypeWithOffset(type, converter);
303930c74f1SLei Zhang   return convertStructTypePacked(type, converter);
304930c74f1SLei Zhang }
305930c74f1SLei Zhang 
306930c74f1SLei Zhang //===----------------------------------------------------------------------===//
307930c74f1SLei Zhang // Operation conversion
308930c74f1SLei Zhang //===----------------------------------------------------------------------===//
309930c74f1SLei Zhang 
310930c74f1SLei Zhang namespace {
311930c74f1SLei Zhang 
312930c74f1SLei Zhang class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
313930c74f1SLei Zhang public:
314930c74f1SLei Zhang   using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
315930c74f1SLei Zhang 
316930c74f1SLei Zhang   LogicalResult
317b54c724bSRiver Riddle   matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
318930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
319206fad0eSMatthias Springer     auto dstType =
320206fad0eSMatthias Springer         getTypeConverter()->convertType(op.getComponentPtr().getType());
321930c74f1SLei Zhang     if (!dstType)
322c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(op, "type conversion failed");
323930c74f1SLei Zhang     // To use GEP we need to add a first 0 index to go through the pointer.
32490a1632dSJakub Kuderski     auto indices = llvm::to_vector<4>(adaptor.getIndices());
32590a1632dSJakub Kuderski     Type indexType = op.getIndices().front().getType();
326206fad0eSMatthias Springer     auto llvmIndexType = getTypeConverter()->convertType(indexType);
327930c74f1SLei Zhang     if (!llvmIndexType)
328c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(op, "type conversion failed");
329930c74f1SLei Zhang     Value zero = rewriter.create<LLVM::ConstantOp>(
330930c74f1SLei Zhang         op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
331930c74f1SLei Zhang     indices.insert(indices.begin(), zero);
332c19436eeSKohei Yamaguchi 
333206fad0eSMatthias Springer     auto elementType = getTypeConverter()->convertType(
334c19436eeSKohei Yamaguchi         cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
335c19436eeSKohei Yamaguchi     if (!elementType)
336c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(op, "type conversion failed");
337c19436eeSKohei Yamaguchi     rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
33819b1e27fSMarkus Böck                                              adaptor.getBasePtr(), indices);
339930c74f1SLei Zhang     return success();
340930c74f1SLei Zhang   }
341930c74f1SLei Zhang };
342930c74f1SLei Zhang 
343930c74f1SLei Zhang class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
344930c74f1SLei Zhang public:
345930c74f1SLei Zhang   using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
346930c74f1SLei Zhang 
347930c74f1SLei Zhang   LogicalResult
348b54c724bSRiver Riddle   matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
349930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
350206fad0eSMatthias Springer     auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
351930c74f1SLei Zhang     if (!dstType)
352c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(op, "type conversion failed");
353162f7572SMahesh Ravishankar     rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
354162f7572SMahesh Ravishankar                                                    op.getVariable());
355930c74f1SLei Zhang     return success();
356930c74f1SLei Zhang   }
357930c74f1SLei Zhang };
358930c74f1SLei Zhang 
359930c74f1SLei Zhang class BitFieldInsertPattern
360930c74f1SLei Zhang     : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
361930c74f1SLei Zhang public:
362930c74f1SLei Zhang   using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
363930c74f1SLei Zhang 
364930c74f1SLei Zhang   LogicalResult
365b54c724bSRiver Riddle   matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
366930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
367930c74f1SLei Zhang     auto srcType = op.getType();
368206fad0eSMatthias Springer     auto dstType = getTypeConverter()->convertType(srcType);
369930c74f1SLei Zhang     if (!dstType)
370c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(op, "type conversion failed");
371930c74f1SLei Zhang     Location loc = op.getLoc();
372930c74f1SLei Zhang 
373930c74f1SLei Zhang     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
37490a1632dSJakub Kuderski     Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
375206fad0eSMatthias Springer                                         *getTypeConverter(), rewriter);
37690a1632dSJakub Kuderski     Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
377206fad0eSMatthias Springer                                        *getTypeConverter(), rewriter);
378930c74f1SLei Zhang 
379930c74f1SLei Zhang     // Create a mask with bits set outside [Offset, Offset + Count - 1].
380930c74f1SLei Zhang     Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
381930c74f1SLei Zhang     Value maskShiftedByCount =
382930c74f1SLei Zhang         rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
383930c74f1SLei Zhang     Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
384930c74f1SLei Zhang                                                  maskShiftedByCount, minusOne);
385930c74f1SLei Zhang     Value maskShiftedByCountAndOffset =
386930c74f1SLei Zhang         rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
387930c74f1SLei Zhang     Value mask = rewriter.create<LLVM::XOrOp>(
388930c74f1SLei Zhang         loc, dstType, maskShiftedByCountAndOffset, minusOne);
389930c74f1SLei Zhang 
390930c74f1SLei Zhang     // Extract unchanged bits from the `Base`  that are outside of
391930c74f1SLei Zhang     // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
392930c74f1SLei Zhang     Value baseAndMask =
39390a1632dSJakub Kuderski         rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
394930c74f1SLei Zhang     Value insertShiftedByOffset =
39590a1632dSJakub Kuderski         rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
396930c74f1SLei Zhang     rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
397930c74f1SLei Zhang                                             insertShiftedByOffset);
398930c74f1SLei Zhang     return success();
399930c74f1SLei Zhang   }
400930c74f1SLei Zhang };
401930c74f1SLei Zhang 
402930c74f1SLei Zhang /// Converts SPIR-V ConstantOp with scalar or vector type.
403930c74f1SLei Zhang class ConstantScalarAndVectorPattern
404930c74f1SLei Zhang     : public SPIRVToLLVMConversion<spirv::ConstantOp> {
405930c74f1SLei Zhang public:
406930c74f1SLei Zhang   using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
407930c74f1SLei Zhang 
408930c74f1SLei Zhang   LogicalResult
409b54c724bSRiver Riddle   matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
410930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
411930c74f1SLei Zhang     auto srcType = constOp.getType();
4125550c821STres Popp     if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
413930c74f1SLei Zhang       return failure();
414930c74f1SLei Zhang 
415206fad0eSMatthias Springer     auto dstType = getTypeConverter()->convertType(srcType);
416930c74f1SLei Zhang     if (!dstType)
417c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(constOp, "type conversion failed");
418930c74f1SLei Zhang 
419930c74f1SLei Zhang     // SPIR-V constant can be a signed/unsigned integer, which has to be
420930c74f1SLei Zhang     // casted to signless integer when converting to LLVM dialect. Removing the
421930c74f1SLei Zhang     // sign bit may have unexpected behaviour. However, it is better to handle
422930c74f1SLei Zhang     // it case-by-case, given that the purpose of the conversion is not to
423930c74f1SLei Zhang     // cover all possible corner cases.
424930c74f1SLei Zhang     if (isSignedIntegerOrVector(srcType) ||
425930c74f1SLei Zhang         isUnsignedIntegerOrVector(srcType)) {
426930c74f1SLei Zhang       auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
427930c74f1SLei Zhang 
4285550c821STres Popp       if (isa<VectorType>(srcType)) {
4295550c821STres Popp         auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
430930c74f1SLei Zhang         rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
431930c74f1SLei Zhang             constOp, dstType,
432930c74f1SLei Zhang             dstElementsAttr.mapValues(
433930c74f1SLei Zhang                 signlessType, [&](const APInt &value) { return value; }));
434930c74f1SLei Zhang         return success();
435930c74f1SLei Zhang       }
4365550c821STres Popp       auto srcAttr = cast<IntegerAttr>(constOp.getValue());
437930c74f1SLei Zhang       auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
438930c74f1SLei Zhang       rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
439930c74f1SLei Zhang       return success();
440930c74f1SLei Zhang     }
441b54c724bSRiver Riddle     rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
442b54c724bSRiver Riddle         constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
443930c74f1SLei Zhang     return success();
444930c74f1SLei Zhang   }
445930c74f1SLei Zhang };
446930c74f1SLei Zhang 
447930c74f1SLei Zhang class BitFieldSExtractPattern
448930c74f1SLei Zhang     : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
449930c74f1SLei Zhang public:
450930c74f1SLei Zhang   using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
451930c74f1SLei Zhang 
452930c74f1SLei Zhang   LogicalResult
453b54c724bSRiver Riddle   matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
454930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
455930c74f1SLei Zhang     auto srcType = op.getType();
456206fad0eSMatthias Springer     auto dstType = getTypeConverter()->convertType(srcType);
457930c74f1SLei Zhang     if (!dstType)
458c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(op, "type conversion failed");
459930c74f1SLei Zhang     Location loc = op.getLoc();
460930c74f1SLei Zhang 
461930c74f1SLei Zhang     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
46290a1632dSJakub Kuderski     Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
463206fad0eSMatthias Springer                                         *getTypeConverter(), rewriter);
46490a1632dSJakub Kuderski     Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
465206fad0eSMatthias Springer                                        *getTypeConverter(), rewriter);
466930c74f1SLei Zhang 
467930c74f1SLei Zhang     // Create a constant that holds the size of the `Base`.
468930c74f1SLei Zhang     IntegerType integerType;
4695550c821STres Popp     if (auto vecType = dyn_cast<VectorType>(srcType))
4705550c821STres Popp       integerType = cast<IntegerType>(vecType.getElementType());
471930c74f1SLei Zhang     else
4725550c821STres Popp       integerType = cast<IntegerType>(srcType);
473930c74f1SLei Zhang 
474930c74f1SLei Zhang     auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
475930c74f1SLei Zhang     Value size =
4765550c821STres Popp         isa<VectorType>(srcType)
477930c74f1SLei Zhang             ? rewriter.create<LLVM::ConstantOp>(
478930c74f1SLei Zhang                   loc, dstType,
4795550c821STres Popp                   SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
480930c74f1SLei Zhang             : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
481930c74f1SLei Zhang 
482930c74f1SLei Zhang     // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
483930c74f1SLei Zhang     // at Offset + Count - 1 is the most significant bit now.
484930c74f1SLei Zhang     Value countPlusOffset =
485930c74f1SLei Zhang         rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
486930c74f1SLei Zhang     Value amountToShiftLeft =
487930c74f1SLei Zhang         rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
488930c74f1SLei Zhang     Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
48990a1632dSJakub Kuderski         loc, dstType, op.getBase(), amountToShiftLeft);
490930c74f1SLei Zhang 
491930c74f1SLei Zhang     // Shift the result right, filling the bits with the sign bit.
492930c74f1SLei Zhang     Value amountToShiftRight =
493930c74f1SLei Zhang         rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
494930c74f1SLei Zhang     rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
495930c74f1SLei Zhang                                               amountToShiftRight);
496930c74f1SLei Zhang     return success();
497930c74f1SLei Zhang   }
498930c74f1SLei Zhang };
499930c74f1SLei Zhang 
500930c74f1SLei Zhang class BitFieldUExtractPattern
501930c74f1SLei Zhang     : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
502930c74f1SLei Zhang public:
503930c74f1SLei Zhang   using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
504930c74f1SLei Zhang 
505930c74f1SLei Zhang   LogicalResult
506b54c724bSRiver Riddle   matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
507930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
508930c74f1SLei Zhang     auto srcType = op.getType();
509206fad0eSMatthias Springer     auto dstType = getTypeConverter()->convertType(srcType);
510930c74f1SLei Zhang     if (!dstType)
511c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(op, "type conversion failed");
512930c74f1SLei Zhang     Location loc = op.getLoc();
513930c74f1SLei Zhang 
514930c74f1SLei Zhang     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
51590a1632dSJakub Kuderski     Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
516206fad0eSMatthias Springer                                         *getTypeConverter(), rewriter);
51790a1632dSJakub Kuderski     Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
518206fad0eSMatthias Springer                                        *getTypeConverter(), rewriter);
519930c74f1SLei Zhang 
520930c74f1SLei Zhang     // Create a mask with bits set at [0, Count - 1].
521930c74f1SLei Zhang     Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
522930c74f1SLei Zhang     Value maskShiftedByCount =
523930c74f1SLei Zhang         rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
524930c74f1SLei Zhang     Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
525930c74f1SLei Zhang                                               minusOne);
526930c74f1SLei Zhang 
527930c74f1SLei Zhang     // Shift `Base` by `Offset` and apply the mask on it.
528930c74f1SLei Zhang     Value shiftedBase =
52990a1632dSJakub Kuderski         rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
530930c74f1SLei Zhang     rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
531930c74f1SLei Zhang     return success();
532930c74f1SLei Zhang   }
533930c74f1SLei Zhang };
534930c74f1SLei Zhang 
535930c74f1SLei Zhang class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
536930c74f1SLei Zhang public:
537930c74f1SLei Zhang   using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion;
538930c74f1SLei Zhang 
539930c74f1SLei Zhang   LogicalResult
540b54c724bSRiver Riddle   matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
541930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
542b54c724bSRiver Riddle     rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
543930c74f1SLei Zhang                                             branchOp.getTarget());
544930c74f1SLei Zhang     return success();
545930c74f1SLei Zhang   }
546930c74f1SLei Zhang };
547930c74f1SLei Zhang 
548930c74f1SLei Zhang class BranchConditionalConversionPattern
549930c74f1SLei Zhang     : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
550930c74f1SLei Zhang public:
551930c74f1SLei Zhang   using SPIRVToLLVMConversion<
552930c74f1SLei Zhang       spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
553930c74f1SLei Zhang 
554930c74f1SLei Zhang   LogicalResult
555b54c724bSRiver Riddle   matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
556930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
557930c74f1SLei Zhang     // If branch weights exist, map them to 32-bit integer vector.
55810fa2770STobias Gysi     DenseI32ArrayAttr branchWeights = nullptr;
55990a1632dSJakub Kuderski     if (auto weights = op.getBranchWeights()) {
56010fa2770STobias Gysi       SmallVector<int32_t> weightValues;
56110fa2770STobias Gysi       for (auto weight : weights->getAsRange<IntegerAttr>())
56210fa2770STobias Gysi         weightValues.push_back(weight.getInt());
56310fa2770STobias Gysi       branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
564930c74f1SLei Zhang     }
565930c74f1SLei Zhang 
566930c74f1SLei Zhang     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
56790a1632dSJakub Kuderski         op, op.getCondition(), op.getTrueBlockArguments(),
568930c74f1SLei Zhang         op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
569930c74f1SLei Zhang         op.getFalseBlock());
570930c74f1SLei Zhang     return success();
571930c74f1SLei Zhang   }
572930c74f1SLei Zhang };
573930c74f1SLei Zhang 
5745ab6ef75SJakub Kuderski /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
5755ab6ef75SJakub Kuderski /// type is an aggregate type (struct or array). Otherwise, converts to
576930c74f1SLei Zhang /// `llvm.extractelement` that operates on vectors.
577930c74f1SLei Zhang class CompositeExtractPattern
578930c74f1SLei Zhang     : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
579930c74f1SLei Zhang public:
580930c74f1SLei Zhang   using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion;
581930c74f1SLei Zhang 
582930c74f1SLei Zhang   LogicalResult
583b54c724bSRiver Riddle   matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
584930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
585206fad0eSMatthias Springer     auto dstType = this->getTypeConverter()->convertType(op.getType());
586930c74f1SLei Zhang     if (!dstType)
587c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(op, "type conversion failed");
588930c74f1SLei Zhang 
58990a1632dSJakub Kuderski     Type containerType = op.getComposite().getType();
5905550c821STres Popp     if (isa<VectorType>(containerType)) {
591930c74f1SLei Zhang       Location loc = op.getLoc();
5925550c821STres Popp       IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
593930c74f1SLei Zhang       Value index = createI32ConstantOf(loc, rewriter, value.getInt());
594930c74f1SLei Zhang       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
59590a1632dSJakub Kuderski           op, dstType, adaptor.getComposite(), index);
596930c74f1SLei Zhang       return success();
597930c74f1SLei Zhang     }
5985c5af910SJeff Niu 
599930c74f1SLei Zhang     rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
600162f7572SMahesh Ravishankar         op, adaptor.getComposite(),
601162f7572SMahesh Ravishankar         LLVM::convertArrayToIndices(op.getIndices()));
602930c74f1SLei Zhang     return success();
603930c74f1SLei Zhang   }
604930c74f1SLei Zhang };
605930c74f1SLei Zhang 
6065ab6ef75SJakub Kuderski /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
6075ab6ef75SJakub Kuderski /// type is an aggregate type (struct or array). Otherwise, converts to
608930c74f1SLei Zhang /// `llvm.insertelement` that operates on vectors.
609930c74f1SLei Zhang class CompositeInsertPattern
610930c74f1SLei Zhang     : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
611930c74f1SLei Zhang public:
612930c74f1SLei Zhang   using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion;
613930c74f1SLei Zhang 
614930c74f1SLei Zhang   LogicalResult
615b54c724bSRiver Riddle   matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
616930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
617206fad0eSMatthias Springer     auto dstType = this->getTypeConverter()->convertType(op.getType());
618930c74f1SLei Zhang     if (!dstType)
619c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(op, "type conversion failed");
620930c74f1SLei Zhang 
62190a1632dSJakub Kuderski     Type containerType = op.getComposite().getType();
6225550c821STres Popp     if (isa<VectorType>(containerType)) {
623930c74f1SLei Zhang       Location loc = op.getLoc();
6245550c821STres Popp       IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
625930c74f1SLei Zhang       Value index = createI32ConstantOf(loc, rewriter, value.getInt());
626930c74f1SLei Zhang       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
62790a1632dSJakub Kuderski           op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
628930c74f1SLei Zhang       return success();
629930c74f1SLei Zhang     }
6305c5af910SJeff Niu 
631930c74f1SLei Zhang     rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
63290a1632dSJakub Kuderski         op, adaptor.getComposite(), adaptor.getObject(),
63390a1632dSJakub Kuderski         LLVM::convertArrayToIndices(op.getIndices()));
634930c74f1SLei Zhang     return success();
635930c74f1SLei Zhang   }
636930c74f1SLei Zhang };
637930c74f1SLei Zhang 
638930c74f1SLei Zhang /// Converts SPIR-V operations that have straightforward LLVM equivalent
639930c74f1SLei Zhang /// into LLVM dialect operations.
640930c74f1SLei Zhang template <typename SPIRVOp, typename LLVMOp>
641930c74f1SLei Zhang class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
642930c74f1SLei Zhang public:
643930c74f1SLei Zhang   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
644930c74f1SLei Zhang 
645930c74f1SLei Zhang   LogicalResult
646c19436eeSKohei Yamaguchi   matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
647930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
648206fad0eSMatthias Springer     auto dstType = this->getTypeConverter()->convertType(op.getType());
649930c74f1SLei Zhang     if (!dstType)
650c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(op, "type conversion failed");
651b54c724bSRiver Riddle     rewriter.template replaceOpWithNewOp<LLVMOp>(
652c19436eeSKohei Yamaguchi         op, dstType, adaptor.getOperands(), op->getAttrs());
653930c74f1SLei Zhang     return success();
654930c74f1SLei Zhang   }
655930c74f1SLei Zhang };
656930c74f1SLei Zhang 
6575ab6ef75SJakub Kuderski /// Converts `spirv.ExecutionMode` into a global struct constant that holds
658930c74f1SLei Zhang /// execution mode information.
659930c74f1SLei Zhang class ExecutionModePattern
660930c74f1SLei Zhang     : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
661930c74f1SLei Zhang public:
662930c74f1SLei Zhang   using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion;
663930c74f1SLei Zhang 
664930c74f1SLei Zhang   LogicalResult
665b54c724bSRiver Riddle   matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
666930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
667930c74f1SLei Zhang     // First, create the global struct's name that would be associated with
668930c74f1SLei Zhang     // this entry point's execution mode. We set it to be:
669c0a6381eSWeiwei Li     //   __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
670930c74f1SLei Zhang     ModuleOp module = op->getParentOfType<ModuleOp>();
67190a1632dSJakub Kuderski     spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
672930c74f1SLei Zhang     std::string moduleName;
673491d2701SKazu Hirata     if (module.getName().has_value())
67490a1632dSJakub Kuderski       moduleName = "_" + module.getName()->str();
675930c74f1SLei Zhang     else
676930c74f1SLei Zhang       moduleName = "";
677a29fffc4SLei Zhang     std::string executionModeInfoName = llvm::formatv(
67890a1632dSJakub Kuderski         "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
679a29fffc4SLei Zhang         static_cast<uint32_t>(executionModeAttr.getValue()));
680930c74f1SLei Zhang 
681930c74f1SLei Zhang     MLIRContext *context = rewriter.getContext();
682930c74f1SLei Zhang     OpBuilder::InsertionGuard guard(rewriter);
683930c74f1SLei Zhang     rewriter.setInsertionPointToStart(module.getBody());
684930c74f1SLei Zhang 
685930c74f1SLei Zhang     // Create a struct type, corresponding to the C struct below.
686930c74f1SLei Zhang     // struct {
687930c74f1SLei Zhang     //   int32_t executionMode;
688930c74f1SLei Zhang     //   int32_t values[];          // optional values
689930c74f1SLei Zhang     // };
6902230bf99SAlex Zinenko     auto llvmI32Type = IntegerType::get(context, 32);
691c69c9e0fSAlex Zinenko     SmallVector<Type, 2> fields;
692930c74f1SLei Zhang     fields.push_back(llvmI32Type);
69390a1632dSJakub Kuderski     ArrayAttr values = op.getValues();
694930c74f1SLei Zhang     if (!values.empty()) {
695930c74f1SLei Zhang       auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
696930c74f1SLei Zhang       fields.push_back(arrayType);
697930c74f1SLei Zhang     }
698930c74f1SLei Zhang     auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
699930c74f1SLei Zhang 
700930c74f1SLei Zhang     // Create `llvm.mlir.global` with initializer region containing one block.
701930c74f1SLei Zhang     auto global = rewriter.create<LLVM::GlobalOp>(
702930c74f1SLei Zhang         UnknownLoc::get(context), structType, /*isConstant=*/true,
7039a0ea599SDumitru Potop         LLVM::Linkage::External, executionModeInfoName, Attribute(),
7049a0ea599SDumitru Potop         /*alignment=*/0);
705930c74f1SLei Zhang     Location loc = global.getLoc();
706930c74f1SLei Zhang     Region &region = global.getInitializerRegion();
707930c74f1SLei Zhang     Block *block = rewriter.createBlock(&region);
708930c74f1SLei Zhang 
709930c74f1SLei Zhang     // Initialize the struct and set the execution mode value.
710b613a540SMatthias Springer     rewriter.setInsertionPointToStart(block);
711930c74f1SLei Zhang     Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
712a29fffc4SLei Zhang     Value executionMode = rewriter.create<LLVM::ConstantOp>(
713a29fffc4SLei Zhang         loc, llvmI32Type,
714a29fffc4SLei Zhang         rewriter.getI32IntegerAttr(
715a29fffc4SLei Zhang             static_cast<uint32_t>(executionModeAttr.getValue())));
7165c5af910SJeff Niu     structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue,
7175c5af910SJeff Niu                                                        executionMode, 0);
718930c74f1SLei Zhang 
719930c74f1SLei Zhang     // Insert extra operands if they exist into execution mode info struct.
720930c74f1SLei Zhang     for (unsigned i = 0, e = values.size(); i < e; ++i) {
721930c74f1SLei Zhang       auto attr = values.getValue()[i];
722930c74f1SLei Zhang       Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
723930c74f1SLei Zhang       structValue = rewriter.create<LLVM::InsertValueOp>(
72482973067SUday Bondhugula           loc, structValue, entry, ArrayRef<int64_t>({1, i}));
725930c74f1SLei Zhang     }
726930c74f1SLei Zhang     rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
727930c74f1SLei Zhang     rewriter.eraseOp(op);
728930c74f1SLei Zhang     return success();
729930c74f1SLei Zhang   }
730930c74f1SLei Zhang };
731930c74f1SLei Zhang 
7325ab6ef75SJakub Kuderski /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
7335ab6ef75SJakub Kuderski /// global returns a pointer, whereas in LLVM dialect the global holds an actual
7345ab6ef75SJakub Kuderski /// value. This difference is handled by `spirv.mlir.addressof` and
735930c74f1SLei Zhang /// `llvm.mlir.addressof`ops that both return a pointer.
736930c74f1SLei Zhang class GlobalVariablePattern
737930c74f1SLei Zhang     : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
738930c74f1SLei Zhang public:
73912ce9fd1SVictor Perez   template <typename... Args>
74012ce9fd1SVictor Perez   GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
74112ce9fd1SVictor Perez       : SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
74212ce9fd1SVictor Perez             std::forward<Args>(args)...),
74312ce9fd1SVictor Perez         clientAPI(clientAPI) {}
744930c74f1SLei Zhang 
745930c74f1SLei Zhang   LogicalResult
746b54c724bSRiver Riddle   matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
747930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
748930c74f1SLei Zhang     // Currently, there is no support of initialization with a constant value in
749930c74f1SLei Zhang     // SPIR-V dialect. Specialization constants are not considered as well.
75090a1632dSJakub Kuderski     if (op.getInitializer())
751930c74f1SLei Zhang       return failure();
752930c74f1SLei Zhang 
7535550c821STres Popp     auto srcType = cast<spirv::PointerType>(op.getType());
754206fad0eSMatthias Springer     auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
755930c74f1SLei Zhang     if (!dstType)
756c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(op, "type conversion failed");
757930c74f1SLei Zhang 
758930c74f1SLei Zhang     // Limit conversion to the current invocation only or `StorageBuffer`
759930c74f1SLei Zhang     // required by SPIR-V runner.
760930c74f1SLei Zhang     // This is okay because multiple invocations are not supported yet.
761930c74f1SLei Zhang     auto storageClass = srcType.getStorageClass();
7621e4cfe5eSWeiwei Li     switch (storageClass) {
7631e4cfe5eSWeiwei Li     case spirv::StorageClass::Input:
7641e4cfe5eSWeiwei Li     case spirv::StorageClass::Private:
7651e4cfe5eSWeiwei Li     case spirv::StorageClass::Output:
7661e4cfe5eSWeiwei Li     case spirv::StorageClass::StorageBuffer:
7671e4cfe5eSWeiwei Li     case spirv::StorageClass::UniformConstant:
7681e4cfe5eSWeiwei Li       break;
7691e4cfe5eSWeiwei Li     default:
770930c74f1SLei Zhang       return failure();
771930c74f1SLei Zhang     }
772930c74f1SLei Zhang 
773930c74f1SLei Zhang     // LLVM dialect spec: "If the global value is a constant, storing into it is
7741e4cfe5eSWeiwei Li     // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
7751e4cfe5eSWeiwei Li     // storage class that is read-only.
7761e4cfe5eSWeiwei Li     bool isConstant = (storageClass == spirv::StorageClass::Input) ||
7771e4cfe5eSWeiwei Li                       (storageClass == spirv::StorageClass::UniformConstant);
778930c74f1SLei Zhang     // SPIR-V spec: "By default, functions and global variables are private to a
779930c74f1SLei Zhang     // module and cannot be accessed by other modules. However, a module may be
780930c74f1SLei Zhang     // written to export or import functions and global (module scope)
781930c74f1SLei Zhang     // variables.". Therefore, map 'Private' storage class to private linkage,
782930c74f1SLei Zhang     // 'Input' and 'Output' to external linkage.
783930c74f1SLei Zhang     auto linkage = storageClass == spirv::StorageClass::Private
784930c74f1SLei Zhang                        ? LLVM::Linkage::Private
785930c74f1SLei Zhang                        : LLVM::Linkage::External;
7861e4cfe5eSWeiwei Li     auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
78790a1632dSJakub Kuderski         op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
788d45de800SVictor Perez         /*alignment=*/0, storageClassToAddressSpace(clientAPI, storageClass));
7891e4cfe5eSWeiwei Li 
7901e4cfe5eSWeiwei Li     // Attach location attribute if applicable
79190a1632dSJakub Kuderski     if (op.getLocationAttr())
79290a1632dSJakub Kuderski       newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
7931e4cfe5eSWeiwei Li 
794930c74f1SLei Zhang     return success();
795930c74f1SLei Zhang   }
79612ce9fd1SVictor Perez 
79712ce9fd1SVictor Perez private:
79812ce9fd1SVictor Perez   spirv::ClientAPI clientAPI;
799930c74f1SLei Zhang };
800930c74f1SLei Zhang 
801930c74f1SLei Zhang /// Converts SPIR-V cast ops that do not have straightforward LLVM
802930c74f1SLei Zhang /// equivalent in LLVM dialect.
803930c74f1SLei Zhang template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
804930c74f1SLei Zhang class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
805930c74f1SLei Zhang public:
806930c74f1SLei Zhang   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
807930c74f1SLei Zhang 
808930c74f1SLei Zhang   LogicalResult
809c19436eeSKohei Yamaguchi   matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
810930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
811930c74f1SLei Zhang 
812c19436eeSKohei Yamaguchi     Type fromType = op.getOperand().getType();
813c19436eeSKohei Yamaguchi     Type toType = op.getType();
814930c74f1SLei Zhang 
815206fad0eSMatthias Springer     auto dstType = this->getTypeConverter()->convertType(toType);
816930c74f1SLei Zhang     if (!dstType)
817c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(op, "type conversion failed");
818930c74f1SLei Zhang 
819930c74f1SLei Zhang     if (getBitWidth(fromType) < getBitWidth(toType)) {
820c19436eeSKohei Yamaguchi       rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
821b54c724bSRiver Riddle                                                       adaptor.getOperands());
822930c74f1SLei Zhang       return success();
823930c74f1SLei Zhang     }
824930c74f1SLei Zhang     if (getBitWidth(fromType) > getBitWidth(toType)) {
825c19436eeSKohei Yamaguchi       rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
826b54c724bSRiver Riddle                                                         adaptor.getOperands());
827930c74f1SLei Zhang       return success();
828930c74f1SLei Zhang     }
829930c74f1SLei Zhang     return failure();
830930c74f1SLei Zhang   }
831930c74f1SLei Zhang };
832930c74f1SLei Zhang 
833930c74f1SLei Zhang class FunctionCallPattern
834930c74f1SLei Zhang     : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
835930c74f1SLei Zhang public:
836930c74f1SLei Zhang   using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion;
837930c74f1SLei Zhang 
838930c74f1SLei Zhang   LogicalResult
839b54c724bSRiver Riddle   matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
840930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
841930c74f1SLei Zhang     if (callOp.getNumResults() == 0) {
842fde3c16aSSirui Mu       auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
8431a36588eSKazu Hirata           callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
844fde3c16aSSirui Mu       newOp.getProperties().operandSegmentSizes = {
845fde3c16aSSirui Mu           static_cast<int32_t>(adaptor.getOperands().size()), 0};
846fde3c16aSSirui Mu       newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
847930c74f1SLei Zhang       return success();
848930c74f1SLei Zhang     }
849930c74f1SLei Zhang 
850930c74f1SLei Zhang     // Function returns a single result.
851206fad0eSMatthias Springer     auto dstType = getTypeConverter()->convertType(callOp.getType(0));
852c19436eeSKohei Yamaguchi     if (!dstType)
853c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(callOp, "type conversion failed");
854fde3c16aSSirui Mu     auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
855b54c724bSRiver Riddle         callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
856fde3c16aSSirui Mu     newOp.getProperties().operandSegmentSizes = {
857fde3c16aSSirui Mu         static_cast<int32_t>(adaptor.getOperands().size()), 0};
858fde3c16aSSirui Mu     newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
859930c74f1SLei Zhang     return success();
860930c74f1SLei Zhang   }
861930c74f1SLei Zhang };
862930c74f1SLei Zhang 
863930c74f1SLei Zhang /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
864930c74f1SLei Zhang template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
865930c74f1SLei Zhang class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
866930c74f1SLei Zhang public:
867930c74f1SLei Zhang   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
868930c74f1SLei Zhang 
869930c74f1SLei Zhang   LogicalResult
870c19436eeSKohei Yamaguchi   matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
871930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
872930c74f1SLei Zhang 
873206fad0eSMatthias Springer     auto dstType = this->getTypeConverter()->convertType(op.getType());
874930c74f1SLei Zhang     if (!dstType)
875c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(op, "type conversion failed");
876930c74f1SLei Zhang 
877930c74f1SLei Zhang     rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
878c19436eeSKohei Yamaguchi         op, dstType, predicate, op.getOperand1(), op.getOperand2());
879930c74f1SLei Zhang     return success();
880930c74f1SLei Zhang   }
881930c74f1SLei Zhang };
882930c74f1SLei Zhang 
883930c74f1SLei Zhang /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
884930c74f1SLei Zhang template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
885930c74f1SLei Zhang class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
886930c74f1SLei Zhang public:
887930c74f1SLei Zhang   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
888930c74f1SLei Zhang 
889930c74f1SLei Zhang   LogicalResult
890c19436eeSKohei Yamaguchi   matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
891930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
892930c74f1SLei Zhang 
893206fad0eSMatthias Springer     auto dstType = this->getTypeConverter()->convertType(op.getType());
894930c74f1SLei Zhang     if (!dstType)
895c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(op, "type conversion failed");
896930c74f1SLei Zhang 
897930c74f1SLei Zhang     rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
898c19436eeSKohei Yamaguchi         op, dstType, predicate, op.getOperand1(), op.getOperand2());
899930c74f1SLei Zhang     return success();
900930c74f1SLei Zhang   }
901930c74f1SLei Zhang };
902930c74f1SLei Zhang 
903930c74f1SLei Zhang class InverseSqrtPattern
90452b630daSJakub Kuderski     : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
905930c74f1SLei Zhang public:
90652b630daSJakub Kuderski   using SPIRVToLLVMConversion<spirv::GLInverseSqrtOp>::SPIRVToLLVMConversion;
907930c74f1SLei Zhang 
908930c74f1SLei Zhang   LogicalResult
90952b630daSJakub Kuderski   matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
910930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
911930c74f1SLei Zhang     auto srcType = op.getType();
912206fad0eSMatthias Springer     auto dstType = getTypeConverter()->convertType(srcType);
913930c74f1SLei Zhang     if (!dstType)
914c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(op, "type conversion failed");
915930c74f1SLei Zhang 
916930c74f1SLei Zhang     Location loc = op.getLoc();
917930c74f1SLei Zhang     Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
91890a1632dSJakub Kuderski     Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
919930c74f1SLei Zhang     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
920930c74f1SLei Zhang     return success();
921930c74f1SLei Zhang   }
922930c74f1SLei Zhang };
923930c74f1SLei Zhang 
9245ab6ef75SJakub Kuderski /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
925b54c724bSRiver Riddle template <typename SPIRVOp>
926b54c724bSRiver Riddle class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
927930c74f1SLei Zhang public:
928b54c724bSRiver Riddle   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
929930c74f1SLei Zhang 
930930c74f1SLei Zhang   LogicalResult
931b54c724bSRiver Riddle   matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
932930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
93390a1632dSJakub Kuderski     if (!op.getMemoryAccess()) {
934015192c6SRiver Riddle       return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
935206fad0eSMatthias Springer                                     *this->getTypeConverter(), /*alignment=*/0,
936015192c6SRiver Riddle                                     /*isVolatile=*/false,
937015192c6SRiver Riddle                                     /*isNonTemporal=*/false);
938930c74f1SLei Zhang     }
93990a1632dSJakub Kuderski     auto memoryAccess = *op.getMemoryAccess();
940930c74f1SLei Zhang     switch (memoryAccess) {
941930c74f1SLei Zhang     case spirv::MemoryAccess::Aligned:
942930c74f1SLei Zhang     case spirv::MemoryAccess::None:
943930c74f1SLei Zhang     case spirv::MemoryAccess::Nontemporal:
944930c74f1SLei Zhang     case spirv::MemoryAccess::Volatile: {
945930c74f1SLei Zhang       unsigned alignment =
94690a1632dSJakub Kuderski           memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
947930c74f1SLei Zhang       bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
948930c74f1SLei Zhang       bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
949015192c6SRiver Riddle       return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
950206fad0eSMatthias Springer                                     *this->getTypeConverter(), alignment,
951206fad0eSMatthias Springer                                     isVolatile, isNonTemporal);
952930c74f1SLei Zhang     }
953930c74f1SLei Zhang     default:
954930c74f1SLei Zhang       // There is no support of other memory access attributes.
955930c74f1SLei Zhang       return failure();
956930c74f1SLei Zhang     }
957930c74f1SLei Zhang   }
958930c74f1SLei Zhang };
959930c74f1SLei Zhang 
9605ab6ef75SJakub Kuderski /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
961930c74f1SLei Zhang template <typename SPIRVOp>
962930c74f1SLei Zhang class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
963930c74f1SLei Zhang public:
964930c74f1SLei Zhang   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
965930c74f1SLei Zhang 
966930c74f1SLei Zhang   LogicalResult
967b54c724bSRiver Riddle   matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
968930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
969930c74f1SLei Zhang     auto srcType = notOp.getType();
970206fad0eSMatthias Springer     auto dstType = this->getTypeConverter()->convertType(srcType);
971930c74f1SLei Zhang     if (!dstType)
972c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(notOp, "type conversion failed");
973930c74f1SLei Zhang 
974930c74f1SLei Zhang     Location loc = notOp.getLoc();
975930c74f1SLei Zhang     IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
9765550c821STres Popp     auto mask =
9775550c821STres Popp         isa<VectorType>(srcType)
978930c74f1SLei Zhang             ? rewriter.create<LLVM::ConstantOp>(
979930c74f1SLei Zhang                   loc, dstType,
9805550c821STres Popp                   SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
981930c74f1SLei Zhang             : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
982930c74f1SLei Zhang     rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
98390a1632dSJakub Kuderski                                                       notOp.getOperand(), mask);
984930c74f1SLei Zhang     return success();
985930c74f1SLei Zhang   }
986930c74f1SLei Zhang };
987930c74f1SLei Zhang 
988930c74f1SLei Zhang /// A template pattern that erases the given `SPIRVOp`.
989930c74f1SLei Zhang template <typename SPIRVOp>
990930c74f1SLei Zhang class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
991930c74f1SLei Zhang public:
992930c74f1SLei Zhang   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
993930c74f1SLei Zhang 
994930c74f1SLei Zhang   LogicalResult
995b54c724bSRiver Riddle   matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
996930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
997930c74f1SLei Zhang     rewriter.eraseOp(op);
998930c74f1SLei Zhang     return success();
999930c74f1SLei Zhang   }
1000930c74f1SLei Zhang };
1001930c74f1SLei Zhang 
1002930c74f1SLei Zhang class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
1003930c74f1SLei Zhang public:
1004930c74f1SLei Zhang   using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
1005930c74f1SLei Zhang 
1006930c74f1SLei Zhang   LogicalResult
1007b54c724bSRiver Riddle   matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1008930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
1009930c74f1SLei Zhang     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
1010930c74f1SLei Zhang                                                 ArrayRef<Value>());
1011930c74f1SLei Zhang     return success();
1012930c74f1SLei Zhang   }
1013930c74f1SLei Zhang };
1014930c74f1SLei Zhang 
1015930c74f1SLei Zhang class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
1016930c74f1SLei Zhang public:
1017930c74f1SLei Zhang   using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
1018930c74f1SLei Zhang 
1019930c74f1SLei Zhang   LogicalResult
1020b54c724bSRiver Riddle   matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1021930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
1022930c74f1SLei Zhang     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
1023b54c724bSRiver Riddle                                                 adaptor.getOperands());
1024930c74f1SLei Zhang     return success();
1025930c74f1SLei Zhang   }
1026930c74f1SLei Zhang };
1027930c74f1SLei Zhang 
10281775b98dSFinlay static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
10291775b98dSFinlay                                               StringRef name,
10301775b98dSFinlay                                               ArrayRef<Type> paramTypes,
10316ade03d7SLukas Sommer                                               Type resultType,
10326ade03d7SLukas Sommer                                               bool convergent = true) {
10331775b98dSFinlay   auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
10341775b98dSFinlay       SymbolTable::lookupSymbolIn(symbolTable, name));
10351775b98dSFinlay   if (func)
10361775b98dSFinlay     return func;
10371775b98dSFinlay 
10381775b98dSFinlay   OpBuilder b(symbolTable->getRegion(0));
10391775b98dSFinlay   func = b.create<LLVM::LLVMFuncOp>(
10401775b98dSFinlay       symbolTable->getLoc(), name,
10411775b98dSFinlay       LLVM::LLVMFunctionType::get(resultType, paramTypes));
10421775b98dSFinlay   func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
10436ade03d7SLukas Sommer   func.setConvergent(convergent);
10441775b98dSFinlay   func.setNoUnwind(true);
10451775b98dSFinlay   func.setWillReturn(true);
10461775b98dSFinlay   return func;
10471775b98dSFinlay }
10481775b98dSFinlay 
10491775b98dSFinlay static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
10501775b98dSFinlay                                            LLVM::LLVMFuncOp func,
10511775b98dSFinlay                                            ValueRange args) {
10521775b98dSFinlay   auto call = builder.create<LLVM::CallOp>(loc, func, args);
10531775b98dSFinlay   call.setCConv(func.getCConv());
10541775b98dSFinlay   call.setConvergentAttr(func.getConvergentAttr());
10551775b98dSFinlay   call.setNoUnwindAttr(func.getNoUnwindAttr());
10561775b98dSFinlay   call.setWillReturnAttr(func.getWillReturnAttr());
10571775b98dSFinlay   return call;
10581775b98dSFinlay }
10591775b98dSFinlay 
106005fcdd55SVictor Perez template <typename BarrierOpTy>
106105fcdd55SVictor Perez class ControlBarrierPattern : public SPIRVToLLVMConversion<BarrierOpTy> {
10621775b98dSFinlay public:
106305fcdd55SVictor Perez   using OpAdaptor = typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor;
106405fcdd55SVictor Perez 
106505fcdd55SVictor Perez   using SPIRVToLLVMConversion<BarrierOpTy>::SPIRVToLLVMConversion;
106605fcdd55SVictor Perez 
106705fcdd55SVictor Perez   static constexpr StringRef getFuncName();
10681775b98dSFinlay 
10691775b98dSFinlay   LogicalResult
107005fcdd55SVictor Perez   matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
10711775b98dSFinlay                   ConversionPatternRewriter &rewriter) const override {
107205fcdd55SVictor Perez     constexpr StringRef funcName = getFuncName();
10731775b98dSFinlay     Operation *symbolTable =
107405fcdd55SVictor Perez         controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();
10751775b98dSFinlay 
10761775b98dSFinlay     Type i32 = rewriter.getI32Type();
10771775b98dSFinlay 
10781775b98dSFinlay     Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
10791775b98dSFinlay     LLVM::LLVMFuncOp func =
10801775b98dSFinlay         lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
10811775b98dSFinlay 
10821775b98dSFinlay     Location loc = controlBarrierOp->getLoc();
10831775b98dSFinlay     Value execution = rewriter.create<LLVM::ConstantOp>(
10841775b98dSFinlay         loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
10851775b98dSFinlay     Value memory = rewriter.create<LLVM::ConstantOp>(
10861775b98dSFinlay         loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
10871775b98dSFinlay     Value semantics = rewriter.create<LLVM::ConstantOp>(
10881775b98dSFinlay         loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
10891775b98dSFinlay 
10901775b98dSFinlay     auto call = createSPIRVBuiltinCall(loc, rewriter, func,
10911775b98dSFinlay                                        {execution, memory, semantics});
10921775b98dSFinlay 
10931775b98dSFinlay     rewriter.replaceOp(controlBarrierOp, call);
10941775b98dSFinlay     return success();
10951775b98dSFinlay   }
10961775b98dSFinlay };
10971775b98dSFinlay 
10986ade03d7SLukas Sommer namespace {
10996ade03d7SLukas Sommer 
11006ade03d7SLukas Sommer StringRef getTypeMangling(Type type, bool isSigned) {
11016ade03d7SLukas Sommer   return llvm::TypeSwitch<Type, StringRef>(type)
11026ade03d7SLukas Sommer       .Case<Float16Type>([](auto) { return "Dh"; })
11036ade03d7SLukas Sommer       .Case<Float32Type>([](auto) { return "f"; })
11046ade03d7SLukas Sommer       .Case<Float64Type>([](auto) { return "d"; })
11056ade03d7SLukas Sommer       .Case<IntegerType>([isSigned](IntegerType intTy) {
11066ade03d7SLukas Sommer         switch (intTy.getWidth()) {
11076ade03d7SLukas Sommer         case 1:
11086ade03d7SLukas Sommer           return "b";
11096ade03d7SLukas Sommer         case 8:
11106ade03d7SLukas Sommer           return (isSigned) ? "a" : "c";
11116ade03d7SLukas Sommer         case 16:
11126ade03d7SLukas Sommer           return (isSigned) ? "s" : "t";
11136ade03d7SLukas Sommer         case 32:
11146ade03d7SLukas Sommer           return (isSigned) ? "i" : "j";
11156ade03d7SLukas Sommer         case 64:
11166ade03d7SLukas Sommer           return (isSigned) ? "l" : "m";
11176ade03d7SLukas Sommer         default:
11186ade03d7SLukas Sommer           llvm_unreachable("Unsupported integer width");
11196ade03d7SLukas Sommer         }
11206ade03d7SLukas Sommer       })
11216ade03d7SLukas Sommer       .Default([](auto) {
11226ade03d7SLukas Sommer         llvm_unreachable("No mangling defined");
11236ade03d7SLukas Sommer         return "";
11246ade03d7SLukas Sommer       });
11256ade03d7SLukas Sommer }
11266ade03d7SLukas Sommer 
11276ade03d7SLukas Sommer template <typename ReduceOp>
11286ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName();
11296ade03d7SLukas Sommer 
11306ade03d7SLukas Sommer template <>
11316ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
11326ade03d7SLukas Sommer   return "_Z17__spirv_GroupIAddii";
11336ade03d7SLukas Sommer }
11346ade03d7SLukas Sommer template <>
11356ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
11366ade03d7SLukas Sommer   return "_Z17__spirv_GroupFAddii";
11376ade03d7SLukas Sommer }
11386ade03d7SLukas Sommer template <>
11396ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
11406ade03d7SLukas Sommer   return "_Z17__spirv_GroupSMinii";
11416ade03d7SLukas Sommer }
11426ade03d7SLukas Sommer template <>
11436ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
11446ade03d7SLukas Sommer   return "_Z17__spirv_GroupUMinii";
11456ade03d7SLukas Sommer }
11466ade03d7SLukas Sommer template <>
11476ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
11486ade03d7SLukas Sommer   return "_Z17__spirv_GroupFMinii";
11496ade03d7SLukas Sommer }
11506ade03d7SLukas Sommer template <>
11516ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
11526ade03d7SLukas Sommer   return "_Z17__spirv_GroupSMaxii";
11536ade03d7SLukas Sommer }
11546ade03d7SLukas Sommer template <>
11556ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
11566ade03d7SLukas Sommer   return "_Z17__spirv_GroupUMaxii";
11576ade03d7SLukas Sommer }
11586ade03d7SLukas Sommer template <>
11596ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
11606ade03d7SLukas Sommer   return "_Z17__spirv_GroupFMaxii";
11616ade03d7SLukas Sommer }
11626ade03d7SLukas Sommer template <>
11636ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
11646ade03d7SLukas Sommer   return "_Z27__spirv_GroupNonUniformIAddii";
11656ade03d7SLukas Sommer }
11666ade03d7SLukas Sommer template <>
11676ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
11686ade03d7SLukas Sommer   return "_Z27__spirv_GroupNonUniformFAddii";
11696ade03d7SLukas Sommer }
11706ade03d7SLukas Sommer template <>
11716ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
11726ade03d7SLukas Sommer   return "_Z27__spirv_GroupNonUniformIMulii";
11736ade03d7SLukas Sommer }
11746ade03d7SLukas Sommer template <>
11756ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
11766ade03d7SLukas Sommer   return "_Z27__spirv_GroupNonUniformFMulii";
11776ade03d7SLukas Sommer }
11786ade03d7SLukas Sommer template <>
11796ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
11806ade03d7SLukas Sommer   return "_Z27__spirv_GroupNonUniformSMinii";
11816ade03d7SLukas Sommer }
11826ade03d7SLukas Sommer template <>
11836ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
11846ade03d7SLukas Sommer   return "_Z27__spirv_GroupNonUniformUMinii";
11856ade03d7SLukas Sommer }
11866ade03d7SLukas Sommer template <>
11876ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
11886ade03d7SLukas Sommer   return "_Z27__spirv_GroupNonUniformFMinii";
11896ade03d7SLukas Sommer }
11906ade03d7SLukas Sommer template <>
11916ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
11926ade03d7SLukas Sommer   return "_Z27__spirv_GroupNonUniformSMaxii";
11936ade03d7SLukas Sommer }
11946ade03d7SLukas Sommer template <>
11956ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
11966ade03d7SLukas Sommer   return "_Z27__spirv_GroupNonUniformUMaxii";
11976ade03d7SLukas Sommer }
11986ade03d7SLukas Sommer template <>
11996ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
12006ade03d7SLukas Sommer   return "_Z27__spirv_GroupNonUniformFMaxii";
12016ade03d7SLukas Sommer }
12026ade03d7SLukas Sommer template <>
12036ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
12046ade03d7SLukas Sommer   return "_Z33__spirv_GroupNonUniformBitwiseAndii";
12056ade03d7SLukas Sommer }
12066ade03d7SLukas Sommer template <>
12076ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
12086ade03d7SLukas Sommer   return "_Z32__spirv_GroupNonUniformBitwiseOrii";
12096ade03d7SLukas Sommer }
12106ade03d7SLukas Sommer template <>
12116ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
12126ade03d7SLukas Sommer   return "_Z33__spirv_GroupNonUniformBitwiseXorii";
12136ade03d7SLukas Sommer }
12146ade03d7SLukas Sommer template <>
12156ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
12166ade03d7SLukas Sommer   return "_Z33__spirv_GroupNonUniformLogicalAndii";
12176ade03d7SLukas Sommer }
12186ade03d7SLukas Sommer template <>
12196ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
12206ade03d7SLukas Sommer   return "_Z32__spirv_GroupNonUniformLogicalOrii";
12216ade03d7SLukas Sommer }
12226ade03d7SLukas Sommer template <>
12236ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
12246ade03d7SLukas Sommer   return "_Z33__spirv_GroupNonUniformLogicalXorii";
12256ade03d7SLukas Sommer }
12266ade03d7SLukas Sommer } // namespace
12276ade03d7SLukas Sommer 
12286ade03d7SLukas Sommer template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
12296ade03d7SLukas Sommer class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
12306ade03d7SLukas Sommer public:
12316ade03d7SLukas Sommer   using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion;
12326ade03d7SLukas Sommer 
12336ade03d7SLukas Sommer   LogicalResult
12346ade03d7SLukas Sommer   matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
12356ade03d7SLukas Sommer                   ConversionPatternRewriter &rewriter) const override {
12366ade03d7SLukas Sommer 
12376ade03d7SLukas Sommer     Type retTy = op.getResult().getType();
12386ade03d7SLukas Sommer     if (!retTy.isIntOrFloat()) {
12396ade03d7SLukas Sommer       return failure();
12406ade03d7SLukas Sommer     }
12416ade03d7SLukas Sommer     SmallString<36> funcName = getGroupFuncName<ReduceOp>();
12426ade03d7SLukas Sommer     funcName += getTypeMangling(retTy, false);
12436ade03d7SLukas Sommer 
12446ade03d7SLukas Sommer     Type i32Ty = rewriter.getI32Type();
12456ade03d7SLukas Sommer     SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
12466ade03d7SLukas Sommer     if constexpr (NonUniform) {
12476ade03d7SLukas Sommer       if (adaptor.getClusterSize()) {
12486ade03d7SLukas Sommer         funcName += "j";
12496ade03d7SLukas Sommer         paramTypes.push_back(i32Ty);
12506ade03d7SLukas Sommer       }
12516ade03d7SLukas Sommer     }
12526ade03d7SLukas Sommer 
12536ade03d7SLukas Sommer     Operation *symbolTable =
12546ade03d7SLukas Sommer         op->template getParentWithTrait<OpTrait::SymbolTable>();
12556ade03d7SLukas Sommer 
1256*4adeb6cfSLukas Sommer     LLVM::LLVMFuncOp func =
1257*4adeb6cfSLukas Sommer         lookupOrCreateSPIRVFn(symbolTable, funcName, paramTypes, retTy);
12586ade03d7SLukas Sommer 
12596ade03d7SLukas Sommer     Location loc = op.getLoc();
12606ade03d7SLukas Sommer     Value scope = rewriter.create<LLVM::ConstantOp>(
12616ade03d7SLukas Sommer         loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope()));
12626ade03d7SLukas Sommer     Value groupOp = rewriter.create<LLVM::ConstantOp>(
12636ade03d7SLukas Sommer         loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation()));
12646ade03d7SLukas Sommer     SmallVector<Value> operands{scope, groupOp};
12656ade03d7SLukas Sommer     operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
12666ade03d7SLukas Sommer 
12676ade03d7SLukas Sommer     auto call = createSPIRVBuiltinCall(loc, rewriter, func, operands);
12686ade03d7SLukas Sommer     rewriter.replaceOp(op, call);
12696ade03d7SLukas Sommer     return success();
12706ade03d7SLukas Sommer   }
12716ade03d7SLukas Sommer };
12726ade03d7SLukas Sommer 
127305fcdd55SVictor Perez template <>
127405fcdd55SVictor Perez constexpr StringRef
127505fcdd55SVictor Perez ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
127605fcdd55SVictor Perez   return "_Z22__spirv_ControlBarrieriii";
127705fcdd55SVictor Perez }
127805fcdd55SVictor Perez 
127905fcdd55SVictor Perez template <>
128005fcdd55SVictor Perez constexpr StringRef
128105fcdd55SVictor Perez ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
128205fcdd55SVictor Perez   return "_Z33__spirv_ControlBarrierArriveINTELiii";
128305fcdd55SVictor Perez }
128405fcdd55SVictor Perez 
128505fcdd55SVictor Perez template <>
128605fcdd55SVictor Perez constexpr StringRef
128705fcdd55SVictor Perez ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
128805fcdd55SVictor Perez   return "_Z31__spirv_ControlBarrierWaitINTELiii";
128905fcdd55SVictor Perez }
129005fcdd55SVictor Perez 
12915ab6ef75SJakub Kuderski /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
12925ab6ef75SJakub Kuderski /// should be reachable for conversion to succeed. The structure of the loop in
12935ab6ef75SJakub Kuderski /// LLVM dialect will be the following:
1294930c74f1SLei Zhang ///
1295930c74f1SLei Zhang ///      +------------------------------------+
12965ab6ef75SJakub Kuderski ///      | <code before spirv.mlir.loop>        |
1297930c74f1SLei Zhang ///      | llvm.br ^header                    |
1298930c74f1SLei Zhang ///      +------------------------------------+
1299930c74f1SLei Zhang ///                           |
1300930c74f1SLei Zhang ///   +----------------+      |
1301930c74f1SLei Zhang ///   |                |      |
1302930c74f1SLei Zhang ///   |                V      V
1303930c74f1SLei Zhang ///   |  +------------------------------------+
1304930c74f1SLei Zhang ///   |  | ^header:                           |
1305930c74f1SLei Zhang ///   |  |   <header code>                    |
1306930c74f1SLei Zhang ///   |  |   llvm.cond_br %cond, ^body, ^exit |
1307930c74f1SLei Zhang ///   |  +------------------------------------+
1308930c74f1SLei Zhang ///   |                    |
1309930c74f1SLei Zhang ///   |                    |----------------------+
1310930c74f1SLei Zhang ///   |                    |                      |
1311930c74f1SLei Zhang ///   |                    V                      |
1312930c74f1SLei Zhang ///   |  +------------------------------------+   |
1313930c74f1SLei Zhang ///   |  | ^body:                             |   |
1314930c74f1SLei Zhang ///   |  |   <body code>                      |   |
1315930c74f1SLei Zhang ///   |  |   llvm.br ^continue                |   |
1316930c74f1SLei Zhang ///   |  +------------------------------------+   |
1317930c74f1SLei Zhang ///   |                    |                      |
1318930c74f1SLei Zhang ///   |                    V                      |
1319930c74f1SLei Zhang ///   |  +------------------------------------+   |
1320930c74f1SLei Zhang ///   |  | ^continue:                         |   |
1321930c74f1SLei Zhang ///   |  |   <continue code>                  |   |
1322930c74f1SLei Zhang ///   |  |   llvm.br ^header                  |   |
1323930c74f1SLei Zhang ///   |  +------------------------------------+   |
1324930c74f1SLei Zhang ///   |               |                           |
1325930c74f1SLei Zhang ///   +---------------+    +----------------------+
1326930c74f1SLei Zhang ///                        |
1327930c74f1SLei Zhang ///                        V
1328930c74f1SLei Zhang ///      +------------------------------------+
1329930c74f1SLei Zhang ///      | ^exit:                             |
1330930c74f1SLei Zhang ///      |   llvm.br ^remaining               |
1331930c74f1SLei Zhang ///      +------------------------------------+
1332930c74f1SLei Zhang ///                        |
1333930c74f1SLei Zhang ///                        V
1334930c74f1SLei Zhang ///      +------------------------------------+
1335930c74f1SLei Zhang ///      | ^remaining:                        |
13365ab6ef75SJakub Kuderski ///      |   <code after spirv.mlir.loop>       |
1337930c74f1SLei Zhang ///      +------------------------------------+
1338930c74f1SLei Zhang ///
1339930c74f1SLei Zhang class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1340930c74f1SLei Zhang public:
1341930c74f1SLei Zhang   using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion;
1342930c74f1SLei Zhang 
1343930c74f1SLei Zhang   LogicalResult
1344b54c724bSRiver Riddle   matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1345930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
1346930c74f1SLei Zhang     // There is no support of loop control at the moment.
134790a1632dSJakub Kuderski     if (loopOp.getLoopControl() != spirv::LoopControl::None)
1348930c74f1SLei Zhang       return failure();
1349930c74f1SLei Zhang 
13505aa741d7SLongsheng Mou     // `spirv.mlir.loop` with empty region is redundant and should be erased.
13515aa741d7SLongsheng Mou     if (loopOp.getBody().empty()) {
13525aa741d7SLongsheng Mou       rewriter.eraseOp(loopOp);
13535aa741d7SLongsheng Mou       return success();
13545aa741d7SLongsheng Mou     }
13555aa741d7SLongsheng Mou 
1356930c74f1SLei Zhang     Location loc = loopOp.getLoc();
1357930c74f1SLei Zhang 
13585ab6ef75SJakub Kuderski     // Split the current block after `spirv.mlir.loop`. The remaining ops will
13595ab6ef75SJakub Kuderski     // be used in `endBlock`.
1360930c74f1SLei Zhang     Block *currentBlock = rewriter.getBlock();
1361930c74f1SLei Zhang     auto position = Block::iterator(loopOp);
1362930c74f1SLei Zhang     Block *endBlock = rewriter.splitBlock(currentBlock, position);
1363930c74f1SLei Zhang 
1364930c74f1SLei Zhang     // Remove entry block and create a branch in the current block going to the
1365930c74f1SLei Zhang     // header block.
1366930c74f1SLei Zhang     Block *entryBlock = loopOp.getEntryBlock();
1367930c74f1SLei Zhang     assert(entryBlock->getOperations().size() == 1);
1368930c74f1SLei Zhang     auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1369930c74f1SLei Zhang     if (!brOp)
1370930c74f1SLei Zhang       return failure();
1371930c74f1SLei Zhang     Block *headerBlock = loopOp.getHeaderBlock();
1372930c74f1SLei Zhang     rewriter.setInsertionPointToEnd(currentBlock);
1373930c74f1SLei Zhang     rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1374930c74f1SLei Zhang     rewriter.eraseBlock(entryBlock);
1375930c74f1SLei Zhang 
1376930c74f1SLei Zhang     // Branch from merge block to end block.
1377930c74f1SLei Zhang     Block *mergeBlock = loopOp.getMergeBlock();
1378930c74f1SLei Zhang     Operation *terminator = mergeBlock->getTerminator();
1379930c74f1SLei Zhang     ValueRange terminatorOperands = terminator->getOperands();
1380930c74f1SLei Zhang     rewriter.setInsertionPointToEnd(mergeBlock);
1381930c74f1SLei Zhang     rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1382930c74f1SLei Zhang 
138390a1632dSJakub Kuderski     rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
1384930c74f1SLei Zhang     rewriter.replaceOp(loopOp, endBlock->getArguments());
1385930c74f1SLei Zhang     return success();
1386930c74f1SLei Zhang   }
1387930c74f1SLei Zhang };
1388930c74f1SLei Zhang 
13895ab6ef75SJakub Kuderski /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
13903fb384d5SKareemErgawy-TomTom /// block. All blocks within selection should be reachable for conversion to
13913fb384d5SKareemErgawy-TomTom /// succeed.
1392930c74f1SLei Zhang class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1393930c74f1SLei Zhang public:
1394930c74f1SLei Zhang   using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
1395930c74f1SLei Zhang 
1396930c74f1SLei Zhang   LogicalResult
1397b54c724bSRiver Riddle   matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1398930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
1399930c74f1SLei Zhang     // There is no support for `Flatten` or `DontFlatten` selection control at
1400930c74f1SLei Zhang     // the moment. This are just compiler hints and can be performed during the
1401930c74f1SLei Zhang     // optimization passes.
140290a1632dSJakub Kuderski     if (op.getSelectionControl() != spirv::SelectionControl::None)
1403930c74f1SLei Zhang       return failure();
1404930c74f1SLei Zhang 
14055ab6ef75SJakub Kuderski     // `spirv.mlir.selection` should have at least two blocks: one selection
14063fb384d5SKareemErgawy-TomTom     // header block and one merge block. If no blocks are present, or control
14073fb384d5SKareemErgawy-TomTom     // flow branches straight to merge block (two blocks are present), the op is
1408930c74f1SLei Zhang     // redundant and it is erased.
140990a1632dSJakub Kuderski     if (op.getBody().getBlocks().size() <= 2) {
1410930c74f1SLei Zhang       rewriter.eraseOp(op);
1411930c74f1SLei Zhang       return success();
1412930c74f1SLei Zhang     }
1413930c74f1SLei Zhang 
1414930c74f1SLei Zhang     Location loc = op.getLoc();
1415930c74f1SLei Zhang 
14165ab6ef75SJakub Kuderski     // Split the current block after `spirv.mlir.selection`. The remaining ops
14173fb384d5SKareemErgawy-TomTom     // will be used in `continueBlock`.
1418930c74f1SLei Zhang     auto *currentBlock = rewriter.getInsertionBlock();
1419930c74f1SLei Zhang     rewriter.setInsertionPointAfter(op);
1420930c74f1SLei Zhang     auto position = rewriter.getInsertionPoint();
1421930c74f1SLei Zhang     auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1422930c74f1SLei Zhang 
1423930c74f1SLei Zhang     // Extract conditional branch information from the header block. By SPIR-V
14245ab6ef75SJakub Kuderski     // dialect spec, it should contain `spirv.BranchConditional` or
14255ab6ef75SJakub Kuderski     // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the
14265ab6ef75SJakub Kuderski     // moment in the SPIR-V dialect. Remove this block when finished.
1427930c74f1SLei Zhang     auto *headerBlock = op.getHeaderBlock();
1428930c74f1SLei Zhang     assert(headerBlock->getOperations().size() == 1);
1429930c74f1SLei Zhang     auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1430930c74f1SLei Zhang         headerBlock->getOperations().front());
1431930c74f1SLei Zhang     if (!condBrOp)
1432930c74f1SLei Zhang       return failure();
1433930c74f1SLei Zhang     rewriter.eraseBlock(headerBlock);
1434930c74f1SLei Zhang 
1435930c74f1SLei Zhang     // Branch from merge block to continue block.
1436930c74f1SLei Zhang     auto *mergeBlock = op.getMergeBlock();
1437930c74f1SLei Zhang     Operation *terminator = mergeBlock->getTerminator();
1438930c74f1SLei Zhang     ValueRange terminatorOperands = terminator->getOperands();
1439930c74f1SLei Zhang     rewriter.setInsertionPointToEnd(mergeBlock);
1440930c74f1SLei Zhang     rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1441930c74f1SLei Zhang 
1442930c74f1SLei Zhang     // Link current block to `true` and `false` blocks within the selection.
1443930c74f1SLei Zhang     Block *trueBlock = condBrOp.getTrueBlock();
1444930c74f1SLei Zhang     Block *falseBlock = condBrOp.getFalseBlock();
1445930c74f1SLei Zhang     rewriter.setInsertionPointToEnd(currentBlock);
144690a1632dSJakub Kuderski     rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1447162f7572SMahesh Ravishankar                                     condBrOp.getTrueTargetOperands(),
1448162f7572SMahesh Ravishankar                                     falseBlock,
144990a1632dSJakub Kuderski                                     condBrOp.getFalseTargetOperands());
1450930c74f1SLei Zhang 
145190a1632dSJakub Kuderski     rewriter.inlineRegionBefore(op.getBody(), continueBlock);
1452930c74f1SLei Zhang     rewriter.replaceOp(op, continueBlock->getArguments());
1453930c74f1SLei Zhang     return success();
1454930c74f1SLei Zhang   }
1455930c74f1SLei Zhang };
1456930c74f1SLei Zhang 
1457930c74f1SLei Zhang /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1458930c74f1SLei Zhang /// puts a restriction on `Shift` and `Base` to have the same bit width,
1459930c74f1SLei Zhang /// `Shift` is zero or sign extended to match this specification. Cases when
1460930c74f1SLei Zhang /// `Shift` bit width > `Base` bit width are considered to be illegal.
1461930c74f1SLei Zhang template <typename SPIRVOp, typename LLVMOp>
1462930c74f1SLei Zhang class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1463930c74f1SLei Zhang public:
1464930c74f1SLei Zhang   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
1465930c74f1SLei Zhang 
1466930c74f1SLei Zhang   LogicalResult
1467c19436eeSKohei Yamaguchi   matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1468930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
1469930c74f1SLei Zhang 
1470206fad0eSMatthias Springer     auto dstType = this->getTypeConverter()->convertType(op.getType());
1471930c74f1SLei Zhang     if (!dstType)
1472c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(op, "type conversion failed");
1473930c74f1SLei Zhang 
1474c19436eeSKohei Yamaguchi     Type op1Type = op.getOperand1().getType();
1475c19436eeSKohei Yamaguchi     Type op2Type = op.getOperand2().getType();
1476930c74f1SLei Zhang 
1477930c74f1SLei Zhang     if (op1Type == op2Type) {
1478c19436eeSKohei Yamaguchi       rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1479b54c724bSRiver Riddle                                                    adaptor.getOperands());
1480930c74f1SLei Zhang       return success();
1481930c74f1SLei Zhang     }
1482930c74f1SLei Zhang 
148300b12a94SThéo Degioanni     std::optional<uint64_t> dstTypeWidth =
148400b12a94SThéo Degioanni         getIntegerOrVectorElementWidth(dstType);
148500b12a94SThéo Degioanni     std::optional<uint64_t> op2TypeWidth =
148600b12a94SThéo Degioanni         getIntegerOrVectorElementWidth(op2Type);
148700b12a94SThéo Degioanni 
148800b12a94SThéo Degioanni     if (!dstTypeWidth || !op2TypeWidth)
148900b12a94SThéo Degioanni       return failure();
149000b12a94SThéo Degioanni 
1491c19436eeSKohei Yamaguchi     Location loc = op.getLoc();
1492930c74f1SLei Zhang     Value extended;
149300b12a94SThéo Degioanni     if (op2TypeWidth < dstTypeWidth) {
1494930c74f1SLei Zhang       if (isUnsignedIntegerOrVector(op2Type)) {
149500b12a94SThéo Degioanni         extended = rewriter.template create<LLVM::ZExtOp>(
149600b12a94SThéo Degioanni             loc, dstType, adaptor.getOperand2());
1497930c74f1SLei Zhang       } else {
149800b12a94SThéo Degioanni         extended = rewriter.template create<LLVM::SExtOp>(
149900b12a94SThéo Degioanni             loc, dstType, adaptor.getOperand2());
1500930c74f1SLei Zhang       }
150100b12a94SThéo Degioanni     } else if (op2TypeWidth == dstTypeWidth) {
150200b12a94SThéo Degioanni       extended = adaptor.getOperand2();
150300b12a94SThéo Degioanni     } else {
150400b12a94SThéo Degioanni       return failure();
150500b12a94SThéo Degioanni     }
150600b12a94SThéo Degioanni 
1507930c74f1SLei Zhang     Value result = rewriter.template create<LLVMOp>(
150890a1632dSJakub Kuderski         loc, dstType, adaptor.getOperand1(), extended);
1509c19436eeSKohei Yamaguchi     rewriter.replaceOp(op, result);
1510930c74f1SLei Zhang     return success();
1511930c74f1SLei Zhang   }
1512930c74f1SLei Zhang };
1513930c74f1SLei Zhang 
151452b630daSJakub Kuderski class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1515930c74f1SLei Zhang public:
151652b630daSJakub Kuderski   using SPIRVToLLVMConversion<spirv::GLTanOp>::SPIRVToLLVMConversion;
1517930c74f1SLei Zhang 
1518930c74f1SLei Zhang   LogicalResult
151952b630daSJakub Kuderski   matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1520930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
1521206fad0eSMatthias Springer     auto dstType = getTypeConverter()->convertType(tanOp.getType());
1522930c74f1SLei Zhang     if (!dstType)
1523c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
1524930c74f1SLei Zhang 
1525930c74f1SLei Zhang     Location loc = tanOp.getLoc();
152690a1632dSJakub Kuderski     Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
152790a1632dSJakub Kuderski     Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1528930c74f1SLei Zhang     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1529930c74f1SLei Zhang     return success();
1530930c74f1SLei Zhang   }
1531930c74f1SLei Zhang };
1532930c74f1SLei Zhang 
15335ab6ef75SJakub Kuderski /// Convert `spirv.Tanh` to
1534930c74f1SLei Zhang ///
1535930c74f1SLei Zhang ///   exp(2x) - 1
1536930c74f1SLei Zhang ///   -----------
1537930c74f1SLei Zhang ///   exp(2x) + 1
1538930c74f1SLei Zhang ///
153952b630daSJakub Kuderski class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1540930c74f1SLei Zhang public:
154152b630daSJakub Kuderski   using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion;
1542930c74f1SLei Zhang 
1543930c74f1SLei Zhang   LogicalResult
154452b630daSJakub Kuderski   matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1545930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
1546930c74f1SLei Zhang     auto srcType = tanhOp.getType();
1547206fad0eSMatthias Springer     auto dstType = getTypeConverter()->convertType(srcType);
1548930c74f1SLei Zhang     if (!dstType)
1549c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
1550930c74f1SLei Zhang 
1551930c74f1SLei Zhang     Location loc = tanhOp.getLoc();
1552930c74f1SLei Zhang     Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1553930c74f1SLei Zhang     Value multiplied =
155490a1632dSJakub Kuderski         rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1555930c74f1SLei Zhang     Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1556930c74f1SLei Zhang     Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1557930c74f1SLei Zhang     Value numerator =
1558930c74f1SLei Zhang         rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1559930c74f1SLei Zhang     Value denominator =
1560930c74f1SLei Zhang         rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1561930c74f1SLei Zhang     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1562930c74f1SLei Zhang                                               denominator);
1563930c74f1SLei Zhang     return success();
1564930c74f1SLei Zhang   }
1565930c74f1SLei Zhang };
1566930c74f1SLei Zhang 
1567930c74f1SLei Zhang class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1568930c74f1SLei Zhang public:
1569930c74f1SLei Zhang   using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
1570930c74f1SLei Zhang 
1571930c74f1SLei Zhang   LogicalResult
1572b54c724bSRiver Riddle   matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1573930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
1574930c74f1SLei Zhang     auto srcType = varOp.getType();
1575930c74f1SLei Zhang     // Initialization is supported for scalars and vectors only.
15765550c821STres Popp     auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
157790a1632dSJakub Kuderski     auto init = varOp.getInitializer();
15785550c821STres Popp     if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1579930c74f1SLei Zhang       return failure();
1580930c74f1SLei Zhang 
1581206fad0eSMatthias Springer     auto dstType = getTypeConverter()->convertType(srcType);
1582930c74f1SLei Zhang     if (!dstType)
1583c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1584930c74f1SLei Zhang 
1585930c74f1SLei Zhang     Location loc = varOp.getLoc();
1586930c74f1SLei Zhang     Value size = createI32ConstantOf(loc, rewriter, 1);
1587930c74f1SLei Zhang     if (!init) {
1588206fad0eSMatthias Springer       auto elementType = getTypeConverter()->convertType(pointerTo);
1589c19436eeSKohei Yamaguchi       if (!elementType)
1590c19436eeSKohei Yamaguchi         return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1591c19436eeSKohei Yamaguchi       rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
1592c19436eeSKohei Yamaguchi                                                   size);
1593930c74f1SLei Zhang       return success();
1594930c74f1SLei Zhang     }
1595206fad0eSMatthias Springer     auto elementType = getTypeConverter()->convertType(pointerTo);
1596c19436eeSKohei Yamaguchi     if (!elementType)
1597c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1598c19436eeSKohei Yamaguchi     Value allocated =
1599c19436eeSKohei Yamaguchi         rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size);
160090a1632dSJakub Kuderski     rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1601930c74f1SLei Zhang     rewriter.replaceOp(varOp, allocated);
1602930c74f1SLei Zhang     return success();
1603930c74f1SLei Zhang   }
1604930c74f1SLei Zhang };
1605930c74f1SLei Zhang 
1606930c74f1SLei Zhang //===----------------------------------------------------------------------===//
160719b1e27fSMarkus Böck // BitcastOp conversion
160819b1e27fSMarkus Böck //===----------------------------------------------------------------------===//
160919b1e27fSMarkus Böck 
161019b1e27fSMarkus Böck class BitcastConversionPattern
161119b1e27fSMarkus Böck     : public SPIRVToLLVMConversion<spirv::BitcastOp> {
161219b1e27fSMarkus Böck public:
161319b1e27fSMarkus Böck   using SPIRVToLLVMConversion<spirv::BitcastOp>::SPIRVToLLVMConversion;
161419b1e27fSMarkus Böck 
161519b1e27fSMarkus Böck   LogicalResult
161619b1e27fSMarkus Böck   matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
161719b1e27fSMarkus Böck                   ConversionPatternRewriter &rewriter) const override {
1618206fad0eSMatthias Springer     auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
161919b1e27fSMarkus Böck     if (!dstType)
1620c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
162119b1e27fSMarkus Böck 
162241f3b83fSChristian Ulmann     // LLVM's opaque pointers do not require bitcasts.
162341f3b83fSChristian Ulmann     if (isa<LLVM::LLVMPointerType>(dstType)) {
162419b1e27fSMarkus Böck       rewriter.replaceOp(bitcastOp, adaptor.getOperand());
162519b1e27fSMarkus Böck       return success();
162619b1e27fSMarkus Böck     }
162719b1e27fSMarkus Böck 
162819b1e27fSMarkus Böck     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
162919b1e27fSMarkus Böck         bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
163019b1e27fSMarkus Böck     return success();
163119b1e27fSMarkus Böck   }
163219b1e27fSMarkus Böck };
163319b1e27fSMarkus Böck 
163419b1e27fSMarkus Böck //===----------------------------------------------------------------------===//
1635930c74f1SLei Zhang // FuncOp conversion
1636930c74f1SLei Zhang //===----------------------------------------------------------------------===//
1637930c74f1SLei Zhang 
1638930c74f1SLei Zhang class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1639930c74f1SLei Zhang public:
1640930c74f1SLei Zhang   using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
1641930c74f1SLei Zhang 
1642930c74f1SLei Zhang   LogicalResult
1643b54c724bSRiver Riddle   matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1644930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
1645930c74f1SLei Zhang 
1646930c74f1SLei Zhang     // Convert function signature. At the moment LLVMType converter is enough
1647930c74f1SLei Zhang     // for currently supported types.
16484a3460a7SRiver Riddle     auto funcType = funcOp.getFunctionType();
1649930c74f1SLei Zhang     TypeConverter::SignatureConversion signatureConverter(
1650930c74f1SLei Zhang         funcType.getNumInputs());
165198723e65SMatthias Springer     auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter())
165298723e65SMatthias Springer                         ->convertFunctionSignature(
165398723e65SMatthias Springer                             funcType, /*isVariadic=*/false,
165498723e65SMatthias Springer                             /*useBarePtrCallConv=*/false, signatureConverter);
1655930c74f1SLei Zhang     if (!llvmType)
1656930c74f1SLei Zhang       return failure();
1657930c74f1SLei Zhang 
1658930c74f1SLei Zhang     // Create a new `LLVMFuncOp`
1659930c74f1SLei Zhang     Location loc = funcOp.getLoc();
1660930c74f1SLei Zhang     StringRef name = funcOp.getName();
1661930c74f1SLei Zhang     auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1662930c74f1SLei Zhang 
1663930c74f1SLei Zhang     // Convert SPIR-V Function Control to equivalent LLVM function attribute
1664930c74f1SLei Zhang     MLIRContext *context = funcOp.getContext();
166590a1632dSJakub Kuderski     switch (funcOp.getFunctionControl()) {
1666c012e487SJohannes de Fine Licht     case spirv::FunctionControl::Inline:
1667c012e487SJohannes de Fine Licht       newFuncOp.setAlwaysInline(true);
1668c012e487SJohannes de Fine Licht       break;
1669c012e487SJohannes de Fine Licht     case spirv::FunctionControl::DontInline:
1670c012e487SJohannes de Fine Licht       newFuncOp.setNoInline(true);
1671c012e487SJohannes de Fine Licht       break;
1672c012e487SJohannes de Fine Licht 
1673930c74f1SLei Zhang #define DISPATCH(functionControl, llvmAttr)                                    \
1674930c74f1SLei Zhang   case functionControl:                                                        \
1675c2c83e97STres Popp     newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr}));    \
1676930c74f1SLei Zhang     break;
1677930c74f1SLei Zhang 
1678930c74f1SLei Zhang       DISPATCH(spirv::FunctionControl::Pure,
1679c2c83e97STres Popp                StringAttr::get(context, "readonly"));
1680930c74f1SLei Zhang       DISPATCH(spirv::FunctionControl::Const,
1681c2c83e97STres Popp                StringAttr::get(context, "readnone"));
1682930c74f1SLei Zhang 
1683930c74f1SLei Zhang #undef DISPATCH
1684930c74f1SLei Zhang 
1685930c74f1SLei Zhang     // Default: if `spirv::FunctionControl::None`, then no attributes are
1686930c74f1SLei Zhang     // needed.
1687930c74f1SLei Zhang     default:
1688930c74f1SLei Zhang       break;
1689930c74f1SLei Zhang     }
1690930c74f1SLei Zhang 
1691930c74f1SLei Zhang     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1692930c74f1SLei Zhang                                 newFuncOp.end());
1693206fad0eSMatthias Springer     if (failed(rewriter.convertRegionTypes(
1694206fad0eSMatthias Springer             &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1695930c74f1SLei Zhang       return failure();
1696930c74f1SLei Zhang     }
1697930c74f1SLei Zhang     rewriter.eraseOp(funcOp);
1698930c74f1SLei Zhang     return success();
1699930c74f1SLei Zhang   }
1700930c74f1SLei Zhang };
1701930c74f1SLei Zhang 
1702930c74f1SLei Zhang //===----------------------------------------------------------------------===//
1703930c74f1SLei Zhang // ModuleOp conversion
1704930c74f1SLei Zhang //===----------------------------------------------------------------------===//
1705930c74f1SLei Zhang 
1706930c74f1SLei Zhang class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1707930c74f1SLei Zhang public:
1708930c74f1SLei Zhang   using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
1709930c74f1SLei Zhang 
1710930c74f1SLei Zhang   LogicalResult
1711b54c724bSRiver Riddle   matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1712930c74f1SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
1713930c74f1SLei Zhang 
1714930c74f1SLei Zhang     auto newModuleOp =
1715930c74f1SLei Zhang         rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
171656f60a1cSLei Zhang     rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1717930c74f1SLei Zhang 
1718930c74f1SLei Zhang     // Remove the terminator block that was automatically added by builder
1719930c74f1SLei Zhang     rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1720930c74f1SLei Zhang     rewriter.eraseOp(spvModuleOp);
1721930c74f1SLei Zhang     return success();
1722930c74f1SLei Zhang   }
1723930c74f1SLei Zhang };
1724930c74f1SLei Zhang 
17253483fc5aSWeiwei Li //===----------------------------------------------------------------------===//
17263483fc5aSWeiwei Li // VectorShuffleOp conversion
17273483fc5aSWeiwei Li //===----------------------------------------------------------------------===//
17283483fc5aSWeiwei Li 
17293483fc5aSWeiwei Li class VectorShufflePattern
17303483fc5aSWeiwei Li     : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
17313483fc5aSWeiwei Li public:
17323483fc5aSWeiwei Li   using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion;
17333483fc5aSWeiwei Li   LogicalResult
17343483fc5aSWeiwei Li   matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
17353483fc5aSWeiwei Li                   ConversionPatternRewriter &rewriter) const override {
17363483fc5aSWeiwei Li     Location loc = op.getLoc();
173790a1632dSJakub Kuderski     auto components = adaptor.getComponents();
173890a1632dSJakub Kuderski     auto vector1 = adaptor.getVector1();
173990a1632dSJakub Kuderski     auto vector2 = adaptor.getVector2();
17405550c821STres Popp     int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
17415550c821STres Popp     int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
17423483fc5aSWeiwei Li     if (vector1Size == vector2Size) {
1743b2ccfb4dSJeff Niu       rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1744b2ccfb4dSJeff Niu           op, vector1, vector2,
1745b2ccfb4dSJeff Niu           LLVM::convertArrayToIndices<int32_t>(components));
17463483fc5aSWeiwei Li       return success();
17473483fc5aSWeiwei Li     }
17483483fc5aSWeiwei Li 
1749206fad0eSMatthias Springer     auto dstType = getTypeConverter()->convertType(op.getType());
1750c19436eeSKohei Yamaguchi     if (!dstType)
1751c19436eeSKohei Yamaguchi       return rewriter.notifyMatchFailure(op, "type conversion failed");
17525550c821STres Popp     auto scalarType = cast<VectorType>(dstType).getElementType();
17533483fc5aSWeiwei Li     auto componentsArray = components.getValue();
175402b6fb21SMehdi Amini     auto *context = rewriter.getContext();
17553483fc5aSWeiwei Li     auto llvmI32Type = IntegerType::get(context, 32);
17563483fc5aSWeiwei Li     Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
17573483fc5aSWeiwei Li     for (unsigned i = 0; i < componentsArray.size(); i++) {
17585550c821STres Popp       if (!isa<IntegerAttr>(componentsArray[i]))
1759a29fffc4SLei Zhang         return op.emitError("unable to support non-constant component");
17603483fc5aSWeiwei Li 
17615550c821STres Popp       int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
17623483fc5aSWeiwei Li       if (indexVal == -1)
17633483fc5aSWeiwei Li         continue;
17643483fc5aSWeiwei Li 
17653483fc5aSWeiwei Li       int offsetVal = 0;
17663483fc5aSWeiwei Li       Value baseVector = vector1;
17673483fc5aSWeiwei Li       if (indexVal >= vector1Size) {
17683483fc5aSWeiwei Li         offsetVal = vector1Size;
17693483fc5aSWeiwei Li         baseVector = vector2;
17703483fc5aSWeiwei Li       }
17713483fc5aSWeiwei Li 
17723483fc5aSWeiwei Li       Value dstIndex = rewriter.create<LLVM::ConstantOp>(
17733483fc5aSWeiwei Li           loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
17743483fc5aSWeiwei Li       Value index = rewriter.create<LLVM::ConstantOp>(
17753483fc5aSWeiwei Li           loc, llvmI32Type,
17763483fc5aSWeiwei Li           rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
17773483fc5aSWeiwei Li 
17783483fc5aSWeiwei Li       auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
17793483fc5aSWeiwei Li           loc, scalarType, baseVector, index);
17803483fc5aSWeiwei Li       targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
17813483fc5aSWeiwei Li                                                         extractOp, dstIndex);
17823483fc5aSWeiwei Li     }
17833483fc5aSWeiwei Li     rewriter.replaceOp(op, targetOp);
17843483fc5aSWeiwei Li     return success();
17853483fc5aSWeiwei Li   }
17863483fc5aSWeiwei Li };
1787930c74f1SLei Zhang } // namespace
1788930c74f1SLei Zhang 
1789930c74f1SLei Zhang //===----------------------------------------------------------------------===//
1790930c74f1SLei Zhang // Pattern population
1791930c74f1SLei Zhang //===----------------------------------------------------------------------===//
1792930c74f1SLei Zhang 
179312ce9fd1SVictor Perez void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter,
179412ce9fd1SVictor Perez                                              spirv::ClientAPI clientAPI) {
1795930c74f1SLei Zhang   typeConverter.addConversion([&](spirv::ArrayType type) {
1796930c74f1SLei Zhang     return convertArrayType(type, typeConverter);
1797930c74f1SLei Zhang   });
179812ce9fd1SVictor Perez   typeConverter.addConversion([&, clientAPI](spirv::PointerType type) {
179912ce9fd1SVictor Perez     return convertPointerType(type, typeConverter, clientAPI);
1800930c74f1SLei Zhang   });
1801930c74f1SLei Zhang   typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1802930c74f1SLei Zhang     return convertRuntimeArrayType(type, typeConverter);
1803930c74f1SLei Zhang   });
1804930c74f1SLei Zhang   typeConverter.addConversion([&](spirv::StructType type) {
1805930c74f1SLei Zhang     return convertStructType(type, typeConverter);
1806930c74f1SLei Zhang   });
1807930c74f1SLei Zhang }
1808930c74f1SLei Zhang 
1809930c74f1SLei Zhang void mlir::populateSPIRVToLLVMConversionPatterns(
1810206fad0eSMatthias Springer     const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
181112ce9fd1SVictor Perez     spirv::ClientAPI clientAPI) {
1812dc4e913bSChris Lattner   patterns.add<
1813930c74f1SLei Zhang       // Arithmetic ops
1814930c74f1SLei Zhang       DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1815930c74f1SLei Zhang       DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1816930c74f1SLei Zhang       DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1817930c74f1SLei Zhang       DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1818930c74f1SLei Zhang       DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1819930c74f1SLei Zhang       DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1820930c74f1SLei Zhang       DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1821930c74f1SLei Zhang       DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1822930c74f1SLei Zhang       DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1823930c74f1SLei Zhang       DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1824930c74f1SLei Zhang       DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1825930c74f1SLei Zhang       DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1826930c74f1SLei Zhang       DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1827930c74f1SLei Zhang 
1828930c74f1SLei Zhang       // Bitwise ops
1829930c74f1SLei Zhang       BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1830930c74f1SLei Zhang       DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1831930c74f1SLei Zhang       DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1832930c74f1SLei Zhang       DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1833930c74f1SLei Zhang       DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1834930c74f1SLei Zhang       DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1835930c74f1SLei Zhang       NotPattern<spirv::NotOp>,
1836930c74f1SLei Zhang 
1837930c74f1SLei Zhang       // Cast ops
183819b1e27fSMarkus Böck       BitcastConversionPattern,
1839930c74f1SLei Zhang       DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1840930c74f1SLei Zhang       DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1841930c74f1SLei Zhang       DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1842930c74f1SLei Zhang       DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1843930c74f1SLei Zhang       IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1844930c74f1SLei Zhang       IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1845930c74f1SLei Zhang       IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1846930c74f1SLei Zhang 
1847930c74f1SLei Zhang       // Comparison ops
1848930c74f1SLei Zhang       IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1849930c74f1SLei Zhang       IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1850930c74f1SLei Zhang       FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1851930c74f1SLei Zhang       FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1852930c74f1SLei Zhang       FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1853930c74f1SLei Zhang       FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1854930c74f1SLei Zhang       FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1855930c74f1SLei Zhang       FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1856930c74f1SLei Zhang       FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1857930c74f1SLei Zhang       FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1858930c74f1SLei Zhang       FComparePattern<spirv::FUnordGreaterThanEqualOp,
1859930c74f1SLei Zhang                       LLVM::FCmpPredicate::uge>,
1860930c74f1SLei Zhang       FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1861930c74f1SLei Zhang       FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1862930c74f1SLei Zhang       FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1863930c74f1SLei Zhang       IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1864930c74f1SLei Zhang       IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1865930c74f1SLei Zhang       IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1866930c74f1SLei Zhang       IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1867930c74f1SLei Zhang       IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1868930c74f1SLei Zhang       IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1869930c74f1SLei Zhang       IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1870930c74f1SLei Zhang       IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1871930c74f1SLei Zhang 
1872930c74f1SLei Zhang       // Constant op
1873930c74f1SLei Zhang       ConstantScalarAndVectorPattern,
1874930c74f1SLei Zhang 
1875930c74f1SLei Zhang       // Control Flow ops
1876930c74f1SLei Zhang       BranchConversionPattern, BranchConditionalConversionPattern,
1877930c74f1SLei Zhang       FunctionCallPattern, LoopPattern, SelectionPattern,
1878930c74f1SLei Zhang       ErasePattern<spirv::MergeOp>,
1879930c74f1SLei Zhang 
1880930c74f1SLei Zhang       // Entry points and execution mode are handled separately.
1881930c74f1SLei Zhang       ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1882930c74f1SLei Zhang 
1883930c74f1SLei Zhang       // GLSL extended instruction set ops
188452b630daSJakub Kuderski       DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
188552b630daSJakub Kuderski       DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
188652b630daSJakub Kuderski       DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
188752b630daSJakub Kuderski       DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
188852b630daSJakub Kuderski       DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
188952b630daSJakub Kuderski       DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
189052b630daSJakub Kuderski       DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
189152b630daSJakub Kuderski       DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
189252b630daSJakub Kuderski       DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
189352b630daSJakub Kuderski       DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
189452b630daSJakub Kuderski       DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
189552b630daSJakub Kuderski       DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1896930c74f1SLei Zhang       InverseSqrtPattern, TanPattern, TanhPattern,
1897930c74f1SLei Zhang 
1898930c74f1SLei Zhang       // Logical ops
1899930c74f1SLei Zhang       DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1900930c74f1SLei Zhang       DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1901930c74f1SLei Zhang       IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1902930c74f1SLei Zhang       IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1903930c74f1SLei Zhang       NotPattern<spirv::LogicalNotOp>,
1904930c74f1SLei Zhang 
1905930c74f1SLei Zhang       // Memory ops
190612ce9fd1SVictor Perez       AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
190712ce9fd1SVictor Perez       LoadStorePattern<spirv::StoreOp>, VariablePattern,
1908930c74f1SLei Zhang 
1909930c74f1SLei Zhang       // Miscellaneous ops
1910930c74f1SLei Zhang       CompositeExtractPattern, CompositeInsertPattern,
1911930c74f1SLei Zhang       DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1912930c74f1SLei Zhang       DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
19133483fc5aSWeiwei Li       VectorShufflePattern,
1914930c74f1SLei Zhang 
1915930c74f1SLei Zhang       // Shift ops
1916930c74f1SLei Zhang       ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1917930c74f1SLei Zhang       ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1918930c74f1SLei Zhang       ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1919930c74f1SLei Zhang 
1920930c74f1SLei Zhang       // Return ops
19211775b98dSFinlay       ReturnPattern, ReturnValuePattern,
19221775b98dSFinlay 
19231775b98dSFinlay       // Barrier ops
192405fcdd55SVictor Perez       ControlBarrierPattern<spirv::ControlBarrierOp>,
192505fcdd55SVictor Perez       ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
192605fcdd55SVictor Perez       ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
19276ade03d7SLukas Sommer 
19286ade03d7SLukas Sommer       // Group reduction operations
19296ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupIAddOp>,
19306ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupFAddOp>,
19316ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupFMinOp>,
19326ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupUMinOp>,
19336ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupSMinOp, /*Signed=*/true>,
19346ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupFMaxOp>,
19356ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupUMaxOp>,
19366ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupSMaxOp, /*Signed=*/true>,
19376ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed=*/false,
19386ade03d7SLukas Sommer                          /*NonUniform=*/true>,
19396ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed=*/false,
19406ade03d7SLukas Sommer                          /*NonUniform=*/true>,
19416ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed=*/false,
19426ade03d7SLukas Sommer                          /*NonUniform=*/true>,
19436ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed=*/false,
19446ade03d7SLukas Sommer                          /*NonUniform=*/true>,
19456ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed=*/true,
19466ade03d7SLukas Sommer                          /*NonUniform=*/true>,
19476ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed=*/false,
19486ade03d7SLukas Sommer                          /*NonUniform=*/true>,
19496ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed=*/false,
19506ade03d7SLukas Sommer                          /*NonUniform=*/true>,
19516ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed=*/true,
19526ade03d7SLukas Sommer                          /*NonUniform=*/true>,
19536ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed=*/false,
19546ade03d7SLukas Sommer                          /*NonUniform=*/true>,
19556ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed=*/false,
19566ade03d7SLukas Sommer                          /*NonUniform=*/true>,
19576ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed=*/false,
19586ade03d7SLukas Sommer                          /*NonUniform=*/true>,
19596ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed=*/false,
19606ade03d7SLukas Sommer                          /*NonUniform=*/true>,
19616ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed=*/false,
19626ade03d7SLukas Sommer                          /*NonUniform=*/true>,
19636ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed=*/false,
19646ade03d7SLukas Sommer                          /*NonUniform=*/true>,
19656ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed=*/false,
19666ade03d7SLukas Sommer                          /*NonUniform=*/true>,
19676ade03d7SLukas Sommer       GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed=*/false,
19686ade03d7SLukas Sommer                          /*NonUniform=*/true>>(patterns.getContext(),
19696ade03d7SLukas Sommer                                                typeConverter);
197012ce9fd1SVictor Perez 
197112ce9fd1SVictor Perez   patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
197212ce9fd1SVictor Perez                                       typeConverter);
1973930c74f1SLei Zhang }
1974930c74f1SLei Zhang 
1975930c74f1SLei Zhang void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
1976206fad0eSMatthias Springer     const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1977dc4e913bSChris Lattner   patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1978930c74f1SLei Zhang }
1979930c74f1SLei Zhang 
1980930c74f1SLei Zhang void mlir::populateSPIRVToLLVMModuleConversionPatterns(
1981206fad0eSMatthias Springer     const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
198256f60a1cSLei Zhang   patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1983930c74f1SLei Zhang }
1984930c74f1SLei Zhang 
1985930c74f1SLei Zhang //===----------------------------------------------------------------------===//
1986930c74f1SLei Zhang // Pre-conversion hooks
1987930c74f1SLei Zhang //===----------------------------------------------------------------------===//
1988930c74f1SLei Zhang 
1989930c74f1SLei Zhang /// Hook for descriptor set and binding number encoding.
1990930c74f1SLei Zhang static constexpr StringRef kBinding = "binding";
1991930c74f1SLei Zhang static constexpr StringRef kDescriptorSet = "descriptor_set";
1992930c74f1SLei Zhang void mlir::encodeBindAttribute(ModuleOp module) {
1993930c74f1SLei Zhang   auto spvModules = module.getOps<spirv::ModuleOp>();
1994930c74f1SLei Zhang   for (auto spvModule : spvModules) {
1995930c74f1SLei Zhang     spvModule.walk([&](spirv::GlobalVariableOp op) {
1996930c74f1SLei Zhang       IntegerAttr descriptorSet =
1997930c74f1SLei Zhang           op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1998930c74f1SLei Zhang       IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1999930c74f1SLei Zhang       // For every global variable in the module, get the ones with descriptor
2000930c74f1SLei Zhang       // set and binding numbers.
2001930c74f1SLei Zhang       if (descriptorSet && binding) {
2002930c74f1SLei Zhang         // Encode these numbers into the variable's symbolic name. If the
2003930c74f1SLei Zhang         // SPIR-V module has a name, add it at the beginning.
2004c27d8152SKazu Hirata         auto moduleAndName =
2005c27d8152SKazu Hirata             spvModule.getName().has_value()
200690a1632dSJakub Kuderski                 ? spvModule.getName()->str() + "_" + op.getSymName().str()
200790a1632dSJakub Kuderski                 : op.getSymName().str();
2008930c74f1SLei Zhang         std::string name =
2009930c74f1SLei Zhang             llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
2010930c74f1SLei Zhang                           std::to_string(descriptorSet.getInt()),
2011930c74f1SLei Zhang                           std::to_string(binding.getInt()));
201241d4aa7dSChris Lattner         auto nameAttr = StringAttr::get(op->getContext(), name);
2013930c74f1SLei Zhang 
2014930c74f1SLei Zhang         // Replace all symbol uses and set the new symbol name. Finally, remove
2015930c74f1SLei Zhang         // descriptor set and binding attributes.
201641d4aa7dSChris Lattner         if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
2017930c74f1SLei Zhang           op.emitError("unable to replace all symbol uses for ") << name;
201841d4aa7dSChris Lattner         SymbolTable::setSymbolName(op, nameAttr);
2019dffc487bSChristian Sigg         op->removeAttr(kDescriptorSet);
2020dffc487bSChristian Sigg         op->removeAttr(kBinding);
2021930c74f1SLei Zhang       }
2022930c74f1SLei Zhang     });
2023930c74f1SLei Zhang   }
2024930c74f1SLei Zhang }
2025