1 //===- LLVMTypes.cpp - MLIR LLVM dialect types ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the types for the LLVM dialect in MLIR. These MLIR types 10 // correspond to the LLVM IR type system. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TypeDetail.h" 15 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/TypeSupport.h" 21 22 #include "llvm/ADT/ScopeExit.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 #include "llvm/Support/TypeSize.h" 25 #include <optional> 26 27 using namespace mlir; 28 using namespace mlir::LLVM; 29 30 constexpr const static uint64_t kBitsInByte = 8; 31 32 //===----------------------------------------------------------------------===// 33 // custom<FunctionTypes> 34 //===----------------------------------------------------------------------===// 35 36 static ParseResult parseFunctionTypes(AsmParser &p, SmallVector<Type> ¶ms, 37 bool &isVarArg) { 38 isVarArg = false; 39 // `(` `)` 40 if (succeeded(p.parseOptionalRParen())) 41 return success(); 42 43 // `(` `...` `)` 44 if (succeeded(p.parseOptionalEllipsis())) { 45 isVarArg = true; 46 return p.parseRParen(); 47 } 48 49 // type (`,` type)* (`,` `...`)? 50 Type type; 51 if (parsePrettyLLVMType(p, type)) 52 return failure(); 53 params.push_back(type); 54 while (succeeded(p.parseOptionalComma())) { 55 if (succeeded(p.parseOptionalEllipsis())) { 56 isVarArg = true; 57 return p.parseRParen(); 58 } 59 if (parsePrettyLLVMType(p, type)) 60 return failure(); 61 params.push_back(type); 62 } 63 return p.parseRParen(); 64 } 65 66 static void printFunctionTypes(AsmPrinter &p, ArrayRef<Type> params, 67 bool isVarArg) { 68 llvm::interleaveComma(params, p, 69 [&](Type type) { printPrettyLLVMType(p, type); }); 70 if (isVarArg) { 71 if (!params.empty()) 72 p << ", "; 73 p << "..."; 74 } 75 p << ')'; 76 } 77 78 //===----------------------------------------------------------------------===// 79 // custom<ExtTypeParams> 80 //===----------------------------------------------------------------------===// 81 82 /// Parses the parameter list for a target extension type. The parameter list 83 /// contains an optional list of type parameters, followed by an optional list 84 /// of integer parameters. Type and integer parameters cannot be interleaved in 85 /// the list. 86 /// extTypeParams ::= typeList? | intList? | (typeList "," intList) 87 /// typeList ::= type ("," type)* 88 /// intList ::= integer ("," integer)* 89 static ParseResult 90 parseExtTypeParams(AsmParser &p, SmallVectorImpl<Type> &typeParams, 91 SmallVectorImpl<unsigned int> &intParams) { 92 bool parseType = true; 93 auto typeOrIntParser = [&]() -> ParseResult { 94 unsigned int i; 95 auto intResult = p.parseOptionalInteger(i); 96 if (intResult.has_value() && !failed(*intResult)) { 97 // Successfully parsed an integer. 98 intParams.push_back(i); 99 // After the first integer was successfully parsed, no 100 // more types can be parsed. 101 parseType = false; 102 return success(); 103 } 104 if (parseType) { 105 Type t; 106 if (!parsePrettyLLVMType(p, t)) { 107 // Successfully parsed a type. 108 typeParams.push_back(t); 109 return success(); 110 } 111 } 112 return failure(); 113 }; 114 if (p.parseCommaSeparatedList(typeOrIntParser)) { 115 p.emitError(p.getCurrentLocation(), 116 "failed to parse parameter list for target extension type"); 117 return failure(); 118 } 119 return success(); 120 } 121 122 static void printExtTypeParams(AsmPrinter &p, ArrayRef<Type> typeParams, 123 ArrayRef<unsigned int> intParams) { 124 p << typeParams; 125 if (!typeParams.empty() && !intParams.empty()) 126 p << ", "; 127 128 p << intParams; 129 } 130 131 //===----------------------------------------------------------------------===// 132 // ODS-Generated Definitions 133 //===----------------------------------------------------------------------===// 134 135 /// These are unused for now. 136 /// TODO: Move over to these once more types have been migrated to TypeDef. 137 LLVM_ATTRIBUTE_UNUSED static OptionalParseResult 138 generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); 139 LLVM_ATTRIBUTE_UNUSED static LogicalResult 140 generatedTypePrinter(Type def, AsmPrinter &printer); 141 142 #include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc" 143 144 #define GET_TYPEDEF_CLASSES 145 #include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc" 146 147 //===----------------------------------------------------------------------===// 148 // LLVMArrayType 149 //===----------------------------------------------------------------------===// 150 151 bool LLVMArrayType::isValidElementType(Type type) { 152 return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType, 153 LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>( 154 type); 155 } 156 157 LLVMArrayType LLVMArrayType::get(Type elementType, uint64_t numElements) { 158 assert(elementType && "expected non-null subtype"); 159 return Base::get(elementType.getContext(), elementType, numElements); 160 } 161 162 LLVMArrayType 163 LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError, 164 Type elementType, uint64_t numElements) { 165 assert(elementType && "expected non-null subtype"); 166 return Base::getChecked(emitError, elementType.getContext(), elementType, 167 numElements); 168 } 169 170 LogicalResult 171 LLVMArrayType::verify(function_ref<InFlightDiagnostic()> emitError, 172 Type elementType, uint64_t numElements) { 173 if (!isValidElementType(elementType)) 174 return emitError() << "invalid array element type: " << elementType; 175 return success(); 176 } 177 178 //===----------------------------------------------------------------------===// 179 // DataLayoutTypeInterface 180 181 llvm::TypeSize 182 LLVMArrayType::getTypeSizeInBits(const DataLayout &dataLayout, 183 DataLayoutEntryListRef params) const { 184 return llvm::TypeSize::getFixed(kBitsInByte * 185 getTypeSize(dataLayout, params)); 186 } 187 188 llvm::TypeSize LLVMArrayType::getTypeSize(const DataLayout &dataLayout, 189 DataLayoutEntryListRef params) const { 190 return llvm::alignTo(dataLayout.getTypeSize(getElementType()), 191 dataLayout.getTypeABIAlignment(getElementType())) * 192 getNumElements(); 193 } 194 195 uint64_t LLVMArrayType::getABIAlignment(const DataLayout &dataLayout, 196 DataLayoutEntryListRef params) const { 197 return dataLayout.getTypeABIAlignment(getElementType()); 198 } 199 200 uint64_t 201 LLVMArrayType::getPreferredAlignment(const DataLayout &dataLayout, 202 DataLayoutEntryListRef params) const { 203 return dataLayout.getTypePreferredAlignment(getElementType()); 204 } 205 206 //===----------------------------------------------------------------------===// 207 // Function type. 208 //===----------------------------------------------------------------------===// 209 210 bool LLVMFunctionType::isValidArgumentType(Type type) { 211 return !llvm::isa<LLVMVoidType, LLVMFunctionType>(type); 212 } 213 214 bool LLVMFunctionType::isValidResultType(Type type) { 215 return !llvm::isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>(type); 216 } 217 218 LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef<Type> arguments, 219 bool isVarArg) { 220 assert(result && "expected non-null result"); 221 return Base::get(result.getContext(), result, arguments, isVarArg); 222 } 223 224 LLVMFunctionType 225 LLVMFunctionType::getChecked(function_ref<InFlightDiagnostic()> emitError, 226 Type result, ArrayRef<Type> arguments, 227 bool isVarArg) { 228 assert(result && "expected non-null result"); 229 return Base::getChecked(emitError, result.getContext(), result, arguments, 230 isVarArg); 231 } 232 233 LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs, 234 TypeRange results) const { 235 assert(results.size() == 1 && "expected a single result type"); 236 return get(results[0], llvm::to_vector(inputs), isVarArg()); 237 } 238 239 ArrayRef<Type> LLVMFunctionType::getReturnTypes() const { 240 return static_cast<detail::LLVMFunctionTypeStorage *>(getImpl())->returnType; 241 } 242 243 LogicalResult 244 LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError, 245 Type result, ArrayRef<Type> arguments, bool) { 246 if (!isValidResultType(result)) 247 return emitError() << "invalid function result type: " << result; 248 249 for (Type arg : arguments) 250 if (!isValidArgumentType(arg)) 251 return emitError() << "invalid function argument type: " << arg; 252 253 return success(); 254 } 255 256 //===----------------------------------------------------------------------===// 257 // DataLayoutTypeInterface 258 259 constexpr const static uint64_t kDefaultPointerSizeBits = 64; 260 constexpr const static uint64_t kDefaultPointerAlignment = 8; 261 262 std::optional<uint64_t> mlir::LLVM::extractPointerSpecValue(Attribute attr, 263 PtrDLEntryPos pos) { 264 auto spec = cast<DenseIntElementsAttr>(attr); 265 auto idx = static_cast<int64_t>(pos); 266 if (idx >= spec.size()) 267 return std::nullopt; 268 return spec.getValues<uint64_t>()[idx]; 269 } 270 271 /// Returns the part of the data layout entry that corresponds to `pos` for the 272 /// given `type` by interpreting the list of entries `params`. For the pointer 273 /// type in the default address space, returns the default value if the entries 274 /// do not provide a custom one, for other address spaces returns std::nullopt. 275 static std::optional<uint64_t> 276 getPointerDataLayoutEntry(DataLayoutEntryListRef params, LLVMPointerType type, 277 PtrDLEntryPos pos) { 278 // First, look for the entry for the pointer in the current address space. 279 Attribute currentEntry; 280 for (DataLayoutEntryInterface entry : params) { 281 if (!entry.isTypeEntry()) 282 continue; 283 if (cast<LLVMPointerType>(cast<Type>(entry.getKey())).getAddressSpace() == 284 type.getAddressSpace()) { 285 currentEntry = entry.getValue(); 286 break; 287 } 288 } 289 if (currentEntry) { 290 std::optional<uint64_t> value = extractPointerSpecValue(currentEntry, pos); 291 // If the optional `PtrDLEntryPos::Index` entry is not available, use the 292 // pointer size as the index bitwidth. 293 if (!value && pos == PtrDLEntryPos::Index) 294 value = extractPointerSpecValue(currentEntry, PtrDLEntryPos::Size); 295 bool isSizeOrIndex = 296 pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index; 297 return *value / (isSizeOrIndex ? 1 : kBitsInByte); 298 } 299 300 // If not found, and this is the pointer to the default memory space, assume 301 // 64-bit pointers. 302 if (type.getAddressSpace() == 0) { 303 bool isSizeOrIndex = 304 pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index; 305 return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment; 306 } 307 308 return std::nullopt; 309 } 310 311 llvm::TypeSize 312 LLVMPointerType::getTypeSizeInBits(const DataLayout &dataLayout, 313 DataLayoutEntryListRef params) const { 314 if (std::optional<uint64_t> size = 315 getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Size)) 316 return llvm::TypeSize::getFixed(*size); 317 318 // For other memory spaces, use the size of the pointer to the default memory 319 // space. 320 return dataLayout.getTypeSizeInBits(get(getContext())); 321 } 322 323 uint64_t LLVMPointerType::getABIAlignment(const DataLayout &dataLayout, 324 DataLayoutEntryListRef params) const { 325 if (std::optional<uint64_t> alignment = 326 getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Abi)) 327 return *alignment; 328 329 return dataLayout.getTypeABIAlignment(get(getContext())); 330 } 331 332 uint64_t 333 LLVMPointerType::getPreferredAlignment(const DataLayout &dataLayout, 334 DataLayoutEntryListRef params) const { 335 if (std::optional<uint64_t> alignment = 336 getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Preferred)) 337 return *alignment; 338 339 return dataLayout.getTypePreferredAlignment(get(getContext())); 340 } 341 342 std::optional<uint64_t> 343 LLVMPointerType::getIndexBitwidth(const DataLayout &dataLayout, 344 DataLayoutEntryListRef params) const { 345 if (std::optional<uint64_t> indexBitwidth = 346 getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Index)) 347 return *indexBitwidth; 348 349 return dataLayout.getTypeIndexBitwidth(get(getContext())); 350 } 351 352 bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout, 353 DataLayoutEntryListRef newLayout) const { 354 for (DataLayoutEntryInterface newEntry : newLayout) { 355 if (!newEntry.isTypeEntry()) 356 continue; 357 uint64_t size = kDefaultPointerSizeBits; 358 uint64_t abi = kDefaultPointerAlignment; 359 auto newType = 360 llvm::cast<LLVMPointerType>(llvm::cast<Type>(newEntry.getKey())); 361 const auto *it = 362 llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { 363 if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) { 364 return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 365 newType.getAddressSpace(); 366 } 367 return false; 368 }); 369 if (it == oldLayout.end()) { 370 llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { 371 if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) { 372 return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 0; 373 } 374 return false; 375 }); 376 } 377 if (it != oldLayout.end()) { 378 size = *extractPointerSpecValue(*it, PtrDLEntryPos::Size); 379 abi = *extractPointerSpecValue(*it, PtrDLEntryPos::Abi); 380 } 381 382 Attribute newSpec = llvm::cast<DenseIntElementsAttr>(newEntry.getValue()); 383 uint64_t newSize = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Size); 384 uint64_t newAbi = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Abi); 385 if (size != newSize || abi < newAbi || abi % newAbi != 0) 386 return false; 387 } 388 return true; 389 } 390 391 LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries, 392 Location loc) const { 393 for (DataLayoutEntryInterface entry : entries) { 394 if (!entry.isTypeEntry()) 395 continue; 396 auto key = llvm::cast<Type>(entry.getKey()); 397 auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue()); 398 if (!values || (values.size() != 3 && values.size() != 4)) { 399 return emitError(loc) 400 << "expected layout attribute for " << key 401 << " to be a dense integer elements attribute with 3 or 4 " 402 "elements"; 403 } 404 if (!values.getElementType().isInteger(64)) 405 return emitError(loc) << "expected i64 parameters for " << key; 406 407 if (extractPointerSpecValue(values, PtrDLEntryPos::Abi) > 408 extractPointerSpecValue(values, PtrDLEntryPos::Preferred)) { 409 return emitError(loc) << "preferred alignment is expected to be at least " 410 "as large as ABI alignment"; 411 } 412 } 413 return success(); 414 } 415 416 //===----------------------------------------------------------------------===// 417 // Struct type. 418 //===----------------------------------------------------------------------===// 419 420 bool LLVMStructType::isValidElementType(Type type) { 421 return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType, 422 LLVMFunctionType, LLVMTokenType>(type); 423 } 424 425 LLVMStructType LLVMStructType::getIdentified(MLIRContext *context, 426 StringRef name) { 427 return Base::get(context, name, /*opaque=*/false); 428 } 429 430 LLVMStructType LLVMStructType::getIdentifiedChecked( 431 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context, 432 StringRef name) { 433 return Base::getChecked(emitError, context, name, /*opaque=*/false); 434 } 435 436 LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context, 437 StringRef name, 438 ArrayRef<Type> elements, 439 bool isPacked) { 440 std::string stringName = name.str(); 441 unsigned counter = 0; 442 do { 443 auto type = LLVMStructType::getIdentified(context, stringName); 444 if (type.isInitialized() || failed(type.setBody(elements, isPacked))) { 445 counter += 1; 446 stringName = (Twine(name) + "." + std::to_string(counter)).str(); 447 continue; 448 } 449 return type; 450 } while (true); 451 } 452 453 LLVMStructType LLVMStructType::getLiteral(MLIRContext *context, 454 ArrayRef<Type> types, bool isPacked) { 455 return Base::get(context, types, isPacked); 456 } 457 458 LLVMStructType 459 LLVMStructType::getLiteralChecked(function_ref<InFlightDiagnostic()> emitError, 460 MLIRContext *context, ArrayRef<Type> types, 461 bool isPacked) { 462 return Base::getChecked(emitError, context, types, isPacked); 463 } 464 465 LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) { 466 return Base::get(context, name, /*opaque=*/true); 467 } 468 469 LLVMStructType 470 LLVMStructType::getOpaqueChecked(function_ref<InFlightDiagnostic()> emitError, 471 MLIRContext *context, StringRef name) { 472 return Base::getChecked(emitError, context, name, /*opaque=*/true); 473 } 474 475 LogicalResult LLVMStructType::setBody(ArrayRef<Type> types, bool isPacked) { 476 assert(isIdentified() && "can only set bodies of identified structs"); 477 assert(llvm::all_of(types, LLVMStructType::isValidElementType) && 478 "expected valid body types"); 479 return Base::mutate(types, isPacked); 480 } 481 482 bool LLVMStructType::isPacked() const { return getImpl()->isPacked(); } 483 bool LLVMStructType::isIdentified() const { return getImpl()->isIdentified(); } 484 bool LLVMStructType::isOpaque() const { 485 return getImpl()->isIdentified() && 486 (getImpl()->isOpaque() || !getImpl()->isInitialized()); 487 } 488 bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); } 489 StringRef LLVMStructType::getName() const { return getImpl()->getIdentifier(); } 490 ArrayRef<Type> LLVMStructType::getBody() const { 491 return isIdentified() ? getImpl()->getIdentifiedStructBody() 492 : getImpl()->getTypeList(); 493 } 494 495 LogicalResult 496 LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef, 497 bool) { 498 return success(); 499 } 500 501 LogicalResult 502 LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError, 503 ArrayRef<Type> types, bool) { 504 for (Type t : types) 505 if (!isValidElementType(t)) 506 return emitError() << "invalid LLVM structure element type: " << t; 507 508 return success(); 509 } 510 511 llvm::TypeSize 512 LLVMStructType::getTypeSizeInBits(const DataLayout &dataLayout, 513 DataLayoutEntryListRef params) const { 514 auto structSize = llvm::TypeSize::getFixed(0); 515 uint64_t structAlignment = 1; 516 for (Type element : getBody()) { 517 uint64_t elementAlignment = 518 isPacked() ? 1 : dataLayout.getTypeABIAlignment(element); 519 // Add padding to the struct size to align it to the abi alignment of the 520 // element type before than adding the size of the element. 521 structSize = llvm::alignTo(structSize, elementAlignment); 522 structSize += dataLayout.getTypeSize(element); 523 524 // The alignment requirement of a struct is equal to the strictest alignment 525 // requirement of its elements. 526 structAlignment = std::max(elementAlignment, structAlignment); 527 } 528 // At the end, add padding to the struct to satisfy its own alignment 529 // requirement. Otherwise structs inside of arrays would be misaligned. 530 structSize = llvm::alignTo(structSize, structAlignment); 531 return structSize * kBitsInByte; 532 } 533 534 namespace { 535 enum class StructDLEntryPos { Abi = 0, Preferred = 1 }; 536 } // namespace 537 538 static std::optional<uint64_t> 539 getStructDataLayoutEntry(DataLayoutEntryListRef params, LLVMStructType type, 540 StructDLEntryPos pos) { 541 const auto *currentEntry = 542 llvm::find_if(params, [](DataLayoutEntryInterface entry) { 543 return entry.isTypeEntry(); 544 }); 545 if (currentEntry == params.end()) 546 return std::nullopt; 547 548 auto attr = llvm::cast<DenseIntElementsAttr>(currentEntry->getValue()); 549 if (pos == StructDLEntryPos::Preferred && 550 attr.size() <= static_cast<int64_t>(StructDLEntryPos::Preferred)) 551 // If no preferred was specified, fall back to abi alignment 552 pos = StructDLEntryPos::Abi; 553 554 return attr.getValues<uint64_t>()[static_cast<size_t>(pos)]; 555 } 556 557 static uint64_t calculateStructAlignment(const DataLayout &dataLayout, 558 DataLayoutEntryListRef params, 559 LLVMStructType type, 560 StructDLEntryPos pos) { 561 // Packed structs always have an abi alignment of 1 562 if (pos == StructDLEntryPos::Abi && type.isPacked()) { 563 return 1; 564 } 565 566 // The alignment requirement of a struct is equal to the strictest alignment 567 // requirement of its elements. 568 uint64_t structAlignment = 1; 569 for (Type iter : type.getBody()) { 570 structAlignment = 571 std::max(dataLayout.getTypeABIAlignment(iter), structAlignment); 572 } 573 574 // Entries are only allowed to be stricter than the required alignment 575 if (std::optional<uint64_t> entryResult = 576 getStructDataLayoutEntry(params, type, pos)) 577 return std::max(*entryResult / kBitsInByte, structAlignment); 578 579 return structAlignment; 580 } 581 582 uint64_t LLVMStructType::getABIAlignment(const DataLayout &dataLayout, 583 DataLayoutEntryListRef params) const { 584 return calculateStructAlignment(dataLayout, params, *this, 585 StructDLEntryPos::Abi); 586 } 587 588 uint64_t 589 LLVMStructType::getPreferredAlignment(const DataLayout &dataLayout, 590 DataLayoutEntryListRef params) const { 591 return calculateStructAlignment(dataLayout, params, *this, 592 StructDLEntryPos::Preferred); 593 } 594 595 static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos) { 596 return llvm::cast<DenseIntElementsAttr>(attr) 597 .getValues<uint64_t>()[static_cast<size_t>(pos)]; 598 } 599 600 bool LLVMStructType::areCompatible(DataLayoutEntryListRef oldLayout, 601 DataLayoutEntryListRef newLayout) const { 602 for (DataLayoutEntryInterface newEntry : newLayout) { 603 if (!newEntry.isTypeEntry()) 604 continue; 605 606 const auto *previousEntry = 607 llvm::find_if(oldLayout, [](DataLayoutEntryInterface entry) { 608 return entry.isTypeEntry(); 609 }); 610 if (previousEntry == oldLayout.end()) 611 continue; 612 613 uint64_t abi = extractStructSpecValue(previousEntry->getValue(), 614 StructDLEntryPos::Abi); 615 uint64_t newAbi = 616 extractStructSpecValue(newEntry.getValue(), StructDLEntryPos::Abi); 617 if (abi < newAbi || abi % newAbi != 0) 618 return false; 619 } 620 return true; 621 } 622 623 LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries, 624 Location loc) const { 625 for (DataLayoutEntryInterface entry : entries) { 626 if (!entry.isTypeEntry()) 627 continue; 628 629 auto key = llvm::cast<LLVMStructType>(llvm::cast<Type>(entry.getKey())); 630 auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue()); 631 if (!values || (values.size() != 2 && values.size() != 1)) { 632 return emitError(loc) 633 << "expected layout attribute for " 634 << llvm::cast<Type>(entry.getKey()) 635 << " to be a dense integer elements attribute of 1 or 2 elements"; 636 } 637 if (!values.getElementType().isInteger(64)) 638 return emitError(loc) << "expected i64 entries for " << key; 639 640 if (key.isIdentified() || !key.getBody().empty()) { 641 return emitError(loc) << "unexpected layout attribute for struct " << key; 642 } 643 644 if (values.size() == 1) 645 continue; 646 647 if (extractStructSpecValue(values, StructDLEntryPos::Abi) > 648 extractStructSpecValue(values, StructDLEntryPos::Preferred)) { 649 return emitError(loc) << "preferred alignment is expected to be at least " 650 "as large as ABI alignment"; 651 } 652 } 653 return mlir::success(); 654 } 655 656 //===----------------------------------------------------------------------===// 657 // Vector types. 658 //===----------------------------------------------------------------------===// 659 660 /// Verifies that the type about to be constructed is well-formed. 661 template <typename VecTy> 662 static LogicalResult 663 verifyVectorConstructionInvariants(function_ref<InFlightDiagnostic()> emitError, 664 Type elementType, unsigned numElements) { 665 if (numElements == 0) 666 return emitError() << "the number of vector elements must be positive"; 667 668 if (!VecTy::isValidElementType(elementType)) 669 return emitError() << "invalid vector element type"; 670 671 return success(); 672 } 673 674 LLVMFixedVectorType LLVMFixedVectorType::get(Type elementType, 675 unsigned numElements) { 676 assert(elementType && "expected non-null subtype"); 677 return Base::get(elementType.getContext(), elementType, numElements); 678 } 679 680 LLVMFixedVectorType 681 LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError, 682 Type elementType, unsigned numElements) { 683 assert(elementType && "expected non-null subtype"); 684 return Base::getChecked(emitError, elementType.getContext(), elementType, 685 numElements); 686 } 687 688 bool LLVMFixedVectorType::isValidElementType(Type type) { 689 return llvm::isa<LLVMPointerType, LLVMPPCFP128Type>(type); 690 } 691 692 LogicalResult 693 LLVMFixedVectorType::verify(function_ref<InFlightDiagnostic()> emitError, 694 Type elementType, unsigned numElements) { 695 return verifyVectorConstructionInvariants<LLVMFixedVectorType>( 696 emitError, elementType, numElements); 697 } 698 699 //===----------------------------------------------------------------------===// 700 // LLVMScalableVectorType. 701 //===----------------------------------------------------------------------===// 702 703 LLVMScalableVectorType LLVMScalableVectorType::get(Type elementType, 704 unsigned minNumElements) { 705 assert(elementType && "expected non-null subtype"); 706 return Base::get(elementType.getContext(), elementType, minNumElements); 707 } 708 709 LLVMScalableVectorType 710 LLVMScalableVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError, 711 Type elementType, unsigned minNumElements) { 712 assert(elementType && "expected non-null subtype"); 713 return Base::getChecked(emitError, elementType.getContext(), elementType, 714 minNumElements); 715 } 716 717 bool LLVMScalableVectorType::isValidElementType(Type type) { 718 if (auto intType = llvm::dyn_cast<IntegerType>(type)) 719 return intType.isSignless(); 720 721 return isCompatibleFloatingPointType(type) || 722 llvm::isa<LLVMPointerType>(type); 723 } 724 725 LogicalResult 726 LLVMScalableVectorType::verify(function_ref<InFlightDiagnostic()> emitError, 727 Type elementType, unsigned numElements) { 728 return verifyVectorConstructionInvariants<LLVMScalableVectorType>( 729 emitError, elementType, numElements); 730 } 731 732 //===----------------------------------------------------------------------===// 733 // LLVMTargetExtType. 734 //===----------------------------------------------------------------------===// 735 736 static constexpr llvm::StringRef kSpirvPrefix = "spirv."; 737 static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount"; 738 739 bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const { 740 // See llvm/lib/IR/Type.cpp for reference. 741 uint64_t properties = 0; 742 743 if (getExtTypeName().starts_with(kSpirvPrefix)) 744 properties |= 745 (LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal); 746 747 return (properties & prop) == prop; 748 } 749 750 bool LLVM::LLVMTargetExtType::supportsMemOps() const { 751 // See llvm/lib/IR/Type.cpp for reference. 752 if (getExtTypeName().starts_with(kSpirvPrefix)) 753 return true; 754 755 if (getExtTypeName() == kArmSVCount) 756 return true; 757 758 return false; 759 } 760 761 //===----------------------------------------------------------------------===// 762 // Utility functions. 763 //===----------------------------------------------------------------------===// 764 765 bool mlir::LLVM::isCompatibleOuterType(Type type) { 766 // clang-format off 767 if (llvm::isa< 768 BFloat16Type, 769 Float16Type, 770 Float32Type, 771 Float64Type, 772 Float80Type, 773 Float128Type, 774 LLVMArrayType, 775 LLVMFunctionType, 776 LLVMLabelType, 777 LLVMMetadataType, 778 LLVMPPCFP128Type, 779 LLVMPointerType, 780 LLVMStructType, 781 LLVMTokenType, 782 LLVMFixedVectorType, 783 LLVMScalableVectorType, 784 LLVMTargetExtType, 785 LLVMVoidType, 786 LLVMX86AMXType 787 >(type)) { 788 // clang-format on 789 return true; 790 } 791 792 // Only signless integers are compatible. 793 if (auto intType = llvm::dyn_cast<IntegerType>(type)) 794 return intType.isSignless(); 795 796 // 1D vector types are compatible. 797 if (auto vecType = llvm::dyn_cast<VectorType>(type)) 798 return vecType.getRank() == 1; 799 800 return false; 801 } 802 803 static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) { 804 if (!compatibleTypes.insert(type).second) 805 return true; 806 807 auto isCompatible = [&](Type type) { 808 return isCompatibleImpl(type, compatibleTypes); 809 }; 810 811 bool result = 812 llvm::TypeSwitch<Type, bool>(type) 813 .Case<LLVMStructType>([&](auto structType) { 814 return llvm::all_of(structType.getBody(), isCompatible); 815 }) 816 .Case<LLVMFunctionType>([&](auto funcType) { 817 return isCompatible(funcType.getReturnType()) && 818 llvm::all_of(funcType.getParams(), isCompatible); 819 }) 820 .Case<IntegerType>([](auto intType) { return intType.isSignless(); }) 821 .Case<VectorType>([&](auto vecType) { 822 return vecType.getRank() == 1 && 823 isCompatible(vecType.getElementType()); 824 }) 825 .Case<LLVMPointerType>([&](auto pointerType) { return true; }) 826 .Case<LLVMTargetExtType>([&](auto extType) { 827 return llvm::all_of(extType.getTypeParams(), isCompatible); 828 }) 829 // clang-format off 830 .Case< 831 LLVMFixedVectorType, 832 LLVMScalableVectorType, 833 LLVMArrayType 834 >([&](auto containerType) { 835 return isCompatible(containerType.getElementType()); 836 }) 837 .Case< 838 BFloat16Type, 839 Float16Type, 840 Float32Type, 841 Float64Type, 842 Float80Type, 843 Float128Type, 844 LLVMLabelType, 845 LLVMMetadataType, 846 LLVMPPCFP128Type, 847 LLVMTokenType, 848 LLVMVoidType, 849 LLVMX86AMXType 850 >([](Type) { return true; }) 851 // clang-format on 852 .Default([](Type) { return false; }); 853 854 if (!result) 855 compatibleTypes.erase(type); 856 857 return result; 858 } 859 860 bool LLVMDialect::isCompatibleType(Type type) { 861 if (auto *llvmDialect = 862 type.getContext()->getLoadedDialect<LLVM::LLVMDialect>()) 863 return isCompatibleImpl(type, llvmDialect->compatibleTypes.get()); 864 865 DenseSet<Type> localCompatibleTypes; 866 return isCompatibleImpl(type, localCompatibleTypes); 867 } 868 869 bool mlir::LLVM::isCompatibleType(Type type) { 870 return LLVMDialect::isCompatibleType(type); 871 } 872 873 bool mlir::LLVM::isCompatibleFloatingPointType(Type type) { 874 return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type, 875 Float80Type, Float128Type, LLVMPPCFP128Type>(type); 876 } 877 878 bool mlir::LLVM::isCompatibleVectorType(Type type) { 879 if (llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType>(type)) 880 return true; 881 882 if (auto vecType = llvm::dyn_cast<VectorType>(type)) { 883 if (vecType.getRank() != 1) 884 return false; 885 Type elementType = vecType.getElementType(); 886 if (auto intType = llvm::dyn_cast<IntegerType>(elementType)) 887 return intType.isSignless(); 888 return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type, 889 Float80Type, Float128Type>(elementType); 890 } 891 return false; 892 } 893 894 Type mlir::LLVM::getVectorElementType(Type type) { 895 return llvm::TypeSwitch<Type, Type>(type) 896 .Case<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>( 897 [](auto ty) { return ty.getElementType(); }) 898 .Default([](Type) -> Type { 899 llvm_unreachable("incompatible with LLVM vector type"); 900 }); 901 } 902 903 llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) { 904 return llvm::TypeSwitch<Type, llvm::ElementCount>(type) 905 .Case([](VectorType ty) { 906 if (ty.isScalable()) 907 return llvm::ElementCount::getScalable(ty.getNumElements()); 908 return llvm::ElementCount::getFixed(ty.getNumElements()); 909 }) 910 .Case([](LLVMFixedVectorType ty) { 911 return llvm::ElementCount::getFixed(ty.getNumElements()); 912 }) 913 .Case([](LLVMScalableVectorType ty) { 914 return llvm::ElementCount::getScalable(ty.getMinNumElements()); 915 }) 916 .Default([](Type) -> llvm::ElementCount { 917 llvm_unreachable("incompatible with LLVM vector type"); 918 }); 919 } 920 921 bool mlir::LLVM::isScalableVectorType(Type vectorType) { 922 assert((llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>( 923 vectorType)) && 924 "expected LLVM-compatible vector type"); 925 return !llvm::isa<LLVMFixedVectorType>(vectorType) && 926 (llvm::isa<LLVMScalableVectorType>(vectorType) || 927 llvm::cast<VectorType>(vectorType).isScalable()); 928 } 929 930 Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements, 931 bool isScalable) { 932 bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType); 933 bool useBuiltIn = VectorType::isValidElementType(elementType); 934 (void)useBuiltIn; 935 assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type " 936 "to be either builtin or LLVM dialect type"); 937 if (useLLVM) { 938 if (isScalable) 939 return LLVMScalableVectorType::get(elementType, numElements); 940 return LLVMFixedVectorType::get(elementType, numElements); 941 } 942 943 // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as 944 // scalable/non-scalable. 945 return VectorType::get(numElements, elementType, {isScalable}); 946 } 947 948 Type mlir::LLVM::getVectorType(Type elementType, 949 const llvm::ElementCount &numElements) { 950 if (numElements.isScalable()) 951 return getVectorType(elementType, numElements.getKnownMinValue(), 952 /*isScalable=*/true); 953 return getVectorType(elementType, numElements.getFixedValue(), 954 /*isScalable=*/false); 955 } 956 957 Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) { 958 bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType); 959 bool useBuiltIn = VectorType::isValidElementType(elementType); 960 (void)useBuiltIn; 961 assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type " 962 "to be either builtin or LLVM dialect type"); 963 if (useLLVM) 964 return LLVMFixedVectorType::get(elementType, numElements); 965 return VectorType::get(numElements, elementType); 966 } 967 968 Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) { 969 bool useLLVM = LLVMScalableVectorType::isValidElementType(elementType); 970 bool useBuiltIn = VectorType::isValidElementType(elementType); 971 (void)useBuiltIn; 972 assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible scalable-vector " 973 "type to be either builtin or LLVM dialect " 974 "type"); 975 if (useLLVM) 976 return LLVMScalableVectorType::get(elementType, numElements); 977 978 // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as 979 // scalable/non-scalable. 980 return VectorType::get(numElements, elementType, /*scalableDims=*/true); 981 } 982 983 llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) { 984 assert(isCompatibleType(type) && 985 "expected a type compatible with the LLVM dialect"); 986 987 return llvm::TypeSwitch<Type, llvm::TypeSize>(type) 988 .Case<BFloat16Type, Float16Type>( 989 [](Type) { return llvm::TypeSize::getFixed(16); }) 990 .Case<Float32Type>([](Type) { return llvm::TypeSize::getFixed(32); }) 991 .Case<Float64Type>([](Type) { return llvm::TypeSize::getFixed(64); }) 992 .Case<Float80Type>([](Type) { return llvm::TypeSize::getFixed(80); }) 993 .Case<Float128Type>([](Type) { return llvm::TypeSize::getFixed(128); }) 994 .Case<IntegerType>([](IntegerType intTy) { 995 return llvm::TypeSize::getFixed(intTy.getWidth()); 996 }) 997 .Case<LLVMPPCFP128Type>( 998 [](Type) { return llvm::TypeSize::getFixed(128); }) 999 .Case<LLVMFixedVectorType>([](LLVMFixedVectorType t) { 1000 llvm::TypeSize elementSize = 1001 getPrimitiveTypeSizeInBits(t.getElementType()); 1002 return llvm::TypeSize(elementSize.getFixedValue() * t.getNumElements(), 1003 elementSize.isScalable()); 1004 }) 1005 .Case<VectorType>([](VectorType t) { 1006 assert(isCompatibleVectorType(t) && 1007 "unexpected incompatible with LLVM vector type"); 1008 llvm::TypeSize elementSize = 1009 getPrimitiveTypeSizeInBits(t.getElementType()); 1010 return llvm::TypeSize(elementSize.getFixedValue() * t.getNumElements(), 1011 elementSize.isScalable()); 1012 }) 1013 .Default([](Type ty) { 1014 assert((llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType, 1015 LLVMTokenType, LLVMStructType, LLVMArrayType, 1016 LLVMPointerType, LLVMFunctionType, LLVMTargetExtType>( 1017 ty)) && 1018 "unexpected missing support for primitive type"); 1019 return llvm::TypeSize::getFixed(0); 1020 }); 1021 } 1022 1023 //===----------------------------------------------------------------------===// 1024 // LLVMDialect 1025 //===----------------------------------------------------------------------===// 1026 1027 void LLVMDialect::registerTypes() { 1028 addTypes< 1029 #define GET_TYPEDEF_LIST 1030 #include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc" 1031 >(); 1032 } 1033 1034 Type LLVMDialect::parseType(DialectAsmParser &parser) const { 1035 return detail::parseType(parser); 1036 } 1037 1038 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { 1039 return detail::printType(type, os); 1040 } 1041