xref: /llvm-project/flang/lib/Optimizer/Transforms/StackArrays.cpp (revision e3cd88a7be1dfd912bb6e7c7e888e7b442ffb5de)
1 //===- StackArrays.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
11 #include "flang/Optimizer/Dialect/FIRAttr.h"
12 #include "flang/Optimizer/Dialect/FIRDialect.h"
13 #include "flang/Optimizer/Dialect/FIROps.h"
14 #include "flang/Optimizer/Dialect/FIRType.h"
15 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
16 #include "flang/Optimizer/Transforms/Passes.h"
17 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
18 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
19 #include "mlir/Analysis/DataFlow/DenseAnalysis.h"
20 #include "mlir/Analysis/DataFlowFramework.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/Value.h"
26 #include "mlir/Interfaces/LoopLikeInterface.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29 #include "mlir/Transforms/Passes.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/PointerUnion.h"
33 #include "llvm/Support/Casting.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include <optional>
36 
37 namespace fir {
38 #define GEN_PASS_DEF_STACKARRAYS
39 #include "flang/Optimizer/Transforms/Passes.h.inc"
40 } // namespace fir
41 
42 #define DEBUG_TYPE "stack-arrays"
43 
44 static llvm::cl::opt<std::size_t> maxAllocsPerFunc(
45     "stack-arrays-max-allocs",
46     llvm::cl::desc("The maximum number of heap allocations to consider in one "
47                    "function before skipping (to save compilation time). Set "
48                    "to 0 for no limit."),
49     llvm::cl::init(1000), llvm::cl::Hidden);
50 
51 namespace {
52 
53 /// The state of an SSA value at each program point
54 enum class AllocationState {
55   /// This means that the allocation state of a variable cannot be determined
56   /// at this program point, e.g. because one route through a conditional freed
57   /// the variable and the other route didn't.
58   /// This asserts a known-unknown: different from the unknown-unknown of having
59   /// no AllocationState stored for a particular SSA value
60   Unknown,
61   /// Means this SSA value was allocated on the heap in this function and has
62   /// now been freed
63   Freed,
64   /// Means this SSA value was allocated on the heap in this function and is a
65   /// candidate for moving to the stack
66   Allocated,
67 };
68 
69 /// Stores where an alloca should be inserted. If the PointerUnion is an
70 /// Operation the alloca should be inserted /after/ the operation. If it is a
71 /// block, the alloca can be placed anywhere in that block.
72 class InsertionPoint {
73   llvm::PointerUnion<mlir::Operation *, mlir::Block *> location;
74   bool saveRestoreStack;
75 
76   /// Get contained pointer type or nullptr
77   template <class T>
78   T *tryGetPtr() const {
79     // Use llvm::dyn_cast_if_present because location may be null here.
80     if (T *ptr = llvm::dyn_cast_if_present<T *>(location))
81       return ptr;
82     return nullptr;
83   }
84 
85 public:
86   template <class T>
87   InsertionPoint(T *ptr, bool saveRestoreStack = false)
88       : location(ptr), saveRestoreStack{saveRestoreStack} {}
89   InsertionPoint(std::nullptr_t null)
90       : location(null), saveRestoreStack{false} {}
91 
92   /// Get contained operation, or nullptr
93   mlir::Operation *tryGetOperation() const {
94     return tryGetPtr<mlir::Operation>();
95   }
96 
97   /// Get contained block, or nullptr
98   mlir::Block *tryGetBlock() const { return tryGetPtr<mlir::Block>(); }
99 
100   /// Get whether the stack should be saved/restored. If yes, an llvm.stacksave
101   /// intrinsic should be added before the alloca, and an llvm.stackrestore
102   /// intrinsic should be added where the freemem is
103   bool shouldSaveRestoreStack() const { return saveRestoreStack; }
104 
105   operator bool() const { return tryGetOperation() || tryGetBlock(); }
106 
107   bool operator==(const InsertionPoint &rhs) const {
108     return (location == rhs.location) &&
109            (saveRestoreStack == rhs.saveRestoreStack);
110   }
111 
112   bool operator!=(const InsertionPoint &rhs) const { return !(*this == rhs); }
113 };
114 
115 /// Maps SSA values to their AllocationState at a particular program point.
116 /// Also caches the insertion points for the new alloca operations
117 class LatticePoint : public mlir::dataflow::AbstractDenseLattice {
118   // Maps all values we are interested in to states
119   llvm::SmallDenseMap<mlir::Value, AllocationState, 1> stateMap;
120 
121 public:
122   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LatticePoint)
123   using AbstractDenseLattice::AbstractDenseLattice;
124 
125   bool operator==(const LatticePoint &rhs) const {
126     return stateMap == rhs.stateMap;
127   }
128 
129   /// Join the lattice accross control-flow edges
130   mlir::ChangeResult join(const AbstractDenseLattice &lattice) override;
131 
132   void print(llvm::raw_ostream &os) const override;
133 
134   /// Clear all modifications
135   mlir::ChangeResult reset();
136 
137   /// Set the state of an SSA value
138   mlir::ChangeResult set(mlir::Value value, AllocationState state);
139 
140   /// Get fir.allocmem ops which were allocated in this function and always
141   /// freed before the function returns, plus whre to insert replacement
142   /// fir.alloca ops
143   void appendFreedValues(llvm::DenseSet<mlir::Value> &out) const;
144 
145   std::optional<AllocationState> get(mlir::Value val) const;
146 };
147 
148 class AllocationAnalysis
149     : public mlir::dataflow::DenseForwardDataFlowAnalysis<LatticePoint> {
150 public:
151   using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis;
152 
153   mlir::LogicalResult visitOperation(mlir::Operation *op,
154                                      const LatticePoint &before,
155                                      LatticePoint *after) override;
156 
157   /// At an entry point, the last modifications of all memory resources are
158   /// yet to be determined
159   void setToEntryState(LatticePoint *lattice) override;
160 
161 protected:
162   /// Visit control flow operations and decide whether to call visitOperation
163   /// to apply the transfer function
164   mlir::LogicalResult processOperation(mlir::Operation *op) override;
165 };
166 
167 /// Drives analysis to find candidate fir.allocmem operations which could be
168 /// moved to the stack. Intended to be used with mlir::Pass::getAnalysis
169 class StackArraysAnalysisWrapper {
170 public:
171   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StackArraysAnalysisWrapper)
172 
173   // Maps fir.allocmem -> place to insert alloca
174   using AllocMemMap = llvm::DenseMap<mlir::Operation *, InsertionPoint>;
175 
176   StackArraysAnalysisWrapper(mlir::Operation *op) {}
177 
178   // returns nullptr if analysis failed
179   const AllocMemMap *getCandidateOps(mlir::Operation *func);
180 
181 private:
182   llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps;
183 
184   llvm::LogicalResult analyseFunction(mlir::Operation *func);
185 };
186 
187 /// Converts a fir.allocmem to a fir.alloca
188 class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
189 public:
190   explicit AllocMemConversion(
191       mlir::MLIRContext *ctx,
192       const StackArraysAnalysisWrapper::AllocMemMap &candidateOps)
193       : OpRewritePattern(ctx), candidateOps{candidateOps} {}
194 
195   llvm::LogicalResult
196   matchAndRewrite(fir::AllocMemOp allocmem,
197                   mlir::PatternRewriter &rewriter) const override;
198 
199   /// Determine where to insert the alloca operation. The returned value should
200   /// be checked to see if it is inside a loop
201   static InsertionPoint
202   findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc,
203                            const llvm::SmallVector<mlir::Operation *> &freeOps);
204 
205 private:
206   /// Handle to the DFA (already run)
207   const StackArraysAnalysisWrapper::AllocMemMap &candidateOps;
208 
209   /// If we failed to find an insertion point not inside a loop, see if it would
210   /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop
211   static InsertionPoint findAllocaLoopInsertionPoint(
212       fir::AllocMemOp &oldAlloc,
213       const llvm::SmallVector<mlir::Operation *> &freeOps);
214 
215   /// Returns the alloca if it was successfully inserted, otherwise {}
216   std::optional<fir::AllocaOp>
217   insertAlloca(fir::AllocMemOp &oldAlloc,
218                mlir::PatternRewriter &rewriter) const;
219 
220   /// Inserts a stacksave before oldAlloc and a stackrestore after each freemem
221   void insertStackSaveRestore(fir::AllocMemOp &oldAlloc,
222                               mlir::PatternRewriter &rewriter) const;
223 };
224 
225 class StackArraysPass : public fir::impl::StackArraysBase<StackArraysPass> {
226 public:
227   StackArraysPass() = default;
228   StackArraysPass(const StackArraysPass &pass);
229 
230   llvm::StringRef getDescription() const override;
231 
232   void runOnOperation() override;
233 
234 private:
235   Statistic runCount{this, "stackArraysRunCount",
236                      "Number of heap allocations moved to the stack"};
237 };
238 
239 } // namespace
240 
241 static void print(llvm::raw_ostream &os, AllocationState state) {
242   switch (state) {
243   case AllocationState::Unknown:
244     os << "Unknown";
245     break;
246   case AllocationState::Freed:
247     os << "Freed";
248     break;
249   case AllocationState::Allocated:
250     os << "Allocated";
251     break;
252   }
253 }
254 
255 /// Join two AllocationStates for the same value coming from different CFG
256 /// blocks
257 static AllocationState join(AllocationState lhs, AllocationState rhs) {
258   //           | Allocated | Freed     | Unknown
259   // ========= | ========= | ========= | =========
260   // Allocated | Allocated | Unknown   | Unknown
261   // Freed     | Unknown   | Freed     | Unknown
262   // Unknown   | Unknown   | Unknown   | Unknown
263   if (lhs == rhs)
264     return lhs;
265   return AllocationState::Unknown;
266 }
267 
268 mlir::ChangeResult LatticePoint::join(const AbstractDenseLattice &lattice) {
269   const auto &rhs = static_cast<const LatticePoint &>(lattice);
270   mlir::ChangeResult changed = mlir::ChangeResult::NoChange;
271 
272   // add everything from rhs to map, handling cases where values are in both
273   for (const auto &[value, rhsState] : rhs.stateMap) {
274     auto it = stateMap.find(value);
275     if (it != stateMap.end()) {
276       // value is present in both maps
277       AllocationState myState = it->second;
278       AllocationState newState = ::join(myState, rhsState);
279       if (newState != myState) {
280         changed = mlir::ChangeResult::Change;
281         it->getSecond() = newState;
282       }
283     } else {
284       // value not present in current map: add it
285       stateMap.insert({value, rhsState});
286       changed = mlir::ChangeResult::Change;
287     }
288   }
289 
290   return changed;
291 }
292 
293 void LatticePoint::print(llvm::raw_ostream &os) const {
294   for (const auto &[value, state] : stateMap) {
295     os << "\n * " << value << ": ";
296     ::print(os, state);
297   }
298 }
299 
300 mlir::ChangeResult LatticePoint::reset() {
301   if (stateMap.empty())
302     return mlir::ChangeResult::NoChange;
303   stateMap.clear();
304   return mlir::ChangeResult::Change;
305 }
306 
307 mlir::ChangeResult LatticePoint::set(mlir::Value value, AllocationState state) {
308   if (stateMap.count(value)) {
309     // already in map
310     AllocationState &oldState = stateMap[value];
311     if (oldState != state) {
312       stateMap[value] = state;
313       return mlir::ChangeResult::Change;
314     }
315     return mlir::ChangeResult::NoChange;
316   }
317   stateMap.insert({value, state});
318   return mlir::ChangeResult::Change;
319 }
320 
321 /// Get values which were allocated in this function and always freed before
322 /// the function returns
323 void LatticePoint::appendFreedValues(llvm::DenseSet<mlir::Value> &out) const {
324   for (auto &[value, state] : stateMap) {
325     if (state == AllocationState::Freed)
326       out.insert(value);
327   }
328 }
329 
330 std::optional<AllocationState> LatticePoint::get(mlir::Value val) const {
331   auto it = stateMap.find(val);
332   if (it == stateMap.end())
333     return {};
334   return it->second;
335 }
336 
337 static mlir::Value lookThroughDeclaresAndConverts(mlir::Value value) {
338   while (mlir::Operation *op = value.getDefiningOp()) {
339     if (auto declareOp = llvm::dyn_cast<fir::DeclareOp>(op))
340       value = declareOp.getMemref();
341     else if (auto convertOp = llvm::dyn_cast<fir::ConvertOp>(op))
342       value = convertOp->getOperand(0);
343     else
344       return value;
345   }
346   return value;
347 }
348 
349 mlir::LogicalResult AllocationAnalysis::visitOperation(
350     mlir::Operation *op, const LatticePoint &before, LatticePoint *after) {
351   LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op
352                           << "\n");
353   LLVM_DEBUG(llvm::dbgs() << "--Lattice in: " << before << "\n");
354 
355   // propagate before -> after
356   mlir::ChangeResult changed = after->join(before);
357 
358   if (auto allocmem = mlir::dyn_cast<fir::AllocMemOp>(op)) {
359     assert(op->getNumResults() == 1 && "fir.allocmem has one result");
360     auto attr = op->getAttrOfType<fir::MustBeHeapAttr>(
361         fir::MustBeHeapAttr::getAttrName());
362     if (attr && attr.getValue()) {
363       LLVM_DEBUG(llvm::dbgs() << "--Found fir.must_be_heap: skipping\n");
364       // skip allocation marked not to be moved
365       return mlir::success();
366     }
367 
368     auto retTy = allocmem.getAllocatedType();
369     if (!mlir::isa<fir::SequenceType>(retTy)) {
370       LLVM_DEBUG(llvm::dbgs()
371                  << "--Allocation is not for an array: skipping\n");
372       return mlir::success();
373     }
374 
375     mlir::Value result = op->getResult(0);
376     changed |= after->set(result, AllocationState::Allocated);
377   } else if (mlir::isa<fir::FreeMemOp>(op)) {
378     assert(op->getNumOperands() == 1 && "fir.freemem has one operand");
379     mlir::Value operand = op->getOperand(0);
380 
381     // Note: StackArrays is scheduled in the pass pipeline after lowering hlfir
382     // to fir. Therefore, we only need to handle `fir::DeclareOp`s. Also look
383     // past converts in case the pointer was changed between different pointer
384     // types.
385     operand = lookThroughDeclaresAndConverts(operand);
386 
387     std::optional<AllocationState> operandState = before.get(operand);
388     if (operandState && *operandState == AllocationState::Allocated) {
389       // don't tag things not allocated in this function as freed, so that we
390       // don't think they are candidates for moving to the stack
391       changed |= after->set(operand, AllocationState::Freed);
392     }
393   } else if (mlir::isa<fir::ResultOp>(op)) {
394     mlir::Operation *parent = op->getParentOp();
395     LatticePoint *parentLattice = getLattice(getProgramPointAfter(parent));
396     assert(parentLattice);
397     mlir::ChangeResult parentChanged = parentLattice->join(*after);
398     propagateIfChanged(parentLattice, parentChanged);
399   }
400 
401   // we pass lattices straight through fir.call because called functions should
402   // not deallocate flang-generated array temporaries
403 
404   LLVM_DEBUG(llvm::dbgs() << "--Lattice out: " << *after << "\n");
405   propagateIfChanged(after, changed);
406   return mlir::success();
407 }
408 
409 void AllocationAnalysis::setToEntryState(LatticePoint *lattice) {
410   propagateIfChanged(lattice, lattice->reset());
411 }
412 
413 /// Mostly a copy of AbstractDenseLattice::processOperation - the difference
414 /// being that call operations are passed through to the transfer function
415 mlir::LogicalResult AllocationAnalysis::processOperation(mlir::Operation *op) {
416   mlir::ProgramPoint *point = getProgramPointAfter(op);
417   // If the containing block is not executable, bail out.
418   if (op->getBlock() != nullptr &&
419       !getOrCreateFor<mlir::dataflow::Executable>(
420            point, getProgramPointBefore(op->getBlock()))
421            ->isLive())
422     return mlir::success();
423 
424   // Get the dense lattice to update
425   mlir::dataflow::AbstractDenseLattice *after = getLattice(point);
426 
427   // If this op implements region control-flow, then control-flow dictates its
428   // transfer function.
429   if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) {
430     visitRegionBranchOperation(point, branch, after);
431     return mlir::success();
432   }
433 
434   // pass call operations through to the transfer function
435 
436   // Get the dense state before the execution of the op.
437   const mlir::dataflow::AbstractDenseLattice *before =
438       getLatticeFor(point, getProgramPointBefore(op));
439 
440   /// Invoke the operation transfer function
441   return visitOperationImpl(op, *before, after);
442 }
443 
444 llvm::LogicalResult
445 StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
446   assert(mlir::isa<mlir::func::FuncOp>(func));
447   size_t nAllocs = 0;
448   func->walk([&nAllocs](fir::AllocMemOp) { nAllocs++; });
449   // don't bother with the analysis if there are no heap allocations
450   if (nAllocs == 0)
451     return mlir::success();
452   if ((maxAllocsPerFunc != 0) && (nAllocs > maxAllocsPerFunc)) {
453     LLVM_DEBUG(llvm::dbgs() << "Skipping stack arrays for function with "
454                             << nAllocs << " heap allocations");
455     return mlir::success();
456   }
457 
458   mlir::DataFlowSolver solver;
459   // constant propagation is required for dead code analysis, dead code analysis
460   // is required to mark blocks live (required for mlir dense dfa)
461   solver.load<mlir::dataflow::SparseConstantPropagation>();
462   solver.load<mlir::dataflow::DeadCodeAnalysis>();
463 
464   auto [it, inserted] = funcMaps.try_emplace(func);
465   AllocMemMap &candidateOps = it->second;
466 
467   solver.load<AllocationAnalysis>();
468   if (failed(solver.initializeAndRun(func))) {
469     llvm::errs() << "DataFlowSolver failed!";
470     return mlir::failure();
471   }
472 
473   LatticePoint point{solver.getProgramPointAfter(func)};
474   auto joinOperationLattice = [&](mlir::Operation *op) {
475     const LatticePoint *lattice =
476         solver.lookupState<LatticePoint>(solver.getProgramPointAfter(op));
477     // there will be no lattice for an unreachable block
478     if (lattice)
479       (void)point.join(*lattice);
480   };
481 
482   func->walk([&](mlir::func::ReturnOp child) { joinOperationLattice(child); });
483   func->walk([&](fir::UnreachableOp child) { joinOperationLattice(child); });
484   func->walk(
485       [&](mlir::omp::TerminatorOp child) { joinOperationLattice(child); });
486   func->walk([&](mlir::omp::YieldOp child) { joinOperationLattice(child); });
487 
488   llvm::DenseSet<mlir::Value> freedValues;
489   point.appendFreedValues(freedValues);
490 
491   // Find all fir.freemem operations corresponding to fir.allocmem
492   // in freedValues. It is best to find the association going back
493   // from fir.freemem to fir.allocmem through the def-use chains,
494   // so that we can use lookThroughDeclaresAndConverts same way
495   // the AllocationAnalysis is handling them.
496   llvm::DenseMap<mlir::Operation *, llvm::SmallVector<mlir::Operation *>>
497       allocToFreeMemMap;
498   func->walk([&](fir::FreeMemOp freeOp) {
499     mlir::Value memref = lookThroughDeclaresAndConverts(freeOp.getHeapref());
500     if (!freedValues.count(memref))
501       return;
502 
503     auto allocMem = memref.getDefiningOp<fir::AllocMemOp>();
504     allocToFreeMemMap[allocMem].push_back(freeOp);
505   });
506 
507   // We only replace allocations which are definately freed on all routes
508   // through the function because otherwise the allocation may have an intende
509   // lifetime longer than the current stack frame (e.g. a heap allocation which
510   // is then freed by another function).
511   for (mlir::Value freedValue : freedValues) {
512     fir::AllocMemOp allocmem = freedValue.getDefiningOp<fir::AllocMemOp>();
513     InsertionPoint insertionPoint =
514         AllocMemConversion::findAllocaInsertionPoint(
515             allocmem, allocToFreeMemMap[allocmem]);
516     if (insertionPoint)
517       candidateOps.insert({allocmem, insertionPoint});
518   }
519 
520   LLVM_DEBUG(for (auto [allocMemOp, _]
521                   : candidateOps) {
522     llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n';
523   });
524   return mlir::success();
525 }
526 
527 const StackArraysAnalysisWrapper::AllocMemMap *
528 StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) {
529   if (!funcMaps.contains(func))
530     if (mlir::failed(analyseFunction(func)))
531       return nullptr;
532   return &funcMaps[func];
533 }
534 
535 /// Restore the old allocation type exected by existing code
536 static mlir::Value convertAllocationType(mlir::PatternRewriter &rewriter,
537                                          const mlir::Location &loc,
538                                          mlir::Value heap, mlir::Value stack) {
539   mlir::Type heapTy = heap.getType();
540   mlir::Type stackTy = stack.getType();
541 
542   if (heapTy == stackTy)
543     return stack;
544 
545   fir::HeapType firHeapTy = mlir::cast<fir::HeapType>(heapTy);
546   LLVM_ATTRIBUTE_UNUSED fir::ReferenceType firRefTy =
547       mlir::cast<fir::ReferenceType>(stackTy);
548   assert(firHeapTy.getElementType() == firRefTy.getElementType() &&
549          "Allocations must have the same type");
550 
551   auto insertionPoint = rewriter.saveInsertionPoint();
552   rewriter.setInsertionPointAfter(stack.getDefiningOp());
553   mlir::Value conv =
554       rewriter.create<fir::ConvertOp>(loc, firHeapTy, stack).getResult();
555   rewriter.restoreInsertionPoint(insertionPoint);
556   return conv;
557 }
558 
559 llvm::LogicalResult
560 AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
561                                     mlir::PatternRewriter &rewriter) const {
562   auto oldInsertionPt = rewriter.saveInsertionPoint();
563   // add alloca operation
564   std::optional<fir::AllocaOp> alloca = insertAlloca(allocmem, rewriter);
565   rewriter.restoreInsertionPoint(oldInsertionPt);
566   if (!alloca)
567     return mlir::failure();
568 
569   // remove freemem operations
570   llvm::SmallVector<mlir::Operation *> erases;
571   mlir::Operation *parent = allocmem->getParentOp();
572   // TODO: this shouldn't need to be re-calculated for every allocmem
573   parent->walk([&](fir::FreeMemOp freeOp) {
574     if (lookThroughDeclaresAndConverts(freeOp->getOperand(0)) == allocmem)
575       erases.push_back(freeOp);
576   });
577 
578   // now we are done iterating the users, it is safe to mutate them
579   for (mlir::Operation *erase : erases)
580     rewriter.eraseOp(erase);
581 
582   // replace references to heap allocation with references to stack allocation
583   mlir::Value newValue = convertAllocationType(
584       rewriter, allocmem.getLoc(), allocmem.getResult(), alloca->getResult());
585   rewriter.replaceAllUsesWith(allocmem.getResult(), newValue);
586 
587   // remove allocmem operation
588   rewriter.eraseOp(allocmem.getOperation());
589 
590   return mlir::success();
591 }
592 
593 static bool isInLoop(mlir::Block *block) {
594   return mlir::LoopLikeOpInterface::blockIsInLoop(block);
595 }
596 
597 static bool isInLoop(mlir::Operation *op) {
598   return isInLoop(op->getBlock()) ||
599          op->getParentOfType<mlir::LoopLikeOpInterface>();
600 }
601 
602 InsertionPoint AllocMemConversion::findAllocaInsertionPoint(
603     fir::AllocMemOp &oldAlloc,
604     const llvm::SmallVector<mlir::Operation *> &freeOps) {
605   // Ideally the alloca should be inserted at the end of the function entry
606   // block so that we do not allocate stack space in a loop. However,
607   // the operands to the alloca may not be available that early, so insert it
608   // after the last operand becomes available
609   // If the old allocmem op was in an openmp region then it should not be moved
610   // outside of that
611   LLVM_DEBUG(llvm::dbgs() << "StackArrays: findAllocaInsertionPoint: "
612                           << oldAlloc << "\n");
613 
614   // check that an Operation or Block we are about to return is not in a loop
615   auto checkReturn = [&](auto *point) -> InsertionPoint {
616     if (isInLoop(point)) {
617       mlir::Operation *oldAllocOp = oldAlloc.getOperation();
618       if (isInLoop(oldAllocOp)) {
619         // where we want to put it is in a loop, and even the old location is in
620         // a loop. Give up.
621         return findAllocaLoopInsertionPoint(oldAlloc, freeOps);
622       }
623       return {oldAllocOp};
624     }
625     return {point};
626   };
627 
628   auto oldOmpRegion =
629       oldAlloc->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>();
630 
631   // Find when the last operand value becomes available
632   mlir::Block *operandsBlock = nullptr;
633   mlir::Operation *lastOperand = nullptr;
634   for (mlir::Value operand : oldAlloc.getOperands()) {
635     LLVM_DEBUG(llvm::dbgs() << "--considering operand " << operand << "\n");
636     mlir::Operation *op = operand.getDefiningOp();
637     if (!op)
638       return checkReturn(oldAlloc.getOperation());
639     if (!operandsBlock)
640       operandsBlock = op->getBlock();
641     else if (operandsBlock != op->getBlock()) {
642       LLVM_DEBUG(llvm::dbgs()
643                  << "----operand declared in a different block!\n");
644       // Operation::isBeforeInBlock requires the operations to be in the same
645       // block. The best we can do is the location of the allocmem.
646       return checkReturn(oldAlloc.getOperation());
647     }
648     if (!lastOperand || lastOperand->isBeforeInBlock(op))
649       lastOperand = op;
650   }
651 
652   if (lastOperand) {
653     // there were value operands to the allocmem so insert after the last one
654     LLVM_DEBUG(llvm::dbgs()
655                << "--Placing after last operand: " << *lastOperand << "\n");
656     // check we aren't moving out of an omp region
657     auto lastOpOmpRegion =
658         lastOperand->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>();
659     if (lastOpOmpRegion == oldOmpRegion)
660       return checkReturn(lastOperand);
661     // Presumably this happened because the operands became ready before the
662     // start of this openmp region. (lastOpOmpRegion != oldOmpRegion) should
663     // imply that oldOmpRegion comes after lastOpOmpRegion.
664     return checkReturn(oldOmpRegion.getAllocaBlock());
665   }
666 
667   // There were no value operands to the allocmem so we are safe to insert it
668   // as early as we want
669 
670   // handle openmp case
671   if (oldOmpRegion)
672     return checkReturn(oldOmpRegion.getAllocaBlock());
673 
674   // fall back to the function entry block
675   mlir::func::FuncOp func = oldAlloc->getParentOfType<mlir::func::FuncOp>();
676   assert(func && "This analysis is run on func.func");
677   mlir::Block &entryBlock = func.getBlocks().front();
678   LLVM_DEBUG(llvm::dbgs() << "--Placing at the start of func entry block\n");
679   return checkReturn(&entryBlock);
680 }
681 
682 InsertionPoint AllocMemConversion::findAllocaLoopInsertionPoint(
683     fir::AllocMemOp &oldAlloc,
684     const llvm::SmallVector<mlir::Operation *> &freeOps) {
685   mlir::Operation *oldAllocOp = oldAlloc;
686   // This is only called as a last resort. We should try to insert at the
687   // location of the old allocation, which is inside of a loop, using
688   // llvm.stacksave/llvm.stackrestore
689 
690   assert(freeOps.size() && "DFA should only return freed memory");
691 
692   // Don't attempt to reason about a stacksave/stackrestore between different
693   // blocks
694   for (mlir::Operation *free : freeOps)
695     if (free->getBlock() != oldAllocOp->getBlock())
696       return {nullptr};
697 
698   // Check that there aren't any other stack allocations in between the
699   // stack save and stack restore
700   // note: for flang generated temporaries there should only be one free op
701   for (mlir::Operation *free : freeOps) {
702     for (mlir::Operation *op = oldAlloc; op && op != free;
703          op = op->getNextNode()) {
704       if (mlir::isa<fir::AllocaOp>(op))
705         return {nullptr};
706     }
707   }
708 
709   return InsertionPoint{oldAllocOp, /*shouldStackSaveRestore=*/true};
710 }
711 
712 std::optional<fir::AllocaOp>
713 AllocMemConversion::insertAlloca(fir::AllocMemOp &oldAlloc,
714                                  mlir::PatternRewriter &rewriter) const {
715   auto it = candidateOps.find(oldAlloc.getOperation());
716   if (it == candidateOps.end())
717     return {};
718   InsertionPoint insertionPoint = it->second;
719   if (!insertionPoint)
720     return {};
721 
722   if (insertionPoint.shouldSaveRestoreStack())
723     insertStackSaveRestore(oldAlloc, rewriter);
724 
725   mlir::Location loc = oldAlloc.getLoc();
726   mlir::Type varTy = oldAlloc.getInType();
727   if (mlir::Operation *op = insertionPoint.tryGetOperation()) {
728     rewriter.setInsertionPointAfter(op);
729   } else {
730     mlir::Block *block = insertionPoint.tryGetBlock();
731     assert(block && "There must be a valid insertion point");
732     rewriter.setInsertionPointToStart(block);
733   }
734 
735   auto unpackName = [](std::optional<llvm::StringRef> opt) -> llvm::StringRef {
736     if (opt)
737       return *opt;
738     return {};
739   };
740 
741   llvm::StringRef uniqName = unpackName(oldAlloc.getUniqName());
742   llvm::StringRef bindcName = unpackName(oldAlloc.getBindcName());
743   return rewriter.create<fir::AllocaOp>(loc, varTy, uniqName, bindcName,
744                                         oldAlloc.getTypeparams(),
745                                         oldAlloc.getShape());
746 }
747 
748 void AllocMemConversion::insertStackSaveRestore(
749     fir::AllocMemOp &oldAlloc, mlir::PatternRewriter &rewriter) const {
750   auto oldPoint = rewriter.saveInsertionPoint();
751   auto mod = oldAlloc->getParentOfType<mlir::ModuleOp>();
752   fir::FirOpBuilder builder{rewriter, mod};
753 
754   builder.setInsertionPoint(oldAlloc);
755   mlir::Value sp = builder.genStackSave(oldAlloc.getLoc());
756 
757   auto createStackRestoreCall = [&](mlir::Operation *user) {
758     builder.setInsertionPoint(user);
759     builder.genStackRestore(user->getLoc(), sp);
760   };
761 
762   for (mlir::Operation *user : oldAlloc->getUsers()) {
763     if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
764       for (mlir::Operation *user : declareOp->getUsers()) {
765         if (mlir::isa<fir::FreeMemOp>(user))
766           createStackRestoreCall(user);
767       }
768     }
769 
770     if (mlir::isa<fir::FreeMemOp>(user)) {
771       createStackRestoreCall(user);
772     }
773   }
774 
775   rewriter.restoreInsertionPoint(oldPoint);
776 }
777 
778 StackArraysPass::StackArraysPass(const StackArraysPass &pass)
779     : fir::impl::StackArraysBase<StackArraysPass>(pass) {}
780 
781 llvm::StringRef StackArraysPass::getDescription() const {
782   return "Move heap allocated array temporaries to the stack";
783 }
784 
785 void StackArraysPass::runOnOperation() {
786   mlir::func::FuncOp func = getOperation();
787 
788   auto &analysis = getAnalysis<StackArraysAnalysisWrapper>();
789   const StackArraysAnalysisWrapper::AllocMemMap *candidateOps =
790       analysis.getCandidateOps(func);
791   if (!candidateOps) {
792     signalPassFailure();
793     return;
794   }
795 
796   if (candidateOps->empty())
797     return;
798   runCount += candidateOps->size();
799 
800   llvm::SmallVector<mlir::Operation *> opsToConvert;
801   opsToConvert.reserve(candidateOps->size());
802   for (auto [op, _] : *candidateOps)
803     opsToConvert.push_back(op);
804 
805   mlir::MLIRContext &context = getContext();
806   mlir::RewritePatternSet patterns(&context);
807   mlir::GreedyRewriteConfig config;
808   // prevent the pattern driver form merging blocks
809   config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
810 
811   patterns.insert<AllocMemConversion>(&context, *candidateOps);
812   if (mlir::failed(mlir::applyOpPatternsGreedily(
813           opsToConvert, std::move(patterns), config))) {
814     mlir::emitError(func->getLoc(), "error in stack arrays optimization\n");
815     signalPassFailure();
816   }
817 }
818