1 //===- SPIRVConversion.h - SPIR-V Conversion Utilities ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Defines utilities to use while converting to the SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRVCONVERSION_H 14 #define MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRVCONVERSION_H 15 16 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" 20 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 #include "mlir/Transforms/OneToNTypeConversion.h" 24 #include "llvm/ADT/SmallSet.h" 25 #include "llvm/Support/LogicalResult.h" 26 27 namespace mlir { 28 29 //===----------------------------------------------------------------------===// 30 // Type Converter 31 //===----------------------------------------------------------------------===// 32 33 /// How sub-byte values are storaged in memory. 34 enum class SPIRVSubByteTypeStorage { 35 /// Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8. 36 Packed, 37 }; 38 39 struct SPIRVConversionOptions { 40 /// The number of bits to store a boolean value. 41 unsigned boolNumBits{8}; 42 43 /// How sub-byte values are storaged in memory. 44 SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed}; 45 46 /// Whether to emulate narrower scalar types with 32-bit scalar types if not 47 /// supported by the target. 48 /// 49 /// Non-32-bit scalar types require special hardware support that may not 50 /// exist on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar 51 /// types require special capabilities or extensions. This option controls 52 /// whether to use 32-bit types to emulate < 32-bits-wide scalars, if a scalar 53 /// type of a certain bitwidth is not supported in the target environment. 54 /// This requires the runtime to also feed in data with a matched bitwidth and 55 /// layout for interface types. The runtime can do that by inspecting the 56 /// SPIR-V module. 57 /// 58 /// If the original scalar type has less than 32-bit, a multiple of its 59 /// values will be packed into one 32-bit value to be memory efficient. 60 bool emulateLT32BitScalarTypes{true}; 61 62 /// Use 64-bit integers when converting index types. 63 bool use64bitIndex{false}; 64 }; 65 66 /// Type conversion from builtin types to SPIR-V types for shader interface. 67 /// 68 /// For memref types, this converter additionally performs type wrapping to 69 /// satisfy shader interface requirements: shader interface types must be 70 /// pointers to structs. 71 class SPIRVTypeConverter : public TypeConverter { 72 public: 73 explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, 74 const SPIRVConversionOptions &options = {}); 75 76 /// Gets the SPIR-V correspondence for the standard index type. 77 Type getIndexType() const; 78 79 /// Gets the bitwidth of the index type when converted to SPIR-V. 80 unsigned getIndexTypeBitwidth() const { 81 return options.use64bitIndex ? 64 : 32; 82 } 83 84 const spirv::TargetEnv &getTargetEnv() const { return targetEnv; } 85 86 /// Returns the options controlling the SPIR-V type converter. 87 const SPIRVConversionOptions &getOptions() const { return options; } 88 89 /// Checks if the SPIR-V capability inquired is supported. 90 bool allows(spirv::Capability capability) const; 91 92 private: 93 spirv::TargetEnv targetEnv; 94 SPIRVConversionOptions options; 95 96 MLIRContext *getContext() const; 97 }; 98 99 //===----------------------------------------------------------------------===// 100 // Conversion Target 101 //===----------------------------------------------------------------------===// 102 103 // The default SPIR-V conversion target. 104 // 105 // It takes a SPIR-V target environment and controls operation legality based on 106 // the their availability in the target environment. 107 class SPIRVConversionTarget : public ConversionTarget { 108 public: 109 /// Creates a SPIR-V conversion target for the given target environment. 110 static std::unique_ptr<SPIRVConversionTarget> 111 get(spirv::TargetEnvAttr targetAttr); 112 113 private: 114 explicit SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr); 115 116 // Be explicit that instance of this class cannot be copied or moved: there 117 // are lambdas capturing fields of the instance. 118 SPIRVConversionTarget(const SPIRVConversionTarget &) = delete; 119 SPIRVConversionTarget(SPIRVConversionTarget &&) = delete; 120 SPIRVConversionTarget &operator=(const SPIRVConversionTarget &) = delete; 121 SPIRVConversionTarget &operator=(SPIRVConversionTarget &&) = delete; 122 123 /// Returns true if the given `op` is legal to use under the current target 124 /// environment. 125 bool isLegalOp(Operation *op); 126 127 spirv::TargetEnv targetEnv; 128 }; 129 130 //===----------------------------------------------------------------------===// 131 // Patterns and Utility Functions 132 //===----------------------------------------------------------------------===// 133 134 /// Appends to a pattern list additional patterns for translating the builtin 135 /// `func` op to the SPIR-V dialect. These patterns do not handle shader 136 /// interface/ABI; they convert function parameters to be of SPIR-V allowed 137 /// types. 138 void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, 139 RewritePatternSet &patterns); 140 141 void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns); 142 143 void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns); 144 145 namespace spirv { 146 class AccessChainOp; 147 148 /// Returns the value for the given `builtin` variable. This function gets or 149 /// inserts the global variable associated for the builtin within the nearest 150 /// symbol table enclosing `op`. Returns null Value on error. 151 /// 152 /// The global name being generated will be mangled using `preffix` and 153 /// `suffix`. 154 Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, 155 OpBuilder &builder, 156 StringRef prefix = "__builtin__", 157 StringRef suffix = "__"); 158 159 /// Gets the value at the given `offset` of the push constant storage with a 160 /// total of `elementCount` `integerType` integers. A global variable will be 161 /// created in the nearest symbol table enclosing `op` for the push constant 162 /// storage if not existing. Load ops will be created via the given `builder` to 163 /// load values from the push constant. Returns null Value on error. 164 Value getPushConstantValue(Operation *op, unsigned elementCount, 165 unsigned offset, Type integerType, 166 OpBuilder &builder); 167 168 /// Generates IR to perform index linearization with the given `indices` and 169 /// their corresponding `strides`, adding an initial `offset`. 170 Value linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides, 171 int64_t offset, Type integerType, Location loc, 172 OpBuilder &builder); 173 174 /// Performs the index computation to get to the element at `indices` of the 175 /// memory pointed to by `basePtr`, using the layout map of `baseType`. 176 /// Returns null if index computation cannot be performed. 177 178 // TODO: This method assumes that the `baseType` is a MemRefType with AffineMap 179 // that has static strides. Extend to handle dynamic strides. 180 Value getElementPtr(const SPIRVTypeConverter &typeConverter, 181 MemRefType baseType, Value basePtr, ValueRange indices, 182 Location loc, OpBuilder &builder); 183 184 // GetElementPtr implementation for Kernel/OpenCL flavored SPIR-V. 185 Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, 186 MemRefType baseType, Value basePtr, 187 ValueRange indices, Location loc, OpBuilder &builder); 188 189 // GetElementPtr implementation for Vulkan/Shader flavored SPIR-V. 190 Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, 191 MemRefType baseType, Value basePtr, 192 ValueRange indices, Location loc, OpBuilder &builder); 193 194 // Find the largest factor of size among {2,3,4} for the lowest dimension of 195 // the target shape. 196 int getComputeVectorSize(int64_t size); 197 198 // GetNativeVectorShape implementation for reduction ops. 199 SmallVector<int64_t> getNativeVectorShapeImpl(vector::ReductionOp op); 200 201 // GetNativeVectorShape implementation for transpose ops. 202 SmallVector<int64_t> getNativeVectorShapeImpl(vector::TransposeOp op); 203 204 // For general ops. 205 std::optional<SmallVector<int64_t>> getNativeVectorShape(Operation *op); 206 207 // Unroll vectors in function signatures to native size. 208 LogicalResult unrollVectorsInSignatures(Operation *op); 209 210 // Unroll vectors in function bodies to native size. 211 LogicalResult unrollVectorsInFuncBodies(Operation *op); 212 213 } // namespace spirv 214 } // namespace mlir 215 216 #endif // MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRVCONVERSION_H 217