xref: /llvm-project/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
10 #include "MemRefDescriptor.h"
11 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/Support/Threading.h"
16 #include <memory>
17 #include <mutex>
18 #include <optional>
19 
20 using namespace mlir;
21 
22 SmallVector<Type> &LLVMTypeConverter::getCurrentThreadRecursiveStack() {
23   {
24     // Most of the time, the entry already exists in the map.
25     std::shared_lock<decltype(callStackMutex)> lock(callStackMutex,
26                                                     std::defer_lock);
27     if (getContext().isMultithreadingEnabled())
28       lock.lock();
29     auto recursiveStack = conversionCallStack.find(llvm::get_threadid());
30     if (recursiveStack != conversionCallStack.end())
31       return *recursiveStack->second;
32   }
33 
34   // First time this thread gets here, we have to get an exclusive access to
35   // inset in the map
36   std::unique_lock<decltype(callStackMutex)> lock(callStackMutex);
37   auto recursiveStackInserted = conversionCallStack.insert(std::make_pair(
38       llvm::get_threadid(), std::make_unique<SmallVector<Type>>()));
39   return *recursiveStackInserted.first->second;
40 }
41 
42 /// Create an LLVMTypeConverter using default LowerToLLVMOptions.
43 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
44                                      const DataLayoutAnalysis *analysis)
45     : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
46 
47 /// Helper function that checks if the given value range is a bare pointer.
48 static bool isBarePointer(ValueRange values) {
49   return values.size() == 1 &&
50          isa<LLVM::LLVMPointerType>(values.front().getType());
51 }
52 
53 /// Pack SSA values into an unranked memref descriptor struct.
54 static Value packUnrankedMemRefDesc(OpBuilder &builder,
55                                     UnrankedMemRefType resultType,
56                                     ValueRange inputs, Location loc,
57                                     const LLVMTypeConverter &converter) {
58   // Note: Bare pointers are not supported for unranked memrefs because a
59   // memref descriptor cannot be built just from a bare pointer.
60   if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
61     return Value();
62   return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
63                                         inputs);
64 }
65 
66 /// Pack SSA values into a ranked memref descriptor struct.
67 static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType,
68                                   ValueRange inputs, Location loc,
69                                   const LLVMTypeConverter &converter) {
70   assert(resultType && "expected non-null result type");
71   if (isBarePointer(inputs))
72     return MemRefDescriptor::fromStaticShape(builder, loc, converter,
73                                              resultType, inputs[0]);
74   if (TypeRange(inputs) ==
75       converter.getMemRefDescriptorFields(resultType,
76                                           /*unpackAggregates=*/true))
77     return MemRefDescriptor::pack(builder, loc, converter, resultType, inputs);
78   // The inputs are neither a bare pointer nor an unpacked memref descriptor.
79   // This materialization function cannot be used.
80   return Value();
81 }
82 
83 /// MemRef descriptor elements -> UnrankedMemRefType
84 static Value unrankedMemRefMaterialization(OpBuilder &builder,
85                                            UnrankedMemRefType resultType,
86                                            ValueRange inputs, Location loc,
87                                            const LLVMTypeConverter &converter) {
88   // A source materialization must return a value of type
89   // `resultType`, so insert a cast from the memref descriptor type
90   // (!llvm.struct) to the original memref type.
91   Value packed =
92       packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter);
93   if (!packed)
94     return Value();
95   return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
96       .getResult(0);
97 }
98 
99 /// MemRef descriptor elements -> MemRefType
100 static Value rankedMemRefMaterialization(OpBuilder &builder,
101                                          MemRefType resultType,
102                                          ValueRange inputs, Location loc,
103                                          const LLVMTypeConverter &converter) {
104   // A source materialization must return a value of type `resultType`,
105   // so insert a cast from the memref descriptor type (!llvm.struct) to the
106   // original memref type.
107   Value packed =
108       packRankedMemRefDesc(builder, resultType, inputs, loc, converter);
109   if (!packed)
110     return Value();
111   return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
112       .getResult(0);
113 }
114 
115 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
116 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
117                                      const LowerToLLVMOptions &options,
118                                      const DataLayoutAnalysis *analysis)
119     : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options),
120       dataLayoutAnalysis(analysis) {
121   assert(llvmDialect && "LLVM IR dialect is not registered");
122 
123   // Register conversions for the builtin types.
124   addConversion([&](ComplexType type) { return convertComplexType(type); });
125   addConversion([&](FloatType type) { return convertFloatType(type); });
126   addConversion([&](FunctionType type) { return convertFunctionType(type); });
127   addConversion([&](IndexType type) { return convertIndexType(type); });
128   addConversion([&](IntegerType type) { return convertIntegerType(type); });
129   addConversion([&](MemRefType type) { return convertMemRefType(type); });
130   addConversion(
131       [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
132   addConversion([&](VectorType type) -> std::optional<Type> {
133     FailureOr<Type> llvmType = convertVectorType(type);
134     if (failed(llvmType))
135       return std::nullopt;
136     return llvmType;
137   });
138 
139   // LLVM-compatible types are legal, so add a pass-through conversion. Do this
140   // before the conversions below since conversions are attempted in reverse
141   // order and those should take priority.
142   addConversion([](Type type) {
143     return LLVM::isCompatibleType(type) ? std::optional<Type>(type)
144                                         : std::nullopt;
145   });
146 
147   addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results)
148                     -> std::optional<LogicalResult> {
149     // Fastpath for types that won't be converted by this callback anyway.
150     if (LLVM::isCompatibleType(type)) {
151       results.push_back(type);
152       return success();
153     }
154 
155     if (type.isIdentified()) {
156       auto convertedType = LLVM::LLVMStructType::getIdentified(
157           type.getContext(), ("_Converted." + type.getName()).str());
158 
159       SmallVectorImpl<Type> &recursiveStack = getCurrentThreadRecursiveStack();
160       if (llvm::count(recursiveStack, type)) {
161         results.push_back(convertedType);
162         return success();
163       }
164       recursiveStack.push_back(type);
165       auto popConversionCallStack = llvm::make_scope_exit(
166           [&recursiveStack]() { recursiveStack.pop_back(); });
167 
168       SmallVector<Type> convertedElemTypes;
169       convertedElemTypes.reserve(type.getBody().size());
170       if (failed(convertTypes(type.getBody(), convertedElemTypes)))
171         return std::nullopt;
172 
173       // If the converted type has not been initialized yet, just set its body
174       // to be the converted arguments and return.
175       if (!convertedType.isInitialized()) {
176         if (failed(
177                 convertedType.setBody(convertedElemTypes, type.isPacked()))) {
178           return failure();
179         }
180         results.push_back(convertedType);
181         return success();
182       }
183 
184       // If it has been initialized, has the same body and packed bit, just use
185       // it. This ensures that recursive structs keep being recursive rather
186       // than including a non-updated name.
187       if (TypeRange(convertedType.getBody()) == TypeRange(convertedElemTypes) &&
188           convertedType.isPacked() == type.isPacked()) {
189         results.push_back(convertedType);
190         return success();
191       }
192 
193       return failure();
194     }
195 
196     SmallVector<Type> convertedSubtypes;
197     convertedSubtypes.reserve(type.getBody().size());
198     if (failed(convertTypes(type.getBody(), convertedSubtypes)))
199       return std::nullopt;
200 
201     results.push_back(LLVM::LLVMStructType::getLiteral(
202         type.getContext(), convertedSubtypes, type.isPacked()));
203     return success();
204   });
205   addConversion([&](LLVM::LLVMArrayType type) -> std::optional<Type> {
206     if (auto element = convertType(type.getElementType()))
207       return LLVM::LLVMArrayType::get(element, type.getNumElements());
208     return std::nullopt;
209   });
210   addConversion([&](LLVM::LLVMFunctionType type) -> std::optional<Type> {
211     Type convertedResType = convertType(type.getReturnType());
212     if (!convertedResType)
213       return std::nullopt;
214 
215     SmallVector<Type> convertedArgTypes;
216     convertedArgTypes.reserve(type.getNumParams());
217     if (failed(convertTypes(type.getParams(), convertedArgTypes)))
218       return std::nullopt;
219 
220     return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes,
221                                        type.isVarArg());
222   });
223 
224   // Add generic source and target materializations to handle cases where
225   // non-LLVM types persist after an LLVM conversion.
226   addSourceMaterialization([&](OpBuilder &builder, Type resultType,
227                                ValueRange inputs, Location loc) {
228     return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
229         .getResult(0);
230   });
231   addTargetMaterialization([&](OpBuilder &builder, Type resultType,
232                                ValueRange inputs, Location loc) {
233     return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
234         .getResult(0);
235   });
236 
237   // Source materializations convert from the new block argument types
238   // (multiple SSA values that make up a memref descriptor) back to the
239   // original block argument type.
240   addSourceMaterialization([&](OpBuilder &builder,
241                                UnrankedMemRefType resultType, ValueRange inputs,
242                                Location loc) {
243     return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
244                                          *this);
245   });
246   addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
247                                ValueRange inputs, Location loc) {
248     return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
249   });
250 
251   // Bare pointer -> Packed MemRef descriptor
252   addTargetMaterialization([&](OpBuilder &builder, Type resultType,
253                                ValueRange inputs, Location loc,
254                                Type originalType) -> Value {
255     // The original MemRef type is required to build a MemRef descriptor
256     // because the sizes/strides of the MemRef cannot be inferred from just the
257     // bare pointer.
258     if (!originalType)
259       return Value();
260     if (resultType != convertType(originalType))
261       return Value();
262     if (auto memrefType = dyn_cast<MemRefType>(originalType))
263       return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this);
264     if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
265       return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc,
266                                     *this);
267     return Value();
268   });
269 
270   // Integer memory spaces map to themselves.
271   addTypeAttributeConversion(
272       [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
273 }
274 
275 /// Returns the MLIR context.
276 MLIRContext &LLVMTypeConverter::getContext() const {
277   return *getDialect()->getContext();
278 }
279 
280 Type LLVMTypeConverter::getIndexType() const {
281   return IntegerType::get(&getContext(), getIndexTypeBitwidth());
282 }
283 
284 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const {
285   return options.dataLayout.getPointerSizeInBits(addressSpace);
286 }
287 
288 Type LLVMTypeConverter::convertIndexType(IndexType type) const {
289   return getIndexType();
290 }
291 
292 Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
293   return IntegerType::get(&getContext(), type.getWidth());
294 }
295 
296 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
297   // Valid LLVM float types are used directly.
298   if (LLVM::isCompatibleType(type))
299     return type;
300 
301   // F4, F6, F8 types are converted to integer types with the same bit width.
302   if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
303           Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
304           Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
305           Float8E8M0FNUType>(type))
306     return IntegerType::get(&getContext(), type.getWidth());
307 
308   // Other floating-point types: A custom type conversion rule must be
309   // specified by the user.
310   return Type();
311 }
312 
313 // Convert a `ComplexType` to an LLVM type. The result is a complex number
314 // struct with entries for the
315 //   1. real part and for the
316 //   2. imaginary part.
317 Type LLVMTypeConverter::convertComplexType(ComplexType type) const {
318   auto elementType = convertType(type.getElementType());
319   return LLVM::LLVMStructType::getLiteral(&getContext(),
320                                           {elementType, elementType});
321 }
322 
323 // Except for signatures, MLIR function types are converted into LLVM
324 // pointer-to-function types.
325 Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
326   return LLVM::LLVMPointerType::get(type.getContext());
327 }
328 
329 /// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
330 /// function arguments. Returns an empty container if none of these attributes
331 /// are found in any of the arguments.
332 static void
333 filterByValRefArgAttrs(FunctionOpInterface funcOp,
334                        SmallVectorImpl<std::optional<NamedAttribute>> &result) {
335   assert(result.empty() && "Unexpected non-empty output");
336   result.resize(funcOp.getNumArguments(), std::nullopt);
337   bool foundByValByRefAttrs = false;
338   for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
339     for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
340       if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
341            namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
342         foundByValByRefAttrs = true;
343         result[argIdx] = namedAttr;
344         break;
345       }
346     }
347   }
348 
349   if (!foundByValByRefAttrs)
350     result.clear();
351 }
352 
353 // Function types are converted to LLVM Function types by recursively converting
354 // argument and result types. If MLIR Function has zero results, the LLVM
355 // Function has one VoidType result. If MLIR Function has more than one result,
356 // they are into an LLVM StructType in their order of appearance.
357 // If `byValRefNonPtrAttrs` is provided, converted types of `llvm.byval` and
358 // `llvm.byref` function arguments which are not LLVM pointers are overridden
359 // with LLVM pointers. `llvm.byval` and `llvm.byref` arguments that were already
360 // converted to LLVM pointer types are removed from 'byValRefNonPtrAttrs`.
361 Type LLVMTypeConverter::convertFunctionSignatureImpl(
362     FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
363     LLVMTypeConverter::SignatureConversion &result,
364     SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs) const {
365   // Select the argument converter depending on the calling convention.
366   useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
367   auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
368                                              : structFuncArgTypeConverter;
369   // Convert argument types one by one and check for errors.
370   for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
371     SmallVector<Type, 8> converted;
372     if (failed(funcArgConverter(*this, type, converted)))
373       return {};
374 
375     // Rewrite converted type of `llvm.byval` or `llvm.byref` function
376     // argument that was not converted to an LLVM pointer types.
377     if (byValRefNonPtrAttrs != nullptr && !byValRefNonPtrAttrs->empty() &&
378         converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) {
379       // If the argument was already converted to an LLVM pointer type, we stop
380       // tracking it as it doesn't need more processing.
381       if (isa<LLVM::LLVMPointerType>(converted[0]))
382         (*byValRefNonPtrAttrs)[idx] = std::nullopt;
383       else
384         converted[0] = LLVM::LLVMPointerType::get(&getContext());
385     }
386 
387     result.addInputs(idx, converted);
388   }
389 
390   // If function does not return anything, create the void result type,
391   // if it returns on element, convert it, otherwise pack the result types into
392   // a struct.
393   Type resultType =
394       funcTy.getNumResults() == 0
395           ? LLVM::LLVMVoidType::get(&getContext())
396           : packFunctionResults(funcTy.getResults(), useBarePtrCallConv);
397   if (!resultType)
398     return {};
399   return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(),
400                                      isVariadic);
401 }
402 
403 Type LLVMTypeConverter::convertFunctionSignature(
404     FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
405     LLVMTypeConverter::SignatureConversion &result) const {
406   return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
407                                       result,
408                                       /*byValRefNonPtrAttrs=*/nullptr);
409 }
410 
411 Type LLVMTypeConverter::convertFunctionSignature(
412     FunctionOpInterface funcOp, bool isVariadic, bool useBarePtrCallConv,
413     LLVMTypeConverter::SignatureConversion &result,
414     SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs) const {
415   // Gather all `llvm.byval` and `llvm.byref` function arguments. Only those
416   // that were not converted to LLVM pointer types will be returned for further
417   // processing.
418   filterByValRefArgAttrs(funcOp, byValRefNonPtrAttrs);
419   auto funcTy = cast<FunctionType>(funcOp.getFunctionType());
420   return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
421                                       result, &byValRefNonPtrAttrs);
422 }
423 
424 /// Converts the function type to a C-compatible format, in particular using
425 /// pointers to memref descriptors for arguments.
426 std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
427 LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const {
428   SmallVector<Type, 4> inputs;
429 
430   Type resultType = type.getNumResults() == 0
431                         ? LLVM::LLVMVoidType::get(&getContext())
432                         : packFunctionResults(type.getResults());
433   if (!resultType)
434     return {};
435 
436   auto ptrType = LLVM::LLVMPointerType::get(type.getContext());
437   auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
438   if (structType) {
439     // Struct types cannot be safely returned via C interface. Make this a
440     // pointer argument, instead.
441     inputs.push_back(ptrType);
442     resultType = LLVM::LLVMVoidType::get(&getContext());
443   }
444 
445   for (Type t : type.getInputs()) {
446     auto converted = convertType(t);
447     if (!converted || !LLVM::isCompatibleType(converted))
448       return {};
449     if (isa<MemRefType, UnrankedMemRefType>(t))
450       converted = ptrType;
451     inputs.push_back(converted);
452   }
453 
454   return {LLVM::LLVMFunctionType::get(resultType, inputs), structType};
455 }
456 
457 /// Convert a memref type into a list of LLVM IR types that will form the
458 /// memref descriptor. The result contains the following types:
459 ///  1. The pointer to the allocated data buffer, followed by
460 ///  2. The pointer to the aligned data buffer, followed by
461 ///  3. A lowered `index`-type integer containing the distance between the
462 ///  beginning of the buffer and the first element to be accessed through the
463 ///  view, followed by
464 ///  4. An array containing as many `index`-type integers as the rank of the
465 ///  MemRef: the array represents the size, in number of elements, of the memref
466 ///  along the given dimension. For constant MemRef dimensions, the
467 ///  corresponding size entry is a constant whose runtime value must match the
468 ///  static value, followed by
469 ///  5. A second array containing as many `index`-type integers as the rank of
470 ///  the MemRef: the second array represents the "stride" (in tensor abstraction
471 ///  sense), i.e. the number of consecutive elements of the underlying buffer.
472 ///  TODO: add assertions for the static cases.
473 ///
474 ///  If `unpackAggregates` is set to true, the arrays described in (4) and (5)
475 ///  are expanded into individual index-type elements.
476 ///
477 ///  template <typename Elem, typename Index, size_t Rank>
478 ///  struct {
479 ///    Elem *allocatedPtr;
480 ///    Elem *alignedPtr;
481 ///    Index offset;
482 ///    Index sizes[Rank]; // omitted when rank == 0
483 ///    Index strides[Rank]; // omitted when rank == 0
484 ///  };
485 SmallVector<Type, 5>
486 LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
487                                              bool unpackAggregates) const {
488   if (!type.isStrided()) {
489     emitError(
490         UnknownLoc::get(type.getContext()),
491         "conversion to strided form failed either due to non-strided layout "
492         "maps (which should have been normalized away) or other reasons");
493     return {};
494   }
495 
496   Type elementType = convertType(type.getElementType());
497   if (!elementType)
498     return {};
499 
500   FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
501   if (failed(addressSpace)) {
502     emitError(UnknownLoc::get(type.getContext()),
503               "conversion of memref memory space ")
504         << type.getMemorySpace()
505         << " to integer address space "
506            "failed. Consider adding memory space conversions.";
507     return {};
508   }
509   auto ptrTy = LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
510 
511   auto indexTy = getIndexType();
512 
513   SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
514   auto rank = type.getRank();
515   if (rank == 0)
516     return results;
517 
518   if (unpackAggregates)
519     results.insert(results.end(), 2 * rank, indexTy);
520   else
521     results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank));
522   return results;
523 }
524 
525 unsigned
526 LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type,
527                                            const DataLayout &layout) const {
528   // Compute the descriptor size given that of its components indicated above.
529   unsigned space = *getMemRefAddressSpace(type);
530   return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
531          (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
532 }
533 
534 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
535 /// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
536 Type LLVMTypeConverter::convertMemRefType(MemRefType type) const {
537   // When converting a MemRefType to a struct with descriptor fields, do not
538   // unpack the `sizes` and `strides` arrays.
539   SmallVector<Type, 5> types =
540       getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
541   if (types.empty())
542     return {};
543   return LLVM::LLVMStructType::getLiteral(&getContext(), types);
544 }
545 
546 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
547 /// that will form the unranked memref descriptor. In particular, the fields
548 /// for an unranked memref descriptor are:
549 /// 1. index-typed rank, the dynamic rank of this MemRef
550 /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
551 ///    stack allocated (alloca) copy of a MemRef descriptor that got casted to
552 ///    be unranked.
553 SmallVector<Type, 2>
554 LLVMTypeConverter::getUnrankedMemRefDescriptorFields() const {
555   return {getIndexType(), LLVM::LLVMPointerType::get(&getContext())};
556 }
557 
558 unsigned LLVMTypeConverter::getUnrankedMemRefDescriptorSize(
559     UnrankedMemRefType type, const DataLayout &layout) const {
560   // Compute the descriptor size given that of its components indicated above.
561   unsigned space = *getMemRefAddressSpace(type);
562   return layout.getTypeSize(getIndexType()) +
563          llvm::divideCeil(getPointerBitwidth(space), 8);
564 }
565 
566 Type LLVMTypeConverter::convertUnrankedMemRefType(
567     UnrankedMemRefType type) const {
568   if (!convertType(type.getElementType()))
569     return {};
570   return LLVM::LLVMStructType::getLiteral(&getContext(),
571                                           getUnrankedMemRefDescriptorFields());
572 }
573 
574 FailureOr<unsigned>
575 LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) const {
576   if (!type.getMemorySpace()) // Default memory space -> 0.
577     return 0;
578   std::optional<Attribute> converted =
579       convertTypeAttribute(type, type.getMemorySpace());
580   if (!converted)
581     return failure();
582   if (!(*converted)) // Conversion to default is 0.
583     return 0;
584   if (auto explicitSpace = dyn_cast_if_present<IntegerAttr>(*converted)) {
585     if (explicitSpace.getType().isIndex() ||
586         explicitSpace.getType().isSignlessInteger())
587       return explicitSpace.getInt();
588   }
589   return failure();
590 }
591 
592 // Check if a memref type can be converted to a bare pointer.
593 bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
594   if (isa<UnrankedMemRefType>(type))
595     // Unranked memref is not supported in the bare pointer calling convention.
596     return false;
597 
598   // Check that the memref has static shape, strides and offset. Otherwise, it
599   // cannot be lowered to a bare pointer.
600   auto memrefTy = cast<MemRefType>(type);
601   if (!memrefTy.hasStaticShape())
602     return false;
603 
604   int64_t offset = 0;
605   SmallVector<int64_t, 4> strides;
606   if (failed(memrefTy.getStridesAndOffset(strides, offset)))
607     return false;
608 
609   for (int64_t stride : strides)
610     if (ShapedType::isDynamic(stride))
611       return false;
612 
613   return !ShapedType::isDynamic(offset);
614 }
615 
616 /// Convert a memref type to a bare pointer to the memref element type.
617 Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
618   if (!canConvertToBarePtr(type))
619     return {};
620   Type elementType = convertType(type.getElementType());
621   if (!elementType)
622     return {};
623   FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
624   if (failed(addressSpace))
625     return {};
626   return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
627 }
628 
629 /// Convert an n-D vector type to an LLVM vector type:
630 ///  * 0-D `vector<T>` are converted to vector<1xT>
631 ///  * 1-D `vector<axT>` remains as is while,
632 ///  * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
633 ///    `!llvm.array<ax...array<jxvector<kxT>>>`.
634 /// As LLVM supports arrays of scalable vectors, this method will also convert
635 /// n-D scalable vectors provided that only the trailing dim is scalable.
636 FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
637   auto elementType = convertType(type.getElementType());
638   if (!elementType)
639     return {};
640   if (type.getShape().empty())
641     return VectorType::get({1}, elementType);
642   Type vectorType = VectorType::get(type.getShape().back(), elementType,
643                                     type.getScalableDims().back());
644   assert(LLVM::isCompatibleVectorType(vectorType) &&
645          "expected vector type compatible with the LLVM dialect");
646   // For n-D vector types for which a _non-trailing_ dim is scalable,
647   // return a failure. Supporting such cases would require LLVM
648   // to support something akin "scalable arrays" of vectors.
649   if (llvm::is_contained(type.getScalableDims().drop_back(), true))
650     return failure();
651   auto shape = type.getShape();
652   for (int i = shape.size() - 2; i >= 0; --i)
653     vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
654   return vectorType;
655 }
656 
657 /// Convert a type in the context of the default or bare pointer calling
658 /// convention. Calling convention sensitive types, such as MemRefType and
659 /// UnrankedMemRefType, are converted following the specific rules for the
660 /// calling convention. Calling convention independent types are converted
661 /// following the default LLVM type conversions.
662 Type LLVMTypeConverter::convertCallingConventionType(
663     Type type, bool useBarePtrCallConv) const {
664   if (useBarePtrCallConv)
665     if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
666       return convertMemRefToBarePtr(memrefTy);
667 
668   return convertType(type);
669 }
670 
671 /// Promote the bare pointers in 'values' that resulted from memrefs to
672 /// descriptors. 'stdTypes' holds they types of 'values' before the conversion
673 /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
674 void LLVMTypeConverter::promoteBarePtrsToDescriptors(
675     ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
676     SmallVectorImpl<Value> &values) const {
677   assert(stdTypes.size() == values.size() &&
678          "The number of types and values doesn't match");
679   for (unsigned i = 0, end = values.size(); i < end; ++i)
680     if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
681       values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
682                                                     memrefTy, values[i]);
683 }
684 
685 /// Convert a non-empty list of types of values produced by an operation into an
686 /// LLVM-compatible type. In particular, if more than one value is
687 /// produced, create a literal structure with elements that correspond to each
688 /// of the types converted with `convertType`.
689 Type LLVMTypeConverter::packOperationResults(TypeRange types) const {
690   assert(!types.empty() && "expected non-empty list of type");
691   if (types.size() == 1)
692     return convertType(types[0]);
693 
694   SmallVector<Type> resultTypes;
695   resultTypes.reserve(types.size());
696   for (Type type : types) {
697     Type converted = convertType(type);
698     if (!converted || !LLVM::isCompatibleType(converted))
699       return {};
700     resultTypes.push_back(converted);
701   }
702 
703   return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
704 }
705 
706 /// Convert a non-empty list of types to be returned from a function into an
707 /// LLVM-compatible type. In particular, if more than one value is returned,
708 /// create an LLVM dialect structure type with elements that correspond to each
709 /// of the types converted with `convertCallingConventionType`.
710 Type LLVMTypeConverter::packFunctionResults(TypeRange types,
711                                             bool useBarePtrCallConv) const {
712   assert(!types.empty() && "expected non-empty list of type");
713 
714   useBarePtrCallConv |= options.useBarePtrCallConv;
715   if (types.size() == 1)
716     return convertCallingConventionType(types.front(), useBarePtrCallConv);
717 
718   SmallVector<Type> resultTypes;
719   resultTypes.reserve(types.size());
720   for (auto t : types) {
721     auto converted = convertCallingConventionType(t, useBarePtrCallConv);
722     if (!converted || !LLVM::isCompatibleType(converted))
723       return {};
724     resultTypes.push_back(converted);
725   }
726 
727   return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
728 }
729 
730 Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
731                                                     OpBuilder &builder) const {
732   // Alloca with proper alignment. We do not expect optimizations of this
733   // alloca op and so we omit allocating at the entry block.
734   auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
735   Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
736                                                builder.getIndexAttr(1));
737   Value allocated =
738       builder.create<LLVM::AllocaOp>(loc, ptrType, operand.getType(), one);
739   // Store into the alloca'ed descriptor.
740   builder.create<LLVM::StoreOp>(loc, operand, allocated);
741   return allocated;
742 }
743 
744 SmallVector<Value, 4>
745 LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
746                                    ValueRange operands, OpBuilder &builder,
747                                    bool useBarePtrCallConv) const {
748   SmallVector<Value, 4> promotedOperands;
749   promotedOperands.reserve(operands.size());
750   useBarePtrCallConv |= options.useBarePtrCallConv;
751   for (auto it : llvm::zip(opOperands, operands)) {
752     auto operand = std::get<0>(it);
753     auto llvmOperand = std::get<1>(it);
754 
755     if (useBarePtrCallConv) {
756       // For the bare-ptr calling convention, we only have to extract the
757       // aligned pointer of a memref.
758       if (dyn_cast<MemRefType>(operand.getType())) {
759         MemRefDescriptor desc(llvmOperand);
760         llvmOperand = desc.alignedPtr(builder, loc);
761       } else if (isa<UnrankedMemRefType>(operand.getType())) {
762         llvm_unreachable("Unranked memrefs are not supported");
763       }
764     } else {
765       if (isa<UnrankedMemRefType>(operand.getType())) {
766         UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
767                                          promotedOperands);
768         continue;
769       }
770       if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
771         MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
772                                  promotedOperands);
773         continue;
774       }
775     }
776 
777     promotedOperands.push_back(llvmOperand);
778   }
779   return promotedOperands;
780 }
781 
782 /// Callback to convert function argument types. It converts a MemRef function
783 /// argument to a list of non-aggregate types containing descriptor
784 /// information, and an UnrankedmemRef function argument to a list containing
785 /// the rank and a pointer to a descriptor struct.
786 LogicalResult
787 mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
788                                  SmallVectorImpl<Type> &result) {
789   if (auto memref = dyn_cast<MemRefType>(type)) {
790     // In signatures, Memref descriptors are expanded into lists of
791     // non-aggregate values.
792     auto converted =
793         converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true);
794     if (converted.empty())
795       return failure();
796     result.append(converted.begin(), converted.end());
797     return success();
798   }
799   if (isa<UnrankedMemRefType>(type)) {
800     auto converted = converter.getUnrankedMemRefDescriptorFields();
801     if (converted.empty())
802       return failure();
803     result.append(converted.begin(), converted.end());
804     return success();
805   }
806   auto converted = converter.convertType(type);
807   if (!converted)
808     return failure();
809   result.push_back(converted);
810   return success();
811 }
812 
813 /// Callback to convert function argument types. It converts MemRef function
814 /// arguments to bare pointers to the MemRef element type.
815 LogicalResult
816 mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
817                                   SmallVectorImpl<Type> &result) {
818   auto llvmTy = converter.convertCallingConventionType(
819       type, /*useBarePointerCallConv=*/true);
820   if (!llvmTy)
821     return failure();
822 
823   result.push_back(llvmTy);
824   return success();
825 }
826