xref: /llvm-project/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h (revision df31fd8a3648a487aaec31bef059187b0d316d1f)
1 //===- TypeConverter.h - Convert builtin to LLVM dialect types --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Provides a type converter configuration for converting most builtin types to
10 // LLVM dialect types.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
15 #define MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
16 
17 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 
21 namespace mlir {
22 
23 class DataLayoutAnalysis;
24 class FunctionOpInterface;
25 class LowerToLLVMOptions;
26 
27 namespace LLVM {
28 class LLVMDialect;
29 class LLVMPointerType;
30 class LLVMFunctionType;
31 class LLVMStructType;
32 } // namespace LLVM
33 
34 /// Conversion from types to the LLVM IR dialect.
35 class LLVMTypeConverter : public TypeConverter {
36   /// Give structFuncArgTypeConverter access to memref-specific functions.
37   friend LogicalResult
38   structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
39                              SmallVectorImpl<Type> &result);
40 
41 public:
42   using TypeConverter::convertType;
43 
44   /// Create an LLVMTypeConverter using the default LowerToLLVMOptions.
45   /// Optionally takes a data layout analysis to use in conversions.
46   LLVMTypeConverter(MLIRContext *ctx,
47                     const DataLayoutAnalysis *analysis = nullptr);
48 
49   /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. Optionally
50   /// takes a data layout analysis to use in conversions.
51   LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options,
52                     const DataLayoutAnalysis *analysis = nullptr);
53 
54   /// Convert a function type. The arguments and results are converted one by
55   /// one and results are packed into a wrapped LLVM IR structure type. `result`
56   /// is populated with argument mapping.
57   Type convertFunctionSignature(FunctionType funcTy, bool isVariadic,
58                                 bool useBarePtrCallConv,
59                                 SignatureConversion &result) const;
60 
61   /// Convert a function type. The arguments and results are converted one by
62   /// one and results are packed into a wrapped LLVM IR structure type. `result`
63   /// is populated with argument mapping. Converted types of `llvm.byval` and
64   /// `llvm.byref` function arguments which are not LLVM pointers are overridden
65   /// with LLVM pointers. Overridden arguments are returned in
66   /// `byValRefNonPtrAttrs`.
67   Type convertFunctionSignature(FunctionOpInterface funcOp, bool isVariadic,
68                                 bool useBarePtrCallConv,
69                                 LLVMTypeConverter::SignatureConversion &result,
70                                 SmallVectorImpl<std::optional<NamedAttribute>>
71                                     &byValRefNonPtrAttrs) const;
72 
73   /// Convert a non-empty list of types to be returned from a function into an
74   /// LLVM-compatible type. In particular, if more than one value is returned,
75   /// create an LLVM dialect structure type with elements that correspond to
76   /// each of the types converted with `convertCallingConventionType`.
77   Type packFunctionResults(TypeRange types,
78                            bool useBarePointerCallConv = false) const;
79 
80   /// Convert a non-empty list of types of values produced by an operation into
81   /// an LLVM-compatible type. In particular, if more than one value is
82   /// produced, create a literal structure with elements that correspond to each
83   /// of the LLVM-compatible types converted with `convertType`.
84   Type packOperationResults(TypeRange types) const;
85 
86   /// Convert a type in the context of the default or bare pointer calling
87   /// convention. Calling convention sensitive types, such as MemRefType and
88   /// UnrankedMemRefType, are converted following the specific rules for the
89   /// calling convention. Calling convention independent types are converted
90   /// following the default LLVM type conversions.
91   Type convertCallingConventionType(Type type,
92                                     bool useBarePointerCallConv = false) const;
93 
94   /// Promote the bare pointers in 'values' that resulted from memrefs to
95   /// descriptors. 'stdTypes' holds the types of 'values' before the conversion
96   /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
97   void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter,
98                                     Location loc, ArrayRef<Type> stdTypes,
99                                     SmallVectorImpl<Value> &values) const;
100 
101   /// Returns the MLIR context.
102   MLIRContext &getContext() const;
103 
104   /// Returns the LLVM dialect.
105   LLVM::LLVMDialect *getDialect() const { return llvmDialect; }
106 
107   const LowerToLLVMOptions &getOptions() const { return options; }
108 
109   /// Promote the LLVM representation of all operands including promoting MemRef
110   /// descriptors to stack and use pointers to struct to avoid the complexity
111   /// of the platform-specific C/C++ ABI lowering related to struct argument
112   /// passing.
113   SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
114                                         ValueRange operands, OpBuilder &builder,
115                                         bool useBarePtrCallConv = false) const;
116 
117   /// Promote the LLVM struct representation of one MemRef descriptor to stack
118   /// and use pointer to struct to avoid the complexity of the platform-specific
119   /// C/C++ ABI lowering related to struct argument passing.
120   Value promoteOneMemRefDescriptor(Location loc, Value operand,
121                                    OpBuilder &builder) const;
122 
123   /// Converts the function type to a C-compatible format, in particular using
124   /// pointers to memref descriptors for arguments. Also converts the return
125   /// type to a pointer argument if it is a struct. Returns true if this
126   /// was the case.
127   std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
128   convertFunctionTypeCWrapper(FunctionType type) const;
129 
130   /// Returns the data layout to use during and after conversion.
131   const llvm::DataLayout &getDataLayout() const { return options.dataLayout; }
132 
133   /// Returns the data layout analysis to query during conversion.
134   const DataLayoutAnalysis *getDataLayoutAnalysis() const {
135     return dataLayoutAnalysis;
136   }
137 
138   /// Gets the LLVM representation of the index type. The returned type is an
139   /// integer type with the size configured for this type converter.
140   Type getIndexType() const;
141 
142   /// Gets the bitwidth of the index type when converted to LLVM.
143   unsigned getIndexTypeBitwidth() const { return options.getIndexBitwidth(); }
144 
145   /// Gets the pointer bitwidth.
146   unsigned getPointerBitwidth(unsigned addressSpace = 0) const;
147 
148   /// Returns the size of the memref descriptor object in bytes.
149   unsigned getMemRefDescriptorSize(MemRefType type,
150                                    const DataLayout &layout) const;
151 
152   /// Returns the size of the unranked memref descriptor object in bytes.
153   unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,
154                                            const DataLayout &layout) const;
155 
156   /// Return the LLVM address space corresponding to the memory space of the
157   /// memref type `type` or failure if the memory space cannot be converted to
158   /// an integer.
159   FailureOr<unsigned> getMemRefAddressSpace(BaseMemRefType type) const;
160 
161   /// Check if a memref type can be converted to a bare pointer.
162   static bool canConvertToBarePtr(BaseMemRefType type);
163 
164   /// Convert a memref type into a list of LLVM IR types that will form the
165   /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides`
166   /// arrays in the descriptors are unpacked to individual index-typed elements,
167   /// else they are kept as rank-sized arrays of index type. In particular,
168   /// the list will contain:
169   /// - two pointers to the memref element type, followed by
170   /// - an index-typed offset, followed by
171   /// - (if unpackAggregates = true)
172   ///    - one index-typed size per dimension of the memref, followed by
173   ///    - one index-typed stride per dimension of the memref.
174   /// - (if unpackArrregates = false)
175   ///   - one rank-sized array of index-type for the size of each dimension
176   ///   - one rank-sized array of index-type for the stride of each dimension
177   ///
178   /// For example, memref<?x?xf32> is converted to the following list:
179   /// - `!llvm<"float*">` (allocated pointer),
180   /// - `!llvm<"float*">` (aligned pointer),
181   /// - `i64` (offset),
182   /// - `i64`, `i64` (sizes),
183   /// - `i64`, `i64` (strides).
184   /// These types can be recomposed to a memref descriptor struct.
185   SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type,
186                                                  bool unpackAggregates) const;
187 
188   /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
189   /// that will form the unranked memref descriptor. In particular, this list
190   /// contains:
191   /// - an integer rank, followed by
192   /// - a pointer to the memref descriptor struct.
193   /// For example, memref<*xf32> is converted to the following list:
194   /// i64 (rank)
195   /// !llvm<"i8*"> (type-erased pointer).
196   /// These types can be recomposed to a unranked memref descriptor struct.
197   SmallVector<Type, 2> getUnrankedMemRefDescriptorFields() const;
198 
199 protected:
200   /// Pointer to the LLVM dialect.
201   LLVM::LLVMDialect *llvmDialect;
202 
203   // Recursive structure detection.
204   // We store one entry per thread here, and rely on locking.
205   DenseMap<uint64_t, std::unique_ptr<SmallVector<Type>>> conversionCallStack;
206   llvm::sys::SmartRWMutex<true> callStackMutex;
207   SmallVector<Type> &getCurrentThreadRecursiveStack();
208 
209 private:
210   /// Convert a function type. The arguments and results are converted one by
211   /// one. Additionally, if the function returns more than one value, pack the
212   /// results into an LLVM IR structure type so that the converted function type
213   /// returns at most one result.
214   Type convertFunctionType(FunctionType type) const;
215 
216   /// Common implementation for `convertFunctionSignature` methods. Convert a
217   /// function type. The arguments and results are converted one by one and
218   /// results are packed into a wrapped LLVM IR structure type. `result` is
219   /// populated with argument mapping. If `byValRefNonPtrAttrs` is provided,
220   /// converted types of `llvm.byval` and `llvm.byref` function arguments which
221   /// are not LLVM pointers are overridden with LLVM pointers. `llvm.byval` and
222   /// `llvm.byref` arguments that were already converted to LLVM pointer types
223   /// are removed from 'byValRefNonPtrAttrs`.
224   Type convertFunctionSignatureImpl(
225       FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
226       LLVMTypeConverter::SignatureConversion &result,
227       SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs)
228       const;
229 
230   /// Convert the index type.  Uses llvmModule data layout to create an integer
231   /// of the pointer bitwidth.
232   Type convertIndexType(IndexType type) const;
233 
234   /// Convert an integer type `i*` to `!llvm<"i*">`.
235   Type convertIntegerType(IntegerType type) const;
236 
237   /// Convert a floating point type: `f16` to `f16`, `f32` to
238   /// `f32` and `f64` to `f64`.  `bf16` is not supported
239   /// by LLVM. 8-bit float types are converted to 8-bit integers as this is how
240   /// all LLVM backends that support them currently represent them.
241   Type convertFloatType(FloatType type) const;
242 
243   /// Convert complex number type: `complex<f16>` to `!llvm<"{ half, half }">`,
244   /// `complex<f32>` to `!llvm<"{ float, float }">`, and `complex<f64>` to
245   /// `!llvm<"{ double, double }">`. `complex<bf16>` is not supported.
246   Type convertComplexType(ComplexType type) const;
247 
248   /// Convert a memref type into an LLVM type that captures the relevant data.
249   Type convertMemRefType(MemRefType type) const;
250 
251   /// Convert an unranked memref type to an LLVM type that captures the
252   /// runtime rank and a pointer to the static ranked memref desc
253   Type convertUnrankedMemRefType(UnrankedMemRefType type) const;
254 
255   /// Convert a memref type to a bare pointer to the memref element type.
256   Type convertMemRefToBarePtr(BaseMemRefType type) const;
257 
258   /// Convert a 1D vector type into an LLVM vector type.
259   FailureOr<Type> convertVectorType(VectorType type) const;
260 
261   /// Options for customizing the llvm lowering.
262   LowerToLLVMOptions options;
263 
264   /// Data layout analysis mapping scopes to layouts active in them.
265   const DataLayoutAnalysis *dataLayoutAnalysis;
266 };
267 
268 /// Callback to convert function argument types. It converts a MemRef function
269 /// argument to a list of non-aggregate types containing descriptor
270 /// information, and an UnrankedmemRef function argument to a list containing
271 /// the rank and a pointer to a descriptor struct.
272 LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter,
273                                          Type type,
274                                          SmallVectorImpl<Type> &result);
275 
276 /// Callback to convert function argument types. It converts MemRef function
277 /// arguments to bare pointers to the MemRef element type.
278 LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter,
279                                           Type type,
280                                           SmallVectorImpl<Type> &result);
281 
282 } // namespace mlir
283 
284 #endif // MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
285