xref: /llvm-project/flang/lib/Optimizer/Transforms/StackArrays.cpp (revision e3cd88a7be1dfd912bb6e7c7e888e7b442ffb5de)
1cc14bf22STom Eccles //===- StackArrays.cpp ----------------------------------------------------===//
2cc14bf22STom Eccles //
3cc14bf22STom Eccles // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4cc14bf22STom Eccles // See https://llvm.org/LICENSE.txt for license information.
5cc14bf22STom Eccles // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6cc14bf22STom Eccles //
7cc14bf22STom Eccles //===----------------------------------------------------------------------===//
8cc14bf22STom Eccles 
9cc14bf22STom Eccles #include "flang/Optimizer/Builder/FIRBuilder.h"
10cc14bf22STom Eccles #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
11cc14bf22STom Eccles #include "flang/Optimizer/Dialect/FIRAttr.h"
12cc14bf22STom Eccles #include "flang/Optimizer/Dialect/FIRDialect.h"
13cc14bf22STom Eccles #include "flang/Optimizer/Dialect/FIROps.h"
14cc14bf22STom Eccles #include "flang/Optimizer/Dialect/FIRType.h"
15b07ef9e7SRenaud-K #include "flang/Optimizer/Dialect/Support/FIRContext.h"
16cc14bf22STom Eccles #include "flang/Optimizer/Transforms/Passes.h"
17cc14bf22STom Eccles #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
18cc14bf22STom Eccles #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
19cc14bf22STom Eccles #include "mlir/Analysis/DataFlow/DenseAnalysis.h"
20cc14bf22STom Eccles #include "mlir/Analysis/DataFlowFramework.h"
21cc14bf22STom Eccles #include "mlir/Dialect/Func/IR/FuncOps.h"
22cc14bf22STom Eccles #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
23cc14bf22STom Eccles #include "mlir/IR/Builders.h"
24cc14bf22STom Eccles #include "mlir/IR/Diagnostics.h"
25cc14bf22STom Eccles #include "mlir/IR/Value.h"
26cc14bf22STom Eccles #include "mlir/Interfaces/LoopLikeInterface.h"
27cc14bf22STom Eccles #include "mlir/Pass/Pass.h"
28408f4196STom Eccles #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29cc14bf22STom Eccles #include "mlir/Transforms/Passes.h"
30cc14bf22STom Eccles #include "llvm/ADT/DenseMap.h"
31cc14bf22STom Eccles #include "llvm/ADT/DenseSet.h"
32cc14bf22STom Eccles #include "llvm/ADT/PointerUnion.h"
33cc14bf22STom Eccles #include "llvm/Support/Casting.h"
34cc14bf22STom Eccles #include "llvm/Support/raw_ostream.h"
35cc14bf22STom Eccles #include <optional>
36cc14bf22STom Eccles 
37cc14bf22STom Eccles namespace fir {
38cc14bf22STom Eccles #define GEN_PASS_DEF_STACKARRAYS
39cc14bf22STom Eccles #include "flang/Optimizer/Transforms/Passes.h.inc"
40cc14bf22STom Eccles } // namespace fir
41cc14bf22STom Eccles 
42cc14bf22STom Eccles #define DEBUG_TYPE "stack-arrays"
43cc14bf22STom Eccles 
44e2153241STom Eccles static llvm::cl::opt<std::size_t> maxAllocsPerFunc(
45e2153241STom Eccles     "stack-arrays-max-allocs",
46e2153241STom Eccles     llvm::cl::desc("The maximum number of heap allocations to consider in one "
47e2153241STom Eccles                    "function before skipping (to save compilation time). Set "
48e2153241STom Eccles                    "to 0 for no limit."),
49e2153241STom Eccles     llvm::cl::init(1000), llvm::cl::Hidden);
50e2153241STom Eccles 
51cc14bf22STom Eccles namespace {
52cc14bf22STom Eccles 
53cc14bf22STom Eccles /// The state of an SSA value at each program point
54cc14bf22STom Eccles enum class AllocationState {
55cc14bf22STom Eccles   /// This means that the allocation state of a variable cannot be determined
56cc14bf22STom Eccles   /// at this program point, e.g. because one route through a conditional freed
57cc14bf22STom Eccles   /// the variable and the other route didn't.
58cc14bf22STom Eccles   /// This asserts a known-unknown: different from the unknown-unknown of having
59cc14bf22STom Eccles   /// no AllocationState stored for a particular SSA value
60cc14bf22STom Eccles   Unknown,
61cc14bf22STom Eccles   /// Means this SSA value was allocated on the heap in this function and has
62cc14bf22STom Eccles   /// now been freed
63cc14bf22STom Eccles   Freed,
64cc14bf22STom Eccles   /// Means this SSA value was allocated on the heap in this function and is a
65cc14bf22STom Eccles   /// candidate for moving to the stack
66cc14bf22STom Eccles   Allocated,
67cc14bf22STom Eccles };
68cc14bf22STom Eccles 
69cc14bf22STom Eccles /// Stores where an alloca should be inserted. If the PointerUnion is an
70cc14bf22STom Eccles /// Operation the alloca should be inserted /after/ the operation. If it is a
71cc14bf22STom Eccles /// block, the alloca can be placed anywhere in that block.
72cc14bf22STom Eccles class InsertionPoint {
73cc14bf22STom Eccles   llvm::PointerUnion<mlir::Operation *, mlir::Block *> location;
74cc14bf22STom Eccles   bool saveRestoreStack;
75cc14bf22STom Eccles 
76cc14bf22STom Eccles   /// Get contained pointer type or nullptr
77cc14bf22STom Eccles   template <class T>
78cc14bf22STom Eccles   T *tryGetPtr() const {
79392651a7SKazu Hirata     // Use llvm::dyn_cast_if_present because location may be null here.
80392651a7SKazu Hirata     if (T *ptr = llvm::dyn_cast_if_present<T *>(location))
81392651a7SKazu Hirata       return ptr;
82cc14bf22STom Eccles     return nullptr;
83cc14bf22STom Eccles   }
84cc14bf22STom Eccles 
85cc14bf22STom Eccles public:
86cc14bf22STom Eccles   template <class T>
87cc14bf22STom Eccles   InsertionPoint(T *ptr, bool saveRestoreStack = false)
88cc14bf22STom Eccles       : location(ptr), saveRestoreStack{saveRestoreStack} {}
89cc14bf22STom Eccles   InsertionPoint(std::nullptr_t null)
90cc14bf22STom Eccles       : location(null), saveRestoreStack{false} {}
91cc14bf22STom Eccles 
92cc14bf22STom Eccles   /// Get contained operation, or nullptr
93cc14bf22STom Eccles   mlir::Operation *tryGetOperation() const {
94cc14bf22STom Eccles     return tryGetPtr<mlir::Operation>();
95cc14bf22STom Eccles   }
96cc14bf22STom Eccles 
97cc14bf22STom Eccles   /// Get contained block, or nullptr
98cc14bf22STom Eccles   mlir::Block *tryGetBlock() const { return tryGetPtr<mlir::Block>(); }
99cc14bf22STom Eccles 
100cc14bf22STom Eccles   /// Get whether the stack should be saved/restored. If yes, an llvm.stacksave
101cc14bf22STom Eccles   /// intrinsic should be added before the alloca, and an llvm.stackrestore
102cc14bf22STom Eccles   /// intrinsic should be added where the freemem is
103cc14bf22STom Eccles   bool shouldSaveRestoreStack() const { return saveRestoreStack; }
104cc14bf22STom Eccles 
105cc14bf22STom Eccles   operator bool() const { return tryGetOperation() || tryGetBlock(); }
106cc14bf22STom Eccles 
107cc14bf22STom Eccles   bool operator==(const InsertionPoint &rhs) const {
108cc14bf22STom Eccles     return (location == rhs.location) &&
109cc14bf22STom Eccles            (saveRestoreStack == rhs.saveRestoreStack);
110cc14bf22STom Eccles   }
111cc14bf22STom Eccles 
112cc14bf22STom Eccles   bool operator!=(const InsertionPoint &rhs) const { return !(*this == rhs); }
113cc14bf22STom Eccles };
114cc14bf22STom Eccles 
115cc14bf22STom Eccles /// Maps SSA values to their AllocationState at a particular program point.
116cc14bf22STom Eccles /// Also caches the insertion points for the new alloca operations
117cc14bf22STom Eccles class LatticePoint : public mlir::dataflow::AbstractDenseLattice {
118cc14bf22STom Eccles   // Maps all values we are interested in to states
119cc14bf22STom Eccles   llvm::SmallDenseMap<mlir::Value, AllocationState, 1> stateMap;
120cc14bf22STom Eccles 
121cc14bf22STom Eccles public:
122cc14bf22STom Eccles   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LatticePoint)
123cc14bf22STom Eccles   using AbstractDenseLattice::AbstractDenseLattice;
124cc14bf22STom Eccles 
125cc14bf22STom Eccles   bool operator==(const LatticePoint &rhs) const {
126cc14bf22STom Eccles     return stateMap == rhs.stateMap;
127cc14bf22STom Eccles   }
128cc14bf22STom Eccles 
129cc14bf22STom Eccles   /// Join the lattice accross control-flow edges
130cc14bf22STom Eccles   mlir::ChangeResult join(const AbstractDenseLattice &lattice) override;
131cc14bf22STom Eccles 
132cc14bf22STom Eccles   void print(llvm::raw_ostream &os) const override;
133cc14bf22STom Eccles 
134cc14bf22STom Eccles   /// Clear all modifications
135cc14bf22STom Eccles   mlir::ChangeResult reset();
136cc14bf22STom Eccles 
137cc14bf22STom Eccles   /// Set the state of an SSA value
138cc14bf22STom Eccles   mlir::ChangeResult set(mlir::Value value, AllocationState state);
139cc14bf22STom Eccles 
140cc14bf22STom Eccles   /// Get fir.allocmem ops which were allocated in this function and always
141cc14bf22STom Eccles   /// freed before the function returns, plus whre to insert replacement
142cc14bf22STom Eccles   /// fir.alloca ops
143cc14bf22STom Eccles   void appendFreedValues(llvm::DenseSet<mlir::Value> &out) const;
144cc14bf22STom Eccles 
145cc14bf22STom Eccles   std::optional<AllocationState> get(mlir::Value val) const;
146cc14bf22STom Eccles };
147cc14bf22STom Eccles 
148cc14bf22STom Eccles class AllocationAnalysis
149b2b7efb9SAlex Zinenko     : public mlir::dataflow::DenseForwardDataFlowAnalysis<LatticePoint> {
150cc14bf22STom Eccles public:
151b2b7efb9SAlex Zinenko   using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis;
152cc14bf22STom Eccles 
15315e915a4SIvan Butygin   mlir::LogicalResult visitOperation(mlir::Operation *op,
15415e915a4SIvan Butygin                                      const LatticePoint &before,
155cc14bf22STom Eccles                                      LatticePoint *after) override;
156cc14bf22STom Eccles 
157cc14bf22STom Eccles   /// At an entry point, the last modifications of all memory resources are
158cc14bf22STom Eccles   /// yet to be determined
159cc14bf22STom Eccles   void setToEntryState(LatticePoint *lattice) override;
160cc14bf22STom Eccles 
161cc14bf22STom Eccles protected:
162cc14bf22STom Eccles   /// Visit control flow operations and decide whether to call visitOperation
163cc14bf22STom Eccles   /// to apply the transfer function
16415e915a4SIvan Butygin   mlir::LogicalResult processOperation(mlir::Operation *op) override;
165cc14bf22STom Eccles };
166cc14bf22STom Eccles 
167cc14bf22STom Eccles /// Drives analysis to find candidate fir.allocmem operations which could be
168cc14bf22STom Eccles /// moved to the stack. Intended to be used with mlir::Pass::getAnalysis
169cc14bf22STom Eccles class StackArraysAnalysisWrapper {
170cc14bf22STom Eccles public:
171cc14bf22STom Eccles   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StackArraysAnalysisWrapper)
172cc14bf22STom Eccles 
173cc14bf22STom Eccles   // Maps fir.allocmem -> place to insert alloca
174cc14bf22STom Eccles   using AllocMemMap = llvm::DenseMap<mlir::Operation *, InsertionPoint>;
175cc14bf22STom Eccles 
176cc14bf22STom Eccles   StackArraysAnalysisWrapper(mlir::Operation *op) {}
177cc14bf22STom Eccles 
178408f4196STom Eccles   // returns nullptr if analysis failed
179408f4196STom Eccles   const AllocMemMap *getCandidateOps(mlir::Operation *func);
180cc14bf22STom Eccles 
181cc14bf22STom Eccles private:
182cc14bf22STom Eccles   llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps;
183cc14bf22STom Eccles 
184db791b27SRamkumar Ramachandra   llvm::LogicalResult analyseFunction(mlir::Operation *func);
185cc14bf22STom Eccles };
186cc14bf22STom Eccles 
187cc14bf22STom Eccles /// Converts a fir.allocmem to a fir.alloca
188cc14bf22STom Eccles class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
189cc14bf22STom Eccles public:
190408f4196STom Eccles   explicit AllocMemConversion(
191cc14bf22STom Eccles       mlir::MLIRContext *ctx,
192408f4196STom Eccles       const StackArraysAnalysisWrapper::AllocMemMap &candidateOps)
193408f4196STom Eccles       : OpRewritePattern(ctx), candidateOps{candidateOps} {}
194cc14bf22STom Eccles 
195db791b27SRamkumar Ramachandra   llvm::LogicalResult
196cc14bf22STom Eccles   matchAndRewrite(fir::AllocMemOp allocmem,
197cc14bf22STom Eccles                   mlir::PatternRewriter &rewriter) const override;
198cc14bf22STom Eccles 
199cc14bf22STom Eccles   /// Determine where to insert the alloca operation. The returned value should
200cc14bf22STom Eccles   /// be checked to see if it is inside a loop
201*e3cd88a7SSlava Zakharin   static InsertionPoint
202*e3cd88a7SSlava Zakharin   findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc,
203*e3cd88a7SSlava Zakharin                            const llvm::SmallVector<mlir::Operation *> &freeOps);
204cc14bf22STom Eccles 
205cc14bf22STom Eccles private:
206408f4196STom Eccles   /// Handle to the DFA (already run)
207408f4196STom Eccles   const StackArraysAnalysisWrapper::AllocMemMap &candidateOps;
208cc14bf22STom Eccles 
209cc14bf22STom Eccles   /// If we failed to find an insertion point not inside a loop, see if it would
210cc14bf22STom Eccles   /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop
211*e3cd88a7SSlava Zakharin   static InsertionPoint findAllocaLoopInsertionPoint(
212*e3cd88a7SSlava Zakharin       fir::AllocMemOp &oldAlloc,
213*e3cd88a7SSlava Zakharin       const llvm::SmallVector<mlir::Operation *> &freeOps);
214cc14bf22STom Eccles 
215cc14bf22STom Eccles   /// Returns the alloca if it was successfully inserted, otherwise {}
216cc14bf22STom Eccles   std::optional<fir::AllocaOp>
217cc14bf22STom Eccles   insertAlloca(fir::AllocMemOp &oldAlloc,
218cc14bf22STom Eccles                mlir::PatternRewriter &rewriter) const;
219cc14bf22STom Eccles 
220cc14bf22STom Eccles   /// Inserts a stacksave before oldAlloc and a stackrestore after each freemem
221cc14bf22STom Eccles   void insertStackSaveRestore(fir::AllocMemOp &oldAlloc,
222cc14bf22STom Eccles                               mlir::PatternRewriter &rewriter) const;
223cc14bf22STom Eccles };
224cc14bf22STom Eccles 
225cc14bf22STom Eccles class StackArraysPass : public fir::impl::StackArraysBase<StackArraysPass> {
226cc14bf22STom Eccles public:
227cc14bf22STom Eccles   StackArraysPass() = default;
228cc14bf22STom Eccles   StackArraysPass(const StackArraysPass &pass);
229cc14bf22STom Eccles 
230cc14bf22STom Eccles   llvm::StringRef getDescription() const override;
231cc14bf22STom Eccles 
232cc14bf22STom Eccles   void runOnOperation() override;
233cc14bf22STom Eccles 
234cc14bf22STom Eccles private:
235cc14bf22STom Eccles   Statistic runCount{this, "stackArraysRunCount",
236cc14bf22STom Eccles                      "Number of heap allocations moved to the stack"};
237cc14bf22STom Eccles };
238cc14bf22STom Eccles 
239cc14bf22STom Eccles } // namespace
240cc14bf22STom Eccles 
241cc14bf22STom Eccles static void print(llvm::raw_ostream &os, AllocationState state) {
242cc14bf22STom Eccles   switch (state) {
243cc14bf22STom Eccles   case AllocationState::Unknown:
244cc14bf22STom Eccles     os << "Unknown";
245cc14bf22STom Eccles     break;
246cc14bf22STom Eccles   case AllocationState::Freed:
247cc14bf22STom Eccles     os << "Freed";
248cc14bf22STom Eccles     break;
249cc14bf22STom Eccles   case AllocationState::Allocated:
250cc14bf22STom Eccles     os << "Allocated";
251cc14bf22STom Eccles     break;
252cc14bf22STom Eccles   }
253cc14bf22STom Eccles }
254cc14bf22STom Eccles 
255cc14bf22STom Eccles /// Join two AllocationStates for the same value coming from different CFG
256cc14bf22STom Eccles /// blocks
257cc14bf22STom Eccles static AllocationState join(AllocationState lhs, AllocationState rhs) {
258cc14bf22STom Eccles   //           | Allocated | Freed     | Unknown
259cc14bf22STom Eccles   // ========= | ========= | ========= | =========
260cc14bf22STom Eccles   // Allocated | Allocated | Unknown   | Unknown
261cc14bf22STom Eccles   // Freed     | Unknown   | Freed     | Unknown
262cc14bf22STom Eccles   // Unknown   | Unknown   | Unknown   | Unknown
263cc14bf22STom Eccles   if (lhs == rhs)
264cc14bf22STom Eccles     return lhs;
265cc14bf22STom Eccles   return AllocationState::Unknown;
266cc14bf22STom Eccles }
267cc14bf22STom Eccles 
268cc14bf22STom Eccles mlir::ChangeResult LatticePoint::join(const AbstractDenseLattice &lattice) {
269cc14bf22STom Eccles   const auto &rhs = static_cast<const LatticePoint &>(lattice);
270cc14bf22STom Eccles   mlir::ChangeResult changed = mlir::ChangeResult::NoChange;
271cc14bf22STom Eccles 
272cc14bf22STom Eccles   // add everything from rhs to map, handling cases where values are in both
273cc14bf22STom Eccles   for (const auto &[value, rhsState] : rhs.stateMap) {
274cc14bf22STom Eccles     auto it = stateMap.find(value);
275cc14bf22STom Eccles     if (it != stateMap.end()) {
276cc14bf22STom Eccles       // value is present in both maps
277cc14bf22STom Eccles       AllocationState myState = it->second;
278cc14bf22STom Eccles       AllocationState newState = ::join(myState, rhsState);
279cc14bf22STom Eccles       if (newState != myState) {
280cc14bf22STom Eccles         changed = mlir::ChangeResult::Change;
281cc14bf22STom Eccles         it->getSecond() = newState;
282cc14bf22STom Eccles       }
283cc14bf22STom Eccles     } else {
284cc14bf22STom Eccles       // value not present in current map: add it
285cc14bf22STom Eccles       stateMap.insert({value, rhsState});
286cc14bf22STom Eccles       changed = mlir::ChangeResult::Change;
287cc14bf22STom Eccles     }
288cc14bf22STom Eccles   }
289cc14bf22STom Eccles 
290cc14bf22STom Eccles   return changed;
291cc14bf22STom Eccles }
292cc14bf22STom Eccles 
293cc14bf22STom Eccles void LatticePoint::print(llvm::raw_ostream &os) const {
294cc14bf22STom Eccles   for (const auto &[value, state] : stateMap) {
295464d321eSKareem Ergawy     os << "\n * " << value << ": ";
296cc14bf22STom Eccles     ::print(os, state);
297cc14bf22STom Eccles   }
298cc14bf22STom Eccles }
299cc14bf22STom Eccles 
300cc14bf22STom Eccles mlir::ChangeResult LatticePoint::reset() {
301cc14bf22STom Eccles   if (stateMap.empty())
302cc14bf22STom Eccles     return mlir::ChangeResult::NoChange;
303cc14bf22STom Eccles   stateMap.clear();
304cc14bf22STom Eccles   return mlir::ChangeResult::Change;
305cc14bf22STom Eccles }
306cc14bf22STom Eccles 
307cc14bf22STom Eccles mlir::ChangeResult LatticePoint::set(mlir::Value value, AllocationState state) {
308cc14bf22STom Eccles   if (stateMap.count(value)) {
309cc14bf22STom Eccles     // already in map
310cc14bf22STom Eccles     AllocationState &oldState = stateMap[value];
311cc14bf22STom Eccles     if (oldState != state) {
312cc14bf22STom Eccles       stateMap[value] = state;
313cc14bf22STom Eccles       return mlir::ChangeResult::Change;
314cc14bf22STom Eccles     }
315cc14bf22STom Eccles     return mlir::ChangeResult::NoChange;
316cc14bf22STom Eccles   }
317cc14bf22STom Eccles   stateMap.insert({value, state});
318cc14bf22STom Eccles   return mlir::ChangeResult::Change;
319cc14bf22STom Eccles }
320cc14bf22STom Eccles 
321cc14bf22STom Eccles /// Get values which were allocated in this function and always freed before
322cc14bf22STom Eccles /// the function returns
323cc14bf22STom Eccles void LatticePoint::appendFreedValues(llvm::DenseSet<mlir::Value> &out) const {
324cc14bf22STom Eccles   for (auto &[value, state] : stateMap) {
325cc14bf22STom Eccles     if (state == AllocationState::Freed)
326cc14bf22STom Eccles       out.insert(value);
327cc14bf22STom Eccles   }
328cc14bf22STom Eccles }
329cc14bf22STom Eccles 
330cc14bf22STom Eccles std::optional<AllocationState> LatticePoint::get(mlir::Value val) const {
331cc14bf22STom Eccles   auto it = stateMap.find(val);
332cc14bf22STom Eccles   if (it == stateMap.end())
333cc14bf22STom Eccles     return {};
334cc14bf22STom Eccles   return it->second;
335cc14bf22STom Eccles }
336cc14bf22STom Eccles 
337303249c4STom Eccles static mlir::Value lookThroughDeclaresAndConverts(mlir::Value value) {
338303249c4STom Eccles   while (mlir::Operation *op = value.getDefiningOp()) {
339303249c4STom Eccles     if (auto declareOp = llvm::dyn_cast<fir::DeclareOp>(op))
340303249c4STom Eccles       value = declareOp.getMemref();
341303249c4STom Eccles     else if (auto convertOp = llvm::dyn_cast<fir::ConvertOp>(op))
342303249c4STom Eccles       value = convertOp->getOperand(0);
343303249c4STom Eccles     else
344303249c4STom Eccles       return value;
345303249c4STom Eccles   }
346303249c4STom Eccles   return value;
347303249c4STom Eccles }
348303249c4STom Eccles 
34915e915a4SIvan Butygin mlir::LogicalResult AllocationAnalysis::visitOperation(
35015e915a4SIvan Butygin     mlir::Operation *op, const LatticePoint &before, LatticePoint *after) {
351cc14bf22STom Eccles   LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op
352cc14bf22STom Eccles                           << "\n");
353cc14bf22STom Eccles   LLVM_DEBUG(llvm::dbgs() << "--Lattice in: " << before << "\n");
354cc14bf22STom Eccles 
355cc14bf22STom Eccles   // propagate before -> after
356cc14bf22STom Eccles   mlir::ChangeResult changed = after->join(before);
357cc14bf22STom Eccles 
358cc14bf22STom Eccles   if (auto allocmem = mlir::dyn_cast<fir::AllocMemOp>(op)) {
359cc14bf22STom Eccles     assert(op->getNumResults() == 1 && "fir.allocmem has one result");
360cc14bf22STom Eccles     auto attr = op->getAttrOfType<fir::MustBeHeapAttr>(
361cc14bf22STom Eccles         fir::MustBeHeapAttr::getAttrName());
362cc14bf22STom Eccles     if (attr && attr.getValue()) {
363cc14bf22STom Eccles       LLVM_DEBUG(llvm::dbgs() << "--Found fir.must_be_heap: skipping\n");
364cc14bf22STom Eccles       // skip allocation marked not to be moved
36515e915a4SIvan Butygin       return mlir::success();
366cc14bf22STom Eccles     }
367cc14bf22STom Eccles 
368cc14bf22STom Eccles     auto retTy = allocmem.getAllocatedType();
369fac349a1SChristian Sigg     if (!mlir::isa<fir::SequenceType>(retTy)) {
370cc14bf22STom Eccles       LLVM_DEBUG(llvm::dbgs()
371cc14bf22STom Eccles                  << "--Allocation is not for an array: skipping\n");
37215e915a4SIvan Butygin       return mlir::success();
373cc14bf22STom Eccles     }
374cc14bf22STom Eccles 
375cc14bf22STom Eccles     mlir::Value result = op->getResult(0);
376cc14bf22STom Eccles     changed |= after->set(result, AllocationState::Allocated);
377cc14bf22STom Eccles   } else if (mlir::isa<fir::FreeMemOp>(op)) {
378cc14bf22STom Eccles     assert(op->getNumOperands() == 1 && "fir.freemem has one operand");
379cc14bf22STom Eccles     mlir::Value operand = op->getOperand(0);
380464d321eSKareem Ergawy 
381464d321eSKareem Ergawy     // Note: StackArrays is scheduled in the pass pipeline after lowering hlfir
382303249c4STom Eccles     // to fir. Therefore, we only need to handle `fir::DeclareOp`s. Also look
383303249c4STom Eccles     // past converts in case the pointer was changed between different pointer
384303249c4STom Eccles     // types.
385303249c4STom Eccles     operand = lookThroughDeclaresAndConverts(operand);
386464d321eSKareem Ergawy 
387cc14bf22STom Eccles     std::optional<AllocationState> operandState = before.get(operand);
388cc14bf22STom Eccles     if (operandState && *operandState == AllocationState::Allocated) {
389cc14bf22STom Eccles       // don't tag things not allocated in this function as freed, so that we
390cc14bf22STom Eccles       // don't think they are candidates for moving to the stack
391cc14bf22STom Eccles       changed |= after->set(operand, AllocationState::Freed);
392cc14bf22STom Eccles     }
393cc14bf22STom Eccles   } else if (mlir::isa<fir::ResultOp>(op)) {
394cc14bf22STom Eccles     mlir::Operation *parent = op->getParentOp();
3954b3f251bSdonald chen     LatticePoint *parentLattice = getLattice(getProgramPointAfter(parent));
396cc14bf22STom Eccles     assert(parentLattice);
397cc14bf22STom Eccles     mlir::ChangeResult parentChanged = parentLattice->join(*after);
398cc14bf22STom Eccles     propagateIfChanged(parentLattice, parentChanged);
399cc14bf22STom Eccles   }
400cc14bf22STom Eccles 
401cc14bf22STom Eccles   // we pass lattices straight through fir.call because called functions should
402cc14bf22STom Eccles   // not deallocate flang-generated array temporaries
403cc14bf22STom Eccles 
404cc14bf22STom Eccles   LLVM_DEBUG(llvm::dbgs() << "--Lattice out: " << *after << "\n");
405cc14bf22STom Eccles   propagateIfChanged(after, changed);
40615e915a4SIvan Butygin   return mlir::success();
407cc14bf22STom Eccles }
408cc14bf22STom Eccles 
409cc14bf22STom Eccles void AllocationAnalysis::setToEntryState(LatticePoint *lattice) {
410cc14bf22STom Eccles   propagateIfChanged(lattice, lattice->reset());
411cc14bf22STom Eccles }
412cc14bf22STom Eccles 
413cc14bf22STom Eccles /// Mostly a copy of AbstractDenseLattice::processOperation - the difference
414cc14bf22STom Eccles /// being that call operations are passed through to the transfer function
41515e915a4SIvan Butygin mlir::LogicalResult AllocationAnalysis::processOperation(mlir::Operation *op) {
4164b3f251bSdonald chen   mlir::ProgramPoint *point = getProgramPointAfter(op);
417cc14bf22STom Eccles   // If the containing block is not executable, bail out.
4184b3f251bSdonald chen   if (op->getBlock() != nullptr &&
4194b3f251bSdonald chen       !getOrCreateFor<mlir::dataflow::Executable>(
4204b3f251bSdonald chen            point, getProgramPointBefore(op->getBlock()))
4214b3f251bSdonald chen            ->isLive())
42215e915a4SIvan Butygin     return mlir::success();
423cc14bf22STom Eccles 
424cc14bf22STom Eccles   // Get the dense lattice to update
4254b3f251bSdonald chen   mlir::dataflow::AbstractDenseLattice *after = getLattice(point);
426cc14bf22STom Eccles 
427cc14bf22STom Eccles   // If this op implements region control-flow, then control-flow dictates its
428cc14bf22STom Eccles   // transfer function.
42915e915a4SIvan Butygin   if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) {
4304b3f251bSdonald chen     visitRegionBranchOperation(point, branch, after);
43115e915a4SIvan Butygin     return mlir::success();
43215e915a4SIvan Butygin   }
433cc14bf22STom Eccles 
434cc14bf22STom Eccles   // pass call operations through to the transfer function
435cc14bf22STom Eccles 
436cc14bf22STom Eccles   // Get the dense state before the execution of the op.
4374b3f251bSdonald chen   const mlir::dataflow::AbstractDenseLattice *before =
4384b3f251bSdonald chen       getLatticeFor(point, getProgramPointBefore(op));
439cc14bf22STom Eccles 
440cc14bf22STom Eccles   /// Invoke the operation transfer function
44115e915a4SIvan Butygin   return visitOperationImpl(op, *before, after);
442cc14bf22STom Eccles }
443cc14bf22STom Eccles 
444db791b27SRamkumar Ramachandra llvm::LogicalResult
445408f4196STom Eccles StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
446cc14bf22STom Eccles   assert(mlir::isa<mlir::func::FuncOp>(func));
447e2153241STom Eccles   size_t nAllocs = 0;
448e2153241STom Eccles   func->walk([&nAllocs](fir::AllocMemOp) { nAllocs++; });
449e2153241STom Eccles   // don't bother with the analysis if there are no heap allocations
450e2153241STom Eccles   if (nAllocs == 0)
451e2153241STom Eccles     return mlir::success();
452e2153241STom Eccles   if ((maxAllocsPerFunc != 0) && (nAllocs > maxAllocsPerFunc)) {
453e2153241STom Eccles     LLVM_DEBUG(llvm::dbgs() << "Skipping stack arrays for function with "
454e2153241STom Eccles                             << nAllocs << " heap allocations");
455e2153241STom Eccles     return mlir::success();
456e2153241STom Eccles   }
457e2153241STom Eccles 
458cc14bf22STom Eccles   mlir::DataFlowSolver solver;
459cc14bf22STom Eccles   // constant propagation is required for dead code analysis, dead code analysis
460cc14bf22STom Eccles   // is required to mark blocks live (required for mlir dense dfa)
461cc14bf22STom Eccles   solver.load<mlir::dataflow::SparseConstantPropagation>();
462cc14bf22STom Eccles   solver.load<mlir::dataflow::DeadCodeAnalysis>();
463cc14bf22STom Eccles 
464cc14bf22STom Eccles   auto [it, inserted] = funcMaps.try_emplace(func);
465cc14bf22STom Eccles   AllocMemMap &candidateOps = it->second;
466cc14bf22STom Eccles 
467cc14bf22STom Eccles   solver.load<AllocationAnalysis>();
468cc14bf22STom Eccles   if (failed(solver.initializeAndRun(func))) {
469cc14bf22STom Eccles     llvm::errs() << "DataFlowSolver failed!";
470408f4196STom Eccles     return mlir::failure();
471cc14bf22STom Eccles   }
472cc14bf22STom Eccles 
4734b3f251bSdonald chen   LatticePoint point{solver.getProgramPointAfter(func)};
4747a49d50fSTom Eccles   auto joinOperationLattice = [&](mlir::Operation *op) {
4754b3f251bSdonald chen     const LatticePoint *lattice =
4764b3f251bSdonald chen         solver.lookupState<LatticePoint>(solver.getProgramPointAfter(op));
477cc14bf22STom Eccles     // there will be no lattice for an unreachable block
478cc14bf22STom Eccles     if (lattice)
479c50de57fSKazu Hirata       (void)point.join(*lattice);
4807a49d50fSTom Eccles   };
481698b42ccSKareem Ergawy 
4827a49d50fSTom Eccles   func->walk([&](mlir::func::ReturnOp child) { joinOperationLattice(child); });
4837a49d50fSTom Eccles   func->walk([&](fir::UnreachableOp child) { joinOperationLattice(child); });
484464d321eSKareem Ergawy   func->walk(
485464d321eSKareem Ergawy       [&](mlir::omp::TerminatorOp child) { joinOperationLattice(child); });
486698b42ccSKareem Ergawy   func->walk([&](mlir::omp::YieldOp child) { joinOperationLattice(child); });
487464d321eSKareem Ergawy 
488cc14bf22STom Eccles   llvm::DenseSet<mlir::Value> freedValues;
489cc14bf22STom Eccles   point.appendFreedValues(freedValues);
490cc14bf22STom Eccles 
491*e3cd88a7SSlava Zakharin   // Find all fir.freemem operations corresponding to fir.allocmem
492*e3cd88a7SSlava Zakharin   // in freedValues. It is best to find the association going back
493*e3cd88a7SSlava Zakharin   // from fir.freemem to fir.allocmem through the def-use chains,
494*e3cd88a7SSlava Zakharin   // so that we can use lookThroughDeclaresAndConverts same way
495*e3cd88a7SSlava Zakharin   // the AllocationAnalysis is handling them.
496*e3cd88a7SSlava Zakharin   llvm::DenseMap<mlir::Operation *, llvm::SmallVector<mlir::Operation *>>
497*e3cd88a7SSlava Zakharin       allocToFreeMemMap;
498*e3cd88a7SSlava Zakharin   func->walk([&](fir::FreeMemOp freeOp) {
499*e3cd88a7SSlava Zakharin     mlir::Value memref = lookThroughDeclaresAndConverts(freeOp.getHeapref());
500*e3cd88a7SSlava Zakharin     if (!freedValues.count(memref))
501*e3cd88a7SSlava Zakharin       return;
502*e3cd88a7SSlava Zakharin 
503*e3cd88a7SSlava Zakharin     auto allocMem = memref.getDefiningOp<fir::AllocMemOp>();
504*e3cd88a7SSlava Zakharin     allocToFreeMemMap[allocMem].push_back(freeOp);
505*e3cd88a7SSlava Zakharin   });
506*e3cd88a7SSlava Zakharin 
507cc14bf22STom Eccles   // We only replace allocations which are definately freed on all routes
508cc14bf22STom Eccles   // through the function because otherwise the allocation may have an intende
509cc14bf22STom Eccles   // lifetime longer than the current stack frame (e.g. a heap allocation which
510cc14bf22STom Eccles   // is then freed by another function).
511cc14bf22STom Eccles   for (mlir::Value freedValue : freedValues) {
512cc14bf22STom Eccles     fir::AllocMemOp allocmem = freedValue.getDefiningOp<fir::AllocMemOp>();
513cc14bf22STom Eccles     InsertionPoint insertionPoint =
514*e3cd88a7SSlava Zakharin         AllocMemConversion::findAllocaInsertionPoint(
515*e3cd88a7SSlava Zakharin             allocmem, allocToFreeMemMap[allocmem]);
516cc14bf22STom Eccles     if (insertionPoint)
517cc14bf22STom Eccles       candidateOps.insert({allocmem, insertionPoint});
518cc14bf22STom Eccles   }
519cc14bf22STom Eccles 
520cc14bf22STom Eccles   LLVM_DEBUG(for (auto [allocMemOp, _]
521cc14bf22STom Eccles                   : candidateOps) {
522cc14bf22STom Eccles     llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n';
523cc14bf22STom Eccles   });
524408f4196STom Eccles   return mlir::success();
525cc14bf22STom Eccles }
526cc14bf22STom Eccles 
527408f4196STom Eccles const StackArraysAnalysisWrapper::AllocMemMap *
528cc14bf22STom Eccles StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) {
529408f4196STom Eccles   if (!funcMaps.contains(func))
530408f4196STom Eccles     if (mlir::failed(analyseFunction(func)))
531408f4196STom Eccles       return nullptr;
532408f4196STom Eccles   return &funcMaps[func];
533cc14bf22STom Eccles }
534cc14bf22STom Eccles 
535775de675STom Eccles /// Restore the old allocation type exected by existing code
536775de675STom Eccles static mlir::Value convertAllocationType(mlir::PatternRewriter &rewriter,
537775de675STom Eccles                                          const mlir::Location &loc,
538775de675STom Eccles                                          mlir::Value heap, mlir::Value stack) {
539775de675STom Eccles   mlir::Type heapTy = heap.getType();
540775de675STom Eccles   mlir::Type stackTy = stack.getType();
541775de675STom Eccles 
542775de675STom Eccles   if (heapTy == stackTy)
543775de675STom Eccles     return stack;
544775de675STom Eccles 
545775de675STom Eccles   fir::HeapType firHeapTy = mlir::cast<fir::HeapType>(heapTy);
54676c3c5bcSTom Eccles   LLVM_ATTRIBUTE_UNUSED fir::ReferenceType firRefTy =
54776c3c5bcSTom Eccles       mlir::cast<fir::ReferenceType>(stackTy);
548775de675STom Eccles   assert(firHeapTy.getElementType() == firRefTy.getElementType() &&
549775de675STom Eccles          "Allocations must have the same type");
550775de675STom Eccles 
551775de675STom Eccles   auto insertionPoint = rewriter.saveInsertionPoint();
552775de675STom Eccles   rewriter.setInsertionPointAfter(stack.getDefiningOp());
553775de675STom Eccles   mlir::Value conv =
554775de675STom Eccles       rewriter.create<fir::ConvertOp>(loc, firHeapTy, stack).getResult();
555775de675STom Eccles   rewriter.restoreInsertionPoint(insertionPoint);
556775de675STom Eccles   return conv;
557775de675STom Eccles }
558775de675STom Eccles 
559db791b27SRamkumar Ramachandra llvm::LogicalResult
560cc14bf22STom Eccles AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
561cc14bf22STom Eccles                                     mlir::PatternRewriter &rewriter) const {
562cc14bf22STom Eccles   auto oldInsertionPt = rewriter.saveInsertionPoint();
563cc14bf22STom Eccles   // add alloca operation
564cc14bf22STom Eccles   std::optional<fir::AllocaOp> alloca = insertAlloca(allocmem, rewriter);
565cc14bf22STom Eccles   rewriter.restoreInsertionPoint(oldInsertionPt);
566cc14bf22STom Eccles   if (!alloca)
567cc14bf22STom Eccles     return mlir::failure();
568cc14bf22STom Eccles 
569cc14bf22STom Eccles   // remove freemem operations
570408f4196STom Eccles   llvm::SmallVector<mlir::Operation *> erases;
571303249c4STom Eccles   mlir::Operation *parent = allocmem->getParentOp();
572303249c4STom Eccles   // TODO: this shouldn't need to be re-calculated for every allocmem
573303249c4STom Eccles   parent->walk([&](fir::FreeMemOp freeOp) {
574303249c4STom Eccles     if (lookThroughDeclaresAndConverts(freeOp->getOperand(0)) == allocmem)
575303249c4STom Eccles       erases.push_back(freeOp);
576303249c4STom Eccles   });
577464d321eSKareem Ergawy 
578408f4196STom Eccles   // now we are done iterating the users, it is safe to mutate them
579408f4196STom Eccles   for (mlir::Operation *erase : erases)
580408f4196STom Eccles     rewriter.eraseOp(erase);
581cc14bf22STom Eccles 
582cc14bf22STom Eccles   // replace references to heap allocation with references to stack allocation
583775de675STom Eccles   mlir::Value newValue = convertAllocationType(
584775de675STom Eccles       rewriter, allocmem.getLoc(), allocmem.getResult(), alloca->getResult());
585775de675STom Eccles   rewriter.replaceAllUsesWith(allocmem.getResult(), newValue);
586cc14bf22STom Eccles 
587cc14bf22STom Eccles   // remove allocmem operation
588cc14bf22STom Eccles   rewriter.eraseOp(allocmem.getOperation());
589cc14bf22STom Eccles 
590cc14bf22STom Eccles   return mlir::success();
591cc14bf22STom Eccles }
592cc14bf22STom Eccles 
593cc14bf22STom Eccles static bool isInLoop(mlir::Block *block) {
594d5ea1b22STom Eccles   return mlir::LoopLikeOpInterface::blockIsInLoop(block);
595cc14bf22STom Eccles }
596cc14bf22STom Eccles 
597cc14bf22STom Eccles static bool isInLoop(mlir::Operation *op) {
598cc14bf22STom Eccles   return isInLoop(op->getBlock()) ||
599cc14bf22STom Eccles          op->getParentOfType<mlir::LoopLikeOpInterface>();
600cc14bf22STom Eccles }
601cc14bf22STom Eccles 
602*e3cd88a7SSlava Zakharin InsertionPoint AllocMemConversion::findAllocaInsertionPoint(
603*e3cd88a7SSlava Zakharin     fir::AllocMemOp &oldAlloc,
604*e3cd88a7SSlava Zakharin     const llvm::SmallVector<mlir::Operation *> &freeOps) {
605cc14bf22STom Eccles   // Ideally the alloca should be inserted at the end of the function entry
606cc14bf22STom Eccles   // block so that we do not allocate stack space in a loop. However,
607cc14bf22STom Eccles   // the operands to the alloca may not be available that early, so insert it
608cc14bf22STom Eccles   // after the last operand becomes available
609cc14bf22STom Eccles   // If the old allocmem op was in an openmp region then it should not be moved
610cc14bf22STom Eccles   // outside of that
611cc14bf22STom Eccles   LLVM_DEBUG(llvm::dbgs() << "StackArrays: findAllocaInsertionPoint: "
612cc14bf22STom Eccles                           << oldAlloc << "\n");
613cc14bf22STom Eccles 
614cc14bf22STom Eccles   // check that an Operation or Block we are about to return is not in a loop
615cc14bf22STom Eccles   auto checkReturn = [&](auto *point) -> InsertionPoint {
616cc14bf22STom Eccles     if (isInLoop(point)) {
617cc14bf22STom Eccles       mlir::Operation *oldAllocOp = oldAlloc.getOperation();
618cc14bf22STom Eccles       if (isInLoop(oldAllocOp)) {
619cc14bf22STom Eccles         // where we want to put it is in a loop, and even the old location is in
620cc14bf22STom Eccles         // a loop. Give up.
621*e3cd88a7SSlava Zakharin         return findAllocaLoopInsertionPoint(oldAlloc, freeOps);
622cc14bf22STom Eccles       }
623cc14bf22STom Eccles       return {oldAllocOp};
624cc14bf22STom Eccles     }
625cc14bf22STom Eccles     return {point};
626cc14bf22STom Eccles   };
627cc14bf22STom Eccles 
628cc14bf22STom Eccles   auto oldOmpRegion =
629cc14bf22STom Eccles       oldAlloc->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>();
630cc14bf22STom Eccles 
631cc14bf22STom Eccles   // Find when the last operand value becomes available
632cc14bf22STom Eccles   mlir::Block *operandsBlock = nullptr;
633cc14bf22STom Eccles   mlir::Operation *lastOperand = nullptr;
634cc14bf22STom Eccles   for (mlir::Value operand : oldAlloc.getOperands()) {
635cc14bf22STom Eccles     LLVM_DEBUG(llvm::dbgs() << "--considering operand " << operand << "\n");
636cc14bf22STom Eccles     mlir::Operation *op = operand.getDefiningOp();
637cc14bf22STom Eccles     if (!op)
638cc14bf22STom Eccles       return checkReturn(oldAlloc.getOperation());
639cc14bf22STom Eccles     if (!operandsBlock)
640cc14bf22STom Eccles       operandsBlock = op->getBlock();
641cc14bf22STom Eccles     else if (operandsBlock != op->getBlock()) {
642cc14bf22STom Eccles       LLVM_DEBUG(llvm::dbgs()
643cc14bf22STom Eccles                  << "----operand declared in a different block!\n");
644cc14bf22STom Eccles       // Operation::isBeforeInBlock requires the operations to be in the same
645cc14bf22STom Eccles       // block. The best we can do is the location of the allocmem.
646cc14bf22STom Eccles       return checkReturn(oldAlloc.getOperation());
647cc14bf22STom Eccles     }
648cc14bf22STom Eccles     if (!lastOperand || lastOperand->isBeforeInBlock(op))
649cc14bf22STom Eccles       lastOperand = op;
650cc14bf22STom Eccles   }
651cc14bf22STom Eccles 
652cc14bf22STom Eccles   if (lastOperand) {
653cc14bf22STom Eccles     // there were value operands to the allocmem so insert after the last one
654cc14bf22STom Eccles     LLVM_DEBUG(llvm::dbgs()
655cc14bf22STom Eccles                << "--Placing after last operand: " << *lastOperand << "\n");
656cc14bf22STom Eccles     // check we aren't moving out of an omp region
657cc14bf22STom Eccles     auto lastOpOmpRegion =
658cc14bf22STom Eccles         lastOperand->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>();
659cc14bf22STom Eccles     if (lastOpOmpRegion == oldOmpRegion)
660cc14bf22STom Eccles       return checkReturn(lastOperand);
661cc14bf22STom Eccles     // Presumably this happened because the operands became ready before the
662cc14bf22STom Eccles     // start of this openmp region. (lastOpOmpRegion != oldOmpRegion) should
663cc14bf22STom Eccles     // imply that oldOmpRegion comes after lastOpOmpRegion.
664cc14bf22STom Eccles     return checkReturn(oldOmpRegion.getAllocaBlock());
665cc14bf22STom Eccles   }
666cc14bf22STom Eccles 
667cc14bf22STom Eccles   // There were no value operands to the allocmem so we are safe to insert it
668cc14bf22STom Eccles   // as early as we want
669cc14bf22STom Eccles 
670cc14bf22STom Eccles   // handle openmp case
671cc14bf22STom Eccles   if (oldOmpRegion)
672cc14bf22STom Eccles     return checkReturn(oldOmpRegion.getAllocaBlock());
673cc14bf22STom Eccles 
674cc14bf22STom Eccles   // fall back to the function entry block
675cc14bf22STom Eccles   mlir::func::FuncOp func = oldAlloc->getParentOfType<mlir::func::FuncOp>();
676cc14bf22STom Eccles   assert(func && "This analysis is run on func.func");
677cc14bf22STom Eccles   mlir::Block &entryBlock = func.getBlocks().front();
678cc14bf22STom Eccles   LLVM_DEBUG(llvm::dbgs() << "--Placing at the start of func entry block\n");
679cc14bf22STom Eccles   return checkReturn(&entryBlock);
680cc14bf22STom Eccles }
681cc14bf22STom Eccles 
682*e3cd88a7SSlava Zakharin InsertionPoint AllocMemConversion::findAllocaLoopInsertionPoint(
683*e3cd88a7SSlava Zakharin     fir::AllocMemOp &oldAlloc,
684*e3cd88a7SSlava Zakharin     const llvm::SmallVector<mlir::Operation *> &freeOps) {
685cc14bf22STom Eccles   mlir::Operation *oldAllocOp = oldAlloc;
686cc14bf22STom Eccles   // This is only called as a last resort. We should try to insert at the
687cc14bf22STom Eccles   // location of the old allocation, which is inside of a loop, using
688cc14bf22STom Eccles   // llvm.stacksave/llvm.stackrestore
689cc14bf22STom Eccles 
690cc14bf22STom Eccles   assert(freeOps.size() && "DFA should only return freed memory");
691cc14bf22STom Eccles 
692cc14bf22STom Eccles   // Don't attempt to reason about a stacksave/stackrestore between different
693cc14bf22STom Eccles   // blocks
694cc14bf22STom Eccles   for (mlir::Operation *free : freeOps)
695cc14bf22STom Eccles     if (free->getBlock() != oldAllocOp->getBlock())
696cc14bf22STom Eccles       return {nullptr};
697cc14bf22STom Eccles 
698cc14bf22STom Eccles   // Check that there aren't any other stack allocations in between the
699cc14bf22STom Eccles   // stack save and stack restore
700cc14bf22STom Eccles   // note: for flang generated temporaries there should only be one free op
701cc14bf22STom Eccles   for (mlir::Operation *free : freeOps) {
702cc14bf22STom Eccles     for (mlir::Operation *op = oldAlloc; op && op != free;
703cc14bf22STom Eccles          op = op->getNextNode()) {
704cc14bf22STom Eccles       if (mlir::isa<fir::AllocaOp>(op))
705cc14bf22STom Eccles         return {nullptr};
706cc14bf22STom Eccles     }
707cc14bf22STom Eccles   }
708cc14bf22STom Eccles 
709cc14bf22STom Eccles   return InsertionPoint{oldAllocOp, /*shouldStackSaveRestore=*/true};
710cc14bf22STom Eccles }
711cc14bf22STom Eccles 
712cc14bf22STom Eccles std::optional<fir::AllocaOp>
713cc14bf22STom Eccles AllocMemConversion::insertAlloca(fir::AllocMemOp &oldAlloc,
714cc14bf22STom Eccles                                  mlir::PatternRewriter &rewriter) const {
715cc14bf22STom Eccles   auto it = candidateOps.find(oldAlloc.getOperation());
716cc14bf22STom Eccles   if (it == candidateOps.end())
717cc14bf22STom Eccles     return {};
718cc14bf22STom Eccles   InsertionPoint insertionPoint = it->second;
719cc14bf22STom Eccles   if (!insertionPoint)
720cc14bf22STom Eccles     return {};
721cc14bf22STom Eccles 
722cc14bf22STom Eccles   if (insertionPoint.shouldSaveRestoreStack())
723cc14bf22STom Eccles     insertStackSaveRestore(oldAlloc, rewriter);
724cc14bf22STom Eccles 
725cc14bf22STom Eccles   mlir::Location loc = oldAlloc.getLoc();
726cc14bf22STom Eccles   mlir::Type varTy = oldAlloc.getInType();
727cc14bf22STom Eccles   if (mlir::Operation *op = insertionPoint.tryGetOperation()) {
728cc14bf22STom Eccles     rewriter.setInsertionPointAfter(op);
729cc14bf22STom Eccles   } else {
730cc14bf22STom Eccles     mlir::Block *block = insertionPoint.tryGetBlock();
731cc14bf22STom Eccles     assert(block && "There must be a valid insertion point");
732cc14bf22STom Eccles     rewriter.setInsertionPointToStart(block);
733cc14bf22STom Eccles   }
734cc14bf22STom Eccles 
735cc14bf22STom Eccles   auto unpackName = [](std::optional<llvm::StringRef> opt) -> llvm::StringRef {
736cc14bf22STom Eccles     if (opt)
737cc14bf22STom Eccles       return *opt;
738cc14bf22STom Eccles     return {};
739cc14bf22STom Eccles   };
740cc14bf22STom Eccles 
741cc14bf22STom Eccles   llvm::StringRef uniqName = unpackName(oldAlloc.getUniqName());
742cc14bf22STom Eccles   llvm::StringRef bindcName = unpackName(oldAlloc.getBindcName());
743cc14bf22STom Eccles   return rewriter.create<fir::AllocaOp>(loc, varTy, uniqName, bindcName,
744cc14bf22STom Eccles                                         oldAlloc.getTypeparams(),
745cc14bf22STom Eccles                                         oldAlloc.getShape());
746cc14bf22STom Eccles }
747cc14bf22STom Eccles 
748cc14bf22STom Eccles void AllocMemConversion::insertStackSaveRestore(
749cc14bf22STom Eccles     fir::AllocMemOp &oldAlloc, mlir::PatternRewriter &rewriter) const {
750cc14bf22STom Eccles   auto oldPoint = rewriter.saveInsertionPoint();
751cc14bf22STom Eccles   auto mod = oldAlloc->getParentOfType<mlir::ModuleOp>();
75253cc33b0STom Eccles   fir::FirOpBuilder builder{rewriter, mod};
753cc14bf22STom Eccles 
754cc14bf22STom Eccles   builder.setInsertionPoint(oldAlloc);
7555aaf384bSTom Eccles   mlir::Value sp = builder.genStackSave(oldAlloc.getLoc());
756cc14bf22STom Eccles 
757464d321eSKareem Ergawy   auto createStackRestoreCall = [&](mlir::Operation *user) {
758cc14bf22STom Eccles     builder.setInsertionPoint(user);
7595aaf384bSTom Eccles     builder.genStackRestore(user->getLoc(), sp);
760464d321eSKareem Ergawy   };
761464d321eSKareem Ergawy 
762464d321eSKareem Ergawy   for (mlir::Operation *user : oldAlloc->getUsers()) {
763464d321eSKareem Ergawy     if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
764464d321eSKareem Ergawy       for (mlir::Operation *user : declareOp->getUsers()) {
765464d321eSKareem Ergawy         if (mlir::isa<fir::FreeMemOp>(user))
766464d321eSKareem Ergawy           createStackRestoreCall(user);
767464d321eSKareem Ergawy       }
768464d321eSKareem Ergawy     }
769464d321eSKareem Ergawy 
770464d321eSKareem Ergawy     if (mlir::isa<fir::FreeMemOp>(user)) {
771464d321eSKareem Ergawy       createStackRestoreCall(user);
772cc14bf22STom Eccles     }
773cc14bf22STom Eccles   }
774cc14bf22STom Eccles 
775cc14bf22STom Eccles   rewriter.restoreInsertionPoint(oldPoint);
776cc14bf22STom Eccles }
777cc14bf22STom Eccles 
778cc14bf22STom Eccles StackArraysPass::StackArraysPass(const StackArraysPass &pass)
779cc14bf22STom Eccles     : fir::impl::StackArraysBase<StackArraysPass>(pass) {}
780cc14bf22STom Eccles 
781cc14bf22STom Eccles llvm::StringRef StackArraysPass::getDescription() const {
782cc14bf22STom Eccles   return "Move heap allocated array temporaries to the stack";
783cc14bf22STom Eccles }
784cc14bf22STom Eccles 
785cc14bf22STom Eccles void StackArraysPass::runOnOperation() {
7861e64864cSTom Eccles   mlir::func::FuncOp func = getOperation();
787cc14bf22STom Eccles 
788cc14bf22STom Eccles   auto &analysis = getAnalysis<StackArraysAnalysisWrapper>();
789408f4196STom Eccles   const StackArraysAnalysisWrapper::AllocMemMap *candidateOps =
790408f4196STom Eccles       analysis.getCandidateOps(func);
791408f4196STom Eccles   if (!candidateOps) {
792cc14bf22STom Eccles     signalPassFailure();
793cc14bf22STom Eccles     return;
794cc14bf22STom Eccles   }
795cc14bf22STom Eccles 
796408f4196STom Eccles   if (candidateOps->empty())
797cc14bf22STom Eccles     return;
798408f4196STom Eccles   runCount += candidateOps->size();
799408f4196STom Eccles 
800408f4196STom Eccles   llvm::SmallVector<mlir::Operation *> opsToConvert;
801408f4196STom Eccles   opsToConvert.reserve(candidateOps->size());
802408f4196STom Eccles   for (auto [op, _] : *candidateOps)
803408f4196STom Eccles     opsToConvert.push_back(op);
804cc14bf22STom Eccles 
805cc14bf22STom Eccles   mlir::MLIRContext &context = getContext();
806cc14bf22STom Eccles   mlir::RewritePatternSet patterns(&context);
807408f4196STom Eccles   mlir::GreedyRewriteConfig config;
808408f4196STom Eccles   // prevent the pattern driver form merging blocks
809a506279eSMehdi Amini   config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
810cc14bf22STom Eccles 
811408f4196STom Eccles   patterns.insert<AllocMemConversion>(&context, *candidateOps);
81209dfc571SJacques Pienaar   if (mlir::failed(mlir::applyOpPatternsGreedily(
81309dfc571SJacques Pienaar           opsToConvert, std::move(patterns), config))) {
814cc14bf22STom Eccles     mlir::emitError(func->getLoc(), "error in stack arrays optimization\n");
815cc14bf22STom Eccles     signalPassFailure();
816cc14bf22STom Eccles   }
817cc14bf22STom Eccles }
818