xref: /llvm-project/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp (revision 1d4453a6711394b368995c0f761015c1f6d27250)
1 //===- LowerHLFIROrderedAssignments.cpp - Lower HLFIR ordered assignments -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This file defines a pass to lower HLFIR ordered assignments.
9 // Ordered assignments are all the operations with the
10 // OrderedAssignmentTreeOpInterface that implements user defined assignments,
11 // assignment to vector subscripted entities, and assignments inside forall and
12 // where.
13 // The pass lowers these operations to regular hlfir.assign, loops and, if
14 // needed, introduces temporary storage to fulfill Fortran semantics.
15 //
16 // For each rewrite, an analysis builds an evaluation schedule, and then the
17 // new code is generated by following the evaluation schedule.
18 //===----------------------------------------------------------------------===//
19 
20 #include "ScheduleOrderedAssignments.h"
21 #include "flang/Optimizer/Builder/FIRBuilder.h"
22 #include "flang/Optimizer/Builder/HLFIRTools.h"
23 #include "flang/Optimizer/Builder/TemporaryStorage.h"
24 #include "flang/Optimizer/Builder/Todo.h"
25 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
26 #include "flang/Optimizer/HLFIR/Passes.h"
27 #include "mlir/IR/Dominance.h"
28 #include "mlir/IR/IRMapping.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/Debug.h"
33 
34 namespace hlfir {
35 #define GEN_PASS_DEF_LOWERHLFIRORDEREDASSIGNMENTS
36 #include "flang/Optimizer/HLFIR/Passes.h.inc"
37 } // namespace hlfir
38 
39 #define DEBUG_TYPE "flang-ordered-assignment"
40 
41 // Test option only to test the scheduling part only (operations are erased
42 // without codegen). The only goal is to allow printing and testing the debug
43 // info.
44 static llvm::cl::opt<bool> dbgScheduleOnly(
45     "flang-dbg-order-assignment-schedule-only",
46     llvm::cl::desc("Only run ordered assignment scheduling with no codegen"),
47     llvm::cl::init(false));
48 
49 namespace {
50 
51 /// Structure that represents a masked expression being lowered. Masked
52 /// expressions are any expressions inside an hlfir.where. As described in
53 /// Fortran 2018 section 10.2.3.2, the evaluation of the elemental parts of such
54 /// expressions must be masked, while the evaluation of none elemental parts
55 /// must not be masked. This structure analyzes the region evaluating the
56 /// expression and allows splitting the generation of the none elemental part
57 /// from the elemental part.
58 struct MaskedArrayExpr {
59   MaskedArrayExpr(mlir::Location loc, mlir::Region &region,
60                   bool isOuterMaskExpr);
61 
62   /// Generate the none elemental part. Must be called outside of the
63   /// loops created for the WHERE construct.
64   void generateNoneElementalPart(fir::FirOpBuilder &builder,
65                                  mlir::IRMapping &mapper);
66 
67   /// Methods below can only be called once generateNoneElementalPart has been
68   /// called.
69 
70   /// Return the shape of the expression.
71   mlir::Value generateShape(fir::FirOpBuilder &builder,
72                             mlir::IRMapping &mapper);
73   /// Return the value of an element value for this expression given the current
74   /// where loop indices.
75   mlir::Value generateElementalParts(fir::FirOpBuilder &builder,
76                                      mlir::ValueRange oneBasedIndices,
77                                      mlir::IRMapping &mapper);
78   /// Generate the cleanup for the none elemental parts, if any. This must be
79   /// called after the loops created for the WHERE construct.
80   void generateNoneElementalCleanupIfAny(fir::FirOpBuilder &builder,
81                                          mlir::IRMapping &mapper);
82 
83   /// Helper to clone the clean-ups of the masked expr region terminator.
84   /// This is called outside of the loops for the initial mask, and inside
85   /// the loops for the other masked expressions.
86   mlir::Operation *generateMaskedExprCleanUps(fir::FirOpBuilder &builder,
87                                               mlir::IRMapping &mapper);
88 
89   mlir::Location loc;
90   mlir::Region &region;
91   /// Set of operations that form the elemental parts of the
92   /// expression evaluation. These are the hlfir.elemental and
93   /// hlfir.elemental_addr that form the elemental tree producing
94   /// the expression value. hlfir.elemental that produce values
95   /// used inside transformational operations are not part of this set.
96   llvm::SmallSet<mlir::Operation *, 4> elementalParts{};
97   /// Was generateNoneElementalPart called?
98   bool noneElementalPartWasGenerated = false;
99   /// Is this expression the mask expression of the outer where statement?
100   /// It is special because its evaluation is not masked by anything yet.
101   bool isOuterMaskExpr = false;
102 };
103 } // namespace
104 
105 namespace {
106 /// Structure that visits an ordered assignment tree and generates code for
107 /// it according to a schedule.
108 class OrderedAssignmentRewriter {
109 public:
110   OrderedAssignmentRewriter(fir::FirOpBuilder &builder,
111                             hlfir::OrderedAssignmentTreeOpInterface root)
112       : builder{builder}, root{root} {}
113 
114   /// Generate code for the current run of the schedule.
115   void lowerRun(hlfir::Run &run) {
116     currentRun = &run;
117     walk(root);
118     currentRun = nullptr;
119     assert(constructStack.empty() && "must exit constructs after a run");
120     mapper.clear();
121     savedInCurrentRunBeforeUse.clear();
122   }
123 
124   /// After all run have been lowered, clean-up all the temporary
125   /// storage that were created (do not call final routines).
126   void cleanupSavedEntities() {
127     for (auto &temp : savedEntities)
128       temp.second.destroy(root.getLoc(), builder);
129   }
130 
131   /// Lowered value for an expression, and the original hlfir.yield if any
132   /// clean-up needs to be cloned after usage.
133   using ValueAndCleanUp = std::pair<mlir::Value, std::optional<hlfir::YieldOp>>;
134 
135 private:
136   /// Walk the part of an order assignment tree node that needs
137   /// to be evaluated in the current run.
138   void walk(hlfir::OrderedAssignmentTreeOpInterface node);
139 
140   /// Generate code when entering a given ordered assignment node.
141   void pre(hlfir::ForallOp forallOp);
142   void pre(hlfir::ForallIndexOp);
143   void pre(hlfir::ForallMaskOp);
144   void pre(hlfir::WhereOp whereOp);
145   void pre(hlfir::ElseWhereOp elseWhereOp);
146   void pre(hlfir::RegionAssignOp);
147 
148   /// Generate code when leaving a given ordered assignment node.
149   void post(hlfir::ForallOp);
150   void post(hlfir::ForallMaskOp);
151   void post(hlfir::WhereOp);
152   void post(hlfir::ElseWhereOp);
153   /// Enter (and maybe create) the fir.if else block of an ElseWhereOp,
154   /// but do not generate the elswhere mask or the new fir.if.
155   void enterElsewhere(hlfir::ElseWhereOp);
156 
157   /// Are there any leaf region in the node that must be saved in the current
158   /// run?
159   bool mustSaveRegionIn(
160       hlfir::OrderedAssignmentTreeOpInterface node,
161       llvm::SmallVectorImpl<hlfir::SaveEntity> &saveEntities) const;
162   /// Should this node be evaluated in the current run? Saving a region in a
163   /// node does not imply the node needs to be evaluated.
164   bool
165   isRequiredInCurrentRun(hlfir::OrderedAssignmentTreeOpInterface node) const;
166 
167   /// Generate a scalar value yielded by an ordered assignment tree region.
168   /// If the value was not saved in a previous run, this clone the region
169   /// code, except the final yield, at the current execution point.
170   /// If the value was saved in a previous run, this fetches the saved value
171   /// from the temporary storage and returns the value.
172   /// Inside Forall, the value will be hoisted outside of the forall loops if
173   /// it does not depend on the forall indices.
174   /// An optional type can be provided to get a value from a specific type
175   /// (the cast will be hoisted if the computation is hoisted).
176   mlir::Value generateYieldedScalarValue(
177       mlir::Region &region,
178       std::optional<mlir::Type> castToType = std::nullopt);
179 
180   /// Generate an entity yielded by an ordered assignment tree region, and
181   /// optionally return the (uncloned) yield if there is any clean-up that
182   /// should be done after using the entity. Like, generateYieldedScalarValue,
183   /// this will return the saved value if the region was saved in a previous
184   /// run.
185   ValueAndCleanUp
186   generateYieldedEntity(mlir::Region &region,
187                         std::optional<mlir::Type> castToType = std::nullopt);
188 
189   struct LhsValueAndCleanUp {
190     mlir::Value lhs;
191     std::optional<hlfir::YieldOp> elementalCleanup;
192     mlir::Region *nonElementalCleanup = nullptr;
193     std::optional<hlfir::LoopNest> vectorSubscriptLoopNest;
194     std::optional<mlir::Value> vectorSubscriptShape;
195   };
196 
197   /// Generate the left-hand side. If the left-hand side is vector
198   /// subscripted (hlfir.elemental_addr), this will create a loop nest
199   /// (unless it was already created by a WHERE mask) and return the
200   /// element address.
201   LhsValueAndCleanUp
202   generateYieldedLHS(mlir::Location loc, mlir::Region &lhsRegion,
203                      std::optional<hlfir::Entity> loweredRhs = std::nullopt);
204 
205   /// If \p maybeYield is present and has a clean-up, generate the clean-up
206   /// at the current insertion point (by cloning).
207   void generateCleanupIfAny(std::optional<hlfir::YieldOp> maybeYield);
208   void generateCleanupIfAny(mlir::Region *cleanupRegion);
209 
210   /// Generate a masked entity. This can only be called when whereLoopNest was
211   /// set (When an hlfir.where is being visited).
212   /// This method returns the scalar element (that may have been previously
213   /// saved) for the current indices inside the where loop.
214   mlir::Value generateMaskedEntity(mlir::Location loc, mlir::Region &region) {
215     MaskedArrayExpr maskedExpr(loc, region, /*isOuterMaskExpr=*/!whereLoopNest);
216     return generateMaskedEntity(maskedExpr);
217   }
218   mlir::Value generateMaskedEntity(MaskedArrayExpr &maskedExpr);
219 
220   /// Create a fir.if at the current position inside the where loop nest
221   /// given the element value of a mask.
222   void generateMaskIfOp(mlir::Value cdt);
223 
224   /// Save a value for subsequent runs.
225   void generateSaveEntity(hlfir::SaveEntity savedEntity,
226                           bool willUseSavedEntityInSameRun);
227   void saveLeftHandSide(hlfir::SaveEntity savedEntity,
228                         hlfir::RegionAssignOp regionAssignOp);
229 
230   /// Get a value if it was saved in this run or a previous run. Returns
231   /// nullopt if it has not been saved.
232   std::optional<ValueAndCleanUp> getIfSaved(mlir::Region &region);
233 
234   /// Generate code before the loop nest for the current run, if any.
235   void doBeforeLoopNest(const std::function<void()> &callback) {
236     if (constructStack.empty()) {
237       callback();
238       return;
239     }
240     auto insertionPoint = builder.saveInsertionPoint();
241     builder.setInsertionPoint(constructStack[0]);
242     callback();
243     builder.restoreInsertionPoint(insertionPoint);
244   }
245 
246   /// Can the current loop nest iteration number be computed? For simplicity,
247   /// this is true if and only if all the bounds and steps of the fir.do_loop
248   /// nest dominates the outer loop. The argument is filled with the current
249   /// loop nest on success.
250   bool currentLoopNestIterationNumberCanBeComputed(
251       llvm::SmallVectorImpl<fir::DoLoopOp> &loopNest);
252 
253   template <typename T>
254   fir::factory::TemporaryStorage *insertSavedEntity(mlir::Region &region,
255                                                     T &&temp) {
256     auto inserted =
257         savedEntities.insert(std::make_pair(&region, std::forward<T>(temp)));
258     assert(inserted.second && "temp must have been emplaced");
259     return &inserted.first->second;
260   }
261 
262   fir::FirOpBuilder &builder;
263 
264   /// Map containing the mapping between the original order assignment tree
265   /// operations and the operations that have been cloned in the current run.
266   /// It is reset between two runs.
267   mlir::IRMapping mapper;
268   /// Dominance info is used to determine if inner loop bounds are all computed
269   /// before outer loop for the current loop. It does not need to be reset
270   /// between runs.
271   mlir::DominanceInfo dominanceInfo;
272   /// Construct stack in the current run. This allows setting back the insertion
273   /// point correctly when leaving a node that requires a fir.do_loop or fir.if
274   /// operation.
275   llvm::SmallVector<mlir::Operation *> constructStack;
276   /// Current where loop nest, if any.
277   std::optional<hlfir::LoopNest> whereLoopNest;
278 
279   /// Map of temporary storage to keep track of saved entity once the run
280   /// that saves them has been lowered. It is kept in-between runs.
281   /// llvm::MapVector is used to guarantee deterministic order
282   /// of iterating through savedEntities (e.g. for generating
283   /// destruction code for the temporary storages).
284   llvm::MapVector<mlir::Region *, fir::factory::TemporaryStorage> savedEntities;
285   /// Map holding the values that were saved in the current run and that also
286   /// need to be used (because their construct will be visited). It is reset
287   /// after each run. It avoids having to store and fetch in the temporary
288   /// during the same run, which would require the temporary to have different
289   /// fetching and storing counters.
290   llvm::DenseMap<mlir::Region *, ValueAndCleanUp> savedInCurrentRunBeforeUse;
291 
292   /// Root of the order assignment tree being lowered.
293   hlfir::OrderedAssignmentTreeOpInterface root;
294   /// Pointer to the current run of the schedule being lowered.
295   hlfir::Run *currentRun = nullptr;
296 
297   /// When allocating temporary storage inlined, indicate if the storage should
298   /// be heap or stack allocated. Temporary allocated with the runtime are heap
299   /// allocated by the runtime.
300   bool allocateOnHeap = true;
301 };
302 } // namespace
303 
304 void OrderedAssignmentRewriter::walk(
305     hlfir::OrderedAssignmentTreeOpInterface node) {
306   bool mustVisit =
307       isRequiredInCurrentRun(node) || mlir::isa<hlfir::ForallIndexOp>(node);
308   llvm::SmallVector<hlfir::SaveEntity> saveEntities;
309   mlir::Operation *nodeOp = node.getOperation();
310   if (mustSaveRegionIn(node, saveEntities)) {
311     mlir::IRRewriter::InsertPoint insertionPoint;
312     if (auto elseWhereOp = mlir::dyn_cast<hlfir::ElseWhereOp>(nodeOp)) {
313       // ElseWhere mask to save must be evaluated inside the fir.if else
314       // for the previous where/elsewehere (its evaluation must be
315       // masked by the "pending control mask").
316       insertionPoint = builder.saveInsertionPoint();
317       enterElsewhere(elseWhereOp);
318     }
319     for (hlfir::SaveEntity saveEntity : saveEntities)
320       generateSaveEntity(saveEntity, mustVisit);
321     if (insertionPoint.isSet())
322       builder.restoreInsertionPoint(insertionPoint);
323   }
324   if (mustVisit) {
325     llvm::TypeSwitch<mlir::Operation *, void>(nodeOp)
326         .Case<hlfir::ForallOp, hlfir::ForallIndexOp, hlfir::ForallMaskOp,
327               hlfir::RegionAssignOp, hlfir::WhereOp, hlfir::ElseWhereOp>(
328             [&](auto concreteOp) { pre(concreteOp); })
329         .Default([](auto) {});
330     if (auto *body = node.getSubTreeRegion()) {
331       for (mlir::Operation &op : body->getOps())
332         if (auto subNode =
333                 mlir::dyn_cast<hlfir::OrderedAssignmentTreeOpInterface>(op))
334           walk(subNode);
335       llvm::TypeSwitch<mlir::Operation *, void>(nodeOp)
336           .Case<hlfir::ForallOp, hlfir::ForallMaskOp, hlfir::WhereOp,
337                 hlfir::ElseWhereOp>([&](auto concreteOp) { post(concreteOp); })
338           .Default([](auto) {});
339     }
340   }
341 }
342 
343 void OrderedAssignmentRewriter::pre(hlfir::ForallOp forallOp) {
344   /// Create a fir.do_loop given the hlfir.forall control values.
345   mlir::Type idxTy = builder.getIndexType();
346   mlir::Location loc = forallOp.getLoc();
347   mlir::Value lb = generateYieldedScalarValue(forallOp.getLbRegion(), idxTy);
348   mlir::Value ub = generateYieldedScalarValue(forallOp.getUbRegion(), idxTy);
349   mlir::Value step;
350   if (forallOp.getStepRegion().empty()) {
351     auto insertionPoint = builder.saveInsertionPoint();
352     if (!constructStack.empty())
353       builder.setInsertionPoint(constructStack[0]);
354     step = builder.createIntegerConstant(loc, idxTy, 1);
355     if (!constructStack.empty())
356       builder.restoreInsertionPoint(insertionPoint);
357   } else {
358     step = generateYieldedScalarValue(forallOp.getStepRegion(), idxTy);
359   }
360   auto doLoop = builder.create<fir::DoLoopOp>(loc, lb, ub, step);
361   builder.setInsertionPointToStart(doLoop.getBody());
362   mlir::Value oldIndex = forallOp.getForallIndexValue();
363   mlir::Value newIndex =
364       builder.createConvert(loc, oldIndex.getType(), doLoop.getInductionVar());
365   mapper.map(oldIndex, newIndex);
366   constructStack.push_back(doLoop);
367 }
368 
369 void OrderedAssignmentRewriter::post(hlfir::ForallOp) {
370   assert(!constructStack.empty() && "must contain a loop");
371   builder.setInsertionPointAfter(constructStack.pop_back_val());
372 }
373 
374 void OrderedAssignmentRewriter::pre(hlfir::ForallIndexOp forallIndexOp) {
375   mlir::Location loc = forallIndexOp.getLoc();
376   mlir::Type intTy = fir::unwrapRefType(forallIndexOp.getType());
377   mlir::Value indexVar =
378       builder.createTemporary(loc, intTy, forallIndexOp.getName());
379   mlir::Value newVal = mapper.lookupOrDefault(forallIndexOp.getIndex());
380   builder.createStoreWithConvert(loc, newVal, indexVar);
381   mapper.map(forallIndexOp, indexVar);
382 }
383 
384 void OrderedAssignmentRewriter::pre(hlfir::ForallMaskOp forallMaskOp) {
385   mlir::Location loc = forallMaskOp.getLoc();
386   mlir::Value mask = generateYieldedScalarValue(forallMaskOp.getMaskRegion(),
387                                                 builder.getI1Type());
388   auto ifOp = builder.create<fir::IfOp>(loc, std::nullopt, mask, false);
389   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
390   constructStack.push_back(ifOp);
391 }
392 
393 void OrderedAssignmentRewriter::post(hlfir::ForallMaskOp forallMaskOp) {
394   assert(!constructStack.empty() && "must contain an ifop");
395   builder.setInsertionPointAfter(constructStack.pop_back_val());
396 }
397 
398 /// Convert an entity to the type of a given mold.
399 /// This is intended to help with cases where hlfir entity is a value while
400 /// it must be used as a variable or vice-versa. These mismatches may occur
401 /// between the type of user defined assignment block arguments and the actual
402 /// argument that was lowered for them. The actual may be an in-memory copy
403 /// while the block argument expects an hlfir.expr.
404 static hlfir::Entity
405 convertToMoldType(mlir::Location loc, fir::FirOpBuilder &builder,
406                   hlfir::Entity input, hlfir::Entity mold,
407                   llvm::SmallVectorImpl<hlfir::CleanupFunction> &cleanups) {
408   if (input.getType() == mold.getType())
409     return input;
410   fir::FirOpBuilder *b = &builder;
411   if (input.isVariable() && mold.isValue()) {
412     if (fir::isa_trivial(mold.getType())) {
413       // fir.ref<T> to T.
414       mlir::Value load = builder.create<fir::LoadOp>(loc, input);
415       return hlfir::Entity{builder.createConvert(loc, mold.getType(), load)};
416     }
417     // fir.ref<T> to hlfir.expr<T>.
418     mlir::Value asExpr = builder.create<hlfir::AsExprOp>(loc, input);
419     if (asExpr.getType() != mold.getType())
420       TODO(loc, "hlfir.expr conversion");
421     cleanups.emplace_back([=]() { b->create<hlfir::DestroyOp>(loc, asExpr); });
422     return hlfir::Entity{asExpr};
423   }
424   if (input.isValue() && mold.isVariable()) {
425     // T to fir.ref<T>, or hlfir.expr<T> to fir.ref<T>.
426     hlfir::AssociateOp associate = hlfir::genAssociateExpr(
427         loc, builder, input, mold.getFortranElementType(), ".tmp.val2ref");
428     cleanups.emplace_back(
429         [=]() { b->create<hlfir::EndAssociateOp>(loc, associate); });
430     return hlfir::Entity{associate.getBase()};
431   }
432   // Variable to Variable mismatch (e.g., fir.heap<T> vs fir.ref<T>), or value
433   // to Value mismatch (e.g. i1 vs fir.logical<4>).
434   if (mlir::isa<fir::BaseBoxType>(mold.getType()) &&
435       !mlir::isa<fir::BaseBoxType>(input.getType())) {
436     // An entity may have have been saved without descriptor while the original
437     // value had a descriptor (e.g., it was not contiguous).
438     auto emboxed = hlfir::convertToBox(loc, builder, input, mold.getType());
439     assert(!emboxed.second && "temp should already be in memory");
440     input = hlfir::Entity{fir::getBase(emboxed.first)};
441   }
442   return hlfir::Entity{builder.createConvert(loc, mold.getType(), input)};
443 }
444 
445 void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
446   mlir::Location loc = regionAssignOp.getLoc();
447   std::optional<hlfir::LoopNest> elementalLoopNest;
448   auto [rhsValue, oldRhsYield] =
449       generateYieldedEntity(regionAssignOp.getRhsRegion());
450   hlfir::Entity rhsEntity{rhsValue};
451   LhsValueAndCleanUp loweredLhs =
452       generateYieldedLHS(loc, regionAssignOp.getLhsRegion(), rhsEntity);
453   hlfir::Entity lhsEntity{loweredLhs.lhs};
454   if (loweredLhs.vectorSubscriptLoopNest)
455     rhsEntity = hlfir::getElementAt(
456         loc, builder, rhsEntity,
457         loweredLhs.vectorSubscriptLoopNest->oneBasedIndices);
458   if (!regionAssignOp.getUserDefinedAssignment().empty()) {
459     hlfir::Entity userAssignLhs{regionAssignOp.getUserAssignmentLhs()};
460     hlfir::Entity userAssignRhs{regionAssignOp.getUserAssignmentRhs()};
461     std::optional<hlfir::LoopNest> elementalLoopNest;
462     if (lhsEntity.isArray() && userAssignLhs.isScalar()) {
463       // Elemental assignment with array argument (the RHS cannot be an array
464       // if the LHS is not).
465       mlir::Value shape = hlfir::genShape(loc, builder, lhsEntity);
466       elementalLoopNest = hlfir::genLoopNest(loc, builder, shape);
467       builder.setInsertionPointToStart(elementalLoopNest->body);
468       lhsEntity = hlfir::getElementAt(loc, builder, lhsEntity,
469                                       elementalLoopNest->oneBasedIndices);
470       rhsEntity = hlfir::getElementAt(loc, builder, rhsEntity,
471                                       elementalLoopNest->oneBasedIndices);
472     }
473 
474     llvm::SmallVector<hlfir::CleanupFunction, 2> argConversionCleanups;
475     lhsEntity = convertToMoldType(loc, builder, lhsEntity, userAssignLhs,
476                                   argConversionCleanups);
477     rhsEntity = convertToMoldType(loc, builder, rhsEntity, userAssignRhs,
478                                   argConversionCleanups);
479     mapper.map(userAssignLhs, lhsEntity);
480     mapper.map(userAssignRhs, rhsEntity);
481     for (auto &op :
482          regionAssignOp.getUserDefinedAssignment().front().without_terminator())
483       (void)builder.clone(op, mapper);
484     for (auto &cleanupConversion : argConversionCleanups)
485       cleanupConversion();
486     if (elementalLoopNest)
487       builder.setInsertionPointAfter(elementalLoopNest->outerOp);
488   } else {
489     // TODO: preserve allocatable assignment aspects for forall once
490     // they are conveyed in hlfir.region_assign.
491     builder.create<hlfir::AssignOp>(loc, rhsEntity, lhsEntity);
492   }
493   generateCleanupIfAny(loweredLhs.elementalCleanup);
494   if (loweredLhs.vectorSubscriptLoopNest)
495     builder.setInsertionPointAfter(loweredLhs.vectorSubscriptLoopNest->outerOp);
496   generateCleanupIfAny(oldRhsYield);
497   generateCleanupIfAny(loweredLhs.nonElementalCleanup);
498 }
499 
500 void OrderedAssignmentRewriter::generateMaskIfOp(mlir::Value cdt) {
501   mlir::Location loc = cdt.getLoc();
502   cdt = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{cdt});
503   cdt = builder.createConvert(loc, builder.getI1Type(), cdt);
504   auto ifOp = builder.create<fir::IfOp>(cdt.getLoc(), std::nullopt, cdt,
505                                         /*withElseRegion=*/false);
506   constructStack.push_back(ifOp.getOperation());
507   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
508 }
509 
510 void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
511   mlir::Location loc = whereOp.getLoc();
512   if (!whereLoopNest) {
513     // This is the top-level WHERE. Start a loop nest iterating on the shape of
514     // the where mask.
515     if (auto maybeSaved = getIfSaved(whereOp.getMaskRegion())) {
516       // Use the saved value to get the shape and condition element.
517       hlfir::Entity savedMask{maybeSaved->first};
518       mlir::Value shape = hlfir::genShape(loc, builder, savedMask);
519       whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
520       constructStack.push_back(whereLoopNest->outerOp);
521       builder.setInsertionPointToStart(whereLoopNest->body);
522       mlir::Value cdt = hlfir::getElementAt(loc, builder, savedMask,
523                                             whereLoopNest->oneBasedIndices);
524       generateMaskIfOp(cdt);
525       if (maybeSaved->second) {
526         // If this is the same run as the one that saved the value, the clean-up
527         // was left-over to be done now.
528         auto insertionPoint = builder.saveInsertionPoint();
529         builder.setInsertionPointAfter(whereLoopNest->outerOp);
530         generateCleanupIfAny(maybeSaved->second);
531         builder.restoreInsertionPoint(insertionPoint);
532       }
533       return;
534     }
535     // The mask was not evaluated yet or can be safely re-evaluated.
536     MaskedArrayExpr mask(loc, whereOp.getMaskRegion(),
537                          /*isOuterMaskExpr=*/true);
538     mask.generateNoneElementalPart(builder, mapper);
539     mlir::Value shape = mask.generateShape(builder, mapper);
540     whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
541     constructStack.push_back(whereLoopNest->outerOp);
542     builder.setInsertionPointToStart(whereLoopNest->body);
543     mlir::Value cdt = generateMaskedEntity(mask);
544     generateMaskIfOp(cdt);
545     return;
546   }
547   // Where Loops have been already created by a parent WHERE.
548   // Generate a fir.if with the value of the current element of the mask
549   // inside the loops. The case where the mask was saved is handled in the
550   // generateYieldedScalarValue call.
551   mlir::Value cdt = generateYieldedScalarValue(whereOp.getMaskRegion());
552   generateMaskIfOp(cdt);
553 }
554 
555 void OrderedAssignmentRewriter::post(hlfir::WhereOp whereOp) {
556   assert(!constructStack.empty() && "must contain a fir.if");
557   builder.setInsertionPointAfter(constructStack.pop_back_val());
558   // If all where/elsewhere fir.if have been popped, this is the outer whereOp,
559   // and the where loop must be exited.
560   assert(!constructStack.empty() && "must contain a  fir.do_loop or fir.if");
561   if (mlir::isa<fir::DoLoopOp>(constructStack.back())) {
562     builder.setInsertionPointAfter(constructStack.pop_back_val());
563     whereLoopNest.reset();
564   }
565 }
566 
567 void OrderedAssignmentRewriter::enterElsewhere(hlfir::ElseWhereOp elseWhereOp) {
568   // Create an "else" region for the current where/elsewhere fir.if.
569   auto ifOp = mlir::dyn_cast<fir::IfOp>(constructStack.back());
570   assert(ifOp && "must be an if");
571   if (ifOp.getElseRegion().empty()) {
572     mlir::Location loc = elseWhereOp.getLoc();
573     builder.createBlock(&ifOp.getElseRegion());
574     auto end = builder.create<fir::ResultOp>(loc);
575     builder.setInsertionPoint(end);
576   } else {
577     builder.setInsertionPoint(&ifOp.getElseRegion().back().back());
578   }
579 }
580 
581 void OrderedAssignmentRewriter::pre(hlfir::ElseWhereOp elseWhereOp) {
582   enterElsewhere(elseWhereOp);
583   if (elseWhereOp.getMaskRegion().empty())
584     return;
585   // Create new nested fir.if with elsewhere mask if any.
586   mlir::Value cdt = generateYieldedScalarValue(elseWhereOp.getMaskRegion());
587   generateMaskIfOp(cdt);
588 }
589 
590 void OrderedAssignmentRewriter::post(hlfir::ElseWhereOp elseWhereOp) {
591   // Exit ifOp that was created for the elseWhereOp mask, if any.
592   if (elseWhereOp.getMaskRegion().empty())
593     return;
594   assert(!constructStack.empty() && "must contain a fir.if");
595   builder.setInsertionPointAfter(constructStack.pop_back_val());
596 }
597 
598 /// Is this value a Forall index?
599 /// Forall index are block arguments of hlfir.forall body, or the result
600 /// of hlfir.forall_index.
601 static bool isForallIndex(mlir::Value value) {
602   if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(value)) {
603     if (mlir::Block *block = blockArg.getOwner())
604       return block->isEntryBlock() &&
605              mlir::isa_and_nonnull<hlfir::ForallOp>(block->getParentOp());
606     return false;
607   }
608   return value.getDefiningOp<hlfir::ForallIndexOp>();
609 }
610 
611 static OrderedAssignmentRewriter::ValueAndCleanUp
612 castIfNeeded(mlir::Location loc, fir::FirOpBuilder &builder,
613              OrderedAssignmentRewriter::ValueAndCleanUp valueAndCleanUp,
614              std::optional<mlir::Type> castToType) {
615   if (!castToType.has_value())
616     return valueAndCleanUp;
617   mlir::Value cast =
618       builder.createConvert(loc, *castToType, valueAndCleanUp.first);
619   return {cast, valueAndCleanUp.second};
620 }
621 
622 std::optional<OrderedAssignmentRewriter::ValueAndCleanUp>
623 OrderedAssignmentRewriter::getIfSaved(mlir::Region &region) {
624   mlir::Location loc = region.getParentOp()->getLoc();
625   // If the region was saved in the same run, use the value that was evaluated
626   // instead of fetching the temp, and do clean-up, if any, that were delayed.
627   // This is done to avoid requiring the temporary stack to have different
628   // fetching and storing counters, and also because it produces slightly better
629   // code.
630   if (auto savedInSameRun = savedInCurrentRunBeforeUse.find(&region);
631       savedInSameRun != savedInCurrentRunBeforeUse.end())
632     return savedInSameRun->second;
633   // If the region was saved in a previous run, fetch the saved value.
634   if (auto temp = savedEntities.find(&region); temp != savedEntities.end()) {
635     doBeforeLoopNest([&]() { temp->second.resetFetchPosition(loc, builder); });
636     return ValueAndCleanUp{temp->second.fetch(loc, builder), std::nullopt};
637   }
638   return std::nullopt;
639 }
640 
641 static hlfir::YieldOp getYield(mlir::Region &region) {
642   auto yield = mlir::dyn_cast_or_null<hlfir::YieldOp>(
643       region.back().getOperations().back());
644   assert(yield && "region computing entities must end with a YieldOp");
645   return yield;
646 }
647 
648 OrderedAssignmentRewriter::ValueAndCleanUp
649 OrderedAssignmentRewriter::generateYieldedEntity(
650     mlir::Region &region, std::optional<mlir::Type> castToType) {
651   mlir::Location loc = region.getParentOp()->getLoc();
652   if (auto maybeValueAndCleanUp = getIfSaved(region))
653     return castIfNeeded(loc, builder, *maybeValueAndCleanUp, castToType);
654   // Otherwise, evaluate the region now.
655 
656   // Masked expression must not evaluate the elemental parts that are masked,
657   // they have custom code generation.
658   if (whereLoopNest.has_value()) {
659     mlir::Value maskedValue = generateMaskedEntity(loc, region);
660     return castIfNeeded(loc, builder, {maskedValue, std::nullopt}, castToType);
661   }
662 
663   auto oldYield = getYield(region);
664   // Inside Forall, scalars that do not depend on forall indices can be hoisted
665   // here because their evaluation is required to only call pure procedures, and
666   // if they depend on a variable previously assigned to in a forall assignment,
667   // this assignment must have been scheduled in a previous run. Hoisting of
668   // scalars is done here to help creating simple temporary storage if needed.
669   // Inner forall bounds can often be hoisted, and this allows computing the
670   // total number of iterations to create temporary storages.
671   bool hoistComputation = false;
672   if (fir::isa_trivial(oldYield.getEntity().getType()) &&
673       !constructStack.empty()) {
674     mlir::WalkResult walkResult =
675         region.walk([&](mlir::Operation *op) -> mlir::WalkResult {
676           if (llvm::any_of(op->getOperands(), [](mlir::Value value) {
677                 return isForallIndex(value);
678               }))
679             return mlir::WalkResult::interrupt();
680           return mlir::WalkResult::advance();
681         });
682     hoistComputation = !walkResult.wasInterrupted();
683   }
684   auto insertionPoint = builder.saveInsertionPoint();
685   if (hoistComputation)
686     builder.setInsertionPoint(constructStack[0]);
687 
688   // Clone all operations except the final hlfir.yield.
689   assert(region.hasOneBlock() && "region must contain one block");
690   for (auto &op : region.back().without_terminator())
691     (void)builder.clone(op, mapper);
692   // Get the value for the yielded entity, it may be the result of an operation
693   // that was cloned, or it may be the same as the previous value if the yield
694   // operand was created before the ordered assignment tree.
695   mlir::Value newEntity = mapper.lookupOrDefault(oldYield.getEntity());
696   if (castToType.has_value())
697     newEntity =
698         builder.createConvert(newEntity.getLoc(), *castToType, newEntity);
699 
700   if (hoistComputation) {
701     // Hoisted trivial scalars clean-up can be done right away, the value is
702     // in registers.
703     generateCleanupIfAny(oldYield);
704     builder.restoreInsertionPoint(insertionPoint);
705     return {newEntity, std::nullopt};
706   }
707   if (oldYield.getCleanup().empty())
708     return {newEntity, std::nullopt};
709   return {newEntity, oldYield};
710 }
711 
712 mlir::Value OrderedAssignmentRewriter::generateYieldedScalarValue(
713     mlir::Region &region, std::optional<mlir::Type> castToType) {
714   mlir::Location loc = region.getParentOp()->getLoc();
715   auto [value, maybeYield] = generateYieldedEntity(region, castToType);
716   value = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{value});
717   assert(fir::isa_trivial(value.getType()) && "not a trivial scalar value");
718   generateCleanupIfAny(maybeYield);
719   return value;
720 }
721 
722 OrderedAssignmentRewriter::LhsValueAndCleanUp
723 OrderedAssignmentRewriter::generateYieldedLHS(
724     mlir::Location loc, mlir::Region &lhsRegion,
725     std::optional<hlfir::Entity> loweredRhs) {
726   LhsValueAndCleanUp loweredLhs;
727   hlfir::ElementalAddrOp elementalAddrLhs =
728       mlir::dyn_cast<hlfir::ElementalAddrOp>(lhsRegion.back().back());
729   if (auto temp = savedEntities.find(&lhsRegion); temp != savedEntities.end()) {
730     // The LHS address was computed and saved in a previous run. Fetch it.
731     doBeforeLoopNest([&]() { temp->second.resetFetchPosition(loc, builder); });
732     if (elementalAddrLhs && !whereLoopNest) {
733       // Vector subscripted designator address are saved element by element.
734       // If no "elemental" loops have been created yet, the shape of the
735       // RHS, if it is an array can be used, or the shape of the vector
736       // subscripted designator must be retrieved to generate the "elemental"
737       // loop nest.
738       if (loweredRhs && loweredRhs->isArray()) {
739         // The RHS shape can be used to create the elemental loops and avoid
740         // saving the LHS shape.
741         loweredLhs.vectorSubscriptShape =
742             hlfir::genShape(loc, builder, *loweredRhs);
743       } else {
744         // If the shape cannot be retrieved from the RHS, it must have been
745         // saved. Get it from the temporary.
746         auto &vectorTmp =
747             temp->second.cast<fir::factory::AnyVectorSubscriptStack>();
748         loweredLhs.vectorSubscriptShape = vectorTmp.fetchShape(loc, builder);
749       }
750       loweredLhs.vectorSubscriptLoopNest = hlfir::genLoopNest(
751           loc, builder, loweredLhs.vectorSubscriptShape.value());
752       builder.setInsertionPointToStart(
753           loweredLhs.vectorSubscriptLoopNest->body);
754     }
755     loweredLhs.lhs = temp->second.fetch(loc, builder);
756     return loweredLhs;
757   }
758   // The LHS has not yet been evaluated and saved. Evaluate it now.
759   if (elementalAddrLhs && !whereLoopNest) {
760     // This is a vector subscripted entity. The address of elements must
761     // be returned. If no "elemental" loops have been created for a WHERE,
762     // create them now based on the vector subscripted designator shape.
763     for (auto &op : lhsRegion.front().without_terminator())
764       (void)builder.clone(op, mapper);
765     loweredLhs.vectorSubscriptShape =
766         mapper.lookupOrDefault(elementalAddrLhs.getShape());
767     loweredLhs.vectorSubscriptLoopNest =
768         hlfir::genLoopNest(loc, builder, *loweredLhs.vectorSubscriptShape,
769                            !elementalAddrLhs.isOrdered());
770     builder.setInsertionPointToStart(loweredLhs.vectorSubscriptLoopNest->body);
771     mapper.map(elementalAddrLhs.getIndices(),
772                loweredLhs.vectorSubscriptLoopNest->oneBasedIndices);
773     for (auto &op : elementalAddrLhs.getBody().front().without_terminator())
774       (void)builder.clone(op, mapper);
775     loweredLhs.elementalCleanup = elementalAddrLhs.getYieldOp();
776     loweredLhs.lhs =
777         mapper.lookupOrDefault(loweredLhs.elementalCleanup->getEntity());
778   } else {
779     // This is a designator without vector subscripts. Generate it as
780     // it is done for other entities.
781     auto [lhs, yield] = generateYieldedEntity(lhsRegion);
782     loweredLhs.lhs = lhs;
783     if (yield && !yield->getCleanup().empty())
784       loweredLhs.nonElementalCleanup = &yield->getCleanup();
785   }
786   return loweredLhs;
787 }
788 
789 mlir::Value
790 OrderedAssignmentRewriter::generateMaskedEntity(MaskedArrayExpr &maskedExpr) {
791   assert(whereLoopNest.has_value() && "must be inside WHERE loop nest");
792   auto insertionPoint = builder.saveInsertionPoint();
793   if (!maskedExpr.noneElementalPartWasGenerated) {
794     // Generate none elemental part before the where loops (but inside the
795     // current forall loops if any).
796     builder.setInsertionPoint(whereLoopNest->outerOp);
797     maskedExpr.generateNoneElementalPart(builder, mapper);
798   }
799   // Generate the none elemental part cleanup after the where loops.
800   builder.setInsertionPointAfter(whereLoopNest->outerOp);
801   maskedExpr.generateNoneElementalCleanupIfAny(builder, mapper);
802   // Generate the value of the current element for the masked expression
803   // at the current insertion point (inside the where loops, and any fir.if
804   // generated for previous masks).
805   builder.restoreInsertionPoint(insertionPoint);
806   mlir::Value scalar = maskedExpr.generateElementalParts(
807       builder, whereLoopNest->oneBasedIndices, mapper);
808   /// Generate cleanups for the elemental parts inside the loops (setting the
809   /// location so that the assignment will be generated before the cleanups).
810   if (!maskedExpr.isOuterMaskExpr)
811     if (mlir::Operation *firstCleanup =
812             maskedExpr.generateMaskedExprCleanUps(builder, mapper))
813       builder.setInsertionPoint(firstCleanup);
814   return scalar;
815 }
816 
817 void OrderedAssignmentRewriter::generateCleanupIfAny(
818     std::optional<hlfir::YieldOp> maybeYield) {
819   if (maybeYield.has_value())
820     generateCleanupIfAny(&maybeYield->getCleanup());
821 }
822 void OrderedAssignmentRewriter::generateCleanupIfAny(
823     mlir::Region *cleanupRegion) {
824   if (cleanupRegion && !cleanupRegion->empty()) {
825     assert(cleanupRegion->hasOneBlock() && "region must contain one block");
826     for (auto &op : cleanupRegion->back().without_terminator())
827       builder.clone(op, mapper);
828   }
829 }
830 
831 bool OrderedAssignmentRewriter::mustSaveRegionIn(
832     hlfir::OrderedAssignmentTreeOpInterface node,
833     llvm::SmallVectorImpl<hlfir::SaveEntity> &saveEntities) const {
834   for (auto &action : currentRun->actions)
835     if (hlfir::SaveEntity *savedEntity =
836             std::get_if<hlfir::SaveEntity>(&action))
837       if (node.getOperation() == savedEntity->yieldRegion->getParentOp())
838         saveEntities.push_back(*savedEntity);
839   return !saveEntities.empty();
840 }
841 
842 bool OrderedAssignmentRewriter::isRequiredInCurrentRun(
843     hlfir::OrderedAssignmentTreeOpInterface node) const {
844   // hlfir.forall_index do not contain saved regions/assignments,
845   // but if their hlfir.forall parent was required, they are
846   // required (the forall indices needs to be mapped).
847   if (mlir::isa<hlfir::ForallIndexOp>(node))
848     return true;
849   for (auto &action : currentRun->actions)
850     if (hlfir::SaveEntity *savedEntity =
851             std::get_if<hlfir::SaveEntity>(&action)) {
852       // A SaveEntity action does not require evaluating the node that contains
853       // it, but it requires to evaluate all the parents of the nodes that
854       // contains it. For instance, an saving a bound in hlfir.forall B does not
855       // require creating the loops for B, but it requires creating the loops
856       // for any forall parent A of the forall B.
857       if (node->isProperAncestor(savedEntity->yieldRegion->getParentOp()))
858         return true;
859     } else {
860       auto assign = std::get<hlfir::RegionAssignOp>(action);
861       if (node->isAncestor(assign.getOperation()))
862         return true;
863     }
864   return false;
865 }
866 
867 /// Is the apply using all the elemental indices in order?
868 static bool isInOrderApply(hlfir::ApplyOp apply,
869                            hlfir::ElementalOpInterface elemental) {
870   mlir::Region::BlockArgListType elementalIndices = elemental.getIndices();
871   if (elementalIndices.size() != apply.getIndices().size())
872     return false;
873   for (auto [elementalIdx, applyIdx] :
874        llvm::zip(elementalIndices, apply.getIndices()))
875     if (elementalIdx != applyIdx)
876       return false;
877   return true;
878 }
879 
880 /// Gather the tree of hlfir::ElementalOpInterface use-def, if any, starting
881 /// from \p elemental, which may be a nullptr.
882 static void
883 gatherElementalTree(hlfir::ElementalOpInterface elemental,
884                     llvm::SmallPtrSetImpl<mlir::Operation *> &elementalOps,
885                     bool isOutOfOrder) {
886   if (elemental) {
887     // Only inline an applied elemental that must be executed in order if the
888     // applying indices are in order. An hlfir::Elemental may have been created
889     // for a transformational like transpose, and Fortran 2018 standard
890     // section 10.2.3.2, point 10 imply that impure elemental sub-expression
891     // evaluations should not be masked if they are the arguments of
892     // transformational expressions.
893     if (isOutOfOrder && elemental.isOrdered())
894       return;
895     elementalOps.insert(elemental.getOperation());
896     for (mlir::Operation &op : elemental.getElementalRegion().getOps())
897       if (auto apply = mlir::dyn_cast<hlfir::ApplyOp>(op)) {
898         bool isUnorderedApply =
899             isOutOfOrder || !isInOrderApply(apply, elemental);
900         auto maybeElemental =
901             mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>(
902                 apply.getExpr().getDefiningOp());
903         gatherElementalTree(maybeElemental, elementalOps, isUnorderedApply);
904       }
905   }
906 }
907 
908 MaskedArrayExpr::MaskedArrayExpr(mlir::Location loc, mlir::Region &region,
909                                  bool isOuterMaskExpr)
910     : loc{loc}, region{region}, isOuterMaskExpr{isOuterMaskExpr} {
911   mlir::Operation &terminator = region.back().back();
912   if (auto elementalAddr =
913           mlir::dyn_cast<hlfir::ElementalOpInterface>(terminator)) {
914     // Vector subscripted designator (hlfir.elemental_addr terminator).
915     gatherElementalTree(elementalAddr, elementalParts, /*isOutOfOrder=*/false);
916     return;
917   }
918   // Try if elemental expression.
919   mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity();
920   auto maybeElemental = mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>(
921       entity.getDefiningOp());
922   gatherElementalTree(maybeElemental, elementalParts, /*isOutOfOrder=*/false);
923 }
924 
925 void MaskedArrayExpr::generateNoneElementalPart(fir::FirOpBuilder &builder,
926                                                 mlir::IRMapping &mapper) {
927   assert(!noneElementalPartWasGenerated &&
928          "none elemental parts already generated");
929   if (isOuterMaskExpr) {
930     // The outer mask expression is actually not masked, it is dealt as
931     // such so that its elemental part, if any, can be inlined in the WHERE
932     // loops. But all of the operations outside of hlfir.elemental/
933     // hlfir.elemental_addr must be emitted now because their value may be
934     // required to deduce the mask shape and the WHERE loop bounds.
935     for (mlir::Operation &op : region.back().without_terminator())
936       if (!elementalParts.contains(&op))
937         (void)builder.clone(op, mapper);
938   } else {
939     // For actual masked expressions, Fortran requires elemental expressions,
940     // even the scalar ones that are not encoded with hlfir.elemental, to be
941     // evaluated only when the mask is true. Blindly hoisting all scalar SSA
942     // tree could be wrong if the scalar computation has side effects and
943     // would never have been evaluated (e.g. division by zero) if the mask
944     // is fully false. See F'2023 10.2.3.2 point 10.
945     // Clone only the bodies of all hlfir.exactly_once operations, which contain
946     // the evaluation of sub-expression tree whose root was a non elemental
947     // function call at the Fortran level (the call itself may have been inlined
948     // since). These must be evaluated only once as per F'2023 10.2.3.2 point 9.
949     for (mlir::Operation &op : region.back().without_terminator())
950       if (auto exactlyOnce = mlir::dyn_cast<hlfir::ExactlyOnceOp>(op)) {
951         for (mlir::Operation &subOp :
952              exactlyOnce.getBody().back().without_terminator())
953           (void)builder.clone(subOp, mapper);
954         mlir::Value oldYield = getYield(exactlyOnce.getBody()).getEntity();
955         auto newYield = mapper.lookupOrDefault(oldYield);
956         mapper.map(exactlyOnce.getResult(), newYield);
957       }
958   }
959   noneElementalPartWasGenerated = true;
960 }
961 
962 mlir::Value MaskedArrayExpr::generateShape(fir::FirOpBuilder &builder,
963                                            mlir::IRMapping &mapper) {
964   assert(noneElementalPartWasGenerated &&
965          "non elemental part must have been generated");
966   mlir::Operation &terminator = region.back().back();
967   // If the operation that produced the yielded entity is elemental, it was not
968   // cloned, but it holds a shape argument that was cloned. Return the cloned
969   // shape.
970   if (auto elementalAddrOp = mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator))
971     return mapper.lookupOrDefault(elementalAddrOp.getShape());
972   mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity();
973   if (auto elemental = entity.getDefiningOp<hlfir::ElementalOp>())
974     return mapper.lookupOrDefault(elemental.getShape());
975   // Otherwise, the whole entity was cloned, and the shape can be generated
976   // from it.
977   hlfir::Entity clonedEntity{mapper.lookupOrDefault(entity)};
978   return hlfir::genShape(loc, builder, hlfir::Entity{clonedEntity});
979 }
980 
981 mlir::Value
982 MaskedArrayExpr::generateElementalParts(fir::FirOpBuilder &builder,
983                                         mlir::ValueRange oneBasedIndices,
984                                         mlir::IRMapping &mapper) {
985   assert(noneElementalPartWasGenerated &&
986          "non elemental part must have been generated");
987   if (!isOuterMaskExpr) {
988     // Clone all operations that are not hlfir.exactly_once and that are not
989     // hlfir.elemental/hlfir.elemental_addr.
990     for (mlir::Operation &op : region.back().without_terminator())
991       if (!mlir::isa<hlfir::ExactlyOnceOp>(op) && !elementalParts.contains(&op))
992         (void)builder.clone(op, mapper);
993     // For the outer mask, this was already done outside of the loop.
994   }
995   // Clone and "index" bodies of hlfir.elemental/hlfir.elemental_addr.
996   mlir::Operation &terminator = region.back().back();
997   hlfir::ElementalOpInterface elemental =
998       mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator);
999   if (!elemental) {
1000     // If the terminator is not an hlfir.elemental_addr, try if the yielded
1001     // entity was produced by an hlfir.elemental.
1002     mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity();
1003     elemental = entity.getDefiningOp<hlfir::ElementalOp>();
1004     if (!elemental) {
1005       // The yielded entity was not produced by an elemental operation,
1006       // get its clone in the non elemental part evaluation and address it.
1007       hlfir::Entity clonedEntity{mapper.lookupOrDefault(entity)};
1008       return hlfir::getElementAt(loc, builder, clonedEntity, oneBasedIndices);
1009     }
1010   }
1011 
1012   auto mustRecursivelyInline =
1013       [&](hlfir::ElementalOp appliedElemental) -> bool {
1014     return elementalParts.contains(appliedElemental.getOperation());
1015   };
1016   return inlineElementalOp(loc, builder, elemental, oneBasedIndices, mapper,
1017                            mustRecursivelyInline);
1018 }
1019 
1020 mlir::Operation *
1021 MaskedArrayExpr::generateMaskedExprCleanUps(fir::FirOpBuilder &builder,
1022                                             mlir::IRMapping &mapper) {
1023   // Clone the clean-ups from the region itself, except for the destroy
1024   // of the hlfir.elemental that have been inlined.
1025   mlir::Operation &terminator = region.back().back();
1026   mlir::Region *cleanupRegion = nullptr;
1027   if (auto elementalAddr = mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator)) {
1028     cleanupRegion = &elementalAddr.getCleanup();
1029   } else {
1030     auto yieldOp = mlir::cast<hlfir::YieldOp>(terminator);
1031     cleanupRegion = &yieldOp.getCleanup();
1032   }
1033   if (cleanupRegion->empty())
1034     return nullptr;
1035   mlir::Operation *firstNewCleanup = nullptr;
1036   for (mlir::Operation &op : cleanupRegion->front().without_terminator()) {
1037     if (auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(op))
1038       if (elementalParts.contains(destroy.getExpr().getDefiningOp()))
1039         continue;
1040     mlir::Operation *cleanup = builder.clone(op, mapper);
1041     if (!firstNewCleanup)
1042       firstNewCleanup = cleanup;
1043   }
1044   return firstNewCleanup;
1045 }
1046 
1047 void MaskedArrayExpr::generateNoneElementalCleanupIfAny(
1048     fir::FirOpBuilder &builder, mlir::IRMapping &mapper) {
1049   if (!isOuterMaskExpr) {
1050     // Clone clean-ups of hlfir.exactly_once operations (in reverse order
1051     // to properly deal with stack restores).
1052     for (mlir::Operation &op :
1053          llvm::reverse(region.back().without_terminator()))
1054       if (auto exactlyOnce = mlir::dyn_cast<hlfir::ExactlyOnceOp>(op)) {
1055         mlir::Region &cleanupRegion =
1056             getYield(exactlyOnce.getBody()).getCleanup();
1057         if (!cleanupRegion.empty())
1058           for (mlir::Operation &cleanupOp :
1059                cleanupRegion.front().without_terminator())
1060             (void)builder.clone(cleanupOp, mapper);
1061       }
1062   } else {
1063     // For the outer mask, the region clean-ups must be generated
1064     // outside of the loops since the mask non hlfir.elemental part
1065     // is generated before the loops.
1066     generateMaskedExprCleanUps(builder, mapper);
1067   }
1068 }
1069 
1070 static hlfir::RegionAssignOp
1071 getAssignIfLeftHandSideRegion(mlir::Region &region) {
1072   auto assign = mlir::dyn_cast<hlfir::RegionAssignOp>(region.getParentOp());
1073   if (assign && (&assign.getLhsRegion() == &region))
1074     return assign;
1075   return nullptr;
1076 }
1077 
1078 bool OrderedAssignmentRewriter::currentLoopNestIterationNumberCanBeComputed(
1079     llvm::SmallVectorImpl<fir::DoLoopOp> &loopNest) {
1080   if (constructStack.empty())
1081     return true;
1082   mlir::Operation *outerLoop = constructStack[0];
1083   mlir::Operation *currentConstruct = constructStack.back();
1084   // Loop through the loops until the outer construct is met, and test if the
1085   // loop operands dominate the outer construct.
1086   while (currentConstruct) {
1087     if (auto doLoop = mlir::dyn_cast<fir::DoLoopOp>(currentConstruct)) {
1088       if (llvm::any_of(doLoop->getOperands(), [&](mlir::Value value) {
1089             return !dominanceInfo.properlyDominates(value, outerLoop);
1090           })) {
1091         return false;
1092       }
1093       loopNest.push_back(doLoop);
1094     }
1095     if (currentConstruct == outerLoop)
1096       currentConstruct = nullptr;
1097     else
1098       currentConstruct = currentConstruct->getParentOp();
1099   }
1100   return true;
1101 }
1102 
1103 static mlir::Value
1104 computeLoopNestIterationNumber(mlir::Location loc, fir::FirOpBuilder &builder,
1105                                llvm::ArrayRef<fir::DoLoopOp> loopNest) {
1106   mlir::Value loopExtent;
1107   for (fir::DoLoopOp doLoop : loopNest) {
1108     mlir::Value extent = builder.genExtentFromTriplet(
1109         loc, doLoop.getLowerBound(), doLoop.getUpperBound(), doLoop.getStep(),
1110         builder.getIndexType());
1111     if (!loopExtent)
1112       loopExtent = extent;
1113     else
1114       loopExtent = builder.create<mlir::arith::MulIOp>(loc, loopExtent, extent);
1115   }
1116   assert(loopExtent && "loopNest must not be empty");
1117   return loopExtent;
1118 }
1119 
1120 /// Return a name for temporary storage that indicates in which context
1121 /// the temporary storage was created.
1122 static llvm::StringRef
1123 getTempName(hlfir::OrderedAssignmentTreeOpInterface root) {
1124   if (mlir::isa<hlfir::ForallOp>(root.getOperation()))
1125     return ".tmp.forall";
1126   if (mlir::isa<hlfir::WhereOp>(root.getOperation()))
1127     return ".tmp.where";
1128   return ".tmp.assign";
1129 }
1130 
1131 void OrderedAssignmentRewriter::generateSaveEntity(
1132     hlfir::SaveEntity savedEntity, bool willUseSavedEntityInSameRun) {
1133   mlir::Region &region = *savedEntity.yieldRegion;
1134 
1135   if (hlfir::RegionAssignOp regionAssignOp =
1136           getAssignIfLeftHandSideRegion(region)) {
1137     // Need to save the address, not the values.
1138     assert(!willUseSavedEntityInSameRun &&
1139            "lhs cannot be used in the loop nest where it is saved");
1140     return saveLeftHandSide(savedEntity, regionAssignOp);
1141   }
1142 
1143   mlir::Location loc = region.getParentOp()->getLoc();
1144   // Evaluate the region inside the loop nest (if any).
1145   auto [clonedValue, oldYield] = generateYieldedEntity(region);
1146   hlfir::Entity entity{clonedValue};
1147   entity = hlfir::loadTrivialScalar(loc, builder, entity);
1148   mlir::Type entityType = entity.getType();
1149 
1150   llvm::StringRef tempName = getTempName(root);
1151   fir::factory::TemporaryStorage *temp = nullptr;
1152   if (constructStack.empty()) {
1153     // Value evaluated outside of any loops (this may be the first MASK of a
1154     // WHERE construct, or an LHS/RHS temp of hlfir.region_assign outside of
1155     // WHERE/FORALL).
1156     temp = insertSavedEntity(
1157         region, fir::factory::SimpleCopy(loc, builder, entity, tempName));
1158   } else {
1159     // Need to create a temporary for values computed inside loops.
1160     // Create temporary storage outside of the loop nest given the entity
1161     // type (and the loop context).
1162     llvm::SmallVector<fir::DoLoopOp> loopNest;
1163     bool loopShapeCanBePreComputed =
1164         currentLoopNestIterationNumberCanBeComputed(loopNest);
1165     doBeforeLoopNest([&] {
1166       /// For simple scalars inside loops whose total iteration number can be
1167       /// pre-computed, create a rank-1 array outside of the loops. It will be
1168       /// assigned/fetched inside the loops like a normal Fortran array given
1169       /// the iteration count.
1170       if (loopShapeCanBePreComputed && fir::isa_trivial(entityType)) {
1171         mlir::Value loopExtent =
1172             computeLoopNestIterationNumber(loc, builder, loopNest);
1173         auto sequenceType =
1174             mlir::cast<fir::SequenceType>(builder.getVarLenSeqTy(entityType));
1175         temp = insertSavedEntity(region,
1176                                  fir::factory::HomogeneousScalarStack{
1177                                      loc, builder, sequenceType, loopExtent,
1178                                      /*lenParams=*/{}, allocateOnHeap,
1179                                      /*stackThroughLoops=*/true, tempName});
1180 
1181       } else {
1182         // If the number of iteration is not known, or if the values at each
1183         // iterations are values that may have different shape, type parameters
1184         // or dynamic type, use the runtime to create and manage a stack-like
1185         // temporary.
1186         temp = insertSavedEntity(
1187             region, fir::factory::AnyValueStack{loc, builder, entityType});
1188       }
1189     });
1190     // Inside the loop nest (and any fir.if if there are active masks), copy
1191     // the value to the temp and do clean-ups for the value if any.
1192     temp->pushValue(loc, builder, entity);
1193   }
1194 
1195   // Delay the clean-up if the entity will be used in the same run (i.e., the
1196   // parent construct will be visited and needs to be lowered). When possible,
1197   // this is not done for hlfir.expr because this use would prevent the
1198   // hlfir.expr storage from being moved when creating the temporary in
1199   // bufferization, and that would lead to an extra copy.
1200   if (willUseSavedEntityInSameRun &&
1201       (!temp->canBeFetchedAfterPush() ||
1202        !mlir::isa<hlfir::ExprType>(entity.getType()))) {
1203     auto inserted =
1204         savedInCurrentRunBeforeUse.try_emplace(&region, entity, oldYield);
1205     assert(inserted.second && "entity must have been emplaced");
1206     (void)inserted;
1207   } else {
1208     if (constructStack.empty() &&
1209         mlir::isa<hlfir::RegionAssignOp>(region.getParentOp())) {
1210       // Here the clean-up code is inserted after the original
1211       // RegionAssignOp, so that the assignment code happens
1212       // before the cleanup. We do this only for standalone
1213       // operations, because the clean-up is handled specially
1214       // during lowering of the parent constructs if any
1215       // (e.g. see generateNoneElementalCleanupIfAny for
1216       // WhereOp).
1217       auto insertionPoint = builder.saveInsertionPoint();
1218       builder.setInsertionPointAfter(region.getParentOp());
1219       generateCleanupIfAny(oldYield);
1220       builder.restoreInsertionPoint(insertionPoint);
1221     } else {
1222       generateCleanupIfAny(oldYield);
1223     }
1224   }
1225 }
1226 
1227 static bool rhsIsArray(hlfir::RegionAssignOp regionAssignOp) {
1228   auto yieldOp = mlir::dyn_cast<hlfir::YieldOp>(
1229       regionAssignOp.getRhsRegion().back().back());
1230   return yieldOp && hlfir::Entity{yieldOp.getEntity()}.isArray();
1231 }
1232 
1233 void OrderedAssignmentRewriter::saveLeftHandSide(
1234     hlfir::SaveEntity savedEntity, hlfir::RegionAssignOp regionAssignOp) {
1235   mlir::Region &region = *savedEntity.yieldRegion;
1236   mlir::Location loc = region.getParentOp()->getLoc();
1237   LhsValueAndCleanUp loweredLhs = generateYieldedLHS(loc, region);
1238   fir::factory::TemporaryStorage *temp = nullptr;
1239   if (loweredLhs.vectorSubscriptLoopNest)
1240     constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerOp);
1241   if (loweredLhs.vectorSubscriptLoopNest && !rhsIsArray(regionAssignOp)) {
1242     // Vector subscripted entity for which the shape must also be saved on top
1243     // of the element addresses (e.g. the shape may change in each forall
1244     // iteration and is needed to create the elemental loops).
1245     mlir::Value shape = loweredLhs.vectorSubscriptShape.value();
1246     int rank = mlir::cast<fir::ShapeType>(shape.getType()).getRank();
1247     const bool shapeIsInvariant =
1248         constructStack.empty() ||
1249         dominanceInfo.properlyDominates(shape, constructStack[0]);
1250     doBeforeLoopNest([&] {
1251       // Outside of any forall/where/elemental loops, create a temporary that
1252       // will both be able to save the vector subscripted designator shape(s)
1253       // and element addresses.
1254       temp =
1255           insertSavedEntity(region, fir::factory::AnyVectorSubscriptStack{
1256                                         loc, builder, loweredLhs.lhs.getType(),
1257                                         shapeIsInvariant, rank});
1258     });
1259     // Save shape before the elemental loop nest created by the vector
1260     // subscripted LHS.
1261     auto &vectorTmp = temp->cast<fir::factory::AnyVectorSubscriptStack>();
1262     auto insertionPoint = builder.saveInsertionPoint();
1263     builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerOp);
1264     vectorTmp.pushShape(loc, builder, shape);
1265     builder.restoreInsertionPoint(insertionPoint);
1266   } else {
1267     // Otherwise, only save the LHS address.
1268     // If the LHS address dominates the constructs, its SSA value can
1269     // simply be tracked and there is no need to save the address in memory.
1270     // Otherwise, the addresses are stored at each iteration in memory with
1271     // a descriptor stack.
1272     if (constructStack.empty() ||
1273         dominanceInfo.properlyDominates(loweredLhs.lhs, constructStack[0]))
1274       doBeforeLoopNest([&] {
1275         temp = insertSavedEntity(region, fir::factory::SSARegister{});
1276       });
1277     else
1278       doBeforeLoopNest([&] {
1279         temp = insertSavedEntity(
1280             region, fir::factory::AnyVariableStack{loc, builder,
1281                                                    loweredLhs.lhs.getType()});
1282       });
1283   }
1284   temp->pushValue(loc, builder, loweredLhs.lhs);
1285   generateCleanupIfAny(loweredLhs.elementalCleanup);
1286   if (loweredLhs.vectorSubscriptLoopNest) {
1287     constructStack.pop_back();
1288     builder.setInsertionPointAfter(loweredLhs.vectorSubscriptLoopNest->outerOp);
1289   }
1290 }
1291 
1292 /// Lower an ordered assignment tree to fir.do_loop and hlfir.assign given
1293 /// a schedule.
1294 static void lower(hlfir::OrderedAssignmentTreeOpInterface root,
1295                   mlir::PatternRewriter &rewriter, hlfir::Schedule &schedule) {
1296   auto module = root->getParentOfType<mlir::ModuleOp>();
1297   fir::FirOpBuilder builder(rewriter, module);
1298   OrderedAssignmentRewriter assignmentRewriter(builder, root);
1299   for (auto &run : schedule)
1300     assignmentRewriter.lowerRun(run);
1301   assignmentRewriter.cleanupSavedEntities();
1302 }
1303 
1304 /// Shared rewrite entry point for all the ordered assignment tree root
1305 /// operations. It calls the scheduler and then apply the schedule.
1306 static llvm::LogicalResult rewrite(hlfir::OrderedAssignmentTreeOpInterface root,
1307                                    bool tryFusingAssignments,
1308                                    mlir::PatternRewriter &rewriter) {
1309   hlfir::Schedule schedule =
1310       hlfir::buildEvaluationSchedule(root, tryFusingAssignments);
1311 
1312   LLVM_DEBUG(
1313       /// Debug option to print the scheduling debug info without doing
1314       /// any code generation. The operations are simply erased to avoid
1315       /// failing and calling the rewrite patterns on nested operations.
1316       /// The only purpose of this is to help testing scheduling without
1317       /// having to test generated code.
1318       if (dbgScheduleOnly) {
1319         rewriter.eraseOp(root);
1320         return mlir::success();
1321       });
1322   lower(root, rewriter, schedule);
1323   rewriter.eraseOp(root);
1324   return mlir::success();
1325 }
1326 
1327 namespace {
1328 
1329 class ForallOpConversion : public mlir::OpRewritePattern<hlfir::ForallOp> {
1330 public:
1331   explicit ForallOpConversion(mlir::MLIRContext *ctx, bool tryFusingAssignments)
1332       : OpRewritePattern{ctx}, tryFusingAssignments{tryFusingAssignments} {}
1333 
1334   llvm::LogicalResult
1335   matchAndRewrite(hlfir::ForallOp forallOp,
1336                   mlir::PatternRewriter &rewriter) const override {
1337     auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>(
1338         forallOp.getOperation());
1339     if (mlir::failed(::rewrite(root, tryFusingAssignments, rewriter)))
1340       TODO(forallOp.getLoc(), "FORALL construct or statement in HLFIR");
1341     return mlir::success();
1342   }
1343   const bool tryFusingAssignments;
1344 };
1345 
1346 class WhereOpConversion : public mlir::OpRewritePattern<hlfir::WhereOp> {
1347 public:
1348   explicit WhereOpConversion(mlir::MLIRContext *ctx, bool tryFusingAssignments)
1349       : OpRewritePattern{ctx}, tryFusingAssignments{tryFusingAssignments} {}
1350 
1351   llvm::LogicalResult
1352   matchAndRewrite(hlfir::WhereOp whereOp,
1353                   mlir::PatternRewriter &rewriter) const override {
1354     auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>(
1355         whereOp.getOperation());
1356     return ::rewrite(root, tryFusingAssignments, rewriter);
1357   }
1358   const bool tryFusingAssignments;
1359 };
1360 
1361 class RegionAssignConversion
1362     : public mlir::OpRewritePattern<hlfir::RegionAssignOp> {
1363 public:
1364   explicit RegionAssignConversion(mlir::MLIRContext *ctx)
1365       : OpRewritePattern{ctx} {}
1366 
1367   llvm::LogicalResult
1368   matchAndRewrite(hlfir::RegionAssignOp regionAssignOp,
1369                   mlir::PatternRewriter &rewriter) const override {
1370     auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>(
1371         regionAssignOp.getOperation());
1372     return ::rewrite(root, /*tryFusingAssignments=*/false, rewriter);
1373   }
1374 };
1375 
1376 class LowerHLFIROrderedAssignments
1377     : public hlfir::impl::LowerHLFIROrderedAssignmentsBase<
1378           LowerHLFIROrderedAssignments> {
1379 public:
1380   using LowerHLFIROrderedAssignmentsBase<
1381       LowerHLFIROrderedAssignments>::LowerHLFIROrderedAssignmentsBase;
1382 
1383   void runOnOperation() override {
1384     // Running on a ModuleOp because this pass may generate FuncOp declaration
1385     // for runtime calls. This could be a FuncOp pass otherwise.
1386     auto module = this->getOperation();
1387     auto *context = &getContext();
1388     mlir::RewritePatternSet patterns(context);
1389     // Patterns are only defined for the OrderedAssignmentTreeOpInterface
1390     // operations that can be the root of ordered assignments. The other
1391     // operations will be taken care of while rewriting these trees (they
1392     // cannot exist outside of these operations given their verifiers/traits).
1393     patterns.insert<ForallOpConversion, WhereOpConversion>(
1394         context, this->tryFusingAssignments.getValue());
1395     patterns.insert<RegionAssignConversion>(context);
1396     mlir::ConversionTarget target(*context);
1397     target.markUnknownOpDynamicallyLegal([](mlir::Operation *op) {
1398       return !mlir::isa<hlfir::OrderedAssignmentTreeOpInterface>(op);
1399     });
1400     if (mlir::failed(mlir::applyPartialConversion(module, target,
1401                                                   std::move(patterns)))) {
1402       mlir::emitError(mlir::UnknownLoc::get(context),
1403                       "failure in HLFIR ordered assignments lowering pass");
1404       signalPassFailure();
1405     }
1406   }
1407 };
1408 } // namespace
1409