1 //===- StackArrays.cpp ----------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/FIRBuilder.h" 10 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" 11 #include "flang/Optimizer/Dialect/FIRAttr.h" 12 #include "flang/Optimizer/Dialect/FIRDialect.h" 13 #include "flang/Optimizer/Dialect/FIROps.h" 14 #include "flang/Optimizer/Dialect/FIRType.h" 15 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 16 #include "flang/Optimizer/Transforms/Passes.h" 17 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" 18 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 19 #include "mlir/Analysis/DataFlow/DenseAnalysis.h" 20 #include "mlir/Analysis/DataFlowFramework.h" 21 #include "mlir/Dialect/Func/IR/FuncOps.h" 22 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 23 #include "mlir/IR/Builders.h" 24 #include "mlir/IR/Diagnostics.h" 25 #include "mlir/IR/Value.h" 26 #include "mlir/Interfaces/LoopLikeInterface.h" 27 #include "mlir/Pass/Pass.h" 28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29 #include "mlir/Transforms/Passes.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/DenseSet.h" 32 #include "llvm/ADT/PointerUnion.h" 33 #include "llvm/Support/Casting.h" 34 #include "llvm/Support/raw_ostream.h" 35 #include <optional> 36 37 namespace fir { 38 #define GEN_PASS_DEF_STACKARRAYS 39 #include "flang/Optimizer/Transforms/Passes.h.inc" 40 } // namespace fir 41 42 #define DEBUG_TYPE "stack-arrays" 43 44 static llvm::cl::opt<std::size_t> maxAllocsPerFunc( 45 "stack-arrays-max-allocs", 46 llvm::cl::desc("The maximum number of heap allocations to consider in one " 47 "function before skipping (to save compilation time). Set " 48 "to 0 for no limit."), 49 llvm::cl::init(1000), llvm::cl::Hidden); 50 51 namespace { 52 53 /// The state of an SSA value at each program point 54 enum class AllocationState { 55 /// This means that the allocation state of a variable cannot be determined 56 /// at this program point, e.g. because one route through a conditional freed 57 /// the variable and the other route didn't. 58 /// This asserts a known-unknown: different from the unknown-unknown of having 59 /// no AllocationState stored for a particular SSA value 60 Unknown, 61 /// Means this SSA value was allocated on the heap in this function and has 62 /// now been freed 63 Freed, 64 /// Means this SSA value was allocated on the heap in this function and is a 65 /// candidate for moving to the stack 66 Allocated, 67 }; 68 69 /// Stores where an alloca should be inserted. If the PointerUnion is an 70 /// Operation the alloca should be inserted /after/ the operation. If it is a 71 /// block, the alloca can be placed anywhere in that block. 72 class InsertionPoint { 73 llvm::PointerUnion<mlir::Operation *, mlir::Block *> location; 74 bool saveRestoreStack; 75 76 /// Get contained pointer type or nullptr 77 template <class T> 78 T *tryGetPtr() const { 79 if (location.is<T *>()) 80 return location.get<T *>(); 81 return nullptr; 82 } 83 84 public: 85 template <class T> 86 InsertionPoint(T *ptr, bool saveRestoreStack = false) 87 : location(ptr), saveRestoreStack{saveRestoreStack} {} 88 InsertionPoint(std::nullptr_t null) 89 : location(null), saveRestoreStack{false} {} 90 91 /// Get contained operation, or nullptr 92 mlir::Operation *tryGetOperation() const { 93 return tryGetPtr<mlir::Operation>(); 94 } 95 96 /// Get contained block, or nullptr 97 mlir::Block *tryGetBlock() const { return tryGetPtr<mlir::Block>(); } 98 99 /// Get whether the stack should be saved/restored. If yes, an llvm.stacksave 100 /// intrinsic should be added before the alloca, and an llvm.stackrestore 101 /// intrinsic should be added where the freemem is 102 bool shouldSaveRestoreStack() const { return saveRestoreStack; } 103 104 operator bool() const { return tryGetOperation() || tryGetBlock(); } 105 106 bool operator==(const InsertionPoint &rhs) const { 107 return (location == rhs.location) && 108 (saveRestoreStack == rhs.saveRestoreStack); 109 } 110 111 bool operator!=(const InsertionPoint &rhs) const { return !(*this == rhs); } 112 }; 113 114 /// Maps SSA values to their AllocationState at a particular program point. 115 /// Also caches the insertion points for the new alloca operations 116 class LatticePoint : public mlir::dataflow::AbstractDenseLattice { 117 // Maps all values we are interested in to states 118 llvm::SmallDenseMap<mlir::Value, AllocationState, 1> stateMap; 119 120 public: 121 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LatticePoint) 122 using AbstractDenseLattice::AbstractDenseLattice; 123 124 bool operator==(const LatticePoint &rhs) const { 125 return stateMap == rhs.stateMap; 126 } 127 128 /// Join the lattice accross control-flow edges 129 mlir::ChangeResult join(const AbstractDenseLattice &lattice) override; 130 131 void print(llvm::raw_ostream &os) const override; 132 133 /// Clear all modifications 134 mlir::ChangeResult reset(); 135 136 /// Set the state of an SSA value 137 mlir::ChangeResult set(mlir::Value value, AllocationState state); 138 139 /// Get fir.allocmem ops which were allocated in this function and always 140 /// freed before the function returns, plus whre to insert replacement 141 /// fir.alloca ops 142 void appendFreedValues(llvm::DenseSet<mlir::Value> &out) const; 143 144 std::optional<AllocationState> get(mlir::Value val) const; 145 }; 146 147 class AllocationAnalysis 148 : public mlir::dataflow::DenseForwardDataFlowAnalysis<LatticePoint> { 149 public: 150 using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis; 151 152 mlir::LogicalResult visitOperation(mlir::Operation *op, 153 const LatticePoint &before, 154 LatticePoint *after) override; 155 156 /// At an entry point, the last modifications of all memory resources are 157 /// yet to be determined 158 void setToEntryState(LatticePoint *lattice) override; 159 160 protected: 161 /// Visit control flow operations and decide whether to call visitOperation 162 /// to apply the transfer function 163 mlir::LogicalResult processOperation(mlir::Operation *op) override; 164 }; 165 166 /// Drives analysis to find candidate fir.allocmem operations which could be 167 /// moved to the stack. Intended to be used with mlir::Pass::getAnalysis 168 class StackArraysAnalysisWrapper { 169 public: 170 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StackArraysAnalysisWrapper) 171 172 // Maps fir.allocmem -> place to insert alloca 173 using AllocMemMap = llvm::DenseMap<mlir::Operation *, InsertionPoint>; 174 175 StackArraysAnalysisWrapper(mlir::Operation *op) {} 176 177 // returns nullptr if analysis failed 178 const AllocMemMap *getCandidateOps(mlir::Operation *func); 179 180 private: 181 llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps; 182 183 llvm::LogicalResult analyseFunction(mlir::Operation *func); 184 }; 185 186 /// Converts a fir.allocmem to a fir.alloca 187 class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> { 188 public: 189 explicit AllocMemConversion( 190 mlir::MLIRContext *ctx, 191 const StackArraysAnalysisWrapper::AllocMemMap &candidateOps) 192 : OpRewritePattern(ctx), candidateOps{candidateOps} {} 193 194 llvm::LogicalResult 195 matchAndRewrite(fir::AllocMemOp allocmem, 196 mlir::PatternRewriter &rewriter) const override; 197 198 /// Determine where to insert the alloca operation. The returned value should 199 /// be checked to see if it is inside a loop 200 static InsertionPoint findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc); 201 202 private: 203 /// Handle to the DFA (already run) 204 const StackArraysAnalysisWrapper::AllocMemMap &candidateOps; 205 206 /// If we failed to find an insertion point not inside a loop, see if it would 207 /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop 208 static InsertionPoint findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc); 209 210 /// Returns the alloca if it was successfully inserted, otherwise {} 211 std::optional<fir::AllocaOp> 212 insertAlloca(fir::AllocMemOp &oldAlloc, 213 mlir::PatternRewriter &rewriter) const; 214 215 /// Inserts a stacksave before oldAlloc and a stackrestore after each freemem 216 void insertStackSaveRestore(fir::AllocMemOp &oldAlloc, 217 mlir::PatternRewriter &rewriter) const; 218 }; 219 220 class StackArraysPass : public fir::impl::StackArraysBase<StackArraysPass> { 221 public: 222 StackArraysPass() = default; 223 StackArraysPass(const StackArraysPass &pass); 224 225 llvm::StringRef getDescription() const override; 226 227 void runOnOperation() override; 228 229 private: 230 Statistic runCount{this, "stackArraysRunCount", 231 "Number of heap allocations moved to the stack"}; 232 }; 233 234 } // namespace 235 236 static void print(llvm::raw_ostream &os, AllocationState state) { 237 switch (state) { 238 case AllocationState::Unknown: 239 os << "Unknown"; 240 break; 241 case AllocationState::Freed: 242 os << "Freed"; 243 break; 244 case AllocationState::Allocated: 245 os << "Allocated"; 246 break; 247 } 248 } 249 250 /// Join two AllocationStates for the same value coming from different CFG 251 /// blocks 252 static AllocationState join(AllocationState lhs, AllocationState rhs) { 253 // | Allocated | Freed | Unknown 254 // ========= | ========= | ========= | ========= 255 // Allocated | Allocated | Unknown | Unknown 256 // Freed | Unknown | Freed | Unknown 257 // Unknown | Unknown | Unknown | Unknown 258 if (lhs == rhs) 259 return lhs; 260 return AllocationState::Unknown; 261 } 262 263 mlir::ChangeResult LatticePoint::join(const AbstractDenseLattice &lattice) { 264 const auto &rhs = static_cast<const LatticePoint &>(lattice); 265 mlir::ChangeResult changed = mlir::ChangeResult::NoChange; 266 267 // add everything from rhs to map, handling cases where values are in both 268 for (const auto &[value, rhsState] : rhs.stateMap) { 269 auto it = stateMap.find(value); 270 if (it != stateMap.end()) { 271 // value is present in both maps 272 AllocationState myState = it->second; 273 AllocationState newState = ::join(myState, rhsState); 274 if (newState != myState) { 275 changed = mlir::ChangeResult::Change; 276 it->getSecond() = newState; 277 } 278 } else { 279 // value not present in current map: add it 280 stateMap.insert({value, rhsState}); 281 changed = mlir::ChangeResult::Change; 282 } 283 } 284 285 return changed; 286 } 287 288 void LatticePoint::print(llvm::raw_ostream &os) const { 289 for (const auto &[value, state] : stateMap) { 290 os << "\n * " << value << ": "; 291 ::print(os, state); 292 } 293 } 294 295 mlir::ChangeResult LatticePoint::reset() { 296 if (stateMap.empty()) 297 return mlir::ChangeResult::NoChange; 298 stateMap.clear(); 299 return mlir::ChangeResult::Change; 300 } 301 302 mlir::ChangeResult LatticePoint::set(mlir::Value value, AllocationState state) { 303 if (stateMap.count(value)) { 304 // already in map 305 AllocationState &oldState = stateMap[value]; 306 if (oldState != state) { 307 stateMap[value] = state; 308 return mlir::ChangeResult::Change; 309 } 310 return mlir::ChangeResult::NoChange; 311 } 312 stateMap.insert({value, state}); 313 return mlir::ChangeResult::Change; 314 } 315 316 /// Get values which were allocated in this function and always freed before 317 /// the function returns 318 void LatticePoint::appendFreedValues(llvm::DenseSet<mlir::Value> &out) const { 319 for (auto &[value, state] : stateMap) { 320 if (state == AllocationState::Freed) 321 out.insert(value); 322 } 323 } 324 325 std::optional<AllocationState> LatticePoint::get(mlir::Value val) const { 326 auto it = stateMap.find(val); 327 if (it == stateMap.end()) 328 return {}; 329 return it->second; 330 } 331 332 mlir::LogicalResult AllocationAnalysis::visitOperation( 333 mlir::Operation *op, const LatticePoint &before, LatticePoint *after) { 334 LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op 335 << "\n"); 336 LLVM_DEBUG(llvm::dbgs() << "--Lattice in: " << before << "\n"); 337 338 // propagate before -> after 339 mlir::ChangeResult changed = after->join(before); 340 341 if (auto allocmem = mlir::dyn_cast<fir::AllocMemOp>(op)) { 342 assert(op->getNumResults() == 1 && "fir.allocmem has one result"); 343 auto attr = op->getAttrOfType<fir::MustBeHeapAttr>( 344 fir::MustBeHeapAttr::getAttrName()); 345 if (attr && attr.getValue()) { 346 LLVM_DEBUG(llvm::dbgs() << "--Found fir.must_be_heap: skipping\n"); 347 // skip allocation marked not to be moved 348 return mlir::success(); 349 } 350 351 auto retTy = allocmem.getAllocatedType(); 352 if (!mlir::isa<fir::SequenceType>(retTy)) { 353 LLVM_DEBUG(llvm::dbgs() 354 << "--Allocation is not for an array: skipping\n"); 355 return mlir::success(); 356 } 357 358 mlir::Value result = op->getResult(0); 359 changed |= after->set(result, AllocationState::Allocated); 360 } else if (mlir::isa<fir::FreeMemOp>(op)) { 361 assert(op->getNumOperands() == 1 && "fir.freemem has one operand"); 362 mlir::Value operand = op->getOperand(0); 363 364 // Note: StackArrays is scheduled in the pass pipeline after lowering hlfir 365 // to fir. Therefore, we only need to handle `fir::DeclareOp`s. 366 if (auto declareOp = 367 llvm::dyn_cast_if_present<fir::DeclareOp>(operand.getDefiningOp())) 368 operand = declareOp.getMemref(); 369 370 std::optional<AllocationState> operandState = before.get(operand); 371 if (operandState && *operandState == AllocationState::Allocated) { 372 // don't tag things not allocated in this function as freed, so that we 373 // don't think they are candidates for moving to the stack 374 changed |= after->set(operand, AllocationState::Freed); 375 } 376 } else if (mlir::isa<fir::ResultOp>(op)) { 377 mlir::Operation *parent = op->getParentOp(); 378 LatticePoint *parentLattice = getLattice(getProgramPointAfter(parent)); 379 assert(parentLattice); 380 mlir::ChangeResult parentChanged = parentLattice->join(*after); 381 propagateIfChanged(parentLattice, parentChanged); 382 } 383 384 // we pass lattices straight through fir.call because called functions should 385 // not deallocate flang-generated array temporaries 386 387 LLVM_DEBUG(llvm::dbgs() << "--Lattice out: " << *after << "\n"); 388 propagateIfChanged(after, changed); 389 return mlir::success(); 390 } 391 392 void AllocationAnalysis::setToEntryState(LatticePoint *lattice) { 393 propagateIfChanged(lattice, lattice->reset()); 394 } 395 396 /// Mostly a copy of AbstractDenseLattice::processOperation - the difference 397 /// being that call operations are passed through to the transfer function 398 mlir::LogicalResult AllocationAnalysis::processOperation(mlir::Operation *op) { 399 mlir::ProgramPoint *point = getProgramPointAfter(op); 400 // If the containing block is not executable, bail out. 401 if (op->getBlock() != nullptr && 402 !getOrCreateFor<mlir::dataflow::Executable>( 403 point, getProgramPointBefore(op->getBlock())) 404 ->isLive()) 405 return mlir::success(); 406 407 // Get the dense lattice to update 408 mlir::dataflow::AbstractDenseLattice *after = getLattice(point); 409 410 // If this op implements region control-flow, then control-flow dictates its 411 // transfer function. 412 if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) { 413 visitRegionBranchOperation(point, branch, after); 414 return mlir::success(); 415 } 416 417 // pass call operations through to the transfer function 418 419 // Get the dense state before the execution of the op. 420 const mlir::dataflow::AbstractDenseLattice *before = 421 getLatticeFor(point, getProgramPointBefore(op)); 422 423 /// Invoke the operation transfer function 424 return visitOperationImpl(op, *before, after); 425 } 426 427 llvm::LogicalResult 428 StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) { 429 assert(mlir::isa<mlir::func::FuncOp>(func)); 430 size_t nAllocs = 0; 431 func->walk([&nAllocs](fir::AllocMemOp) { nAllocs++; }); 432 // don't bother with the analysis if there are no heap allocations 433 if (nAllocs == 0) 434 return mlir::success(); 435 if ((maxAllocsPerFunc != 0) && (nAllocs > maxAllocsPerFunc)) { 436 LLVM_DEBUG(llvm::dbgs() << "Skipping stack arrays for function with " 437 << nAllocs << " heap allocations"); 438 return mlir::success(); 439 } 440 441 mlir::DataFlowSolver solver; 442 // constant propagation is required for dead code analysis, dead code analysis 443 // is required to mark blocks live (required for mlir dense dfa) 444 solver.load<mlir::dataflow::SparseConstantPropagation>(); 445 solver.load<mlir::dataflow::DeadCodeAnalysis>(); 446 447 auto [it, inserted] = funcMaps.try_emplace(func); 448 AllocMemMap &candidateOps = it->second; 449 450 solver.load<AllocationAnalysis>(); 451 if (failed(solver.initializeAndRun(func))) { 452 llvm::errs() << "DataFlowSolver failed!"; 453 return mlir::failure(); 454 } 455 456 LatticePoint point{solver.getProgramPointAfter(func)}; 457 auto joinOperationLattice = [&](mlir::Operation *op) { 458 const LatticePoint *lattice = 459 solver.lookupState<LatticePoint>(solver.getProgramPointAfter(op)); 460 // there will be no lattice for an unreachable block 461 if (lattice) 462 (void)point.join(*lattice); 463 }; 464 465 func->walk([&](mlir::func::ReturnOp child) { joinOperationLattice(child); }); 466 func->walk([&](fir::UnreachableOp child) { joinOperationLattice(child); }); 467 func->walk( 468 [&](mlir::omp::TerminatorOp child) { joinOperationLattice(child); }); 469 func->walk([&](mlir::omp::YieldOp child) { joinOperationLattice(child); }); 470 471 llvm::DenseSet<mlir::Value> freedValues; 472 point.appendFreedValues(freedValues); 473 474 // We only replace allocations which are definately freed on all routes 475 // through the function because otherwise the allocation may have an intende 476 // lifetime longer than the current stack frame (e.g. a heap allocation which 477 // is then freed by another function). 478 for (mlir::Value freedValue : freedValues) { 479 fir::AllocMemOp allocmem = freedValue.getDefiningOp<fir::AllocMemOp>(); 480 InsertionPoint insertionPoint = 481 AllocMemConversion::findAllocaInsertionPoint(allocmem); 482 if (insertionPoint) 483 candidateOps.insert({allocmem, insertionPoint}); 484 } 485 486 LLVM_DEBUG(for (auto [allocMemOp, _] 487 : candidateOps) { 488 llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n'; 489 }); 490 return mlir::success(); 491 } 492 493 const StackArraysAnalysisWrapper::AllocMemMap * 494 StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) { 495 if (!funcMaps.contains(func)) 496 if (mlir::failed(analyseFunction(func))) 497 return nullptr; 498 return &funcMaps[func]; 499 } 500 501 /// Restore the old allocation type exected by existing code 502 static mlir::Value convertAllocationType(mlir::PatternRewriter &rewriter, 503 const mlir::Location &loc, 504 mlir::Value heap, mlir::Value stack) { 505 mlir::Type heapTy = heap.getType(); 506 mlir::Type stackTy = stack.getType(); 507 508 if (heapTy == stackTy) 509 return stack; 510 511 fir::HeapType firHeapTy = mlir::cast<fir::HeapType>(heapTy); 512 LLVM_ATTRIBUTE_UNUSED fir::ReferenceType firRefTy = 513 mlir::cast<fir::ReferenceType>(stackTy); 514 assert(firHeapTy.getElementType() == firRefTy.getElementType() && 515 "Allocations must have the same type"); 516 517 auto insertionPoint = rewriter.saveInsertionPoint(); 518 rewriter.setInsertionPointAfter(stack.getDefiningOp()); 519 mlir::Value conv = 520 rewriter.create<fir::ConvertOp>(loc, firHeapTy, stack).getResult(); 521 rewriter.restoreInsertionPoint(insertionPoint); 522 return conv; 523 } 524 525 llvm::LogicalResult 526 AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem, 527 mlir::PatternRewriter &rewriter) const { 528 auto oldInsertionPt = rewriter.saveInsertionPoint(); 529 // add alloca operation 530 std::optional<fir::AllocaOp> alloca = insertAlloca(allocmem, rewriter); 531 rewriter.restoreInsertionPoint(oldInsertionPt); 532 if (!alloca) 533 return mlir::failure(); 534 535 // remove freemem operations 536 llvm::SmallVector<mlir::Operation *> erases; 537 for (mlir::Operation *user : allocmem.getOperation()->getUsers()) { 538 if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) { 539 for (mlir::Operation *user : declareOp->getUsers()) { 540 if (mlir::isa<fir::FreeMemOp>(user)) 541 erases.push_back(user); 542 } 543 } 544 545 if (mlir::isa<fir::FreeMemOp>(user)) 546 erases.push_back(user); 547 } 548 549 // now we are done iterating the users, it is safe to mutate them 550 for (mlir::Operation *erase : erases) 551 rewriter.eraseOp(erase); 552 553 // replace references to heap allocation with references to stack allocation 554 mlir::Value newValue = convertAllocationType( 555 rewriter, allocmem.getLoc(), allocmem.getResult(), alloca->getResult()); 556 rewriter.replaceAllUsesWith(allocmem.getResult(), newValue); 557 558 // remove allocmem operation 559 rewriter.eraseOp(allocmem.getOperation()); 560 561 return mlir::success(); 562 } 563 564 static bool isInLoop(mlir::Block *block) { 565 return mlir::LoopLikeOpInterface::blockIsInLoop(block); 566 } 567 568 static bool isInLoop(mlir::Operation *op) { 569 return isInLoop(op->getBlock()) || 570 op->getParentOfType<mlir::LoopLikeOpInterface>(); 571 } 572 573 InsertionPoint 574 AllocMemConversion::findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc) { 575 // Ideally the alloca should be inserted at the end of the function entry 576 // block so that we do not allocate stack space in a loop. However, 577 // the operands to the alloca may not be available that early, so insert it 578 // after the last operand becomes available 579 // If the old allocmem op was in an openmp region then it should not be moved 580 // outside of that 581 LLVM_DEBUG(llvm::dbgs() << "StackArrays: findAllocaInsertionPoint: " 582 << oldAlloc << "\n"); 583 584 // check that an Operation or Block we are about to return is not in a loop 585 auto checkReturn = [&](auto *point) -> InsertionPoint { 586 if (isInLoop(point)) { 587 mlir::Operation *oldAllocOp = oldAlloc.getOperation(); 588 if (isInLoop(oldAllocOp)) { 589 // where we want to put it is in a loop, and even the old location is in 590 // a loop. Give up. 591 return findAllocaLoopInsertionPoint(oldAlloc); 592 } 593 return {oldAllocOp}; 594 } 595 return {point}; 596 }; 597 598 auto oldOmpRegion = 599 oldAlloc->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>(); 600 601 // Find when the last operand value becomes available 602 mlir::Block *operandsBlock = nullptr; 603 mlir::Operation *lastOperand = nullptr; 604 for (mlir::Value operand : oldAlloc.getOperands()) { 605 LLVM_DEBUG(llvm::dbgs() << "--considering operand " << operand << "\n"); 606 mlir::Operation *op = operand.getDefiningOp(); 607 if (!op) 608 return checkReturn(oldAlloc.getOperation()); 609 if (!operandsBlock) 610 operandsBlock = op->getBlock(); 611 else if (operandsBlock != op->getBlock()) { 612 LLVM_DEBUG(llvm::dbgs() 613 << "----operand declared in a different block!\n"); 614 // Operation::isBeforeInBlock requires the operations to be in the same 615 // block. The best we can do is the location of the allocmem. 616 return checkReturn(oldAlloc.getOperation()); 617 } 618 if (!lastOperand || lastOperand->isBeforeInBlock(op)) 619 lastOperand = op; 620 } 621 622 if (lastOperand) { 623 // there were value operands to the allocmem so insert after the last one 624 LLVM_DEBUG(llvm::dbgs() 625 << "--Placing after last operand: " << *lastOperand << "\n"); 626 // check we aren't moving out of an omp region 627 auto lastOpOmpRegion = 628 lastOperand->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>(); 629 if (lastOpOmpRegion == oldOmpRegion) 630 return checkReturn(lastOperand); 631 // Presumably this happened because the operands became ready before the 632 // start of this openmp region. (lastOpOmpRegion != oldOmpRegion) should 633 // imply that oldOmpRegion comes after lastOpOmpRegion. 634 return checkReturn(oldOmpRegion.getAllocaBlock()); 635 } 636 637 // There were no value operands to the allocmem so we are safe to insert it 638 // as early as we want 639 640 // handle openmp case 641 if (oldOmpRegion) 642 return checkReturn(oldOmpRegion.getAllocaBlock()); 643 644 // fall back to the function entry block 645 mlir::func::FuncOp func = oldAlloc->getParentOfType<mlir::func::FuncOp>(); 646 assert(func && "This analysis is run on func.func"); 647 mlir::Block &entryBlock = func.getBlocks().front(); 648 LLVM_DEBUG(llvm::dbgs() << "--Placing at the start of func entry block\n"); 649 return checkReturn(&entryBlock); 650 } 651 652 InsertionPoint 653 AllocMemConversion::findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc) { 654 mlir::Operation *oldAllocOp = oldAlloc; 655 // This is only called as a last resort. We should try to insert at the 656 // location of the old allocation, which is inside of a loop, using 657 // llvm.stacksave/llvm.stackrestore 658 659 // find freemem ops 660 llvm::SmallVector<mlir::Operation *, 1> freeOps; 661 662 for (mlir::Operation *user : oldAllocOp->getUsers()) { 663 if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) { 664 for (mlir::Operation *user : declareOp->getUsers()) { 665 if (mlir::isa<fir::FreeMemOp>(user)) 666 freeOps.push_back(user); 667 } 668 } 669 670 if (mlir::isa<fir::FreeMemOp>(user)) 671 freeOps.push_back(user); 672 } 673 674 assert(freeOps.size() && "DFA should only return freed memory"); 675 676 // Don't attempt to reason about a stacksave/stackrestore between different 677 // blocks 678 for (mlir::Operation *free : freeOps) 679 if (free->getBlock() != oldAllocOp->getBlock()) 680 return {nullptr}; 681 682 // Check that there aren't any other stack allocations in between the 683 // stack save and stack restore 684 // note: for flang generated temporaries there should only be one free op 685 for (mlir::Operation *free : freeOps) { 686 for (mlir::Operation *op = oldAlloc; op && op != free; 687 op = op->getNextNode()) { 688 if (mlir::isa<fir::AllocaOp>(op)) 689 return {nullptr}; 690 } 691 } 692 693 return InsertionPoint{oldAllocOp, /*shouldStackSaveRestore=*/true}; 694 } 695 696 std::optional<fir::AllocaOp> 697 AllocMemConversion::insertAlloca(fir::AllocMemOp &oldAlloc, 698 mlir::PatternRewriter &rewriter) const { 699 auto it = candidateOps.find(oldAlloc.getOperation()); 700 if (it == candidateOps.end()) 701 return {}; 702 InsertionPoint insertionPoint = it->second; 703 if (!insertionPoint) 704 return {}; 705 706 if (insertionPoint.shouldSaveRestoreStack()) 707 insertStackSaveRestore(oldAlloc, rewriter); 708 709 mlir::Location loc = oldAlloc.getLoc(); 710 mlir::Type varTy = oldAlloc.getInType(); 711 if (mlir::Operation *op = insertionPoint.tryGetOperation()) { 712 rewriter.setInsertionPointAfter(op); 713 } else { 714 mlir::Block *block = insertionPoint.tryGetBlock(); 715 assert(block && "There must be a valid insertion point"); 716 rewriter.setInsertionPointToStart(block); 717 } 718 719 auto unpackName = [](std::optional<llvm::StringRef> opt) -> llvm::StringRef { 720 if (opt) 721 return *opt; 722 return {}; 723 }; 724 725 llvm::StringRef uniqName = unpackName(oldAlloc.getUniqName()); 726 llvm::StringRef bindcName = unpackName(oldAlloc.getBindcName()); 727 return rewriter.create<fir::AllocaOp>(loc, varTy, uniqName, bindcName, 728 oldAlloc.getTypeparams(), 729 oldAlloc.getShape()); 730 } 731 732 void AllocMemConversion::insertStackSaveRestore( 733 fir::AllocMemOp &oldAlloc, mlir::PatternRewriter &rewriter) const { 734 auto oldPoint = rewriter.saveInsertionPoint(); 735 auto mod = oldAlloc->getParentOfType<mlir::ModuleOp>(); 736 fir::FirOpBuilder builder{rewriter, mod}; 737 738 builder.setInsertionPoint(oldAlloc); 739 mlir::Value sp = builder.genStackSave(oldAlloc.getLoc()); 740 741 auto createStackRestoreCall = [&](mlir::Operation *user) { 742 builder.setInsertionPoint(user); 743 builder.genStackRestore(user->getLoc(), sp); 744 }; 745 746 for (mlir::Operation *user : oldAlloc->getUsers()) { 747 if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) { 748 for (mlir::Operation *user : declareOp->getUsers()) { 749 if (mlir::isa<fir::FreeMemOp>(user)) 750 createStackRestoreCall(user); 751 } 752 } 753 754 if (mlir::isa<fir::FreeMemOp>(user)) { 755 createStackRestoreCall(user); 756 } 757 } 758 759 rewriter.restoreInsertionPoint(oldPoint); 760 } 761 762 StackArraysPass::StackArraysPass(const StackArraysPass &pass) 763 : fir::impl::StackArraysBase<StackArraysPass>(pass) {} 764 765 llvm::StringRef StackArraysPass::getDescription() const { 766 return "Move heap allocated array temporaries to the stack"; 767 } 768 769 void StackArraysPass::runOnOperation() { 770 mlir::func::FuncOp func = getOperation(); 771 772 auto &analysis = getAnalysis<StackArraysAnalysisWrapper>(); 773 const StackArraysAnalysisWrapper::AllocMemMap *candidateOps = 774 analysis.getCandidateOps(func); 775 if (!candidateOps) { 776 signalPassFailure(); 777 return; 778 } 779 780 if (candidateOps->empty()) 781 return; 782 runCount += candidateOps->size(); 783 784 llvm::SmallVector<mlir::Operation *> opsToConvert; 785 opsToConvert.reserve(candidateOps->size()); 786 for (auto [op, _] : *candidateOps) 787 opsToConvert.push_back(op); 788 789 mlir::MLIRContext &context = getContext(); 790 mlir::RewritePatternSet patterns(&context); 791 mlir::GreedyRewriteConfig config; 792 // prevent the pattern driver form merging blocks 793 config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; 794 795 patterns.insert<AllocMemConversion>(&context, *candidateOps); 796 if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert, 797 std::move(patterns), config))) { 798 mlir::emitError(func->getLoc(), "error in stack arrays optimization\n"); 799 signalPassFailure(); 800 } 801 } 802