xref: /llvm-project/flang/lib/Optimizer/Transforms/StackArrays.cpp (revision 698b42ccff69fde6509c6ad0baf262c257c039bb)
1 //===- StackArrays.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
11 #include "flang/Optimizer/Dialect/FIRAttr.h"
12 #include "flang/Optimizer/Dialect/FIRDialect.h"
13 #include "flang/Optimizer/Dialect/FIROps.h"
14 #include "flang/Optimizer/Dialect/FIRType.h"
15 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
16 #include "flang/Optimizer/Transforms/Passes.h"
17 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
18 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
19 #include "mlir/Analysis/DataFlow/DenseAnalysis.h"
20 #include "mlir/Analysis/DataFlowFramework.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/Value.h"
26 #include "mlir/Interfaces/LoopLikeInterface.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29 #include "mlir/Transforms/Passes.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/PointerUnion.h"
33 #include "llvm/Support/Casting.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include <optional>
36 
37 namespace fir {
38 #define GEN_PASS_DEF_STACKARRAYS
39 #include "flang/Optimizer/Transforms/Passes.h.inc"
40 } // namespace fir
41 
42 #define DEBUG_TYPE "stack-arrays"
43 
44 static llvm::cl::opt<std::size_t> maxAllocsPerFunc(
45     "stack-arrays-max-allocs",
46     llvm::cl::desc("The maximum number of heap allocations to consider in one "
47                    "function before skipping (to save compilation time). Set "
48                    "to 0 for no limit."),
49     llvm::cl::init(1000), llvm::cl::Hidden);
50 
51 namespace {
52 
53 /// The state of an SSA value at each program point
54 enum class AllocationState {
55   /// This means that the allocation state of a variable cannot be determined
56   /// at this program point, e.g. because one route through a conditional freed
57   /// the variable and the other route didn't.
58   /// This asserts a known-unknown: different from the unknown-unknown of having
59   /// no AllocationState stored for a particular SSA value
60   Unknown,
61   /// Means this SSA value was allocated on the heap in this function and has
62   /// now been freed
63   Freed,
64   /// Means this SSA value was allocated on the heap in this function and is a
65   /// candidate for moving to the stack
66   Allocated,
67 };
68 
69 /// Stores where an alloca should be inserted. If the PointerUnion is an
70 /// Operation the alloca should be inserted /after/ the operation. If it is a
71 /// block, the alloca can be placed anywhere in that block.
72 class InsertionPoint {
73   llvm::PointerUnion<mlir::Operation *, mlir::Block *> location;
74   bool saveRestoreStack;
75 
76   /// Get contained pointer type or nullptr
77   template <class T>
78   T *tryGetPtr() const {
79     if (location.is<T *>())
80       return location.get<T *>();
81     return nullptr;
82   }
83 
84 public:
85   template <class T>
86   InsertionPoint(T *ptr, bool saveRestoreStack = false)
87       : location(ptr), saveRestoreStack{saveRestoreStack} {}
88   InsertionPoint(std::nullptr_t null)
89       : location(null), saveRestoreStack{false} {}
90 
91   /// Get contained operation, or nullptr
92   mlir::Operation *tryGetOperation() const {
93     return tryGetPtr<mlir::Operation>();
94   }
95 
96   /// Get contained block, or nullptr
97   mlir::Block *tryGetBlock() const { return tryGetPtr<mlir::Block>(); }
98 
99   /// Get whether the stack should be saved/restored. If yes, an llvm.stacksave
100   /// intrinsic should be added before the alloca, and an llvm.stackrestore
101   /// intrinsic should be added where the freemem is
102   bool shouldSaveRestoreStack() const { return saveRestoreStack; }
103 
104   operator bool() const { return tryGetOperation() || tryGetBlock(); }
105 
106   bool operator==(const InsertionPoint &rhs) const {
107     return (location == rhs.location) &&
108            (saveRestoreStack == rhs.saveRestoreStack);
109   }
110 
111   bool operator!=(const InsertionPoint &rhs) const { return !(*this == rhs); }
112 };
113 
114 /// Maps SSA values to their AllocationState at a particular program point.
115 /// Also caches the insertion points for the new alloca operations
116 class LatticePoint : public mlir::dataflow::AbstractDenseLattice {
117   // Maps all values we are interested in to states
118   llvm::SmallDenseMap<mlir::Value, AllocationState, 1> stateMap;
119 
120 public:
121   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LatticePoint)
122   using AbstractDenseLattice::AbstractDenseLattice;
123 
124   bool operator==(const LatticePoint &rhs) const {
125     return stateMap == rhs.stateMap;
126   }
127 
128   /// Join the lattice accross control-flow edges
129   mlir::ChangeResult join(const AbstractDenseLattice &lattice) override;
130 
131   void print(llvm::raw_ostream &os) const override;
132 
133   /// Clear all modifications
134   mlir::ChangeResult reset();
135 
136   /// Set the state of an SSA value
137   mlir::ChangeResult set(mlir::Value value, AllocationState state);
138 
139   /// Get fir.allocmem ops which were allocated in this function and always
140   /// freed before the function returns, plus whre to insert replacement
141   /// fir.alloca ops
142   void appendFreedValues(llvm::DenseSet<mlir::Value> &out) const;
143 
144   std::optional<AllocationState> get(mlir::Value val) const;
145 };
146 
147 class AllocationAnalysis
148     : public mlir::dataflow::DenseForwardDataFlowAnalysis<LatticePoint> {
149 public:
150   using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis;
151 
152   void visitOperation(mlir::Operation *op, const LatticePoint &before,
153                       LatticePoint *after) override;
154 
155   /// At an entry point, the last modifications of all memory resources are
156   /// yet to be determined
157   void setToEntryState(LatticePoint *lattice) override;
158 
159 protected:
160   /// Visit control flow operations and decide whether to call visitOperation
161   /// to apply the transfer function
162   void processOperation(mlir::Operation *op) override;
163 };
164 
165 /// Drives analysis to find candidate fir.allocmem operations which could be
166 /// moved to the stack. Intended to be used with mlir::Pass::getAnalysis
167 class StackArraysAnalysisWrapper {
168 public:
169   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StackArraysAnalysisWrapper)
170 
171   // Maps fir.allocmem -> place to insert alloca
172   using AllocMemMap = llvm::DenseMap<mlir::Operation *, InsertionPoint>;
173 
174   StackArraysAnalysisWrapper(mlir::Operation *op) {}
175 
176   // returns nullptr if analysis failed
177   const AllocMemMap *getCandidateOps(mlir::Operation *func);
178 
179 private:
180   llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps;
181 
182   llvm::LogicalResult analyseFunction(mlir::Operation *func);
183 };
184 
185 /// Converts a fir.allocmem to a fir.alloca
186 class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
187 public:
188   explicit AllocMemConversion(
189       mlir::MLIRContext *ctx,
190       const StackArraysAnalysisWrapper::AllocMemMap &candidateOps)
191       : OpRewritePattern(ctx), candidateOps{candidateOps} {}
192 
193   llvm::LogicalResult
194   matchAndRewrite(fir::AllocMemOp allocmem,
195                   mlir::PatternRewriter &rewriter) const override;
196 
197   /// Determine where to insert the alloca operation. The returned value should
198   /// be checked to see if it is inside a loop
199   static InsertionPoint findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc);
200 
201 private:
202   /// Handle to the DFA (already run)
203   const StackArraysAnalysisWrapper::AllocMemMap &candidateOps;
204 
205   /// If we failed to find an insertion point not inside a loop, see if it would
206   /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop
207   static InsertionPoint findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc);
208 
209   /// Returns the alloca if it was successfully inserted, otherwise {}
210   std::optional<fir::AllocaOp>
211   insertAlloca(fir::AllocMemOp &oldAlloc,
212                mlir::PatternRewriter &rewriter) const;
213 
214   /// Inserts a stacksave before oldAlloc and a stackrestore after each freemem
215   void insertStackSaveRestore(fir::AllocMemOp &oldAlloc,
216                               mlir::PatternRewriter &rewriter) const;
217 };
218 
219 class StackArraysPass : public fir::impl::StackArraysBase<StackArraysPass> {
220 public:
221   StackArraysPass() = default;
222   StackArraysPass(const StackArraysPass &pass);
223 
224   llvm::StringRef getDescription() const override;
225 
226   void runOnOperation() override;
227   void runOnFunc(mlir::Operation *func);
228 
229 private:
230   Statistic runCount{this, "stackArraysRunCount",
231                      "Number of heap allocations moved to the stack"};
232 };
233 
234 } // namespace
235 
236 static void print(llvm::raw_ostream &os, AllocationState state) {
237   switch (state) {
238   case AllocationState::Unknown:
239     os << "Unknown";
240     break;
241   case AllocationState::Freed:
242     os << "Freed";
243     break;
244   case AllocationState::Allocated:
245     os << "Allocated";
246     break;
247   }
248 }
249 
250 /// Join two AllocationStates for the same value coming from different CFG
251 /// blocks
252 static AllocationState join(AllocationState lhs, AllocationState rhs) {
253   //           | Allocated | Freed     | Unknown
254   // ========= | ========= | ========= | =========
255   // Allocated | Allocated | Unknown   | Unknown
256   // Freed     | Unknown   | Freed     | Unknown
257   // Unknown   | Unknown   | Unknown   | Unknown
258   if (lhs == rhs)
259     return lhs;
260   return AllocationState::Unknown;
261 }
262 
263 mlir::ChangeResult LatticePoint::join(const AbstractDenseLattice &lattice) {
264   const auto &rhs = static_cast<const LatticePoint &>(lattice);
265   mlir::ChangeResult changed = mlir::ChangeResult::NoChange;
266 
267   // add everything from rhs to map, handling cases where values are in both
268   for (const auto &[value, rhsState] : rhs.stateMap) {
269     auto it = stateMap.find(value);
270     if (it != stateMap.end()) {
271       // value is present in both maps
272       AllocationState myState = it->second;
273       AllocationState newState = ::join(myState, rhsState);
274       if (newState != myState) {
275         changed = mlir::ChangeResult::Change;
276         it->getSecond() = newState;
277       }
278     } else {
279       // value not present in current map: add it
280       stateMap.insert({value, rhsState});
281       changed = mlir::ChangeResult::Change;
282     }
283   }
284 
285   return changed;
286 }
287 
288 void LatticePoint::print(llvm::raw_ostream &os) const {
289   for (const auto &[value, state] : stateMap) {
290     os << "\n * " << value << ": ";
291     ::print(os, state);
292   }
293 }
294 
295 mlir::ChangeResult LatticePoint::reset() {
296   if (stateMap.empty())
297     return mlir::ChangeResult::NoChange;
298   stateMap.clear();
299   return mlir::ChangeResult::Change;
300 }
301 
302 mlir::ChangeResult LatticePoint::set(mlir::Value value, AllocationState state) {
303   if (stateMap.count(value)) {
304     // already in map
305     AllocationState &oldState = stateMap[value];
306     if (oldState != state) {
307       stateMap[value] = state;
308       return mlir::ChangeResult::Change;
309     }
310     return mlir::ChangeResult::NoChange;
311   }
312   stateMap.insert({value, state});
313   return mlir::ChangeResult::Change;
314 }
315 
316 /// Get values which were allocated in this function and always freed before
317 /// the function returns
318 void LatticePoint::appendFreedValues(llvm::DenseSet<mlir::Value> &out) const {
319   for (auto &[value, state] : stateMap) {
320     if (state == AllocationState::Freed)
321       out.insert(value);
322   }
323 }
324 
325 std::optional<AllocationState> LatticePoint::get(mlir::Value val) const {
326   auto it = stateMap.find(val);
327   if (it == stateMap.end())
328     return {};
329   return it->second;
330 }
331 
332 void AllocationAnalysis::visitOperation(mlir::Operation *op,
333                                         const LatticePoint &before,
334                                         LatticePoint *after) {
335   LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op
336                           << "\n");
337   LLVM_DEBUG(llvm::dbgs() << "--Lattice in: " << before << "\n");
338 
339   // propagate before -> after
340   mlir::ChangeResult changed = after->join(before);
341 
342   if (auto allocmem = mlir::dyn_cast<fir::AllocMemOp>(op)) {
343     assert(op->getNumResults() == 1 && "fir.allocmem has one result");
344     auto attr = op->getAttrOfType<fir::MustBeHeapAttr>(
345         fir::MustBeHeapAttr::getAttrName());
346     if (attr && attr.getValue()) {
347       LLVM_DEBUG(llvm::dbgs() << "--Found fir.must_be_heap: skipping\n");
348       // skip allocation marked not to be moved
349       return;
350     }
351 
352     auto retTy = allocmem.getAllocatedType();
353     if (!mlir::isa<fir::SequenceType>(retTy)) {
354       LLVM_DEBUG(llvm::dbgs()
355                  << "--Allocation is not for an array: skipping\n");
356       return;
357     }
358 
359     mlir::Value result = op->getResult(0);
360     changed |= after->set(result, AllocationState::Allocated);
361   } else if (mlir::isa<fir::FreeMemOp>(op)) {
362     assert(op->getNumOperands() == 1 && "fir.freemem has one operand");
363     mlir::Value operand = op->getOperand(0);
364 
365     // Note: StackArrays is scheduled in the pass pipeline after lowering hlfir
366     // to fir. Therefore, we only need to handle `fir::DeclareOp`s.
367     if (auto declareOp =
368             llvm::dyn_cast_if_present<fir::DeclareOp>(operand.getDefiningOp()))
369       operand = declareOp.getMemref();
370 
371     std::optional<AllocationState> operandState = before.get(operand);
372     if (operandState && *operandState == AllocationState::Allocated) {
373       // don't tag things not allocated in this function as freed, so that we
374       // don't think they are candidates for moving to the stack
375       changed |= after->set(operand, AllocationState::Freed);
376     }
377   } else if (mlir::isa<fir::ResultOp>(op)) {
378     mlir::Operation *parent = op->getParentOp();
379     LatticePoint *parentLattice = getLattice(parent);
380     assert(parentLattice);
381     mlir::ChangeResult parentChanged = parentLattice->join(*after);
382     propagateIfChanged(parentLattice, parentChanged);
383   }
384 
385   // we pass lattices straight through fir.call because called functions should
386   // not deallocate flang-generated array temporaries
387 
388   LLVM_DEBUG(llvm::dbgs() << "--Lattice out: " << *after << "\n");
389   propagateIfChanged(after, changed);
390 }
391 
392 void AllocationAnalysis::setToEntryState(LatticePoint *lattice) {
393   propagateIfChanged(lattice, lattice->reset());
394 }
395 
396 /// Mostly a copy of AbstractDenseLattice::processOperation - the difference
397 /// being that call operations are passed through to the transfer function
398 void AllocationAnalysis::processOperation(mlir::Operation *op) {
399   // If the containing block is not executable, bail out.
400   if (!getOrCreateFor<mlir::dataflow::Executable>(op, op->getBlock())->isLive())
401     return;
402 
403   // Get the dense lattice to update
404   mlir::dataflow::AbstractDenseLattice *after = getLattice(op);
405 
406   // If this op implements region control-flow, then control-flow dictates its
407   // transfer function.
408   if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op))
409     return visitRegionBranchOperation(op, branch, after);
410 
411   // pass call operations through to the transfer function
412 
413   // Get the dense state before the execution of the op.
414   const mlir::dataflow::AbstractDenseLattice *before;
415   if (mlir::Operation *prev = op->getPrevNode())
416     before = getLatticeFor(op, prev);
417   else
418     before = getLatticeFor(op, op->getBlock());
419 
420   /// Invoke the operation transfer function
421   visitOperationImpl(op, *before, after);
422 }
423 
424 llvm::LogicalResult
425 StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
426   assert(mlir::isa<mlir::func::FuncOp>(func));
427   size_t nAllocs = 0;
428   func->walk([&nAllocs](fir::AllocMemOp) { nAllocs++; });
429   // don't bother with the analysis if there are no heap allocations
430   if (nAllocs == 0)
431     return mlir::success();
432   if ((maxAllocsPerFunc != 0) && (nAllocs > maxAllocsPerFunc)) {
433     LLVM_DEBUG(llvm::dbgs() << "Skipping stack arrays for function with "
434                             << nAllocs << " heap allocations");
435     return mlir::success();
436   }
437 
438   mlir::DataFlowSolver solver;
439   // constant propagation is required for dead code analysis, dead code analysis
440   // is required to mark blocks live (required for mlir dense dfa)
441   solver.load<mlir::dataflow::SparseConstantPropagation>();
442   solver.load<mlir::dataflow::DeadCodeAnalysis>();
443 
444   auto [it, inserted] = funcMaps.try_emplace(func);
445   AllocMemMap &candidateOps = it->second;
446 
447   solver.load<AllocationAnalysis>();
448   if (failed(solver.initializeAndRun(func))) {
449     llvm::errs() << "DataFlowSolver failed!";
450     return mlir::failure();
451   }
452 
453   LatticePoint point{func};
454   auto joinOperationLattice = [&](mlir::Operation *op) {
455     const LatticePoint *lattice = solver.lookupState<LatticePoint>(op);
456     // there will be no lattice for an unreachable block
457     if (lattice)
458       (void)point.join(*lattice);
459   };
460 
461   func->walk([&](mlir::func::ReturnOp child) { joinOperationLattice(child); });
462   func->walk([&](fir::UnreachableOp child) { joinOperationLattice(child); });
463   func->walk(
464       [&](mlir::omp::TerminatorOp child) { joinOperationLattice(child); });
465   func->walk([&](mlir::omp::YieldOp child) { joinOperationLattice(child); });
466 
467   llvm::DenseSet<mlir::Value> freedValues;
468   point.appendFreedValues(freedValues);
469 
470   // We only replace allocations which are definately freed on all routes
471   // through the function because otherwise the allocation may have an intende
472   // lifetime longer than the current stack frame (e.g. a heap allocation which
473   // is then freed by another function).
474   for (mlir::Value freedValue : freedValues) {
475     fir::AllocMemOp allocmem = freedValue.getDefiningOp<fir::AllocMemOp>();
476     InsertionPoint insertionPoint =
477         AllocMemConversion::findAllocaInsertionPoint(allocmem);
478     if (insertionPoint)
479       candidateOps.insert({allocmem, insertionPoint});
480   }
481 
482   LLVM_DEBUG(for (auto [allocMemOp, _]
483                   : candidateOps) {
484     llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n';
485   });
486   return mlir::success();
487 }
488 
489 const StackArraysAnalysisWrapper::AllocMemMap *
490 StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) {
491   if (!funcMaps.contains(func))
492     if (mlir::failed(analyseFunction(func)))
493       return nullptr;
494   return &funcMaps[func];
495 }
496 
497 /// Restore the old allocation type exected by existing code
498 static mlir::Value convertAllocationType(mlir::PatternRewriter &rewriter,
499                                          const mlir::Location &loc,
500                                          mlir::Value heap, mlir::Value stack) {
501   mlir::Type heapTy = heap.getType();
502   mlir::Type stackTy = stack.getType();
503 
504   if (heapTy == stackTy)
505     return stack;
506 
507   fir::HeapType firHeapTy = mlir::cast<fir::HeapType>(heapTy);
508   LLVM_ATTRIBUTE_UNUSED fir::ReferenceType firRefTy =
509       mlir::cast<fir::ReferenceType>(stackTy);
510   assert(firHeapTy.getElementType() == firRefTy.getElementType() &&
511          "Allocations must have the same type");
512 
513   auto insertionPoint = rewriter.saveInsertionPoint();
514   rewriter.setInsertionPointAfter(stack.getDefiningOp());
515   mlir::Value conv =
516       rewriter.create<fir::ConvertOp>(loc, firHeapTy, stack).getResult();
517   rewriter.restoreInsertionPoint(insertionPoint);
518   return conv;
519 }
520 
521 llvm::LogicalResult
522 AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
523                                     mlir::PatternRewriter &rewriter) const {
524   auto oldInsertionPt = rewriter.saveInsertionPoint();
525   // add alloca operation
526   std::optional<fir::AllocaOp> alloca = insertAlloca(allocmem, rewriter);
527   rewriter.restoreInsertionPoint(oldInsertionPt);
528   if (!alloca)
529     return mlir::failure();
530 
531   // remove freemem operations
532   llvm::SmallVector<mlir::Operation *> erases;
533   for (mlir::Operation *user : allocmem.getOperation()->getUsers()) {
534     if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
535       for (mlir::Operation *user : declareOp->getUsers()) {
536         if (mlir::isa<fir::FreeMemOp>(user))
537           erases.push_back(user);
538       }
539     }
540 
541     if (mlir::isa<fir::FreeMemOp>(user))
542       erases.push_back(user);
543   }
544 
545   // now we are done iterating the users, it is safe to mutate them
546   for (mlir::Operation *erase : erases)
547     rewriter.eraseOp(erase);
548 
549   // replace references to heap allocation with references to stack allocation
550   mlir::Value newValue = convertAllocationType(
551       rewriter, allocmem.getLoc(), allocmem.getResult(), alloca->getResult());
552   rewriter.replaceAllUsesWith(allocmem.getResult(), newValue);
553 
554   // remove allocmem operation
555   rewriter.eraseOp(allocmem.getOperation());
556 
557   return mlir::success();
558 }
559 
560 static bool isInLoop(mlir::Block *block) {
561   return mlir::LoopLikeOpInterface::blockIsInLoop(block);
562 }
563 
564 static bool isInLoop(mlir::Operation *op) {
565   return isInLoop(op->getBlock()) ||
566          op->getParentOfType<mlir::LoopLikeOpInterface>();
567 }
568 
569 InsertionPoint
570 AllocMemConversion::findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc) {
571   // Ideally the alloca should be inserted at the end of the function entry
572   // block so that we do not allocate stack space in a loop. However,
573   // the operands to the alloca may not be available that early, so insert it
574   // after the last operand becomes available
575   // If the old allocmem op was in an openmp region then it should not be moved
576   // outside of that
577   LLVM_DEBUG(llvm::dbgs() << "StackArrays: findAllocaInsertionPoint: "
578                           << oldAlloc << "\n");
579 
580   // check that an Operation or Block we are about to return is not in a loop
581   auto checkReturn = [&](auto *point) -> InsertionPoint {
582     if (isInLoop(point)) {
583       mlir::Operation *oldAllocOp = oldAlloc.getOperation();
584       if (isInLoop(oldAllocOp)) {
585         // where we want to put it is in a loop, and even the old location is in
586         // a loop. Give up.
587         return findAllocaLoopInsertionPoint(oldAlloc);
588       }
589       return {oldAllocOp};
590     }
591     return {point};
592   };
593 
594   auto oldOmpRegion =
595       oldAlloc->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>();
596 
597   // Find when the last operand value becomes available
598   mlir::Block *operandsBlock = nullptr;
599   mlir::Operation *lastOperand = nullptr;
600   for (mlir::Value operand : oldAlloc.getOperands()) {
601     LLVM_DEBUG(llvm::dbgs() << "--considering operand " << operand << "\n");
602     mlir::Operation *op = operand.getDefiningOp();
603     if (!op)
604       return checkReturn(oldAlloc.getOperation());
605     if (!operandsBlock)
606       operandsBlock = op->getBlock();
607     else if (operandsBlock != op->getBlock()) {
608       LLVM_DEBUG(llvm::dbgs()
609                  << "----operand declared in a different block!\n");
610       // Operation::isBeforeInBlock requires the operations to be in the same
611       // block. The best we can do is the location of the allocmem.
612       return checkReturn(oldAlloc.getOperation());
613     }
614     if (!lastOperand || lastOperand->isBeforeInBlock(op))
615       lastOperand = op;
616   }
617 
618   if (lastOperand) {
619     // there were value operands to the allocmem so insert after the last one
620     LLVM_DEBUG(llvm::dbgs()
621                << "--Placing after last operand: " << *lastOperand << "\n");
622     // check we aren't moving out of an omp region
623     auto lastOpOmpRegion =
624         lastOperand->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>();
625     if (lastOpOmpRegion == oldOmpRegion)
626       return checkReturn(lastOperand);
627     // Presumably this happened because the operands became ready before the
628     // start of this openmp region. (lastOpOmpRegion != oldOmpRegion) should
629     // imply that oldOmpRegion comes after lastOpOmpRegion.
630     return checkReturn(oldOmpRegion.getAllocaBlock());
631   }
632 
633   // There were no value operands to the allocmem so we are safe to insert it
634   // as early as we want
635 
636   // handle openmp case
637   if (oldOmpRegion)
638     return checkReturn(oldOmpRegion.getAllocaBlock());
639 
640   // fall back to the function entry block
641   mlir::func::FuncOp func = oldAlloc->getParentOfType<mlir::func::FuncOp>();
642   assert(func && "This analysis is run on func.func");
643   mlir::Block &entryBlock = func.getBlocks().front();
644   LLVM_DEBUG(llvm::dbgs() << "--Placing at the start of func entry block\n");
645   return checkReturn(&entryBlock);
646 }
647 
648 InsertionPoint
649 AllocMemConversion::findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc) {
650   mlir::Operation *oldAllocOp = oldAlloc;
651   // This is only called as a last resort. We should try to insert at the
652   // location of the old allocation, which is inside of a loop, using
653   // llvm.stacksave/llvm.stackrestore
654 
655   // find freemem ops
656   llvm::SmallVector<mlir::Operation *, 1> freeOps;
657 
658   for (mlir::Operation *user : oldAllocOp->getUsers()) {
659     if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
660       for (mlir::Operation *user : declareOp->getUsers()) {
661         if (mlir::isa<fir::FreeMemOp>(user))
662           freeOps.push_back(user);
663       }
664     }
665 
666     if (mlir::isa<fir::FreeMemOp>(user))
667       freeOps.push_back(user);
668   }
669 
670   assert(freeOps.size() && "DFA should only return freed memory");
671 
672   // Don't attempt to reason about a stacksave/stackrestore between different
673   // blocks
674   for (mlir::Operation *free : freeOps)
675     if (free->getBlock() != oldAllocOp->getBlock())
676       return {nullptr};
677 
678   // Check that there aren't any other stack allocations in between the
679   // stack save and stack restore
680   // note: for flang generated temporaries there should only be one free op
681   for (mlir::Operation *free : freeOps) {
682     for (mlir::Operation *op = oldAlloc; op && op != free;
683          op = op->getNextNode()) {
684       if (mlir::isa<fir::AllocaOp>(op))
685         return {nullptr};
686     }
687   }
688 
689   return InsertionPoint{oldAllocOp, /*shouldStackSaveRestore=*/true};
690 }
691 
692 std::optional<fir::AllocaOp>
693 AllocMemConversion::insertAlloca(fir::AllocMemOp &oldAlloc,
694                                  mlir::PatternRewriter &rewriter) const {
695   auto it = candidateOps.find(oldAlloc.getOperation());
696   if (it == candidateOps.end())
697     return {};
698   InsertionPoint insertionPoint = it->second;
699   if (!insertionPoint)
700     return {};
701 
702   if (insertionPoint.shouldSaveRestoreStack())
703     insertStackSaveRestore(oldAlloc, rewriter);
704 
705   mlir::Location loc = oldAlloc.getLoc();
706   mlir::Type varTy = oldAlloc.getInType();
707   if (mlir::Operation *op = insertionPoint.tryGetOperation()) {
708     rewriter.setInsertionPointAfter(op);
709   } else {
710     mlir::Block *block = insertionPoint.tryGetBlock();
711     assert(block && "There must be a valid insertion point");
712     rewriter.setInsertionPointToStart(block);
713   }
714 
715   auto unpackName = [](std::optional<llvm::StringRef> opt) -> llvm::StringRef {
716     if (opt)
717       return *opt;
718     return {};
719   };
720 
721   llvm::StringRef uniqName = unpackName(oldAlloc.getUniqName());
722   llvm::StringRef bindcName = unpackName(oldAlloc.getBindcName());
723   return rewriter.create<fir::AllocaOp>(loc, varTy, uniqName, bindcName,
724                                         oldAlloc.getTypeparams(),
725                                         oldAlloc.getShape());
726 }
727 
728 void AllocMemConversion::insertStackSaveRestore(
729     fir::AllocMemOp &oldAlloc, mlir::PatternRewriter &rewriter) const {
730   auto oldPoint = rewriter.saveInsertionPoint();
731   auto mod = oldAlloc->getParentOfType<mlir::ModuleOp>();
732   fir::FirOpBuilder builder{rewriter, mod};
733 
734   mlir::func::FuncOp stackSaveFn = fir::factory::getLlvmStackSave(builder);
735   mlir::SymbolRefAttr stackSaveSym =
736       builder.getSymbolRefAttr(stackSaveFn.getName());
737 
738   builder.setInsertionPoint(oldAlloc);
739   mlir::Value sp =
740       builder
741           .create<fir::CallOp>(oldAlloc.getLoc(),
742                                stackSaveFn.getFunctionType().getResults(),
743                                stackSaveSym, mlir::ValueRange{})
744           .getResult(0);
745 
746   mlir::func::FuncOp stackRestoreFn =
747       fir::factory::getLlvmStackRestore(builder);
748   mlir::SymbolRefAttr stackRestoreSym =
749       builder.getSymbolRefAttr(stackRestoreFn.getName());
750 
751   auto createStackRestoreCall = [&](mlir::Operation *user) {
752     builder.setInsertionPoint(user);
753     builder.create<fir::CallOp>(user->getLoc(),
754                                 stackRestoreFn.getFunctionType().getResults(),
755                                 stackRestoreSym, mlir::ValueRange{sp});
756   };
757 
758   for (mlir::Operation *user : oldAlloc->getUsers()) {
759     if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
760       for (mlir::Operation *user : declareOp->getUsers()) {
761         if (mlir::isa<fir::FreeMemOp>(user))
762           createStackRestoreCall(user);
763       }
764     }
765 
766     if (mlir::isa<fir::FreeMemOp>(user)) {
767       createStackRestoreCall(user);
768     }
769   }
770 
771   rewriter.restoreInsertionPoint(oldPoint);
772 }
773 
774 StackArraysPass::StackArraysPass(const StackArraysPass &pass)
775     : fir::impl::StackArraysBase<StackArraysPass>(pass) {}
776 
777 llvm::StringRef StackArraysPass::getDescription() const {
778   return "Move heap allocated array temporaries to the stack";
779 }
780 
781 void StackArraysPass::runOnOperation() {
782   mlir::ModuleOp mod = getOperation();
783 
784   mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); });
785 }
786 
787 void StackArraysPass::runOnFunc(mlir::Operation *func) {
788   assert(mlir::isa<mlir::func::FuncOp>(func));
789 
790   auto &analysis = getAnalysis<StackArraysAnalysisWrapper>();
791   const StackArraysAnalysisWrapper::AllocMemMap *candidateOps =
792       analysis.getCandidateOps(func);
793   if (!candidateOps) {
794     signalPassFailure();
795     return;
796   }
797 
798   if (candidateOps->empty())
799     return;
800   runCount += candidateOps->size();
801 
802   llvm::SmallVector<mlir::Operation *> opsToConvert;
803   opsToConvert.reserve(candidateOps->size());
804   for (auto [op, _] : *candidateOps)
805     opsToConvert.push_back(op);
806 
807   mlir::MLIRContext &context = getContext();
808   mlir::RewritePatternSet patterns(&context);
809   mlir::GreedyRewriteConfig config;
810   // prevent the pattern driver form merging blocks
811   config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
812 
813   patterns.insert<AllocMemConversion>(&context, *candidateOps);
814   if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert,
815                                                 std::move(patterns), config))) {
816     mlir::emitError(func->getLoc(), "error in stack arrays optimization\n");
817     signalPassFailure();
818   }
819 }
820