xref: /llvm-project/flang/lib/Lower/HostAssociations.cpp (revision 4abbf99579633d70bdecb9876cbed319ce9f546a)
1 //===-- HostAssociations.cpp ----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Lower/HostAssociations.h"
10 #include "flang/Evaluate/check-expression.h"
11 #include "flang/Lower/AbstractConverter.h"
12 #include "flang/Lower/Allocatable.h"
13 #include "flang/Lower/BoxAnalyzer.h"
14 #include "flang/Lower/CallInterface.h"
15 #include "flang/Lower/ConvertType.h"
16 #include "flang/Lower/ConvertVariable.h"
17 #include "flang/Lower/OpenMP.h"
18 #include "flang/Lower/PFTBuilder.h"
19 #include "flang/Lower/SymbolMap.h"
20 #include "flang/Optimizer/Builder/Character.h"
21 #include "flang/Optimizer/Builder/FIRBuilder.h"
22 #include "flang/Optimizer/Builder/Todo.h"
23 #include "flang/Optimizer/Support/FatalError.h"
24 #include "flang/Semantics/tools.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Debug.h"
27 #include <optional>
28 
29 #define DEBUG_TYPE "flang-host-assoc"
30 
31 // Host association inside internal procedures is implemented by allocating an
32 // mlir tuple (a struct) inside the host containing the addresses and properties
33 // of variables that are accessed by internal procedures. The address of this
34 // tuple is passed as an argument by the host when calling internal procedures.
35 // Internal procedures propagate a reference to this tuple when calling other
36 // internal procedures of the host.
37 //
38 // This file defines how the type of the host tuple is built, how the tuple
39 // value is created inside the host, and how the host associated variables are
40 // instantiated inside the internal procedures from the tuple value. The
41 // CapturedXXX classes define each of these three actions for a specific
42 // kind of variables by providing a `getType`, a `instantiateHostTuple`, and a
43 // `getFromTuple` method. These classes are structured as follow:
44 //
45 //   class CapturedKindOfVar : public CapturedSymbols<CapturedKindOfVar> {
46 //     // Return the type of the tuple element for a host associated
47 //     // variable given its symbol inside the host. This is called when
48 //     // building function interfaces.
49 //     static mlir::Type getType();
50 //     // Build the tuple element value for a host associated variable given its
51 //     // value inside the host. This is called when lowering the host body.
52 //     static void instantiateHostTuple();
53 //     // Instantiate a host variable inside an internal procedure given its
54 //     // tuple element value. This is called when lowering internal procedure
55 //     // bodies.
56 //     static void getFromTuple();
57 //   };
58 //
59 // If a new kind of variable requires ad-hoc handling, a new CapturedXXX class
60 // should be added to handle it, and `walkCaptureCategories` should be updated
61 // to dispatch this new kind of variable to this new class.
62 
63 /// Is \p sym a derived type entity with length parameters ?
isDerivedWithLenParameters(const Fortran::semantics::Symbol & sym)64 static bool isDerivedWithLenParameters(const Fortran::semantics::Symbol &sym) {
65   if (const auto *declTy = sym.GetType())
66     if (const auto *derived = declTy->AsDerived())
67       return Fortran::semantics::CountLenParameters(*derived) != 0;
68   return false;
69 }
70 
71 /// Map the extracted fir::ExtendedValue for a host associated variable inside
72 /// and internal procedure to its symbol. Generates an hlfir.declare in HLFIR.
bindCapturedSymbol(const Fortran::semantics::Symbol & sym,fir::ExtendedValue val,Fortran::lower::AbstractConverter & converter,Fortran::lower::SymMap & symMap)73 static void bindCapturedSymbol(const Fortran::semantics::Symbol &sym,
74                                fir::ExtendedValue val,
75                                Fortran::lower::AbstractConverter &converter,
76                                Fortran::lower::SymMap &symMap) {
77   if (converter.getLoweringOptions().getLowerToHighLevelFIR())
78     Fortran::lower::genDeclareSymbol(converter, symMap, sym, val,
79                                      fir::FortranVariableFlagsEnum::host_assoc);
80   else
81     symMap.addSymbol(sym, val);
82 }
83 
84 namespace {
85 /// Struct to be used as argument in walkCaptureCategories when building the
86 /// tuple element type for a host associated variable.
87 struct GetTypeInTuple {
88   /// walkCaptureCategories must return a type.
89   using Result = mlir::Type;
90 };
91 
92 /// Struct to be used as argument in walkCaptureCategories when building the
93 /// tuple element value for a host associated variable.
94 struct InstantiateHostTuple {
95   /// walkCaptureCategories returns nothing.
96   using Result = void;
97   /// Value of the variable inside the host procedure.
98   fir::ExtendedValue hostValue;
99   /// Address of the tuple element of the variable.
100   mlir::Value addrInTuple;
101   mlir::Location loc;
102 };
103 
104 /// Struct to be used as argument in walkCaptureCategories when instantiating a
105 /// host associated variables from its tuple element value.
106 struct GetFromTuple {
107   /// walkCaptureCategories returns nothing.
108   using Result = void;
109   /// Symbol map inside the internal procedure.
110   Fortran::lower::SymMap &symMap;
111   /// Value of the tuple element for the host associated variable.
112   mlir::Value valueInTuple;
113   mlir::Location loc;
114 };
115 
116 /// Base class that must be inherited with CRTP by classes defining
117 /// how host association is implemented for a type of symbol.
118 /// It simply dispatches visit() calls to the implementations according
119 /// to the argument type.
120 template <typename SymbolCategory>
121 class CapturedSymbols {
122 public:
123   template <typename T>
visit(const T &,Fortran::lower::AbstractConverter &,const Fortran::semantics::Symbol &,const Fortran::lower::BoxAnalyzer &)124   static void visit(const T &, Fortran::lower::AbstractConverter &,
125                     const Fortran::semantics::Symbol &,
126                     const Fortran::lower::BoxAnalyzer &) {
127     static_assert(!std::is_same_v<T, T> &&
128                   "default visit must not be instantiated");
129   }
visit(const GetTypeInTuple &,Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym,const Fortran::lower::BoxAnalyzer &)130   static mlir::Type visit(const GetTypeInTuple &,
131                           Fortran::lower::AbstractConverter &converter,
132                           const Fortran::semantics::Symbol &sym,
133                           const Fortran::lower::BoxAnalyzer &) {
134     return SymbolCategory::getType(converter, sym);
135   }
visit(const InstantiateHostTuple & args,Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym,const Fortran::lower::BoxAnalyzer &)136   static void visit(const InstantiateHostTuple &args,
137                     Fortran::lower::AbstractConverter &converter,
138                     const Fortran::semantics::Symbol &sym,
139                     const Fortran::lower::BoxAnalyzer &) {
140     return SymbolCategory::instantiateHostTuple(args, converter, sym);
141   }
visit(const GetFromTuple & args,Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym,const Fortran::lower::BoxAnalyzer & ba)142   static void visit(const GetFromTuple &args,
143                     Fortran::lower::AbstractConverter &converter,
144                     const Fortran::semantics::Symbol &sym,
145                     const Fortran::lower::BoxAnalyzer &ba) {
146     return SymbolCategory::getFromTuple(args, converter, sym, ba);
147   }
148 };
149 
150 /// Class defining simple scalars are captured in internal procedures.
151 /// Simple scalars are non character intrinsic scalars. They are captured
152 /// as `!fir.ref<T>`, for example `!fir.ref<i32>` for `INTEGER*4`.
153 class CapturedSimpleScalars : public CapturedSymbols<CapturedSimpleScalars> {
154 public:
getType(Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym)155   static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
156                             const Fortran::semantics::Symbol &sym) {
157     return fir::ReferenceType::get(converter.genType(sym));
158   }
159 
instantiateHostTuple(const InstantiateHostTuple & args,Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol &)160   static void instantiateHostTuple(const InstantiateHostTuple &args,
161                                    Fortran::lower::AbstractConverter &converter,
162                                    const Fortran::semantics::Symbol &) {
163     fir::FirOpBuilder &builder = converter.getFirOpBuilder();
164     mlir::Type typeInTuple = fir::dyn_cast_ptrEleTy(args.addrInTuple.getType());
165     assert(typeInTuple && "addrInTuple must be an address");
166     mlir::Value castBox = builder.createConvert(args.loc, typeInTuple,
167                                                 fir::getBase(args.hostValue));
168     builder.create<fir::StoreOp>(args.loc, castBox, args.addrInTuple);
169   }
170 
getFromTuple(const GetFromTuple & args,Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym,const Fortran::lower::BoxAnalyzer &)171   static void getFromTuple(const GetFromTuple &args,
172                            Fortran::lower::AbstractConverter &converter,
173                            const Fortran::semantics::Symbol &sym,
174                            const Fortran::lower::BoxAnalyzer &) {
175     bindCapturedSymbol(sym, args.valueInTuple, converter, args.symMap);
176   }
177 };
178 
179 /// Class defining how dummy procedures and procedure pointers
180 /// are captured in internal procedures.
181 class CapturedProcedure : public CapturedSymbols<CapturedProcedure> {
182 public:
getType(Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym)183   static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
184                             const Fortran::semantics::Symbol &sym) {
185     mlir::Type funTy = Fortran::lower::getDummyProcedureType(sym, converter);
186     if (Fortran::semantics::IsPointer(sym))
187       return fir::ReferenceType::get(funTy);
188     return funTy;
189   }
190 
instantiateHostTuple(const InstantiateHostTuple & args,Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol &)191   static void instantiateHostTuple(const InstantiateHostTuple &args,
192                                    Fortran::lower::AbstractConverter &converter,
193                                    const Fortran::semantics::Symbol &) {
194     fir::FirOpBuilder &builder = converter.getFirOpBuilder();
195     mlir::Type typeInTuple = fir::dyn_cast_ptrEleTy(args.addrInTuple.getType());
196     assert(typeInTuple && "addrInTuple must be an address");
197     mlir::Value castBox = builder.createConvert(args.loc, typeInTuple,
198                                                 fir::getBase(args.hostValue));
199     builder.create<fir::StoreOp>(args.loc, castBox, args.addrInTuple);
200   }
201 
getFromTuple(const GetFromTuple & args,Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym,const Fortran::lower::BoxAnalyzer &)202   static void getFromTuple(const GetFromTuple &args,
203                            Fortran::lower::AbstractConverter &converter,
204                            const Fortran::semantics::Symbol &sym,
205                            const Fortran::lower::BoxAnalyzer &) {
206     bindCapturedSymbol(sym, args.valueInTuple, converter, args.symMap);
207   }
208 };
209 
210 /// Class defining how character scalars are captured in internal procedures.
211 /// Character scalars are passed as !fir.boxchar<kind> in the tuple.
212 class CapturedCharacterScalars
213     : public CapturedSymbols<CapturedCharacterScalars> {
214 public:
215   // Note: so far, do not specialize constant length characters. They can be
216   // implemented by only passing the address. This could be done later in
217   // lowering or a CapturedStaticLenCharacterScalars class could be added here.
218 
getType(Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym)219   static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
220                             const Fortran::semantics::Symbol &sym) {
221     fir::KindTy kind =
222         mlir::cast<fir::CharacterType>(converter.genType(sym)).getFKind();
223     return fir::BoxCharType::get(&converter.getMLIRContext(), kind);
224   }
225 
instantiateHostTuple(const InstantiateHostTuple & args,Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol &)226   static void instantiateHostTuple(const InstantiateHostTuple &args,
227                                    Fortran::lower::AbstractConverter &converter,
228                                    const Fortran::semantics::Symbol &) {
229     const fir::CharBoxValue *charBox = args.hostValue.getCharBox();
230     assert(charBox && "host value must be a fir::CharBoxValue");
231     fir::FirOpBuilder &builder = converter.getFirOpBuilder();
232     mlir::Value boxchar = fir::factory::CharacterExprHelper(builder, args.loc)
233                               .createEmbox(*charBox);
234     builder.create<fir::StoreOp>(args.loc, boxchar, args.addrInTuple);
235   }
236 
getFromTuple(const GetFromTuple & args,Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym,const Fortran::lower::BoxAnalyzer &)237   static void getFromTuple(const GetFromTuple &args,
238                            Fortran::lower::AbstractConverter &converter,
239                            const Fortran::semantics::Symbol &sym,
240                            const Fortran::lower::BoxAnalyzer &) {
241     fir::factory::CharacterExprHelper charHelp(converter.getFirOpBuilder(),
242                                                args.loc);
243     std::pair<mlir::Value, mlir::Value> unboxchar =
244         charHelp.createUnboxChar(args.valueInTuple);
245     bindCapturedSymbol(sym,
246                        fir::CharBoxValue{unboxchar.first, unboxchar.second},
247                        converter, args.symMap);
248   }
249 };
250 
251 /// Class defining how polymorphic scalar entities are captured in internal
252 /// procedures. Polymorphic entities are always boxed as a fir.class box.
253 /// Polymorphic array can be handled in CapturedArrays directly
254 class CapturedPolymorphicScalar
255     : public CapturedSymbols<CapturedPolymorphicScalar> {
256 public:
getType(Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym)257   static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
258                             const Fortran::semantics::Symbol &sym) {
259     return converter.genType(sym);
260   }
instantiateHostTuple(const InstantiateHostTuple & args,Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym)261   static void instantiateHostTuple(const InstantiateHostTuple &args,
262                                    Fortran::lower::AbstractConverter &converter,
263                                    const Fortran::semantics::Symbol &sym) {
264     fir::FirOpBuilder &builder = converter.getFirOpBuilder();
265     mlir::Location loc = args.loc;
266     mlir::Type typeInTuple = fir::dyn_cast_ptrEleTy(args.addrInTuple.getType());
267     assert(typeInTuple && "addrInTuple must be an address");
268     mlir::Value castBox = builder.createConvert(args.loc, typeInTuple,
269                                                 fir::getBase(args.hostValue));
270     if (Fortran::semantics::IsOptional(sym)) {
271       auto isPresent =
272           builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), castBox);
273       builder.genIfThenElse(loc, isPresent)
274           .genThen([&]() {
275             builder.create<fir::StoreOp>(loc, castBox, args.addrInTuple);
276           })
277           .genElse([&]() {
278             mlir::Value null = fir::factory::createUnallocatedBox(
279                 builder, loc, typeInTuple,
280                 /*nonDeferredParams=*/mlir::ValueRange{});
281             builder.create<fir::StoreOp>(loc, null, args.addrInTuple);
282           })
283           .end();
284     } else {
285       builder.create<fir::StoreOp>(loc, castBox, args.addrInTuple);
286     }
287   }
getFromTuple(const GetFromTuple & args,Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym,const Fortran::lower::BoxAnalyzer & ba)288   static void getFromTuple(const GetFromTuple &args,
289                            Fortran::lower::AbstractConverter &converter,
290                            const Fortran::semantics::Symbol &sym,
291                            const Fortran::lower::BoxAnalyzer &ba) {
292     fir::FirOpBuilder &builder = converter.getFirOpBuilder();
293     mlir::Location loc = args.loc;
294     mlir::Value box = args.valueInTuple;
295     if (Fortran::semantics::IsOptional(sym)) {
296       auto boxTy = mlir::cast<fir::BaseBoxType>(box.getType());
297       auto eleTy = boxTy.getEleTy();
298       if (!fir::isa_ref_type(eleTy))
299         eleTy = builder.getRefType(eleTy);
300       auto addr = builder.create<fir::BoxAddrOp>(loc, eleTy, box);
301       mlir::Value isPresent = builder.genIsNotNullAddr(loc, addr);
302       auto absentBox = builder.create<fir::AbsentOp>(loc, boxTy);
303       box =
304           builder.create<mlir::arith::SelectOp>(loc, isPresent, box, absentBox);
305     }
306     bindCapturedSymbol(sym, box, converter, args.symMap);
307   }
308 };
309 
310 /// Class defining how allocatable and pointers entities are captured in
311 /// internal procedures. Allocatable and pointers are simply captured by placing
312 /// their !fir.ref<fir.box<>> address in the host tuple.
313 class CapturedAllocatableAndPointer
314     : public CapturedSymbols<CapturedAllocatableAndPointer> {
315 public:
getType(Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym)316   static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
317                             const Fortran::semantics::Symbol &sym) {
318     mlir::Type baseType = converter.genType(sym);
319     if (sym.GetUltimate().test(Fortran::semantics::Symbol::Flag::CrayPointee))
320       return fir::ReferenceType::get(
321           Fortran::lower::getCrayPointeeBoxType(baseType));
322     return fir::ReferenceType::get(baseType);
323   }
instantiateHostTuple(const InstantiateHostTuple & args,Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol &)324   static void instantiateHostTuple(const InstantiateHostTuple &args,
325                                    Fortran::lower::AbstractConverter &converter,
326                                    const Fortran::semantics::Symbol &) {
327     assert(args.hostValue.getBoxOf<fir::MutableBoxValue>() &&
328            "host value must be a fir::MutableBoxValue");
329     fir::FirOpBuilder &builder = converter.getFirOpBuilder();
330     mlir::Type typeInTuple = fir::dyn_cast_ptrEleTy(args.addrInTuple.getType());
331     assert(typeInTuple && "addrInTuple must be an address");
332     mlir::Value castBox = builder.createConvert(args.loc, typeInTuple,
333                                                 fir::getBase(args.hostValue));
334     builder.create<fir::StoreOp>(args.loc, castBox, args.addrInTuple);
335   }
getFromTuple(const GetFromTuple & args,Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym,const Fortran::lower::BoxAnalyzer & ba)336   static void getFromTuple(const GetFromTuple &args,
337                            Fortran::lower::AbstractConverter &converter,
338                            const Fortran::semantics::Symbol &sym,
339                            const Fortran::lower::BoxAnalyzer &ba) {
340     fir::FirOpBuilder &builder = converter.getFirOpBuilder();
341     mlir::Location loc = args.loc;
342     // Non deferred type parameters impact the semantics of some statements
343     // where allocatables/pointer can appear. For instance, assignment to a
344     // scalar character allocatable with has a different semantics in F2003 and
345     // later if the length is non deferred vs when it is deferred. So it is
346     // important to keep track of the non deferred parameters here.
347     llvm::SmallVector<mlir::Value> nonDeferredLenParams;
348     if (ba.isChar()) {
349       mlir::IndexType idxTy = builder.getIndexType();
350       if (std::optional<int64_t> len = ba.getCharLenConst()) {
351         nonDeferredLenParams.push_back(
352             builder.createIntegerConstant(loc, idxTy, *len));
353       } else if (Fortran::semantics::IsAssumedLengthCharacter(sym) ||
354                  ba.getCharLenExpr()) {
355         nonDeferredLenParams.push_back(
356             Fortran::lower::getAssumedCharAllocatableOrPointerLen(
357                 builder, loc, sym, args.valueInTuple));
358       }
359     } else if (isDerivedWithLenParameters(sym)) {
360       TODO(loc, "host associated derived type allocatable or pointer with "
361                 "length parameters");
362     }
363     bindCapturedSymbol(
364         sym, fir::MutableBoxValue(args.valueInTuple, nonDeferredLenParams, {}),
365         converter, args.symMap);
366   }
367 };
368 
369 /// Class defining how arrays, including assumed-ranks, are captured inside
370 /// internal procedures.
371 /// Array are captured via a `fir.box<fir.array<T>>` descriptor that belongs to
372 /// the host tuple. This allows capturing lower bounds, which can be done by
373 /// providing a ShapeShiftOp argument to the EmboxOp.
374 class CapturedArrays : public CapturedSymbols<CapturedArrays> {
375 
376   // Note: Constant shape arrays are not specialized (their base address would
377   // be sufficient information inside the tuple). They could be specialized in
378   // a later FIR pass, or a CapturedStaticShapeArrays could be added to deal
379   // with them here.
380 public:
getType(Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym)381   static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
382                             const Fortran::semantics::Symbol &sym) {
383     mlir::Type type = converter.genType(sym);
384     bool isPolymorphic = Fortran::semantics::IsPolymorphic(sym);
385     assert((mlir::isa<fir::SequenceType>(type) ||
386             (isPolymorphic && mlir::isa<fir::ClassType>(type))) &&
387            "must be a sequence type");
388     if (isPolymorphic)
389       return type;
390     return fir::BoxType::get(type);
391   }
392 
instantiateHostTuple(const InstantiateHostTuple & args,Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym)393   static void instantiateHostTuple(const InstantiateHostTuple &args,
394                                    Fortran::lower::AbstractConverter &converter,
395                                    const Fortran::semantics::Symbol &sym) {
396     fir::FirOpBuilder &builder = converter.getFirOpBuilder();
397     mlir::Location loc = args.loc;
398     fir::MutableBoxValue boxInTuple(args.addrInTuple, {}, {});
399     if (args.hostValue.getBoxOf<fir::BoxValue>() &&
400         Fortran::semantics::IsOptional(sym)) {
401       // The assumed shape optional case need some care because it is illegal to
402       // read the incoming box if it is absent (this would cause segfaults).
403       // Pointer association requires reading the target box, so it can only be
404       // done on present optional. For absent optionals, simply create a
405       // disassociated pointer (it is illegal to inquire about lower bounds or
406       // lengths of optional according to 15.5.2.12 3 (9) and 10.1.11 2 (7)b).
407       auto isPresent = builder.create<fir::IsPresentOp>(
408           loc, builder.getI1Type(), fir::getBase(args.hostValue));
409       builder.genIfThenElse(loc, isPresent)
410           .genThen([&]() {
411             fir::factory::associateMutableBox(builder, loc, boxInTuple,
412                                               args.hostValue,
413                                               /*lbounds=*/std::nullopt);
414           })
415           .genElse([&]() {
416             fir::factory::disassociateMutableBox(builder, loc, boxInTuple);
417           })
418           .end();
419     } else {
420       fir::factory::associateMutableBox(
421           builder, loc, boxInTuple, args.hostValue, /*lbounds=*/std::nullopt);
422     }
423   }
424 
getFromTuple(const GetFromTuple & args,Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym,const Fortran::lower::BoxAnalyzer & ba)425   static void getFromTuple(const GetFromTuple &args,
426                            Fortran::lower::AbstractConverter &converter,
427                            const Fortran::semantics::Symbol &sym,
428                            const Fortran::lower::BoxAnalyzer &ba) {
429     fir::FirOpBuilder &builder = converter.getFirOpBuilder();
430     mlir::Location loc = args.loc;
431     mlir::Value box = args.valueInTuple;
432     mlir::IndexType idxTy = builder.getIndexType();
433     llvm::SmallVector<mlir::Value> lbounds;
434     if (!ba.lboundIsAllOnes() && !Fortran::evaluate::IsAssumedRank(sym)) {
435       if (ba.isStaticArray()) {
436         for (std::int64_t lb : ba.staticLBound())
437           lbounds.emplace_back(builder.createIntegerConstant(loc, idxTy, lb));
438       } else {
439         // Cannot re-evaluate specification expressions here.
440         // Operands values may have changed. Get value from fir.box
441         const unsigned rank = sym.Rank();
442         for (unsigned dim = 0; dim < rank; ++dim) {
443           mlir::Value dimVal = builder.createIntegerConstant(loc, idxTy, dim);
444           auto dims = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy,
445                                                      box, dimVal);
446           lbounds.emplace_back(dims.getResult(0));
447         }
448       }
449     }
450 
451     if (canReadCapturedBoxValue(converter, sym)) {
452       fir::BoxValue boxValue(box, lbounds, /*explicitParams=*/std::nullopt);
453       bindCapturedSymbol(sym,
454                          fir::factory::readBoxValue(builder, loc, boxValue),
455                          converter, args.symMap);
456     } else {
457       // Keep variable as a fir.box/fir.class.
458       // If this is an optional that is absent, the fir.box needs to be an
459       // AbsentOp result, otherwise it will not work properly with IsPresentOp
460       // (absent boxes are null descriptor addresses, not descriptors containing
461       // a null base address).
462       if (Fortran::semantics::IsOptional(sym)) {
463         auto boxTy = mlir::cast<fir::BaseBoxType>(box.getType());
464         auto eleTy = boxTy.getEleTy();
465         if (!fir::isa_ref_type(eleTy))
466           eleTy = builder.getRefType(eleTy);
467         auto addr = builder.create<fir::BoxAddrOp>(loc, eleTy, box);
468         mlir::Value isPresent = builder.genIsNotNullAddr(loc, addr);
469         auto absentBox = builder.create<fir::AbsentOp>(loc, boxTy);
470         box = builder.create<mlir::arith::SelectOp>(loc, isPresent, box,
471                                                     absentBox);
472       }
473       fir::BoxValue boxValue(box, lbounds, /*explicitParams=*/std::nullopt);
474       bindCapturedSymbol(sym, boxValue, converter, args.symMap);
475     }
476   }
477 
478 private:
479   /// Can the fir.box from the host link be read into simpler values ?
480   /// Later, without the symbol information, it might not be possible
481   /// to tell if the fir::BoxValue from the host link is contiguous.
482   static bool
canReadCapturedBoxValue(Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym)483   canReadCapturedBoxValue(Fortran::lower::AbstractConverter &converter,
484                           const Fortran::semantics::Symbol &sym) {
485     bool isScalarOrContiguous =
486         sym.Rank() == 0 || Fortran::evaluate::IsSimplyContiguous(
487                                Fortran::evaluate::AsGenericExpr(sym).value(),
488                                converter.getFoldingContext());
489     const Fortran::semantics::DeclTypeSpec *type = sym.GetType();
490     bool isPolymorphic = type && type->IsPolymorphic();
491     return isScalarOrContiguous && !isPolymorphic &&
492            !isDerivedWithLenParameters(sym) &&
493            !Fortran::evaluate::IsAssumedRank(sym);
494   }
495 };
496 } // namespace
497 
498 /// Dispatch \p visitor to the CapturedSymbols which is handling how host
499 /// association is implemented for this kind of symbols. This ensures the same
500 /// dispatch decision is taken when building the tuple type, when creating the
501 /// tuple, and when instantiating host associated variables from it.
502 template <typename T>
503 static typename T::Result
walkCaptureCategories(T visitor,Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym)504 walkCaptureCategories(T visitor, Fortran::lower::AbstractConverter &converter,
505                       const Fortran::semantics::Symbol &sym) {
506   if (isDerivedWithLenParameters(sym))
507     // Should be boxed.
508     TODO(converter.genLocation(sym.name()),
509          "host associated derived type with length parameters");
510   Fortran::lower::BoxAnalyzer ba;
511   // Do not analyze procedures, they may be subroutines with no types that would
512   // crash the analysis.
513   if (Fortran::semantics::IsProcedure(sym))
514     return CapturedProcedure::visit(visitor, converter, sym, ba);
515   ba.analyze(sym);
516   if (Fortran::semantics::IsAllocatableOrPointer(sym) ||
517       sym.GetUltimate().test(Fortran::semantics::Symbol::Flag::CrayPointee))
518     return CapturedAllocatableAndPointer::visit(visitor, converter, sym, ba);
519   if (ba.isArray()) // include assumed-ranks.
520     return CapturedArrays::visit(visitor, converter, sym, ba);
521   if (Fortran::semantics::IsPolymorphic(sym))
522     return CapturedPolymorphicScalar::visit(visitor, converter, sym, ba);
523   if (ba.isChar())
524     return CapturedCharacterScalars::visit(visitor, converter, sym, ba);
525   assert(ba.isTrivial() && "must be trivial scalar");
526   return CapturedSimpleScalars::visit(visitor, converter, sym, ba);
527 }
528 
529 // `t` should be the result of getArgumentType, which has a type of
530 // `!fir.ref<tuple<...>>`.
unwrapTupleTy(mlir::Type t)531 static mlir::TupleType unwrapTupleTy(mlir::Type t) {
532   return mlir::cast<mlir::TupleType>(fir::dyn_cast_ptrEleTy(t));
533 }
534 
genTupleCoor(fir::FirOpBuilder & builder,mlir::Location loc,mlir::Type varTy,mlir::Value tupleArg,mlir::Value offset)535 static mlir::Value genTupleCoor(fir::FirOpBuilder &builder, mlir::Location loc,
536                                 mlir::Type varTy, mlir::Value tupleArg,
537                                 mlir::Value offset) {
538   // fir.ref<fir.ref> and fir.ptr<fir.ref> are forbidden. Use
539   // fir.llvm_ptr if needed.
540   auto ty = mlir::isa<fir::ReferenceType>(varTy)
541                 ? mlir::Type(fir::LLVMPointerType::get(varTy))
542                 : mlir::Type(builder.getRefType(varTy));
543   return builder.create<fir::CoordinateOp>(loc, ty, tupleArg, offset);
544 }
545 
addSymbolsToBind(const llvm::SetVector<const Fortran::semantics::Symbol * > & symbols,const Fortran::semantics::Scope & hostScope)546 void Fortran::lower::HostAssociations::addSymbolsToBind(
547     const llvm::SetVector<const Fortran::semantics::Symbol *> &symbols,
548     const Fortran::semantics::Scope &hostScope) {
549   assert(tupleSymbols.empty() && globalSymbols.empty() &&
550          "must be initially empty");
551   this->hostScope = &hostScope;
552   for (const auto *s : symbols)
553     // GlobalOp are created for non-global threadprivate variable,
554     //  so considering them as globals.
555     if (Fortran::lower::symbolIsGlobal(*s) ||
556         (*s).test(Fortran::semantics::Symbol::Flag::OmpThreadprivate)) {
557       // The ultimate symbol is stored here so that global symbols from the
558       // host scope can later be searched in this set.
559       globalSymbols.insert(&s->GetUltimate());
560     } else {
561       tupleSymbols.insert(s);
562     }
563 }
564 
hostProcedureBindings(Fortran::lower::AbstractConverter & converter,Fortran::lower::SymMap & symMap)565 void Fortran::lower::HostAssociations::hostProcedureBindings(
566     Fortran::lower::AbstractConverter &converter,
567     Fortran::lower::SymMap &symMap) {
568   if (tupleSymbols.empty())
569     return;
570 
571   // Create the tuple variable.
572   mlir::TupleType tupTy = unwrapTupleTy(getArgumentType(converter));
573   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
574   mlir::Location loc = converter.getCurrentLocation();
575   auto hostTuple = builder.create<fir::AllocaOp>(loc, tupTy);
576   mlir::IntegerType offTy = builder.getIntegerType(32);
577 
578   // Walk the list of tupleSymbols and update the pointers in the tuple.
579   for (auto s : llvm::enumerate(tupleSymbols)) {
580     auto indexInTuple = s.index();
581     mlir::Value off = builder.createIntegerConstant(loc, offTy, indexInTuple);
582     mlir::Type varTy = tupTy.getType(indexInTuple);
583     mlir::Value eleOff = genTupleCoor(builder, loc, varTy, hostTuple, off);
584     InstantiateHostTuple instantiateHostTuple{
585         converter.getSymbolExtendedValue(*s.value(), &symMap), eleOff, loc};
586     walkCaptureCategories(instantiateHostTuple, converter, *s.value());
587   }
588 
589   converter.bindHostAssocTuple(hostTuple);
590 }
591 
internalProcedureBindings(Fortran::lower::AbstractConverter & converter,Fortran::lower::SymMap & symMap)592 void Fortran::lower::HostAssociations::internalProcedureBindings(
593     Fortran::lower::AbstractConverter &converter,
594     Fortran::lower::SymMap &symMap) {
595   if (!globalSymbols.empty()) {
596     assert(hostScope && "host scope must have been set");
597     Fortran::lower::AggregateStoreMap storeMap;
598     // The host scope variable list is required to deal with host variables
599     // that are equivalenced and requires instantiating the right global
600     // AggregateStore.
601     for (auto &hostVariable : pft::getScopeVariableList(*hostScope))
602       if ((hostVariable.isAggregateStore() && hostVariable.isGlobal()) ||
603           (hostVariable.hasSymbol() &&
604            globalSymbols.contains(&hostVariable.getSymbol().GetUltimate()))) {
605         Fortran::lower::instantiateVariable(converter, hostVariable, symMap,
606                                             storeMap);
607         // Generate threadprivate Op for host associated variables.
608         if (hostVariable.hasSymbol() &&
609             hostVariable.getSymbol().test(
610                 Fortran::semantics::Symbol::Flag::OmpThreadprivate))
611           Fortran::lower::genThreadprivateOp(converter, hostVariable);
612       }
613   }
614   if (tupleSymbols.empty())
615     return;
616 
617   // Find the argument with the tuple type. The argument ought to be appended.
618   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
619   mlir::Type argTy = getArgumentType(converter);
620   mlir::TupleType tupTy = unwrapTupleTy(argTy);
621   mlir::Location loc = converter.getCurrentLocation();
622   mlir::func::FuncOp func = builder.getFunction();
623   mlir::Value tupleArg;
624   for (auto [ty, arg] : llvm::reverse(llvm::zip(
625            func.getFunctionType().getInputs(), func.front().getArguments())))
626     if (ty == argTy) {
627       tupleArg = arg;
628       break;
629     }
630   if (!tupleArg)
631     fir::emitFatalError(loc, "no host association argument found");
632 
633   converter.bindHostAssocTuple(tupleArg);
634 
635   mlir::IntegerType offTy = builder.getIntegerType(32);
636 
637   // Walk the list and add the bindings to the symbol table.
638   for (auto s : llvm::enumerate(tupleSymbols)) {
639     mlir::Value off = builder.createIntegerConstant(loc, offTy, s.index());
640     mlir::Type varTy = tupTy.getType(s.index());
641     mlir::Value eleOff = genTupleCoor(builder, loc, varTy, tupleArg, off);
642     mlir::Value valueInTuple = builder.create<fir::LoadOp>(loc, eleOff);
643     GetFromTuple getFromTuple{symMap, valueInTuple, loc};
644     walkCaptureCategories(getFromTuple, converter, *s.value());
645   }
646 }
647 
getArgumentType(Fortran::lower::AbstractConverter & converter)648 mlir::Type Fortran::lower::HostAssociations::getArgumentType(
649     Fortran::lower::AbstractConverter &converter) {
650   if (tupleSymbols.empty())
651     return {};
652   if (argType)
653     return argType;
654 
655   // Walk the list of Symbols and create their types. Wrap them in a reference
656   // to a tuple.
657   mlir::MLIRContext *ctxt = &converter.getMLIRContext();
658   llvm::SmallVector<mlir::Type> tupleTys;
659   for (const Fortran::semantics::Symbol *sym : tupleSymbols)
660     tupleTys.emplace_back(
661         walkCaptureCategories(GetTypeInTuple{}, converter, *sym));
662   argType = fir::ReferenceType::get(mlir::TupleType::get(ctxt, tupleTys));
663   return argType;
664 }
665