1 //===- StackArrays.cpp ----------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/FIRBuilder.h" 10 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" 11 #include "flang/Optimizer/Dialect/FIRAttr.h" 12 #include "flang/Optimizer/Dialect/FIRDialect.h" 13 #include "flang/Optimizer/Dialect/FIROps.h" 14 #include "flang/Optimizer/Dialect/FIRType.h" 15 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 16 #include "flang/Optimizer/Transforms/Passes.h" 17 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" 18 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 19 #include "mlir/Analysis/DataFlow/DenseAnalysis.h" 20 #include "mlir/Analysis/DataFlowFramework.h" 21 #include "mlir/Dialect/Func/IR/FuncOps.h" 22 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 23 #include "mlir/IR/Builders.h" 24 #include "mlir/IR/Diagnostics.h" 25 #include "mlir/IR/Value.h" 26 #include "mlir/Interfaces/LoopLikeInterface.h" 27 #include "mlir/Pass/Pass.h" 28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29 #include "mlir/Transforms/Passes.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/DenseSet.h" 32 #include "llvm/ADT/PointerUnion.h" 33 #include "llvm/Support/Casting.h" 34 #include "llvm/Support/raw_ostream.h" 35 #include <optional> 36 37 namespace fir { 38 #define GEN_PASS_DEF_STACKARRAYS 39 #include "flang/Optimizer/Transforms/Passes.h.inc" 40 } // namespace fir 41 42 #define DEBUG_TYPE "stack-arrays" 43 44 static llvm::cl::opt<std::size_t> maxAllocsPerFunc( 45 "stack-arrays-max-allocs", 46 llvm::cl::desc("The maximum number of heap allocations to consider in one " 47 "function before skipping (to save compilation time). Set " 48 "to 0 for no limit."), 49 llvm::cl::init(1000), llvm::cl::Hidden); 50 51 namespace { 52 53 /// The state of an SSA value at each program point 54 enum class AllocationState { 55 /// This means that the allocation state of a variable cannot be determined 56 /// at this program point, e.g. because one route through a conditional freed 57 /// the variable and the other route didn't. 58 /// This asserts a known-unknown: different from the unknown-unknown of having 59 /// no AllocationState stored for a particular SSA value 60 Unknown, 61 /// Means this SSA value was allocated on the heap in this function and has 62 /// now been freed 63 Freed, 64 /// Means this SSA value was allocated on the heap in this function and is a 65 /// candidate for moving to the stack 66 Allocated, 67 }; 68 69 /// Stores where an alloca should be inserted. If the PointerUnion is an 70 /// Operation the alloca should be inserted /after/ the operation. If it is a 71 /// block, the alloca can be placed anywhere in that block. 72 class InsertionPoint { 73 llvm::PointerUnion<mlir::Operation *, mlir::Block *> location; 74 bool saveRestoreStack; 75 76 /// Get contained pointer type or nullptr 77 template <class T> 78 T *tryGetPtr() const { 79 // Use llvm::dyn_cast_if_present because location may be null here. 80 if (T *ptr = llvm::dyn_cast_if_present<T *>(location)) 81 return ptr; 82 return nullptr; 83 } 84 85 public: 86 template <class T> 87 InsertionPoint(T *ptr, bool saveRestoreStack = false) 88 : location(ptr), saveRestoreStack{saveRestoreStack} {} 89 InsertionPoint(std::nullptr_t null) 90 : location(null), saveRestoreStack{false} {} 91 92 /// Get contained operation, or nullptr 93 mlir::Operation *tryGetOperation() const { 94 return tryGetPtr<mlir::Operation>(); 95 } 96 97 /// Get contained block, or nullptr 98 mlir::Block *tryGetBlock() const { return tryGetPtr<mlir::Block>(); } 99 100 /// Get whether the stack should be saved/restored. If yes, an llvm.stacksave 101 /// intrinsic should be added before the alloca, and an llvm.stackrestore 102 /// intrinsic should be added where the freemem is 103 bool shouldSaveRestoreStack() const { return saveRestoreStack; } 104 105 operator bool() const { return tryGetOperation() || tryGetBlock(); } 106 107 bool operator==(const InsertionPoint &rhs) const { 108 return (location == rhs.location) && 109 (saveRestoreStack == rhs.saveRestoreStack); 110 } 111 112 bool operator!=(const InsertionPoint &rhs) const { return !(*this == rhs); } 113 }; 114 115 /// Maps SSA values to their AllocationState at a particular program point. 116 /// Also caches the insertion points for the new alloca operations 117 class LatticePoint : public mlir::dataflow::AbstractDenseLattice { 118 // Maps all values we are interested in to states 119 llvm::SmallDenseMap<mlir::Value, AllocationState, 1> stateMap; 120 121 public: 122 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LatticePoint) 123 using AbstractDenseLattice::AbstractDenseLattice; 124 125 bool operator==(const LatticePoint &rhs) const { 126 return stateMap == rhs.stateMap; 127 } 128 129 /// Join the lattice accross control-flow edges 130 mlir::ChangeResult join(const AbstractDenseLattice &lattice) override; 131 132 void print(llvm::raw_ostream &os) const override; 133 134 /// Clear all modifications 135 mlir::ChangeResult reset(); 136 137 /// Set the state of an SSA value 138 mlir::ChangeResult set(mlir::Value value, AllocationState state); 139 140 /// Get fir.allocmem ops which were allocated in this function and always 141 /// freed before the function returns, plus whre to insert replacement 142 /// fir.alloca ops 143 void appendFreedValues(llvm::DenseSet<mlir::Value> &out) const; 144 145 std::optional<AllocationState> get(mlir::Value val) const; 146 }; 147 148 class AllocationAnalysis 149 : public mlir::dataflow::DenseForwardDataFlowAnalysis<LatticePoint> { 150 public: 151 using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis; 152 153 mlir::LogicalResult visitOperation(mlir::Operation *op, 154 const LatticePoint &before, 155 LatticePoint *after) override; 156 157 /// At an entry point, the last modifications of all memory resources are 158 /// yet to be determined 159 void setToEntryState(LatticePoint *lattice) override; 160 161 protected: 162 /// Visit control flow operations and decide whether to call visitOperation 163 /// to apply the transfer function 164 mlir::LogicalResult processOperation(mlir::Operation *op) override; 165 }; 166 167 /// Drives analysis to find candidate fir.allocmem operations which could be 168 /// moved to the stack. Intended to be used with mlir::Pass::getAnalysis 169 class StackArraysAnalysisWrapper { 170 public: 171 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StackArraysAnalysisWrapper) 172 173 // Maps fir.allocmem -> place to insert alloca 174 using AllocMemMap = llvm::DenseMap<mlir::Operation *, InsertionPoint>; 175 176 StackArraysAnalysisWrapper(mlir::Operation *op) {} 177 178 // returns nullptr if analysis failed 179 const AllocMemMap *getCandidateOps(mlir::Operation *func); 180 181 private: 182 llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps; 183 184 llvm::LogicalResult analyseFunction(mlir::Operation *func); 185 }; 186 187 /// Converts a fir.allocmem to a fir.alloca 188 class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> { 189 public: 190 explicit AllocMemConversion( 191 mlir::MLIRContext *ctx, 192 const StackArraysAnalysisWrapper::AllocMemMap &candidateOps) 193 : OpRewritePattern(ctx), candidateOps{candidateOps} {} 194 195 llvm::LogicalResult 196 matchAndRewrite(fir::AllocMemOp allocmem, 197 mlir::PatternRewriter &rewriter) const override; 198 199 /// Determine where to insert the alloca operation. The returned value should 200 /// be checked to see if it is inside a loop 201 static InsertionPoint 202 findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc, 203 const llvm::SmallVector<mlir::Operation *> &freeOps); 204 205 private: 206 /// Handle to the DFA (already run) 207 const StackArraysAnalysisWrapper::AllocMemMap &candidateOps; 208 209 /// If we failed to find an insertion point not inside a loop, see if it would 210 /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop 211 static InsertionPoint findAllocaLoopInsertionPoint( 212 fir::AllocMemOp &oldAlloc, 213 const llvm::SmallVector<mlir::Operation *> &freeOps); 214 215 /// Returns the alloca if it was successfully inserted, otherwise {} 216 std::optional<fir::AllocaOp> 217 insertAlloca(fir::AllocMemOp &oldAlloc, 218 mlir::PatternRewriter &rewriter) const; 219 220 /// Inserts a stacksave before oldAlloc and a stackrestore after each freemem 221 void insertStackSaveRestore(fir::AllocMemOp &oldAlloc, 222 mlir::PatternRewriter &rewriter) const; 223 }; 224 225 class StackArraysPass : public fir::impl::StackArraysBase<StackArraysPass> { 226 public: 227 StackArraysPass() = default; 228 StackArraysPass(const StackArraysPass &pass); 229 230 llvm::StringRef getDescription() const override; 231 232 void runOnOperation() override; 233 234 private: 235 Statistic runCount{this, "stackArraysRunCount", 236 "Number of heap allocations moved to the stack"}; 237 }; 238 239 } // namespace 240 241 static void print(llvm::raw_ostream &os, AllocationState state) { 242 switch (state) { 243 case AllocationState::Unknown: 244 os << "Unknown"; 245 break; 246 case AllocationState::Freed: 247 os << "Freed"; 248 break; 249 case AllocationState::Allocated: 250 os << "Allocated"; 251 break; 252 } 253 } 254 255 /// Join two AllocationStates for the same value coming from different CFG 256 /// blocks 257 static AllocationState join(AllocationState lhs, AllocationState rhs) { 258 // | Allocated | Freed | Unknown 259 // ========= | ========= | ========= | ========= 260 // Allocated | Allocated | Unknown | Unknown 261 // Freed | Unknown | Freed | Unknown 262 // Unknown | Unknown | Unknown | Unknown 263 if (lhs == rhs) 264 return lhs; 265 return AllocationState::Unknown; 266 } 267 268 mlir::ChangeResult LatticePoint::join(const AbstractDenseLattice &lattice) { 269 const auto &rhs = static_cast<const LatticePoint &>(lattice); 270 mlir::ChangeResult changed = mlir::ChangeResult::NoChange; 271 272 // add everything from rhs to map, handling cases where values are in both 273 for (const auto &[value, rhsState] : rhs.stateMap) { 274 auto it = stateMap.find(value); 275 if (it != stateMap.end()) { 276 // value is present in both maps 277 AllocationState myState = it->second; 278 AllocationState newState = ::join(myState, rhsState); 279 if (newState != myState) { 280 changed = mlir::ChangeResult::Change; 281 it->getSecond() = newState; 282 } 283 } else { 284 // value not present in current map: add it 285 stateMap.insert({value, rhsState}); 286 changed = mlir::ChangeResult::Change; 287 } 288 } 289 290 return changed; 291 } 292 293 void LatticePoint::print(llvm::raw_ostream &os) const { 294 for (const auto &[value, state] : stateMap) { 295 os << "\n * " << value << ": "; 296 ::print(os, state); 297 } 298 } 299 300 mlir::ChangeResult LatticePoint::reset() { 301 if (stateMap.empty()) 302 return mlir::ChangeResult::NoChange; 303 stateMap.clear(); 304 return mlir::ChangeResult::Change; 305 } 306 307 mlir::ChangeResult LatticePoint::set(mlir::Value value, AllocationState state) { 308 if (stateMap.count(value)) { 309 // already in map 310 AllocationState &oldState = stateMap[value]; 311 if (oldState != state) { 312 stateMap[value] = state; 313 return mlir::ChangeResult::Change; 314 } 315 return mlir::ChangeResult::NoChange; 316 } 317 stateMap.insert({value, state}); 318 return mlir::ChangeResult::Change; 319 } 320 321 /// Get values which were allocated in this function and always freed before 322 /// the function returns 323 void LatticePoint::appendFreedValues(llvm::DenseSet<mlir::Value> &out) const { 324 for (auto &[value, state] : stateMap) { 325 if (state == AllocationState::Freed) 326 out.insert(value); 327 } 328 } 329 330 std::optional<AllocationState> LatticePoint::get(mlir::Value val) const { 331 auto it = stateMap.find(val); 332 if (it == stateMap.end()) 333 return {}; 334 return it->second; 335 } 336 337 static mlir::Value lookThroughDeclaresAndConverts(mlir::Value value) { 338 while (mlir::Operation *op = value.getDefiningOp()) { 339 if (auto declareOp = llvm::dyn_cast<fir::DeclareOp>(op)) 340 value = declareOp.getMemref(); 341 else if (auto convertOp = llvm::dyn_cast<fir::ConvertOp>(op)) 342 value = convertOp->getOperand(0); 343 else 344 return value; 345 } 346 return value; 347 } 348 349 mlir::LogicalResult AllocationAnalysis::visitOperation( 350 mlir::Operation *op, const LatticePoint &before, LatticePoint *after) { 351 LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op 352 << "\n"); 353 LLVM_DEBUG(llvm::dbgs() << "--Lattice in: " << before << "\n"); 354 355 // propagate before -> after 356 mlir::ChangeResult changed = after->join(before); 357 358 if (auto allocmem = mlir::dyn_cast<fir::AllocMemOp>(op)) { 359 assert(op->getNumResults() == 1 && "fir.allocmem has one result"); 360 auto attr = op->getAttrOfType<fir::MustBeHeapAttr>( 361 fir::MustBeHeapAttr::getAttrName()); 362 if (attr && attr.getValue()) { 363 LLVM_DEBUG(llvm::dbgs() << "--Found fir.must_be_heap: skipping\n"); 364 // skip allocation marked not to be moved 365 return mlir::success(); 366 } 367 368 auto retTy = allocmem.getAllocatedType(); 369 if (!mlir::isa<fir::SequenceType>(retTy)) { 370 LLVM_DEBUG(llvm::dbgs() 371 << "--Allocation is not for an array: skipping\n"); 372 return mlir::success(); 373 } 374 375 mlir::Value result = op->getResult(0); 376 changed |= after->set(result, AllocationState::Allocated); 377 } else if (mlir::isa<fir::FreeMemOp>(op)) { 378 assert(op->getNumOperands() == 1 && "fir.freemem has one operand"); 379 mlir::Value operand = op->getOperand(0); 380 381 // Note: StackArrays is scheduled in the pass pipeline after lowering hlfir 382 // to fir. Therefore, we only need to handle `fir::DeclareOp`s. Also look 383 // past converts in case the pointer was changed between different pointer 384 // types. 385 operand = lookThroughDeclaresAndConverts(operand); 386 387 std::optional<AllocationState> operandState = before.get(operand); 388 if (operandState && *operandState == AllocationState::Allocated) { 389 // don't tag things not allocated in this function as freed, so that we 390 // don't think they are candidates for moving to the stack 391 changed |= after->set(operand, AllocationState::Freed); 392 } 393 } else if (mlir::isa<fir::ResultOp>(op)) { 394 mlir::Operation *parent = op->getParentOp(); 395 LatticePoint *parentLattice = getLattice(getProgramPointAfter(parent)); 396 assert(parentLattice); 397 mlir::ChangeResult parentChanged = parentLattice->join(*after); 398 propagateIfChanged(parentLattice, parentChanged); 399 } 400 401 // we pass lattices straight through fir.call because called functions should 402 // not deallocate flang-generated array temporaries 403 404 LLVM_DEBUG(llvm::dbgs() << "--Lattice out: " << *after << "\n"); 405 propagateIfChanged(after, changed); 406 return mlir::success(); 407 } 408 409 void AllocationAnalysis::setToEntryState(LatticePoint *lattice) { 410 propagateIfChanged(lattice, lattice->reset()); 411 } 412 413 /// Mostly a copy of AbstractDenseLattice::processOperation - the difference 414 /// being that call operations are passed through to the transfer function 415 mlir::LogicalResult AllocationAnalysis::processOperation(mlir::Operation *op) { 416 mlir::ProgramPoint *point = getProgramPointAfter(op); 417 // If the containing block is not executable, bail out. 418 if (op->getBlock() != nullptr && 419 !getOrCreateFor<mlir::dataflow::Executable>( 420 point, getProgramPointBefore(op->getBlock())) 421 ->isLive()) 422 return mlir::success(); 423 424 // Get the dense lattice to update 425 mlir::dataflow::AbstractDenseLattice *after = getLattice(point); 426 427 // If this op implements region control-flow, then control-flow dictates its 428 // transfer function. 429 if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) { 430 visitRegionBranchOperation(point, branch, after); 431 return mlir::success(); 432 } 433 434 // pass call operations through to the transfer function 435 436 // Get the dense state before the execution of the op. 437 const mlir::dataflow::AbstractDenseLattice *before = 438 getLatticeFor(point, getProgramPointBefore(op)); 439 440 /// Invoke the operation transfer function 441 return visitOperationImpl(op, *before, after); 442 } 443 444 llvm::LogicalResult 445 StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) { 446 assert(mlir::isa<mlir::func::FuncOp>(func)); 447 size_t nAllocs = 0; 448 func->walk([&nAllocs](fir::AllocMemOp) { nAllocs++; }); 449 // don't bother with the analysis if there are no heap allocations 450 if (nAllocs == 0) 451 return mlir::success(); 452 if ((maxAllocsPerFunc != 0) && (nAllocs > maxAllocsPerFunc)) { 453 LLVM_DEBUG(llvm::dbgs() << "Skipping stack arrays for function with " 454 << nAllocs << " heap allocations"); 455 return mlir::success(); 456 } 457 458 mlir::DataFlowSolver solver; 459 // constant propagation is required for dead code analysis, dead code analysis 460 // is required to mark blocks live (required for mlir dense dfa) 461 solver.load<mlir::dataflow::SparseConstantPropagation>(); 462 solver.load<mlir::dataflow::DeadCodeAnalysis>(); 463 464 auto [it, inserted] = funcMaps.try_emplace(func); 465 AllocMemMap &candidateOps = it->second; 466 467 solver.load<AllocationAnalysis>(); 468 if (failed(solver.initializeAndRun(func))) { 469 llvm::errs() << "DataFlowSolver failed!"; 470 return mlir::failure(); 471 } 472 473 LatticePoint point{solver.getProgramPointAfter(func)}; 474 auto joinOperationLattice = [&](mlir::Operation *op) { 475 const LatticePoint *lattice = 476 solver.lookupState<LatticePoint>(solver.getProgramPointAfter(op)); 477 // there will be no lattice for an unreachable block 478 if (lattice) 479 (void)point.join(*lattice); 480 }; 481 482 func->walk([&](mlir::func::ReturnOp child) { joinOperationLattice(child); }); 483 func->walk([&](fir::UnreachableOp child) { joinOperationLattice(child); }); 484 func->walk( 485 [&](mlir::omp::TerminatorOp child) { joinOperationLattice(child); }); 486 func->walk([&](mlir::omp::YieldOp child) { joinOperationLattice(child); }); 487 488 llvm::DenseSet<mlir::Value> freedValues; 489 point.appendFreedValues(freedValues); 490 491 // Find all fir.freemem operations corresponding to fir.allocmem 492 // in freedValues. It is best to find the association going back 493 // from fir.freemem to fir.allocmem through the def-use chains, 494 // so that we can use lookThroughDeclaresAndConverts same way 495 // the AllocationAnalysis is handling them. 496 llvm::DenseMap<mlir::Operation *, llvm::SmallVector<mlir::Operation *>> 497 allocToFreeMemMap; 498 func->walk([&](fir::FreeMemOp freeOp) { 499 mlir::Value memref = lookThroughDeclaresAndConverts(freeOp.getHeapref()); 500 if (!freedValues.count(memref)) 501 return; 502 503 auto allocMem = memref.getDefiningOp<fir::AllocMemOp>(); 504 allocToFreeMemMap[allocMem].push_back(freeOp); 505 }); 506 507 // We only replace allocations which are definately freed on all routes 508 // through the function because otherwise the allocation may have an intende 509 // lifetime longer than the current stack frame (e.g. a heap allocation which 510 // is then freed by another function). 511 for (mlir::Value freedValue : freedValues) { 512 fir::AllocMemOp allocmem = freedValue.getDefiningOp<fir::AllocMemOp>(); 513 InsertionPoint insertionPoint = 514 AllocMemConversion::findAllocaInsertionPoint( 515 allocmem, allocToFreeMemMap[allocmem]); 516 if (insertionPoint) 517 candidateOps.insert({allocmem, insertionPoint}); 518 } 519 520 LLVM_DEBUG(for (auto [allocMemOp, _] 521 : candidateOps) { 522 llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n'; 523 }); 524 return mlir::success(); 525 } 526 527 const StackArraysAnalysisWrapper::AllocMemMap * 528 StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) { 529 if (!funcMaps.contains(func)) 530 if (mlir::failed(analyseFunction(func))) 531 return nullptr; 532 return &funcMaps[func]; 533 } 534 535 /// Restore the old allocation type exected by existing code 536 static mlir::Value convertAllocationType(mlir::PatternRewriter &rewriter, 537 const mlir::Location &loc, 538 mlir::Value heap, mlir::Value stack) { 539 mlir::Type heapTy = heap.getType(); 540 mlir::Type stackTy = stack.getType(); 541 542 if (heapTy == stackTy) 543 return stack; 544 545 fir::HeapType firHeapTy = mlir::cast<fir::HeapType>(heapTy); 546 LLVM_ATTRIBUTE_UNUSED fir::ReferenceType firRefTy = 547 mlir::cast<fir::ReferenceType>(stackTy); 548 assert(firHeapTy.getElementType() == firRefTy.getElementType() && 549 "Allocations must have the same type"); 550 551 auto insertionPoint = rewriter.saveInsertionPoint(); 552 rewriter.setInsertionPointAfter(stack.getDefiningOp()); 553 mlir::Value conv = 554 rewriter.create<fir::ConvertOp>(loc, firHeapTy, stack).getResult(); 555 rewriter.restoreInsertionPoint(insertionPoint); 556 return conv; 557 } 558 559 llvm::LogicalResult 560 AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem, 561 mlir::PatternRewriter &rewriter) const { 562 auto oldInsertionPt = rewriter.saveInsertionPoint(); 563 // add alloca operation 564 std::optional<fir::AllocaOp> alloca = insertAlloca(allocmem, rewriter); 565 rewriter.restoreInsertionPoint(oldInsertionPt); 566 if (!alloca) 567 return mlir::failure(); 568 569 // remove freemem operations 570 llvm::SmallVector<mlir::Operation *> erases; 571 mlir::Operation *parent = allocmem->getParentOp(); 572 // TODO: this shouldn't need to be re-calculated for every allocmem 573 parent->walk([&](fir::FreeMemOp freeOp) { 574 if (lookThroughDeclaresAndConverts(freeOp->getOperand(0)) == allocmem) 575 erases.push_back(freeOp); 576 }); 577 578 // now we are done iterating the users, it is safe to mutate them 579 for (mlir::Operation *erase : erases) 580 rewriter.eraseOp(erase); 581 582 // replace references to heap allocation with references to stack allocation 583 mlir::Value newValue = convertAllocationType( 584 rewriter, allocmem.getLoc(), allocmem.getResult(), alloca->getResult()); 585 rewriter.replaceAllUsesWith(allocmem.getResult(), newValue); 586 587 // remove allocmem operation 588 rewriter.eraseOp(allocmem.getOperation()); 589 590 return mlir::success(); 591 } 592 593 static bool isInLoop(mlir::Block *block) { 594 return mlir::LoopLikeOpInterface::blockIsInLoop(block); 595 } 596 597 static bool isInLoop(mlir::Operation *op) { 598 return isInLoop(op->getBlock()) || 599 op->getParentOfType<mlir::LoopLikeOpInterface>(); 600 } 601 602 InsertionPoint AllocMemConversion::findAllocaInsertionPoint( 603 fir::AllocMemOp &oldAlloc, 604 const llvm::SmallVector<mlir::Operation *> &freeOps) { 605 // Ideally the alloca should be inserted at the end of the function entry 606 // block so that we do not allocate stack space in a loop. However, 607 // the operands to the alloca may not be available that early, so insert it 608 // after the last operand becomes available 609 // If the old allocmem op was in an openmp region then it should not be moved 610 // outside of that 611 LLVM_DEBUG(llvm::dbgs() << "StackArrays: findAllocaInsertionPoint: " 612 << oldAlloc << "\n"); 613 614 // check that an Operation or Block we are about to return is not in a loop 615 auto checkReturn = [&](auto *point) -> InsertionPoint { 616 if (isInLoop(point)) { 617 mlir::Operation *oldAllocOp = oldAlloc.getOperation(); 618 if (isInLoop(oldAllocOp)) { 619 // where we want to put it is in a loop, and even the old location is in 620 // a loop. Give up. 621 return findAllocaLoopInsertionPoint(oldAlloc, freeOps); 622 } 623 return {oldAllocOp}; 624 } 625 return {point}; 626 }; 627 628 auto oldOmpRegion = 629 oldAlloc->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>(); 630 631 // Find when the last operand value becomes available 632 mlir::Block *operandsBlock = nullptr; 633 mlir::Operation *lastOperand = nullptr; 634 for (mlir::Value operand : oldAlloc.getOperands()) { 635 LLVM_DEBUG(llvm::dbgs() << "--considering operand " << operand << "\n"); 636 mlir::Operation *op = operand.getDefiningOp(); 637 if (!op) 638 return checkReturn(oldAlloc.getOperation()); 639 if (!operandsBlock) 640 operandsBlock = op->getBlock(); 641 else if (operandsBlock != op->getBlock()) { 642 LLVM_DEBUG(llvm::dbgs() 643 << "----operand declared in a different block!\n"); 644 // Operation::isBeforeInBlock requires the operations to be in the same 645 // block. The best we can do is the location of the allocmem. 646 return checkReturn(oldAlloc.getOperation()); 647 } 648 if (!lastOperand || lastOperand->isBeforeInBlock(op)) 649 lastOperand = op; 650 } 651 652 if (lastOperand) { 653 // there were value operands to the allocmem so insert after the last one 654 LLVM_DEBUG(llvm::dbgs() 655 << "--Placing after last operand: " << *lastOperand << "\n"); 656 // check we aren't moving out of an omp region 657 auto lastOpOmpRegion = 658 lastOperand->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>(); 659 if (lastOpOmpRegion == oldOmpRegion) 660 return checkReturn(lastOperand); 661 // Presumably this happened because the operands became ready before the 662 // start of this openmp region. (lastOpOmpRegion != oldOmpRegion) should 663 // imply that oldOmpRegion comes after lastOpOmpRegion. 664 return checkReturn(oldOmpRegion.getAllocaBlock()); 665 } 666 667 // There were no value operands to the allocmem so we are safe to insert it 668 // as early as we want 669 670 // handle openmp case 671 if (oldOmpRegion) 672 return checkReturn(oldOmpRegion.getAllocaBlock()); 673 674 // fall back to the function entry block 675 mlir::func::FuncOp func = oldAlloc->getParentOfType<mlir::func::FuncOp>(); 676 assert(func && "This analysis is run on func.func"); 677 mlir::Block &entryBlock = func.getBlocks().front(); 678 LLVM_DEBUG(llvm::dbgs() << "--Placing at the start of func entry block\n"); 679 return checkReturn(&entryBlock); 680 } 681 682 InsertionPoint AllocMemConversion::findAllocaLoopInsertionPoint( 683 fir::AllocMemOp &oldAlloc, 684 const llvm::SmallVector<mlir::Operation *> &freeOps) { 685 mlir::Operation *oldAllocOp = oldAlloc; 686 // This is only called as a last resort. We should try to insert at the 687 // location of the old allocation, which is inside of a loop, using 688 // llvm.stacksave/llvm.stackrestore 689 690 assert(freeOps.size() && "DFA should only return freed memory"); 691 692 // Don't attempt to reason about a stacksave/stackrestore between different 693 // blocks 694 for (mlir::Operation *free : freeOps) 695 if (free->getBlock() != oldAllocOp->getBlock()) 696 return {nullptr}; 697 698 // Check that there aren't any other stack allocations in between the 699 // stack save and stack restore 700 // note: for flang generated temporaries there should only be one free op 701 for (mlir::Operation *free : freeOps) { 702 for (mlir::Operation *op = oldAlloc; op && op != free; 703 op = op->getNextNode()) { 704 if (mlir::isa<fir::AllocaOp>(op)) 705 return {nullptr}; 706 } 707 } 708 709 return InsertionPoint{oldAllocOp, /*shouldStackSaveRestore=*/true}; 710 } 711 712 std::optional<fir::AllocaOp> 713 AllocMemConversion::insertAlloca(fir::AllocMemOp &oldAlloc, 714 mlir::PatternRewriter &rewriter) const { 715 auto it = candidateOps.find(oldAlloc.getOperation()); 716 if (it == candidateOps.end()) 717 return {}; 718 InsertionPoint insertionPoint = it->second; 719 if (!insertionPoint) 720 return {}; 721 722 if (insertionPoint.shouldSaveRestoreStack()) 723 insertStackSaveRestore(oldAlloc, rewriter); 724 725 mlir::Location loc = oldAlloc.getLoc(); 726 mlir::Type varTy = oldAlloc.getInType(); 727 if (mlir::Operation *op = insertionPoint.tryGetOperation()) { 728 rewriter.setInsertionPointAfter(op); 729 } else { 730 mlir::Block *block = insertionPoint.tryGetBlock(); 731 assert(block && "There must be a valid insertion point"); 732 rewriter.setInsertionPointToStart(block); 733 } 734 735 auto unpackName = [](std::optional<llvm::StringRef> opt) -> llvm::StringRef { 736 if (opt) 737 return *opt; 738 return {}; 739 }; 740 741 llvm::StringRef uniqName = unpackName(oldAlloc.getUniqName()); 742 llvm::StringRef bindcName = unpackName(oldAlloc.getBindcName()); 743 return rewriter.create<fir::AllocaOp>(loc, varTy, uniqName, bindcName, 744 oldAlloc.getTypeparams(), 745 oldAlloc.getShape()); 746 } 747 748 void AllocMemConversion::insertStackSaveRestore( 749 fir::AllocMemOp &oldAlloc, mlir::PatternRewriter &rewriter) const { 750 auto oldPoint = rewriter.saveInsertionPoint(); 751 auto mod = oldAlloc->getParentOfType<mlir::ModuleOp>(); 752 fir::FirOpBuilder builder{rewriter, mod}; 753 754 builder.setInsertionPoint(oldAlloc); 755 mlir::Value sp = builder.genStackSave(oldAlloc.getLoc()); 756 757 auto createStackRestoreCall = [&](mlir::Operation *user) { 758 builder.setInsertionPoint(user); 759 builder.genStackRestore(user->getLoc(), sp); 760 }; 761 762 for (mlir::Operation *user : oldAlloc->getUsers()) { 763 if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) { 764 for (mlir::Operation *user : declareOp->getUsers()) { 765 if (mlir::isa<fir::FreeMemOp>(user)) 766 createStackRestoreCall(user); 767 } 768 } 769 770 if (mlir::isa<fir::FreeMemOp>(user)) { 771 createStackRestoreCall(user); 772 } 773 } 774 775 rewriter.restoreInsertionPoint(oldPoint); 776 } 777 778 StackArraysPass::StackArraysPass(const StackArraysPass &pass) 779 : fir::impl::StackArraysBase<StackArraysPass>(pass) {} 780 781 llvm::StringRef StackArraysPass::getDescription() const { 782 return "Move heap allocated array temporaries to the stack"; 783 } 784 785 void StackArraysPass::runOnOperation() { 786 mlir::func::FuncOp func = getOperation(); 787 788 auto &analysis = getAnalysis<StackArraysAnalysisWrapper>(); 789 const StackArraysAnalysisWrapper::AllocMemMap *candidateOps = 790 analysis.getCandidateOps(func); 791 if (!candidateOps) { 792 signalPassFailure(); 793 return; 794 } 795 796 if (candidateOps->empty()) 797 return; 798 runCount += candidateOps->size(); 799 800 llvm::SmallVector<mlir::Operation *> opsToConvert; 801 opsToConvert.reserve(candidateOps->size()); 802 for (auto [op, _] : *candidateOps) 803 opsToConvert.push_back(op); 804 805 mlir::MLIRContext &context = getContext(); 806 mlir::RewritePatternSet patterns(&context); 807 mlir::GreedyRewriteConfig config; 808 // prevent the pattern driver form merging blocks 809 config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; 810 811 patterns.insert<AllocMemConversion>(&context, *candidateOps); 812 if (mlir::failed(mlir::applyOpPatternsGreedily( 813 opsToConvert, std::move(patterns), config))) { 814 mlir::emitError(func->getLoc(), "error in stack arrays optimization\n"); 815 signalPassFailure(); 816 } 817 } 818