1cc14bf22STom Eccles //===- StackArrays.cpp ----------------------------------------------------===// 2cc14bf22STom Eccles // 3cc14bf22STom Eccles // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4cc14bf22STom Eccles // See https://llvm.org/LICENSE.txt for license information. 5cc14bf22STom Eccles // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6cc14bf22STom Eccles // 7cc14bf22STom Eccles //===----------------------------------------------------------------------===// 8cc14bf22STom Eccles 9cc14bf22STom Eccles #include "flang/Optimizer/Builder/FIRBuilder.h" 10cc14bf22STom Eccles #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" 11cc14bf22STom Eccles #include "flang/Optimizer/Dialect/FIRAttr.h" 12cc14bf22STom Eccles #include "flang/Optimizer/Dialect/FIRDialect.h" 13cc14bf22STom Eccles #include "flang/Optimizer/Dialect/FIROps.h" 14cc14bf22STom Eccles #include "flang/Optimizer/Dialect/FIRType.h" 15b07ef9e7SRenaud-K #include "flang/Optimizer/Dialect/Support/FIRContext.h" 16cc14bf22STom Eccles #include "flang/Optimizer/Transforms/Passes.h" 17cc14bf22STom Eccles #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" 18cc14bf22STom Eccles #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 19cc14bf22STom Eccles #include "mlir/Analysis/DataFlow/DenseAnalysis.h" 20cc14bf22STom Eccles #include "mlir/Analysis/DataFlowFramework.h" 21cc14bf22STom Eccles #include "mlir/Dialect/Func/IR/FuncOps.h" 22cc14bf22STom Eccles #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 23cc14bf22STom Eccles #include "mlir/IR/Builders.h" 24cc14bf22STom Eccles #include "mlir/IR/Diagnostics.h" 25cc14bf22STom Eccles #include "mlir/IR/Value.h" 26cc14bf22STom Eccles #include "mlir/Interfaces/LoopLikeInterface.h" 27cc14bf22STom Eccles #include "mlir/Pass/Pass.h" 28408f4196STom Eccles #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29cc14bf22STom Eccles #include "mlir/Transforms/Passes.h" 30cc14bf22STom Eccles #include "llvm/ADT/DenseMap.h" 31cc14bf22STom Eccles #include "llvm/ADT/DenseSet.h" 32cc14bf22STom Eccles #include "llvm/ADT/PointerUnion.h" 33cc14bf22STom Eccles #include "llvm/Support/Casting.h" 34cc14bf22STom Eccles #include "llvm/Support/raw_ostream.h" 35cc14bf22STom Eccles #include <optional> 36cc14bf22STom Eccles 37cc14bf22STom Eccles namespace fir { 38cc14bf22STom Eccles #define GEN_PASS_DEF_STACKARRAYS 39cc14bf22STom Eccles #include "flang/Optimizer/Transforms/Passes.h.inc" 40cc14bf22STom Eccles } // namespace fir 41cc14bf22STom Eccles 42cc14bf22STom Eccles #define DEBUG_TYPE "stack-arrays" 43cc14bf22STom Eccles 44e2153241STom Eccles static llvm::cl::opt<std::size_t> maxAllocsPerFunc( 45e2153241STom Eccles "stack-arrays-max-allocs", 46e2153241STom Eccles llvm::cl::desc("The maximum number of heap allocations to consider in one " 47e2153241STom Eccles "function before skipping (to save compilation time). Set " 48e2153241STom Eccles "to 0 for no limit."), 49e2153241STom Eccles llvm::cl::init(1000), llvm::cl::Hidden); 50e2153241STom Eccles 51cc14bf22STom Eccles namespace { 52cc14bf22STom Eccles 53cc14bf22STom Eccles /// The state of an SSA value at each program point 54cc14bf22STom Eccles enum class AllocationState { 55cc14bf22STom Eccles /// This means that the allocation state of a variable cannot be determined 56cc14bf22STom Eccles /// at this program point, e.g. because one route through a conditional freed 57cc14bf22STom Eccles /// the variable and the other route didn't. 58cc14bf22STom Eccles /// This asserts a known-unknown: different from the unknown-unknown of having 59cc14bf22STom Eccles /// no AllocationState stored for a particular SSA value 60cc14bf22STom Eccles Unknown, 61cc14bf22STom Eccles /// Means this SSA value was allocated on the heap in this function and has 62cc14bf22STom Eccles /// now been freed 63cc14bf22STom Eccles Freed, 64cc14bf22STom Eccles /// Means this SSA value was allocated on the heap in this function and is a 65cc14bf22STom Eccles /// candidate for moving to the stack 66cc14bf22STom Eccles Allocated, 67cc14bf22STom Eccles }; 68cc14bf22STom Eccles 69cc14bf22STom Eccles /// Stores where an alloca should be inserted. If the PointerUnion is an 70cc14bf22STom Eccles /// Operation the alloca should be inserted /after/ the operation. If it is a 71cc14bf22STom Eccles /// block, the alloca can be placed anywhere in that block. 72cc14bf22STom Eccles class InsertionPoint { 73cc14bf22STom Eccles llvm::PointerUnion<mlir::Operation *, mlir::Block *> location; 74cc14bf22STom Eccles bool saveRestoreStack; 75cc14bf22STom Eccles 76cc14bf22STom Eccles /// Get contained pointer type or nullptr 77cc14bf22STom Eccles template <class T> 78cc14bf22STom Eccles T *tryGetPtr() const { 79392651a7SKazu Hirata // Use llvm::dyn_cast_if_present because location may be null here. 80392651a7SKazu Hirata if (T *ptr = llvm::dyn_cast_if_present<T *>(location)) 81392651a7SKazu Hirata return ptr; 82cc14bf22STom Eccles return nullptr; 83cc14bf22STom Eccles } 84cc14bf22STom Eccles 85cc14bf22STom Eccles public: 86cc14bf22STom Eccles template <class T> 87cc14bf22STom Eccles InsertionPoint(T *ptr, bool saveRestoreStack = false) 88cc14bf22STom Eccles : location(ptr), saveRestoreStack{saveRestoreStack} {} 89cc14bf22STom Eccles InsertionPoint(std::nullptr_t null) 90cc14bf22STom Eccles : location(null), saveRestoreStack{false} {} 91cc14bf22STom Eccles 92cc14bf22STom Eccles /// Get contained operation, or nullptr 93cc14bf22STom Eccles mlir::Operation *tryGetOperation() const { 94cc14bf22STom Eccles return tryGetPtr<mlir::Operation>(); 95cc14bf22STom Eccles } 96cc14bf22STom Eccles 97cc14bf22STom Eccles /// Get contained block, or nullptr 98cc14bf22STom Eccles mlir::Block *tryGetBlock() const { return tryGetPtr<mlir::Block>(); } 99cc14bf22STom Eccles 100cc14bf22STom Eccles /// Get whether the stack should be saved/restored. If yes, an llvm.stacksave 101cc14bf22STom Eccles /// intrinsic should be added before the alloca, and an llvm.stackrestore 102cc14bf22STom Eccles /// intrinsic should be added where the freemem is 103cc14bf22STom Eccles bool shouldSaveRestoreStack() const { return saveRestoreStack; } 104cc14bf22STom Eccles 105cc14bf22STom Eccles operator bool() const { return tryGetOperation() || tryGetBlock(); } 106cc14bf22STom Eccles 107cc14bf22STom Eccles bool operator==(const InsertionPoint &rhs) const { 108cc14bf22STom Eccles return (location == rhs.location) && 109cc14bf22STom Eccles (saveRestoreStack == rhs.saveRestoreStack); 110cc14bf22STom Eccles } 111cc14bf22STom Eccles 112cc14bf22STom Eccles bool operator!=(const InsertionPoint &rhs) const { return !(*this == rhs); } 113cc14bf22STom Eccles }; 114cc14bf22STom Eccles 115cc14bf22STom Eccles /// Maps SSA values to their AllocationState at a particular program point. 116cc14bf22STom Eccles /// Also caches the insertion points for the new alloca operations 117cc14bf22STom Eccles class LatticePoint : public mlir::dataflow::AbstractDenseLattice { 118cc14bf22STom Eccles // Maps all values we are interested in to states 119cc14bf22STom Eccles llvm::SmallDenseMap<mlir::Value, AllocationState, 1> stateMap; 120cc14bf22STom Eccles 121cc14bf22STom Eccles public: 122cc14bf22STom Eccles MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LatticePoint) 123cc14bf22STom Eccles using AbstractDenseLattice::AbstractDenseLattice; 124cc14bf22STom Eccles 125cc14bf22STom Eccles bool operator==(const LatticePoint &rhs) const { 126cc14bf22STom Eccles return stateMap == rhs.stateMap; 127cc14bf22STom Eccles } 128cc14bf22STom Eccles 129cc14bf22STom Eccles /// Join the lattice accross control-flow edges 130cc14bf22STom Eccles mlir::ChangeResult join(const AbstractDenseLattice &lattice) override; 131cc14bf22STom Eccles 132cc14bf22STom Eccles void print(llvm::raw_ostream &os) const override; 133cc14bf22STom Eccles 134cc14bf22STom Eccles /// Clear all modifications 135cc14bf22STom Eccles mlir::ChangeResult reset(); 136cc14bf22STom Eccles 137cc14bf22STom Eccles /// Set the state of an SSA value 138cc14bf22STom Eccles mlir::ChangeResult set(mlir::Value value, AllocationState state); 139cc14bf22STom Eccles 140cc14bf22STom Eccles /// Get fir.allocmem ops which were allocated in this function and always 141cc14bf22STom Eccles /// freed before the function returns, plus whre to insert replacement 142cc14bf22STom Eccles /// fir.alloca ops 143cc14bf22STom Eccles void appendFreedValues(llvm::DenseSet<mlir::Value> &out) const; 144cc14bf22STom Eccles 145cc14bf22STom Eccles std::optional<AllocationState> get(mlir::Value val) const; 146cc14bf22STom Eccles }; 147cc14bf22STom Eccles 148cc14bf22STom Eccles class AllocationAnalysis 149b2b7efb9SAlex Zinenko : public mlir::dataflow::DenseForwardDataFlowAnalysis<LatticePoint> { 150cc14bf22STom Eccles public: 151b2b7efb9SAlex Zinenko using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis; 152cc14bf22STom Eccles 15315e915a4SIvan Butygin mlir::LogicalResult visitOperation(mlir::Operation *op, 15415e915a4SIvan Butygin const LatticePoint &before, 155cc14bf22STom Eccles LatticePoint *after) override; 156cc14bf22STom Eccles 157cc14bf22STom Eccles /// At an entry point, the last modifications of all memory resources are 158cc14bf22STom Eccles /// yet to be determined 159cc14bf22STom Eccles void setToEntryState(LatticePoint *lattice) override; 160cc14bf22STom Eccles 161cc14bf22STom Eccles protected: 162cc14bf22STom Eccles /// Visit control flow operations and decide whether to call visitOperation 163cc14bf22STom Eccles /// to apply the transfer function 16415e915a4SIvan Butygin mlir::LogicalResult processOperation(mlir::Operation *op) override; 165cc14bf22STom Eccles }; 166cc14bf22STom Eccles 167cc14bf22STom Eccles /// Drives analysis to find candidate fir.allocmem operations which could be 168cc14bf22STom Eccles /// moved to the stack. Intended to be used with mlir::Pass::getAnalysis 169cc14bf22STom Eccles class StackArraysAnalysisWrapper { 170cc14bf22STom Eccles public: 171cc14bf22STom Eccles MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StackArraysAnalysisWrapper) 172cc14bf22STom Eccles 173cc14bf22STom Eccles // Maps fir.allocmem -> place to insert alloca 174cc14bf22STom Eccles using AllocMemMap = llvm::DenseMap<mlir::Operation *, InsertionPoint>; 175cc14bf22STom Eccles 176cc14bf22STom Eccles StackArraysAnalysisWrapper(mlir::Operation *op) {} 177cc14bf22STom Eccles 178408f4196STom Eccles // returns nullptr if analysis failed 179408f4196STom Eccles const AllocMemMap *getCandidateOps(mlir::Operation *func); 180cc14bf22STom Eccles 181cc14bf22STom Eccles private: 182cc14bf22STom Eccles llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps; 183cc14bf22STom Eccles 184db791b27SRamkumar Ramachandra llvm::LogicalResult analyseFunction(mlir::Operation *func); 185cc14bf22STom Eccles }; 186cc14bf22STom Eccles 187cc14bf22STom Eccles /// Converts a fir.allocmem to a fir.alloca 188cc14bf22STom Eccles class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> { 189cc14bf22STom Eccles public: 190408f4196STom Eccles explicit AllocMemConversion( 191cc14bf22STom Eccles mlir::MLIRContext *ctx, 192408f4196STom Eccles const StackArraysAnalysisWrapper::AllocMemMap &candidateOps) 193408f4196STom Eccles : OpRewritePattern(ctx), candidateOps{candidateOps} {} 194cc14bf22STom Eccles 195db791b27SRamkumar Ramachandra llvm::LogicalResult 196cc14bf22STom Eccles matchAndRewrite(fir::AllocMemOp allocmem, 197cc14bf22STom Eccles mlir::PatternRewriter &rewriter) const override; 198cc14bf22STom Eccles 199cc14bf22STom Eccles /// Determine where to insert the alloca operation. The returned value should 200cc14bf22STom Eccles /// be checked to see if it is inside a loop 201*e3cd88a7SSlava Zakharin static InsertionPoint 202*e3cd88a7SSlava Zakharin findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc, 203*e3cd88a7SSlava Zakharin const llvm::SmallVector<mlir::Operation *> &freeOps); 204cc14bf22STom Eccles 205cc14bf22STom Eccles private: 206408f4196STom Eccles /// Handle to the DFA (already run) 207408f4196STom Eccles const StackArraysAnalysisWrapper::AllocMemMap &candidateOps; 208cc14bf22STom Eccles 209cc14bf22STom Eccles /// If we failed to find an insertion point not inside a loop, see if it would 210cc14bf22STom Eccles /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop 211*e3cd88a7SSlava Zakharin static InsertionPoint findAllocaLoopInsertionPoint( 212*e3cd88a7SSlava Zakharin fir::AllocMemOp &oldAlloc, 213*e3cd88a7SSlava Zakharin const llvm::SmallVector<mlir::Operation *> &freeOps); 214cc14bf22STom Eccles 215cc14bf22STom Eccles /// Returns the alloca if it was successfully inserted, otherwise {} 216cc14bf22STom Eccles std::optional<fir::AllocaOp> 217cc14bf22STom Eccles insertAlloca(fir::AllocMemOp &oldAlloc, 218cc14bf22STom Eccles mlir::PatternRewriter &rewriter) const; 219cc14bf22STom Eccles 220cc14bf22STom Eccles /// Inserts a stacksave before oldAlloc and a stackrestore after each freemem 221cc14bf22STom Eccles void insertStackSaveRestore(fir::AllocMemOp &oldAlloc, 222cc14bf22STom Eccles mlir::PatternRewriter &rewriter) const; 223cc14bf22STom Eccles }; 224cc14bf22STom Eccles 225cc14bf22STom Eccles class StackArraysPass : public fir::impl::StackArraysBase<StackArraysPass> { 226cc14bf22STom Eccles public: 227cc14bf22STom Eccles StackArraysPass() = default; 228cc14bf22STom Eccles StackArraysPass(const StackArraysPass &pass); 229cc14bf22STom Eccles 230cc14bf22STom Eccles llvm::StringRef getDescription() const override; 231cc14bf22STom Eccles 232cc14bf22STom Eccles void runOnOperation() override; 233cc14bf22STom Eccles 234cc14bf22STom Eccles private: 235cc14bf22STom Eccles Statistic runCount{this, "stackArraysRunCount", 236cc14bf22STom Eccles "Number of heap allocations moved to the stack"}; 237cc14bf22STom Eccles }; 238cc14bf22STom Eccles 239cc14bf22STom Eccles } // namespace 240cc14bf22STom Eccles 241cc14bf22STom Eccles static void print(llvm::raw_ostream &os, AllocationState state) { 242cc14bf22STom Eccles switch (state) { 243cc14bf22STom Eccles case AllocationState::Unknown: 244cc14bf22STom Eccles os << "Unknown"; 245cc14bf22STom Eccles break; 246cc14bf22STom Eccles case AllocationState::Freed: 247cc14bf22STom Eccles os << "Freed"; 248cc14bf22STom Eccles break; 249cc14bf22STom Eccles case AllocationState::Allocated: 250cc14bf22STom Eccles os << "Allocated"; 251cc14bf22STom Eccles break; 252cc14bf22STom Eccles } 253cc14bf22STom Eccles } 254cc14bf22STom Eccles 255cc14bf22STom Eccles /// Join two AllocationStates for the same value coming from different CFG 256cc14bf22STom Eccles /// blocks 257cc14bf22STom Eccles static AllocationState join(AllocationState lhs, AllocationState rhs) { 258cc14bf22STom Eccles // | Allocated | Freed | Unknown 259cc14bf22STom Eccles // ========= | ========= | ========= | ========= 260cc14bf22STom Eccles // Allocated | Allocated | Unknown | Unknown 261cc14bf22STom Eccles // Freed | Unknown | Freed | Unknown 262cc14bf22STom Eccles // Unknown | Unknown | Unknown | Unknown 263cc14bf22STom Eccles if (lhs == rhs) 264cc14bf22STom Eccles return lhs; 265cc14bf22STom Eccles return AllocationState::Unknown; 266cc14bf22STom Eccles } 267cc14bf22STom Eccles 268cc14bf22STom Eccles mlir::ChangeResult LatticePoint::join(const AbstractDenseLattice &lattice) { 269cc14bf22STom Eccles const auto &rhs = static_cast<const LatticePoint &>(lattice); 270cc14bf22STom Eccles mlir::ChangeResult changed = mlir::ChangeResult::NoChange; 271cc14bf22STom Eccles 272cc14bf22STom Eccles // add everything from rhs to map, handling cases where values are in both 273cc14bf22STom Eccles for (const auto &[value, rhsState] : rhs.stateMap) { 274cc14bf22STom Eccles auto it = stateMap.find(value); 275cc14bf22STom Eccles if (it != stateMap.end()) { 276cc14bf22STom Eccles // value is present in both maps 277cc14bf22STom Eccles AllocationState myState = it->second; 278cc14bf22STom Eccles AllocationState newState = ::join(myState, rhsState); 279cc14bf22STom Eccles if (newState != myState) { 280cc14bf22STom Eccles changed = mlir::ChangeResult::Change; 281cc14bf22STom Eccles it->getSecond() = newState; 282cc14bf22STom Eccles } 283cc14bf22STom Eccles } else { 284cc14bf22STom Eccles // value not present in current map: add it 285cc14bf22STom Eccles stateMap.insert({value, rhsState}); 286cc14bf22STom Eccles changed = mlir::ChangeResult::Change; 287cc14bf22STom Eccles } 288cc14bf22STom Eccles } 289cc14bf22STom Eccles 290cc14bf22STom Eccles return changed; 291cc14bf22STom Eccles } 292cc14bf22STom Eccles 293cc14bf22STom Eccles void LatticePoint::print(llvm::raw_ostream &os) const { 294cc14bf22STom Eccles for (const auto &[value, state] : stateMap) { 295464d321eSKareem Ergawy os << "\n * " << value << ": "; 296cc14bf22STom Eccles ::print(os, state); 297cc14bf22STom Eccles } 298cc14bf22STom Eccles } 299cc14bf22STom Eccles 300cc14bf22STom Eccles mlir::ChangeResult LatticePoint::reset() { 301cc14bf22STom Eccles if (stateMap.empty()) 302cc14bf22STom Eccles return mlir::ChangeResult::NoChange; 303cc14bf22STom Eccles stateMap.clear(); 304cc14bf22STom Eccles return mlir::ChangeResult::Change; 305cc14bf22STom Eccles } 306cc14bf22STom Eccles 307cc14bf22STom Eccles mlir::ChangeResult LatticePoint::set(mlir::Value value, AllocationState state) { 308cc14bf22STom Eccles if (stateMap.count(value)) { 309cc14bf22STom Eccles // already in map 310cc14bf22STom Eccles AllocationState &oldState = stateMap[value]; 311cc14bf22STom Eccles if (oldState != state) { 312cc14bf22STom Eccles stateMap[value] = state; 313cc14bf22STom Eccles return mlir::ChangeResult::Change; 314cc14bf22STom Eccles } 315cc14bf22STom Eccles return mlir::ChangeResult::NoChange; 316cc14bf22STom Eccles } 317cc14bf22STom Eccles stateMap.insert({value, state}); 318cc14bf22STom Eccles return mlir::ChangeResult::Change; 319cc14bf22STom Eccles } 320cc14bf22STom Eccles 321cc14bf22STom Eccles /// Get values which were allocated in this function and always freed before 322cc14bf22STom Eccles /// the function returns 323cc14bf22STom Eccles void LatticePoint::appendFreedValues(llvm::DenseSet<mlir::Value> &out) const { 324cc14bf22STom Eccles for (auto &[value, state] : stateMap) { 325cc14bf22STom Eccles if (state == AllocationState::Freed) 326cc14bf22STom Eccles out.insert(value); 327cc14bf22STom Eccles } 328cc14bf22STom Eccles } 329cc14bf22STom Eccles 330cc14bf22STom Eccles std::optional<AllocationState> LatticePoint::get(mlir::Value val) const { 331cc14bf22STom Eccles auto it = stateMap.find(val); 332cc14bf22STom Eccles if (it == stateMap.end()) 333cc14bf22STom Eccles return {}; 334cc14bf22STom Eccles return it->second; 335cc14bf22STom Eccles } 336cc14bf22STom Eccles 337303249c4STom Eccles static mlir::Value lookThroughDeclaresAndConverts(mlir::Value value) { 338303249c4STom Eccles while (mlir::Operation *op = value.getDefiningOp()) { 339303249c4STom Eccles if (auto declareOp = llvm::dyn_cast<fir::DeclareOp>(op)) 340303249c4STom Eccles value = declareOp.getMemref(); 341303249c4STom Eccles else if (auto convertOp = llvm::dyn_cast<fir::ConvertOp>(op)) 342303249c4STom Eccles value = convertOp->getOperand(0); 343303249c4STom Eccles else 344303249c4STom Eccles return value; 345303249c4STom Eccles } 346303249c4STom Eccles return value; 347303249c4STom Eccles } 348303249c4STom Eccles 34915e915a4SIvan Butygin mlir::LogicalResult AllocationAnalysis::visitOperation( 35015e915a4SIvan Butygin mlir::Operation *op, const LatticePoint &before, LatticePoint *after) { 351cc14bf22STom Eccles LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op 352cc14bf22STom Eccles << "\n"); 353cc14bf22STom Eccles LLVM_DEBUG(llvm::dbgs() << "--Lattice in: " << before << "\n"); 354cc14bf22STom Eccles 355cc14bf22STom Eccles // propagate before -> after 356cc14bf22STom Eccles mlir::ChangeResult changed = after->join(before); 357cc14bf22STom Eccles 358cc14bf22STom Eccles if (auto allocmem = mlir::dyn_cast<fir::AllocMemOp>(op)) { 359cc14bf22STom Eccles assert(op->getNumResults() == 1 && "fir.allocmem has one result"); 360cc14bf22STom Eccles auto attr = op->getAttrOfType<fir::MustBeHeapAttr>( 361cc14bf22STom Eccles fir::MustBeHeapAttr::getAttrName()); 362cc14bf22STom Eccles if (attr && attr.getValue()) { 363cc14bf22STom Eccles LLVM_DEBUG(llvm::dbgs() << "--Found fir.must_be_heap: skipping\n"); 364cc14bf22STom Eccles // skip allocation marked not to be moved 36515e915a4SIvan Butygin return mlir::success(); 366cc14bf22STom Eccles } 367cc14bf22STom Eccles 368cc14bf22STom Eccles auto retTy = allocmem.getAllocatedType(); 369fac349a1SChristian Sigg if (!mlir::isa<fir::SequenceType>(retTy)) { 370cc14bf22STom Eccles LLVM_DEBUG(llvm::dbgs() 371cc14bf22STom Eccles << "--Allocation is not for an array: skipping\n"); 37215e915a4SIvan Butygin return mlir::success(); 373cc14bf22STom Eccles } 374cc14bf22STom Eccles 375cc14bf22STom Eccles mlir::Value result = op->getResult(0); 376cc14bf22STom Eccles changed |= after->set(result, AllocationState::Allocated); 377cc14bf22STom Eccles } else if (mlir::isa<fir::FreeMemOp>(op)) { 378cc14bf22STom Eccles assert(op->getNumOperands() == 1 && "fir.freemem has one operand"); 379cc14bf22STom Eccles mlir::Value operand = op->getOperand(0); 380464d321eSKareem Ergawy 381464d321eSKareem Ergawy // Note: StackArrays is scheduled in the pass pipeline after lowering hlfir 382303249c4STom Eccles // to fir. Therefore, we only need to handle `fir::DeclareOp`s. Also look 383303249c4STom Eccles // past converts in case the pointer was changed between different pointer 384303249c4STom Eccles // types. 385303249c4STom Eccles operand = lookThroughDeclaresAndConverts(operand); 386464d321eSKareem Ergawy 387cc14bf22STom Eccles std::optional<AllocationState> operandState = before.get(operand); 388cc14bf22STom Eccles if (operandState && *operandState == AllocationState::Allocated) { 389cc14bf22STom Eccles // don't tag things not allocated in this function as freed, so that we 390cc14bf22STom Eccles // don't think they are candidates for moving to the stack 391cc14bf22STom Eccles changed |= after->set(operand, AllocationState::Freed); 392cc14bf22STom Eccles } 393cc14bf22STom Eccles } else if (mlir::isa<fir::ResultOp>(op)) { 394cc14bf22STom Eccles mlir::Operation *parent = op->getParentOp(); 3954b3f251bSdonald chen LatticePoint *parentLattice = getLattice(getProgramPointAfter(parent)); 396cc14bf22STom Eccles assert(parentLattice); 397cc14bf22STom Eccles mlir::ChangeResult parentChanged = parentLattice->join(*after); 398cc14bf22STom Eccles propagateIfChanged(parentLattice, parentChanged); 399cc14bf22STom Eccles } 400cc14bf22STom Eccles 401cc14bf22STom Eccles // we pass lattices straight through fir.call because called functions should 402cc14bf22STom Eccles // not deallocate flang-generated array temporaries 403cc14bf22STom Eccles 404cc14bf22STom Eccles LLVM_DEBUG(llvm::dbgs() << "--Lattice out: " << *after << "\n"); 405cc14bf22STom Eccles propagateIfChanged(after, changed); 40615e915a4SIvan Butygin return mlir::success(); 407cc14bf22STom Eccles } 408cc14bf22STom Eccles 409cc14bf22STom Eccles void AllocationAnalysis::setToEntryState(LatticePoint *lattice) { 410cc14bf22STom Eccles propagateIfChanged(lattice, lattice->reset()); 411cc14bf22STom Eccles } 412cc14bf22STom Eccles 413cc14bf22STom Eccles /// Mostly a copy of AbstractDenseLattice::processOperation - the difference 414cc14bf22STom Eccles /// being that call operations are passed through to the transfer function 41515e915a4SIvan Butygin mlir::LogicalResult AllocationAnalysis::processOperation(mlir::Operation *op) { 4164b3f251bSdonald chen mlir::ProgramPoint *point = getProgramPointAfter(op); 417cc14bf22STom Eccles // If the containing block is not executable, bail out. 4184b3f251bSdonald chen if (op->getBlock() != nullptr && 4194b3f251bSdonald chen !getOrCreateFor<mlir::dataflow::Executable>( 4204b3f251bSdonald chen point, getProgramPointBefore(op->getBlock())) 4214b3f251bSdonald chen ->isLive()) 42215e915a4SIvan Butygin return mlir::success(); 423cc14bf22STom Eccles 424cc14bf22STom Eccles // Get the dense lattice to update 4254b3f251bSdonald chen mlir::dataflow::AbstractDenseLattice *after = getLattice(point); 426cc14bf22STom Eccles 427cc14bf22STom Eccles // If this op implements region control-flow, then control-flow dictates its 428cc14bf22STom Eccles // transfer function. 42915e915a4SIvan Butygin if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) { 4304b3f251bSdonald chen visitRegionBranchOperation(point, branch, after); 43115e915a4SIvan Butygin return mlir::success(); 43215e915a4SIvan Butygin } 433cc14bf22STom Eccles 434cc14bf22STom Eccles // pass call operations through to the transfer function 435cc14bf22STom Eccles 436cc14bf22STom Eccles // Get the dense state before the execution of the op. 4374b3f251bSdonald chen const mlir::dataflow::AbstractDenseLattice *before = 4384b3f251bSdonald chen getLatticeFor(point, getProgramPointBefore(op)); 439cc14bf22STom Eccles 440cc14bf22STom Eccles /// Invoke the operation transfer function 44115e915a4SIvan Butygin return visitOperationImpl(op, *before, after); 442cc14bf22STom Eccles } 443cc14bf22STom Eccles 444db791b27SRamkumar Ramachandra llvm::LogicalResult 445408f4196STom Eccles StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) { 446cc14bf22STom Eccles assert(mlir::isa<mlir::func::FuncOp>(func)); 447e2153241STom Eccles size_t nAllocs = 0; 448e2153241STom Eccles func->walk([&nAllocs](fir::AllocMemOp) { nAllocs++; }); 449e2153241STom Eccles // don't bother with the analysis if there are no heap allocations 450e2153241STom Eccles if (nAllocs == 0) 451e2153241STom Eccles return mlir::success(); 452e2153241STom Eccles if ((maxAllocsPerFunc != 0) && (nAllocs > maxAllocsPerFunc)) { 453e2153241STom Eccles LLVM_DEBUG(llvm::dbgs() << "Skipping stack arrays for function with " 454e2153241STom Eccles << nAllocs << " heap allocations"); 455e2153241STom Eccles return mlir::success(); 456e2153241STom Eccles } 457e2153241STom Eccles 458cc14bf22STom Eccles mlir::DataFlowSolver solver; 459cc14bf22STom Eccles // constant propagation is required for dead code analysis, dead code analysis 460cc14bf22STom Eccles // is required to mark blocks live (required for mlir dense dfa) 461cc14bf22STom Eccles solver.load<mlir::dataflow::SparseConstantPropagation>(); 462cc14bf22STom Eccles solver.load<mlir::dataflow::DeadCodeAnalysis>(); 463cc14bf22STom Eccles 464cc14bf22STom Eccles auto [it, inserted] = funcMaps.try_emplace(func); 465cc14bf22STom Eccles AllocMemMap &candidateOps = it->second; 466cc14bf22STom Eccles 467cc14bf22STom Eccles solver.load<AllocationAnalysis>(); 468cc14bf22STom Eccles if (failed(solver.initializeAndRun(func))) { 469cc14bf22STom Eccles llvm::errs() << "DataFlowSolver failed!"; 470408f4196STom Eccles return mlir::failure(); 471cc14bf22STom Eccles } 472cc14bf22STom Eccles 4734b3f251bSdonald chen LatticePoint point{solver.getProgramPointAfter(func)}; 4747a49d50fSTom Eccles auto joinOperationLattice = [&](mlir::Operation *op) { 4754b3f251bSdonald chen const LatticePoint *lattice = 4764b3f251bSdonald chen solver.lookupState<LatticePoint>(solver.getProgramPointAfter(op)); 477cc14bf22STom Eccles // there will be no lattice for an unreachable block 478cc14bf22STom Eccles if (lattice) 479c50de57fSKazu Hirata (void)point.join(*lattice); 4807a49d50fSTom Eccles }; 481698b42ccSKareem Ergawy 4827a49d50fSTom Eccles func->walk([&](mlir::func::ReturnOp child) { joinOperationLattice(child); }); 4837a49d50fSTom Eccles func->walk([&](fir::UnreachableOp child) { joinOperationLattice(child); }); 484464d321eSKareem Ergawy func->walk( 485464d321eSKareem Ergawy [&](mlir::omp::TerminatorOp child) { joinOperationLattice(child); }); 486698b42ccSKareem Ergawy func->walk([&](mlir::omp::YieldOp child) { joinOperationLattice(child); }); 487464d321eSKareem Ergawy 488cc14bf22STom Eccles llvm::DenseSet<mlir::Value> freedValues; 489cc14bf22STom Eccles point.appendFreedValues(freedValues); 490cc14bf22STom Eccles 491*e3cd88a7SSlava Zakharin // Find all fir.freemem operations corresponding to fir.allocmem 492*e3cd88a7SSlava Zakharin // in freedValues. It is best to find the association going back 493*e3cd88a7SSlava Zakharin // from fir.freemem to fir.allocmem through the def-use chains, 494*e3cd88a7SSlava Zakharin // so that we can use lookThroughDeclaresAndConverts same way 495*e3cd88a7SSlava Zakharin // the AllocationAnalysis is handling them. 496*e3cd88a7SSlava Zakharin llvm::DenseMap<mlir::Operation *, llvm::SmallVector<mlir::Operation *>> 497*e3cd88a7SSlava Zakharin allocToFreeMemMap; 498*e3cd88a7SSlava Zakharin func->walk([&](fir::FreeMemOp freeOp) { 499*e3cd88a7SSlava Zakharin mlir::Value memref = lookThroughDeclaresAndConverts(freeOp.getHeapref()); 500*e3cd88a7SSlava Zakharin if (!freedValues.count(memref)) 501*e3cd88a7SSlava Zakharin return; 502*e3cd88a7SSlava Zakharin 503*e3cd88a7SSlava Zakharin auto allocMem = memref.getDefiningOp<fir::AllocMemOp>(); 504*e3cd88a7SSlava Zakharin allocToFreeMemMap[allocMem].push_back(freeOp); 505*e3cd88a7SSlava Zakharin }); 506*e3cd88a7SSlava Zakharin 507cc14bf22STom Eccles // We only replace allocations which are definately freed on all routes 508cc14bf22STom Eccles // through the function because otherwise the allocation may have an intende 509cc14bf22STom Eccles // lifetime longer than the current stack frame (e.g. a heap allocation which 510cc14bf22STom Eccles // is then freed by another function). 511cc14bf22STom Eccles for (mlir::Value freedValue : freedValues) { 512cc14bf22STom Eccles fir::AllocMemOp allocmem = freedValue.getDefiningOp<fir::AllocMemOp>(); 513cc14bf22STom Eccles InsertionPoint insertionPoint = 514*e3cd88a7SSlava Zakharin AllocMemConversion::findAllocaInsertionPoint( 515*e3cd88a7SSlava Zakharin allocmem, allocToFreeMemMap[allocmem]); 516cc14bf22STom Eccles if (insertionPoint) 517cc14bf22STom Eccles candidateOps.insert({allocmem, insertionPoint}); 518cc14bf22STom Eccles } 519cc14bf22STom Eccles 520cc14bf22STom Eccles LLVM_DEBUG(for (auto [allocMemOp, _] 521cc14bf22STom Eccles : candidateOps) { 522cc14bf22STom Eccles llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n'; 523cc14bf22STom Eccles }); 524408f4196STom Eccles return mlir::success(); 525cc14bf22STom Eccles } 526cc14bf22STom Eccles 527408f4196STom Eccles const StackArraysAnalysisWrapper::AllocMemMap * 528cc14bf22STom Eccles StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) { 529408f4196STom Eccles if (!funcMaps.contains(func)) 530408f4196STom Eccles if (mlir::failed(analyseFunction(func))) 531408f4196STom Eccles return nullptr; 532408f4196STom Eccles return &funcMaps[func]; 533cc14bf22STom Eccles } 534cc14bf22STom Eccles 535775de675STom Eccles /// Restore the old allocation type exected by existing code 536775de675STom Eccles static mlir::Value convertAllocationType(mlir::PatternRewriter &rewriter, 537775de675STom Eccles const mlir::Location &loc, 538775de675STom Eccles mlir::Value heap, mlir::Value stack) { 539775de675STom Eccles mlir::Type heapTy = heap.getType(); 540775de675STom Eccles mlir::Type stackTy = stack.getType(); 541775de675STom Eccles 542775de675STom Eccles if (heapTy == stackTy) 543775de675STom Eccles return stack; 544775de675STom Eccles 545775de675STom Eccles fir::HeapType firHeapTy = mlir::cast<fir::HeapType>(heapTy); 54676c3c5bcSTom Eccles LLVM_ATTRIBUTE_UNUSED fir::ReferenceType firRefTy = 54776c3c5bcSTom Eccles mlir::cast<fir::ReferenceType>(stackTy); 548775de675STom Eccles assert(firHeapTy.getElementType() == firRefTy.getElementType() && 549775de675STom Eccles "Allocations must have the same type"); 550775de675STom Eccles 551775de675STom Eccles auto insertionPoint = rewriter.saveInsertionPoint(); 552775de675STom Eccles rewriter.setInsertionPointAfter(stack.getDefiningOp()); 553775de675STom Eccles mlir::Value conv = 554775de675STom Eccles rewriter.create<fir::ConvertOp>(loc, firHeapTy, stack).getResult(); 555775de675STom Eccles rewriter.restoreInsertionPoint(insertionPoint); 556775de675STom Eccles return conv; 557775de675STom Eccles } 558775de675STom Eccles 559db791b27SRamkumar Ramachandra llvm::LogicalResult 560cc14bf22STom Eccles AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem, 561cc14bf22STom Eccles mlir::PatternRewriter &rewriter) const { 562cc14bf22STom Eccles auto oldInsertionPt = rewriter.saveInsertionPoint(); 563cc14bf22STom Eccles // add alloca operation 564cc14bf22STom Eccles std::optional<fir::AllocaOp> alloca = insertAlloca(allocmem, rewriter); 565cc14bf22STom Eccles rewriter.restoreInsertionPoint(oldInsertionPt); 566cc14bf22STom Eccles if (!alloca) 567cc14bf22STom Eccles return mlir::failure(); 568cc14bf22STom Eccles 569cc14bf22STom Eccles // remove freemem operations 570408f4196STom Eccles llvm::SmallVector<mlir::Operation *> erases; 571303249c4STom Eccles mlir::Operation *parent = allocmem->getParentOp(); 572303249c4STom Eccles // TODO: this shouldn't need to be re-calculated for every allocmem 573303249c4STom Eccles parent->walk([&](fir::FreeMemOp freeOp) { 574303249c4STom Eccles if (lookThroughDeclaresAndConverts(freeOp->getOperand(0)) == allocmem) 575303249c4STom Eccles erases.push_back(freeOp); 576303249c4STom Eccles }); 577464d321eSKareem Ergawy 578408f4196STom Eccles // now we are done iterating the users, it is safe to mutate them 579408f4196STom Eccles for (mlir::Operation *erase : erases) 580408f4196STom Eccles rewriter.eraseOp(erase); 581cc14bf22STom Eccles 582cc14bf22STom Eccles // replace references to heap allocation with references to stack allocation 583775de675STom Eccles mlir::Value newValue = convertAllocationType( 584775de675STom Eccles rewriter, allocmem.getLoc(), allocmem.getResult(), alloca->getResult()); 585775de675STom Eccles rewriter.replaceAllUsesWith(allocmem.getResult(), newValue); 586cc14bf22STom Eccles 587cc14bf22STom Eccles // remove allocmem operation 588cc14bf22STom Eccles rewriter.eraseOp(allocmem.getOperation()); 589cc14bf22STom Eccles 590cc14bf22STom Eccles return mlir::success(); 591cc14bf22STom Eccles } 592cc14bf22STom Eccles 593cc14bf22STom Eccles static bool isInLoop(mlir::Block *block) { 594d5ea1b22STom Eccles return mlir::LoopLikeOpInterface::blockIsInLoop(block); 595cc14bf22STom Eccles } 596cc14bf22STom Eccles 597cc14bf22STom Eccles static bool isInLoop(mlir::Operation *op) { 598cc14bf22STom Eccles return isInLoop(op->getBlock()) || 599cc14bf22STom Eccles op->getParentOfType<mlir::LoopLikeOpInterface>(); 600cc14bf22STom Eccles } 601cc14bf22STom Eccles 602*e3cd88a7SSlava Zakharin InsertionPoint AllocMemConversion::findAllocaInsertionPoint( 603*e3cd88a7SSlava Zakharin fir::AllocMemOp &oldAlloc, 604*e3cd88a7SSlava Zakharin const llvm::SmallVector<mlir::Operation *> &freeOps) { 605cc14bf22STom Eccles // Ideally the alloca should be inserted at the end of the function entry 606cc14bf22STom Eccles // block so that we do not allocate stack space in a loop. However, 607cc14bf22STom Eccles // the operands to the alloca may not be available that early, so insert it 608cc14bf22STom Eccles // after the last operand becomes available 609cc14bf22STom Eccles // If the old allocmem op was in an openmp region then it should not be moved 610cc14bf22STom Eccles // outside of that 611cc14bf22STom Eccles LLVM_DEBUG(llvm::dbgs() << "StackArrays: findAllocaInsertionPoint: " 612cc14bf22STom Eccles << oldAlloc << "\n"); 613cc14bf22STom Eccles 614cc14bf22STom Eccles // check that an Operation or Block we are about to return is not in a loop 615cc14bf22STom Eccles auto checkReturn = [&](auto *point) -> InsertionPoint { 616cc14bf22STom Eccles if (isInLoop(point)) { 617cc14bf22STom Eccles mlir::Operation *oldAllocOp = oldAlloc.getOperation(); 618cc14bf22STom Eccles if (isInLoop(oldAllocOp)) { 619cc14bf22STom Eccles // where we want to put it is in a loop, and even the old location is in 620cc14bf22STom Eccles // a loop. Give up. 621*e3cd88a7SSlava Zakharin return findAllocaLoopInsertionPoint(oldAlloc, freeOps); 622cc14bf22STom Eccles } 623cc14bf22STom Eccles return {oldAllocOp}; 624cc14bf22STom Eccles } 625cc14bf22STom Eccles return {point}; 626cc14bf22STom Eccles }; 627cc14bf22STom Eccles 628cc14bf22STom Eccles auto oldOmpRegion = 629cc14bf22STom Eccles oldAlloc->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>(); 630cc14bf22STom Eccles 631cc14bf22STom Eccles // Find when the last operand value becomes available 632cc14bf22STom Eccles mlir::Block *operandsBlock = nullptr; 633cc14bf22STom Eccles mlir::Operation *lastOperand = nullptr; 634cc14bf22STom Eccles for (mlir::Value operand : oldAlloc.getOperands()) { 635cc14bf22STom Eccles LLVM_DEBUG(llvm::dbgs() << "--considering operand " << operand << "\n"); 636cc14bf22STom Eccles mlir::Operation *op = operand.getDefiningOp(); 637cc14bf22STom Eccles if (!op) 638cc14bf22STom Eccles return checkReturn(oldAlloc.getOperation()); 639cc14bf22STom Eccles if (!operandsBlock) 640cc14bf22STom Eccles operandsBlock = op->getBlock(); 641cc14bf22STom Eccles else if (operandsBlock != op->getBlock()) { 642cc14bf22STom Eccles LLVM_DEBUG(llvm::dbgs() 643cc14bf22STom Eccles << "----operand declared in a different block!\n"); 644cc14bf22STom Eccles // Operation::isBeforeInBlock requires the operations to be in the same 645cc14bf22STom Eccles // block. The best we can do is the location of the allocmem. 646cc14bf22STom Eccles return checkReturn(oldAlloc.getOperation()); 647cc14bf22STom Eccles } 648cc14bf22STom Eccles if (!lastOperand || lastOperand->isBeforeInBlock(op)) 649cc14bf22STom Eccles lastOperand = op; 650cc14bf22STom Eccles } 651cc14bf22STom Eccles 652cc14bf22STom Eccles if (lastOperand) { 653cc14bf22STom Eccles // there were value operands to the allocmem so insert after the last one 654cc14bf22STom Eccles LLVM_DEBUG(llvm::dbgs() 655cc14bf22STom Eccles << "--Placing after last operand: " << *lastOperand << "\n"); 656cc14bf22STom Eccles // check we aren't moving out of an omp region 657cc14bf22STom Eccles auto lastOpOmpRegion = 658cc14bf22STom Eccles lastOperand->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>(); 659cc14bf22STom Eccles if (lastOpOmpRegion == oldOmpRegion) 660cc14bf22STom Eccles return checkReturn(lastOperand); 661cc14bf22STom Eccles // Presumably this happened because the operands became ready before the 662cc14bf22STom Eccles // start of this openmp region. (lastOpOmpRegion != oldOmpRegion) should 663cc14bf22STom Eccles // imply that oldOmpRegion comes after lastOpOmpRegion. 664cc14bf22STom Eccles return checkReturn(oldOmpRegion.getAllocaBlock()); 665cc14bf22STom Eccles } 666cc14bf22STom Eccles 667cc14bf22STom Eccles // There were no value operands to the allocmem so we are safe to insert it 668cc14bf22STom Eccles // as early as we want 669cc14bf22STom Eccles 670cc14bf22STom Eccles // handle openmp case 671cc14bf22STom Eccles if (oldOmpRegion) 672cc14bf22STom Eccles return checkReturn(oldOmpRegion.getAllocaBlock()); 673cc14bf22STom Eccles 674cc14bf22STom Eccles // fall back to the function entry block 675cc14bf22STom Eccles mlir::func::FuncOp func = oldAlloc->getParentOfType<mlir::func::FuncOp>(); 676cc14bf22STom Eccles assert(func && "This analysis is run on func.func"); 677cc14bf22STom Eccles mlir::Block &entryBlock = func.getBlocks().front(); 678cc14bf22STom Eccles LLVM_DEBUG(llvm::dbgs() << "--Placing at the start of func entry block\n"); 679cc14bf22STom Eccles return checkReturn(&entryBlock); 680cc14bf22STom Eccles } 681cc14bf22STom Eccles 682*e3cd88a7SSlava Zakharin InsertionPoint AllocMemConversion::findAllocaLoopInsertionPoint( 683*e3cd88a7SSlava Zakharin fir::AllocMemOp &oldAlloc, 684*e3cd88a7SSlava Zakharin const llvm::SmallVector<mlir::Operation *> &freeOps) { 685cc14bf22STom Eccles mlir::Operation *oldAllocOp = oldAlloc; 686cc14bf22STom Eccles // This is only called as a last resort. We should try to insert at the 687cc14bf22STom Eccles // location of the old allocation, which is inside of a loop, using 688cc14bf22STom Eccles // llvm.stacksave/llvm.stackrestore 689cc14bf22STom Eccles 690cc14bf22STom Eccles assert(freeOps.size() && "DFA should only return freed memory"); 691cc14bf22STom Eccles 692cc14bf22STom Eccles // Don't attempt to reason about a stacksave/stackrestore between different 693cc14bf22STom Eccles // blocks 694cc14bf22STom Eccles for (mlir::Operation *free : freeOps) 695cc14bf22STom Eccles if (free->getBlock() != oldAllocOp->getBlock()) 696cc14bf22STom Eccles return {nullptr}; 697cc14bf22STom Eccles 698cc14bf22STom Eccles // Check that there aren't any other stack allocations in between the 699cc14bf22STom Eccles // stack save and stack restore 700cc14bf22STom Eccles // note: for flang generated temporaries there should only be one free op 701cc14bf22STom Eccles for (mlir::Operation *free : freeOps) { 702cc14bf22STom Eccles for (mlir::Operation *op = oldAlloc; op && op != free; 703cc14bf22STom Eccles op = op->getNextNode()) { 704cc14bf22STom Eccles if (mlir::isa<fir::AllocaOp>(op)) 705cc14bf22STom Eccles return {nullptr}; 706cc14bf22STom Eccles } 707cc14bf22STom Eccles } 708cc14bf22STom Eccles 709cc14bf22STom Eccles return InsertionPoint{oldAllocOp, /*shouldStackSaveRestore=*/true}; 710cc14bf22STom Eccles } 711cc14bf22STom Eccles 712cc14bf22STom Eccles std::optional<fir::AllocaOp> 713cc14bf22STom Eccles AllocMemConversion::insertAlloca(fir::AllocMemOp &oldAlloc, 714cc14bf22STom Eccles mlir::PatternRewriter &rewriter) const { 715cc14bf22STom Eccles auto it = candidateOps.find(oldAlloc.getOperation()); 716cc14bf22STom Eccles if (it == candidateOps.end()) 717cc14bf22STom Eccles return {}; 718cc14bf22STom Eccles InsertionPoint insertionPoint = it->second; 719cc14bf22STom Eccles if (!insertionPoint) 720cc14bf22STom Eccles return {}; 721cc14bf22STom Eccles 722cc14bf22STom Eccles if (insertionPoint.shouldSaveRestoreStack()) 723cc14bf22STom Eccles insertStackSaveRestore(oldAlloc, rewriter); 724cc14bf22STom Eccles 725cc14bf22STom Eccles mlir::Location loc = oldAlloc.getLoc(); 726cc14bf22STom Eccles mlir::Type varTy = oldAlloc.getInType(); 727cc14bf22STom Eccles if (mlir::Operation *op = insertionPoint.tryGetOperation()) { 728cc14bf22STom Eccles rewriter.setInsertionPointAfter(op); 729cc14bf22STom Eccles } else { 730cc14bf22STom Eccles mlir::Block *block = insertionPoint.tryGetBlock(); 731cc14bf22STom Eccles assert(block && "There must be a valid insertion point"); 732cc14bf22STom Eccles rewriter.setInsertionPointToStart(block); 733cc14bf22STom Eccles } 734cc14bf22STom Eccles 735cc14bf22STom Eccles auto unpackName = [](std::optional<llvm::StringRef> opt) -> llvm::StringRef { 736cc14bf22STom Eccles if (opt) 737cc14bf22STom Eccles return *opt; 738cc14bf22STom Eccles return {}; 739cc14bf22STom Eccles }; 740cc14bf22STom Eccles 741cc14bf22STom Eccles llvm::StringRef uniqName = unpackName(oldAlloc.getUniqName()); 742cc14bf22STom Eccles llvm::StringRef bindcName = unpackName(oldAlloc.getBindcName()); 743cc14bf22STom Eccles return rewriter.create<fir::AllocaOp>(loc, varTy, uniqName, bindcName, 744cc14bf22STom Eccles oldAlloc.getTypeparams(), 745cc14bf22STom Eccles oldAlloc.getShape()); 746cc14bf22STom Eccles } 747cc14bf22STom Eccles 748cc14bf22STom Eccles void AllocMemConversion::insertStackSaveRestore( 749cc14bf22STom Eccles fir::AllocMemOp &oldAlloc, mlir::PatternRewriter &rewriter) const { 750cc14bf22STom Eccles auto oldPoint = rewriter.saveInsertionPoint(); 751cc14bf22STom Eccles auto mod = oldAlloc->getParentOfType<mlir::ModuleOp>(); 75253cc33b0STom Eccles fir::FirOpBuilder builder{rewriter, mod}; 753cc14bf22STom Eccles 754cc14bf22STom Eccles builder.setInsertionPoint(oldAlloc); 7555aaf384bSTom Eccles mlir::Value sp = builder.genStackSave(oldAlloc.getLoc()); 756cc14bf22STom Eccles 757464d321eSKareem Ergawy auto createStackRestoreCall = [&](mlir::Operation *user) { 758cc14bf22STom Eccles builder.setInsertionPoint(user); 7595aaf384bSTom Eccles builder.genStackRestore(user->getLoc(), sp); 760464d321eSKareem Ergawy }; 761464d321eSKareem Ergawy 762464d321eSKareem Ergawy for (mlir::Operation *user : oldAlloc->getUsers()) { 763464d321eSKareem Ergawy if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) { 764464d321eSKareem Ergawy for (mlir::Operation *user : declareOp->getUsers()) { 765464d321eSKareem Ergawy if (mlir::isa<fir::FreeMemOp>(user)) 766464d321eSKareem Ergawy createStackRestoreCall(user); 767464d321eSKareem Ergawy } 768464d321eSKareem Ergawy } 769464d321eSKareem Ergawy 770464d321eSKareem Ergawy if (mlir::isa<fir::FreeMemOp>(user)) { 771464d321eSKareem Ergawy createStackRestoreCall(user); 772cc14bf22STom Eccles } 773cc14bf22STom Eccles } 774cc14bf22STom Eccles 775cc14bf22STom Eccles rewriter.restoreInsertionPoint(oldPoint); 776cc14bf22STom Eccles } 777cc14bf22STom Eccles 778cc14bf22STom Eccles StackArraysPass::StackArraysPass(const StackArraysPass &pass) 779cc14bf22STom Eccles : fir::impl::StackArraysBase<StackArraysPass>(pass) {} 780cc14bf22STom Eccles 781cc14bf22STom Eccles llvm::StringRef StackArraysPass::getDescription() const { 782cc14bf22STom Eccles return "Move heap allocated array temporaries to the stack"; 783cc14bf22STom Eccles } 784cc14bf22STom Eccles 785cc14bf22STom Eccles void StackArraysPass::runOnOperation() { 7861e64864cSTom Eccles mlir::func::FuncOp func = getOperation(); 787cc14bf22STom Eccles 788cc14bf22STom Eccles auto &analysis = getAnalysis<StackArraysAnalysisWrapper>(); 789408f4196STom Eccles const StackArraysAnalysisWrapper::AllocMemMap *candidateOps = 790408f4196STom Eccles analysis.getCandidateOps(func); 791408f4196STom Eccles if (!candidateOps) { 792cc14bf22STom Eccles signalPassFailure(); 793cc14bf22STom Eccles return; 794cc14bf22STom Eccles } 795cc14bf22STom Eccles 796408f4196STom Eccles if (candidateOps->empty()) 797cc14bf22STom Eccles return; 798408f4196STom Eccles runCount += candidateOps->size(); 799408f4196STom Eccles 800408f4196STom Eccles llvm::SmallVector<mlir::Operation *> opsToConvert; 801408f4196STom Eccles opsToConvert.reserve(candidateOps->size()); 802408f4196STom Eccles for (auto [op, _] : *candidateOps) 803408f4196STom Eccles opsToConvert.push_back(op); 804cc14bf22STom Eccles 805cc14bf22STom Eccles mlir::MLIRContext &context = getContext(); 806cc14bf22STom Eccles mlir::RewritePatternSet patterns(&context); 807408f4196STom Eccles mlir::GreedyRewriteConfig config; 808408f4196STom Eccles // prevent the pattern driver form merging blocks 809a506279eSMehdi Amini config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; 810cc14bf22STom Eccles 811408f4196STom Eccles patterns.insert<AllocMemConversion>(&context, *candidateOps); 81209dfc571SJacques Pienaar if (mlir::failed(mlir::applyOpPatternsGreedily( 81309dfc571SJacques Pienaar opsToConvert, std::move(patterns), config))) { 814cc14bf22STom Eccles mlir::emitError(func->getLoc(), "error in stack arrays optimization\n"); 815cc14bf22STom Eccles signalPassFailure(); 816cc14bf22STom Eccles } 817cc14bf22STom Eccles } 818