1 //===- StackArrays.cpp ----------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/FIRBuilder.h" 10 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" 11 #include "flang/Optimizer/Dialect/FIRAttr.h" 12 #include "flang/Optimizer/Dialect/FIRDialect.h" 13 #include "flang/Optimizer/Dialect/FIROps.h" 14 #include "flang/Optimizer/Dialect/FIRType.h" 15 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 16 #include "flang/Optimizer/Transforms/Passes.h" 17 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" 18 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 19 #include "mlir/Analysis/DataFlow/DenseAnalysis.h" 20 #include "mlir/Analysis/DataFlowFramework.h" 21 #include "mlir/Dialect/Func/IR/FuncOps.h" 22 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 23 #include "mlir/IR/Builders.h" 24 #include "mlir/IR/Diagnostics.h" 25 #include "mlir/IR/Value.h" 26 #include "mlir/Interfaces/LoopLikeInterface.h" 27 #include "mlir/Pass/Pass.h" 28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29 #include "mlir/Transforms/Passes.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/DenseSet.h" 32 #include "llvm/ADT/PointerUnion.h" 33 #include "llvm/Support/Casting.h" 34 #include "llvm/Support/raw_ostream.h" 35 #include <optional> 36 37 namespace fir { 38 #define GEN_PASS_DEF_STACKARRAYS 39 #include "flang/Optimizer/Transforms/Passes.h.inc" 40 } // namespace fir 41 42 #define DEBUG_TYPE "stack-arrays" 43 44 static llvm::cl::opt<std::size_t> maxAllocsPerFunc( 45 "stack-arrays-max-allocs", 46 llvm::cl::desc("The maximum number of heap allocations to consider in one " 47 "function before skipping (to save compilation time). Set " 48 "to 0 for no limit."), 49 llvm::cl::init(1000), llvm::cl::Hidden); 50 51 namespace { 52 53 /// The state of an SSA value at each program point 54 enum class AllocationState { 55 /// This means that the allocation state of a variable cannot be determined 56 /// at this program point, e.g. because one route through a conditional freed 57 /// the variable and the other route didn't. 58 /// This asserts a known-unknown: different from the unknown-unknown of having 59 /// no AllocationState stored for a particular SSA value 60 Unknown, 61 /// Means this SSA value was allocated on the heap in this function and has 62 /// now been freed 63 Freed, 64 /// Means this SSA value was allocated on the heap in this function and is a 65 /// candidate for moving to the stack 66 Allocated, 67 }; 68 69 /// Stores where an alloca should be inserted. If the PointerUnion is an 70 /// Operation the alloca should be inserted /after/ the operation. If it is a 71 /// block, the alloca can be placed anywhere in that block. 72 class InsertionPoint { 73 llvm::PointerUnion<mlir::Operation *, mlir::Block *> location; 74 bool saveRestoreStack; 75 76 /// Get contained pointer type or nullptr 77 template <class T> 78 T *tryGetPtr() const { 79 if (location.is<T *>()) 80 return location.get<T *>(); 81 return nullptr; 82 } 83 84 public: 85 template <class T> 86 InsertionPoint(T *ptr, bool saveRestoreStack = false) 87 : location(ptr), saveRestoreStack{saveRestoreStack} {} 88 InsertionPoint(std::nullptr_t null) 89 : location(null), saveRestoreStack{false} {} 90 91 /// Get contained operation, or nullptr 92 mlir::Operation *tryGetOperation() const { 93 return tryGetPtr<mlir::Operation>(); 94 } 95 96 /// Get contained block, or nullptr 97 mlir::Block *tryGetBlock() const { return tryGetPtr<mlir::Block>(); } 98 99 /// Get whether the stack should be saved/restored. If yes, an llvm.stacksave 100 /// intrinsic should be added before the alloca, and an llvm.stackrestore 101 /// intrinsic should be added where the freemem is 102 bool shouldSaveRestoreStack() const { return saveRestoreStack; } 103 104 operator bool() const { return tryGetOperation() || tryGetBlock(); } 105 106 bool operator==(const InsertionPoint &rhs) const { 107 return (location == rhs.location) && 108 (saveRestoreStack == rhs.saveRestoreStack); 109 } 110 111 bool operator!=(const InsertionPoint &rhs) const { return !(*this == rhs); } 112 }; 113 114 /// Maps SSA values to their AllocationState at a particular program point. 115 /// Also caches the insertion points for the new alloca operations 116 class LatticePoint : public mlir::dataflow::AbstractDenseLattice { 117 // Maps all values we are interested in to states 118 llvm::SmallDenseMap<mlir::Value, AllocationState, 1> stateMap; 119 120 public: 121 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LatticePoint) 122 using AbstractDenseLattice::AbstractDenseLattice; 123 124 bool operator==(const LatticePoint &rhs) const { 125 return stateMap == rhs.stateMap; 126 } 127 128 /// Join the lattice accross control-flow edges 129 mlir::ChangeResult join(const AbstractDenseLattice &lattice) override; 130 131 void print(llvm::raw_ostream &os) const override; 132 133 /// Clear all modifications 134 mlir::ChangeResult reset(); 135 136 /// Set the state of an SSA value 137 mlir::ChangeResult set(mlir::Value value, AllocationState state); 138 139 /// Get fir.allocmem ops which were allocated in this function and always 140 /// freed before the function returns, plus whre to insert replacement 141 /// fir.alloca ops 142 void appendFreedValues(llvm::DenseSet<mlir::Value> &out) const; 143 144 std::optional<AllocationState> get(mlir::Value val) const; 145 }; 146 147 class AllocationAnalysis 148 : public mlir::dataflow::DenseForwardDataFlowAnalysis<LatticePoint> { 149 public: 150 using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis; 151 152 void visitOperation(mlir::Operation *op, const LatticePoint &before, 153 LatticePoint *after) override; 154 155 /// At an entry point, the last modifications of all memory resources are 156 /// yet to be determined 157 void setToEntryState(LatticePoint *lattice) override; 158 159 protected: 160 /// Visit control flow operations and decide whether to call visitOperation 161 /// to apply the transfer function 162 void processOperation(mlir::Operation *op) override; 163 }; 164 165 /// Drives analysis to find candidate fir.allocmem operations which could be 166 /// moved to the stack. Intended to be used with mlir::Pass::getAnalysis 167 class StackArraysAnalysisWrapper { 168 public: 169 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StackArraysAnalysisWrapper) 170 171 // Maps fir.allocmem -> place to insert alloca 172 using AllocMemMap = llvm::DenseMap<mlir::Operation *, InsertionPoint>; 173 174 StackArraysAnalysisWrapper(mlir::Operation *op) {} 175 176 // returns nullptr if analysis failed 177 const AllocMemMap *getCandidateOps(mlir::Operation *func); 178 179 private: 180 llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps; 181 182 llvm::LogicalResult analyseFunction(mlir::Operation *func); 183 }; 184 185 /// Converts a fir.allocmem to a fir.alloca 186 class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> { 187 public: 188 explicit AllocMemConversion( 189 mlir::MLIRContext *ctx, 190 const StackArraysAnalysisWrapper::AllocMemMap &candidateOps) 191 : OpRewritePattern(ctx), candidateOps{candidateOps} {} 192 193 llvm::LogicalResult 194 matchAndRewrite(fir::AllocMemOp allocmem, 195 mlir::PatternRewriter &rewriter) const override; 196 197 /// Determine where to insert the alloca operation. The returned value should 198 /// be checked to see if it is inside a loop 199 static InsertionPoint findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc); 200 201 private: 202 /// Handle to the DFA (already run) 203 const StackArraysAnalysisWrapper::AllocMemMap &candidateOps; 204 205 /// If we failed to find an insertion point not inside a loop, see if it would 206 /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop 207 static InsertionPoint findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc); 208 209 /// Returns the alloca if it was successfully inserted, otherwise {} 210 std::optional<fir::AllocaOp> 211 insertAlloca(fir::AllocMemOp &oldAlloc, 212 mlir::PatternRewriter &rewriter) const; 213 214 /// Inserts a stacksave before oldAlloc and a stackrestore after each freemem 215 void insertStackSaveRestore(fir::AllocMemOp &oldAlloc, 216 mlir::PatternRewriter &rewriter) const; 217 }; 218 219 class StackArraysPass : public fir::impl::StackArraysBase<StackArraysPass> { 220 public: 221 StackArraysPass() = default; 222 StackArraysPass(const StackArraysPass &pass); 223 224 llvm::StringRef getDescription() const override; 225 226 void runOnOperation() override; 227 void runOnFunc(mlir::Operation *func); 228 229 private: 230 Statistic runCount{this, "stackArraysRunCount", 231 "Number of heap allocations moved to the stack"}; 232 }; 233 234 } // namespace 235 236 static void print(llvm::raw_ostream &os, AllocationState state) { 237 switch (state) { 238 case AllocationState::Unknown: 239 os << "Unknown"; 240 break; 241 case AllocationState::Freed: 242 os << "Freed"; 243 break; 244 case AllocationState::Allocated: 245 os << "Allocated"; 246 break; 247 } 248 } 249 250 /// Join two AllocationStates for the same value coming from different CFG 251 /// blocks 252 static AllocationState join(AllocationState lhs, AllocationState rhs) { 253 // | Allocated | Freed | Unknown 254 // ========= | ========= | ========= | ========= 255 // Allocated | Allocated | Unknown | Unknown 256 // Freed | Unknown | Freed | Unknown 257 // Unknown | Unknown | Unknown | Unknown 258 if (lhs == rhs) 259 return lhs; 260 return AllocationState::Unknown; 261 } 262 263 mlir::ChangeResult LatticePoint::join(const AbstractDenseLattice &lattice) { 264 const auto &rhs = static_cast<const LatticePoint &>(lattice); 265 mlir::ChangeResult changed = mlir::ChangeResult::NoChange; 266 267 // add everything from rhs to map, handling cases where values are in both 268 for (const auto &[value, rhsState] : rhs.stateMap) { 269 auto it = stateMap.find(value); 270 if (it != stateMap.end()) { 271 // value is present in both maps 272 AllocationState myState = it->second; 273 AllocationState newState = ::join(myState, rhsState); 274 if (newState != myState) { 275 changed = mlir::ChangeResult::Change; 276 it->getSecond() = newState; 277 } 278 } else { 279 // value not present in current map: add it 280 stateMap.insert({value, rhsState}); 281 changed = mlir::ChangeResult::Change; 282 } 283 } 284 285 return changed; 286 } 287 288 void LatticePoint::print(llvm::raw_ostream &os) const { 289 for (const auto &[value, state] : stateMap) { 290 os << "\n * " << value << ": "; 291 ::print(os, state); 292 } 293 } 294 295 mlir::ChangeResult LatticePoint::reset() { 296 if (stateMap.empty()) 297 return mlir::ChangeResult::NoChange; 298 stateMap.clear(); 299 return mlir::ChangeResult::Change; 300 } 301 302 mlir::ChangeResult LatticePoint::set(mlir::Value value, AllocationState state) { 303 if (stateMap.count(value)) { 304 // already in map 305 AllocationState &oldState = stateMap[value]; 306 if (oldState != state) { 307 stateMap[value] = state; 308 return mlir::ChangeResult::Change; 309 } 310 return mlir::ChangeResult::NoChange; 311 } 312 stateMap.insert({value, state}); 313 return mlir::ChangeResult::Change; 314 } 315 316 /// Get values which were allocated in this function and always freed before 317 /// the function returns 318 void LatticePoint::appendFreedValues(llvm::DenseSet<mlir::Value> &out) const { 319 for (auto &[value, state] : stateMap) { 320 if (state == AllocationState::Freed) 321 out.insert(value); 322 } 323 } 324 325 std::optional<AllocationState> LatticePoint::get(mlir::Value val) const { 326 auto it = stateMap.find(val); 327 if (it == stateMap.end()) 328 return {}; 329 return it->second; 330 } 331 332 void AllocationAnalysis::visitOperation(mlir::Operation *op, 333 const LatticePoint &before, 334 LatticePoint *after) { 335 LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op 336 << "\n"); 337 LLVM_DEBUG(llvm::dbgs() << "--Lattice in: " << before << "\n"); 338 339 // propagate before -> after 340 mlir::ChangeResult changed = after->join(before); 341 342 if (auto allocmem = mlir::dyn_cast<fir::AllocMemOp>(op)) { 343 assert(op->getNumResults() == 1 && "fir.allocmem has one result"); 344 auto attr = op->getAttrOfType<fir::MustBeHeapAttr>( 345 fir::MustBeHeapAttr::getAttrName()); 346 if (attr && attr.getValue()) { 347 LLVM_DEBUG(llvm::dbgs() << "--Found fir.must_be_heap: skipping\n"); 348 // skip allocation marked not to be moved 349 return; 350 } 351 352 auto retTy = allocmem.getAllocatedType(); 353 if (!mlir::isa<fir::SequenceType>(retTy)) { 354 LLVM_DEBUG(llvm::dbgs() 355 << "--Allocation is not for an array: skipping\n"); 356 return; 357 } 358 359 mlir::Value result = op->getResult(0); 360 changed |= after->set(result, AllocationState::Allocated); 361 } else if (mlir::isa<fir::FreeMemOp>(op)) { 362 assert(op->getNumOperands() == 1 && "fir.freemem has one operand"); 363 mlir::Value operand = op->getOperand(0); 364 365 // Note: StackArrays is scheduled in the pass pipeline after lowering hlfir 366 // to fir. Therefore, we only need to handle `fir::DeclareOp`s. 367 if (auto declareOp = 368 llvm::dyn_cast_if_present<fir::DeclareOp>(operand.getDefiningOp())) 369 operand = declareOp.getMemref(); 370 371 std::optional<AllocationState> operandState = before.get(operand); 372 if (operandState && *operandState == AllocationState::Allocated) { 373 // don't tag things not allocated in this function as freed, so that we 374 // don't think they are candidates for moving to the stack 375 changed |= after->set(operand, AllocationState::Freed); 376 } 377 } else if (mlir::isa<fir::ResultOp>(op)) { 378 mlir::Operation *parent = op->getParentOp(); 379 LatticePoint *parentLattice = getLattice(parent); 380 assert(parentLattice); 381 mlir::ChangeResult parentChanged = parentLattice->join(*after); 382 propagateIfChanged(parentLattice, parentChanged); 383 } 384 385 // we pass lattices straight through fir.call because called functions should 386 // not deallocate flang-generated array temporaries 387 388 LLVM_DEBUG(llvm::dbgs() << "--Lattice out: " << *after << "\n"); 389 propagateIfChanged(after, changed); 390 } 391 392 void AllocationAnalysis::setToEntryState(LatticePoint *lattice) { 393 propagateIfChanged(lattice, lattice->reset()); 394 } 395 396 /// Mostly a copy of AbstractDenseLattice::processOperation - the difference 397 /// being that call operations are passed through to the transfer function 398 void AllocationAnalysis::processOperation(mlir::Operation *op) { 399 // If the containing block is not executable, bail out. 400 if (!getOrCreateFor<mlir::dataflow::Executable>(op, op->getBlock())->isLive()) 401 return; 402 403 // Get the dense lattice to update 404 mlir::dataflow::AbstractDenseLattice *after = getLattice(op); 405 406 // If this op implements region control-flow, then control-flow dictates its 407 // transfer function. 408 if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) 409 return visitRegionBranchOperation(op, branch, after); 410 411 // pass call operations through to the transfer function 412 413 // Get the dense state before the execution of the op. 414 const mlir::dataflow::AbstractDenseLattice *before; 415 if (mlir::Operation *prev = op->getPrevNode()) 416 before = getLatticeFor(op, prev); 417 else 418 before = getLatticeFor(op, op->getBlock()); 419 420 /// Invoke the operation transfer function 421 visitOperationImpl(op, *before, after); 422 } 423 424 llvm::LogicalResult 425 StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) { 426 assert(mlir::isa<mlir::func::FuncOp>(func)); 427 size_t nAllocs = 0; 428 func->walk([&nAllocs](fir::AllocMemOp) { nAllocs++; }); 429 // don't bother with the analysis if there are no heap allocations 430 if (nAllocs == 0) 431 return mlir::success(); 432 if ((maxAllocsPerFunc != 0) && (nAllocs > maxAllocsPerFunc)) { 433 LLVM_DEBUG(llvm::dbgs() << "Skipping stack arrays for function with " 434 << nAllocs << " heap allocations"); 435 return mlir::success(); 436 } 437 438 mlir::DataFlowSolver solver; 439 // constant propagation is required for dead code analysis, dead code analysis 440 // is required to mark blocks live (required for mlir dense dfa) 441 solver.load<mlir::dataflow::SparseConstantPropagation>(); 442 solver.load<mlir::dataflow::DeadCodeAnalysis>(); 443 444 auto [it, inserted] = funcMaps.try_emplace(func); 445 AllocMemMap &candidateOps = it->second; 446 447 solver.load<AllocationAnalysis>(); 448 if (failed(solver.initializeAndRun(func))) { 449 llvm::errs() << "DataFlowSolver failed!"; 450 return mlir::failure(); 451 } 452 453 LatticePoint point{func}; 454 auto joinOperationLattice = [&](mlir::Operation *op) { 455 const LatticePoint *lattice = solver.lookupState<LatticePoint>(op); 456 // there will be no lattice for an unreachable block 457 if (lattice) 458 (void)point.join(*lattice); 459 }; 460 461 func->walk([&](mlir::func::ReturnOp child) { joinOperationLattice(child); }); 462 func->walk([&](fir::UnreachableOp child) { joinOperationLattice(child); }); 463 func->walk( 464 [&](mlir::omp::TerminatorOp child) { joinOperationLattice(child); }); 465 func->walk([&](mlir::omp::YieldOp child) { joinOperationLattice(child); }); 466 467 llvm::DenseSet<mlir::Value> freedValues; 468 point.appendFreedValues(freedValues); 469 470 // We only replace allocations which are definately freed on all routes 471 // through the function because otherwise the allocation may have an intende 472 // lifetime longer than the current stack frame (e.g. a heap allocation which 473 // is then freed by another function). 474 for (mlir::Value freedValue : freedValues) { 475 fir::AllocMemOp allocmem = freedValue.getDefiningOp<fir::AllocMemOp>(); 476 InsertionPoint insertionPoint = 477 AllocMemConversion::findAllocaInsertionPoint(allocmem); 478 if (insertionPoint) 479 candidateOps.insert({allocmem, insertionPoint}); 480 } 481 482 LLVM_DEBUG(for (auto [allocMemOp, _] 483 : candidateOps) { 484 llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n'; 485 }); 486 return mlir::success(); 487 } 488 489 const StackArraysAnalysisWrapper::AllocMemMap * 490 StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) { 491 if (!funcMaps.contains(func)) 492 if (mlir::failed(analyseFunction(func))) 493 return nullptr; 494 return &funcMaps[func]; 495 } 496 497 /// Restore the old allocation type exected by existing code 498 static mlir::Value convertAllocationType(mlir::PatternRewriter &rewriter, 499 const mlir::Location &loc, 500 mlir::Value heap, mlir::Value stack) { 501 mlir::Type heapTy = heap.getType(); 502 mlir::Type stackTy = stack.getType(); 503 504 if (heapTy == stackTy) 505 return stack; 506 507 fir::HeapType firHeapTy = mlir::cast<fir::HeapType>(heapTy); 508 LLVM_ATTRIBUTE_UNUSED fir::ReferenceType firRefTy = 509 mlir::cast<fir::ReferenceType>(stackTy); 510 assert(firHeapTy.getElementType() == firRefTy.getElementType() && 511 "Allocations must have the same type"); 512 513 auto insertionPoint = rewriter.saveInsertionPoint(); 514 rewriter.setInsertionPointAfter(stack.getDefiningOp()); 515 mlir::Value conv = 516 rewriter.create<fir::ConvertOp>(loc, firHeapTy, stack).getResult(); 517 rewriter.restoreInsertionPoint(insertionPoint); 518 return conv; 519 } 520 521 llvm::LogicalResult 522 AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem, 523 mlir::PatternRewriter &rewriter) const { 524 auto oldInsertionPt = rewriter.saveInsertionPoint(); 525 // add alloca operation 526 std::optional<fir::AllocaOp> alloca = insertAlloca(allocmem, rewriter); 527 rewriter.restoreInsertionPoint(oldInsertionPt); 528 if (!alloca) 529 return mlir::failure(); 530 531 // remove freemem operations 532 llvm::SmallVector<mlir::Operation *> erases; 533 for (mlir::Operation *user : allocmem.getOperation()->getUsers()) { 534 if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) { 535 for (mlir::Operation *user : declareOp->getUsers()) { 536 if (mlir::isa<fir::FreeMemOp>(user)) 537 erases.push_back(user); 538 } 539 } 540 541 if (mlir::isa<fir::FreeMemOp>(user)) 542 erases.push_back(user); 543 } 544 545 // now we are done iterating the users, it is safe to mutate them 546 for (mlir::Operation *erase : erases) 547 rewriter.eraseOp(erase); 548 549 // replace references to heap allocation with references to stack allocation 550 mlir::Value newValue = convertAllocationType( 551 rewriter, allocmem.getLoc(), allocmem.getResult(), alloca->getResult()); 552 rewriter.replaceAllUsesWith(allocmem.getResult(), newValue); 553 554 // remove allocmem operation 555 rewriter.eraseOp(allocmem.getOperation()); 556 557 return mlir::success(); 558 } 559 560 static bool isInLoop(mlir::Block *block) { 561 return mlir::LoopLikeOpInterface::blockIsInLoop(block); 562 } 563 564 static bool isInLoop(mlir::Operation *op) { 565 return isInLoop(op->getBlock()) || 566 op->getParentOfType<mlir::LoopLikeOpInterface>(); 567 } 568 569 InsertionPoint 570 AllocMemConversion::findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc) { 571 // Ideally the alloca should be inserted at the end of the function entry 572 // block so that we do not allocate stack space in a loop. However, 573 // the operands to the alloca may not be available that early, so insert it 574 // after the last operand becomes available 575 // If the old allocmem op was in an openmp region then it should not be moved 576 // outside of that 577 LLVM_DEBUG(llvm::dbgs() << "StackArrays: findAllocaInsertionPoint: " 578 << oldAlloc << "\n"); 579 580 // check that an Operation or Block we are about to return is not in a loop 581 auto checkReturn = [&](auto *point) -> InsertionPoint { 582 if (isInLoop(point)) { 583 mlir::Operation *oldAllocOp = oldAlloc.getOperation(); 584 if (isInLoop(oldAllocOp)) { 585 // where we want to put it is in a loop, and even the old location is in 586 // a loop. Give up. 587 return findAllocaLoopInsertionPoint(oldAlloc); 588 } 589 return {oldAllocOp}; 590 } 591 return {point}; 592 }; 593 594 auto oldOmpRegion = 595 oldAlloc->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>(); 596 597 // Find when the last operand value becomes available 598 mlir::Block *operandsBlock = nullptr; 599 mlir::Operation *lastOperand = nullptr; 600 for (mlir::Value operand : oldAlloc.getOperands()) { 601 LLVM_DEBUG(llvm::dbgs() << "--considering operand " << operand << "\n"); 602 mlir::Operation *op = operand.getDefiningOp(); 603 if (!op) 604 return checkReturn(oldAlloc.getOperation()); 605 if (!operandsBlock) 606 operandsBlock = op->getBlock(); 607 else if (operandsBlock != op->getBlock()) { 608 LLVM_DEBUG(llvm::dbgs() 609 << "----operand declared in a different block!\n"); 610 // Operation::isBeforeInBlock requires the operations to be in the same 611 // block. The best we can do is the location of the allocmem. 612 return checkReturn(oldAlloc.getOperation()); 613 } 614 if (!lastOperand || lastOperand->isBeforeInBlock(op)) 615 lastOperand = op; 616 } 617 618 if (lastOperand) { 619 // there were value operands to the allocmem so insert after the last one 620 LLVM_DEBUG(llvm::dbgs() 621 << "--Placing after last operand: " << *lastOperand << "\n"); 622 // check we aren't moving out of an omp region 623 auto lastOpOmpRegion = 624 lastOperand->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>(); 625 if (lastOpOmpRegion == oldOmpRegion) 626 return checkReturn(lastOperand); 627 // Presumably this happened because the operands became ready before the 628 // start of this openmp region. (lastOpOmpRegion != oldOmpRegion) should 629 // imply that oldOmpRegion comes after lastOpOmpRegion. 630 return checkReturn(oldOmpRegion.getAllocaBlock()); 631 } 632 633 // There were no value operands to the allocmem so we are safe to insert it 634 // as early as we want 635 636 // handle openmp case 637 if (oldOmpRegion) 638 return checkReturn(oldOmpRegion.getAllocaBlock()); 639 640 // fall back to the function entry block 641 mlir::func::FuncOp func = oldAlloc->getParentOfType<mlir::func::FuncOp>(); 642 assert(func && "This analysis is run on func.func"); 643 mlir::Block &entryBlock = func.getBlocks().front(); 644 LLVM_DEBUG(llvm::dbgs() << "--Placing at the start of func entry block\n"); 645 return checkReturn(&entryBlock); 646 } 647 648 InsertionPoint 649 AllocMemConversion::findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc) { 650 mlir::Operation *oldAllocOp = oldAlloc; 651 // This is only called as a last resort. We should try to insert at the 652 // location of the old allocation, which is inside of a loop, using 653 // llvm.stacksave/llvm.stackrestore 654 655 // find freemem ops 656 llvm::SmallVector<mlir::Operation *, 1> freeOps; 657 658 for (mlir::Operation *user : oldAllocOp->getUsers()) { 659 if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) { 660 for (mlir::Operation *user : declareOp->getUsers()) { 661 if (mlir::isa<fir::FreeMemOp>(user)) 662 freeOps.push_back(user); 663 } 664 } 665 666 if (mlir::isa<fir::FreeMemOp>(user)) 667 freeOps.push_back(user); 668 } 669 670 assert(freeOps.size() && "DFA should only return freed memory"); 671 672 // Don't attempt to reason about a stacksave/stackrestore between different 673 // blocks 674 for (mlir::Operation *free : freeOps) 675 if (free->getBlock() != oldAllocOp->getBlock()) 676 return {nullptr}; 677 678 // Check that there aren't any other stack allocations in between the 679 // stack save and stack restore 680 // note: for flang generated temporaries there should only be one free op 681 for (mlir::Operation *free : freeOps) { 682 for (mlir::Operation *op = oldAlloc; op && op != free; 683 op = op->getNextNode()) { 684 if (mlir::isa<fir::AllocaOp>(op)) 685 return {nullptr}; 686 } 687 } 688 689 return InsertionPoint{oldAllocOp, /*shouldStackSaveRestore=*/true}; 690 } 691 692 std::optional<fir::AllocaOp> 693 AllocMemConversion::insertAlloca(fir::AllocMemOp &oldAlloc, 694 mlir::PatternRewriter &rewriter) const { 695 auto it = candidateOps.find(oldAlloc.getOperation()); 696 if (it == candidateOps.end()) 697 return {}; 698 InsertionPoint insertionPoint = it->second; 699 if (!insertionPoint) 700 return {}; 701 702 if (insertionPoint.shouldSaveRestoreStack()) 703 insertStackSaveRestore(oldAlloc, rewriter); 704 705 mlir::Location loc = oldAlloc.getLoc(); 706 mlir::Type varTy = oldAlloc.getInType(); 707 if (mlir::Operation *op = insertionPoint.tryGetOperation()) { 708 rewriter.setInsertionPointAfter(op); 709 } else { 710 mlir::Block *block = insertionPoint.tryGetBlock(); 711 assert(block && "There must be a valid insertion point"); 712 rewriter.setInsertionPointToStart(block); 713 } 714 715 auto unpackName = [](std::optional<llvm::StringRef> opt) -> llvm::StringRef { 716 if (opt) 717 return *opt; 718 return {}; 719 }; 720 721 llvm::StringRef uniqName = unpackName(oldAlloc.getUniqName()); 722 llvm::StringRef bindcName = unpackName(oldAlloc.getBindcName()); 723 return rewriter.create<fir::AllocaOp>(loc, varTy, uniqName, bindcName, 724 oldAlloc.getTypeparams(), 725 oldAlloc.getShape()); 726 } 727 728 void AllocMemConversion::insertStackSaveRestore( 729 fir::AllocMemOp &oldAlloc, mlir::PatternRewriter &rewriter) const { 730 auto oldPoint = rewriter.saveInsertionPoint(); 731 auto mod = oldAlloc->getParentOfType<mlir::ModuleOp>(); 732 fir::FirOpBuilder builder{rewriter, mod}; 733 734 mlir::func::FuncOp stackSaveFn = fir::factory::getLlvmStackSave(builder); 735 mlir::SymbolRefAttr stackSaveSym = 736 builder.getSymbolRefAttr(stackSaveFn.getName()); 737 738 builder.setInsertionPoint(oldAlloc); 739 mlir::Value sp = 740 builder 741 .create<fir::CallOp>(oldAlloc.getLoc(), 742 stackSaveFn.getFunctionType().getResults(), 743 stackSaveSym, mlir::ValueRange{}) 744 .getResult(0); 745 746 mlir::func::FuncOp stackRestoreFn = 747 fir::factory::getLlvmStackRestore(builder); 748 mlir::SymbolRefAttr stackRestoreSym = 749 builder.getSymbolRefAttr(stackRestoreFn.getName()); 750 751 auto createStackRestoreCall = [&](mlir::Operation *user) { 752 builder.setInsertionPoint(user); 753 builder.create<fir::CallOp>(user->getLoc(), 754 stackRestoreFn.getFunctionType().getResults(), 755 stackRestoreSym, mlir::ValueRange{sp}); 756 }; 757 758 for (mlir::Operation *user : oldAlloc->getUsers()) { 759 if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) { 760 for (mlir::Operation *user : declareOp->getUsers()) { 761 if (mlir::isa<fir::FreeMemOp>(user)) 762 createStackRestoreCall(user); 763 } 764 } 765 766 if (mlir::isa<fir::FreeMemOp>(user)) { 767 createStackRestoreCall(user); 768 } 769 } 770 771 rewriter.restoreInsertionPoint(oldPoint); 772 } 773 774 StackArraysPass::StackArraysPass(const StackArraysPass &pass) 775 : fir::impl::StackArraysBase<StackArraysPass>(pass) {} 776 777 llvm::StringRef StackArraysPass::getDescription() const { 778 return "Move heap allocated array temporaries to the stack"; 779 } 780 781 void StackArraysPass::runOnOperation() { 782 mlir::ModuleOp mod = getOperation(); 783 784 mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); }); 785 } 786 787 void StackArraysPass::runOnFunc(mlir::Operation *func) { 788 assert(mlir::isa<mlir::func::FuncOp>(func)); 789 790 auto &analysis = getAnalysis<StackArraysAnalysisWrapper>(); 791 const StackArraysAnalysisWrapper::AllocMemMap *candidateOps = 792 analysis.getCandidateOps(func); 793 if (!candidateOps) { 794 signalPassFailure(); 795 return; 796 } 797 798 if (candidateOps->empty()) 799 return; 800 runCount += candidateOps->size(); 801 802 llvm::SmallVector<mlir::Operation *> opsToConvert; 803 opsToConvert.reserve(candidateOps->size()); 804 for (auto [op, _] : *candidateOps) 805 opsToConvert.push_back(op); 806 807 mlir::MLIRContext &context = getContext(); 808 mlir::RewritePatternSet patterns(&context); 809 mlir::GreedyRewriteConfig config; 810 // prevent the pattern driver form merging blocks 811 config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; 812 813 patterns.insert<AllocMemConversion>(&context, *candidateOps); 814 if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert, 815 std::move(patterns), config))) { 816 mlir::emitError(func->getLoc(), "error in stack arrays optimization\n"); 817 signalPassFailure(); 818 } 819 } 820