xref: /llvm-project/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (revision e504ece6c15fa5b347a4d8ff7e6fc98ee109660e)
1 //===- LLVMTypes.cpp - MLIR LLVM dialect types ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the types for the LLVM dialect in MLIR. These MLIR types
10 // correspond to the LLVM IR type system.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TypeDetail.h"
15 
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/TypeSupport.h"
21 
22 #include "llvm/ADT/ScopeExit.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/TypeSize.h"
25 #include <optional>
26 
27 using namespace mlir;
28 using namespace mlir::LLVM;
29 
30 constexpr const static uint64_t kBitsInByte = 8;
31 
32 //===----------------------------------------------------------------------===//
33 // custom<FunctionTypes>
34 //===----------------------------------------------------------------------===//
35 
36 static ParseResult parseFunctionTypes(AsmParser &p, SmallVector<Type> &params,
37                                       bool &isVarArg) {
38   isVarArg = false;
39   // `(` `)`
40   if (succeeded(p.parseOptionalRParen()))
41     return success();
42 
43   // `(` `...` `)`
44   if (succeeded(p.parseOptionalEllipsis())) {
45     isVarArg = true;
46     return p.parseRParen();
47   }
48 
49   // type (`,` type)* (`,` `...`)?
50   Type type;
51   if (parsePrettyLLVMType(p, type))
52     return failure();
53   params.push_back(type);
54   while (succeeded(p.parseOptionalComma())) {
55     if (succeeded(p.parseOptionalEllipsis())) {
56       isVarArg = true;
57       return p.parseRParen();
58     }
59     if (parsePrettyLLVMType(p, type))
60       return failure();
61     params.push_back(type);
62   }
63   return p.parseRParen();
64 }
65 
66 static void printFunctionTypes(AsmPrinter &p, ArrayRef<Type> params,
67                                bool isVarArg) {
68   llvm::interleaveComma(params, p,
69                         [&](Type type) { printPrettyLLVMType(p, type); });
70   if (isVarArg) {
71     if (!params.empty())
72       p << ", ";
73     p << "...";
74   }
75   p << ')';
76 }
77 
78 //===----------------------------------------------------------------------===//
79 // custom<ExtTypeParams>
80 //===----------------------------------------------------------------------===//
81 
82 /// Parses the parameter list for a target extension type. The parameter list
83 /// contains an optional list of type parameters, followed by an optional list
84 /// of integer parameters. Type and integer parameters cannot be interleaved in
85 /// the list.
86 /// extTypeParams ::= typeList? | intList? | (typeList "," intList)
87 /// typeList      ::= type ("," type)*
88 /// intList       ::= integer ("," integer)*
89 static ParseResult
90 parseExtTypeParams(AsmParser &p, SmallVectorImpl<Type> &typeParams,
91                    SmallVectorImpl<unsigned int> &intParams) {
92   bool parseType = true;
93   auto typeOrIntParser = [&]() -> ParseResult {
94     unsigned int i;
95     auto intResult = p.parseOptionalInteger(i);
96     if (intResult.has_value() && !failed(*intResult)) {
97       // Successfully parsed an integer.
98       intParams.push_back(i);
99       // After the first integer was successfully parsed, no
100       // more types can be parsed.
101       parseType = false;
102       return success();
103     }
104     if (parseType) {
105       Type t;
106       if (!parsePrettyLLVMType(p, t)) {
107         // Successfully parsed a type.
108         typeParams.push_back(t);
109         return success();
110       }
111     }
112     return failure();
113   };
114   if (p.parseCommaSeparatedList(typeOrIntParser)) {
115     p.emitError(p.getCurrentLocation(),
116                 "failed to parse parameter list for target extension type");
117     return failure();
118   }
119   return success();
120 }
121 
122 static void printExtTypeParams(AsmPrinter &p, ArrayRef<Type> typeParams,
123                                ArrayRef<unsigned int> intParams) {
124   p << typeParams;
125   if (!typeParams.empty() && !intParams.empty())
126     p << ", ";
127 
128   p << intParams;
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // ODS-Generated Definitions
133 //===----------------------------------------------------------------------===//
134 
135 /// These are unused for now.
136 /// TODO: Move over to these once more types have been migrated to TypeDef.
137 LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
138 generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
139 LLVM_ATTRIBUTE_UNUSED static LogicalResult
140 generatedTypePrinter(Type def, AsmPrinter &printer);
141 
142 #include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"
143 
144 #define GET_TYPEDEF_CLASSES
145 #include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
146 
147 //===----------------------------------------------------------------------===//
148 // LLVMArrayType
149 //===----------------------------------------------------------------------===//
150 
151 bool LLVMArrayType::isValidElementType(Type type) {
152   return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
153                     LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(
154       type);
155 }
156 
157 LLVMArrayType LLVMArrayType::get(Type elementType, uint64_t numElements) {
158   assert(elementType && "expected non-null subtype");
159   return Base::get(elementType.getContext(), elementType, numElements);
160 }
161 
162 LLVMArrayType
163 LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError,
164                           Type elementType, uint64_t numElements) {
165   assert(elementType && "expected non-null subtype");
166   return Base::getChecked(emitError, elementType.getContext(), elementType,
167                           numElements);
168 }
169 
170 LogicalResult
171 LLVMArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
172                       Type elementType, uint64_t numElements) {
173   if (!isValidElementType(elementType))
174     return emitError() << "invalid array element type: " << elementType;
175   return success();
176 }
177 
178 //===----------------------------------------------------------------------===//
179 // DataLayoutTypeInterface
180 
181 llvm::TypeSize
182 LLVMArrayType::getTypeSizeInBits(const DataLayout &dataLayout,
183                                  DataLayoutEntryListRef params) const {
184   return llvm::TypeSize::getFixed(kBitsInByte *
185                                   getTypeSize(dataLayout, params));
186 }
187 
188 llvm::TypeSize LLVMArrayType::getTypeSize(const DataLayout &dataLayout,
189                                           DataLayoutEntryListRef params) const {
190   return llvm::alignTo(dataLayout.getTypeSize(getElementType()),
191                        dataLayout.getTypeABIAlignment(getElementType())) *
192          getNumElements();
193 }
194 
195 uint64_t LLVMArrayType::getABIAlignment(const DataLayout &dataLayout,
196                                         DataLayoutEntryListRef params) const {
197   return dataLayout.getTypeABIAlignment(getElementType());
198 }
199 
200 uint64_t
201 LLVMArrayType::getPreferredAlignment(const DataLayout &dataLayout,
202                                      DataLayoutEntryListRef params) const {
203   return dataLayout.getTypePreferredAlignment(getElementType());
204 }
205 
206 //===----------------------------------------------------------------------===//
207 // Function type.
208 //===----------------------------------------------------------------------===//
209 
210 bool LLVMFunctionType::isValidArgumentType(Type type) {
211   return !llvm::isa<LLVMVoidType, LLVMFunctionType>(type);
212 }
213 
214 bool LLVMFunctionType::isValidResultType(Type type) {
215   return !llvm::isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>(type);
216 }
217 
218 LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef<Type> arguments,
219                                        bool isVarArg) {
220   assert(result && "expected non-null result");
221   return Base::get(result.getContext(), result, arguments, isVarArg);
222 }
223 
224 LLVMFunctionType
225 LLVMFunctionType::getChecked(function_ref<InFlightDiagnostic()> emitError,
226                              Type result, ArrayRef<Type> arguments,
227                              bool isVarArg) {
228   assert(result && "expected non-null result");
229   return Base::getChecked(emitError, result.getContext(), result, arguments,
230                           isVarArg);
231 }
232 
233 LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs,
234                                          TypeRange results) const {
235   assert(results.size() == 1 && "expected a single result type");
236   return get(results[0], llvm::to_vector(inputs), isVarArg());
237 }
238 
239 ArrayRef<Type> LLVMFunctionType::getReturnTypes() const {
240   return static_cast<detail::LLVMFunctionTypeStorage *>(getImpl())->returnType;
241 }
242 
243 LogicalResult
244 LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
245                          Type result, ArrayRef<Type> arguments, bool) {
246   if (!isValidResultType(result))
247     return emitError() << "invalid function result type: " << result;
248 
249   for (Type arg : arguments)
250     if (!isValidArgumentType(arg))
251       return emitError() << "invalid function argument type: " << arg;
252 
253   return success();
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // DataLayoutTypeInterface
258 
259 constexpr const static uint64_t kDefaultPointerSizeBits = 64;
260 constexpr const static uint64_t kDefaultPointerAlignment = 8;
261 
262 std::optional<uint64_t> mlir::LLVM::extractPointerSpecValue(Attribute attr,
263                                                             PtrDLEntryPos pos) {
264   auto spec = cast<DenseIntElementsAttr>(attr);
265   auto idx = static_cast<int64_t>(pos);
266   if (idx >= spec.size())
267     return std::nullopt;
268   return spec.getValues<uint64_t>()[idx];
269 }
270 
271 /// Returns the part of the data layout entry that corresponds to `pos` for the
272 /// given `type` by interpreting the list of entries `params`. For the pointer
273 /// type in the default address space, returns the default value if the entries
274 /// do not provide a custom one, for other address spaces returns std::nullopt.
275 static std::optional<uint64_t>
276 getPointerDataLayoutEntry(DataLayoutEntryListRef params, LLVMPointerType type,
277                           PtrDLEntryPos pos) {
278   // First, look for the entry for the pointer in the current address space.
279   Attribute currentEntry;
280   for (DataLayoutEntryInterface entry : params) {
281     if (!entry.isTypeEntry())
282       continue;
283     if (cast<LLVMPointerType>(cast<Type>(entry.getKey())).getAddressSpace() ==
284         type.getAddressSpace()) {
285       currentEntry = entry.getValue();
286       break;
287     }
288   }
289   if (currentEntry) {
290     std::optional<uint64_t> value = extractPointerSpecValue(currentEntry, pos);
291     // If the optional `PtrDLEntryPos::Index` entry is not available, use the
292     // pointer size as the index bitwidth.
293     if (!value && pos == PtrDLEntryPos::Index)
294       value = extractPointerSpecValue(currentEntry, PtrDLEntryPos::Size);
295     bool isSizeOrIndex =
296         pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
297     return *value / (isSizeOrIndex ? 1 : kBitsInByte);
298   }
299 
300   // If not found, and this is the pointer to the default memory space, assume
301   // 64-bit pointers.
302   if (type.getAddressSpace() == 0) {
303     bool isSizeOrIndex =
304         pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
305     return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment;
306   }
307 
308   return std::nullopt;
309 }
310 
311 llvm::TypeSize
312 LLVMPointerType::getTypeSizeInBits(const DataLayout &dataLayout,
313                                    DataLayoutEntryListRef params) const {
314   if (std::optional<uint64_t> size =
315           getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Size))
316     return llvm::TypeSize::getFixed(*size);
317 
318   // For other memory spaces, use the size of the pointer to the default memory
319   // space.
320   return dataLayout.getTypeSizeInBits(get(getContext()));
321 }
322 
323 uint64_t LLVMPointerType::getABIAlignment(const DataLayout &dataLayout,
324                                           DataLayoutEntryListRef params) const {
325   if (std::optional<uint64_t> alignment =
326           getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Abi))
327     return *alignment;
328 
329   return dataLayout.getTypeABIAlignment(get(getContext()));
330 }
331 
332 uint64_t
333 LLVMPointerType::getPreferredAlignment(const DataLayout &dataLayout,
334                                        DataLayoutEntryListRef params) const {
335   if (std::optional<uint64_t> alignment =
336           getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Preferred))
337     return *alignment;
338 
339   return dataLayout.getTypePreferredAlignment(get(getContext()));
340 }
341 
342 std::optional<uint64_t>
343 LLVMPointerType::getIndexBitwidth(const DataLayout &dataLayout,
344                                   DataLayoutEntryListRef params) const {
345   if (std::optional<uint64_t> indexBitwidth =
346           getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Index))
347     return *indexBitwidth;
348 
349   return dataLayout.getTypeIndexBitwidth(get(getContext()));
350 }
351 
352 bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
353                                     DataLayoutEntryListRef newLayout) const {
354   for (DataLayoutEntryInterface newEntry : newLayout) {
355     if (!newEntry.isTypeEntry())
356       continue;
357     uint64_t size = kDefaultPointerSizeBits;
358     uint64_t abi = kDefaultPointerAlignment;
359     auto newType =
360         llvm::cast<LLVMPointerType>(llvm::cast<Type>(newEntry.getKey()));
361     const auto *it =
362         llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
363           if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
364             return llvm::cast<LLVMPointerType>(type).getAddressSpace() ==
365                    newType.getAddressSpace();
366           }
367           return false;
368         });
369     if (it == oldLayout.end()) {
370       llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
371         if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
372           return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 0;
373         }
374         return false;
375       });
376     }
377     if (it != oldLayout.end()) {
378       size = *extractPointerSpecValue(*it, PtrDLEntryPos::Size);
379       abi = *extractPointerSpecValue(*it, PtrDLEntryPos::Abi);
380     }
381 
382     Attribute newSpec = llvm::cast<DenseIntElementsAttr>(newEntry.getValue());
383     uint64_t newSize = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Size);
384     uint64_t newAbi = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Abi);
385     if (size != newSize || abi < newAbi || abi % newAbi != 0)
386       return false;
387   }
388   return true;
389 }
390 
391 LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
392                                              Location loc) const {
393   for (DataLayoutEntryInterface entry : entries) {
394     if (!entry.isTypeEntry())
395       continue;
396     auto key = llvm::cast<Type>(entry.getKey());
397     auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
398     if (!values || (values.size() != 3 && values.size() != 4)) {
399       return emitError(loc)
400              << "expected layout attribute for " << key
401              << " to be a dense integer elements attribute with 3 or 4 "
402                 "elements";
403     }
404     if (!values.getElementType().isInteger(64))
405       return emitError(loc) << "expected i64 parameters for " << key;
406 
407     if (extractPointerSpecValue(values, PtrDLEntryPos::Abi) >
408         extractPointerSpecValue(values, PtrDLEntryPos::Preferred)) {
409       return emitError(loc) << "preferred alignment is expected to be at least "
410                                "as large as ABI alignment";
411     }
412   }
413   return success();
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // Struct type.
418 //===----------------------------------------------------------------------===//
419 
420 bool LLVMStructType::isValidElementType(Type type) {
421   return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
422                     LLVMFunctionType, LLVMTokenType>(type);
423 }
424 
425 LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
426                                              StringRef name) {
427   return Base::get(context, name, /*opaque=*/false);
428 }
429 
430 LLVMStructType LLVMStructType::getIdentifiedChecked(
431     function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
432     StringRef name) {
433   return Base::getChecked(emitError, context, name, /*opaque=*/false);
434 }
435 
436 LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context,
437                                                 StringRef name,
438                                                 ArrayRef<Type> elements,
439                                                 bool isPacked) {
440   std::string stringName = name.str();
441   unsigned counter = 0;
442   do {
443     auto type = LLVMStructType::getIdentified(context, stringName);
444     if (type.isInitialized() || failed(type.setBody(elements, isPacked))) {
445       counter += 1;
446       stringName = (Twine(name) + "." + std::to_string(counter)).str();
447       continue;
448     }
449     return type;
450   } while (true);
451 }
452 
453 LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
454                                           ArrayRef<Type> types, bool isPacked) {
455   return Base::get(context, types, isPacked);
456 }
457 
458 LLVMStructType
459 LLVMStructType::getLiteralChecked(function_ref<InFlightDiagnostic()> emitError,
460                                   MLIRContext *context, ArrayRef<Type> types,
461                                   bool isPacked) {
462   return Base::getChecked(emitError, context, types, isPacked);
463 }
464 
465 LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
466   return Base::get(context, name, /*opaque=*/true);
467 }
468 
469 LLVMStructType
470 LLVMStructType::getOpaqueChecked(function_ref<InFlightDiagnostic()> emitError,
471                                  MLIRContext *context, StringRef name) {
472   return Base::getChecked(emitError, context, name, /*opaque=*/true);
473 }
474 
475 LogicalResult LLVMStructType::setBody(ArrayRef<Type> types, bool isPacked) {
476   assert(isIdentified() && "can only set bodies of identified structs");
477   assert(llvm::all_of(types, LLVMStructType::isValidElementType) &&
478          "expected valid body types");
479   return Base::mutate(types, isPacked);
480 }
481 
482 bool LLVMStructType::isPacked() const { return getImpl()->isPacked(); }
483 bool LLVMStructType::isIdentified() const { return getImpl()->isIdentified(); }
484 bool LLVMStructType::isOpaque() const {
485   return getImpl()->isIdentified() &&
486          (getImpl()->isOpaque() || !getImpl()->isInitialized());
487 }
488 bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); }
489 StringRef LLVMStructType::getName() const { return getImpl()->getIdentifier(); }
490 ArrayRef<Type> LLVMStructType::getBody() const {
491   return isIdentified() ? getImpl()->getIdentifiedStructBody()
492                         : getImpl()->getTypeList();
493 }
494 
495 LogicalResult
496 LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
497                                  bool) {
498   return success();
499 }
500 
501 LogicalResult
502 LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
503                                  ArrayRef<Type> types, bool) {
504   for (Type t : types)
505     if (!isValidElementType(t))
506       return emitError() << "invalid LLVM structure element type: " << t;
507 
508   return success();
509 }
510 
511 llvm::TypeSize
512 LLVMStructType::getTypeSizeInBits(const DataLayout &dataLayout,
513                                   DataLayoutEntryListRef params) const {
514   auto structSize = llvm::TypeSize::getFixed(0);
515   uint64_t structAlignment = 1;
516   for (Type element : getBody()) {
517     uint64_t elementAlignment =
518         isPacked() ? 1 : dataLayout.getTypeABIAlignment(element);
519     // Add padding to the struct size to align it to the abi alignment of the
520     // element type before than adding the size of the element.
521     structSize = llvm::alignTo(structSize, elementAlignment);
522     structSize += dataLayout.getTypeSize(element);
523 
524     // The alignment requirement of a struct is equal to the strictest alignment
525     // requirement of its elements.
526     structAlignment = std::max(elementAlignment, structAlignment);
527   }
528   // At the end, add padding to the struct to satisfy its own alignment
529   // requirement. Otherwise structs inside of arrays would be misaligned.
530   structSize = llvm::alignTo(structSize, structAlignment);
531   return structSize * kBitsInByte;
532 }
533 
534 namespace {
535 enum class StructDLEntryPos { Abi = 0, Preferred = 1 };
536 } // namespace
537 
538 static std::optional<uint64_t>
539 getStructDataLayoutEntry(DataLayoutEntryListRef params, LLVMStructType type,
540                          StructDLEntryPos pos) {
541   const auto *currentEntry =
542       llvm::find_if(params, [](DataLayoutEntryInterface entry) {
543         return entry.isTypeEntry();
544       });
545   if (currentEntry == params.end())
546     return std::nullopt;
547 
548   auto attr = llvm::cast<DenseIntElementsAttr>(currentEntry->getValue());
549   if (pos == StructDLEntryPos::Preferred &&
550       attr.size() <= static_cast<int64_t>(StructDLEntryPos::Preferred))
551     // If no preferred was specified, fall back to abi alignment
552     pos = StructDLEntryPos::Abi;
553 
554   return attr.getValues<uint64_t>()[static_cast<size_t>(pos)];
555 }
556 
557 static uint64_t calculateStructAlignment(const DataLayout &dataLayout,
558                                          DataLayoutEntryListRef params,
559                                          LLVMStructType type,
560                                          StructDLEntryPos pos) {
561   // Packed structs always have an abi alignment of 1
562   if (pos == StructDLEntryPos::Abi && type.isPacked()) {
563     return 1;
564   }
565 
566   // The alignment requirement of a struct is equal to the strictest alignment
567   // requirement of its elements.
568   uint64_t structAlignment = 1;
569   for (Type iter : type.getBody()) {
570     structAlignment =
571         std::max(dataLayout.getTypeABIAlignment(iter), structAlignment);
572   }
573 
574   // Entries are only allowed to be stricter than the required alignment
575   if (std::optional<uint64_t> entryResult =
576           getStructDataLayoutEntry(params, type, pos))
577     return std::max(*entryResult / kBitsInByte, structAlignment);
578 
579   return structAlignment;
580 }
581 
582 uint64_t LLVMStructType::getABIAlignment(const DataLayout &dataLayout,
583                                          DataLayoutEntryListRef params) const {
584   return calculateStructAlignment(dataLayout, params, *this,
585                                   StructDLEntryPos::Abi);
586 }
587 
588 uint64_t
589 LLVMStructType::getPreferredAlignment(const DataLayout &dataLayout,
590                                       DataLayoutEntryListRef params) const {
591   return calculateStructAlignment(dataLayout, params, *this,
592                                   StructDLEntryPos::Preferred);
593 }
594 
595 static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos) {
596   return llvm::cast<DenseIntElementsAttr>(attr)
597       .getValues<uint64_t>()[static_cast<size_t>(pos)];
598 }
599 
600 bool LLVMStructType::areCompatible(DataLayoutEntryListRef oldLayout,
601                                    DataLayoutEntryListRef newLayout) const {
602   for (DataLayoutEntryInterface newEntry : newLayout) {
603     if (!newEntry.isTypeEntry())
604       continue;
605 
606     const auto *previousEntry =
607         llvm::find_if(oldLayout, [](DataLayoutEntryInterface entry) {
608           return entry.isTypeEntry();
609         });
610     if (previousEntry == oldLayout.end())
611       continue;
612 
613     uint64_t abi = extractStructSpecValue(previousEntry->getValue(),
614                                           StructDLEntryPos::Abi);
615     uint64_t newAbi =
616         extractStructSpecValue(newEntry.getValue(), StructDLEntryPos::Abi);
617     if (abi < newAbi || abi % newAbi != 0)
618       return false;
619   }
620   return true;
621 }
622 
623 LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries,
624                                             Location loc) const {
625   for (DataLayoutEntryInterface entry : entries) {
626     if (!entry.isTypeEntry())
627       continue;
628 
629     auto key = llvm::cast<LLVMStructType>(llvm::cast<Type>(entry.getKey()));
630     auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
631     if (!values || (values.size() != 2 && values.size() != 1)) {
632       return emitError(loc)
633              << "expected layout attribute for "
634              << llvm::cast<Type>(entry.getKey())
635              << " to be a dense integer elements attribute of 1 or 2 elements";
636     }
637     if (!values.getElementType().isInteger(64))
638       return emitError(loc) << "expected i64 entries for " << key;
639 
640     if (key.isIdentified() || !key.getBody().empty()) {
641       return emitError(loc) << "unexpected layout attribute for struct " << key;
642     }
643 
644     if (values.size() == 1)
645       continue;
646 
647     if (extractStructSpecValue(values, StructDLEntryPos::Abi) >
648         extractStructSpecValue(values, StructDLEntryPos::Preferred)) {
649       return emitError(loc) << "preferred alignment is expected to be at least "
650                                "as large as ABI alignment";
651     }
652   }
653   return mlir::success();
654 }
655 
656 //===----------------------------------------------------------------------===//
657 // Vector types.
658 //===----------------------------------------------------------------------===//
659 
660 /// Verifies that the type about to be constructed is well-formed.
661 template <typename VecTy>
662 static LogicalResult
663 verifyVectorConstructionInvariants(function_ref<InFlightDiagnostic()> emitError,
664                                    Type elementType, unsigned numElements) {
665   if (numElements == 0)
666     return emitError() << "the number of vector elements must be positive";
667 
668   if (!VecTy::isValidElementType(elementType))
669     return emitError() << "invalid vector element type";
670 
671   return success();
672 }
673 
674 LLVMFixedVectorType LLVMFixedVectorType::get(Type elementType,
675                                              unsigned numElements) {
676   assert(elementType && "expected non-null subtype");
677   return Base::get(elementType.getContext(), elementType, numElements);
678 }
679 
680 LLVMFixedVectorType
681 LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
682                                 Type elementType, unsigned numElements) {
683   assert(elementType && "expected non-null subtype");
684   return Base::getChecked(emitError, elementType.getContext(), elementType,
685                           numElements);
686 }
687 
688 bool LLVMFixedVectorType::isValidElementType(Type type) {
689   return llvm::isa<LLVMPointerType, LLVMPPCFP128Type>(type);
690 }
691 
692 LogicalResult
693 LLVMFixedVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
694                             Type elementType, unsigned numElements) {
695   return verifyVectorConstructionInvariants<LLVMFixedVectorType>(
696       emitError, elementType, numElements);
697 }
698 
699 //===----------------------------------------------------------------------===//
700 // LLVMScalableVectorType.
701 //===----------------------------------------------------------------------===//
702 
703 LLVMScalableVectorType LLVMScalableVectorType::get(Type elementType,
704                                                    unsigned minNumElements) {
705   assert(elementType && "expected non-null subtype");
706   return Base::get(elementType.getContext(), elementType, minNumElements);
707 }
708 
709 LLVMScalableVectorType
710 LLVMScalableVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
711                                    Type elementType, unsigned minNumElements) {
712   assert(elementType && "expected non-null subtype");
713   return Base::getChecked(emitError, elementType.getContext(), elementType,
714                           minNumElements);
715 }
716 
717 bool LLVMScalableVectorType::isValidElementType(Type type) {
718   if (auto intType = llvm::dyn_cast<IntegerType>(type))
719     return intType.isSignless();
720 
721   return isCompatibleFloatingPointType(type) ||
722          llvm::isa<LLVMPointerType>(type);
723 }
724 
725 LogicalResult
726 LLVMScalableVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
727                                Type elementType, unsigned numElements) {
728   return verifyVectorConstructionInvariants<LLVMScalableVectorType>(
729       emitError, elementType, numElements);
730 }
731 
732 //===----------------------------------------------------------------------===//
733 // LLVMTargetExtType.
734 //===----------------------------------------------------------------------===//
735 
736 static constexpr llvm::StringRef kSpirvPrefix = "spirv.";
737 static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount";
738 
739 bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const {
740   // See llvm/lib/IR/Type.cpp for reference.
741   uint64_t properties = 0;
742 
743   if (getExtTypeName().starts_with(kSpirvPrefix))
744     properties |=
745         (LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal);
746 
747   return (properties & prop) == prop;
748 }
749 
750 bool LLVM::LLVMTargetExtType::supportsMemOps() const {
751   // See llvm/lib/IR/Type.cpp for reference.
752   if (getExtTypeName().starts_with(kSpirvPrefix))
753     return true;
754 
755   if (getExtTypeName() == kArmSVCount)
756     return true;
757 
758   return false;
759 }
760 
761 //===----------------------------------------------------------------------===//
762 // Utility functions.
763 //===----------------------------------------------------------------------===//
764 
765 bool mlir::LLVM::isCompatibleOuterType(Type type) {
766   // clang-format off
767   if (llvm::isa<
768       BFloat16Type,
769       Float16Type,
770       Float32Type,
771       Float64Type,
772       Float80Type,
773       Float128Type,
774       LLVMArrayType,
775       LLVMFunctionType,
776       LLVMLabelType,
777       LLVMMetadataType,
778       LLVMPPCFP128Type,
779       LLVMPointerType,
780       LLVMStructType,
781       LLVMTokenType,
782       LLVMFixedVectorType,
783       LLVMScalableVectorType,
784       LLVMTargetExtType,
785       LLVMVoidType,
786       LLVMX86AMXType
787     >(type)) {
788     // clang-format on
789     return true;
790   }
791 
792   // Only signless integers are compatible.
793   if (auto intType = llvm::dyn_cast<IntegerType>(type))
794     return intType.isSignless();
795 
796   // 1D vector types are compatible.
797   if (auto vecType = llvm::dyn_cast<VectorType>(type))
798     return vecType.getRank() == 1;
799 
800   return false;
801 }
802 
803 static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
804   if (!compatibleTypes.insert(type).second)
805     return true;
806 
807   auto isCompatible = [&](Type type) {
808     return isCompatibleImpl(type, compatibleTypes);
809   };
810 
811   bool result =
812       llvm::TypeSwitch<Type, bool>(type)
813           .Case<LLVMStructType>([&](auto structType) {
814             return llvm::all_of(structType.getBody(), isCompatible);
815           })
816           .Case<LLVMFunctionType>([&](auto funcType) {
817             return isCompatible(funcType.getReturnType()) &&
818                    llvm::all_of(funcType.getParams(), isCompatible);
819           })
820           .Case<IntegerType>([](auto intType) { return intType.isSignless(); })
821           .Case<VectorType>([&](auto vecType) {
822             return vecType.getRank() == 1 &&
823                    isCompatible(vecType.getElementType());
824           })
825           .Case<LLVMPointerType>([&](auto pointerType) { return true; })
826           .Case<LLVMTargetExtType>([&](auto extType) {
827             return llvm::all_of(extType.getTypeParams(), isCompatible);
828           })
829           // clang-format off
830           .Case<
831               LLVMFixedVectorType,
832               LLVMScalableVectorType,
833               LLVMArrayType
834           >([&](auto containerType) {
835             return isCompatible(containerType.getElementType());
836           })
837           .Case<
838             BFloat16Type,
839             Float16Type,
840             Float32Type,
841             Float64Type,
842             Float80Type,
843             Float128Type,
844             LLVMLabelType,
845             LLVMMetadataType,
846             LLVMPPCFP128Type,
847             LLVMTokenType,
848             LLVMVoidType,
849             LLVMX86AMXType
850           >([](Type) { return true; })
851           // clang-format on
852           .Default([](Type) { return false; });
853 
854   if (!result)
855     compatibleTypes.erase(type);
856 
857   return result;
858 }
859 
860 bool LLVMDialect::isCompatibleType(Type type) {
861   if (auto *llvmDialect =
862           type.getContext()->getLoadedDialect<LLVM::LLVMDialect>())
863     return isCompatibleImpl(type, llvmDialect->compatibleTypes.get());
864 
865   DenseSet<Type> localCompatibleTypes;
866   return isCompatibleImpl(type, localCompatibleTypes);
867 }
868 
869 bool mlir::LLVM::isCompatibleType(Type type) {
870   return LLVMDialect::isCompatibleType(type);
871 }
872 
873 bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {
874   return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
875                    Float80Type, Float128Type, LLVMPPCFP128Type>(type);
876 }
877 
878 bool mlir::LLVM::isCompatibleVectorType(Type type) {
879   if (llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType>(type))
880     return true;
881 
882   if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
883     if (vecType.getRank() != 1)
884       return false;
885     Type elementType = vecType.getElementType();
886     if (auto intType = llvm::dyn_cast<IntegerType>(elementType))
887       return intType.isSignless();
888     return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
889                      Float80Type, Float128Type>(elementType);
890   }
891   return false;
892 }
893 
894 Type mlir::LLVM::getVectorElementType(Type type) {
895   return llvm::TypeSwitch<Type, Type>(type)
896       .Case<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>(
897           [](auto ty) { return ty.getElementType(); })
898       .Default([](Type) -> Type {
899         llvm_unreachable("incompatible with LLVM vector type");
900       });
901 }
902 
903 llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
904   return llvm::TypeSwitch<Type, llvm::ElementCount>(type)
905       .Case([](VectorType ty) {
906         if (ty.isScalable())
907           return llvm::ElementCount::getScalable(ty.getNumElements());
908         return llvm::ElementCount::getFixed(ty.getNumElements());
909       })
910       .Case([](LLVMFixedVectorType ty) {
911         return llvm::ElementCount::getFixed(ty.getNumElements());
912       })
913       .Case([](LLVMScalableVectorType ty) {
914         return llvm::ElementCount::getScalable(ty.getMinNumElements());
915       })
916       .Default([](Type) -> llvm::ElementCount {
917         llvm_unreachable("incompatible with LLVM vector type");
918       });
919 }
920 
921 bool mlir::LLVM::isScalableVectorType(Type vectorType) {
922   assert((llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>(
923              vectorType)) &&
924          "expected LLVM-compatible vector type");
925   return !llvm::isa<LLVMFixedVectorType>(vectorType) &&
926          (llvm::isa<LLVMScalableVectorType>(vectorType) ||
927           llvm::cast<VectorType>(vectorType).isScalable());
928 }
929 
930 Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
931                                bool isScalable) {
932   bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType);
933   bool useBuiltIn = VectorType::isValidElementType(elementType);
934   (void)useBuiltIn;
935   assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type "
936                                    "to be either builtin or LLVM dialect type");
937   if (useLLVM) {
938     if (isScalable)
939       return LLVMScalableVectorType::get(elementType, numElements);
940     return LLVMFixedVectorType::get(elementType, numElements);
941   }
942 
943   // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
944   // scalable/non-scalable.
945   return VectorType::get(numElements, elementType, {isScalable});
946 }
947 
948 Type mlir::LLVM::getVectorType(Type elementType,
949                                const llvm::ElementCount &numElements) {
950   if (numElements.isScalable())
951     return getVectorType(elementType, numElements.getKnownMinValue(),
952                          /*isScalable=*/true);
953   return getVectorType(elementType, numElements.getFixedValue(),
954                        /*isScalable=*/false);
955 }
956 
957 Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
958   bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType);
959   bool useBuiltIn = VectorType::isValidElementType(elementType);
960   (void)useBuiltIn;
961   assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type "
962                                    "to be either builtin or LLVM dialect type");
963   if (useLLVM)
964     return LLVMFixedVectorType::get(elementType, numElements);
965   return VectorType::get(numElements, elementType);
966 }
967 
968 Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
969   bool useLLVM = LLVMScalableVectorType::isValidElementType(elementType);
970   bool useBuiltIn = VectorType::isValidElementType(elementType);
971   (void)useBuiltIn;
972   assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible scalable-vector "
973                                    "type to be either builtin or LLVM dialect "
974                                    "type");
975   if (useLLVM)
976     return LLVMScalableVectorType::get(elementType, numElements);
977 
978   // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
979   // scalable/non-scalable.
980   return VectorType::get(numElements, elementType, /*scalableDims=*/true);
981 }
982 
983 llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
984   assert(isCompatibleType(type) &&
985          "expected a type compatible with the LLVM dialect");
986 
987   return llvm::TypeSwitch<Type, llvm::TypeSize>(type)
988       .Case<BFloat16Type, Float16Type>(
989           [](Type) { return llvm::TypeSize::getFixed(16); })
990       .Case<Float32Type>([](Type) { return llvm::TypeSize::getFixed(32); })
991       .Case<Float64Type>([](Type) { return llvm::TypeSize::getFixed(64); })
992       .Case<Float80Type>([](Type) { return llvm::TypeSize::getFixed(80); })
993       .Case<Float128Type>([](Type) { return llvm::TypeSize::getFixed(128); })
994       .Case<IntegerType>([](IntegerType intTy) {
995         return llvm::TypeSize::getFixed(intTy.getWidth());
996       })
997       .Case<LLVMPPCFP128Type>(
998           [](Type) { return llvm::TypeSize::getFixed(128); })
999       .Case<LLVMFixedVectorType>([](LLVMFixedVectorType t) {
1000         llvm::TypeSize elementSize =
1001             getPrimitiveTypeSizeInBits(t.getElementType());
1002         return llvm::TypeSize(elementSize.getFixedValue() * t.getNumElements(),
1003                               elementSize.isScalable());
1004       })
1005       .Case<VectorType>([](VectorType t) {
1006         assert(isCompatibleVectorType(t) &&
1007                "unexpected incompatible with LLVM vector type");
1008         llvm::TypeSize elementSize =
1009             getPrimitiveTypeSizeInBits(t.getElementType());
1010         return llvm::TypeSize(elementSize.getFixedValue() * t.getNumElements(),
1011                               elementSize.isScalable());
1012       })
1013       .Default([](Type ty) {
1014         assert((llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
1015                           LLVMTokenType, LLVMStructType, LLVMArrayType,
1016                           LLVMPointerType, LLVMFunctionType, LLVMTargetExtType>(
1017                    ty)) &&
1018                "unexpected missing support for primitive type");
1019         return llvm::TypeSize::getFixed(0);
1020       });
1021 }
1022 
1023 //===----------------------------------------------------------------------===//
1024 // LLVMDialect
1025 //===----------------------------------------------------------------------===//
1026 
1027 void LLVMDialect::registerTypes() {
1028   addTypes<
1029 #define GET_TYPEDEF_LIST
1030 #include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
1031       >();
1032 }
1033 
1034 Type LLVMDialect::parseType(DialectAsmParser &parser) const {
1035   return detail::parseType(parser);
1036 }
1037 
1038 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
1039   return detail::printType(type, os);
1040 }
1041