xref: /llvm-project/flang/lib/Optimizer/Dialect/FIRType.cpp (revision af91372b75613d5654e68d393477e8621cb93da7)
1 //===-- FIRType.cpp -------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/Dialect/FIRType.h"
14 #include "flang/ISO_Fortran_binding_wrapper.h"
15 #include "flang/Optimizer/Builder/Todo.h"
16 #include "flang/Optimizer/Dialect/FIRDialect.h"
17 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
18 #include "flang/Tools/PointerModels.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinDialect.h"
21 #include "mlir/IR/Diagnostics.h"
22 #include "mlir/IR/DialectImplementation.h"
23 #include "llvm/ADT/SmallPtrSet.h"
24 #include "llvm/ADT/StringSet.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/ErrorHandling.h"
27 
28 #define GET_TYPEDEF_CLASSES
29 #include "flang/Optimizer/Dialect/FIROpsTypes.cpp.inc"
30 
31 using namespace fir;
32 
33 namespace {
34 
35 template <typename TYPE>
36 TYPE parseIntSingleton(mlir::AsmParser &parser) {
37   int kind = 0;
38   if (parser.parseLess() || parser.parseInteger(kind) || parser.parseGreater())
39     return {};
40   return TYPE::get(parser.getContext(), kind);
41 }
42 
43 template <typename TYPE>
44 TYPE parseKindSingleton(mlir::AsmParser &parser) {
45   return parseIntSingleton<TYPE>(parser);
46 }
47 
48 template <typename TYPE>
49 TYPE parseRankSingleton(mlir::AsmParser &parser) {
50   return parseIntSingleton<TYPE>(parser);
51 }
52 
53 template <typename TYPE>
54 TYPE parseTypeSingleton(mlir::AsmParser &parser) {
55   mlir::Type ty;
56   if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater())
57     return {};
58   return TYPE::get(ty);
59 }
60 
61 /// Is `ty` a standard or FIR integer type?
62 static bool isaIntegerType(mlir::Type ty) {
63   // TODO: why aren't we using isa_integer? investigatation required.
64   return mlir::isa<mlir::IntegerType, fir::IntegerType>(ty);
65 }
66 
67 bool verifyRecordMemberType(mlir::Type ty) {
68   return !mlir::isa<BoxCharType, ShapeType, ShapeShiftType, ShiftType,
69                     SliceType, FieldType, LenType, ReferenceType, TypeDescType>(
70       ty);
71 }
72 
73 bool verifySameLists(llvm::ArrayRef<RecordType::TypePair> a1,
74                      llvm::ArrayRef<RecordType::TypePair> a2) {
75   // FIXME: do we need to allow for any variance here?
76   return a1 == a2;
77 }
78 
79 RecordType verifyDerived(mlir::AsmParser &parser, RecordType derivedTy,
80                          llvm::ArrayRef<RecordType::TypePair> lenPList,
81                          llvm::ArrayRef<RecordType::TypePair> typeList) {
82   auto loc = parser.getNameLoc();
83   if (!verifySameLists(derivedTy.getLenParamList(), lenPList) ||
84       !verifySameLists(derivedTy.getTypeList(), typeList)) {
85     parser.emitError(loc, "cannot redefine record type members");
86     return {};
87   }
88   for (auto &p : lenPList)
89     if (!isaIntegerType(p.second)) {
90       parser.emitError(loc, "LEN parameter must be integral type");
91       return {};
92     }
93   for (auto &p : typeList)
94     if (!verifyRecordMemberType(p.second)) {
95       parser.emitError(loc, "field parameter has invalid type");
96       return {};
97     }
98   llvm::StringSet<> uniq;
99   for (auto &p : lenPList)
100     if (!uniq.insert(p.first).second) {
101       parser.emitError(loc, "LEN parameter cannot have duplicate name");
102       return {};
103     }
104   for (auto &p : typeList)
105     if (!uniq.insert(p.first).second) {
106       parser.emitError(loc, "field cannot have duplicate name");
107       return {};
108     }
109   return derivedTy;
110 }
111 
112 } // namespace
113 
114 // Implementation of the thin interface from dialect to type parser
115 
116 mlir::Type fir::parseFirType(FIROpsDialect *dialect,
117                              mlir::DialectAsmParser &parser) {
118   mlir::StringRef typeTag;
119   mlir::Type genType;
120   auto parseResult = generatedTypeParser(parser, &typeTag, genType);
121   if (parseResult.has_value())
122     return genType;
123   parser.emitError(parser.getNameLoc(), "unknown fir type: ") << typeTag;
124   return {};
125 }
126 
127 namespace fir {
128 namespace detail {
129 
130 // Type storage classes
131 
132 /// Derived type storage
133 struct RecordTypeStorage : public mlir::TypeStorage {
134   using KeyTy = llvm::StringRef;
135 
136   static unsigned hashKey(const KeyTy &key) {
137     return llvm::hash_combine(key.str());
138   }
139 
140   bool operator==(const KeyTy &key) const { return key == getName(); }
141 
142   static RecordTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
143                                       const KeyTy &key) {
144     auto *storage = allocator.allocate<RecordTypeStorage>();
145     return new (storage) RecordTypeStorage{key};
146   }
147 
148   llvm::StringRef getName() const { return name; }
149 
150   void setLenParamList(llvm::ArrayRef<RecordType::TypePair> list) {
151     lens = list;
152   }
153   llvm::ArrayRef<RecordType::TypePair> getLenParamList() const { return lens; }
154 
155   void setTypeList(llvm::ArrayRef<RecordType::TypePair> list) { types = list; }
156   llvm::ArrayRef<RecordType::TypePair> getTypeList() const { return types; }
157 
158   bool isFinalized() const { return finalized; }
159   void finalize(llvm::ArrayRef<RecordType::TypePair> lenParamList,
160                 llvm::ArrayRef<RecordType::TypePair> typeList) {
161     if (finalized)
162       return;
163     finalized = true;
164     setLenParamList(lenParamList);
165     setTypeList(typeList);
166   }
167 
168   bool isPacked() const { return packed; }
169   void pack(bool p) { packed = p; }
170 
171 protected:
172   std::string name;
173   bool finalized;
174   bool packed;
175   std::vector<RecordType::TypePair> lens;
176   std::vector<RecordType::TypePair> types;
177 
178 private:
179   RecordTypeStorage() = delete;
180   explicit RecordTypeStorage(llvm::StringRef name)
181       : name{name}, finalized{false}, packed{false} {}
182 };
183 
184 } // namespace detail
185 
186 template <typename A, typename B>
187 bool inbounds(A v, B lb, B ub) {
188   return v >= lb && v < ub;
189 }
190 
191 bool isa_fir_type(mlir::Type t) {
192   return llvm::isa<FIROpsDialect>(t.getDialect());
193 }
194 
195 bool isa_std_type(mlir::Type t) {
196   return llvm::isa<mlir::BuiltinDialect>(t.getDialect());
197 }
198 
199 bool isa_fir_or_std_type(mlir::Type t) {
200   if (auto funcType = mlir::dyn_cast<mlir::FunctionType>(t))
201     return llvm::all_of(funcType.getInputs(), isa_fir_or_std_type) &&
202            llvm::all_of(funcType.getResults(), isa_fir_or_std_type);
203   return isa_fir_type(t) || isa_std_type(t);
204 }
205 
206 mlir::Type getDerivedType(mlir::Type ty) {
207   return llvm::TypeSwitch<mlir::Type, mlir::Type>(ty)
208       .Case<fir::PointerType, fir::HeapType, fir::SequenceType>([](auto p) {
209         if (auto seq = mlir::dyn_cast<fir::SequenceType>(p.getEleTy()))
210           return seq.getEleTy();
211         return p.getEleTy();
212       })
213       .Case<fir::BoxType>([](auto p) { return getDerivedType(p.getEleTy()); })
214       .Default([](mlir::Type t) { return t; });
215 }
216 
217 mlir::Type dyn_cast_ptrEleTy(mlir::Type t) {
218   return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
219       .Case<fir::ReferenceType, fir::PointerType, fir::HeapType,
220             fir::LLVMPointerType>([](auto p) { return p.getEleTy(); })
221       .Default([](mlir::Type) { return mlir::Type{}; });
222 }
223 
224 mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t) {
225   return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
226       .Case<fir::ReferenceType, fir::PointerType, fir::HeapType,
227             fir::LLVMPointerType>([](auto p) { return p.getEleTy(); })
228       .Case<fir::BaseBoxType>(
229           [](auto p) { return unwrapRefType(p.getEleTy()); })
230       .Default([](mlir::Type) { return mlir::Type{}; });
231 }
232 
233 static bool hasDynamicSize(fir::RecordType recTy) {
234   for (auto field : recTy.getTypeList()) {
235     if (auto arr = mlir::dyn_cast<fir::SequenceType>(field.second)) {
236       if (sequenceWithNonConstantShape(arr))
237         return true;
238     } else if (characterWithDynamicLen(field.second)) {
239       return true;
240     } else if (auto rec = mlir::dyn_cast<fir::RecordType>(field.second)) {
241       if (hasDynamicSize(rec))
242         return true;
243     }
244   }
245   return false;
246 }
247 
248 bool hasDynamicSize(mlir::Type t) {
249   if (auto arr = mlir::dyn_cast<fir::SequenceType>(t)) {
250     if (sequenceWithNonConstantShape(arr))
251       return true;
252     t = arr.getEleTy();
253   }
254   if (characterWithDynamicLen(t))
255     return true;
256   if (auto rec = mlir::dyn_cast<fir::RecordType>(t))
257     return hasDynamicSize(rec);
258   return false;
259 }
260 
261 mlir::Type extractSequenceType(mlir::Type ty) {
262   if (mlir::isa<fir::SequenceType>(ty))
263     return ty;
264   if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
265     return extractSequenceType(boxTy.getEleTy());
266   if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
267     return extractSequenceType(heapTy.getEleTy());
268   if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty))
269     return extractSequenceType(ptrTy.getEleTy());
270   return mlir::Type{};
271 }
272 
273 bool isPointerType(mlir::Type ty) {
274   if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
275     ty = refTy;
276   if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
277     return mlir::isa<fir::PointerType>(boxTy.getEleTy());
278   return false;
279 }
280 
281 bool isAllocatableType(mlir::Type ty) {
282   if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
283     ty = refTy;
284   if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
285     return mlir::isa<fir::HeapType>(boxTy.getEleTy());
286   return false;
287 }
288 
289 bool isBoxNone(mlir::Type ty) {
290   if (auto box = mlir::dyn_cast<fir::BoxType>(ty))
291     return mlir::isa<mlir::NoneType>(box.getEleTy());
292   return false;
293 }
294 
295 bool isBoxedRecordType(mlir::Type ty) {
296   if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
297     ty = refTy;
298   if (auto boxTy = mlir::dyn_cast<fir::BoxType>(ty)) {
299     if (mlir::isa<fir::RecordType>(boxTy.getEleTy()))
300       return true;
301     mlir::Type innerType = boxTy.unwrapInnerType();
302     return innerType && mlir::isa<fir::RecordType>(innerType);
303   }
304   return false;
305 }
306 
307 bool isScalarBoxedRecordType(mlir::Type ty) {
308   if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
309     ty = refTy;
310   if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
311     if (mlir::isa<fir::RecordType>(boxTy.getEleTy()))
312       return true;
313     if (auto heapTy = mlir::dyn_cast<fir::HeapType>(boxTy.getEleTy()))
314       return mlir::isa<fir::RecordType>(heapTy.getEleTy());
315     if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(boxTy.getEleTy()))
316       return mlir::isa<fir::RecordType>(ptrTy.getEleTy());
317   }
318   return false;
319 }
320 
321 bool isAssumedType(mlir::Type ty) {
322   // Rule out CLASS(*) which are `fir.class<[fir.array] none>`.
323   if (mlir::isa<fir::ClassType>(ty))
324     return false;
325   mlir::Type valueType = fir::unwrapPassByRefType(fir::unwrapRefType(ty));
326   // Refuse raw `none` or `fir.array<none>` since assumed type
327   // should be in memory variables.
328   if (valueType == ty)
329     return false;
330   mlir::Type inner = fir::unwrapSequenceType(valueType);
331   return mlir::isa<mlir::NoneType>(inner);
332 }
333 
334 bool isAssumedShape(mlir::Type ty) {
335   if (auto boxTy = mlir::dyn_cast<fir::BoxType>(ty))
336     if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(boxTy.getEleTy()))
337       return seqTy.hasDynamicExtents();
338   return false;
339 }
340 
341 bool isAllocatableOrPointerArray(mlir::Type ty) {
342   if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
343     ty = refTy;
344   if (auto boxTy = mlir::dyn_cast<fir::BoxType>(ty)) {
345     if (auto heapTy = mlir::dyn_cast<fir::HeapType>(boxTy.getEleTy()))
346       return mlir::isa<fir::SequenceType>(heapTy.getEleTy());
347     if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(boxTy.getEleTy()))
348       return mlir::isa<fir::SequenceType>(ptrTy.getEleTy());
349   }
350   return false;
351 }
352 
353 bool isTypeWithDescriptor(mlir::Type ty) {
354   if (mlir::isa<fir::BaseBoxType>(unwrapRefType(ty)))
355     return true;
356   return false;
357 }
358 
359 bool isPolymorphicType(mlir::Type ty) {
360   // CLASS(T) or CLASS(*)
361   if (mlir::isa<fir::ClassType>(fir::unwrapRefType(ty)))
362     return true;
363   // assumed type are polymorphic.
364   return isAssumedType(ty);
365 }
366 
367 bool isUnlimitedPolymorphicType(mlir::Type ty) {
368   // CLASS(*)
369   if (auto clTy = mlir::dyn_cast<fir::ClassType>(fir::unwrapRefType(ty))) {
370     if (mlir::isa<mlir::NoneType>(clTy.getEleTy()))
371       return true;
372     mlir::Type innerType = clTy.unwrapInnerType();
373     return innerType && mlir::isa<mlir::NoneType>(innerType);
374   }
375   // TYPE(*)
376   return isAssumedType(ty);
377 }
378 
379 mlir::Type unwrapInnerType(mlir::Type ty) {
380   return llvm::TypeSwitch<mlir::Type, mlir::Type>(ty)
381       .Case<fir::PointerType, fir::HeapType, fir::SequenceType>([](auto t) {
382         mlir::Type eleTy = t.getEleTy();
383         if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy))
384           return seqTy.getEleTy();
385         return eleTy;
386       })
387       .Case<fir::RecordType>([](auto t) { return t; })
388       .Default([](mlir::Type) { return mlir::Type{}; });
389 }
390 
391 bool isRecordWithAllocatableMember(mlir::Type ty) {
392   if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty))
393     for (auto [field, memTy] : recTy.getTypeList()) {
394       if (fir::isAllocatableType(memTy))
395         return true;
396       // A record type cannot recursively include itself as a direct member.
397       // There must be an intervening `ptr` type, so recursion is safe here.
398       if (mlir::isa<fir::RecordType>(memTy) &&
399           isRecordWithAllocatableMember(memTy))
400         return true;
401     }
402   return false;
403 }
404 
405 bool isRecordWithDescriptorMember(mlir::Type ty) {
406   ty = unwrapSequenceType(ty);
407   if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty))
408     for (auto [field, memTy] : recTy.getTypeList()) {
409       if (mlir::isa<fir::BaseBoxType>(memTy))
410         return true;
411       if (mlir::isa<fir::RecordType>(memTy) &&
412           isRecordWithDescriptorMember(memTy))
413         return true;
414     }
415   return false;
416 }
417 
418 mlir::Type unwrapAllRefAndSeqType(mlir::Type ty) {
419   while (true) {
420     mlir::Type nt = unwrapSequenceType(unwrapRefType(ty));
421     if (auto vecTy = mlir::dyn_cast<fir::VectorType>(nt))
422       nt = vecTy.getEleTy();
423     if (nt == ty)
424       return ty;
425     ty = nt;
426   }
427 }
428 
429 mlir::Type unwrapSeqOrBoxedSeqType(mlir::Type ty) {
430   if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
431     return seqTy.getEleTy();
432   if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
433     auto eleTy = unwrapRefType(boxTy.getEleTy());
434     if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy))
435       return seqTy.getEleTy();
436   }
437   return ty;
438 }
439 
440 unsigned getBoxRank(mlir::Type boxTy) {
441   auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(boxTy);
442   if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy))
443     return seqTy.getDimension();
444   return 0;
445 }
446 
447 /// Return the ISO_C_BINDING intrinsic module value of type \p ty.
448 int getTypeCode(mlir::Type ty, const fir::KindMapping &kindMap) {
449   if (mlir::IntegerType intTy = mlir::dyn_cast<mlir::IntegerType>(ty)) {
450     if (intTy.isUnsigned()) {
451       switch (intTy.getWidth()) {
452       case 8:
453         return CFI_type_uint8_t;
454       case 16:
455         return CFI_type_uint16_t;
456       case 32:
457         return CFI_type_uint32_t;
458       case 64:
459         return CFI_type_uint64_t;
460       case 128:
461         return CFI_type_uint128_t;
462       }
463       llvm_unreachable("unsupported integer type");
464     } else {
465       switch (intTy.getWidth()) {
466       case 8:
467         return CFI_type_int8_t;
468       case 16:
469         return CFI_type_int16_t;
470       case 32:
471         return CFI_type_int32_t;
472       case 64:
473         return CFI_type_int64_t;
474       case 128:
475         return CFI_type_int128_t;
476       }
477       llvm_unreachable("unsupported integer type");
478     }
479   }
480   if (fir::LogicalType logicalTy = mlir::dyn_cast<fir::LogicalType>(ty)) {
481     switch (kindMap.getLogicalBitsize(logicalTy.getFKind())) {
482     case 8:
483       return CFI_type_Bool;
484     case 16:
485       return CFI_type_int_least16_t;
486     case 32:
487       return CFI_type_int_least32_t;
488     case 64:
489       return CFI_type_int_least64_t;
490     }
491     llvm_unreachable("unsupported logical type");
492   }
493   if (mlir::FloatType floatTy = mlir::dyn_cast<mlir::FloatType>(ty)) {
494     switch (floatTy.getWidth()) {
495     case 16:
496       return floatTy.isBF16() ? CFI_type_bfloat : CFI_type_half_float;
497     case 32:
498       return CFI_type_float;
499     case 64:
500       return CFI_type_double;
501     case 80:
502       return CFI_type_extended_double;
503     case 128:
504       return CFI_type_float128;
505     }
506     llvm_unreachable("unsupported real type");
507   }
508   if (mlir::ComplexType complexTy = mlir::dyn_cast<mlir::ComplexType>(ty)) {
509     mlir::FloatType floatTy =
510         mlir::cast<mlir::FloatType>(complexTy.getElementType());
511     if (floatTy.isBF16())
512       return CFI_type_bfloat_Complex;
513     switch (floatTy.getWidth()) {
514     case 16:
515       return CFI_type_half_float_Complex;
516     case 32:
517       return CFI_type_float_Complex;
518     case 64:
519       return CFI_type_double_Complex;
520     case 80:
521       return CFI_type_extended_double_Complex;
522     case 128:
523       return CFI_type_float128_Complex;
524     }
525     llvm_unreachable("unsupported complex size");
526   }
527   if (fir::CharacterType charTy = mlir::dyn_cast<fir::CharacterType>(ty)) {
528     switch (kindMap.getCharacterBitsize(charTy.getFKind())) {
529     case 8:
530       return CFI_type_char;
531     case 16:
532       return CFI_type_char16_t;
533     case 32:
534       return CFI_type_char32_t;
535     }
536     llvm_unreachable("unsupported character type");
537   }
538   if (fir::isa_ref_type(ty))
539     return CFI_type_cptr;
540   if (mlir::isa<fir::RecordType>(ty))
541     return CFI_type_struct;
542   llvm_unreachable("unsupported type");
543 }
544 
545 std::string getTypeAsString(mlir::Type ty, const fir::KindMapping &kindMap,
546                             llvm::StringRef prefix) {
547   std::string buf = prefix.str();
548   llvm::raw_string_ostream name{buf};
549   if (!prefix.empty())
550     name << "_";
551   while (ty) {
552     if (fir::isa_trivial(ty)) {
553       if (mlir::isa<mlir::IndexType>(ty)) {
554         name << "idx";
555       } else if (ty.isIntOrIndex()) {
556         name << 'i' << ty.getIntOrFloatBitWidth();
557       } else if (mlir::isa<mlir::FloatType>(ty)) {
558         name << 'f' << ty.getIntOrFloatBitWidth();
559       } else if (auto cplxTy = mlir::dyn_cast_or_null<mlir::ComplexType>(ty)) {
560         name << 'z';
561         auto floatTy = mlir::cast<mlir::FloatType>(cplxTy.getElementType());
562         name << floatTy.getWidth();
563       } else if (auto logTy = mlir::dyn_cast_or_null<fir::LogicalType>(ty)) {
564         name << 'l' << kindMap.getLogicalBitsize(logTy.getFKind());
565       } else {
566         llvm::report_fatal_error("unsupported type");
567       }
568       break;
569     } else if (mlir::isa<mlir::NoneType>(ty)) {
570       name << "none";
571       break;
572     } else if (auto charTy = mlir::dyn_cast_or_null<fir::CharacterType>(ty)) {
573       name << 'c' << kindMap.getCharacterBitsize(charTy.getFKind());
574       if (charTy.getLen() == fir::CharacterType::unknownLen())
575         name << "xU";
576       else if (charTy.getLen() != fir::CharacterType::singleton())
577         name << "x" << charTy.getLen();
578       break;
579     } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(ty)) {
580       for (auto extent : seqTy.getShape()) {
581         if (extent == fir::SequenceType::getUnknownExtent())
582           name << "Ux";
583         else
584           name << extent << 'x';
585       }
586       ty = seqTy.getEleTy();
587     } else if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(ty)) {
588       name << "ref_";
589       ty = refTy.getEleTy();
590     } else if (auto ptrTy = mlir::dyn_cast_or_null<fir::PointerType>(ty)) {
591       name << "ptr_";
592       ty = ptrTy.getEleTy();
593     } else if (auto ptrTy = mlir::dyn_cast_or_null<fir::LLVMPointerType>(ty)) {
594       name << "llvmptr_";
595       ty = ptrTy.getEleTy();
596     } else if (auto heapTy = mlir::dyn_cast_or_null<fir::HeapType>(ty)) {
597       name << "heap_";
598       ty = heapTy.getEleTy();
599     } else if (auto classTy = mlir::dyn_cast_or_null<fir::ClassType>(ty)) {
600       name << "class_";
601       ty = classTy.getEleTy();
602     } else if (auto boxTy = mlir::dyn_cast_or_null<fir::BoxType>(ty)) {
603       name << "box_";
604       ty = boxTy.getEleTy();
605     } else if (auto boxcharTy = mlir::dyn_cast_or_null<fir::BoxCharType>(ty)) {
606       name << "boxchar_";
607       ty = boxcharTy.getEleTy();
608     } else if (auto recTy = mlir::dyn_cast_or_null<fir::RecordType>(ty)) {
609       name << "rec_" << recTy.getName();
610       break;
611     } else {
612       llvm::report_fatal_error("unsupported type");
613     }
614   }
615   return buf;
616 }
617 
618 mlir::Type changeElementType(mlir::Type type, mlir::Type newElementType,
619                              bool turnBoxIntoClass) {
620   return llvm::TypeSwitch<mlir::Type, mlir::Type>(type)
621       .Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
622         return fir::SequenceType::get(seqTy.getShape(), newElementType);
623       })
624       .Case<fir::PointerType, fir::HeapType, fir::ReferenceType,
625             fir::ClassType>([&](auto t) -> mlir::Type {
626         using FIRT = decltype(t);
627         return FIRT::get(
628             changeElementType(t.getEleTy(), newElementType, turnBoxIntoClass));
629       })
630       .Case<fir::BoxType>([&](fir::BoxType t) -> mlir::Type {
631         mlir::Type newInnerType =
632             changeElementType(t.getEleTy(), newElementType, false);
633         if (turnBoxIntoClass)
634           return fir::ClassType::get(newInnerType);
635         return fir::BoxType::get(newInnerType);
636       })
637       .Default([&](mlir::Type t) -> mlir::Type {
638         assert((fir::isa_trivial(t) || llvm::isa<fir::RecordType>(t) ||
639                 llvm::isa<mlir::NoneType>(t)) &&
640                "unexpected FIR leaf type");
641         return newElementType;
642       });
643 }
644 
645 } // namespace fir
646 
647 namespace {
648 
649 static llvm::SmallPtrSet<detail::RecordTypeStorage const *, 4>
650     recordTypeVisited;
651 
652 } // namespace
653 
654 void fir::verifyIntegralType(mlir::Type type) {
655   if (isaIntegerType(type) || mlir::isa<mlir::IndexType>(type))
656     return;
657   llvm::report_fatal_error("expected integral type");
658 }
659 
660 void fir::printFirType(FIROpsDialect *, mlir::Type ty,
661                        mlir::DialectAsmPrinter &p) {
662   if (mlir::failed(generatedTypePrinter(ty, p)))
663     llvm::report_fatal_error("unknown type to print");
664 }
665 
666 bool fir::isa_unknown_size_box(mlir::Type t) {
667   if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(t)) {
668     auto valueType = fir::unwrapPassByRefType(boxTy);
669     if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(valueType))
670       if (seqTy.hasUnknownShape())
671         return true;
672   }
673   return false;
674 }
675 
676 //===----------------------------------------------------------------------===//
677 // BoxProcType
678 //===----------------------------------------------------------------------===//
679 
680 // `boxproc` `<` return-type `>`
681 mlir::Type BoxProcType::parse(mlir::AsmParser &parser) {
682   mlir::Type ty;
683   if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater())
684     return {};
685   return get(parser.getContext(), ty);
686 }
687 
688 void fir::BoxProcType::print(mlir::AsmPrinter &printer) const {
689   printer << "<" << getEleTy() << '>';
690 }
691 
692 llvm::LogicalResult
693 BoxProcType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
694                     mlir::Type eleTy) {
695   if (mlir::isa<mlir::FunctionType>(eleTy))
696     return mlir::success();
697   if (auto refTy = mlir::dyn_cast<ReferenceType>(eleTy))
698     if (mlir::isa<mlir::FunctionType>(refTy))
699       return mlir::success();
700   return emitError() << "invalid type for boxproc" << eleTy << '\n';
701 }
702 
703 static bool cannotBePointerOrHeapElementType(mlir::Type eleTy) {
704   return mlir::isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType,
705                    SliceType, FieldType, LenType, HeapType, PointerType,
706                    ReferenceType, TypeDescType>(eleTy);
707 }
708 
709 //===----------------------------------------------------------------------===//
710 // BoxType
711 //===----------------------------------------------------------------------===//
712 
713 llvm::LogicalResult
714 fir::BoxType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
715                      mlir::Type eleTy) {
716   if (mlir::isa<fir::BaseBoxType>(eleTy))
717     return emitError() << "invalid element type\n";
718   // TODO
719   return mlir::success();
720 }
721 
722 //===----------------------------------------------------------------------===//
723 // BoxCharType
724 //===----------------------------------------------------------------------===//
725 
726 mlir::Type fir::BoxCharType::parse(mlir::AsmParser &parser) {
727   return parseKindSingleton<fir::BoxCharType>(parser);
728 }
729 
730 void fir::BoxCharType::print(mlir::AsmPrinter &printer) const {
731   printer << "<" << getKind() << ">";
732 }
733 
734 CharacterType
735 fir::BoxCharType::getElementType(mlir::MLIRContext *context) const {
736   return CharacterType::getUnknownLen(context, getKind());
737 }
738 
739 CharacterType fir::BoxCharType::getEleTy() const {
740   return getElementType(getContext());
741 }
742 
743 //===----------------------------------------------------------------------===//
744 // CharacterType
745 //===----------------------------------------------------------------------===//
746 
747 // `char` `<` kind [`,` `len`] `>`
748 mlir::Type fir::CharacterType::parse(mlir::AsmParser &parser) {
749   int kind = 0;
750   if (parser.parseLess() || parser.parseInteger(kind))
751     return {};
752   CharacterType::LenType len = 1;
753   if (mlir::succeeded(parser.parseOptionalComma())) {
754     if (mlir::succeeded(parser.parseOptionalQuestion())) {
755       len = fir::CharacterType::unknownLen();
756     } else if (!mlir::succeeded(parser.parseInteger(len))) {
757       return {};
758     }
759   }
760   if (parser.parseGreater())
761     return {};
762   return get(parser.getContext(), kind, len);
763 }
764 
765 void fir::CharacterType::print(mlir::AsmPrinter &printer) const {
766   printer << "<" << getFKind();
767   auto len = getLen();
768   if (len != fir::CharacterType::singleton()) {
769     printer << ',';
770     if (len == fir::CharacterType::unknownLen())
771       printer << '?';
772     else
773       printer << len;
774   }
775   printer << '>';
776 }
777 
778 //===----------------------------------------------------------------------===//
779 // ClassType
780 //===----------------------------------------------------------------------===//
781 
782 llvm::LogicalResult
783 fir::ClassType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
784                        mlir::Type eleTy) {
785   if (mlir::isa<fir::RecordType, fir::SequenceType, fir::HeapType,
786                 fir::PointerType, mlir::NoneType, mlir::IntegerType,
787                 mlir::FloatType, fir::CharacterType, fir::LogicalType,
788                 mlir::ComplexType>(eleTy))
789     return mlir::success();
790   return emitError() << "invalid element type\n";
791 }
792 
793 //===----------------------------------------------------------------------===//
794 // HeapType
795 //===----------------------------------------------------------------------===//
796 
797 // `heap` `<` type `>`
798 mlir::Type fir::HeapType::parse(mlir::AsmParser &parser) {
799   return parseTypeSingleton<HeapType>(parser);
800 }
801 
802 void fir::HeapType::print(mlir::AsmPrinter &printer) const {
803   printer << "<" << getEleTy() << '>';
804 }
805 
806 llvm::LogicalResult
807 fir::HeapType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
808                       mlir::Type eleTy) {
809   if (cannotBePointerOrHeapElementType(eleTy))
810     return emitError() << "cannot build a heap pointer to type: " << eleTy
811                        << '\n';
812   return mlir::success();
813 }
814 
815 //===----------------------------------------------------------------------===//
816 // IntegerType
817 //===----------------------------------------------------------------------===//
818 
819 // `int` `<` kind `>`
820 mlir::Type fir::IntegerType::parse(mlir::AsmParser &parser) {
821   return parseKindSingleton<fir::IntegerType>(parser);
822 }
823 
824 void fir::IntegerType::print(mlir::AsmPrinter &printer) const {
825   printer << "<" << getFKind() << '>';
826 }
827 
828 //===----------------------------------------------------------------------===//
829 // UnsignedType
830 //===----------------------------------------------------------------------===//
831 
832 // `unsigned` `<` kind `>`
833 mlir::Type fir::UnsignedType::parse(mlir::AsmParser &parser) {
834   return parseKindSingleton<fir::UnsignedType>(parser);
835 }
836 
837 void fir::UnsignedType::print(mlir::AsmPrinter &printer) const {
838   printer << "<" << getFKind() << '>';
839 }
840 
841 //===----------------------------------------------------------------------===//
842 // LogicalType
843 //===----------------------------------------------------------------------===//
844 
845 // `logical` `<` kind `>`
846 mlir::Type fir::LogicalType::parse(mlir::AsmParser &parser) {
847   return parseKindSingleton<fir::LogicalType>(parser);
848 }
849 
850 void fir::LogicalType::print(mlir::AsmPrinter &printer) const {
851   printer << "<" << getFKind() << '>';
852 }
853 
854 //===----------------------------------------------------------------------===//
855 // PointerType
856 //===----------------------------------------------------------------------===//
857 
858 // `ptr` `<` type `>`
859 mlir::Type fir::PointerType::parse(mlir::AsmParser &parser) {
860   return parseTypeSingleton<fir::PointerType>(parser);
861 }
862 
863 void fir::PointerType::print(mlir::AsmPrinter &printer) const {
864   printer << "<" << getEleTy() << '>';
865 }
866 
867 llvm::LogicalResult fir::PointerType::verify(
868     llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
869     mlir::Type eleTy) {
870   if (cannotBePointerOrHeapElementType(eleTy))
871     return emitError() << "cannot build a pointer to type: " << eleTy << '\n';
872   return mlir::success();
873 }
874 
875 //===----------------------------------------------------------------------===//
876 // RecordType
877 //===----------------------------------------------------------------------===//
878 
879 // Fortran derived type
880 // unpacked:
881 // `type` `<` name
882 //           (`(` id `:` type (`,` id `:` type)* `)`)?
883 //           (`{` id `:` type (`,` id `:` type)* `}`)? '>'
884 // packed:
885 // `type` `<` name
886 //           (`(` id `:` type (`,` id `:` type)* `)`)?
887 //           (`<{` id `:` type (`,` id `:` type)* `}>`)? '>'
888 mlir::Type fir::RecordType::parse(mlir::AsmParser &parser) {
889   llvm::StringRef name;
890   if (parser.parseLess() || parser.parseKeyword(&name))
891     return {};
892   RecordType result = RecordType::get(parser.getContext(), name);
893 
894   RecordType::TypeList lenParamList;
895   if (!parser.parseOptionalLParen()) {
896     while (true) {
897       llvm::StringRef lenparam;
898       mlir::Type intTy;
899       if (parser.parseKeyword(&lenparam) || parser.parseColon() ||
900           parser.parseType(intTy)) {
901         parser.emitError(parser.getNameLoc(), "expected LEN parameter list");
902         return {};
903       }
904       lenParamList.emplace_back(lenparam, intTy);
905       if (parser.parseOptionalComma())
906         break;
907     }
908     if (parser.parseRParen())
909       return {};
910   }
911 
912   RecordType::TypeList typeList;
913   if (!parser.parseOptionalLess()) {
914     result.pack(true);
915   }
916 
917   if (!parser.parseOptionalLBrace()) {
918     while (true) {
919       llvm::StringRef field;
920       mlir::Type fldTy;
921       if (parser.parseKeyword(&field) || parser.parseColon() ||
922           parser.parseType(fldTy)) {
923         parser.emitError(parser.getNameLoc(), "expected field type list");
924         return {};
925       }
926       typeList.emplace_back(field, fldTy);
927       if (parser.parseOptionalComma())
928         break;
929     }
930     if (parser.parseOptionalGreater()) {
931       if (parser.parseRBrace())
932         return {};
933     }
934   }
935 
936   if (parser.parseGreater())
937     return {};
938 
939   if (lenParamList.empty() && typeList.empty())
940     return result;
941 
942   result.finalize(lenParamList, typeList);
943   return verifyDerived(parser, result, lenParamList, typeList);
944 }
945 
946 void fir::RecordType::print(mlir::AsmPrinter &printer) const {
947   printer << "<" << getName();
948   if (!recordTypeVisited.count(uniqueKey())) {
949     recordTypeVisited.insert(uniqueKey());
950     if (getLenParamList().size()) {
951       char ch = '(';
952       for (auto p : getLenParamList()) {
953         printer << ch << p.first << ':';
954         p.second.print(printer.getStream());
955         ch = ',';
956       }
957       printer << ')';
958     }
959     if (getTypeList().size()) {
960       if (isPacked()) {
961         printer << '<';
962       }
963       char ch = '{';
964       for (auto p : getTypeList()) {
965         printer << ch << p.first << ':';
966         p.second.print(printer.getStream());
967         ch = ',';
968       }
969       printer << '}';
970       if (isPacked()) {
971         printer << '>';
972       }
973     }
974     recordTypeVisited.erase(uniqueKey());
975   }
976   printer << '>';
977 }
978 
979 void fir::RecordType::finalize(llvm::ArrayRef<TypePair> lenPList,
980                                llvm::ArrayRef<TypePair> typeList) {
981   getImpl()->finalize(lenPList, typeList);
982 }
983 
984 llvm::StringRef fir::RecordType::getName() const {
985   return getImpl()->getName();
986 }
987 
988 RecordType::TypeList fir::RecordType::getTypeList() const {
989   return getImpl()->getTypeList();
990 }
991 
992 RecordType::TypeList fir::RecordType::getLenParamList() const {
993   return getImpl()->getLenParamList();
994 }
995 
996 bool fir::RecordType::isFinalized() const { return getImpl()->isFinalized(); }
997 
998 void fir::RecordType::pack(bool p) { getImpl()->pack(p); }
999 
1000 bool fir::RecordType::isPacked() const { return getImpl()->isPacked(); }
1001 
1002 detail::RecordTypeStorage const *fir::RecordType::uniqueKey() const {
1003   return getImpl();
1004 }
1005 
1006 llvm::LogicalResult fir::RecordType::verify(
1007     llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1008     llvm::StringRef name) {
1009   if (name.size() == 0)
1010     return emitError() << "record types must have a name";
1011   return mlir::success();
1012 }
1013 
1014 mlir::Type fir::RecordType::getType(llvm::StringRef ident) {
1015   for (auto f : getTypeList())
1016     if (ident == f.first)
1017       return f.second;
1018   return {};
1019 }
1020 
1021 unsigned fir::RecordType::getFieldIndex(llvm::StringRef ident) {
1022   for (auto f : llvm::enumerate(getTypeList()))
1023     if (ident == f.value().first)
1024       return f.index();
1025   return std::numeric_limits<unsigned>::max();
1026 }
1027 
1028 //===----------------------------------------------------------------------===//
1029 // ReferenceType
1030 //===----------------------------------------------------------------------===//
1031 
1032 // `ref` `<` type `>`
1033 mlir::Type fir::ReferenceType::parse(mlir::AsmParser &parser) {
1034   return parseTypeSingleton<fir::ReferenceType>(parser);
1035 }
1036 
1037 void fir::ReferenceType::print(mlir::AsmPrinter &printer) const {
1038   printer << "<" << getEleTy() << '>';
1039 }
1040 
1041 llvm::LogicalResult fir::ReferenceType::verify(
1042     llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1043     mlir::Type eleTy) {
1044   if (mlir::isa<ShapeType, ShapeShiftType, SliceType, FieldType, LenType,
1045                 ReferenceType, TypeDescType>(eleTy))
1046     return emitError() << "cannot build a reference to type: " << eleTy << '\n';
1047   return mlir::success();
1048 }
1049 
1050 //===----------------------------------------------------------------------===//
1051 // SequenceType
1052 //===----------------------------------------------------------------------===//
1053 
1054 // `array` `<` `*` | bounds (`x` bounds)* `:` type (',' affine-map)? `>`
1055 // bounds ::= `?` | int-lit
1056 mlir::Type fir::SequenceType::parse(mlir::AsmParser &parser) {
1057   if (parser.parseLess())
1058     return {};
1059   SequenceType::Shape shape;
1060   if (parser.parseOptionalStar()) {
1061     if (parser.parseDimensionList(shape, /*allowDynamic=*/true))
1062       return {};
1063   } else if (parser.parseColon()) {
1064     return {};
1065   }
1066   mlir::Type eleTy;
1067   if (parser.parseType(eleTy))
1068     return {};
1069   mlir::AffineMapAttr map;
1070   if (!parser.parseOptionalComma()) {
1071     if (parser.parseAttribute(map)) {
1072       parser.emitError(parser.getNameLoc(), "expecting affine map");
1073       return {};
1074     }
1075   }
1076   if (parser.parseGreater())
1077     return {};
1078   return SequenceType::get(parser.getContext(), shape, eleTy, map);
1079 }
1080 
1081 void fir::SequenceType::print(mlir::AsmPrinter &printer) const {
1082   auto shape = getShape();
1083   if (shape.size()) {
1084     printer << '<';
1085     for (const auto &b : shape) {
1086       if (b >= 0)
1087         printer << b << 'x';
1088       else
1089         printer << "?x";
1090     }
1091   } else {
1092     printer << "<*:";
1093   }
1094   printer << getEleTy();
1095   if (auto map = getLayoutMap()) {
1096     printer << ", ";
1097     map.print(printer.getStream());
1098   }
1099   printer << '>';
1100 }
1101 
1102 unsigned fir::SequenceType::getConstantRows() const {
1103   if (hasDynamicSize(getEleTy()))
1104     return 0;
1105   auto shape = getShape();
1106   unsigned count = 0;
1107   for (auto d : shape) {
1108     if (d == getUnknownExtent())
1109       break;
1110     ++count;
1111   }
1112   return count;
1113 }
1114 
1115 llvm::LogicalResult fir::SequenceType::verify(
1116     llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1117     llvm::ArrayRef<int64_t> shape, mlir::Type eleTy,
1118     mlir::AffineMapAttr layoutMap) {
1119   // DIMENSION attribute can only be applied to an intrinsic or record type
1120   if (mlir::isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType,
1121                 ShiftType, SliceType, FieldType, LenType, HeapType, PointerType,
1122                 ReferenceType, TypeDescType, SequenceType>(eleTy))
1123     return emitError() << "cannot build an array of this element type: "
1124                        << eleTy << '\n';
1125   return mlir::success();
1126 }
1127 
1128 //===----------------------------------------------------------------------===//
1129 // ShapeType
1130 //===----------------------------------------------------------------------===//
1131 
1132 mlir::Type fir::ShapeType::parse(mlir::AsmParser &parser) {
1133   return parseRankSingleton<fir::ShapeType>(parser);
1134 }
1135 
1136 void fir::ShapeType::print(mlir::AsmPrinter &printer) const {
1137   printer << "<" << getImpl()->rank << ">";
1138 }
1139 
1140 //===----------------------------------------------------------------------===//
1141 // ShapeShiftType
1142 //===----------------------------------------------------------------------===//
1143 
1144 mlir::Type fir::ShapeShiftType::parse(mlir::AsmParser &parser) {
1145   return parseRankSingleton<fir::ShapeShiftType>(parser);
1146 }
1147 
1148 void fir::ShapeShiftType::print(mlir::AsmPrinter &printer) const {
1149   printer << "<" << getRank() << ">";
1150 }
1151 
1152 //===----------------------------------------------------------------------===//
1153 // ShiftType
1154 //===----------------------------------------------------------------------===//
1155 
1156 mlir::Type fir::ShiftType::parse(mlir::AsmParser &parser) {
1157   return parseRankSingleton<fir::ShiftType>(parser);
1158 }
1159 
1160 void fir::ShiftType::print(mlir::AsmPrinter &printer) const {
1161   printer << "<" << getRank() << ">";
1162 }
1163 
1164 //===----------------------------------------------------------------------===//
1165 // SliceType
1166 //===----------------------------------------------------------------------===//
1167 
1168 // `slice` `<` rank `>`
1169 mlir::Type fir::SliceType::parse(mlir::AsmParser &parser) {
1170   return parseRankSingleton<fir::SliceType>(parser);
1171 }
1172 
1173 void fir::SliceType::print(mlir::AsmPrinter &printer) const {
1174   printer << "<" << getRank() << '>';
1175 }
1176 
1177 //===----------------------------------------------------------------------===//
1178 // TypeDescType
1179 //===----------------------------------------------------------------------===//
1180 
1181 // `tdesc` `<` type `>`
1182 mlir::Type fir::TypeDescType::parse(mlir::AsmParser &parser) {
1183   return parseTypeSingleton<fir::TypeDescType>(parser);
1184 }
1185 
1186 void fir::TypeDescType::print(mlir::AsmPrinter &printer) const {
1187   printer << "<" << getOfTy() << '>';
1188 }
1189 
1190 llvm::LogicalResult fir::TypeDescType::verify(
1191     llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1192     mlir::Type eleTy) {
1193   if (mlir::isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType,
1194                 ShiftType, SliceType, FieldType, LenType, ReferenceType,
1195                 TypeDescType>(eleTy))
1196     return emitError() << "cannot build a type descriptor of type: " << eleTy
1197                        << '\n';
1198   return mlir::success();
1199 }
1200 
1201 //===----------------------------------------------------------------------===//
1202 // VectorType
1203 //===----------------------------------------------------------------------===//
1204 
1205 // `vector` `<` len `:` type `>`
1206 mlir::Type fir::VectorType::parse(mlir::AsmParser &parser) {
1207   int64_t len = 0;
1208   mlir::Type eleTy;
1209   if (parser.parseLess() || parser.parseInteger(len) || parser.parseColon() ||
1210       parser.parseType(eleTy) || parser.parseGreater())
1211     return {};
1212   return fir::VectorType::get(len, eleTy);
1213 }
1214 
1215 void fir::VectorType::print(mlir::AsmPrinter &printer) const {
1216   printer << "<" << getLen() << ':' << getEleTy() << '>';
1217 }
1218 
1219 llvm::LogicalResult fir::VectorType::verify(
1220     llvm::function_ref<mlir::InFlightDiagnostic()> emitError, uint64_t len,
1221     mlir::Type eleTy) {
1222   if (!(fir::isa_real(eleTy) || fir::isa_integer(eleTy)))
1223     return emitError() << "cannot build a vector of type " << eleTy << '\n';
1224   return mlir::success();
1225 }
1226 
1227 bool fir::VectorType::isValidElementType(mlir::Type t) {
1228   return isa_real(t) || isa_integer(t);
1229 }
1230 
1231 bool fir::isCharacterProcedureTuple(mlir::Type ty, bool acceptRawFunc) {
1232   mlir::TupleType tuple = mlir::dyn_cast<mlir::TupleType>(ty);
1233   return tuple && tuple.size() == 2 &&
1234          (mlir::isa<fir::BoxProcType>(tuple.getType(0)) ||
1235           (acceptRawFunc && mlir::isa<mlir::FunctionType>(tuple.getType(0)))) &&
1236          fir::isa_integer(tuple.getType(1));
1237 }
1238 
1239 bool fir::hasAbstractResult(mlir::FunctionType ty) {
1240   if (ty.getNumResults() == 0)
1241     return false;
1242   auto resultType = ty.getResult(0);
1243   return mlir::isa<fir::SequenceType, fir::BaseBoxType, fir::RecordType>(
1244       resultType);
1245 }
1246 
1247 /// Convert llvm::Type::TypeID to mlir::Type. \p kind is provided for error
1248 /// messages only.
1249 mlir::Type fir::fromRealTypeID(mlir::MLIRContext *context,
1250                                llvm::Type::TypeID typeID, fir::KindTy kind) {
1251   switch (typeID) {
1252   case llvm::Type::TypeID::HalfTyID:
1253     return mlir::Float16Type::get(context);
1254   case llvm::Type::TypeID::BFloatTyID:
1255     return mlir::BFloat16Type::get(context);
1256   case llvm::Type::TypeID::FloatTyID:
1257     return mlir::Float32Type::get(context);
1258   case llvm::Type::TypeID::DoubleTyID:
1259     return mlir::Float64Type::get(context);
1260   case llvm::Type::TypeID::X86_FP80TyID:
1261     return mlir::Float80Type::get(context);
1262   case llvm::Type::TypeID::FP128TyID:
1263     return mlir::Float128Type::get(context);
1264   default:
1265     mlir::emitError(mlir::UnknownLoc::get(context))
1266         << "unsupported type: !fir.real<" << kind << ">";
1267     return {};
1268   }
1269 }
1270 
1271 //===----------------------------------------------------------------------===//
1272 // BaseBoxType
1273 //===----------------------------------------------------------------------===//
1274 
1275 mlir::Type BaseBoxType::getEleTy() const {
1276   return llvm::TypeSwitch<fir::BaseBoxType, mlir::Type>(*this)
1277       .Case<fir::BoxType, fir::ClassType>(
1278           [](auto type) { return type.getEleTy(); });
1279 }
1280 
1281 mlir::Type BaseBoxType::unwrapInnerType() const {
1282   return fir::unwrapInnerType(getEleTy());
1283 }
1284 
1285 static mlir::Type
1286 changeTypeShape(mlir::Type type,
1287                 std::optional<fir::SequenceType::ShapeRef> newShape) {
1288   return llvm::TypeSwitch<mlir::Type, mlir::Type>(type)
1289       .Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
1290         if (newShape)
1291           return fir::SequenceType::get(*newShape, seqTy.getEleTy());
1292         return seqTy.getEleTy();
1293       })
1294       .Case<fir::PointerType, fir::HeapType, fir::ReferenceType, fir::BoxType,
1295             fir::ClassType>([&](auto t) -> mlir::Type {
1296         using FIRT = decltype(t);
1297         return FIRT::get(changeTypeShape(t.getEleTy(), newShape));
1298       })
1299       .Default([&](mlir::Type t) -> mlir::Type {
1300         assert((fir::isa_trivial(t) || llvm::isa<fir::RecordType>(t) ||
1301                 llvm::isa<mlir::NoneType>(t)) &&
1302                "unexpected FIR leaf type");
1303         if (newShape)
1304           return fir::SequenceType::get(*newShape, t);
1305         return t;
1306       });
1307 }
1308 
1309 fir::BaseBoxType
1310 fir::BaseBoxType::getBoxTypeWithNewShape(mlir::Type shapeMold) const {
1311   fir::SequenceType seqTy = fir::unwrapUntilSeqType(shapeMold);
1312   std::optional<fir::SequenceType::ShapeRef> newShape;
1313   if (seqTy)
1314     newShape = seqTy.getShape();
1315   return mlir::cast<fir::BaseBoxType>(changeTypeShape(*this, newShape));
1316 }
1317 
1318 fir::BaseBoxType fir::BaseBoxType::getBoxTypeWithNewShape(int rank) const {
1319   std::optional<fir::SequenceType::ShapeRef> newShape;
1320   fir::SequenceType::Shape shapeVector;
1321   if (rank > 0) {
1322     shapeVector =
1323         fir::SequenceType::Shape(rank, fir::SequenceType::getUnknownExtent());
1324     newShape = shapeVector;
1325   }
1326   return mlir::cast<fir::BaseBoxType>(changeTypeShape(*this, newShape));
1327 }
1328 
1329 fir::BaseBoxType fir::BaseBoxType::getBoxTypeWithNewAttr(
1330     fir::BaseBoxType::Attribute attr) const {
1331   mlir::Type baseType = fir::unwrapRefType(getEleTy());
1332   switch (attr) {
1333   case fir::BaseBoxType::Attribute::None:
1334     break;
1335   case fir::BaseBoxType::Attribute::Allocatable:
1336     baseType = fir::HeapType::get(baseType);
1337     break;
1338   case fir::BaseBoxType::Attribute::Pointer:
1339     baseType = fir::PointerType::get(baseType);
1340     break;
1341   }
1342   return llvm::TypeSwitch<fir::BaseBoxType, fir::BaseBoxType>(*this)
1343       .Case<fir::BoxType>(
1344           [baseType](auto) { return fir::BoxType::get(baseType); })
1345       .Case<fir::ClassType>(
1346           [baseType](auto) { return fir::ClassType::get(baseType); });
1347 }
1348 
1349 bool fir::BaseBoxType::isAssumedRank() const {
1350   if (auto seqTy =
1351           mlir::dyn_cast<fir::SequenceType>(fir::unwrapRefType(getEleTy())))
1352     return seqTy.hasUnknownShape();
1353   return false;
1354 }
1355 
1356 //===----------------------------------------------------------------------===//
1357 // FIROpsDialect
1358 //===----------------------------------------------------------------------===//
1359 
1360 void FIROpsDialect::registerTypes() {
1361   addTypes<BoxType, BoxCharType, BoxProcType, CharacterType, ClassType,
1362            FieldType, HeapType, fir::IntegerType, LenType, LogicalType,
1363            LLVMPointerType, PointerType, RecordType, ReferenceType,
1364            SequenceType, ShapeType, ShapeShiftType, ShiftType, SliceType,
1365            TypeDescType, fir::VectorType, fir::DummyScopeType>();
1366   fir::ReferenceType::attachInterface<
1367       OpenMPPointerLikeModel<fir::ReferenceType>>(*getContext());
1368   fir::ReferenceType::attachInterface<
1369       OpenACCPointerLikeModel<fir::ReferenceType>>(*getContext());
1370 
1371   fir::PointerType::attachInterface<OpenMPPointerLikeModel<fir::PointerType>>(
1372       *getContext());
1373   fir::PointerType::attachInterface<OpenACCPointerLikeModel<fir::PointerType>>(
1374       *getContext());
1375 
1376   fir::HeapType::attachInterface<OpenMPPointerLikeModel<fir::HeapType>>(
1377       *getContext());
1378   fir::HeapType::attachInterface<OpenACCPointerLikeModel<fir::HeapType>>(
1379       *getContext());
1380 
1381   fir::LLVMPointerType::attachInterface<
1382       OpenMPPointerLikeModel<fir::LLVMPointerType>>(*getContext());
1383   fir::LLVMPointerType::attachInterface<
1384       OpenACCPointerLikeModel<fir::LLVMPointerType>>(*getContext());
1385 }
1386 
1387 std::optional<std::pair<uint64_t, unsigned short>>
1388 fir::getTypeSizeAndAlignment(mlir::Location loc, mlir::Type ty,
1389                              const mlir::DataLayout &dl,
1390                              const fir::KindMapping &kindMap) {
1391   if (mlir::isa<mlir::IntegerType, mlir::FloatType, mlir::ComplexType>(ty)) {
1392     llvm::TypeSize size = dl.getTypeSize(ty);
1393     unsigned short alignment = dl.getTypeABIAlignment(ty);
1394     return std::pair{size, alignment};
1395   }
1396   if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) {
1397     auto result = getTypeSizeAndAlignment(loc, seqTy.getEleTy(), dl, kindMap);
1398     if (!result)
1399       return result;
1400     auto [eleSize, eleAlign] = *result;
1401     std::uint64_t size =
1402         llvm::alignTo(eleSize, eleAlign) * seqTy.getConstantArraySize();
1403     return std::pair{size, eleAlign};
1404   }
1405   if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty)) {
1406     std::uint64_t size = 0;
1407     unsigned short align = 1;
1408     for (auto component : recTy.getTypeList()) {
1409       auto result = getTypeSizeAndAlignment(loc, component.second, dl, kindMap);
1410       if (!result)
1411         return result;
1412       auto [compSize, compAlign] = *result;
1413       size =
1414           llvm::alignTo(size, compAlign) + llvm::alignTo(compSize, compAlign);
1415       align = std::max(align, compAlign);
1416     }
1417     return std::pair{size, align};
1418   }
1419   if (auto logical = mlir::dyn_cast<fir::LogicalType>(ty)) {
1420     mlir::Type intTy = mlir::IntegerType::get(
1421         logical.getContext(), kindMap.getLogicalBitsize(logical.getFKind()));
1422     return getTypeSizeAndAlignment(loc, intTy, dl, kindMap);
1423   }
1424   if (auto character = mlir::dyn_cast<fir::CharacterType>(ty)) {
1425     mlir::Type intTy = mlir::IntegerType::get(
1426         character.getContext(),
1427         kindMap.getCharacterBitsize(character.getFKind()));
1428     auto result = getTypeSizeAndAlignment(loc, intTy, dl, kindMap);
1429     if (!result)
1430       return result;
1431     auto [compSize, compAlign] = *result;
1432     if (character.hasConstantLen())
1433       compSize *= character.getLen();
1434     return std::pair{compSize, compAlign};
1435   }
1436   return std::nullopt;
1437 }
1438 
1439 std::pair<std::uint64_t, unsigned short>
1440 fir::getTypeSizeAndAlignmentOrCrash(mlir::Location loc, mlir::Type ty,
1441                                     const mlir::DataLayout &dl,
1442                                     const fir::KindMapping &kindMap) {
1443   std::optional<std::pair<uint64_t, unsigned short>> result =
1444       getTypeSizeAndAlignment(loc, ty, dl, kindMap);
1445   if (result)
1446     return *result;
1447   TODO(loc, "computing size of a component");
1448 }
1449