xref: /llvm-project/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //===- SPIRVConversion.h - SPIR-V Conversion Utilities ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines utilities to use while converting to the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRVCONVERSION_H
14 #define MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRVCONVERSION_H
15 
16 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
20 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "mlir/Transforms/OneToNTypeConversion.h"
24 #include "llvm/ADT/SmallSet.h"
25 #include "llvm/Support/LogicalResult.h"
26 
27 namespace mlir {
28 
29 //===----------------------------------------------------------------------===//
30 // Type Converter
31 //===----------------------------------------------------------------------===//
32 
33 /// How sub-byte values are storaged in memory.
34 enum class SPIRVSubByteTypeStorage {
35   /// Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
36   Packed,
37 };
38 
39 struct SPIRVConversionOptions {
40   /// The number of bits to store a boolean value.
41   unsigned boolNumBits{8};
42 
43   /// How sub-byte values are storaged in memory.
44   SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed};
45 
46   /// Whether to emulate narrower scalar types with 32-bit scalar types if not
47   /// supported by the target.
48   ///
49   /// Non-32-bit scalar types require special hardware support that may not
50   /// exist on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar
51   /// types require special capabilities or extensions. This option controls
52   /// whether to use 32-bit types to emulate < 32-bits-wide scalars, if a scalar
53   /// type of a certain bitwidth is not supported in the target environment.
54   /// This requires the runtime to also feed in data with a matched bitwidth and
55   /// layout for interface types. The runtime can do that by inspecting the
56   /// SPIR-V module.
57   ///
58   /// If the original scalar type has less than 32-bit, a multiple of its
59   /// values will be packed into one 32-bit value to be memory efficient.
60   bool emulateLT32BitScalarTypes{true};
61 
62   /// Use 64-bit integers when converting index types.
63   bool use64bitIndex{false};
64 };
65 
66 /// Type conversion from builtin types to SPIR-V types for shader interface.
67 ///
68 /// For memref types, this converter additionally performs type wrapping to
69 /// satisfy shader interface requirements: shader interface types must be
70 /// pointers to structs.
71 class SPIRVTypeConverter : public TypeConverter {
72 public:
73   explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
74                               const SPIRVConversionOptions &options = {});
75 
76   /// Gets the SPIR-V correspondence for the standard index type.
77   Type getIndexType() const;
78 
79   /// Gets the bitwidth of the index type when converted to SPIR-V.
80   unsigned getIndexTypeBitwidth() const {
81     return options.use64bitIndex ? 64 : 32;
82   }
83 
84   const spirv::TargetEnv &getTargetEnv() const { return targetEnv; }
85 
86   /// Returns the options controlling the SPIR-V type converter.
87   const SPIRVConversionOptions &getOptions() const { return options; }
88 
89   /// Checks if the SPIR-V capability inquired is supported.
90   bool allows(spirv::Capability capability) const;
91 
92 private:
93   spirv::TargetEnv targetEnv;
94   SPIRVConversionOptions options;
95 
96   MLIRContext *getContext() const;
97 };
98 
99 //===----------------------------------------------------------------------===//
100 // Conversion Target
101 //===----------------------------------------------------------------------===//
102 
103 // The default SPIR-V conversion target.
104 //
105 // It takes a SPIR-V target environment and controls operation legality based on
106 // the their availability in the target environment.
107 class SPIRVConversionTarget : public ConversionTarget {
108 public:
109   /// Creates a SPIR-V conversion target for the given target environment.
110   static std::unique_ptr<SPIRVConversionTarget>
111   get(spirv::TargetEnvAttr targetAttr);
112 
113 private:
114   explicit SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr);
115 
116   // Be explicit that instance of this class cannot be copied or moved: there
117   // are lambdas capturing fields of the instance.
118   SPIRVConversionTarget(const SPIRVConversionTarget &) = delete;
119   SPIRVConversionTarget(SPIRVConversionTarget &&) = delete;
120   SPIRVConversionTarget &operator=(const SPIRVConversionTarget &) = delete;
121   SPIRVConversionTarget &operator=(SPIRVConversionTarget &&) = delete;
122 
123   /// Returns true if the given `op` is legal to use under the current target
124   /// environment.
125   bool isLegalOp(Operation *op);
126 
127   spirv::TargetEnv targetEnv;
128 };
129 
130 //===----------------------------------------------------------------------===//
131 // Patterns and Utility Functions
132 //===----------------------------------------------------------------------===//
133 
134 /// Appends to a pattern list additional patterns for translating the builtin
135 /// `func` op to the SPIR-V dialect. These patterns do not handle shader
136 /// interface/ABI; they convert function parameters to be of SPIR-V allowed
137 /// types.
138 void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
139                                         RewritePatternSet &patterns);
140 
141 void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns);
142 
143 void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns);
144 
145 namespace spirv {
146 class AccessChainOp;
147 
148 /// Returns the value for the given `builtin` variable. This function gets or
149 /// inserts the global variable associated for the builtin within the nearest
150 /// symbol table enclosing `op`. Returns null Value on error.
151 ///
152 /// The global name being generated will be mangled using `preffix` and
153 /// `suffix`.
154 Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType,
155                               OpBuilder &builder,
156                               StringRef prefix = "__builtin__",
157                               StringRef suffix = "__");
158 
159 /// Gets the value at the given `offset` of the push constant storage with a
160 /// total of `elementCount` `integerType` integers. A global variable will be
161 /// created in the nearest symbol table enclosing `op` for the push constant
162 /// storage if not existing. Load ops will be created via the given `builder` to
163 /// load values from the push constant. Returns null Value on error.
164 Value getPushConstantValue(Operation *op, unsigned elementCount,
165                            unsigned offset, Type integerType,
166                            OpBuilder &builder);
167 
168 /// Generates IR to perform index linearization with the given `indices` and
169 /// their corresponding `strides`, adding an initial `offset`.
170 Value linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
171                      int64_t offset, Type integerType, Location loc,
172                      OpBuilder &builder);
173 
174 /// Performs the index computation to get to the element at `indices` of the
175 /// memory pointed to by `basePtr`, using the layout map of `baseType`.
176 /// Returns null if index computation cannot be performed.
177 
178 // TODO: This method assumes that the `baseType` is a MemRefType with AffineMap
179 // that has static strides. Extend to handle dynamic strides.
180 Value getElementPtr(const SPIRVTypeConverter &typeConverter,
181                     MemRefType baseType, Value basePtr, ValueRange indices,
182                     Location loc, OpBuilder &builder);
183 
184 // GetElementPtr implementation for Kernel/OpenCL flavored SPIR-V.
185 Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter,
186                           MemRefType baseType, Value basePtr,
187                           ValueRange indices, Location loc, OpBuilder &builder);
188 
189 // GetElementPtr implementation for Vulkan/Shader flavored SPIR-V.
190 Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
191                           MemRefType baseType, Value basePtr,
192                           ValueRange indices, Location loc, OpBuilder &builder);
193 
194 // Find the largest factor of size among {2,3,4} for the lowest dimension of
195 // the target shape.
196 int getComputeVectorSize(int64_t size);
197 
198 // GetNativeVectorShape implementation for reduction ops.
199 SmallVector<int64_t> getNativeVectorShapeImpl(vector::ReductionOp op);
200 
201 // GetNativeVectorShape implementation for transpose ops.
202 SmallVector<int64_t> getNativeVectorShapeImpl(vector::TransposeOp op);
203 
204 // For general ops.
205 std::optional<SmallVector<int64_t>> getNativeVectorShape(Operation *op);
206 
207 // Unroll vectors in function signatures to native size.
208 LogicalResult unrollVectorsInSignatures(Operation *op);
209 
210 // Unroll vectors in function bodies to native size.
211 LogicalResult unrollVectorsInFuncBodies(Operation *op);
212 
213 } // namespace spirv
214 } // namespace mlir
215 
216 #endif // MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRVCONVERSION_H
217