xref: /llvm-project/flang/lib/Optimizer/Transforms/StackArrays.cpp (revision 392651a7ec250b82cecabbe9c765c3e5d5677d75)
1 //===- StackArrays.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
11 #include "flang/Optimizer/Dialect/FIRAttr.h"
12 #include "flang/Optimizer/Dialect/FIRDialect.h"
13 #include "flang/Optimizer/Dialect/FIROps.h"
14 #include "flang/Optimizer/Dialect/FIRType.h"
15 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
16 #include "flang/Optimizer/Transforms/Passes.h"
17 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
18 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
19 #include "mlir/Analysis/DataFlow/DenseAnalysis.h"
20 #include "mlir/Analysis/DataFlowFramework.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/Value.h"
26 #include "mlir/Interfaces/LoopLikeInterface.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29 #include "mlir/Transforms/Passes.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/PointerUnion.h"
33 #include "llvm/Support/Casting.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include <optional>
36 
37 namespace fir {
38 #define GEN_PASS_DEF_STACKARRAYS
39 #include "flang/Optimizer/Transforms/Passes.h.inc"
40 } // namespace fir
41 
42 #define DEBUG_TYPE "stack-arrays"
43 
44 static llvm::cl::opt<std::size_t> maxAllocsPerFunc(
45     "stack-arrays-max-allocs",
46     llvm::cl::desc("The maximum number of heap allocations to consider in one "
47                    "function before skipping (to save compilation time). Set "
48                    "to 0 for no limit."),
49     llvm::cl::init(1000), llvm::cl::Hidden);
50 
51 namespace {
52 
53 /// The state of an SSA value at each program point
54 enum class AllocationState {
55   /// This means that the allocation state of a variable cannot be determined
56   /// at this program point, e.g. because one route through a conditional freed
57   /// the variable and the other route didn't.
58   /// This asserts a known-unknown: different from the unknown-unknown of having
59   /// no AllocationState stored for a particular SSA value
60   Unknown,
61   /// Means this SSA value was allocated on the heap in this function and has
62   /// now been freed
63   Freed,
64   /// Means this SSA value was allocated on the heap in this function and is a
65   /// candidate for moving to the stack
66   Allocated,
67 };
68 
69 /// Stores where an alloca should be inserted. If the PointerUnion is an
70 /// Operation the alloca should be inserted /after/ the operation. If it is a
71 /// block, the alloca can be placed anywhere in that block.
72 class InsertionPoint {
73   llvm::PointerUnion<mlir::Operation *, mlir::Block *> location;
74   bool saveRestoreStack;
75 
76   /// Get contained pointer type or nullptr
77   template <class T>
78   T *tryGetPtr() const {
79     // Use llvm::dyn_cast_if_present because location may be null here.
80     if (T *ptr = llvm::dyn_cast_if_present<T *>(location))
81       return ptr;
82     return nullptr;
83   }
84 
85 public:
86   template <class T>
87   InsertionPoint(T *ptr, bool saveRestoreStack = false)
88       : location(ptr), saveRestoreStack{saveRestoreStack} {}
89   InsertionPoint(std::nullptr_t null)
90       : location(null), saveRestoreStack{false} {}
91 
92   /// Get contained operation, or nullptr
93   mlir::Operation *tryGetOperation() const {
94     return tryGetPtr<mlir::Operation>();
95   }
96 
97   /// Get contained block, or nullptr
98   mlir::Block *tryGetBlock() const { return tryGetPtr<mlir::Block>(); }
99 
100   /// Get whether the stack should be saved/restored. If yes, an llvm.stacksave
101   /// intrinsic should be added before the alloca, and an llvm.stackrestore
102   /// intrinsic should be added where the freemem is
103   bool shouldSaveRestoreStack() const { return saveRestoreStack; }
104 
105   operator bool() const { return tryGetOperation() || tryGetBlock(); }
106 
107   bool operator==(const InsertionPoint &rhs) const {
108     return (location == rhs.location) &&
109            (saveRestoreStack == rhs.saveRestoreStack);
110   }
111 
112   bool operator!=(const InsertionPoint &rhs) const { return !(*this == rhs); }
113 };
114 
115 /// Maps SSA values to their AllocationState at a particular program point.
116 /// Also caches the insertion points for the new alloca operations
117 class LatticePoint : public mlir::dataflow::AbstractDenseLattice {
118   // Maps all values we are interested in to states
119   llvm::SmallDenseMap<mlir::Value, AllocationState, 1> stateMap;
120 
121 public:
122   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LatticePoint)
123   using AbstractDenseLattice::AbstractDenseLattice;
124 
125   bool operator==(const LatticePoint &rhs) const {
126     return stateMap == rhs.stateMap;
127   }
128 
129   /// Join the lattice accross control-flow edges
130   mlir::ChangeResult join(const AbstractDenseLattice &lattice) override;
131 
132   void print(llvm::raw_ostream &os) const override;
133 
134   /// Clear all modifications
135   mlir::ChangeResult reset();
136 
137   /// Set the state of an SSA value
138   mlir::ChangeResult set(mlir::Value value, AllocationState state);
139 
140   /// Get fir.allocmem ops which were allocated in this function and always
141   /// freed before the function returns, plus whre to insert replacement
142   /// fir.alloca ops
143   void appendFreedValues(llvm::DenseSet<mlir::Value> &out) const;
144 
145   std::optional<AllocationState> get(mlir::Value val) const;
146 };
147 
148 class AllocationAnalysis
149     : public mlir::dataflow::DenseForwardDataFlowAnalysis<LatticePoint> {
150 public:
151   using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis;
152 
153   mlir::LogicalResult visitOperation(mlir::Operation *op,
154                                      const LatticePoint &before,
155                                      LatticePoint *after) override;
156 
157   /// At an entry point, the last modifications of all memory resources are
158   /// yet to be determined
159   void setToEntryState(LatticePoint *lattice) override;
160 
161 protected:
162   /// Visit control flow operations and decide whether to call visitOperation
163   /// to apply the transfer function
164   mlir::LogicalResult processOperation(mlir::Operation *op) override;
165 };
166 
167 /// Drives analysis to find candidate fir.allocmem operations which could be
168 /// moved to the stack. Intended to be used with mlir::Pass::getAnalysis
169 class StackArraysAnalysisWrapper {
170 public:
171   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StackArraysAnalysisWrapper)
172 
173   // Maps fir.allocmem -> place to insert alloca
174   using AllocMemMap = llvm::DenseMap<mlir::Operation *, InsertionPoint>;
175 
176   StackArraysAnalysisWrapper(mlir::Operation *op) {}
177 
178   // returns nullptr if analysis failed
179   const AllocMemMap *getCandidateOps(mlir::Operation *func);
180 
181 private:
182   llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps;
183 
184   llvm::LogicalResult analyseFunction(mlir::Operation *func);
185 };
186 
187 /// Converts a fir.allocmem to a fir.alloca
188 class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
189 public:
190   explicit AllocMemConversion(
191       mlir::MLIRContext *ctx,
192       const StackArraysAnalysisWrapper::AllocMemMap &candidateOps)
193       : OpRewritePattern(ctx), candidateOps{candidateOps} {}
194 
195   llvm::LogicalResult
196   matchAndRewrite(fir::AllocMemOp allocmem,
197                   mlir::PatternRewriter &rewriter) const override;
198 
199   /// Determine where to insert the alloca operation. The returned value should
200   /// be checked to see if it is inside a loop
201   static InsertionPoint findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc);
202 
203 private:
204   /// Handle to the DFA (already run)
205   const StackArraysAnalysisWrapper::AllocMemMap &candidateOps;
206 
207   /// If we failed to find an insertion point not inside a loop, see if it would
208   /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop
209   static InsertionPoint findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc);
210 
211   /// Returns the alloca if it was successfully inserted, otherwise {}
212   std::optional<fir::AllocaOp>
213   insertAlloca(fir::AllocMemOp &oldAlloc,
214                mlir::PatternRewriter &rewriter) const;
215 
216   /// Inserts a stacksave before oldAlloc and a stackrestore after each freemem
217   void insertStackSaveRestore(fir::AllocMemOp &oldAlloc,
218                               mlir::PatternRewriter &rewriter) const;
219 };
220 
221 class StackArraysPass : public fir::impl::StackArraysBase<StackArraysPass> {
222 public:
223   StackArraysPass() = default;
224   StackArraysPass(const StackArraysPass &pass);
225 
226   llvm::StringRef getDescription() const override;
227 
228   void runOnOperation() override;
229 
230 private:
231   Statistic runCount{this, "stackArraysRunCount",
232                      "Number of heap allocations moved to the stack"};
233 };
234 
235 } // namespace
236 
237 static void print(llvm::raw_ostream &os, AllocationState state) {
238   switch (state) {
239   case AllocationState::Unknown:
240     os << "Unknown";
241     break;
242   case AllocationState::Freed:
243     os << "Freed";
244     break;
245   case AllocationState::Allocated:
246     os << "Allocated";
247     break;
248   }
249 }
250 
251 /// Join two AllocationStates for the same value coming from different CFG
252 /// blocks
253 static AllocationState join(AllocationState lhs, AllocationState rhs) {
254   //           | Allocated | Freed     | Unknown
255   // ========= | ========= | ========= | =========
256   // Allocated | Allocated | Unknown   | Unknown
257   // Freed     | Unknown   | Freed     | Unknown
258   // Unknown   | Unknown   | Unknown   | Unknown
259   if (lhs == rhs)
260     return lhs;
261   return AllocationState::Unknown;
262 }
263 
264 mlir::ChangeResult LatticePoint::join(const AbstractDenseLattice &lattice) {
265   const auto &rhs = static_cast<const LatticePoint &>(lattice);
266   mlir::ChangeResult changed = mlir::ChangeResult::NoChange;
267 
268   // add everything from rhs to map, handling cases where values are in both
269   for (const auto &[value, rhsState] : rhs.stateMap) {
270     auto it = stateMap.find(value);
271     if (it != stateMap.end()) {
272       // value is present in both maps
273       AllocationState myState = it->second;
274       AllocationState newState = ::join(myState, rhsState);
275       if (newState != myState) {
276         changed = mlir::ChangeResult::Change;
277         it->getSecond() = newState;
278       }
279     } else {
280       // value not present in current map: add it
281       stateMap.insert({value, rhsState});
282       changed = mlir::ChangeResult::Change;
283     }
284   }
285 
286   return changed;
287 }
288 
289 void LatticePoint::print(llvm::raw_ostream &os) const {
290   for (const auto &[value, state] : stateMap) {
291     os << "\n * " << value << ": ";
292     ::print(os, state);
293   }
294 }
295 
296 mlir::ChangeResult LatticePoint::reset() {
297   if (stateMap.empty())
298     return mlir::ChangeResult::NoChange;
299   stateMap.clear();
300   return mlir::ChangeResult::Change;
301 }
302 
303 mlir::ChangeResult LatticePoint::set(mlir::Value value, AllocationState state) {
304   if (stateMap.count(value)) {
305     // already in map
306     AllocationState &oldState = stateMap[value];
307     if (oldState != state) {
308       stateMap[value] = state;
309       return mlir::ChangeResult::Change;
310     }
311     return mlir::ChangeResult::NoChange;
312   }
313   stateMap.insert({value, state});
314   return mlir::ChangeResult::Change;
315 }
316 
317 /// Get values which were allocated in this function and always freed before
318 /// the function returns
319 void LatticePoint::appendFreedValues(llvm::DenseSet<mlir::Value> &out) const {
320   for (auto &[value, state] : stateMap) {
321     if (state == AllocationState::Freed)
322       out.insert(value);
323   }
324 }
325 
326 std::optional<AllocationState> LatticePoint::get(mlir::Value val) const {
327   auto it = stateMap.find(val);
328   if (it == stateMap.end())
329     return {};
330   return it->second;
331 }
332 
333 mlir::LogicalResult AllocationAnalysis::visitOperation(
334     mlir::Operation *op, const LatticePoint &before, LatticePoint *after) {
335   LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op
336                           << "\n");
337   LLVM_DEBUG(llvm::dbgs() << "--Lattice in: " << before << "\n");
338 
339   // propagate before -> after
340   mlir::ChangeResult changed = after->join(before);
341 
342   if (auto allocmem = mlir::dyn_cast<fir::AllocMemOp>(op)) {
343     assert(op->getNumResults() == 1 && "fir.allocmem has one result");
344     auto attr = op->getAttrOfType<fir::MustBeHeapAttr>(
345         fir::MustBeHeapAttr::getAttrName());
346     if (attr && attr.getValue()) {
347       LLVM_DEBUG(llvm::dbgs() << "--Found fir.must_be_heap: skipping\n");
348       // skip allocation marked not to be moved
349       return mlir::success();
350     }
351 
352     auto retTy = allocmem.getAllocatedType();
353     if (!mlir::isa<fir::SequenceType>(retTy)) {
354       LLVM_DEBUG(llvm::dbgs()
355                  << "--Allocation is not for an array: skipping\n");
356       return mlir::success();
357     }
358 
359     mlir::Value result = op->getResult(0);
360     changed |= after->set(result, AllocationState::Allocated);
361   } else if (mlir::isa<fir::FreeMemOp>(op)) {
362     assert(op->getNumOperands() == 1 && "fir.freemem has one operand");
363     mlir::Value operand = op->getOperand(0);
364 
365     // Note: StackArrays is scheduled in the pass pipeline after lowering hlfir
366     // to fir. Therefore, we only need to handle `fir::DeclareOp`s.
367     if (auto declareOp =
368             llvm::dyn_cast_if_present<fir::DeclareOp>(operand.getDefiningOp()))
369       operand = declareOp.getMemref();
370 
371     std::optional<AllocationState> operandState = before.get(operand);
372     if (operandState && *operandState == AllocationState::Allocated) {
373       // don't tag things not allocated in this function as freed, so that we
374       // don't think they are candidates for moving to the stack
375       changed |= after->set(operand, AllocationState::Freed);
376     }
377   } else if (mlir::isa<fir::ResultOp>(op)) {
378     mlir::Operation *parent = op->getParentOp();
379     LatticePoint *parentLattice = getLattice(getProgramPointAfter(parent));
380     assert(parentLattice);
381     mlir::ChangeResult parentChanged = parentLattice->join(*after);
382     propagateIfChanged(parentLattice, parentChanged);
383   }
384 
385   // we pass lattices straight through fir.call because called functions should
386   // not deallocate flang-generated array temporaries
387 
388   LLVM_DEBUG(llvm::dbgs() << "--Lattice out: " << *after << "\n");
389   propagateIfChanged(after, changed);
390   return mlir::success();
391 }
392 
393 void AllocationAnalysis::setToEntryState(LatticePoint *lattice) {
394   propagateIfChanged(lattice, lattice->reset());
395 }
396 
397 /// Mostly a copy of AbstractDenseLattice::processOperation - the difference
398 /// being that call operations are passed through to the transfer function
399 mlir::LogicalResult AllocationAnalysis::processOperation(mlir::Operation *op) {
400   mlir::ProgramPoint *point = getProgramPointAfter(op);
401   // If the containing block is not executable, bail out.
402   if (op->getBlock() != nullptr &&
403       !getOrCreateFor<mlir::dataflow::Executable>(
404            point, getProgramPointBefore(op->getBlock()))
405            ->isLive())
406     return mlir::success();
407 
408   // Get the dense lattice to update
409   mlir::dataflow::AbstractDenseLattice *after = getLattice(point);
410 
411   // If this op implements region control-flow, then control-flow dictates its
412   // transfer function.
413   if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) {
414     visitRegionBranchOperation(point, branch, after);
415     return mlir::success();
416   }
417 
418   // pass call operations through to the transfer function
419 
420   // Get the dense state before the execution of the op.
421   const mlir::dataflow::AbstractDenseLattice *before =
422       getLatticeFor(point, getProgramPointBefore(op));
423 
424   /// Invoke the operation transfer function
425   return visitOperationImpl(op, *before, after);
426 }
427 
428 llvm::LogicalResult
429 StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
430   assert(mlir::isa<mlir::func::FuncOp>(func));
431   size_t nAllocs = 0;
432   func->walk([&nAllocs](fir::AllocMemOp) { nAllocs++; });
433   // don't bother with the analysis if there are no heap allocations
434   if (nAllocs == 0)
435     return mlir::success();
436   if ((maxAllocsPerFunc != 0) && (nAllocs > maxAllocsPerFunc)) {
437     LLVM_DEBUG(llvm::dbgs() << "Skipping stack arrays for function with "
438                             << nAllocs << " heap allocations");
439     return mlir::success();
440   }
441 
442   mlir::DataFlowSolver solver;
443   // constant propagation is required for dead code analysis, dead code analysis
444   // is required to mark blocks live (required for mlir dense dfa)
445   solver.load<mlir::dataflow::SparseConstantPropagation>();
446   solver.load<mlir::dataflow::DeadCodeAnalysis>();
447 
448   auto [it, inserted] = funcMaps.try_emplace(func);
449   AllocMemMap &candidateOps = it->second;
450 
451   solver.load<AllocationAnalysis>();
452   if (failed(solver.initializeAndRun(func))) {
453     llvm::errs() << "DataFlowSolver failed!";
454     return mlir::failure();
455   }
456 
457   LatticePoint point{solver.getProgramPointAfter(func)};
458   auto joinOperationLattice = [&](mlir::Operation *op) {
459     const LatticePoint *lattice =
460         solver.lookupState<LatticePoint>(solver.getProgramPointAfter(op));
461     // there will be no lattice for an unreachable block
462     if (lattice)
463       (void)point.join(*lattice);
464   };
465 
466   func->walk([&](mlir::func::ReturnOp child) { joinOperationLattice(child); });
467   func->walk([&](fir::UnreachableOp child) { joinOperationLattice(child); });
468   func->walk(
469       [&](mlir::omp::TerminatorOp child) { joinOperationLattice(child); });
470   func->walk([&](mlir::omp::YieldOp child) { joinOperationLattice(child); });
471 
472   llvm::DenseSet<mlir::Value> freedValues;
473   point.appendFreedValues(freedValues);
474 
475   // We only replace allocations which are definately freed on all routes
476   // through the function because otherwise the allocation may have an intende
477   // lifetime longer than the current stack frame (e.g. a heap allocation which
478   // is then freed by another function).
479   for (mlir::Value freedValue : freedValues) {
480     fir::AllocMemOp allocmem = freedValue.getDefiningOp<fir::AllocMemOp>();
481     InsertionPoint insertionPoint =
482         AllocMemConversion::findAllocaInsertionPoint(allocmem);
483     if (insertionPoint)
484       candidateOps.insert({allocmem, insertionPoint});
485   }
486 
487   LLVM_DEBUG(for (auto [allocMemOp, _]
488                   : candidateOps) {
489     llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n';
490   });
491   return mlir::success();
492 }
493 
494 const StackArraysAnalysisWrapper::AllocMemMap *
495 StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) {
496   if (!funcMaps.contains(func))
497     if (mlir::failed(analyseFunction(func)))
498       return nullptr;
499   return &funcMaps[func];
500 }
501 
502 /// Restore the old allocation type exected by existing code
503 static mlir::Value convertAllocationType(mlir::PatternRewriter &rewriter,
504                                          const mlir::Location &loc,
505                                          mlir::Value heap, mlir::Value stack) {
506   mlir::Type heapTy = heap.getType();
507   mlir::Type stackTy = stack.getType();
508 
509   if (heapTy == stackTy)
510     return stack;
511 
512   fir::HeapType firHeapTy = mlir::cast<fir::HeapType>(heapTy);
513   LLVM_ATTRIBUTE_UNUSED fir::ReferenceType firRefTy =
514       mlir::cast<fir::ReferenceType>(stackTy);
515   assert(firHeapTy.getElementType() == firRefTy.getElementType() &&
516          "Allocations must have the same type");
517 
518   auto insertionPoint = rewriter.saveInsertionPoint();
519   rewriter.setInsertionPointAfter(stack.getDefiningOp());
520   mlir::Value conv =
521       rewriter.create<fir::ConvertOp>(loc, firHeapTy, stack).getResult();
522   rewriter.restoreInsertionPoint(insertionPoint);
523   return conv;
524 }
525 
526 llvm::LogicalResult
527 AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
528                                     mlir::PatternRewriter &rewriter) const {
529   auto oldInsertionPt = rewriter.saveInsertionPoint();
530   // add alloca operation
531   std::optional<fir::AllocaOp> alloca = insertAlloca(allocmem, rewriter);
532   rewriter.restoreInsertionPoint(oldInsertionPt);
533   if (!alloca)
534     return mlir::failure();
535 
536   // remove freemem operations
537   llvm::SmallVector<mlir::Operation *> erases;
538   for (mlir::Operation *user : allocmem.getOperation()->getUsers()) {
539     if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
540       for (mlir::Operation *user : declareOp->getUsers()) {
541         if (mlir::isa<fir::FreeMemOp>(user))
542           erases.push_back(user);
543       }
544     }
545 
546     if (mlir::isa<fir::FreeMemOp>(user))
547       erases.push_back(user);
548   }
549 
550   // now we are done iterating the users, it is safe to mutate them
551   for (mlir::Operation *erase : erases)
552     rewriter.eraseOp(erase);
553 
554   // replace references to heap allocation with references to stack allocation
555   mlir::Value newValue = convertAllocationType(
556       rewriter, allocmem.getLoc(), allocmem.getResult(), alloca->getResult());
557   rewriter.replaceAllUsesWith(allocmem.getResult(), newValue);
558 
559   // remove allocmem operation
560   rewriter.eraseOp(allocmem.getOperation());
561 
562   return mlir::success();
563 }
564 
565 static bool isInLoop(mlir::Block *block) {
566   return mlir::LoopLikeOpInterface::blockIsInLoop(block);
567 }
568 
569 static bool isInLoop(mlir::Operation *op) {
570   return isInLoop(op->getBlock()) ||
571          op->getParentOfType<mlir::LoopLikeOpInterface>();
572 }
573 
574 InsertionPoint
575 AllocMemConversion::findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc) {
576   // Ideally the alloca should be inserted at the end of the function entry
577   // block so that we do not allocate stack space in a loop. However,
578   // the operands to the alloca may not be available that early, so insert it
579   // after the last operand becomes available
580   // If the old allocmem op was in an openmp region then it should not be moved
581   // outside of that
582   LLVM_DEBUG(llvm::dbgs() << "StackArrays: findAllocaInsertionPoint: "
583                           << oldAlloc << "\n");
584 
585   // check that an Operation or Block we are about to return is not in a loop
586   auto checkReturn = [&](auto *point) -> InsertionPoint {
587     if (isInLoop(point)) {
588       mlir::Operation *oldAllocOp = oldAlloc.getOperation();
589       if (isInLoop(oldAllocOp)) {
590         // where we want to put it is in a loop, and even the old location is in
591         // a loop. Give up.
592         return findAllocaLoopInsertionPoint(oldAlloc);
593       }
594       return {oldAllocOp};
595     }
596     return {point};
597   };
598 
599   auto oldOmpRegion =
600       oldAlloc->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>();
601 
602   // Find when the last operand value becomes available
603   mlir::Block *operandsBlock = nullptr;
604   mlir::Operation *lastOperand = nullptr;
605   for (mlir::Value operand : oldAlloc.getOperands()) {
606     LLVM_DEBUG(llvm::dbgs() << "--considering operand " << operand << "\n");
607     mlir::Operation *op = operand.getDefiningOp();
608     if (!op)
609       return checkReturn(oldAlloc.getOperation());
610     if (!operandsBlock)
611       operandsBlock = op->getBlock();
612     else if (operandsBlock != op->getBlock()) {
613       LLVM_DEBUG(llvm::dbgs()
614                  << "----operand declared in a different block!\n");
615       // Operation::isBeforeInBlock requires the operations to be in the same
616       // block. The best we can do is the location of the allocmem.
617       return checkReturn(oldAlloc.getOperation());
618     }
619     if (!lastOperand || lastOperand->isBeforeInBlock(op))
620       lastOperand = op;
621   }
622 
623   if (lastOperand) {
624     // there were value operands to the allocmem so insert after the last one
625     LLVM_DEBUG(llvm::dbgs()
626                << "--Placing after last operand: " << *lastOperand << "\n");
627     // check we aren't moving out of an omp region
628     auto lastOpOmpRegion =
629         lastOperand->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>();
630     if (lastOpOmpRegion == oldOmpRegion)
631       return checkReturn(lastOperand);
632     // Presumably this happened because the operands became ready before the
633     // start of this openmp region. (lastOpOmpRegion != oldOmpRegion) should
634     // imply that oldOmpRegion comes after lastOpOmpRegion.
635     return checkReturn(oldOmpRegion.getAllocaBlock());
636   }
637 
638   // There were no value operands to the allocmem so we are safe to insert it
639   // as early as we want
640 
641   // handle openmp case
642   if (oldOmpRegion)
643     return checkReturn(oldOmpRegion.getAllocaBlock());
644 
645   // fall back to the function entry block
646   mlir::func::FuncOp func = oldAlloc->getParentOfType<mlir::func::FuncOp>();
647   assert(func && "This analysis is run on func.func");
648   mlir::Block &entryBlock = func.getBlocks().front();
649   LLVM_DEBUG(llvm::dbgs() << "--Placing at the start of func entry block\n");
650   return checkReturn(&entryBlock);
651 }
652 
653 InsertionPoint
654 AllocMemConversion::findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc) {
655   mlir::Operation *oldAllocOp = oldAlloc;
656   // This is only called as a last resort. We should try to insert at the
657   // location of the old allocation, which is inside of a loop, using
658   // llvm.stacksave/llvm.stackrestore
659 
660   // find freemem ops
661   llvm::SmallVector<mlir::Operation *, 1> freeOps;
662 
663   for (mlir::Operation *user : oldAllocOp->getUsers()) {
664     if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
665       for (mlir::Operation *user : declareOp->getUsers()) {
666         if (mlir::isa<fir::FreeMemOp>(user))
667           freeOps.push_back(user);
668       }
669     }
670 
671     if (mlir::isa<fir::FreeMemOp>(user))
672       freeOps.push_back(user);
673   }
674 
675   assert(freeOps.size() && "DFA should only return freed memory");
676 
677   // Don't attempt to reason about a stacksave/stackrestore between different
678   // blocks
679   for (mlir::Operation *free : freeOps)
680     if (free->getBlock() != oldAllocOp->getBlock())
681       return {nullptr};
682 
683   // Check that there aren't any other stack allocations in between the
684   // stack save and stack restore
685   // note: for flang generated temporaries there should only be one free op
686   for (mlir::Operation *free : freeOps) {
687     for (mlir::Operation *op = oldAlloc; op && op != free;
688          op = op->getNextNode()) {
689       if (mlir::isa<fir::AllocaOp>(op))
690         return {nullptr};
691     }
692   }
693 
694   return InsertionPoint{oldAllocOp, /*shouldStackSaveRestore=*/true};
695 }
696 
697 std::optional<fir::AllocaOp>
698 AllocMemConversion::insertAlloca(fir::AllocMemOp &oldAlloc,
699                                  mlir::PatternRewriter &rewriter) const {
700   auto it = candidateOps.find(oldAlloc.getOperation());
701   if (it == candidateOps.end())
702     return {};
703   InsertionPoint insertionPoint = it->second;
704   if (!insertionPoint)
705     return {};
706 
707   if (insertionPoint.shouldSaveRestoreStack())
708     insertStackSaveRestore(oldAlloc, rewriter);
709 
710   mlir::Location loc = oldAlloc.getLoc();
711   mlir::Type varTy = oldAlloc.getInType();
712   if (mlir::Operation *op = insertionPoint.tryGetOperation()) {
713     rewriter.setInsertionPointAfter(op);
714   } else {
715     mlir::Block *block = insertionPoint.tryGetBlock();
716     assert(block && "There must be a valid insertion point");
717     rewriter.setInsertionPointToStart(block);
718   }
719 
720   auto unpackName = [](std::optional<llvm::StringRef> opt) -> llvm::StringRef {
721     if (opt)
722       return *opt;
723     return {};
724   };
725 
726   llvm::StringRef uniqName = unpackName(oldAlloc.getUniqName());
727   llvm::StringRef bindcName = unpackName(oldAlloc.getBindcName());
728   return rewriter.create<fir::AllocaOp>(loc, varTy, uniqName, bindcName,
729                                         oldAlloc.getTypeparams(),
730                                         oldAlloc.getShape());
731 }
732 
733 void AllocMemConversion::insertStackSaveRestore(
734     fir::AllocMemOp &oldAlloc, mlir::PatternRewriter &rewriter) const {
735   auto oldPoint = rewriter.saveInsertionPoint();
736   auto mod = oldAlloc->getParentOfType<mlir::ModuleOp>();
737   fir::FirOpBuilder builder{rewriter, mod};
738 
739   builder.setInsertionPoint(oldAlloc);
740   mlir::Value sp = builder.genStackSave(oldAlloc.getLoc());
741 
742   auto createStackRestoreCall = [&](mlir::Operation *user) {
743     builder.setInsertionPoint(user);
744     builder.genStackRestore(user->getLoc(), sp);
745   };
746 
747   for (mlir::Operation *user : oldAlloc->getUsers()) {
748     if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
749       for (mlir::Operation *user : declareOp->getUsers()) {
750         if (mlir::isa<fir::FreeMemOp>(user))
751           createStackRestoreCall(user);
752       }
753     }
754 
755     if (mlir::isa<fir::FreeMemOp>(user)) {
756       createStackRestoreCall(user);
757     }
758   }
759 
760   rewriter.restoreInsertionPoint(oldPoint);
761 }
762 
763 StackArraysPass::StackArraysPass(const StackArraysPass &pass)
764     : fir::impl::StackArraysBase<StackArraysPass>(pass) {}
765 
766 llvm::StringRef StackArraysPass::getDescription() const {
767   return "Move heap allocated array temporaries to the stack";
768 }
769 
770 void StackArraysPass::runOnOperation() {
771   mlir::func::FuncOp func = getOperation();
772 
773   auto &analysis = getAnalysis<StackArraysAnalysisWrapper>();
774   const StackArraysAnalysisWrapper::AllocMemMap *candidateOps =
775       analysis.getCandidateOps(func);
776   if (!candidateOps) {
777     signalPassFailure();
778     return;
779   }
780 
781   if (candidateOps->empty())
782     return;
783   runCount += candidateOps->size();
784 
785   llvm::SmallVector<mlir::Operation *> opsToConvert;
786   opsToConvert.reserve(candidateOps->size());
787   for (auto [op, _] : *candidateOps)
788     opsToConvert.push_back(op);
789 
790   mlir::MLIRContext &context = getContext();
791   mlir::RewritePatternSet patterns(&context);
792   mlir::GreedyRewriteConfig config;
793   // prevent the pattern driver form merging blocks
794   config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
795 
796   patterns.insert<AllocMemConversion>(&context, *candidateOps);
797   if (mlir::failed(mlir::applyOpPatternsGreedily(
798           opsToConvert, std::move(patterns), config))) {
799     mlir::emitError(func->getLoc(), "error in stack arrays optimization\n");
800     signalPassFailure();
801   }
802 }
803