xref: /llvm-project/flang/lib/Optimizer/Transforms/StackArrays.cpp (revision fac349a169976f822fb27f03e623fa0d28aec1f3)
1 //===- StackArrays.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
11 #include "flang/Optimizer/Dialect/FIRAttr.h"
12 #include "flang/Optimizer/Dialect/FIRDialect.h"
13 #include "flang/Optimizer/Dialect/FIROps.h"
14 #include "flang/Optimizer/Dialect/FIRType.h"
15 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
16 #include "flang/Optimizer/Transforms/Passes.h"
17 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
18 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
19 #include "mlir/Analysis/DataFlow/DenseAnalysis.h"
20 #include "mlir/Analysis/DataFlowFramework.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/Value.h"
26 #include "mlir/Interfaces/LoopLikeInterface.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Support/LogicalResult.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 #include "mlir/Transforms/Passes.h"
31 #include "llvm/ADT/DenseMap.h"
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/PointerUnion.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Support/raw_ostream.h"
36 #include <optional>
37 
38 namespace fir {
39 #define GEN_PASS_DEF_STACKARRAYS
40 #include "flang/Optimizer/Transforms/Passes.h.inc"
41 } // namespace fir
42 
43 #define DEBUG_TYPE "stack-arrays"
44 
45 static llvm::cl::opt<std::size_t> maxAllocsPerFunc(
46     "stack-arrays-max-allocs",
47     llvm::cl::desc("The maximum number of heap allocations to consider in one "
48                    "function before skipping (to save compilation time). Set "
49                    "to 0 for no limit."),
50     llvm::cl::init(1000), llvm::cl::Hidden);
51 
52 namespace {
53 
54 /// The state of an SSA value at each program point
55 enum class AllocationState {
56   /// This means that the allocation state of a variable cannot be determined
57   /// at this program point, e.g. because one route through a conditional freed
58   /// the variable and the other route didn't.
59   /// This asserts a known-unknown: different from the unknown-unknown of having
60   /// no AllocationState stored for a particular SSA value
61   Unknown,
62   /// Means this SSA value was allocated on the heap in this function and has
63   /// now been freed
64   Freed,
65   /// Means this SSA value was allocated on the heap in this function and is a
66   /// candidate for moving to the stack
67   Allocated,
68 };
69 
70 /// Stores where an alloca should be inserted. If the PointerUnion is an
71 /// Operation the alloca should be inserted /after/ the operation. If it is a
72 /// block, the alloca can be placed anywhere in that block.
73 class InsertionPoint {
74   llvm::PointerUnion<mlir::Operation *, mlir::Block *> location;
75   bool saveRestoreStack;
76 
77   /// Get contained pointer type or nullptr
78   template <class T>
79   T *tryGetPtr() const {
80     if (location.is<T *>())
81       return location.get<T *>();
82     return nullptr;
83   }
84 
85 public:
86   template <class T>
87   InsertionPoint(T *ptr, bool saveRestoreStack = false)
88       : location(ptr), saveRestoreStack{saveRestoreStack} {}
89   InsertionPoint(std::nullptr_t null)
90       : location(null), saveRestoreStack{false} {}
91 
92   /// Get contained operation, or nullptr
93   mlir::Operation *tryGetOperation() const {
94     return tryGetPtr<mlir::Operation>();
95   }
96 
97   /// Get contained block, or nullptr
98   mlir::Block *tryGetBlock() const { return tryGetPtr<mlir::Block>(); }
99 
100   /// Get whether the stack should be saved/restored. If yes, an llvm.stacksave
101   /// intrinsic should be added before the alloca, and an llvm.stackrestore
102   /// intrinsic should be added where the freemem is
103   bool shouldSaveRestoreStack() const { return saveRestoreStack; }
104 
105   operator bool() const { return tryGetOperation() || tryGetBlock(); }
106 
107   bool operator==(const InsertionPoint &rhs) const {
108     return (location == rhs.location) &&
109            (saveRestoreStack == rhs.saveRestoreStack);
110   }
111 
112   bool operator!=(const InsertionPoint &rhs) const { return !(*this == rhs); }
113 };
114 
115 /// Maps SSA values to their AllocationState at a particular program point.
116 /// Also caches the insertion points for the new alloca operations
117 class LatticePoint : public mlir::dataflow::AbstractDenseLattice {
118   // Maps all values we are interested in to states
119   llvm::SmallDenseMap<mlir::Value, AllocationState, 1> stateMap;
120 
121 public:
122   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LatticePoint)
123   using AbstractDenseLattice::AbstractDenseLattice;
124 
125   bool operator==(const LatticePoint &rhs) const {
126     return stateMap == rhs.stateMap;
127   }
128 
129   /// Join the lattice accross control-flow edges
130   mlir::ChangeResult join(const AbstractDenseLattice &lattice) override;
131 
132   void print(llvm::raw_ostream &os) const override;
133 
134   /// Clear all modifications
135   mlir::ChangeResult reset();
136 
137   /// Set the state of an SSA value
138   mlir::ChangeResult set(mlir::Value value, AllocationState state);
139 
140   /// Get fir.allocmem ops which were allocated in this function and always
141   /// freed before the function returns, plus whre to insert replacement
142   /// fir.alloca ops
143   void appendFreedValues(llvm::DenseSet<mlir::Value> &out) const;
144 
145   std::optional<AllocationState> get(mlir::Value val) const;
146 };
147 
148 class AllocationAnalysis
149     : public mlir::dataflow::DenseForwardDataFlowAnalysis<LatticePoint> {
150 public:
151   using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis;
152 
153   void visitOperation(mlir::Operation *op, const LatticePoint &before,
154                       LatticePoint *after) override;
155 
156   /// At an entry point, the last modifications of all memory resources are
157   /// yet to be determined
158   void setToEntryState(LatticePoint *lattice) override;
159 
160 protected:
161   /// Visit control flow operations and decide whether to call visitOperation
162   /// to apply the transfer function
163   void processOperation(mlir::Operation *op) override;
164 };
165 
166 /// Drives analysis to find candidate fir.allocmem operations which could be
167 /// moved to the stack. Intended to be used with mlir::Pass::getAnalysis
168 class StackArraysAnalysisWrapper {
169 public:
170   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StackArraysAnalysisWrapper)
171 
172   // Maps fir.allocmem -> place to insert alloca
173   using AllocMemMap = llvm::DenseMap<mlir::Operation *, InsertionPoint>;
174 
175   StackArraysAnalysisWrapper(mlir::Operation *op) {}
176 
177   // returns nullptr if analysis failed
178   const AllocMemMap *getCandidateOps(mlir::Operation *func);
179 
180 private:
181   llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps;
182 
183   mlir::LogicalResult analyseFunction(mlir::Operation *func);
184 };
185 
186 /// Converts a fir.allocmem to a fir.alloca
187 class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
188 public:
189   explicit AllocMemConversion(
190       mlir::MLIRContext *ctx,
191       const StackArraysAnalysisWrapper::AllocMemMap &candidateOps)
192       : OpRewritePattern(ctx), candidateOps{candidateOps} {}
193 
194   mlir::LogicalResult
195   matchAndRewrite(fir::AllocMemOp allocmem,
196                   mlir::PatternRewriter &rewriter) const override;
197 
198   /// Determine where to insert the alloca operation. The returned value should
199   /// be checked to see if it is inside a loop
200   static InsertionPoint findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc);
201 
202 private:
203   /// Handle to the DFA (already run)
204   const StackArraysAnalysisWrapper::AllocMemMap &candidateOps;
205 
206   /// If we failed to find an insertion point not inside a loop, see if it would
207   /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop
208   static InsertionPoint findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc);
209 
210   /// Returns the alloca if it was successfully inserted, otherwise {}
211   std::optional<fir::AllocaOp>
212   insertAlloca(fir::AllocMemOp &oldAlloc,
213                mlir::PatternRewriter &rewriter) const;
214 
215   /// Inserts a stacksave before oldAlloc and a stackrestore after each freemem
216   void insertStackSaveRestore(fir::AllocMemOp &oldAlloc,
217                               mlir::PatternRewriter &rewriter) const;
218 };
219 
220 class StackArraysPass : public fir::impl::StackArraysBase<StackArraysPass> {
221 public:
222   StackArraysPass() = default;
223   StackArraysPass(const StackArraysPass &pass);
224 
225   llvm::StringRef getDescription() const override;
226 
227   void runOnOperation() override;
228   void runOnFunc(mlir::Operation *func);
229 
230 private:
231   Statistic runCount{this, "stackArraysRunCount",
232                      "Number of heap allocations moved to the stack"};
233 };
234 
235 } // namespace
236 
237 static void print(llvm::raw_ostream &os, AllocationState state) {
238   switch (state) {
239   case AllocationState::Unknown:
240     os << "Unknown";
241     break;
242   case AllocationState::Freed:
243     os << "Freed";
244     break;
245   case AllocationState::Allocated:
246     os << "Allocated";
247     break;
248   }
249 }
250 
251 /// Join two AllocationStates for the same value coming from different CFG
252 /// blocks
253 static AllocationState join(AllocationState lhs, AllocationState rhs) {
254   //           | Allocated | Freed     | Unknown
255   // ========= | ========= | ========= | =========
256   // Allocated | Allocated | Unknown   | Unknown
257   // Freed     | Unknown   | Freed     | Unknown
258   // Unknown   | Unknown   | Unknown   | Unknown
259   if (lhs == rhs)
260     return lhs;
261   return AllocationState::Unknown;
262 }
263 
264 mlir::ChangeResult LatticePoint::join(const AbstractDenseLattice &lattice) {
265   const auto &rhs = static_cast<const LatticePoint &>(lattice);
266   mlir::ChangeResult changed = mlir::ChangeResult::NoChange;
267 
268   // add everything from rhs to map, handling cases where values are in both
269   for (const auto &[value, rhsState] : rhs.stateMap) {
270     auto it = stateMap.find(value);
271     if (it != stateMap.end()) {
272       // value is present in both maps
273       AllocationState myState = it->second;
274       AllocationState newState = ::join(myState, rhsState);
275       if (newState != myState) {
276         changed = mlir::ChangeResult::Change;
277         it->getSecond() = newState;
278       }
279     } else {
280       // value not present in current map: add it
281       stateMap.insert({value, rhsState});
282       changed = mlir::ChangeResult::Change;
283     }
284   }
285 
286   return changed;
287 }
288 
289 void LatticePoint::print(llvm::raw_ostream &os) const {
290   for (const auto &[value, state] : stateMap) {
291     os << value << ": ";
292     ::print(os, state);
293   }
294 }
295 
296 mlir::ChangeResult LatticePoint::reset() {
297   if (stateMap.empty())
298     return mlir::ChangeResult::NoChange;
299   stateMap.clear();
300   return mlir::ChangeResult::Change;
301 }
302 
303 mlir::ChangeResult LatticePoint::set(mlir::Value value, AllocationState state) {
304   if (stateMap.count(value)) {
305     // already in map
306     AllocationState &oldState = stateMap[value];
307     if (oldState != state) {
308       stateMap[value] = state;
309       return mlir::ChangeResult::Change;
310     }
311     return mlir::ChangeResult::NoChange;
312   }
313   stateMap.insert({value, state});
314   return mlir::ChangeResult::Change;
315 }
316 
317 /// Get values which were allocated in this function and always freed before
318 /// the function returns
319 void LatticePoint::appendFreedValues(llvm::DenseSet<mlir::Value> &out) const {
320   for (auto &[value, state] : stateMap) {
321     if (state == AllocationState::Freed)
322       out.insert(value);
323   }
324 }
325 
326 std::optional<AllocationState> LatticePoint::get(mlir::Value val) const {
327   auto it = stateMap.find(val);
328   if (it == stateMap.end())
329     return {};
330   return it->second;
331 }
332 
333 void AllocationAnalysis::visitOperation(mlir::Operation *op,
334                                         const LatticePoint &before,
335                                         LatticePoint *after) {
336   LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op
337                           << "\n");
338   LLVM_DEBUG(llvm::dbgs() << "--Lattice in: " << before << "\n");
339 
340   // propagate before -> after
341   mlir::ChangeResult changed = after->join(before);
342 
343   if (auto allocmem = mlir::dyn_cast<fir::AllocMemOp>(op)) {
344     assert(op->getNumResults() == 1 && "fir.allocmem has one result");
345     auto attr = op->getAttrOfType<fir::MustBeHeapAttr>(
346         fir::MustBeHeapAttr::getAttrName());
347     if (attr && attr.getValue()) {
348       LLVM_DEBUG(llvm::dbgs() << "--Found fir.must_be_heap: skipping\n");
349       // skip allocation marked not to be moved
350       return;
351     }
352 
353     auto retTy = allocmem.getAllocatedType();
354     if (!mlir::isa<fir::SequenceType>(retTy)) {
355       LLVM_DEBUG(llvm::dbgs()
356                  << "--Allocation is not for an array: skipping\n");
357       return;
358     }
359 
360     mlir::Value result = op->getResult(0);
361     changed |= after->set(result, AllocationState::Allocated);
362   } else if (mlir::isa<fir::FreeMemOp>(op)) {
363     assert(op->getNumOperands() == 1 && "fir.freemem has one operand");
364     mlir::Value operand = op->getOperand(0);
365     std::optional<AllocationState> operandState = before.get(operand);
366     if (operandState && *operandState == AllocationState::Allocated) {
367       // don't tag things not allocated in this function as freed, so that we
368       // don't think they are candidates for moving to the stack
369       changed |= after->set(operand, AllocationState::Freed);
370     }
371   } else if (mlir::isa<fir::ResultOp>(op)) {
372     mlir::Operation *parent = op->getParentOp();
373     LatticePoint *parentLattice = getLattice(parent);
374     assert(parentLattice);
375     mlir::ChangeResult parentChanged = parentLattice->join(*after);
376     propagateIfChanged(parentLattice, parentChanged);
377   }
378 
379   // we pass lattices straight through fir.call because called functions should
380   // not deallocate flang-generated array temporaries
381 
382   LLVM_DEBUG(llvm::dbgs() << "--Lattice out: " << *after << "\n");
383   propagateIfChanged(after, changed);
384 }
385 
386 void AllocationAnalysis::setToEntryState(LatticePoint *lattice) {
387   propagateIfChanged(lattice, lattice->reset());
388 }
389 
390 /// Mostly a copy of AbstractDenseLattice::processOperation - the difference
391 /// being that call operations are passed through to the transfer function
392 void AllocationAnalysis::processOperation(mlir::Operation *op) {
393   // If the containing block is not executable, bail out.
394   if (!getOrCreateFor<mlir::dataflow::Executable>(op, op->getBlock())->isLive())
395     return;
396 
397   // Get the dense lattice to update
398   mlir::dataflow::AbstractDenseLattice *after = getLattice(op);
399 
400   // If this op implements region control-flow, then control-flow dictates its
401   // transfer function.
402   if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op))
403     return visitRegionBranchOperation(op, branch, after);
404 
405   // pass call operations through to the transfer function
406 
407   // Get the dense state before the execution of the op.
408   const mlir::dataflow::AbstractDenseLattice *before;
409   if (mlir::Operation *prev = op->getPrevNode())
410     before = getLatticeFor(op, prev);
411   else
412     before = getLatticeFor(op, op->getBlock());
413 
414   /// Invoke the operation transfer function
415   visitOperationImpl(op, *before, after);
416 }
417 
418 mlir::LogicalResult
419 StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
420   assert(mlir::isa<mlir::func::FuncOp>(func));
421   size_t nAllocs = 0;
422   func->walk([&nAllocs](fir::AllocMemOp) { nAllocs++; });
423   // don't bother with the analysis if there are no heap allocations
424   if (nAllocs == 0)
425     return mlir::success();
426   if ((maxAllocsPerFunc != 0) && (nAllocs > maxAllocsPerFunc)) {
427     LLVM_DEBUG(llvm::dbgs() << "Skipping stack arrays for function with "
428                             << nAllocs << " heap allocations");
429     return mlir::success();
430   }
431 
432   mlir::DataFlowSolver solver;
433   // constant propagation is required for dead code analysis, dead code analysis
434   // is required to mark blocks live (required for mlir dense dfa)
435   solver.load<mlir::dataflow::SparseConstantPropagation>();
436   solver.load<mlir::dataflow::DeadCodeAnalysis>();
437 
438   auto [it, inserted] = funcMaps.try_emplace(func);
439   AllocMemMap &candidateOps = it->second;
440 
441   solver.load<AllocationAnalysis>();
442   if (failed(solver.initializeAndRun(func))) {
443     llvm::errs() << "DataFlowSolver failed!";
444     return mlir::failure();
445   }
446 
447   LatticePoint point{func};
448   auto joinOperationLattice = [&](mlir::Operation *op) {
449     const LatticePoint *lattice = solver.lookupState<LatticePoint>(op);
450     // there will be no lattice for an unreachable block
451     if (lattice)
452       (void)point.join(*lattice);
453   };
454   func->walk([&](mlir::func::ReturnOp child) { joinOperationLattice(child); });
455   func->walk([&](fir::UnreachableOp child) { joinOperationLattice(child); });
456   llvm::DenseSet<mlir::Value> freedValues;
457   point.appendFreedValues(freedValues);
458 
459   // We only replace allocations which are definately freed on all routes
460   // through the function because otherwise the allocation may have an intende
461   // lifetime longer than the current stack frame (e.g. a heap allocation which
462   // is then freed by another function).
463   for (mlir::Value freedValue : freedValues) {
464     fir::AllocMemOp allocmem = freedValue.getDefiningOp<fir::AllocMemOp>();
465     InsertionPoint insertionPoint =
466         AllocMemConversion::findAllocaInsertionPoint(allocmem);
467     if (insertionPoint)
468       candidateOps.insert({allocmem, insertionPoint});
469   }
470 
471   LLVM_DEBUG(for (auto [allocMemOp, _]
472                   : candidateOps) {
473     llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n';
474   });
475   return mlir::success();
476 }
477 
478 const StackArraysAnalysisWrapper::AllocMemMap *
479 StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) {
480   if (!funcMaps.contains(func))
481     if (mlir::failed(analyseFunction(func)))
482       return nullptr;
483   return &funcMaps[func];
484 }
485 
486 /// Restore the old allocation type exected by existing code
487 static mlir::Value convertAllocationType(mlir::PatternRewriter &rewriter,
488                                          const mlir::Location &loc,
489                                          mlir::Value heap, mlir::Value stack) {
490   mlir::Type heapTy = heap.getType();
491   mlir::Type stackTy = stack.getType();
492 
493   if (heapTy == stackTy)
494     return stack;
495 
496   fir::HeapType firHeapTy = mlir::cast<fir::HeapType>(heapTy);
497   LLVM_ATTRIBUTE_UNUSED fir::ReferenceType firRefTy =
498       mlir::cast<fir::ReferenceType>(stackTy);
499   assert(firHeapTy.getElementType() == firRefTy.getElementType() &&
500          "Allocations must have the same type");
501 
502   auto insertionPoint = rewriter.saveInsertionPoint();
503   rewriter.setInsertionPointAfter(stack.getDefiningOp());
504   mlir::Value conv =
505       rewriter.create<fir::ConvertOp>(loc, firHeapTy, stack).getResult();
506   rewriter.restoreInsertionPoint(insertionPoint);
507   return conv;
508 }
509 
510 mlir::LogicalResult
511 AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
512                                     mlir::PatternRewriter &rewriter) const {
513   auto oldInsertionPt = rewriter.saveInsertionPoint();
514   // add alloca operation
515   std::optional<fir::AllocaOp> alloca = insertAlloca(allocmem, rewriter);
516   rewriter.restoreInsertionPoint(oldInsertionPt);
517   if (!alloca)
518     return mlir::failure();
519 
520   // remove freemem operations
521   llvm::SmallVector<mlir::Operation *> erases;
522   for (mlir::Operation *user : allocmem.getOperation()->getUsers())
523     if (mlir::isa<fir::FreeMemOp>(user))
524       erases.push_back(user);
525   // now we are done iterating the users, it is safe to mutate them
526   for (mlir::Operation *erase : erases)
527     rewriter.eraseOp(erase);
528 
529   // replace references to heap allocation with references to stack allocation
530   mlir::Value newValue = convertAllocationType(
531       rewriter, allocmem.getLoc(), allocmem.getResult(), alloca->getResult());
532   rewriter.replaceAllUsesWith(allocmem.getResult(), newValue);
533 
534   // remove allocmem operation
535   rewriter.eraseOp(allocmem.getOperation());
536 
537   return mlir::success();
538 }
539 
540 static bool isInLoop(mlir::Block *block) {
541   return mlir::LoopLikeOpInterface::blockIsInLoop(block);
542 }
543 
544 static bool isInLoop(mlir::Operation *op) {
545   return isInLoop(op->getBlock()) ||
546          op->getParentOfType<mlir::LoopLikeOpInterface>();
547 }
548 
549 InsertionPoint
550 AllocMemConversion::findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc) {
551   // Ideally the alloca should be inserted at the end of the function entry
552   // block so that we do not allocate stack space in a loop. However,
553   // the operands to the alloca may not be available that early, so insert it
554   // after the last operand becomes available
555   // If the old allocmem op was in an openmp region then it should not be moved
556   // outside of that
557   LLVM_DEBUG(llvm::dbgs() << "StackArrays: findAllocaInsertionPoint: "
558                           << oldAlloc << "\n");
559 
560   // check that an Operation or Block we are about to return is not in a loop
561   auto checkReturn = [&](auto *point) -> InsertionPoint {
562     if (isInLoop(point)) {
563       mlir::Operation *oldAllocOp = oldAlloc.getOperation();
564       if (isInLoop(oldAllocOp)) {
565         // where we want to put it is in a loop, and even the old location is in
566         // a loop. Give up.
567         return findAllocaLoopInsertionPoint(oldAlloc);
568       }
569       return {oldAllocOp};
570     }
571     return {point};
572   };
573 
574   auto oldOmpRegion =
575       oldAlloc->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>();
576 
577   // Find when the last operand value becomes available
578   mlir::Block *operandsBlock = nullptr;
579   mlir::Operation *lastOperand = nullptr;
580   for (mlir::Value operand : oldAlloc.getOperands()) {
581     LLVM_DEBUG(llvm::dbgs() << "--considering operand " << operand << "\n");
582     mlir::Operation *op = operand.getDefiningOp();
583     if (!op)
584       return checkReturn(oldAlloc.getOperation());
585     if (!operandsBlock)
586       operandsBlock = op->getBlock();
587     else if (operandsBlock != op->getBlock()) {
588       LLVM_DEBUG(llvm::dbgs()
589                  << "----operand declared in a different block!\n");
590       // Operation::isBeforeInBlock requires the operations to be in the same
591       // block. The best we can do is the location of the allocmem.
592       return checkReturn(oldAlloc.getOperation());
593     }
594     if (!lastOperand || lastOperand->isBeforeInBlock(op))
595       lastOperand = op;
596   }
597 
598   if (lastOperand) {
599     // there were value operands to the allocmem so insert after the last one
600     LLVM_DEBUG(llvm::dbgs()
601                << "--Placing after last operand: " << *lastOperand << "\n");
602     // check we aren't moving out of an omp region
603     auto lastOpOmpRegion =
604         lastOperand->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>();
605     if (lastOpOmpRegion == oldOmpRegion)
606       return checkReturn(lastOperand);
607     // Presumably this happened because the operands became ready before the
608     // start of this openmp region. (lastOpOmpRegion != oldOmpRegion) should
609     // imply that oldOmpRegion comes after lastOpOmpRegion.
610     return checkReturn(oldOmpRegion.getAllocaBlock());
611   }
612 
613   // There were no value operands to the allocmem so we are safe to insert it
614   // as early as we want
615 
616   // handle openmp case
617   if (oldOmpRegion)
618     return checkReturn(oldOmpRegion.getAllocaBlock());
619 
620   // fall back to the function entry block
621   mlir::func::FuncOp func = oldAlloc->getParentOfType<mlir::func::FuncOp>();
622   assert(func && "This analysis is run on func.func");
623   mlir::Block &entryBlock = func.getBlocks().front();
624   LLVM_DEBUG(llvm::dbgs() << "--Placing at the start of func entry block\n");
625   return checkReturn(&entryBlock);
626 }
627 
628 InsertionPoint
629 AllocMemConversion::findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc) {
630   mlir::Operation *oldAllocOp = oldAlloc;
631   // This is only called as a last resort. We should try to insert at the
632   // location of the old allocation, which is inside of a loop, using
633   // llvm.stacksave/llvm.stackrestore
634 
635   // find freemem ops
636   llvm::SmallVector<mlir::Operation *, 1> freeOps;
637   for (mlir::Operation *user : oldAllocOp->getUsers())
638     if (mlir::isa<fir::FreeMemOp>(user))
639       freeOps.push_back(user);
640   assert(freeOps.size() && "DFA should only return freed memory");
641 
642   // Don't attempt to reason about a stacksave/stackrestore between different
643   // blocks
644   for (mlir::Operation *free : freeOps)
645     if (free->getBlock() != oldAllocOp->getBlock())
646       return {nullptr};
647 
648   // Check that there aren't any other stack allocations in between the
649   // stack save and stack restore
650   // note: for flang generated temporaries there should only be one free op
651   for (mlir::Operation *free : freeOps) {
652     for (mlir::Operation *op = oldAlloc; op && op != free;
653          op = op->getNextNode()) {
654       if (mlir::isa<fir::AllocaOp>(op))
655         return {nullptr};
656     }
657   }
658 
659   return InsertionPoint{oldAllocOp, /*shouldStackSaveRestore=*/true};
660 }
661 
662 std::optional<fir::AllocaOp>
663 AllocMemConversion::insertAlloca(fir::AllocMemOp &oldAlloc,
664                                  mlir::PatternRewriter &rewriter) const {
665   auto it = candidateOps.find(oldAlloc.getOperation());
666   if (it == candidateOps.end())
667     return {};
668   InsertionPoint insertionPoint = it->second;
669   if (!insertionPoint)
670     return {};
671 
672   if (insertionPoint.shouldSaveRestoreStack())
673     insertStackSaveRestore(oldAlloc, rewriter);
674 
675   mlir::Location loc = oldAlloc.getLoc();
676   mlir::Type varTy = oldAlloc.getInType();
677   if (mlir::Operation *op = insertionPoint.tryGetOperation()) {
678     rewriter.setInsertionPointAfter(op);
679   } else {
680     mlir::Block *block = insertionPoint.tryGetBlock();
681     assert(block && "There must be a valid insertion point");
682     rewriter.setInsertionPointToStart(block);
683   }
684 
685   auto unpackName = [](std::optional<llvm::StringRef> opt) -> llvm::StringRef {
686     if (opt)
687       return *opt;
688     return {};
689   };
690 
691   llvm::StringRef uniqName = unpackName(oldAlloc.getUniqName());
692   llvm::StringRef bindcName = unpackName(oldAlloc.getBindcName());
693   return rewriter.create<fir::AllocaOp>(loc, varTy, uniqName, bindcName,
694                                         oldAlloc.getTypeparams(),
695                                         oldAlloc.getShape());
696 }
697 
698 void AllocMemConversion::insertStackSaveRestore(
699     fir::AllocMemOp &oldAlloc, mlir::PatternRewriter &rewriter) const {
700   auto oldPoint = rewriter.saveInsertionPoint();
701   auto mod = oldAlloc->getParentOfType<mlir::ModuleOp>();
702   fir::FirOpBuilder builder{rewriter, mod};
703 
704   mlir::func::FuncOp stackSaveFn = fir::factory::getLlvmStackSave(builder);
705   mlir::SymbolRefAttr stackSaveSym =
706       builder.getSymbolRefAttr(stackSaveFn.getName());
707 
708   builder.setInsertionPoint(oldAlloc);
709   mlir::Value sp =
710       builder
711           .create<fir::CallOp>(oldAlloc.getLoc(),
712                                stackSaveFn.getFunctionType().getResults(),
713                                stackSaveSym, mlir::ValueRange{})
714           .getResult(0);
715 
716   mlir::func::FuncOp stackRestoreFn =
717       fir::factory::getLlvmStackRestore(builder);
718   mlir::SymbolRefAttr stackRestoreSym =
719       builder.getSymbolRefAttr(stackRestoreFn.getName());
720 
721   for (mlir::Operation *user : oldAlloc->getUsers()) {
722     if (mlir::isa<fir::FreeMemOp>(user)) {
723       builder.setInsertionPoint(user);
724       builder.create<fir::CallOp>(user->getLoc(),
725                                   stackRestoreFn.getFunctionType().getResults(),
726                                   stackRestoreSym, mlir::ValueRange{sp});
727     }
728   }
729 
730   rewriter.restoreInsertionPoint(oldPoint);
731 }
732 
733 StackArraysPass::StackArraysPass(const StackArraysPass &pass)
734     : fir::impl::StackArraysBase<StackArraysPass>(pass) {}
735 
736 llvm::StringRef StackArraysPass::getDescription() const {
737   return "Move heap allocated array temporaries to the stack";
738 }
739 
740 void StackArraysPass::runOnOperation() {
741   mlir::ModuleOp mod = getOperation();
742 
743   mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); });
744 }
745 
746 void StackArraysPass::runOnFunc(mlir::Operation *func) {
747   assert(mlir::isa<mlir::func::FuncOp>(func));
748 
749   auto &analysis = getAnalysis<StackArraysAnalysisWrapper>();
750   const StackArraysAnalysisWrapper::AllocMemMap *candidateOps =
751       analysis.getCandidateOps(func);
752   if (!candidateOps) {
753     signalPassFailure();
754     return;
755   }
756 
757   if (candidateOps->empty())
758     return;
759   runCount += candidateOps->size();
760 
761   llvm::SmallVector<mlir::Operation *> opsToConvert;
762   opsToConvert.reserve(candidateOps->size());
763   for (auto [op, _] : *candidateOps)
764     opsToConvert.push_back(op);
765 
766   mlir::MLIRContext &context = getContext();
767   mlir::RewritePatternSet patterns(&context);
768   mlir::GreedyRewriteConfig config;
769   // prevent the pattern driver form merging blocks
770   config.enableRegionSimplification = false;
771 
772   patterns.insert<AllocMemConversion>(&context, *candidateOps);
773   if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert,
774                                                 std::move(patterns), config))) {
775     mlir::emitError(func->getLoc(), "error in stack arrays optimization\n");
776     signalPassFailure();
777   }
778 }
779