1 //===-- FIRType.cpp -------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Optimizer/Dialect/FIRType.h" 14 #include "flang/ISO_Fortran_binding_wrapper.h" 15 #include "flang/Optimizer/Builder/Todo.h" 16 #include "flang/Optimizer/Dialect/FIRDialect.h" 17 #include "flang/Optimizer/Dialect/Support/KindMapping.h" 18 #include "flang/Tools/PointerModels.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/BuiltinDialect.h" 21 #include "mlir/IR/Diagnostics.h" 22 #include "mlir/IR/DialectImplementation.h" 23 #include "llvm/ADT/SmallPtrSet.h" 24 #include "llvm/ADT/StringSet.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include "llvm/Support/ErrorHandling.h" 27 28 #define GET_TYPEDEF_CLASSES 29 #include "flang/Optimizer/Dialect/FIROpsTypes.cpp.inc" 30 31 using namespace fir; 32 33 namespace { 34 35 template <typename TYPE> 36 TYPE parseIntSingleton(mlir::AsmParser &parser) { 37 int kind = 0; 38 if (parser.parseLess() || parser.parseInteger(kind) || parser.parseGreater()) 39 return {}; 40 return TYPE::get(parser.getContext(), kind); 41 } 42 43 template <typename TYPE> 44 TYPE parseKindSingleton(mlir::AsmParser &parser) { 45 return parseIntSingleton<TYPE>(parser); 46 } 47 48 template <typename TYPE> 49 TYPE parseRankSingleton(mlir::AsmParser &parser) { 50 return parseIntSingleton<TYPE>(parser); 51 } 52 53 template <typename TYPE> 54 TYPE parseTypeSingleton(mlir::AsmParser &parser) { 55 mlir::Type ty; 56 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) 57 return {}; 58 return TYPE::get(ty); 59 } 60 61 /// Is `ty` a standard or FIR integer type? 62 static bool isaIntegerType(mlir::Type ty) { 63 // TODO: why aren't we using isa_integer? investigatation required. 64 return mlir::isa<mlir::IntegerType, fir::IntegerType>(ty); 65 } 66 67 bool verifyRecordMemberType(mlir::Type ty) { 68 return !mlir::isa<BoxCharType, ShapeType, ShapeShiftType, ShiftType, 69 SliceType, FieldType, LenType, ReferenceType, TypeDescType>( 70 ty); 71 } 72 73 bool verifySameLists(llvm::ArrayRef<RecordType::TypePair> a1, 74 llvm::ArrayRef<RecordType::TypePair> a2) { 75 // FIXME: do we need to allow for any variance here? 76 return a1 == a2; 77 } 78 79 RecordType verifyDerived(mlir::AsmParser &parser, RecordType derivedTy, 80 llvm::ArrayRef<RecordType::TypePair> lenPList, 81 llvm::ArrayRef<RecordType::TypePair> typeList) { 82 auto loc = parser.getNameLoc(); 83 if (!verifySameLists(derivedTy.getLenParamList(), lenPList) || 84 !verifySameLists(derivedTy.getTypeList(), typeList)) { 85 parser.emitError(loc, "cannot redefine record type members"); 86 return {}; 87 } 88 for (auto &p : lenPList) 89 if (!isaIntegerType(p.second)) { 90 parser.emitError(loc, "LEN parameter must be integral type"); 91 return {}; 92 } 93 for (auto &p : typeList) 94 if (!verifyRecordMemberType(p.second)) { 95 parser.emitError(loc, "field parameter has invalid type"); 96 return {}; 97 } 98 llvm::StringSet<> uniq; 99 for (auto &p : lenPList) 100 if (!uniq.insert(p.first).second) { 101 parser.emitError(loc, "LEN parameter cannot have duplicate name"); 102 return {}; 103 } 104 for (auto &p : typeList) 105 if (!uniq.insert(p.first).second) { 106 parser.emitError(loc, "field cannot have duplicate name"); 107 return {}; 108 } 109 return derivedTy; 110 } 111 112 } // namespace 113 114 // Implementation of the thin interface from dialect to type parser 115 116 mlir::Type fir::parseFirType(FIROpsDialect *dialect, 117 mlir::DialectAsmParser &parser) { 118 mlir::StringRef typeTag; 119 mlir::Type genType; 120 auto parseResult = generatedTypeParser(parser, &typeTag, genType); 121 if (parseResult.has_value()) 122 return genType; 123 parser.emitError(parser.getNameLoc(), "unknown fir type: ") << typeTag; 124 return {}; 125 } 126 127 namespace fir { 128 namespace detail { 129 130 // Type storage classes 131 132 /// Derived type storage 133 struct RecordTypeStorage : public mlir::TypeStorage { 134 using KeyTy = llvm::StringRef; 135 136 static unsigned hashKey(const KeyTy &key) { 137 return llvm::hash_combine(key.str()); 138 } 139 140 bool operator==(const KeyTy &key) const { return key == getName(); } 141 142 static RecordTypeStorage *construct(mlir::TypeStorageAllocator &allocator, 143 const KeyTy &key) { 144 auto *storage = allocator.allocate<RecordTypeStorage>(); 145 return new (storage) RecordTypeStorage{key}; 146 } 147 148 llvm::StringRef getName() const { return name; } 149 150 void setLenParamList(llvm::ArrayRef<RecordType::TypePair> list) { 151 lens = list; 152 } 153 llvm::ArrayRef<RecordType::TypePair> getLenParamList() const { return lens; } 154 155 void setTypeList(llvm::ArrayRef<RecordType::TypePair> list) { types = list; } 156 llvm::ArrayRef<RecordType::TypePair> getTypeList() const { return types; } 157 158 bool isFinalized() const { return finalized; } 159 void finalize(llvm::ArrayRef<RecordType::TypePair> lenParamList, 160 llvm::ArrayRef<RecordType::TypePair> typeList) { 161 if (finalized) 162 return; 163 finalized = true; 164 setLenParamList(lenParamList); 165 setTypeList(typeList); 166 } 167 168 bool isPacked() const { return packed; } 169 void pack(bool p) { packed = p; } 170 171 protected: 172 std::string name; 173 bool finalized; 174 bool packed; 175 std::vector<RecordType::TypePair> lens; 176 std::vector<RecordType::TypePair> types; 177 178 private: 179 RecordTypeStorage() = delete; 180 explicit RecordTypeStorage(llvm::StringRef name) 181 : name{name}, finalized{false}, packed{false} {} 182 }; 183 184 } // namespace detail 185 186 template <typename A, typename B> 187 bool inbounds(A v, B lb, B ub) { 188 return v >= lb && v < ub; 189 } 190 191 bool isa_fir_type(mlir::Type t) { 192 return llvm::isa<FIROpsDialect>(t.getDialect()); 193 } 194 195 bool isa_std_type(mlir::Type t) { 196 return llvm::isa<mlir::BuiltinDialect>(t.getDialect()); 197 } 198 199 bool isa_fir_or_std_type(mlir::Type t) { 200 if (auto funcType = mlir::dyn_cast<mlir::FunctionType>(t)) 201 return llvm::all_of(funcType.getInputs(), isa_fir_or_std_type) && 202 llvm::all_of(funcType.getResults(), isa_fir_or_std_type); 203 return isa_fir_type(t) || isa_std_type(t); 204 } 205 206 mlir::Type getDerivedType(mlir::Type ty) { 207 return llvm::TypeSwitch<mlir::Type, mlir::Type>(ty) 208 .Case<fir::PointerType, fir::HeapType, fir::SequenceType>([](auto p) { 209 if (auto seq = mlir::dyn_cast<fir::SequenceType>(p.getEleTy())) 210 return seq.getEleTy(); 211 return p.getEleTy(); 212 }) 213 .Case<fir::BoxType>([](auto p) { return getDerivedType(p.getEleTy()); }) 214 .Default([](mlir::Type t) { return t; }); 215 } 216 217 mlir::Type dyn_cast_ptrEleTy(mlir::Type t) { 218 return llvm::TypeSwitch<mlir::Type, mlir::Type>(t) 219 .Case<fir::ReferenceType, fir::PointerType, fir::HeapType, 220 fir::LLVMPointerType>([](auto p) { return p.getEleTy(); }) 221 .Default([](mlir::Type) { return mlir::Type{}; }); 222 } 223 224 mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t) { 225 return llvm::TypeSwitch<mlir::Type, mlir::Type>(t) 226 .Case<fir::ReferenceType, fir::PointerType, fir::HeapType, 227 fir::LLVMPointerType>([](auto p) { return p.getEleTy(); }) 228 .Case<fir::BaseBoxType>( 229 [](auto p) { return unwrapRefType(p.getEleTy()); }) 230 .Default([](mlir::Type) { return mlir::Type{}; }); 231 } 232 233 static bool hasDynamicSize(fir::RecordType recTy) { 234 for (auto field : recTy.getTypeList()) { 235 if (auto arr = mlir::dyn_cast<fir::SequenceType>(field.second)) { 236 if (sequenceWithNonConstantShape(arr)) 237 return true; 238 } else if (characterWithDynamicLen(field.second)) { 239 return true; 240 } else if (auto rec = mlir::dyn_cast<fir::RecordType>(field.second)) { 241 if (hasDynamicSize(rec)) 242 return true; 243 } 244 } 245 return false; 246 } 247 248 bool hasDynamicSize(mlir::Type t) { 249 if (auto arr = mlir::dyn_cast<fir::SequenceType>(t)) { 250 if (sequenceWithNonConstantShape(arr)) 251 return true; 252 t = arr.getEleTy(); 253 } 254 if (characterWithDynamicLen(t)) 255 return true; 256 if (auto rec = mlir::dyn_cast<fir::RecordType>(t)) 257 return hasDynamicSize(rec); 258 return false; 259 } 260 261 mlir::Type extractSequenceType(mlir::Type ty) { 262 if (mlir::isa<fir::SequenceType>(ty)) 263 return ty; 264 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) 265 return extractSequenceType(boxTy.getEleTy()); 266 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty)) 267 return extractSequenceType(heapTy.getEleTy()); 268 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty)) 269 return extractSequenceType(ptrTy.getEleTy()); 270 return mlir::Type{}; 271 } 272 273 bool isPointerType(mlir::Type ty) { 274 if (auto refTy = fir::dyn_cast_ptrEleTy(ty)) 275 ty = refTy; 276 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) 277 return mlir::isa<fir::PointerType>(boxTy.getEleTy()); 278 return false; 279 } 280 281 bool isAllocatableType(mlir::Type ty) { 282 if (auto refTy = fir::dyn_cast_ptrEleTy(ty)) 283 ty = refTy; 284 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) 285 return mlir::isa<fir::HeapType>(boxTy.getEleTy()); 286 return false; 287 } 288 289 bool isBoxNone(mlir::Type ty) { 290 if (auto box = mlir::dyn_cast<fir::BoxType>(ty)) 291 return mlir::isa<mlir::NoneType>(box.getEleTy()); 292 return false; 293 } 294 295 bool isBoxedRecordType(mlir::Type ty) { 296 if (auto refTy = fir::dyn_cast_ptrEleTy(ty)) 297 ty = refTy; 298 if (auto boxTy = mlir::dyn_cast<fir::BoxType>(ty)) { 299 if (mlir::isa<fir::RecordType>(boxTy.getEleTy())) 300 return true; 301 mlir::Type innerType = boxTy.unwrapInnerType(); 302 return innerType && mlir::isa<fir::RecordType>(innerType); 303 } 304 return false; 305 } 306 307 bool isScalarBoxedRecordType(mlir::Type ty) { 308 if (auto refTy = fir::dyn_cast_ptrEleTy(ty)) 309 ty = refTy; 310 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) { 311 if (mlir::isa<fir::RecordType>(boxTy.getEleTy())) 312 return true; 313 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(boxTy.getEleTy())) 314 return mlir::isa<fir::RecordType>(heapTy.getEleTy()); 315 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(boxTy.getEleTy())) 316 return mlir::isa<fir::RecordType>(ptrTy.getEleTy()); 317 } 318 return false; 319 } 320 321 bool isAssumedType(mlir::Type ty) { 322 // Rule out CLASS(*) which are `fir.class<[fir.array] none>`. 323 if (mlir::isa<fir::ClassType>(ty)) 324 return false; 325 mlir::Type valueType = fir::unwrapPassByRefType(fir::unwrapRefType(ty)); 326 // Refuse raw `none` or `fir.array<none>` since assumed type 327 // should be in memory variables. 328 if (valueType == ty) 329 return false; 330 mlir::Type inner = fir::unwrapSequenceType(valueType); 331 return mlir::isa<mlir::NoneType>(inner); 332 } 333 334 bool isAssumedShape(mlir::Type ty) { 335 if (auto boxTy = mlir::dyn_cast<fir::BoxType>(ty)) 336 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(boxTy.getEleTy())) 337 return seqTy.hasDynamicExtents(); 338 return false; 339 } 340 341 bool isAllocatableOrPointerArray(mlir::Type ty) { 342 if (auto refTy = fir::dyn_cast_ptrEleTy(ty)) 343 ty = refTy; 344 if (auto boxTy = mlir::dyn_cast<fir::BoxType>(ty)) { 345 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(boxTy.getEleTy())) 346 return mlir::isa<fir::SequenceType>(heapTy.getEleTy()); 347 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(boxTy.getEleTy())) 348 return mlir::isa<fir::SequenceType>(ptrTy.getEleTy()); 349 } 350 return false; 351 } 352 353 bool isTypeWithDescriptor(mlir::Type ty) { 354 if (mlir::isa<fir::BaseBoxType>(unwrapRefType(ty))) 355 return true; 356 return false; 357 } 358 359 bool isPolymorphicType(mlir::Type ty) { 360 // CLASS(T) or CLASS(*) 361 if (mlir::isa<fir::ClassType>(fir::unwrapRefType(ty))) 362 return true; 363 // assumed type are polymorphic. 364 return isAssumedType(ty); 365 } 366 367 bool isUnlimitedPolymorphicType(mlir::Type ty) { 368 // CLASS(*) 369 if (auto clTy = mlir::dyn_cast<fir::ClassType>(fir::unwrapRefType(ty))) { 370 if (mlir::isa<mlir::NoneType>(clTy.getEleTy())) 371 return true; 372 mlir::Type innerType = clTy.unwrapInnerType(); 373 return innerType && mlir::isa<mlir::NoneType>(innerType); 374 } 375 // TYPE(*) 376 return isAssumedType(ty); 377 } 378 379 mlir::Type unwrapInnerType(mlir::Type ty) { 380 return llvm::TypeSwitch<mlir::Type, mlir::Type>(ty) 381 .Case<fir::PointerType, fir::HeapType, fir::SequenceType>([](auto t) { 382 mlir::Type eleTy = t.getEleTy(); 383 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) 384 return seqTy.getEleTy(); 385 return eleTy; 386 }) 387 .Case<fir::RecordType>([](auto t) { return t; }) 388 .Default([](mlir::Type) { return mlir::Type{}; }); 389 } 390 391 bool isRecordWithAllocatableMember(mlir::Type ty) { 392 if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty)) 393 for (auto [field, memTy] : recTy.getTypeList()) { 394 if (fir::isAllocatableType(memTy)) 395 return true; 396 // A record type cannot recursively include itself as a direct member. 397 // There must be an intervening `ptr` type, so recursion is safe here. 398 if (mlir::isa<fir::RecordType>(memTy) && 399 isRecordWithAllocatableMember(memTy)) 400 return true; 401 } 402 return false; 403 } 404 405 bool isRecordWithDescriptorMember(mlir::Type ty) { 406 ty = unwrapSequenceType(ty); 407 if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty)) 408 for (auto [field, memTy] : recTy.getTypeList()) { 409 if (mlir::isa<fir::BaseBoxType>(memTy)) 410 return true; 411 if (mlir::isa<fir::RecordType>(memTy) && 412 isRecordWithDescriptorMember(memTy)) 413 return true; 414 } 415 return false; 416 } 417 418 mlir::Type unwrapAllRefAndSeqType(mlir::Type ty) { 419 while (true) { 420 mlir::Type nt = unwrapSequenceType(unwrapRefType(ty)); 421 if (auto vecTy = mlir::dyn_cast<fir::VectorType>(nt)) 422 nt = vecTy.getEleTy(); 423 if (nt == ty) 424 return ty; 425 ty = nt; 426 } 427 } 428 429 mlir::Type unwrapSeqOrBoxedSeqType(mlir::Type ty) { 430 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) 431 return seqTy.getEleTy(); 432 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) { 433 auto eleTy = unwrapRefType(boxTy.getEleTy()); 434 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) 435 return seqTy.getEleTy(); 436 } 437 return ty; 438 } 439 440 unsigned getBoxRank(mlir::Type boxTy) { 441 auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(boxTy); 442 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) 443 return seqTy.getDimension(); 444 return 0; 445 } 446 447 /// Return the ISO_C_BINDING intrinsic module value of type \p ty. 448 int getTypeCode(mlir::Type ty, const fir::KindMapping &kindMap) { 449 if (mlir::IntegerType intTy = mlir::dyn_cast<mlir::IntegerType>(ty)) { 450 if (intTy.isUnsigned()) { 451 switch (intTy.getWidth()) { 452 case 8: 453 return CFI_type_uint8_t; 454 case 16: 455 return CFI_type_uint16_t; 456 case 32: 457 return CFI_type_uint32_t; 458 case 64: 459 return CFI_type_uint64_t; 460 case 128: 461 return CFI_type_uint128_t; 462 } 463 llvm_unreachable("unsupported integer type"); 464 } else { 465 switch (intTy.getWidth()) { 466 case 8: 467 return CFI_type_int8_t; 468 case 16: 469 return CFI_type_int16_t; 470 case 32: 471 return CFI_type_int32_t; 472 case 64: 473 return CFI_type_int64_t; 474 case 128: 475 return CFI_type_int128_t; 476 } 477 llvm_unreachable("unsupported integer type"); 478 } 479 } 480 if (fir::LogicalType logicalTy = mlir::dyn_cast<fir::LogicalType>(ty)) { 481 switch (kindMap.getLogicalBitsize(logicalTy.getFKind())) { 482 case 8: 483 return CFI_type_Bool; 484 case 16: 485 return CFI_type_int_least16_t; 486 case 32: 487 return CFI_type_int_least32_t; 488 case 64: 489 return CFI_type_int_least64_t; 490 } 491 llvm_unreachable("unsupported logical type"); 492 } 493 if (mlir::FloatType floatTy = mlir::dyn_cast<mlir::FloatType>(ty)) { 494 switch (floatTy.getWidth()) { 495 case 16: 496 return floatTy.isBF16() ? CFI_type_bfloat : CFI_type_half_float; 497 case 32: 498 return CFI_type_float; 499 case 64: 500 return CFI_type_double; 501 case 80: 502 return CFI_type_extended_double; 503 case 128: 504 return CFI_type_float128; 505 } 506 llvm_unreachable("unsupported real type"); 507 } 508 if (mlir::ComplexType complexTy = mlir::dyn_cast<mlir::ComplexType>(ty)) { 509 mlir::FloatType floatTy = 510 mlir::cast<mlir::FloatType>(complexTy.getElementType()); 511 if (floatTy.isBF16()) 512 return CFI_type_bfloat_Complex; 513 switch (floatTy.getWidth()) { 514 case 16: 515 return CFI_type_half_float_Complex; 516 case 32: 517 return CFI_type_float_Complex; 518 case 64: 519 return CFI_type_double_Complex; 520 case 80: 521 return CFI_type_extended_double_Complex; 522 case 128: 523 return CFI_type_float128_Complex; 524 } 525 llvm_unreachable("unsupported complex size"); 526 } 527 if (fir::CharacterType charTy = mlir::dyn_cast<fir::CharacterType>(ty)) { 528 switch (kindMap.getCharacterBitsize(charTy.getFKind())) { 529 case 8: 530 return CFI_type_char; 531 case 16: 532 return CFI_type_char16_t; 533 case 32: 534 return CFI_type_char32_t; 535 } 536 llvm_unreachable("unsupported character type"); 537 } 538 if (fir::isa_ref_type(ty)) 539 return CFI_type_cptr; 540 if (mlir::isa<fir::RecordType>(ty)) 541 return CFI_type_struct; 542 llvm_unreachable("unsupported type"); 543 } 544 545 std::string getTypeAsString(mlir::Type ty, const fir::KindMapping &kindMap, 546 llvm::StringRef prefix) { 547 std::string buf = prefix.str(); 548 llvm::raw_string_ostream name{buf}; 549 if (!prefix.empty()) 550 name << "_"; 551 while (ty) { 552 if (fir::isa_trivial(ty)) { 553 if (mlir::isa<mlir::IndexType>(ty)) { 554 name << "idx"; 555 } else if (ty.isIntOrIndex()) { 556 name << 'i' << ty.getIntOrFloatBitWidth(); 557 } else if (mlir::isa<mlir::FloatType>(ty)) { 558 name << 'f' << ty.getIntOrFloatBitWidth(); 559 } else if (auto cplxTy = mlir::dyn_cast_or_null<mlir::ComplexType>(ty)) { 560 name << 'z'; 561 auto floatTy = mlir::cast<mlir::FloatType>(cplxTy.getElementType()); 562 name << floatTy.getWidth(); 563 } else if (auto logTy = mlir::dyn_cast_or_null<fir::LogicalType>(ty)) { 564 name << 'l' << kindMap.getLogicalBitsize(logTy.getFKind()); 565 } else { 566 llvm::report_fatal_error("unsupported type"); 567 } 568 break; 569 } else if (mlir::isa<mlir::NoneType>(ty)) { 570 name << "none"; 571 break; 572 } else if (auto charTy = mlir::dyn_cast_or_null<fir::CharacterType>(ty)) { 573 name << 'c' << kindMap.getCharacterBitsize(charTy.getFKind()); 574 if (charTy.getLen() == fir::CharacterType::unknownLen()) 575 name << "xU"; 576 else if (charTy.getLen() != fir::CharacterType::singleton()) 577 name << "x" << charTy.getLen(); 578 break; 579 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(ty)) { 580 for (auto extent : seqTy.getShape()) { 581 if (extent == fir::SequenceType::getUnknownExtent()) 582 name << "Ux"; 583 else 584 name << extent << 'x'; 585 } 586 ty = seqTy.getEleTy(); 587 } else if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(ty)) { 588 name << "ref_"; 589 ty = refTy.getEleTy(); 590 } else if (auto ptrTy = mlir::dyn_cast_or_null<fir::PointerType>(ty)) { 591 name << "ptr_"; 592 ty = ptrTy.getEleTy(); 593 } else if (auto ptrTy = mlir::dyn_cast_or_null<fir::LLVMPointerType>(ty)) { 594 name << "llvmptr_"; 595 ty = ptrTy.getEleTy(); 596 } else if (auto heapTy = mlir::dyn_cast_or_null<fir::HeapType>(ty)) { 597 name << "heap_"; 598 ty = heapTy.getEleTy(); 599 } else if (auto classTy = mlir::dyn_cast_or_null<fir::ClassType>(ty)) { 600 name << "class_"; 601 ty = classTy.getEleTy(); 602 } else if (auto boxTy = mlir::dyn_cast_or_null<fir::BoxType>(ty)) { 603 name << "box_"; 604 ty = boxTy.getEleTy(); 605 } else if (auto boxcharTy = mlir::dyn_cast_or_null<fir::BoxCharType>(ty)) { 606 name << "boxchar_"; 607 ty = boxcharTy.getEleTy(); 608 } else if (auto recTy = mlir::dyn_cast_or_null<fir::RecordType>(ty)) { 609 name << "rec_" << recTy.getName(); 610 break; 611 } else { 612 llvm::report_fatal_error("unsupported type"); 613 } 614 } 615 return buf; 616 } 617 618 mlir::Type changeElementType(mlir::Type type, mlir::Type newElementType, 619 bool turnBoxIntoClass) { 620 return llvm::TypeSwitch<mlir::Type, mlir::Type>(type) 621 .Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type { 622 return fir::SequenceType::get(seqTy.getShape(), newElementType); 623 }) 624 .Case<fir::PointerType, fir::HeapType, fir::ReferenceType, 625 fir::ClassType>([&](auto t) -> mlir::Type { 626 using FIRT = decltype(t); 627 return FIRT::get( 628 changeElementType(t.getEleTy(), newElementType, turnBoxIntoClass)); 629 }) 630 .Case<fir::BoxType>([&](fir::BoxType t) -> mlir::Type { 631 mlir::Type newInnerType = 632 changeElementType(t.getEleTy(), newElementType, false); 633 if (turnBoxIntoClass) 634 return fir::ClassType::get(newInnerType); 635 return fir::BoxType::get(newInnerType); 636 }) 637 .Default([&](mlir::Type t) -> mlir::Type { 638 assert((fir::isa_trivial(t) || llvm::isa<fir::RecordType>(t) || 639 llvm::isa<mlir::NoneType>(t)) && 640 "unexpected FIR leaf type"); 641 return newElementType; 642 }); 643 } 644 645 } // namespace fir 646 647 namespace { 648 649 static llvm::SmallPtrSet<detail::RecordTypeStorage const *, 4> 650 recordTypeVisited; 651 652 } // namespace 653 654 void fir::verifyIntegralType(mlir::Type type) { 655 if (isaIntegerType(type) || mlir::isa<mlir::IndexType>(type)) 656 return; 657 llvm::report_fatal_error("expected integral type"); 658 } 659 660 void fir::printFirType(FIROpsDialect *, mlir::Type ty, 661 mlir::DialectAsmPrinter &p) { 662 if (mlir::failed(generatedTypePrinter(ty, p))) 663 llvm::report_fatal_error("unknown type to print"); 664 } 665 666 bool fir::isa_unknown_size_box(mlir::Type t) { 667 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(t)) { 668 auto valueType = fir::unwrapPassByRefType(boxTy); 669 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(valueType)) 670 if (seqTy.hasUnknownShape()) 671 return true; 672 } 673 return false; 674 } 675 676 //===----------------------------------------------------------------------===// 677 // BoxProcType 678 //===----------------------------------------------------------------------===// 679 680 // `boxproc` `<` return-type `>` 681 mlir::Type BoxProcType::parse(mlir::AsmParser &parser) { 682 mlir::Type ty; 683 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) 684 return {}; 685 return get(parser.getContext(), ty); 686 } 687 688 void fir::BoxProcType::print(mlir::AsmPrinter &printer) const { 689 printer << "<" << getEleTy() << '>'; 690 } 691 692 llvm::LogicalResult 693 BoxProcType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, 694 mlir::Type eleTy) { 695 if (mlir::isa<mlir::FunctionType>(eleTy)) 696 return mlir::success(); 697 if (auto refTy = mlir::dyn_cast<ReferenceType>(eleTy)) 698 if (mlir::isa<mlir::FunctionType>(refTy)) 699 return mlir::success(); 700 return emitError() << "invalid type for boxproc" << eleTy << '\n'; 701 } 702 703 static bool cannotBePointerOrHeapElementType(mlir::Type eleTy) { 704 return mlir::isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType, 705 SliceType, FieldType, LenType, HeapType, PointerType, 706 ReferenceType, TypeDescType>(eleTy); 707 } 708 709 //===----------------------------------------------------------------------===// 710 // BoxType 711 //===----------------------------------------------------------------------===// 712 713 llvm::LogicalResult 714 fir::BoxType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, 715 mlir::Type eleTy) { 716 if (mlir::isa<fir::BaseBoxType>(eleTy)) 717 return emitError() << "invalid element type\n"; 718 // TODO 719 return mlir::success(); 720 } 721 722 //===----------------------------------------------------------------------===// 723 // BoxCharType 724 //===----------------------------------------------------------------------===// 725 726 mlir::Type fir::BoxCharType::parse(mlir::AsmParser &parser) { 727 return parseKindSingleton<fir::BoxCharType>(parser); 728 } 729 730 void fir::BoxCharType::print(mlir::AsmPrinter &printer) const { 731 printer << "<" << getKind() << ">"; 732 } 733 734 CharacterType 735 fir::BoxCharType::getElementType(mlir::MLIRContext *context) const { 736 return CharacterType::getUnknownLen(context, getKind()); 737 } 738 739 CharacterType fir::BoxCharType::getEleTy() const { 740 return getElementType(getContext()); 741 } 742 743 //===----------------------------------------------------------------------===// 744 // CharacterType 745 //===----------------------------------------------------------------------===// 746 747 // `char` `<` kind [`,` `len`] `>` 748 mlir::Type fir::CharacterType::parse(mlir::AsmParser &parser) { 749 int kind = 0; 750 if (parser.parseLess() || parser.parseInteger(kind)) 751 return {}; 752 CharacterType::LenType len = 1; 753 if (mlir::succeeded(parser.parseOptionalComma())) { 754 if (mlir::succeeded(parser.parseOptionalQuestion())) { 755 len = fir::CharacterType::unknownLen(); 756 } else if (!mlir::succeeded(parser.parseInteger(len))) { 757 return {}; 758 } 759 } 760 if (parser.parseGreater()) 761 return {}; 762 return get(parser.getContext(), kind, len); 763 } 764 765 void fir::CharacterType::print(mlir::AsmPrinter &printer) const { 766 printer << "<" << getFKind(); 767 auto len = getLen(); 768 if (len != fir::CharacterType::singleton()) { 769 printer << ','; 770 if (len == fir::CharacterType::unknownLen()) 771 printer << '?'; 772 else 773 printer << len; 774 } 775 printer << '>'; 776 } 777 778 //===----------------------------------------------------------------------===// 779 // ClassType 780 //===----------------------------------------------------------------------===// 781 782 llvm::LogicalResult 783 fir::ClassType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, 784 mlir::Type eleTy) { 785 if (mlir::isa<fir::RecordType, fir::SequenceType, fir::HeapType, 786 fir::PointerType, mlir::NoneType, mlir::IntegerType, 787 mlir::FloatType, fir::CharacterType, fir::LogicalType, 788 mlir::ComplexType>(eleTy)) 789 return mlir::success(); 790 return emitError() << "invalid element type\n"; 791 } 792 793 //===----------------------------------------------------------------------===// 794 // HeapType 795 //===----------------------------------------------------------------------===// 796 797 // `heap` `<` type `>` 798 mlir::Type fir::HeapType::parse(mlir::AsmParser &parser) { 799 return parseTypeSingleton<HeapType>(parser); 800 } 801 802 void fir::HeapType::print(mlir::AsmPrinter &printer) const { 803 printer << "<" << getEleTy() << '>'; 804 } 805 806 llvm::LogicalResult 807 fir::HeapType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, 808 mlir::Type eleTy) { 809 if (cannotBePointerOrHeapElementType(eleTy)) 810 return emitError() << "cannot build a heap pointer to type: " << eleTy 811 << '\n'; 812 return mlir::success(); 813 } 814 815 //===----------------------------------------------------------------------===// 816 // IntegerType 817 //===----------------------------------------------------------------------===// 818 819 // `int` `<` kind `>` 820 mlir::Type fir::IntegerType::parse(mlir::AsmParser &parser) { 821 return parseKindSingleton<fir::IntegerType>(parser); 822 } 823 824 void fir::IntegerType::print(mlir::AsmPrinter &printer) const { 825 printer << "<" << getFKind() << '>'; 826 } 827 828 //===----------------------------------------------------------------------===// 829 // UnsignedType 830 //===----------------------------------------------------------------------===// 831 832 // `unsigned` `<` kind `>` 833 mlir::Type fir::UnsignedType::parse(mlir::AsmParser &parser) { 834 return parseKindSingleton<fir::UnsignedType>(parser); 835 } 836 837 void fir::UnsignedType::print(mlir::AsmPrinter &printer) const { 838 printer << "<" << getFKind() << '>'; 839 } 840 841 //===----------------------------------------------------------------------===// 842 // LogicalType 843 //===----------------------------------------------------------------------===// 844 845 // `logical` `<` kind `>` 846 mlir::Type fir::LogicalType::parse(mlir::AsmParser &parser) { 847 return parseKindSingleton<fir::LogicalType>(parser); 848 } 849 850 void fir::LogicalType::print(mlir::AsmPrinter &printer) const { 851 printer << "<" << getFKind() << '>'; 852 } 853 854 //===----------------------------------------------------------------------===// 855 // PointerType 856 //===----------------------------------------------------------------------===// 857 858 // `ptr` `<` type `>` 859 mlir::Type fir::PointerType::parse(mlir::AsmParser &parser) { 860 return parseTypeSingleton<fir::PointerType>(parser); 861 } 862 863 void fir::PointerType::print(mlir::AsmPrinter &printer) const { 864 printer << "<" << getEleTy() << '>'; 865 } 866 867 llvm::LogicalResult fir::PointerType::verify( 868 llvm::function_ref<mlir::InFlightDiagnostic()> emitError, 869 mlir::Type eleTy) { 870 if (cannotBePointerOrHeapElementType(eleTy)) 871 return emitError() << "cannot build a pointer to type: " << eleTy << '\n'; 872 return mlir::success(); 873 } 874 875 //===----------------------------------------------------------------------===// 876 // RecordType 877 //===----------------------------------------------------------------------===// 878 879 // Fortran derived type 880 // unpacked: 881 // `type` `<` name 882 // (`(` id `:` type (`,` id `:` type)* `)`)? 883 // (`{` id `:` type (`,` id `:` type)* `}`)? '>' 884 // packed: 885 // `type` `<` name 886 // (`(` id `:` type (`,` id `:` type)* `)`)? 887 // (`<{` id `:` type (`,` id `:` type)* `}>`)? '>' 888 mlir::Type fir::RecordType::parse(mlir::AsmParser &parser) { 889 llvm::StringRef name; 890 if (parser.parseLess() || parser.parseKeyword(&name)) 891 return {}; 892 RecordType result = RecordType::get(parser.getContext(), name); 893 894 RecordType::TypeList lenParamList; 895 if (!parser.parseOptionalLParen()) { 896 while (true) { 897 llvm::StringRef lenparam; 898 mlir::Type intTy; 899 if (parser.parseKeyword(&lenparam) || parser.parseColon() || 900 parser.parseType(intTy)) { 901 parser.emitError(parser.getNameLoc(), "expected LEN parameter list"); 902 return {}; 903 } 904 lenParamList.emplace_back(lenparam, intTy); 905 if (parser.parseOptionalComma()) 906 break; 907 } 908 if (parser.parseRParen()) 909 return {}; 910 } 911 912 RecordType::TypeList typeList; 913 if (!parser.parseOptionalLess()) { 914 result.pack(true); 915 } 916 917 if (!parser.parseOptionalLBrace()) { 918 while (true) { 919 llvm::StringRef field; 920 mlir::Type fldTy; 921 if (parser.parseKeyword(&field) || parser.parseColon() || 922 parser.parseType(fldTy)) { 923 parser.emitError(parser.getNameLoc(), "expected field type list"); 924 return {}; 925 } 926 typeList.emplace_back(field, fldTy); 927 if (parser.parseOptionalComma()) 928 break; 929 } 930 if (parser.parseOptionalGreater()) { 931 if (parser.parseRBrace()) 932 return {}; 933 } 934 } 935 936 if (parser.parseGreater()) 937 return {}; 938 939 if (lenParamList.empty() && typeList.empty()) 940 return result; 941 942 result.finalize(lenParamList, typeList); 943 return verifyDerived(parser, result, lenParamList, typeList); 944 } 945 946 void fir::RecordType::print(mlir::AsmPrinter &printer) const { 947 printer << "<" << getName(); 948 if (!recordTypeVisited.count(uniqueKey())) { 949 recordTypeVisited.insert(uniqueKey()); 950 if (getLenParamList().size()) { 951 char ch = '('; 952 for (auto p : getLenParamList()) { 953 printer << ch << p.first << ':'; 954 p.second.print(printer.getStream()); 955 ch = ','; 956 } 957 printer << ')'; 958 } 959 if (getTypeList().size()) { 960 if (isPacked()) { 961 printer << '<'; 962 } 963 char ch = '{'; 964 for (auto p : getTypeList()) { 965 printer << ch << p.first << ':'; 966 p.second.print(printer.getStream()); 967 ch = ','; 968 } 969 printer << '}'; 970 if (isPacked()) { 971 printer << '>'; 972 } 973 } 974 recordTypeVisited.erase(uniqueKey()); 975 } 976 printer << '>'; 977 } 978 979 void fir::RecordType::finalize(llvm::ArrayRef<TypePair> lenPList, 980 llvm::ArrayRef<TypePair> typeList) { 981 getImpl()->finalize(lenPList, typeList); 982 } 983 984 llvm::StringRef fir::RecordType::getName() const { 985 return getImpl()->getName(); 986 } 987 988 RecordType::TypeList fir::RecordType::getTypeList() const { 989 return getImpl()->getTypeList(); 990 } 991 992 RecordType::TypeList fir::RecordType::getLenParamList() const { 993 return getImpl()->getLenParamList(); 994 } 995 996 bool fir::RecordType::isFinalized() const { return getImpl()->isFinalized(); } 997 998 void fir::RecordType::pack(bool p) { getImpl()->pack(p); } 999 1000 bool fir::RecordType::isPacked() const { return getImpl()->isPacked(); } 1001 1002 detail::RecordTypeStorage const *fir::RecordType::uniqueKey() const { 1003 return getImpl(); 1004 } 1005 1006 llvm::LogicalResult fir::RecordType::verify( 1007 llvm::function_ref<mlir::InFlightDiagnostic()> emitError, 1008 llvm::StringRef name) { 1009 if (name.size() == 0) 1010 return emitError() << "record types must have a name"; 1011 return mlir::success(); 1012 } 1013 1014 mlir::Type fir::RecordType::getType(llvm::StringRef ident) { 1015 for (auto f : getTypeList()) 1016 if (ident == f.first) 1017 return f.second; 1018 return {}; 1019 } 1020 1021 unsigned fir::RecordType::getFieldIndex(llvm::StringRef ident) { 1022 for (auto f : llvm::enumerate(getTypeList())) 1023 if (ident == f.value().first) 1024 return f.index(); 1025 return std::numeric_limits<unsigned>::max(); 1026 } 1027 1028 //===----------------------------------------------------------------------===// 1029 // ReferenceType 1030 //===----------------------------------------------------------------------===// 1031 1032 // `ref` `<` type `>` 1033 mlir::Type fir::ReferenceType::parse(mlir::AsmParser &parser) { 1034 return parseTypeSingleton<fir::ReferenceType>(parser); 1035 } 1036 1037 void fir::ReferenceType::print(mlir::AsmPrinter &printer) const { 1038 printer << "<" << getEleTy() << '>'; 1039 } 1040 1041 llvm::LogicalResult fir::ReferenceType::verify( 1042 llvm::function_ref<mlir::InFlightDiagnostic()> emitError, 1043 mlir::Type eleTy) { 1044 if (mlir::isa<ShapeType, ShapeShiftType, SliceType, FieldType, LenType, 1045 ReferenceType, TypeDescType>(eleTy)) 1046 return emitError() << "cannot build a reference to type: " << eleTy << '\n'; 1047 return mlir::success(); 1048 } 1049 1050 //===----------------------------------------------------------------------===// 1051 // SequenceType 1052 //===----------------------------------------------------------------------===// 1053 1054 // `array` `<` `*` | bounds (`x` bounds)* `:` type (',' affine-map)? `>` 1055 // bounds ::= `?` | int-lit 1056 mlir::Type fir::SequenceType::parse(mlir::AsmParser &parser) { 1057 if (parser.parseLess()) 1058 return {}; 1059 SequenceType::Shape shape; 1060 if (parser.parseOptionalStar()) { 1061 if (parser.parseDimensionList(shape, /*allowDynamic=*/true)) 1062 return {}; 1063 } else if (parser.parseColon()) { 1064 return {}; 1065 } 1066 mlir::Type eleTy; 1067 if (parser.parseType(eleTy)) 1068 return {}; 1069 mlir::AffineMapAttr map; 1070 if (!parser.parseOptionalComma()) { 1071 if (parser.parseAttribute(map)) { 1072 parser.emitError(parser.getNameLoc(), "expecting affine map"); 1073 return {}; 1074 } 1075 } 1076 if (parser.parseGreater()) 1077 return {}; 1078 return SequenceType::get(parser.getContext(), shape, eleTy, map); 1079 } 1080 1081 void fir::SequenceType::print(mlir::AsmPrinter &printer) const { 1082 auto shape = getShape(); 1083 if (shape.size()) { 1084 printer << '<'; 1085 for (const auto &b : shape) { 1086 if (b >= 0) 1087 printer << b << 'x'; 1088 else 1089 printer << "?x"; 1090 } 1091 } else { 1092 printer << "<*:"; 1093 } 1094 printer << getEleTy(); 1095 if (auto map = getLayoutMap()) { 1096 printer << ", "; 1097 map.print(printer.getStream()); 1098 } 1099 printer << '>'; 1100 } 1101 1102 unsigned fir::SequenceType::getConstantRows() const { 1103 if (hasDynamicSize(getEleTy())) 1104 return 0; 1105 auto shape = getShape(); 1106 unsigned count = 0; 1107 for (auto d : shape) { 1108 if (d == getUnknownExtent()) 1109 break; 1110 ++count; 1111 } 1112 return count; 1113 } 1114 1115 llvm::LogicalResult fir::SequenceType::verify( 1116 llvm::function_ref<mlir::InFlightDiagnostic()> emitError, 1117 llvm::ArrayRef<int64_t> shape, mlir::Type eleTy, 1118 mlir::AffineMapAttr layoutMap) { 1119 // DIMENSION attribute can only be applied to an intrinsic or record type 1120 if (mlir::isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType, 1121 ShiftType, SliceType, FieldType, LenType, HeapType, PointerType, 1122 ReferenceType, TypeDescType, SequenceType>(eleTy)) 1123 return emitError() << "cannot build an array of this element type: " 1124 << eleTy << '\n'; 1125 return mlir::success(); 1126 } 1127 1128 //===----------------------------------------------------------------------===// 1129 // ShapeType 1130 //===----------------------------------------------------------------------===// 1131 1132 mlir::Type fir::ShapeType::parse(mlir::AsmParser &parser) { 1133 return parseRankSingleton<fir::ShapeType>(parser); 1134 } 1135 1136 void fir::ShapeType::print(mlir::AsmPrinter &printer) const { 1137 printer << "<" << getImpl()->rank << ">"; 1138 } 1139 1140 //===----------------------------------------------------------------------===// 1141 // ShapeShiftType 1142 //===----------------------------------------------------------------------===// 1143 1144 mlir::Type fir::ShapeShiftType::parse(mlir::AsmParser &parser) { 1145 return parseRankSingleton<fir::ShapeShiftType>(parser); 1146 } 1147 1148 void fir::ShapeShiftType::print(mlir::AsmPrinter &printer) const { 1149 printer << "<" << getRank() << ">"; 1150 } 1151 1152 //===----------------------------------------------------------------------===// 1153 // ShiftType 1154 //===----------------------------------------------------------------------===// 1155 1156 mlir::Type fir::ShiftType::parse(mlir::AsmParser &parser) { 1157 return parseRankSingleton<fir::ShiftType>(parser); 1158 } 1159 1160 void fir::ShiftType::print(mlir::AsmPrinter &printer) const { 1161 printer << "<" << getRank() << ">"; 1162 } 1163 1164 //===----------------------------------------------------------------------===// 1165 // SliceType 1166 //===----------------------------------------------------------------------===// 1167 1168 // `slice` `<` rank `>` 1169 mlir::Type fir::SliceType::parse(mlir::AsmParser &parser) { 1170 return parseRankSingleton<fir::SliceType>(parser); 1171 } 1172 1173 void fir::SliceType::print(mlir::AsmPrinter &printer) const { 1174 printer << "<" << getRank() << '>'; 1175 } 1176 1177 //===----------------------------------------------------------------------===// 1178 // TypeDescType 1179 //===----------------------------------------------------------------------===// 1180 1181 // `tdesc` `<` type `>` 1182 mlir::Type fir::TypeDescType::parse(mlir::AsmParser &parser) { 1183 return parseTypeSingleton<fir::TypeDescType>(parser); 1184 } 1185 1186 void fir::TypeDescType::print(mlir::AsmPrinter &printer) const { 1187 printer << "<" << getOfTy() << '>'; 1188 } 1189 1190 llvm::LogicalResult fir::TypeDescType::verify( 1191 llvm::function_ref<mlir::InFlightDiagnostic()> emitError, 1192 mlir::Type eleTy) { 1193 if (mlir::isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType, 1194 ShiftType, SliceType, FieldType, LenType, ReferenceType, 1195 TypeDescType>(eleTy)) 1196 return emitError() << "cannot build a type descriptor of type: " << eleTy 1197 << '\n'; 1198 return mlir::success(); 1199 } 1200 1201 //===----------------------------------------------------------------------===// 1202 // VectorType 1203 //===----------------------------------------------------------------------===// 1204 1205 // `vector` `<` len `:` type `>` 1206 mlir::Type fir::VectorType::parse(mlir::AsmParser &parser) { 1207 int64_t len = 0; 1208 mlir::Type eleTy; 1209 if (parser.parseLess() || parser.parseInteger(len) || parser.parseColon() || 1210 parser.parseType(eleTy) || parser.parseGreater()) 1211 return {}; 1212 return fir::VectorType::get(len, eleTy); 1213 } 1214 1215 void fir::VectorType::print(mlir::AsmPrinter &printer) const { 1216 printer << "<" << getLen() << ':' << getEleTy() << '>'; 1217 } 1218 1219 llvm::LogicalResult fir::VectorType::verify( 1220 llvm::function_ref<mlir::InFlightDiagnostic()> emitError, uint64_t len, 1221 mlir::Type eleTy) { 1222 if (!(fir::isa_real(eleTy) || fir::isa_integer(eleTy))) 1223 return emitError() << "cannot build a vector of type " << eleTy << '\n'; 1224 return mlir::success(); 1225 } 1226 1227 bool fir::VectorType::isValidElementType(mlir::Type t) { 1228 return isa_real(t) || isa_integer(t); 1229 } 1230 1231 bool fir::isCharacterProcedureTuple(mlir::Type ty, bool acceptRawFunc) { 1232 mlir::TupleType tuple = mlir::dyn_cast<mlir::TupleType>(ty); 1233 return tuple && tuple.size() == 2 && 1234 (mlir::isa<fir::BoxProcType>(tuple.getType(0)) || 1235 (acceptRawFunc && mlir::isa<mlir::FunctionType>(tuple.getType(0)))) && 1236 fir::isa_integer(tuple.getType(1)); 1237 } 1238 1239 bool fir::hasAbstractResult(mlir::FunctionType ty) { 1240 if (ty.getNumResults() == 0) 1241 return false; 1242 auto resultType = ty.getResult(0); 1243 return mlir::isa<fir::SequenceType, fir::BaseBoxType, fir::RecordType>( 1244 resultType); 1245 } 1246 1247 /// Convert llvm::Type::TypeID to mlir::Type. \p kind is provided for error 1248 /// messages only. 1249 mlir::Type fir::fromRealTypeID(mlir::MLIRContext *context, 1250 llvm::Type::TypeID typeID, fir::KindTy kind) { 1251 switch (typeID) { 1252 case llvm::Type::TypeID::HalfTyID: 1253 return mlir::Float16Type::get(context); 1254 case llvm::Type::TypeID::BFloatTyID: 1255 return mlir::BFloat16Type::get(context); 1256 case llvm::Type::TypeID::FloatTyID: 1257 return mlir::Float32Type::get(context); 1258 case llvm::Type::TypeID::DoubleTyID: 1259 return mlir::Float64Type::get(context); 1260 case llvm::Type::TypeID::X86_FP80TyID: 1261 return mlir::Float80Type::get(context); 1262 case llvm::Type::TypeID::FP128TyID: 1263 return mlir::Float128Type::get(context); 1264 default: 1265 mlir::emitError(mlir::UnknownLoc::get(context)) 1266 << "unsupported type: !fir.real<" << kind << ">"; 1267 return {}; 1268 } 1269 } 1270 1271 //===----------------------------------------------------------------------===// 1272 // BaseBoxType 1273 //===----------------------------------------------------------------------===// 1274 1275 mlir::Type BaseBoxType::getEleTy() const { 1276 return llvm::TypeSwitch<fir::BaseBoxType, mlir::Type>(*this) 1277 .Case<fir::BoxType, fir::ClassType>( 1278 [](auto type) { return type.getEleTy(); }); 1279 } 1280 1281 mlir::Type BaseBoxType::unwrapInnerType() const { 1282 return fir::unwrapInnerType(getEleTy()); 1283 } 1284 1285 static mlir::Type 1286 changeTypeShape(mlir::Type type, 1287 std::optional<fir::SequenceType::ShapeRef> newShape) { 1288 return llvm::TypeSwitch<mlir::Type, mlir::Type>(type) 1289 .Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type { 1290 if (newShape) 1291 return fir::SequenceType::get(*newShape, seqTy.getEleTy()); 1292 return seqTy.getEleTy(); 1293 }) 1294 .Case<fir::PointerType, fir::HeapType, fir::ReferenceType, fir::BoxType, 1295 fir::ClassType>([&](auto t) -> mlir::Type { 1296 using FIRT = decltype(t); 1297 return FIRT::get(changeTypeShape(t.getEleTy(), newShape)); 1298 }) 1299 .Default([&](mlir::Type t) -> mlir::Type { 1300 assert((fir::isa_trivial(t) || llvm::isa<fir::RecordType>(t) || 1301 llvm::isa<mlir::NoneType>(t)) && 1302 "unexpected FIR leaf type"); 1303 if (newShape) 1304 return fir::SequenceType::get(*newShape, t); 1305 return t; 1306 }); 1307 } 1308 1309 fir::BaseBoxType 1310 fir::BaseBoxType::getBoxTypeWithNewShape(mlir::Type shapeMold) const { 1311 fir::SequenceType seqTy = fir::unwrapUntilSeqType(shapeMold); 1312 std::optional<fir::SequenceType::ShapeRef> newShape; 1313 if (seqTy) 1314 newShape = seqTy.getShape(); 1315 return mlir::cast<fir::BaseBoxType>(changeTypeShape(*this, newShape)); 1316 } 1317 1318 fir::BaseBoxType fir::BaseBoxType::getBoxTypeWithNewShape(int rank) const { 1319 std::optional<fir::SequenceType::ShapeRef> newShape; 1320 fir::SequenceType::Shape shapeVector; 1321 if (rank > 0) { 1322 shapeVector = 1323 fir::SequenceType::Shape(rank, fir::SequenceType::getUnknownExtent()); 1324 newShape = shapeVector; 1325 } 1326 return mlir::cast<fir::BaseBoxType>(changeTypeShape(*this, newShape)); 1327 } 1328 1329 fir::BaseBoxType fir::BaseBoxType::getBoxTypeWithNewAttr( 1330 fir::BaseBoxType::Attribute attr) const { 1331 mlir::Type baseType = fir::unwrapRefType(getEleTy()); 1332 switch (attr) { 1333 case fir::BaseBoxType::Attribute::None: 1334 break; 1335 case fir::BaseBoxType::Attribute::Allocatable: 1336 baseType = fir::HeapType::get(baseType); 1337 break; 1338 case fir::BaseBoxType::Attribute::Pointer: 1339 baseType = fir::PointerType::get(baseType); 1340 break; 1341 } 1342 return llvm::TypeSwitch<fir::BaseBoxType, fir::BaseBoxType>(*this) 1343 .Case<fir::BoxType>( 1344 [baseType](auto) { return fir::BoxType::get(baseType); }) 1345 .Case<fir::ClassType>( 1346 [baseType](auto) { return fir::ClassType::get(baseType); }); 1347 } 1348 1349 bool fir::BaseBoxType::isAssumedRank() const { 1350 if (auto seqTy = 1351 mlir::dyn_cast<fir::SequenceType>(fir::unwrapRefType(getEleTy()))) 1352 return seqTy.hasUnknownShape(); 1353 return false; 1354 } 1355 1356 //===----------------------------------------------------------------------===// 1357 // FIROpsDialect 1358 //===----------------------------------------------------------------------===// 1359 1360 void FIROpsDialect::registerTypes() { 1361 addTypes<BoxType, BoxCharType, BoxProcType, CharacterType, ClassType, 1362 FieldType, HeapType, fir::IntegerType, LenType, LogicalType, 1363 LLVMPointerType, PointerType, RecordType, ReferenceType, 1364 SequenceType, ShapeType, ShapeShiftType, ShiftType, SliceType, 1365 TypeDescType, fir::VectorType, fir::DummyScopeType>(); 1366 fir::ReferenceType::attachInterface< 1367 OpenMPPointerLikeModel<fir::ReferenceType>>(*getContext()); 1368 fir::ReferenceType::attachInterface< 1369 OpenACCPointerLikeModel<fir::ReferenceType>>(*getContext()); 1370 1371 fir::PointerType::attachInterface<OpenMPPointerLikeModel<fir::PointerType>>( 1372 *getContext()); 1373 fir::PointerType::attachInterface<OpenACCPointerLikeModel<fir::PointerType>>( 1374 *getContext()); 1375 1376 fir::HeapType::attachInterface<OpenMPPointerLikeModel<fir::HeapType>>( 1377 *getContext()); 1378 fir::HeapType::attachInterface<OpenACCPointerLikeModel<fir::HeapType>>( 1379 *getContext()); 1380 1381 fir::LLVMPointerType::attachInterface< 1382 OpenMPPointerLikeModel<fir::LLVMPointerType>>(*getContext()); 1383 fir::LLVMPointerType::attachInterface< 1384 OpenACCPointerLikeModel<fir::LLVMPointerType>>(*getContext()); 1385 } 1386 1387 std::optional<std::pair<uint64_t, unsigned short>> 1388 fir::getTypeSizeAndAlignment(mlir::Location loc, mlir::Type ty, 1389 const mlir::DataLayout &dl, 1390 const fir::KindMapping &kindMap) { 1391 if (mlir::isa<mlir::IntegerType, mlir::FloatType, mlir::ComplexType>(ty)) { 1392 llvm::TypeSize size = dl.getTypeSize(ty); 1393 unsigned short alignment = dl.getTypeABIAlignment(ty); 1394 return std::pair{size, alignment}; 1395 } 1396 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) { 1397 auto result = getTypeSizeAndAlignment(loc, seqTy.getEleTy(), dl, kindMap); 1398 if (!result) 1399 return result; 1400 auto [eleSize, eleAlign] = *result; 1401 std::uint64_t size = 1402 llvm::alignTo(eleSize, eleAlign) * seqTy.getConstantArraySize(); 1403 return std::pair{size, eleAlign}; 1404 } 1405 if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty)) { 1406 std::uint64_t size = 0; 1407 unsigned short align = 1; 1408 for (auto component : recTy.getTypeList()) { 1409 auto result = getTypeSizeAndAlignment(loc, component.second, dl, kindMap); 1410 if (!result) 1411 return result; 1412 auto [compSize, compAlign] = *result; 1413 size = 1414 llvm::alignTo(size, compAlign) + llvm::alignTo(compSize, compAlign); 1415 align = std::max(align, compAlign); 1416 } 1417 return std::pair{size, align}; 1418 } 1419 if (auto logical = mlir::dyn_cast<fir::LogicalType>(ty)) { 1420 mlir::Type intTy = mlir::IntegerType::get( 1421 logical.getContext(), kindMap.getLogicalBitsize(logical.getFKind())); 1422 return getTypeSizeAndAlignment(loc, intTy, dl, kindMap); 1423 } 1424 if (auto character = mlir::dyn_cast<fir::CharacterType>(ty)) { 1425 mlir::Type intTy = mlir::IntegerType::get( 1426 character.getContext(), 1427 kindMap.getCharacterBitsize(character.getFKind())); 1428 auto result = getTypeSizeAndAlignment(loc, intTy, dl, kindMap); 1429 if (!result) 1430 return result; 1431 auto [compSize, compAlign] = *result; 1432 if (character.hasConstantLen()) 1433 compSize *= character.getLen(); 1434 return std::pair{compSize, compAlign}; 1435 } 1436 return std::nullopt; 1437 } 1438 1439 std::pair<std::uint64_t, unsigned short> 1440 fir::getTypeSizeAndAlignmentOrCrash(mlir::Location loc, mlir::Type ty, 1441 const mlir::DataLayout &dl, 1442 const fir::KindMapping &kindMap) { 1443 std::optional<std::pair<uint64_t, unsigned short>> result = 1444 getTypeSizeAndAlignment(loc, ty, dl, kindMap); 1445 if (result) 1446 return *result; 1447 TODO(loc, "computing size of a component"); 1448 } 1449