xref: /llvm-project/flang/lib/Optimizer/Transforms/StackArrays.cpp (revision 4b3f251bada55cfc20a2c72321fa0bbfd7a759d5)
1 //===- StackArrays.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
11 #include "flang/Optimizer/Dialect/FIRAttr.h"
12 #include "flang/Optimizer/Dialect/FIRDialect.h"
13 #include "flang/Optimizer/Dialect/FIROps.h"
14 #include "flang/Optimizer/Dialect/FIRType.h"
15 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
16 #include "flang/Optimizer/Transforms/Passes.h"
17 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
18 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
19 #include "mlir/Analysis/DataFlow/DenseAnalysis.h"
20 #include "mlir/Analysis/DataFlowFramework.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/Value.h"
26 #include "mlir/Interfaces/LoopLikeInterface.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29 #include "mlir/Transforms/Passes.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/PointerUnion.h"
33 #include "llvm/Support/Casting.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include <optional>
36 
37 namespace fir {
38 #define GEN_PASS_DEF_STACKARRAYS
39 #include "flang/Optimizer/Transforms/Passes.h.inc"
40 } // namespace fir
41 
42 #define DEBUG_TYPE "stack-arrays"
43 
44 static llvm::cl::opt<std::size_t> maxAllocsPerFunc(
45     "stack-arrays-max-allocs",
46     llvm::cl::desc("The maximum number of heap allocations to consider in one "
47                    "function before skipping (to save compilation time). Set "
48                    "to 0 for no limit."),
49     llvm::cl::init(1000), llvm::cl::Hidden);
50 
51 namespace {
52 
53 /// The state of an SSA value at each program point
54 enum class AllocationState {
55   /// This means that the allocation state of a variable cannot be determined
56   /// at this program point, e.g. because one route through a conditional freed
57   /// the variable and the other route didn't.
58   /// This asserts a known-unknown: different from the unknown-unknown of having
59   /// no AllocationState stored for a particular SSA value
60   Unknown,
61   /// Means this SSA value was allocated on the heap in this function and has
62   /// now been freed
63   Freed,
64   /// Means this SSA value was allocated on the heap in this function and is a
65   /// candidate for moving to the stack
66   Allocated,
67 };
68 
69 /// Stores where an alloca should be inserted. If the PointerUnion is an
70 /// Operation the alloca should be inserted /after/ the operation. If it is a
71 /// block, the alloca can be placed anywhere in that block.
72 class InsertionPoint {
73   llvm::PointerUnion<mlir::Operation *, mlir::Block *> location;
74   bool saveRestoreStack;
75 
76   /// Get contained pointer type or nullptr
77   template <class T>
78   T *tryGetPtr() const {
79     if (location.is<T *>())
80       return location.get<T *>();
81     return nullptr;
82   }
83 
84 public:
85   template <class T>
86   InsertionPoint(T *ptr, bool saveRestoreStack = false)
87       : location(ptr), saveRestoreStack{saveRestoreStack} {}
88   InsertionPoint(std::nullptr_t null)
89       : location(null), saveRestoreStack{false} {}
90 
91   /// Get contained operation, or nullptr
92   mlir::Operation *tryGetOperation() const {
93     return tryGetPtr<mlir::Operation>();
94   }
95 
96   /// Get contained block, or nullptr
97   mlir::Block *tryGetBlock() const { return tryGetPtr<mlir::Block>(); }
98 
99   /// Get whether the stack should be saved/restored. If yes, an llvm.stacksave
100   /// intrinsic should be added before the alloca, and an llvm.stackrestore
101   /// intrinsic should be added where the freemem is
102   bool shouldSaveRestoreStack() const { return saveRestoreStack; }
103 
104   operator bool() const { return tryGetOperation() || tryGetBlock(); }
105 
106   bool operator==(const InsertionPoint &rhs) const {
107     return (location == rhs.location) &&
108            (saveRestoreStack == rhs.saveRestoreStack);
109   }
110 
111   bool operator!=(const InsertionPoint &rhs) const { return !(*this == rhs); }
112 };
113 
114 /// Maps SSA values to their AllocationState at a particular program point.
115 /// Also caches the insertion points for the new alloca operations
116 class LatticePoint : public mlir::dataflow::AbstractDenseLattice {
117   // Maps all values we are interested in to states
118   llvm::SmallDenseMap<mlir::Value, AllocationState, 1> stateMap;
119 
120 public:
121   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LatticePoint)
122   using AbstractDenseLattice::AbstractDenseLattice;
123 
124   bool operator==(const LatticePoint &rhs) const {
125     return stateMap == rhs.stateMap;
126   }
127 
128   /// Join the lattice accross control-flow edges
129   mlir::ChangeResult join(const AbstractDenseLattice &lattice) override;
130 
131   void print(llvm::raw_ostream &os) const override;
132 
133   /// Clear all modifications
134   mlir::ChangeResult reset();
135 
136   /// Set the state of an SSA value
137   mlir::ChangeResult set(mlir::Value value, AllocationState state);
138 
139   /// Get fir.allocmem ops which were allocated in this function and always
140   /// freed before the function returns, plus whre to insert replacement
141   /// fir.alloca ops
142   void appendFreedValues(llvm::DenseSet<mlir::Value> &out) const;
143 
144   std::optional<AllocationState> get(mlir::Value val) const;
145 };
146 
147 class AllocationAnalysis
148     : public mlir::dataflow::DenseForwardDataFlowAnalysis<LatticePoint> {
149 public:
150   using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis;
151 
152   mlir::LogicalResult visitOperation(mlir::Operation *op,
153                                      const LatticePoint &before,
154                                      LatticePoint *after) override;
155 
156   /// At an entry point, the last modifications of all memory resources are
157   /// yet to be determined
158   void setToEntryState(LatticePoint *lattice) override;
159 
160 protected:
161   /// Visit control flow operations and decide whether to call visitOperation
162   /// to apply the transfer function
163   mlir::LogicalResult processOperation(mlir::Operation *op) override;
164 };
165 
166 /// Drives analysis to find candidate fir.allocmem operations which could be
167 /// moved to the stack. Intended to be used with mlir::Pass::getAnalysis
168 class StackArraysAnalysisWrapper {
169 public:
170   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StackArraysAnalysisWrapper)
171 
172   // Maps fir.allocmem -> place to insert alloca
173   using AllocMemMap = llvm::DenseMap<mlir::Operation *, InsertionPoint>;
174 
175   StackArraysAnalysisWrapper(mlir::Operation *op) {}
176 
177   // returns nullptr if analysis failed
178   const AllocMemMap *getCandidateOps(mlir::Operation *func);
179 
180 private:
181   llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps;
182 
183   llvm::LogicalResult analyseFunction(mlir::Operation *func);
184 };
185 
186 /// Converts a fir.allocmem to a fir.alloca
187 class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
188 public:
189   explicit AllocMemConversion(
190       mlir::MLIRContext *ctx,
191       const StackArraysAnalysisWrapper::AllocMemMap &candidateOps)
192       : OpRewritePattern(ctx), candidateOps{candidateOps} {}
193 
194   llvm::LogicalResult
195   matchAndRewrite(fir::AllocMemOp allocmem,
196                   mlir::PatternRewriter &rewriter) const override;
197 
198   /// Determine where to insert the alloca operation. The returned value should
199   /// be checked to see if it is inside a loop
200   static InsertionPoint findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc);
201 
202 private:
203   /// Handle to the DFA (already run)
204   const StackArraysAnalysisWrapper::AllocMemMap &candidateOps;
205 
206   /// If we failed to find an insertion point not inside a loop, see if it would
207   /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop
208   static InsertionPoint findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc);
209 
210   /// Returns the alloca if it was successfully inserted, otherwise {}
211   std::optional<fir::AllocaOp>
212   insertAlloca(fir::AllocMemOp &oldAlloc,
213                mlir::PatternRewriter &rewriter) const;
214 
215   /// Inserts a stacksave before oldAlloc and a stackrestore after each freemem
216   void insertStackSaveRestore(fir::AllocMemOp &oldAlloc,
217                               mlir::PatternRewriter &rewriter) const;
218 };
219 
220 class StackArraysPass : public fir::impl::StackArraysBase<StackArraysPass> {
221 public:
222   StackArraysPass() = default;
223   StackArraysPass(const StackArraysPass &pass);
224 
225   llvm::StringRef getDescription() const override;
226 
227   void runOnOperation() override;
228 
229 private:
230   Statistic runCount{this, "stackArraysRunCount",
231                      "Number of heap allocations moved to the stack"};
232 };
233 
234 } // namespace
235 
236 static void print(llvm::raw_ostream &os, AllocationState state) {
237   switch (state) {
238   case AllocationState::Unknown:
239     os << "Unknown";
240     break;
241   case AllocationState::Freed:
242     os << "Freed";
243     break;
244   case AllocationState::Allocated:
245     os << "Allocated";
246     break;
247   }
248 }
249 
250 /// Join two AllocationStates for the same value coming from different CFG
251 /// blocks
252 static AllocationState join(AllocationState lhs, AllocationState rhs) {
253   //           | Allocated | Freed     | Unknown
254   // ========= | ========= | ========= | =========
255   // Allocated | Allocated | Unknown   | Unknown
256   // Freed     | Unknown   | Freed     | Unknown
257   // Unknown   | Unknown   | Unknown   | Unknown
258   if (lhs == rhs)
259     return lhs;
260   return AllocationState::Unknown;
261 }
262 
263 mlir::ChangeResult LatticePoint::join(const AbstractDenseLattice &lattice) {
264   const auto &rhs = static_cast<const LatticePoint &>(lattice);
265   mlir::ChangeResult changed = mlir::ChangeResult::NoChange;
266 
267   // add everything from rhs to map, handling cases where values are in both
268   for (const auto &[value, rhsState] : rhs.stateMap) {
269     auto it = stateMap.find(value);
270     if (it != stateMap.end()) {
271       // value is present in both maps
272       AllocationState myState = it->second;
273       AllocationState newState = ::join(myState, rhsState);
274       if (newState != myState) {
275         changed = mlir::ChangeResult::Change;
276         it->getSecond() = newState;
277       }
278     } else {
279       // value not present in current map: add it
280       stateMap.insert({value, rhsState});
281       changed = mlir::ChangeResult::Change;
282     }
283   }
284 
285   return changed;
286 }
287 
288 void LatticePoint::print(llvm::raw_ostream &os) const {
289   for (const auto &[value, state] : stateMap) {
290     os << "\n * " << value << ": ";
291     ::print(os, state);
292   }
293 }
294 
295 mlir::ChangeResult LatticePoint::reset() {
296   if (stateMap.empty())
297     return mlir::ChangeResult::NoChange;
298   stateMap.clear();
299   return mlir::ChangeResult::Change;
300 }
301 
302 mlir::ChangeResult LatticePoint::set(mlir::Value value, AllocationState state) {
303   if (stateMap.count(value)) {
304     // already in map
305     AllocationState &oldState = stateMap[value];
306     if (oldState != state) {
307       stateMap[value] = state;
308       return mlir::ChangeResult::Change;
309     }
310     return mlir::ChangeResult::NoChange;
311   }
312   stateMap.insert({value, state});
313   return mlir::ChangeResult::Change;
314 }
315 
316 /// Get values which were allocated in this function and always freed before
317 /// the function returns
318 void LatticePoint::appendFreedValues(llvm::DenseSet<mlir::Value> &out) const {
319   for (auto &[value, state] : stateMap) {
320     if (state == AllocationState::Freed)
321       out.insert(value);
322   }
323 }
324 
325 std::optional<AllocationState> LatticePoint::get(mlir::Value val) const {
326   auto it = stateMap.find(val);
327   if (it == stateMap.end())
328     return {};
329   return it->second;
330 }
331 
332 mlir::LogicalResult AllocationAnalysis::visitOperation(
333     mlir::Operation *op, const LatticePoint &before, LatticePoint *after) {
334   LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op
335                           << "\n");
336   LLVM_DEBUG(llvm::dbgs() << "--Lattice in: " << before << "\n");
337 
338   // propagate before -> after
339   mlir::ChangeResult changed = after->join(before);
340 
341   if (auto allocmem = mlir::dyn_cast<fir::AllocMemOp>(op)) {
342     assert(op->getNumResults() == 1 && "fir.allocmem has one result");
343     auto attr = op->getAttrOfType<fir::MustBeHeapAttr>(
344         fir::MustBeHeapAttr::getAttrName());
345     if (attr && attr.getValue()) {
346       LLVM_DEBUG(llvm::dbgs() << "--Found fir.must_be_heap: skipping\n");
347       // skip allocation marked not to be moved
348       return mlir::success();
349     }
350 
351     auto retTy = allocmem.getAllocatedType();
352     if (!mlir::isa<fir::SequenceType>(retTy)) {
353       LLVM_DEBUG(llvm::dbgs()
354                  << "--Allocation is not for an array: skipping\n");
355       return mlir::success();
356     }
357 
358     mlir::Value result = op->getResult(0);
359     changed |= after->set(result, AllocationState::Allocated);
360   } else if (mlir::isa<fir::FreeMemOp>(op)) {
361     assert(op->getNumOperands() == 1 && "fir.freemem has one operand");
362     mlir::Value operand = op->getOperand(0);
363 
364     // Note: StackArrays is scheduled in the pass pipeline after lowering hlfir
365     // to fir. Therefore, we only need to handle `fir::DeclareOp`s.
366     if (auto declareOp =
367             llvm::dyn_cast_if_present<fir::DeclareOp>(operand.getDefiningOp()))
368       operand = declareOp.getMemref();
369 
370     std::optional<AllocationState> operandState = before.get(operand);
371     if (operandState && *operandState == AllocationState::Allocated) {
372       // don't tag things not allocated in this function as freed, so that we
373       // don't think they are candidates for moving to the stack
374       changed |= after->set(operand, AllocationState::Freed);
375     }
376   } else if (mlir::isa<fir::ResultOp>(op)) {
377     mlir::Operation *parent = op->getParentOp();
378     LatticePoint *parentLattice = getLattice(getProgramPointAfter(parent));
379     assert(parentLattice);
380     mlir::ChangeResult parentChanged = parentLattice->join(*after);
381     propagateIfChanged(parentLattice, parentChanged);
382   }
383 
384   // we pass lattices straight through fir.call because called functions should
385   // not deallocate flang-generated array temporaries
386 
387   LLVM_DEBUG(llvm::dbgs() << "--Lattice out: " << *after << "\n");
388   propagateIfChanged(after, changed);
389   return mlir::success();
390 }
391 
392 void AllocationAnalysis::setToEntryState(LatticePoint *lattice) {
393   propagateIfChanged(lattice, lattice->reset());
394 }
395 
396 /// Mostly a copy of AbstractDenseLattice::processOperation - the difference
397 /// being that call operations are passed through to the transfer function
398 mlir::LogicalResult AllocationAnalysis::processOperation(mlir::Operation *op) {
399   mlir::ProgramPoint *point = getProgramPointAfter(op);
400   // If the containing block is not executable, bail out.
401   if (op->getBlock() != nullptr &&
402       !getOrCreateFor<mlir::dataflow::Executable>(
403            point, getProgramPointBefore(op->getBlock()))
404            ->isLive())
405     return mlir::success();
406 
407   // Get the dense lattice to update
408   mlir::dataflow::AbstractDenseLattice *after = getLattice(point);
409 
410   // If this op implements region control-flow, then control-flow dictates its
411   // transfer function.
412   if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) {
413     visitRegionBranchOperation(point, branch, after);
414     return mlir::success();
415   }
416 
417   // pass call operations through to the transfer function
418 
419   // Get the dense state before the execution of the op.
420   const mlir::dataflow::AbstractDenseLattice *before =
421       getLatticeFor(point, getProgramPointBefore(op));
422 
423   /// Invoke the operation transfer function
424   return visitOperationImpl(op, *before, after);
425 }
426 
427 llvm::LogicalResult
428 StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
429   assert(mlir::isa<mlir::func::FuncOp>(func));
430   size_t nAllocs = 0;
431   func->walk([&nAllocs](fir::AllocMemOp) { nAllocs++; });
432   // don't bother with the analysis if there are no heap allocations
433   if (nAllocs == 0)
434     return mlir::success();
435   if ((maxAllocsPerFunc != 0) && (nAllocs > maxAllocsPerFunc)) {
436     LLVM_DEBUG(llvm::dbgs() << "Skipping stack arrays for function with "
437                             << nAllocs << " heap allocations");
438     return mlir::success();
439   }
440 
441   mlir::DataFlowSolver solver;
442   // constant propagation is required for dead code analysis, dead code analysis
443   // is required to mark blocks live (required for mlir dense dfa)
444   solver.load<mlir::dataflow::SparseConstantPropagation>();
445   solver.load<mlir::dataflow::DeadCodeAnalysis>();
446 
447   auto [it, inserted] = funcMaps.try_emplace(func);
448   AllocMemMap &candidateOps = it->second;
449 
450   solver.load<AllocationAnalysis>();
451   if (failed(solver.initializeAndRun(func))) {
452     llvm::errs() << "DataFlowSolver failed!";
453     return mlir::failure();
454   }
455 
456   LatticePoint point{solver.getProgramPointAfter(func)};
457   auto joinOperationLattice = [&](mlir::Operation *op) {
458     const LatticePoint *lattice =
459         solver.lookupState<LatticePoint>(solver.getProgramPointAfter(op));
460     // there will be no lattice for an unreachable block
461     if (lattice)
462       (void)point.join(*lattice);
463   };
464 
465   func->walk([&](mlir::func::ReturnOp child) { joinOperationLattice(child); });
466   func->walk([&](fir::UnreachableOp child) { joinOperationLattice(child); });
467   func->walk(
468       [&](mlir::omp::TerminatorOp child) { joinOperationLattice(child); });
469   func->walk([&](mlir::omp::YieldOp child) { joinOperationLattice(child); });
470 
471   llvm::DenseSet<mlir::Value> freedValues;
472   point.appendFreedValues(freedValues);
473 
474   // We only replace allocations which are definately freed on all routes
475   // through the function because otherwise the allocation may have an intende
476   // lifetime longer than the current stack frame (e.g. a heap allocation which
477   // is then freed by another function).
478   for (mlir::Value freedValue : freedValues) {
479     fir::AllocMemOp allocmem = freedValue.getDefiningOp<fir::AllocMemOp>();
480     InsertionPoint insertionPoint =
481         AllocMemConversion::findAllocaInsertionPoint(allocmem);
482     if (insertionPoint)
483       candidateOps.insert({allocmem, insertionPoint});
484   }
485 
486   LLVM_DEBUG(for (auto [allocMemOp, _]
487                   : candidateOps) {
488     llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n';
489   });
490   return mlir::success();
491 }
492 
493 const StackArraysAnalysisWrapper::AllocMemMap *
494 StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) {
495   if (!funcMaps.contains(func))
496     if (mlir::failed(analyseFunction(func)))
497       return nullptr;
498   return &funcMaps[func];
499 }
500 
501 /// Restore the old allocation type exected by existing code
502 static mlir::Value convertAllocationType(mlir::PatternRewriter &rewriter,
503                                          const mlir::Location &loc,
504                                          mlir::Value heap, mlir::Value stack) {
505   mlir::Type heapTy = heap.getType();
506   mlir::Type stackTy = stack.getType();
507 
508   if (heapTy == stackTy)
509     return stack;
510 
511   fir::HeapType firHeapTy = mlir::cast<fir::HeapType>(heapTy);
512   LLVM_ATTRIBUTE_UNUSED fir::ReferenceType firRefTy =
513       mlir::cast<fir::ReferenceType>(stackTy);
514   assert(firHeapTy.getElementType() == firRefTy.getElementType() &&
515          "Allocations must have the same type");
516 
517   auto insertionPoint = rewriter.saveInsertionPoint();
518   rewriter.setInsertionPointAfter(stack.getDefiningOp());
519   mlir::Value conv =
520       rewriter.create<fir::ConvertOp>(loc, firHeapTy, stack).getResult();
521   rewriter.restoreInsertionPoint(insertionPoint);
522   return conv;
523 }
524 
525 llvm::LogicalResult
526 AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
527                                     mlir::PatternRewriter &rewriter) const {
528   auto oldInsertionPt = rewriter.saveInsertionPoint();
529   // add alloca operation
530   std::optional<fir::AllocaOp> alloca = insertAlloca(allocmem, rewriter);
531   rewriter.restoreInsertionPoint(oldInsertionPt);
532   if (!alloca)
533     return mlir::failure();
534 
535   // remove freemem operations
536   llvm::SmallVector<mlir::Operation *> erases;
537   for (mlir::Operation *user : allocmem.getOperation()->getUsers()) {
538     if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
539       for (mlir::Operation *user : declareOp->getUsers()) {
540         if (mlir::isa<fir::FreeMemOp>(user))
541           erases.push_back(user);
542       }
543     }
544 
545     if (mlir::isa<fir::FreeMemOp>(user))
546       erases.push_back(user);
547   }
548 
549   // now we are done iterating the users, it is safe to mutate them
550   for (mlir::Operation *erase : erases)
551     rewriter.eraseOp(erase);
552 
553   // replace references to heap allocation with references to stack allocation
554   mlir::Value newValue = convertAllocationType(
555       rewriter, allocmem.getLoc(), allocmem.getResult(), alloca->getResult());
556   rewriter.replaceAllUsesWith(allocmem.getResult(), newValue);
557 
558   // remove allocmem operation
559   rewriter.eraseOp(allocmem.getOperation());
560 
561   return mlir::success();
562 }
563 
564 static bool isInLoop(mlir::Block *block) {
565   return mlir::LoopLikeOpInterface::blockIsInLoop(block);
566 }
567 
568 static bool isInLoop(mlir::Operation *op) {
569   return isInLoop(op->getBlock()) ||
570          op->getParentOfType<mlir::LoopLikeOpInterface>();
571 }
572 
573 InsertionPoint
574 AllocMemConversion::findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc) {
575   // Ideally the alloca should be inserted at the end of the function entry
576   // block so that we do not allocate stack space in a loop. However,
577   // the operands to the alloca may not be available that early, so insert it
578   // after the last operand becomes available
579   // If the old allocmem op was in an openmp region then it should not be moved
580   // outside of that
581   LLVM_DEBUG(llvm::dbgs() << "StackArrays: findAllocaInsertionPoint: "
582                           << oldAlloc << "\n");
583 
584   // check that an Operation or Block we are about to return is not in a loop
585   auto checkReturn = [&](auto *point) -> InsertionPoint {
586     if (isInLoop(point)) {
587       mlir::Operation *oldAllocOp = oldAlloc.getOperation();
588       if (isInLoop(oldAllocOp)) {
589         // where we want to put it is in a loop, and even the old location is in
590         // a loop. Give up.
591         return findAllocaLoopInsertionPoint(oldAlloc);
592       }
593       return {oldAllocOp};
594     }
595     return {point};
596   };
597 
598   auto oldOmpRegion =
599       oldAlloc->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>();
600 
601   // Find when the last operand value becomes available
602   mlir::Block *operandsBlock = nullptr;
603   mlir::Operation *lastOperand = nullptr;
604   for (mlir::Value operand : oldAlloc.getOperands()) {
605     LLVM_DEBUG(llvm::dbgs() << "--considering operand " << operand << "\n");
606     mlir::Operation *op = operand.getDefiningOp();
607     if (!op)
608       return checkReturn(oldAlloc.getOperation());
609     if (!operandsBlock)
610       operandsBlock = op->getBlock();
611     else if (operandsBlock != op->getBlock()) {
612       LLVM_DEBUG(llvm::dbgs()
613                  << "----operand declared in a different block!\n");
614       // Operation::isBeforeInBlock requires the operations to be in the same
615       // block. The best we can do is the location of the allocmem.
616       return checkReturn(oldAlloc.getOperation());
617     }
618     if (!lastOperand || lastOperand->isBeforeInBlock(op))
619       lastOperand = op;
620   }
621 
622   if (lastOperand) {
623     // there were value operands to the allocmem so insert after the last one
624     LLVM_DEBUG(llvm::dbgs()
625                << "--Placing after last operand: " << *lastOperand << "\n");
626     // check we aren't moving out of an omp region
627     auto lastOpOmpRegion =
628         lastOperand->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>();
629     if (lastOpOmpRegion == oldOmpRegion)
630       return checkReturn(lastOperand);
631     // Presumably this happened because the operands became ready before the
632     // start of this openmp region. (lastOpOmpRegion != oldOmpRegion) should
633     // imply that oldOmpRegion comes after lastOpOmpRegion.
634     return checkReturn(oldOmpRegion.getAllocaBlock());
635   }
636 
637   // There were no value operands to the allocmem so we are safe to insert it
638   // as early as we want
639 
640   // handle openmp case
641   if (oldOmpRegion)
642     return checkReturn(oldOmpRegion.getAllocaBlock());
643 
644   // fall back to the function entry block
645   mlir::func::FuncOp func = oldAlloc->getParentOfType<mlir::func::FuncOp>();
646   assert(func && "This analysis is run on func.func");
647   mlir::Block &entryBlock = func.getBlocks().front();
648   LLVM_DEBUG(llvm::dbgs() << "--Placing at the start of func entry block\n");
649   return checkReturn(&entryBlock);
650 }
651 
652 InsertionPoint
653 AllocMemConversion::findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc) {
654   mlir::Operation *oldAllocOp = oldAlloc;
655   // This is only called as a last resort. We should try to insert at the
656   // location of the old allocation, which is inside of a loop, using
657   // llvm.stacksave/llvm.stackrestore
658 
659   // find freemem ops
660   llvm::SmallVector<mlir::Operation *, 1> freeOps;
661 
662   for (mlir::Operation *user : oldAllocOp->getUsers()) {
663     if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
664       for (mlir::Operation *user : declareOp->getUsers()) {
665         if (mlir::isa<fir::FreeMemOp>(user))
666           freeOps.push_back(user);
667       }
668     }
669 
670     if (mlir::isa<fir::FreeMemOp>(user))
671       freeOps.push_back(user);
672   }
673 
674   assert(freeOps.size() && "DFA should only return freed memory");
675 
676   // Don't attempt to reason about a stacksave/stackrestore between different
677   // blocks
678   for (mlir::Operation *free : freeOps)
679     if (free->getBlock() != oldAllocOp->getBlock())
680       return {nullptr};
681 
682   // Check that there aren't any other stack allocations in between the
683   // stack save and stack restore
684   // note: for flang generated temporaries there should only be one free op
685   for (mlir::Operation *free : freeOps) {
686     for (mlir::Operation *op = oldAlloc; op && op != free;
687          op = op->getNextNode()) {
688       if (mlir::isa<fir::AllocaOp>(op))
689         return {nullptr};
690     }
691   }
692 
693   return InsertionPoint{oldAllocOp, /*shouldStackSaveRestore=*/true};
694 }
695 
696 std::optional<fir::AllocaOp>
697 AllocMemConversion::insertAlloca(fir::AllocMemOp &oldAlloc,
698                                  mlir::PatternRewriter &rewriter) const {
699   auto it = candidateOps.find(oldAlloc.getOperation());
700   if (it == candidateOps.end())
701     return {};
702   InsertionPoint insertionPoint = it->second;
703   if (!insertionPoint)
704     return {};
705 
706   if (insertionPoint.shouldSaveRestoreStack())
707     insertStackSaveRestore(oldAlloc, rewriter);
708 
709   mlir::Location loc = oldAlloc.getLoc();
710   mlir::Type varTy = oldAlloc.getInType();
711   if (mlir::Operation *op = insertionPoint.tryGetOperation()) {
712     rewriter.setInsertionPointAfter(op);
713   } else {
714     mlir::Block *block = insertionPoint.tryGetBlock();
715     assert(block && "There must be a valid insertion point");
716     rewriter.setInsertionPointToStart(block);
717   }
718 
719   auto unpackName = [](std::optional<llvm::StringRef> opt) -> llvm::StringRef {
720     if (opt)
721       return *opt;
722     return {};
723   };
724 
725   llvm::StringRef uniqName = unpackName(oldAlloc.getUniqName());
726   llvm::StringRef bindcName = unpackName(oldAlloc.getBindcName());
727   return rewriter.create<fir::AllocaOp>(loc, varTy, uniqName, bindcName,
728                                         oldAlloc.getTypeparams(),
729                                         oldAlloc.getShape());
730 }
731 
732 void AllocMemConversion::insertStackSaveRestore(
733     fir::AllocMemOp &oldAlloc, mlir::PatternRewriter &rewriter) const {
734   auto oldPoint = rewriter.saveInsertionPoint();
735   auto mod = oldAlloc->getParentOfType<mlir::ModuleOp>();
736   fir::FirOpBuilder builder{rewriter, mod};
737 
738   builder.setInsertionPoint(oldAlloc);
739   mlir::Value sp = builder.genStackSave(oldAlloc.getLoc());
740 
741   auto createStackRestoreCall = [&](mlir::Operation *user) {
742     builder.setInsertionPoint(user);
743     builder.genStackRestore(user->getLoc(), sp);
744   };
745 
746   for (mlir::Operation *user : oldAlloc->getUsers()) {
747     if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
748       for (mlir::Operation *user : declareOp->getUsers()) {
749         if (mlir::isa<fir::FreeMemOp>(user))
750           createStackRestoreCall(user);
751       }
752     }
753 
754     if (mlir::isa<fir::FreeMemOp>(user)) {
755       createStackRestoreCall(user);
756     }
757   }
758 
759   rewriter.restoreInsertionPoint(oldPoint);
760 }
761 
762 StackArraysPass::StackArraysPass(const StackArraysPass &pass)
763     : fir::impl::StackArraysBase<StackArraysPass>(pass) {}
764 
765 llvm::StringRef StackArraysPass::getDescription() const {
766   return "Move heap allocated array temporaries to the stack";
767 }
768 
769 void StackArraysPass::runOnOperation() {
770   mlir::func::FuncOp func = getOperation();
771 
772   auto &analysis = getAnalysis<StackArraysAnalysisWrapper>();
773   const StackArraysAnalysisWrapper::AllocMemMap *candidateOps =
774       analysis.getCandidateOps(func);
775   if (!candidateOps) {
776     signalPassFailure();
777     return;
778   }
779 
780   if (candidateOps->empty())
781     return;
782   runCount += candidateOps->size();
783 
784   llvm::SmallVector<mlir::Operation *> opsToConvert;
785   opsToConvert.reserve(candidateOps->size());
786   for (auto [op, _] : *candidateOps)
787     opsToConvert.push_back(op);
788 
789   mlir::MLIRContext &context = getContext();
790   mlir::RewritePatternSet patterns(&context);
791   mlir::GreedyRewriteConfig config;
792   // prevent the pattern driver form merging blocks
793   config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
794 
795   patterns.insert<AllocMemConversion>(&context, *candidateOps);
796   if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert,
797                                                 std::move(patterns), config))) {
798     mlir::emitError(func->getLoc(), "error in stack arrays optimization\n");
799     signalPassFailure();
800   }
801 }
802