xref: /llvm-project/flang/lib/Optimizer/Transforms/StackArrays.cpp (revision 53cc33b00b5bf5cec5de214c1566289c927a83ca)
1 //===- StackArrays.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
11 #include "flang/Optimizer/Dialect/FIRAttr.h"
12 #include "flang/Optimizer/Dialect/FIRDialect.h"
13 #include "flang/Optimizer/Dialect/FIROps.h"
14 #include "flang/Optimizer/Dialect/FIRType.h"
15 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
16 #include "flang/Optimizer/Transforms/Passes.h"
17 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
18 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
19 #include "mlir/Analysis/DataFlow/DenseAnalysis.h"
20 #include "mlir/Analysis/DataFlowFramework.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/Value.h"
26 #include "mlir/Interfaces/LoopLikeInterface.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Support/LogicalResult.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 #include "mlir/Transforms/Passes.h"
31 #include "llvm/ADT/DenseMap.h"
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/PointerUnion.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Support/raw_ostream.h"
36 #include <optional>
37 
38 namespace fir {
39 #define GEN_PASS_DEF_STACKARRAYS
40 #include "flang/Optimizer/Transforms/Passes.h.inc"
41 } // namespace fir
42 
43 #define DEBUG_TYPE "stack-arrays"
44 
45 namespace {
46 
47 /// The state of an SSA value at each program point
48 enum class AllocationState {
49   /// This means that the allocation state of a variable cannot be determined
50   /// at this program point, e.g. because one route through a conditional freed
51   /// the variable and the other route didn't.
52   /// This asserts a known-unknown: different from the unknown-unknown of having
53   /// no AllocationState stored for a particular SSA value
54   Unknown,
55   /// Means this SSA value was allocated on the heap in this function and has
56   /// now been freed
57   Freed,
58   /// Means this SSA value was allocated on the heap in this function and is a
59   /// candidate for moving to the stack
60   Allocated,
61 };
62 
63 /// Stores where an alloca should be inserted. If the PointerUnion is an
64 /// Operation the alloca should be inserted /after/ the operation. If it is a
65 /// block, the alloca can be placed anywhere in that block.
66 class InsertionPoint {
67   llvm::PointerUnion<mlir::Operation *, mlir::Block *> location;
68   bool saveRestoreStack;
69 
70   /// Get contained pointer type or nullptr
71   template <class T>
72   T *tryGetPtr() const {
73     if (location.is<T *>())
74       return location.get<T *>();
75     return nullptr;
76   }
77 
78 public:
79   template <class T>
80   InsertionPoint(T *ptr, bool saveRestoreStack = false)
81       : location(ptr), saveRestoreStack{saveRestoreStack} {}
82   InsertionPoint(std::nullptr_t null)
83       : location(null), saveRestoreStack{false} {}
84 
85   /// Get contained operation, or nullptr
86   mlir::Operation *tryGetOperation() const {
87     return tryGetPtr<mlir::Operation>();
88   }
89 
90   /// Get contained block, or nullptr
91   mlir::Block *tryGetBlock() const { return tryGetPtr<mlir::Block>(); }
92 
93   /// Get whether the stack should be saved/restored. If yes, an llvm.stacksave
94   /// intrinsic should be added before the alloca, and an llvm.stackrestore
95   /// intrinsic should be added where the freemem is
96   bool shouldSaveRestoreStack() const { return saveRestoreStack; }
97 
98   operator bool() const { return tryGetOperation() || tryGetBlock(); }
99 
100   bool operator==(const InsertionPoint &rhs) const {
101     return (location == rhs.location) &&
102            (saveRestoreStack == rhs.saveRestoreStack);
103   }
104 
105   bool operator!=(const InsertionPoint &rhs) const { return !(*this == rhs); }
106 };
107 
108 /// Maps SSA values to their AllocationState at a particular program point.
109 /// Also caches the insertion points for the new alloca operations
110 class LatticePoint : public mlir::dataflow::AbstractDenseLattice {
111   // Maps all values we are interested in to states
112   llvm::SmallDenseMap<mlir::Value, AllocationState, 1> stateMap;
113 
114 public:
115   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LatticePoint)
116   using AbstractDenseLattice::AbstractDenseLattice;
117 
118   bool operator==(const LatticePoint &rhs) const {
119     return stateMap == rhs.stateMap;
120   }
121 
122   /// Join the lattice accross control-flow edges
123   mlir::ChangeResult join(const AbstractDenseLattice &lattice) override;
124 
125   void print(llvm::raw_ostream &os) const override;
126 
127   /// Clear all modifications
128   mlir::ChangeResult reset();
129 
130   /// Set the state of an SSA value
131   mlir::ChangeResult set(mlir::Value value, AllocationState state);
132 
133   /// Get fir.allocmem ops which were allocated in this function and always
134   /// freed before the function returns, plus whre to insert replacement
135   /// fir.alloca ops
136   void appendFreedValues(llvm::DenseSet<mlir::Value> &out) const;
137 
138   std::optional<AllocationState> get(mlir::Value val) const;
139 };
140 
141 class AllocationAnalysis
142     : public mlir::dataflow::DenseDataFlowAnalysis<LatticePoint> {
143 public:
144   using DenseDataFlowAnalysis::DenseDataFlowAnalysis;
145 
146   void visitOperation(mlir::Operation *op, const LatticePoint &before,
147                       LatticePoint *after) override;
148 
149   /// At an entry point, the last modifications of all memory resources are
150   /// yet to be determined
151   void setToEntryState(LatticePoint *lattice) override;
152 
153 protected:
154   /// Visit control flow operations and decide whether to call visitOperation
155   /// to apply the transfer function
156   void processOperation(mlir::Operation *op) override;
157 };
158 
159 /// Drives analysis to find candidate fir.allocmem operations which could be
160 /// moved to the stack. Intended to be used with mlir::Pass::getAnalysis
161 class StackArraysAnalysisWrapper {
162 public:
163   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StackArraysAnalysisWrapper)
164 
165   // Maps fir.allocmem -> place to insert alloca
166   using AllocMemMap = llvm::DenseMap<mlir::Operation *, InsertionPoint>;
167 
168   StackArraysAnalysisWrapper(mlir::Operation *op) {}
169 
170   // returns nullptr if analysis failed
171   const AllocMemMap *getCandidateOps(mlir::Operation *func);
172 
173 private:
174   llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps;
175 
176   mlir::LogicalResult analyseFunction(mlir::Operation *func);
177 };
178 
179 /// Converts a fir.allocmem to a fir.alloca
180 class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
181 public:
182   explicit AllocMemConversion(
183       mlir::MLIRContext *ctx,
184       const StackArraysAnalysisWrapper::AllocMemMap &candidateOps)
185       : OpRewritePattern(ctx), candidateOps{candidateOps} {}
186 
187   mlir::LogicalResult
188   matchAndRewrite(fir::AllocMemOp allocmem,
189                   mlir::PatternRewriter &rewriter) const override;
190 
191   /// Determine where to insert the alloca operation. The returned value should
192   /// be checked to see if it is inside a loop
193   static InsertionPoint findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc);
194 
195 private:
196   /// Handle to the DFA (already run)
197   const StackArraysAnalysisWrapper::AllocMemMap &candidateOps;
198 
199   /// If we failed to find an insertion point not inside a loop, see if it would
200   /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop
201   static InsertionPoint findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc);
202 
203   /// Returns the alloca if it was successfully inserted, otherwise {}
204   std::optional<fir::AllocaOp>
205   insertAlloca(fir::AllocMemOp &oldAlloc,
206                mlir::PatternRewriter &rewriter) const;
207 
208   /// Inserts a stacksave before oldAlloc and a stackrestore after each freemem
209   void insertStackSaveRestore(fir::AllocMemOp &oldAlloc,
210                               mlir::PatternRewriter &rewriter) const;
211 };
212 
213 class StackArraysPass : public fir::impl::StackArraysBase<StackArraysPass> {
214 public:
215   StackArraysPass() = default;
216   StackArraysPass(const StackArraysPass &pass);
217 
218   llvm::StringRef getDescription() const override;
219 
220   void runOnOperation() override;
221   void runOnFunc(mlir::Operation *func);
222 
223 private:
224   Statistic runCount{this, "stackArraysRunCount",
225                      "Number of heap allocations moved to the stack"};
226 };
227 
228 } // namespace
229 
230 static void print(llvm::raw_ostream &os, AllocationState state) {
231   switch (state) {
232   case AllocationState::Unknown:
233     os << "Unknown";
234     break;
235   case AllocationState::Freed:
236     os << "Freed";
237     break;
238   case AllocationState::Allocated:
239     os << "Allocated";
240     break;
241   }
242 }
243 
244 /// Join two AllocationStates for the same value coming from different CFG
245 /// blocks
246 static AllocationState join(AllocationState lhs, AllocationState rhs) {
247   //           | Allocated | Freed     | Unknown
248   // ========= | ========= | ========= | =========
249   // Allocated | Allocated | Unknown   | Unknown
250   // Freed     | Unknown   | Freed     | Unknown
251   // Unknown   | Unknown   | Unknown   | Unknown
252   if (lhs == rhs)
253     return lhs;
254   return AllocationState::Unknown;
255 }
256 
257 mlir::ChangeResult LatticePoint::join(const AbstractDenseLattice &lattice) {
258   const auto &rhs = static_cast<const LatticePoint &>(lattice);
259   mlir::ChangeResult changed = mlir::ChangeResult::NoChange;
260 
261   // add everything from rhs to map, handling cases where values are in both
262   for (const auto &[value, rhsState] : rhs.stateMap) {
263     auto it = stateMap.find(value);
264     if (it != stateMap.end()) {
265       // value is present in both maps
266       AllocationState myState = it->second;
267       AllocationState newState = ::join(myState, rhsState);
268       if (newState != myState) {
269         changed = mlir::ChangeResult::Change;
270         it->getSecond() = newState;
271       }
272     } else {
273       // value not present in current map: add it
274       stateMap.insert({value, rhsState});
275       changed = mlir::ChangeResult::Change;
276     }
277   }
278 
279   return changed;
280 }
281 
282 void LatticePoint::print(llvm::raw_ostream &os) const {
283   for (const auto &[value, state] : stateMap) {
284     os << value << ": ";
285     ::print(os, state);
286   }
287 }
288 
289 mlir::ChangeResult LatticePoint::reset() {
290   if (stateMap.empty())
291     return mlir::ChangeResult::NoChange;
292   stateMap.clear();
293   return mlir::ChangeResult::Change;
294 }
295 
296 mlir::ChangeResult LatticePoint::set(mlir::Value value, AllocationState state) {
297   if (stateMap.count(value)) {
298     // already in map
299     AllocationState &oldState = stateMap[value];
300     if (oldState != state) {
301       stateMap[value] = state;
302       return mlir::ChangeResult::Change;
303     }
304     return mlir::ChangeResult::NoChange;
305   }
306   stateMap.insert({value, state});
307   return mlir::ChangeResult::Change;
308 }
309 
310 /// Get values which were allocated in this function and always freed before
311 /// the function returns
312 void LatticePoint::appendFreedValues(llvm::DenseSet<mlir::Value> &out) const {
313   for (auto &[value, state] : stateMap) {
314     if (state == AllocationState::Freed)
315       out.insert(value);
316   }
317 }
318 
319 std::optional<AllocationState> LatticePoint::get(mlir::Value val) const {
320   auto it = stateMap.find(val);
321   if (it == stateMap.end())
322     return {};
323   return it->second;
324 }
325 
326 void AllocationAnalysis::visitOperation(mlir::Operation *op,
327                                         const LatticePoint &before,
328                                         LatticePoint *after) {
329   LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op
330                           << "\n");
331   LLVM_DEBUG(llvm::dbgs() << "--Lattice in: " << before << "\n");
332 
333   // propagate before -> after
334   mlir::ChangeResult changed = after->join(before);
335 
336   if (auto allocmem = mlir::dyn_cast<fir::AllocMemOp>(op)) {
337     assert(op->getNumResults() == 1 && "fir.allocmem has one result");
338     auto attr = op->getAttrOfType<fir::MustBeHeapAttr>(
339         fir::MustBeHeapAttr::getAttrName());
340     if (attr && attr.getValue()) {
341       LLVM_DEBUG(llvm::dbgs() << "--Found fir.must_be_heap: skipping\n");
342       // skip allocation marked not to be moved
343       return;
344     }
345 
346     auto retTy = allocmem.getAllocatedType();
347     if (!retTy.isa<fir::SequenceType>()) {
348       LLVM_DEBUG(llvm::dbgs()
349                  << "--Allocation is not for an array: skipping\n");
350       return;
351     }
352 
353     mlir::Value result = op->getResult(0);
354     changed |= after->set(result, AllocationState::Allocated);
355   } else if (mlir::isa<fir::FreeMemOp>(op)) {
356     assert(op->getNumOperands() == 1 && "fir.freemem has one operand");
357     mlir::Value operand = op->getOperand(0);
358     std::optional<AllocationState> operandState = before.get(operand);
359     if (operandState && *operandState == AllocationState::Allocated) {
360       // don't tag things not allocated in this function as freed, so that we
361       // don't think they are candidates for moving to the stack
362       changed |= after->set(operand, AllocationState::Freed);
363     }
364   } else if (mlir::isa<fir::ResultOp>(op)) {
365     mlir::Operation *parent = op->getParentOp();
366     LatticePoint *parentLattice = getLattice(parent);
367     assert(parentLattice);
368     mlir::ChangeResult parentChanged = parentLattice->join(*after);
369     propagateIfChanged(parentLattice, parentChanged);
370   }
371 
372   // we pass lattices straight through fir.call because called functions should
373   // not deallocate flang-generated array temporaries
374 
375   LLVM_DEBUG(llvm::dbgs() << "--Lattice out: " << *after << "\n");
376   propagateIfChanged(after, changed);
377 }
378 
379 void AllocationAnalysis::setToEntryState(LatticePoint *lattice) {
380   propagateIfChanged(lattice, lattice->reset());
381 }
382 
383 /// Mostly a copy of AbstractDenseLattice::processOperation - the difference
384 /// being that call operations are passed through to the transfer function
385 void AllocationAnalysis::processOperation(mlir::Operation *op) {
386   // If the containing block is not executable, bail out.
387   if (!getOrCreateFor<mlir::dataflow::Executable>(op, op->getBlock())->isLive())
388     return;
389 
390   // Get the dense lattice to update
391   mlir::dataflow::AbstractDenseLattice *after = getLattice(op);
392 
393   // If this op implements region control-flow, then control-flow dictates its
394   // transfer function.
395   if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op))
396     return visitRegionBranchOperation(op, branch, after);
397 
398   // pass call operations through to the transfer function
399 
400   // Get the dense state before the execution of the op.
401   const mlir::dataflow::AbstractDenseLattice *before;
402   if (mlir::Operation *prev = op->getPrevNode())
403     before = getLatticeFor(op, prev);
404   else
405     before = getLatticeFor(op, op->getBlock());
406 
407   /// Invoke the operation transfer function
408   visitOperationImpl(op, *before, after);
409 }
410 
411 mlir::LogicalResult
412 StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
413   assert(mlir::isa<mlir::func::FuncOp>(func));
414   mlir::DataFlowSolver solver;
415   // constant propagation is required for dead code analysis, dead code analysis
416   // is required to mark blocks live (required for mlir dense dfa)
417   solver.load<mlir::dataflow::SparseConstantPropagation>();
418   solver.load<mlir::dataflow::DeadCodeAnalysis>();
419 
420   auto [it, inserted] = funcMaps.try_emplace(func);
421   AllocMemMap &candidateOps = it->second;
422 
423   solver.load<AllocationAnalysis>();
424   if (failed(solver.initializeAndRun(func))) {
425     llvm::errs() << "DataFlowSolver failed!";
426     return mlir::failure();
427   }
428 
429   LatticePoint point{func};
430   auto joinOperationLattice = [&](mlir::Operation *op) {
431     const LatticePoint *lattice = solver.lookupState<LatticePoint>(op);
432     // there will be no lattice for an unreachable block
433     if (lattice)
434       point.join(*lattice);
435   };
436   func->walk([&](mlir::func::ReturnOp child) { joinOperationLattice(child); });
437   func->walk([&](fir::UnreachableOp child) { joinOperationLattice(child); });
438   llvm::DenseSet<mlir::Value> freedValues;
439   point.appendFreedValues(freedValues);
440 
441   // We only replace allocations which are definately freed on all routes
442   // through the function because otherwise the allocation may have an intende
443   // lifetime longer than the current stack frame (e.g. a heap allocation which
444   // is then freed by another function).
445   for (mlir::Value freedValue : freedValues) {
446     fir::AllocMemOp allocmem = freedValue.getDefiningOp<fir::AllocMemOp>();
447     InsertionPoint insertionPoint =
448         AllocMemConversion::findAllocaInsertionPoint(allocmem);
449     if (insertionPoint)
450       candidateOps.insert({allocmem, insertionPoint});
451   }
452 
453   LLVM_DEBUG(for (auto [allocMemOp, _]
454                   : candidateOps) {
455     llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n';
456   });
457   return mlir::success();
458 }
459 
460 const StackArraysAnalysisWrapper::AllocMemMap *
461 StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) {
462   if (!funcMaps.contains(func))
463     if (mlir::failed(analyseFunction(func)))
464       return nullptr;
465   return &funcMaps[func];
466 }
467 
468 /// Restore the old allocation type exected by existing code
469 static mlir::Value convertAllocationType(mlir::PatternRewriter &rewriter,
470                                          const mlir::Location &loc,
471                                          mlir::Value heap, mlir::Value stack) {
472   mlir::Type heapTy = heap.getType();
473   mlir::Type stackTy = stack.getType();
474 
475   if (heapTy == stackTy)
476     return stack;
477 
478   fir::HeapType firHeapTy = mlir::cast<fir::HeapType>(heapTy);
479   fir::ReferenceType firRefTy = mlir::cast<fir::ReferenceType>(stackTy);
480   assert(firHeapTy.getElementType() == firRefTy.getElementType() &&
481          "Allocations must have the same type");
482 
483   auto insertionPoint = rewriter.saveInsertionPoint();
484   rewriter.setInsertionPointAfter(stack.getDefiningOp());
485   mlir::Value conv =
486       rewriter.create<fir::ConvertOp>(loc, firHeapTy, stack).getResult();
487   rewriter.restoreInsertionPoint(insertionPoint);
488   return conv;
489 }
490 
491 mlir::LogicalResult
492 AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
493                                     mlir::PatternRewriter &rewriter) const {
494   auto oldInsertionPt = rewriter.saveInsertionPoint();
495   // add alloca operation
496   std::optional<fir::AllocaOp> alloca = insertAlloca(allocmem, rewriter);
497   rewriter.restoreInsertionPoint(oldInsertionPt);
498   if (!alloca)
499     return mlir::failure();
500 
501   // remove freemem operations
502   llvm::SmallVector<mlir::Operation *> erases;
503   for (mlir::Operation *user : allocmem.getOperation()->getUsers())
504     if (mlir::isa<fir::FreeMemOp>(user))
505       erases.push_back(user);
506   // now we are done iterating the users, it is safe to mutate them
507   for (mlir::Operation *erase : erases)
508     rewriter.eraseOp(erase);
509 
510   // replace references to heap allocation with references to stack allocation
511   mlir::Value newValue = convertAllocationType(
512       rewriter, allocmem.getLoc(), allocmem.getResult(), alloca->getResult());
513   rewriter.replaceAllUsesWith(allocmem.getResult(), newValue);
514 
515   // remove allocmem operation
516   rewriter.eraseOp(allocmem.getOperation());
517 
518   return mlir::success();
519 }
520 
521 static bool isInLoop(mlir::Block *block) {
522   return mlir::LoopLikeOpInterface::blockIsInLoop(block);
523 }
524 
525 static bool isInLoop(mlir::Operation *op) {
526   return isInLoop(op->getBlock()) ||
527          op->getParentOfType<mlir::LoopLikeOpInterface>();
528 }
529 
530 InsertionPoint
531 AllocMemConversion::findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc) {
532   // Ideally the alloca should be inserted at the end of the function entry
533   // block so that we do not allocate stack space in a loop. However,
534   // the operands to the alloca may not be available that early, so insert it
535   // after the last operand becomes available
536   // If the old allocmem op was in an openmp region then it should not be moved
537   // outside of that
538   LLVM_DEBUG(llvm::dbgs() << "StackArrays: findAllocaInsertionPoint: "
539                           << oldAlloc << "\n");
540 
541   // check that an Operation or Block we are about to return is not in a loop
542   auto checkReturn = [&](auto *point) -> InsertionPoint {
543     if (isInLoop(point)) {
544       mlir::Operation *oldAllocOp = oldAlloc.getOperation();
545       if (isInLoop(oldAllocOp)) {
546         // where we want to put it is in a loop, and even the old location is in
547         // a loop. Give up.
548         return findAllocaLoopInsertionPoint(oldAlloc);
549       }
550       return {oldAllocOp};
551     }
552     return {point};
553   };
554 
555   auto oldOmpRegion =
556       oldAlloc->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>();
557 
558   // Find when the last operand value becomes available
559   mlir::Block *operandsBlock = nullptr;
560   mlir::Operation *lastOperand = nullptr;
561   for (mlir::Value operand : oldAlloc.getOperands()) {
562     LLVM_DEBUG(llvm::dbgs() << "--considering operand " << operand << "\n");
563     mlir::Operation *op = operand.getDefiningOp();
564     if (!op)
565       return checkReturn(oldAlloc.getOperation());
566     if (!operandsBlock)
567       operandsBlock = op->getBlock();
568     else if (operandsBlock != op->getBlock()) {
569       LLVM_DEBUG(llvm::dbgs()
570                  << "----operand declared in a different block!\n");
571       // Operation::isBeforeInBlock requires the operations to be in the same
572       // block. The best we can do is the location of the allocmem.
573       return checkReturn(oldAlloc.getOperation());
574     }
575     if (!lastOperand || lastOperand->isBeforeInBlock(op))
576       lastOperand = op;
577   }
578 
579   if (lastOperand) {
580     // there were value operands to the allocmem so insert after the last one
581     LLVM_DEBUG(llvm::dbgs()
582                << "--Placing after last operand: " << *lastOperand << "\n");
583     // check we aren't moving out of an omp region
584     auto lastOpOmpRegion =
585         lastOperand->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>();
586     if (lastOpOmpRegion == oldOmpRegion)
587       return checkReturn(lastOperand);
588     // Presumably this happened because the operands became ready before the
589     // start of this openmp region. (lastOpOmpRegion != oldOmpRegion) should
590     // imply that oldOmpRegion comes after lastOpOmpRegion.
591     return checkReturn(oldOmpRegion.getAllocaBlock());
592   }
593 
594   // There were no value operands to the allocmem so we are safe to insert it
595   // as early as we want
596 
597   // handle openmp case
598   if (oldOmpRegion)
599     return checkReturn(oldOmpRegion.getAllocaBlock());
600 
601   // fall back to the function entry block
602   mlir::func::FuncOp func = oldAlloc->getParentOfType<mlir::func::FuncOp>();
603   assert(func && "This analysis is run on func.func");
604   mlir::Block &entryBlock = func.getBlocks().front();
605   LLVM_DEBUG(llvm::dbgs() << "--Placing at the start of func entry block\n");
606   return checkReturn(&entryBlock);
607 }
608 
609 InsertionPoint
610 AllocMemConversion::findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc) {
611   mlir::Operation *oldAllocOp = oldAlloc;
612   // This is only called as a last resort. We should try to insert at the
613   // location of the old allocation, which is inside of a loop, using
614   // llvm.stacksave/llvm.stackrestore
615 
616   // find freemem ops
617   llvm::SmallVector<mlir::Operation *, 1> freeOps;
618   for (mlir::Operation *user : oldAllocOp->getUsers())
619     if (mlir::isa<fir::FreeMemOp>(user))
620       freeOps.push_back(user);
621   assert(freeOps.size() && "DFA should only return freed memory");
622 
623   // Don't attempt to reason about a stacksave/stackrestore between different
624   // blocks
625   for (mlir::Operation *free : freeOps)
626     if (free->getBlock() != oldAllocOp->getBlock())
627       return {nullptr};
628 
629   // Check that there aren't any other stack allocations in between the
630   // stack save and stack restore
631   // note: for flang generated temporaries there should only be one free op
632   for (mlir::Operation *free : freeOps) {
633     for (mlir::Operation *op = oldAlloc; op && op != free;
634          op = op->getNextNode()) {
635       if (mlir::isa<fir::AllocaOp>(op))
636         return {nullptr};
637     }
638   }
639 
640   return InsertionPoint{oldAllocOp, /*shouldStackSaveRestore=*/true};
641 }
642 
643 std::optional<fir::AllocaOp>
644 AllocMemConversion::insertAlloca(fir::AllocMemOp &oldAlloc,
645                                  mlir::PatternRewriter &rewriter) const {
646   auto it = candidateOps.find(oldAlloc.getOperation());
647   if (it == candidateOps.end())
648     return {};
649   InsertionPoint insertionPoint = it->second;
650   if (!insertionPoint)
651     return {};
652 
653   if (insertionPoint.shouldSaveRestoreStack())
654     insertStackSaveRestore(oldAlloc, rewriter);
655 
656   mlir::Location loc = oldAlloc.getLoc();
657   mlir::Type varTy = oldAlloc.getInType();
658   if (mlir::Operation *op = insertionPoint.tryGetOperation()) {
659     rewriter.setInsertionPointAfter(op);
660   } else {
661     mlir::Block *block = insertionPoint.tryGetBlock();
662     assert(block && "There must be a valid insertion point");
663     rewriter.setInsertionPointToStart(block);
664   }
665 
666   auto unpackName = [](std::optional<llvm::StringRef> opt) -> llvm::StringRef {
667     if (opt)
668       return *opt;
669     return {};
670   };
671 
672   llvm::StringRef uniqName = unpackName(oldAlloc.getUniqName());
673   llvm::StringRef bindcName = unpackName(oldAlloc.getBindcName());
674   return rewriter.create<fir::AllocaOp>(loc, varTy, uniqName, bindcName,
675                                         oldAlloc.getTypeparams(),
676                                         oldAlloc.getShape());
677 }
678 
679 void AllocMemConversion::insertStackSaveRestore(
680     fir::AllocMemOp &oldAlloc, mlir::PatternRewriter &rewriter) const {
681   auto oldPoint = rewriter.saveInsertionPoint();
682   auto mod = oldAlloc->getParentOfType<mlir::ModuleOp>();
683   fir::FirOpBuilder builder{rewriter, mod};
684 
685   mlir::func::FuncOp stackSaveFn = fir::factory::getLlvmStackSave(builder);
686   mlir::SymbolRefAttr stackSaveSym =
687       builder.getSymbolRefAttr(stackSaveFn.getName());
688 
689   builder.setInsertionPoint(oldAlloc);
690   mlir::Value sp =
691       builder
692           .create<fir::CallOp>(oldAlloc.getLoc(),
693                                stackSaveFn.getFunctionType().getResults(),
694                                stackSaveSym, mlir::ValueRange{})
695           .getResult(0);
696 
697   mlir::func::FuncOp stackRestoreFn =
698       fir::factory::getLlvmStackRestore(builder);
699   mlir::SymbolRefAttr stackRestoreSym =
700       builder.getSymbolRefAttr(stackRestoreFn.getName());
701 
702   for (mlir::Operation *user : oldAlloc->getUsers()) {
703     if (mlir::isa<fir::FreeMemOp>(user)) {
704       builder.setInsertionPoint(user);
705       builder.create<fir::CallOp>(user->getLoc(),
706                                   stackRestoreFn.getFunctionType().getResults(),
707                                   stackRestoreSym, mlir::ValueRange{sp});
708     }
709   }
710 
711   rewriter.restoreInsertionPoint(oldPoint);
712 }
713 
714 StackArraysPass::StackArraysPass(const StackArraysPass &pass)
715     : fir::impl::StackArraysBase<StackArraysPass>(pass) {}
716 
717 llvm::StringRef StackArraysPass::getDescription() const {
718   return "Move heap allocated array temporaries to the stack";
719 }
720 
721 void StackArraysPass::runOnOperation() {
722   mlir::ModuleOp mod = getOperation();
723 
724   mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); });
725 }
726 
727 void StackArraysPass::runOnFunc(mlir::Operation *func) {
728   assert(mlir::isa<mlir::func::FuncOp>(func));
729 
730   auto &analysis = getAnalysis<StackArraysAnalysisWrapper>();
731   const StackArraysAnalysisWrapper::AllocMemMap *candidateOps =
732       analysis.getCandidateOps(func);
733   if (!candidateOps) {
734     signalPassFailure();
735     return;
736   }
737 
738   if (candidateOps->empty())
739     return;
740   runCount += candidateOps->size();
741 
742   llvm::SmallVector<mlir::Operation *> opsToConvert;
743   opsToConvert.reserve(candidateOps->size());
744   for (auto [op, _] : *candidateOps)
745     opsToConvert.push_back(op);
746 
747   mlir::MLIRContext &context = getContext();
748   mlir::RewritePatternSet patterns(&context);
749   mlir::GreedyRewriteConfig config;
750   // prevent the pattern driver form merging blocks
751   config.enableRegionSimplification = false;
752 
753   patterns.insert<AllocMemConversion>(&context, *candidateOps);
754   if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert,
755                                                 std::move(patterns), config))) {
756     mlir::emitError(func->getLoc(), "error in stack arrays optimization\n");
757     signalPassFailure();
758   }
759 }
760 
761 std::unique_ptr<mlir::Pass> fir::createStackArraysPass() {
762   return std::make_unique<StackArraysPass>();
763 }
764