1 //===- TypeConverter.h - Convert builtin to LLVM dialect types --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Provides a type converter configuration for converting most builtin types to 10 // LLVM dialect types. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H 15 #define MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H 16 17 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 21 namespace mlir { 22 23 class DataLayoutAnalysis; 24 class FunctionOpInterface; 25 class LowerToLLVMOptions; 26 27 namespace LLVM { 28 class LLVMDialect; 29 class LLVMPointerType; 30 class LLVMFunctionType; 31 class LLVMStructType; 32 } // namespace LLVM 33 34 /// Conversion from types to the LLVM IR dialect. 35 class LLVMTypeConverter : public TypeConverter { 36 /// Give structFuncArgTypeConverter access to memref-specific functions. 37 friend LogicalResult 38 structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, 39 SmallVectorImpl<Type> &result); 40 41 public: 42 using TypeConverter::convertType; 43 44 /// Create an LLVMTypeConverter using the default LowerToLLVMOptions. 45 /// Optionally takes a data layout analysis to use in conversions. 46 LLVMTypeConverter(MLIRContext *ctx, 47 const DataLayoutAnalysis *analysis = nullptr); 48 49 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. Optionally 50 /// takes a data layout analysis to use in conversions. 51 LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options, 52 const DataLayoutAnalysis *analysis = nullptr); 53 54 /// Convert a function type. The arguments and results are converted one by 55 /// one and results are packed into a wrapped LLVM IR structure type. `result` 56 /// is populated with argument mapping. 57 Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, 58 bool useBarePtrCallConv, 59 SignatureConversion &result) const; 60 61 /// Convert a function type. The arguments and results are converted one by 62 /// one and results are packed into a wrapped LLVM IR structure type. `result` 63 /// is populated with argument mapping. Converted types of `llvm.byval` and 64 /// `llvm.byref` function arguments which are not LLVM pointers are overridden 65 /// with LLVM pointers. Overridden arguments are returned in 66 /// `byValRefNonPtrAttrs`. 67 Type convertFunctionSignature(FunctionOpInterface funcOp, bool isVariadic, 68 bool useBarePtrCallConv, 69 LLVMTypeConverter::SignatureConversion &result, 70 SmallVectorImpl<std::optional<NamedAttribute>> 71 &byValRefNonPtrAttrs) const; 72 73 /// Convert a non-empty list of types to be returned from a function into an 74 /// LLVM-compatible type. In particular, if more than one value is returned, 75 /// create an LLVM dialect structure type with elements that correspond to 76 /// each of the types converted with `convertCallingConventionType`. 77 Type packFunctionResults(TypeRange types, 78 bool useBarePointerCallConv = false) const; 79 80 /// Convert a non-empty list of types of values produced by an operation into 81 /// an LLVM-compatible type. In particular, if more than one value is 82 /// produced, create a literal structure with elements that correspond to each 83 /// of the LLVM-compatible types converted with `convertType`. 84 Type packOperationResults(TypeRange types) const; 85 86 /// Convert a type in the context of the default or bare pointer calling 87 /// convention. Calling convention sensitive types, such as MemRefType and 88 /// UnrankedMemRefType, are converted following the specific rules for the 89 /// calling convention. Calling convention independent types are converted 90 /// following the default LLVM type conversions. 91 Type convertCallingConventionType(Type type, 92 bool useBarePointerCallConv = false) const; 93 94 /// Promote the bare pointers in 'values' that resulted from memrefs to 95 /// descriptors. 'stdTypes' holds the types of 'values' before the conversion 96 /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type). 97 void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, 98 Location loc, ArrayRef<Type> stdTypes, 99 SmallVectorImpl<Value> &values) const; 100 101 /// Returns the MLIR context. 102 MLIRContext &getContext() const; 103 104 /// Returns the LLVM dialect. 105 LLVM::LLVMDialect *getDialect() const { return llvmDialect; } 106 107 const LowerToLLVMOptions &getOptions() const { return options; } 108 109 /// Promote the LLVM representation of all operands including promoting MemRef 110 /// descriptors to stack and use pointers to struct to avoid the complexity 111 /// of the platform-specific C/C++ ABI lowering related to struct argument 112 /// passing. 113 SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands, 114 ValueRange operands, OpBuilder &builder, 115 bool useBarePtrCallConv = false) const; 116 117 /// Promote the LLVM struct representation of one MemRef descriptor to stack 118 /// and use pointer to struct to avoid the complexity of the platform-specific 119 /// C/C++ ABI lowering related to struct argument passing. 120 Value promoteOneMemRefDescriptor(Location loc, Value operand, 121 OpBuilder &builder) const; 122 123 /// Converts the function type to a C-compatible format, in particular using 124 /// pointers to memref descriptors for arguments. Also converts the return 125 /// type to a pointer argument if it is a struct. Returns true if this 126 /// was the case. 127 std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType> 128 convertFunctionTypeCWrapper(FunctionType type) const; 129 130 /// Returns the data layout to use during and after conversion. 131 const llvm::DataLayout &getDataLayout() const { return options.dataLayout; } 132 133 /// Returns the data layout analysis to query during conversion. 134 const DataLayoutAnalysis *getDataLayoutAnalysis() const { 135 return dataLayoutAnalysis; 136 } 137 138 /// Gets the LLVM representation of the index type. The returned type is an 139 /// integer type with the size configured for this type converter. 140 Type getIndexType() const; 141 142 /// Gets the bitwidth of the index type when converted to LLVM. 143 unsigned getIndexTypeBitwidth() const { return options.getIndexBitwidth(); } 144 145 /// Gets the pointer bitwidth. 146 unsigned getPointerBitwidth(unsigned addressSpace = 0) const; 147 148 /// Returns the size of the memref descriptor object in bytes. 149 unsigned getMemRefDescriptorSize(MemRefType type, 150 const DataLayout &layout) const; 151 152 /// Returns the size of the unranked memref descriptor object in bytes. 153 unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, 154 const DataLayout &layout) const; 155 156 /// Return the LLVM address space corresponding to the memory space of the 157 /// memref type `type` or failure if the memory space cannot be converted to 158 /// an integer. 159 FailureOr<unsigned> getMemRefAddressSpace(BaseMemRefType type) const; 160 161 /// Check if a memref type can be converted to a bare pointer. 162 static bool canConvertToBarePtr(BaseMemRefType type); 163 164 /// Convert a memref type into a list of LLVM IR types that will form the 165 /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides` 166 /// arrays in the descriptors are unpacked to individual index-typed elements, 167 /// else they are kept as rank-sized arrays of index type. In particular, 168 /// the list will contain: 169 /// - two pointers to the memref element type, followed by 170 /// - an index-typed offset, followed by 171 /// - (if unpackAggregates = true) 172 /// - one index-typed size per dimension of the memref, followed by 173 /// - one index-typed stride per dimension of the memref. 174 /// - (if unpackArrregates = false) 175 /// - one rank-sized array of index-type for the size of each dimension 176 /// - one rank-sized array of index-type for the stride of each dimension 177 /// 178 /// For example, memref<?x?xf32> is converted to the following list: 179 /// - `!llvm<"float*">` (allocated pointer), 180 /// - `!llvm<"float*">` (aligned pointer), 181 /// - `i64` (offset), 182 /// - `i64`, `i64` (sizes), 183 /// - `i64`, `i64` (strides). 184 /// These types can be recomposed to a memref descriptor struct. 185 SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type, 186 bool unpackAggregates) const; 187 188 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types 189 /// that will form the unranked memref descriptor. In particular, this list 190 /// contains: 191 /// - an integer rank, followed by 192 /// - a pointer to the memref descriptor struct. 193 /// For example, memref<*xf32> is converted to the following list: 194 /// i64 (rank) 195 /// !llvm<"i8*"> (type-erased pointer). 196 /// These types can be recomposed to a unranked memref descriptor struct. 197 SmallVector<Type, 2> getUnrankedMemRefDescriptorFields() const; 198 199 protected: 200 /// Pointer to the LLVM dialect. 201 LLVM::LLVMDialect *llvmDialect; 202 203 // Recursive structure detection. 204 // We store one entry per thread here, and rely on locking. 205 DenseMap<uint64_t, std::unique_ptr<SmallVector<Type>>> conversionCallStack; 206 llvm::sys::SmartRWMutex<true> callStackMutex; 207 SmallVector<Type> &getCurrentThreadRecursiveStack(); 208 209 private: 210 /// Convert a function type. The arguments and results are converted one by 211 /// one. Additionally, if the function returns more than one value, pack the 212 /// results into an LLVM IR structure type so that the converted function type 213 /// returns at most one result. 214 Type convertFunctionType(FunctionType type) const; 215 216 /// Common implementation for `convertFunctionSignature` methods. Convert a 217 /// function type. The arguments and results are converted one by one and 218 /// results are packed into a wrapped LLVM IR structure type. `result` is 219 /// populated with argument mapping. If `byValRefNonPtrAttrs` is provided, 220 /// converted types of `llvm.byval` and `llvm.byref` function arguments which 221 /// are not LLVM pointers are overridden with LLVM pointers. `llvm.byval` and 222 /// `llvm.byref` arguments that were already converted to LLVM pointer types 223 /// are removed from 'byValRefNonPtrAttrs`. 224 Type convertFunctionSignatureImpl( 225 FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, 226 LLVMTypeConverter::SignatureConversion &result, 227 SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs) 228 const; 229 230 /// Convert the index type. Uses llvmModule data layout to create an integer 231 /// of the pointer bitwidth. 232 Type convertIndexType(IndexType type) const; 233 234 /// Convert an integer type `i*` to `!llvm<"i*">`. 235 Type convertIntegerType(IntegerType type) const; 236 237 /// Convert a floating point type: `f16` to `f16`, `f32` to 238 /// `f32` and `f64` to `f64`. `bf16` is not supported 239 /// by LLVM. 8-bit float types are converted to 8-bit integers as this is how 240 /// all LLVM backends that support them currently represent them. 241 Type convertFloatType(FloatType type) const; 242 243 /// Convert complex number type: `complex<f16>` to `!llvm<"{ half, half }">`, 244 /// `complex<f32>` to `!llvm<"{ float, float }">`, and `complex<f64>` to 245 /// `!llvm<"{ double, double }">`. `complex<bf16>` is not supported. 246 Type convertComplexType(ComplexType type) const; 247 248 /// Convert a memref type into an LLVM type that captures the relevant data. 249 Type convertMemRefType(MemRefType type) const; 250 251 /// Convert an unranked memref type to an LLVM type that captures the 252 /// runtime rank and a pointer to the static ranked memref desc 253 Type convertUnrankedMemRefType(UnrankedMemRefType type) const; 254 255 /// Convert a memref type to a bare pointer to the memref element type. 256 Type convertMemRefToBarePtr(BaseMemRefType type) const; 257 258 /// Convert a 1D vector type into an LLVM vector type. 259 FailureOr<Type> convertVectorType(VectorType type) const; 260 261 /// Options for customizing the llvm lowering. 262 LowerToLLVMOptions options; 263 264 /// Data layout analysis mapping scopes to layouts active in them. 265 const DataLayoutAnalysis *dataLayoutAnalysis; 266 }; 267 268 /// Callback to convert function argument types. It converts a MemRef function 269 /// argument to a list of non-aggregate types containing descriptor 270 /// information, and an UnrankedmemRef function argument to a list containing 271 /// the rank and a pointer to a descriptor struct. 272 LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, 273 Type type, 274 SmallVectorImpl<Type> &result); 275 276 /// Callback to convert function argument types. It converts MemRef function 277 /// arguments to bare pointers to the MemRef element type. 278 LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, 279 Type type, 280 SmallVectorImpl<Type> &result); 281 282 } // namespace mlir 283 284 #endif // MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H 285