1 //===- StackArrays.cpp ----------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/FIRBuilder.h" 10 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" 11 #include "flang/Optimizer/Dialect/FIRAttr.h" 12 #include "flang/Optimizer/Dialect/FIRDialect.h" 13 #include "flang/Optimizer/Dialect/FIROps.h" 14 #include "flang/Optimizer/Dialect/FIRType.h" 15 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 16 #include "flang/Optimizer/Transforms/Passes.h" 17 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" 18 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 19 #include "mlir/Analysis/DataFlow/DenseAnalysis.h" 20 #include "mlir/Analysis/DataFlowFramework.h" 21 #include "mlir/Dialect/Func/IR/FuncOps.h" 22 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 23 #include "mlir/IR/Builders.h" 24 #include "mlir/IR/Diagnostics.h" 25 #include "mlir/IR/Value.h" 26 #include "mlir/Interfaces/LoopLikeInterface.h" 27 #include "mlir/Pass/Pass.h" 28 #include "mlir/Support/LogicalResult.h" 29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 30 #include "mlir/Transforms/Passes.h" 31 #include "llvm/ADT/DenseMap.h" 32 #include "llvm/ADT/DenseSet.h" 33 #include "llvm/ADT/PointerUnion.h" 34 #include "llvm/Support/Casting.h" 35 #include "llvm/Support/raw_ostream.h" 36 #include <optional> 37 38 namespace fir { 39 #define GEN_PASS_DEF_STACKARRAYS 40 #include "flang/Optimizer/Transforms/Passes.h.inc" 41 } // namespace fir 42 43 #define DEBUG_TYPE "stack-arrays" 44 45 namespace { 46 47 /// The state of an SSA value at each program point 48 enum class AllocationState { 49 /// This means that the allocation state of a variable cannot be determined 50 /// at this program point, e.g. because one route through a conditional freed 51 /// the variable and the other route didn't. 52 /// This asserts a known-unknown: different from the unknown-unknown of having 53 /// no AllocationState stored for a particular SSA value 54 Unknown, 55 /// Means this SSA value was allocated on the heap in this function and has 56 /// now been freed 57 Freed, 58 /// Means this SSA value was allocated on the heap in this function and is a 59 /// candidate for moving to the stack 60 Allocated, 61 }; 62 63 /// Stores where an alloca should be inserted. If the PointerUnion is an 64 /// Operation the alloca should be inserted /after/ the operation. If it is a 65 /// block, the alloca can be placed anywhere in that block. 66 class InsertionPoint { 67 llvm::PointerUnion<mlir::Operation *, mlir::Block *> location; 68 bool saveRestoreStack; 69 70 /// Get contained pointer type or nullptr 71 template <class T> 72 T *tryGetPtr() const { 73 if (location.is<T *>()) 74 return location.get<T *>(); 75 return nullptr; 76 } 77 78 public: 79 template <class T> 80 InsertionPoint(T *ptr, bool saveRestoreStack = false) 81 : location(ptr), saveRestoreStack{saveRestoreStack} {} 82 InsertionPoint(std::nullptr_t null) 83 : location(null), saveRestoreStack{false} {} 84 85 /// Get contained operation, or nullptr 86 mlir::Operation *tryGetOperation() const { 87 return tryGetPtr<mlir::Operation>(); 88 } 89 90 /// Get contained block, or nullptr 91 mlir::Block *tryGetBlock() const { return tryGetPtr<mlir::Block>(); } 92 93 /// Get whether the stack should be saved/restored. If yes, an llvm.stacksave 94 /// intrinsic should be added before the alloca, and an llvm.stackrestore 95 /// intrinsic should be added where the freemem is 96 bool shouldSaveRestoreStack() const { return saveRestoreStack; } 97 98 operator bool() const { return tryGetOperation() || tryGetBlock(); } 99 100 bool operator==(const InsertionPoint &rhs) const { 101 return (location == rhs.location) && 102 (saveRestoreStack == rhs.saveRestoreStack); 103 } 104 105 bool operator!=(const InsertionPoint &rhs) const { return !(*this == rhs); } 106 }; 107 108 /// Maps SSA values to their AllocationState at a particular program point. 109 /// Also caches the insertion points for the new alloca operations 110 class LatticePoint : public mlir::dataflow::AbstractDenseLattice { 111 // Maps all values we are interested in to states 112 llvm::SmallDenseMap<mlir::Value, AllocationState, 1> stateMap; 113 114 public: 115 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LatticePoint) 116 using AbstractDenseLattice::AbstractDenseLattice; 117 118 bool operator==(const LatticePoint &rhs) const { 119 return stateMap == rhs.stateMap; 120 } 121 122 /// Join the lattice accross control-flow edges 123 mlir::ChangeResult join(const AbstractDenseLattice &lattice) override; 124 125 void print(llvm::raw_ostream &os) const override; 126 127 /// Clear all modifications 128 mlir::ChangeResult reset(); 129 130 /// Set the state of an SSA value 131 mlir::ChangeResult set(mlir::Value value, AllocationState state); 132 133 /// Get fir.allocmem ops which were allocated in this function and always 134 /// freed before the function returns, plus whre to insert replacement 135 /// fir.alloca ops 136 void appendFreedValues(llvm::DenseSet<mlir::Value> &out) const; 137 138 std::optional<AllocationState> get(mlir::Value val) const; 139 }; 140 141 class AllocationAnalysis 142 : public mlir::dataflow::DenseDataFlowAnalysis<LatticePoint> { 143 public: 144 using DenseDataFlowAnalysis::DenseDataFlowAnalysis; 145 146 void visitOperation(mlir::Operation *op, const LatticePoint &before, 147 LatticePoint *after) override; 148 149 /// At an entry point, the last modifications of all memory resources are 150 /// yet to be determined 151 void setToEntryState(LatticePoint *lattice) override; 152 153 protected: 154 /// Visit control flow operations and decide whether to call visitOperation 155 /// to apply the transfer function 156 void processOperation(mlir::Operation *op) override; 157 }; 158 159 /// Drives analysis to find candidate fir.allocmem operations which could be 160 /// moved to the stack. Intended to be used with mlir::Pass::getAnalysis 161 class StackArraysAnalysisWrapper { 162 public: 163 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StackArraysAnalysisWrapper) 164 165 // Maps fir.allocmem -> place to insert alloca 166 using AllocMemMap = llvm::DenseMap<mlir::Operation *, InsertionPoint>; 167 168 StackArraysAnalysisWrapper(mlir::Operation *op) {} 169 170 // returns nullptr if analysis failed 171 const AllocMemMap *getCandidateOps(mlir::Operation *func); 172 173 private: 174 llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps; 175 176 mlir::LogicalResult analyseFunction(mlir::Operation *func); 177 }; 178 179 /// Converts a fir.allocmem to a fir.alloca 180 class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> { 181 public: 182 explicit AllocMemConversion( 183 mlir::MLIRContext *ctx, 184 const StackArraysAnalysisWrapper::AllocMemMap &candidateOps) 185 : OpRewritePattern(ctx), candidateOps{candidateOps} {} 186 187 mlir::LogicalResult 188 matchAndRewrite(fir::AllocMemOp allocmem, 189 mlir::PatternRewriter &rewriter) const override; 190 191 /// Determine where to insert the alloca operation. The returned value should 192 /// be checked to see if it is inside a loop 193 static InsertionPoint findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc); 194 195 private: 196 /// Handle to the DFA (already run) 197 const StackArraysAnalysisWrapper::AllocMemMap &candidateOps; 198 199 /// If we failed to find an insertion point not inside a loop, see if it would 200 /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop 201 static InsertionPoint findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc); 202 203 /// Returns the alloca if it was successfully inserted, otherwise {} 204 std::optional<fir::AllocaOp> 205 insertAlloca(fir::AllocMemOp &oldAlloc, 206 mlir::PatternRewriter &rewriter) const; 207 208 /// Inserts a stacksave before oldAlloc and a stackrestore after each freemem 209 void insertStackSaveRestore(fir::AllocMemOp &oldAlloc, 210 mlir::PatternRewriter &rewriter) const; 211 }; 212 213 class StackArraysPass : public fir::impl::StackArraysBase<StackArraysPass> { 214 public: 215 StackArraysPass() = default; 216 StackArraysPass(const StackArraysPass &pass); 217 218 llvm::StringRef getDescription() const override; 219 220 void runOnOperation() override; 221 void runOnFunc(mlir::Operation *func); 222 223 private: 224 Statistic runCount{this, "stackArraysRunCount", 225 "Number of heap allocations moved to the stack"}; 226 }; 227 228 } // namespace 229 230 static void print(llvm::raw_ostream &os, AllocationState state) { 231 switch (state) { 232 case AllocationState::Unknown: 233 os << "Unknown"; 234 break; 235 case AllocationState::Freed: 236 os << "Freed"; 237 break; 238 case AllocationState::Allocated: 239 os << "Allocated"; 240 break; 241 } 242 } 243 244 /// Join two AllocationStates for the same value coming from different CFG 245 /// blocks 246 static AllocationState join(AllocationState lhs, AllocationState rhs) { 247 // | Allocated | Freed | Unknown 248 // ========= | ========= | ========= | ========= 249 // Allocated | Allocated | Unknown | Unknown 250 // Freed | Unknown | Freed | Unknown 251 // Unknown | Unknown | Unknown | Unknown 252 if (lhs == rhs) 253 return lhs; 254 return AllocationState::Unknown; 255 } 256 257 mlir::ChangeResult LatticePoint::join(const AbstractDenseLattice &lattice) { 258 const auto &rhs = static_cast<const LatticePoint &>(lattice); 259 mlir::ChangeResult changed = mlir::ChangeResult::NoChange; 260 261 // add everything from rhs to map, handling cases where values are in both 262 for (const auto &[value, rhsState] : rhs.stateMap) { 263 auto it = stateMap.find(value); 264 if (it != stateMap.end()) { 265 // value is present in both maps 266 AllocationState myState = it->second; 267 AllocationState newState = ::join(myState, rhsState); 268 if (newState != myState) { 269 changed = mlir::ChangeResult::Change; 270 it->getSecond() = newState; 271 } 272 } else { 273 // value not present in current map: add it 274 stateMap.insert({value, rhsState}); 275 changed = mlir::ChangeResult::Change; 276 } 277 } 278 279 return changed; 280 } 281 282 void LatticePoint::print(llvm::raw_ostream &os) const { 283 for (const auto &[value, state] : stateMap) { 284 os << value << ": "; 285 ::print(os, state); 286 } 287 } 288 289 mlir::ChangeResult LatticePoint::reset() { 290 if (stateMap.empty()) 291 return mlir::ChangeResult::NoChange; 292 stateMap.clear(); 293 return mlir::ChangeResult::Change; 294 } 295 296 mlir::ChangeResult LatticePoint::set(mlir::Value value, AllocationState state) { 297 if (stateMap.count(value)) { 298 // already in map 299 AllocationState &oldState = stateMap[value]; 300 if (oldState != state) { 301 stateMap[value] = state; 302 return mlir::ChangeResult::Change; 303 } 304 return mlir::ChangeResult::NoChange; 305 } 306 stateMap.insert({value, state}); 307 return mlir::ChangeResult::Change; 308 } 309 310 /// Get values which were allocated in this function and always freed before 311 /// the function returns 312 void LatticePoint::appendFreedValues(llvm::DenseSet<mlir::Value> &out) const { 313 for (auto &[value, state] : stateMap) { 314 if (state == AllocationState::Freed) 315 out.insert(value); 316 } 317 } 318 319 std::optional<AllocationState> LatticePoint::get(mlir::Value val) const { 320 auto it = stateMap.find(val); 321 if (it == stateMap.end()) 322 return {}; 323 return it->second; 324 } 325 326 void AllocationAnalysis::visitOperation(mlir::Operation *op, 327 const LatticePoint &before, 328 LatticePoint *after) { 329 LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op 330 << "\n"); 331 LLVM_DEBUG(llvm::dbgs() << "--Lattice in: " << before << "\n"); 332 333 // propagate before -> after 334 mlir::ChangeResult changed = after->join(before); 335 336 if (auto allocmem = mlir::dyn_cast<fir::AllocMemOp>(op)) { 337 assert(op->getNumResults() == 1 && "fir.allocmem has one result"); 338 auto attr = op->getAttrOfType<fir::MustBeHeapAttr>( 339 fir::MustBeHeapAttr::getAttrName()); 340 if (attr && attr.getValue()) { 341 LLVM_DEBUG(llvm::dbgs() << "--Found fir.must_be_heap: skipping\n"); 342 // skip allocation marked not to be moved 343 return; 344 } 345 346 auto retTy = allocmem.getAllocatedType(); 347 if (!retTy.isa<fir::SequenceType>()) { 348 LLVM_DEBUG(llvm::dbgs() 349 << "--Allocation is not for an array: skipping\n"); 350 return; 351 } 352 353 mlir::Value result = op->getResult(0); 354 changed |= after->set(result, AllocationState::Allocated); 355 } else if (mlir::isa<fir::FreeMemOp>(op)) { 356 assert(op->getNumOperands() == 1 && "fir.freemem has one operand"); 357 mlir::Value operand = op->getOperand(0); 358 std::optional<AllocationState> operandState = before.get(operand); 359 if (operandState && *operandState == AllocationState::Allocated) { 360 // don't tag things not allocated in this function as freed, so that we 361 // don't think they are candidates for moving to the stack 362 changed |= after->set(operand, AllocationState::Freed); 363 } 364 } else if (mlir::isa<fir::ResultOp>(op)) { 365 mlir::Operation *parent = op->getParentOp(); 366 LatticePoint *parentLattice = getLattice(parent); 367 assert(parentLattice); 368 mlir::ChangeResult parentChanged = parentLattice->join(*after); 369 propagateIfChanged(parentLattice, parentChanged); 370 } 371 372 // we pass lattices straight through fir.call because called functions should 373 // not deallocate flang-generated array temporaries 374 375 LLVM_DEBUG(llvm::dbgs() << "--Lattice out: " << *after << "\n"); 376 propagateIfChanged(after, changed); 377 } 378 379 void AllocationAnalysis::setToEntryState(LatticePoint *lattice) { 380 propagateIfChanged(lattice, lattice->reset()); 381 } 382 383 /// Mostly a copy of AbstractDenseLattice::processOperation - the difference 384 /// being that call operations are passed through to the transfer function 385 void AllocationAnalysis::processOperation(mlir::Operation *op) { 386 // If the containing block is not executable, bail out. 387 if (!getOrCreateFor<mlir::dataflow::Executable>(op, op->getBlock())->isLive()) 388 return; 389 390 // Get the dense lattice to update 391 mlir::dataflow::AbstractDenseLattice *after = getLattice(op); 392 393 // If this op implements region control-flow, then control-flow dictates its 394 // transfer function. 395 if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) 396 return visitRegionBranchOperation(op, branch, after); 397 398 // pass call operations through to the transfer function 399 400 // Get the dense state before the execution of the op. 401 const mlir::dataflow::AbstractDenseLattice *before; 402 if (mlir::Operation *prev = op->getPrevNode()) 403 before = getLatticeFor(op, prev); 404 else 405 before = getLatticeFor(op, op->getBlock()); 406 407 /// Invoke the operation transfer function 408 visitOperationImpl(op, *before, after); 409 } 410 411 mlir::LogicalResult 412 StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) { 413 assert(mlir::isa<mlir::func::FuncOp>(func)); 414 mlir::DataFlowSolver solver; 415 // constant propagation is required for dead code analysis, dead code analysis 416 // is required to mark blocks live (required for mlir dense dfa) 417 solver.load<mlir::dataflow::SparseConstantPropagation>(); 418 solver.load<mlir::dataflow::DeadCodeAnalysis>(); 419 420 auto [it, inserted] = funcMaps.try_emplace(func); 421 AllocMemMap &candidateOps = it->second; 422 423 solver.load<AllocationAnalysis>(); 424 if (failed(solver.initializeAndRun(func))) { 425 llvm::errs() << "DataFlowSolver failed!"; 426 return mlir::failure(); 427 } 428 429 LatticePoint point{func}; 430 auto joinOperationLattice = [&](mlir::Operation *op) { 431 const LatticePoint *lattice = solver.lookupState<LatticePoint>(op); 432 // there will be no lattice for an unreachable block 433 if (lattice) 434 point.join(*lattice); 435 }; 436 func->walk([&](mlir::func::ReturnOp child) { joinOperationLattice(child); }); 437 func->walk([&](fir::UnreachableOp child) { joinOperationLattice(child); }); 438 llvm::DenseSet<mlir::Value> freedValues; 439 point.appendFreedValues(freedValues); 440 441 // We only replace allocations which are definately freed on all routes 442 // through the function because otherwise the allocation may have an intende 443 // lifetime longer than the current stack frame (e.g. a heap allocation which 444 // is then freed by another function). 445 for (mlir::Value freedValue : freedValues) { 446 fir::AllocMemOp allocmem = freedValue.getDefiningOp<fir::AllocMemOp>(); 447 InsertionPoint insertionPoint = 448 AllocMemConversion::findAllocaInsertionPoint(allocmem); 449 if (insertionPoint) 450 candidateOps.insert({allocmem, insertionPoint}); 451 } 452 453 LLVM_DEBUG(for (auto [allocMemOp, _] 454 : candidateOps) { 455 llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n'; 456 }); 457 return mlir::success(); 458 } 459 460 const StackArraysAnalysisWrapper::AllocMemMap * 461 StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) { 462 if (!funcMaps.contains(func)) 463 if (mlir::failed(analyseFunction(func))) 464 return nullptr; 465 return &funcMaps[func]; 466 } 467 468 /// Restore the old allocation type exected by existing code 469 static mlir::Value convertAllocationType(mlir::PatternRewriter &rewriter, 470 const mlir::Location &loc, 471 mlir::Value heap, mlir::Value stack) { 472 mlir::Type heapTy = heap.getType(); 473 mlir::Type stackTy = stack.getType(); 474 475 if (heapTy == stackTy) 476 return stack; 477 478 fir::HeapType firHeapTy = mlir::cast<fir::HeapType>(heapTy); 479 fir::ReferenceType firRefTy = mlir::cast<fir::ReferenceType>(stackTy); 480 assert(firHeapTy.getElementType() == firRefTy.getElementType() && 481 "Allocations must have the same type"); 482 483 auto insertionPoint = rewriter.saveInsertionPoint(); 484 rewriter.setInsertionPointAfter(stack.getDefiningOp()); 485 mlir::Value conv = 486 rewriter.create<fir::ConvertOp>(loc, firHeapTy, stack).getResult(); 487 rewriter.restoreInsertionPoint(insertionPoint); 488 return conv; 489 } 490 491 mlir::LogicalResult 492 AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem, 493 mlir::PatternRewriter &rewriter) const { 494 auto oldInsertionPt = rewriter.saveInsertionPoint(); 495 // add alloca operation 496 std::optional<fir::AllocaOp> alloca = insertAlloca(allocmem, rewriter); 497 rewriter.restoreInsertionPoint(oldInsertionPt); 498 if (!alloca) 499 return mlir::failure(); 500 501 // remove freemem operations 502 llvm::SmallVector<mlir::Operation *> erases; 503 for (mlir::Operation *user : allocmem.getOperation()->getUsers()) 504 if (mlir::isa<fir::FreeMemOp>(user)) 505 erases.push_back(user); 506 // now we are done iterating the users, it is safe to mutate them 507 for (mlir::Operation *erase : erases) 508 rewriter.eraseOp(erase); 509 510 // replace references to heap allocation with references to stack allocation 511 mlir::Value newValue = convertAllocationType( 512 rewriter, allocmem.getLoc(), allocmem.getResult(), alloca->getResult()); 513 rewriter.replaceAllUsesWith(allocmem.getResult(), newValue); 514 515 // remove allocmem operation 516 rewriter.eraseOp(allocmem.getOperation()); 517 518 return mlir::success(); 519 } 520 521 static bool isInLoop(mlir::Block *block) { 522 return mlir::LoopLikeOpInterface::blockIsInLoop(block); 523 } 524 525 static bool isInLoop(mlir::Operation *op) { 526 return isInLoop(op->getBlock()) || 527 op->getParentOfType<mlir::LoopLikeOpInterface>(); 528 } 529 530 InsertionPoint 531 AllocMemConversion::findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc) { 532 // Ideally the alloca should be inserted at the end of the function entry 533 // block so that we do not allocate stack space in a loop. However, 534 // the operands to the alloca may not be available that early, so insert it 535 // after the last operand becomes available 536 // If the old allocmem op was in an openmp region then it should not be moved 537 // outside of that 538 LLVM_DEBUG(llvm::dbgs() << "StackArrays: findAllocaInsertionPoint: " 539 << oldAlloc << "\n"); 540 541 // check that an Operation or Block we are about to return is not in a loop 542 auto checkReturn = [&](auto *point) -> InsertionPoint { 543 if (isInLoop(point)) { 544 mlir::Operation *oldAllocOp = oldAlloc.getOperation(); 545 if (isInLoop(oldAllocOp)) { 546 // where we want to put it is in a loop, and even the old location is in 547 // a loop. Give up. 548 return findAllocaLoopInsertionPoint(oldAlloc); 549 } 550 return {oldAllocOp}; 551 } 552 return {point}; 553 }; 554 555 auto oldOmpRegion = 556 oldAlloc->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>(); 557 558 // Find when the last operand value becomes available 559 mlir::Block *operandsBlock = nullptr; 560 mlir::Operation *lastOperand = nullptr; 561 for (mlir::Value operand : oldAlloc.getOperands()) { 562 LLVM_DEBUG(llvm::dbgs() << "--considering operand " << operand << "\n"); 563 mlir::Operation *op = operand.getDefiningOp(); 564 if (!op) 565 return checkReturn(oldAlloc.getOperation()); 566 if (!operandsBlock) 567 operandsBlock = op->getBlock(); 568 else if (operandsBlock != op->getBlock()) { 569 LLVM_DEBUG(llvm::dbgs() 570 << "----operand declared in a different block!\n"); 571 // Operation::isBeforeInBlock requires the operations to be in the same 572 // block. The best we can do is the location of the allocmem. 573 return checkReturn(oldAlloc.getOperation()); 574 } 575 if (!lastOperand || lastOperand->isBeforeInBlock(op)) 576 lastOperand = op; 577 } 578 579 if (lastOperand) { 580 // there were value operands to the allocmem so insert after the last one 581 LLVM_DEBUG(llvm::dbgs() 582 << "--Placing after last operand: " << *lastOperand << "\n"); 583 // check we aren't moving out of an omp region 584 auto lastOpOmpRegion = 585 lastOperand->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>(); 586 if (lastOpOmpRegion == oldOmpRegion) 587 return checkReturn(lastOperand); 588 // Presumably this happened because the operands became ready before the 589 // start of this openmp region. (lastOpOmpRegion != oldOmpRegion) should 590 // imply that oldOmpRegion comes after lastOpOmpRegion. 591 return checkReturn(oldOmpRegion.getAllocaBlock()); 592 } 593 594 // There were no value operands to the allocmem so we are safe to insert it 595 // as early as we want 596 597 // handle openmp case 598 if (oldOmpRegion) 599 return checkReturn(oldOmpRegion.getAllocaBlock()); 600 601 // fall back to the function entry block 602 mlir::func::FuncOp func = oldAlloc->getParentOfType<mlir::func::FuncOp>(); 603 assert(func && "This analysis is run on func.func"); 604 mlir::Block &entryBlock = func.getBlocks().front(); 605 LLVM_DEBUG(llvm::dbgs() << "--Placing at the start of func entry block\n"); 606 return checkReturn(&entryBlock); 607 } 608 609 InsertionPoint 610 AllocMemConversion::findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc) { 611 mlir::Operation *oldAllocOp = oldAlloc; 612 // This is only called as a last resort. We should try to insert at the 613 // location of the old allocation, which is inside of a loop, using 614 // llvm.stacksave/llvm.stackrestore 615 616 // find freemem ops 617 llvm::SmallVector<mlir::Operation *, 1> freeOps; 618 for (mlir::Operation *user : oldAllocOp->getUsers()) 619 if (mlir::isa<fir::FreeMemOp>(user)) 620 freeOps.push_back(user); 621 assert(freeOps.size() && "DFA should only return freed memory"); 622 623 // Don't attempt to reason about a stacksave/stackrestore between different 624 // blocks 625 for (mlir::Operation *free : freeOps) 626 if (free->getBlock() != oldAllocOp->getBlock()) 627 return {nullptr}; 628 629 // Check that there aren't any other stack allocations in between the 630 // stack save and stack restore 631 // note: for flang generated temporaries there should only be one free op 632 for (mlir::Operation *free : freeOps) { 633 for (mlir::Operation *op = oldAlloc; op && op != free; 634 op = op->getNextNode()) { 635 if (mlir::isa<fir::AllocaOp>(op)) 636 return {nullptr}; 637 } 638 } 639 640 return InsertionPoint{oldAllocOp, /*shouldStackSaveRestore=*/true}; 641 } 642 643 std::optional<fir::AllocaOp> 644 AllocMemConversion::insertAlloca(fir::AllocMemOp &oldAlloc, 645 mlir::PatternRewriter &rewriter) const { 646 auto it = candidateOps.find(oldAlloc.getOperation()); 647 if (it == candidateOps.end()) 648 return {}; 649 InsertionPoint insertionPoint = it->second; 650 if (!insertionPoint) 651 return {}; 652 653 if (insertionPoint.shouldSaveRestoreStack()) 654 insertStackSaveRestore(oldAlloc, rewriter); 655 656 mlir::Location loc = oldAlloc.getLoc(); 657 mlir::Type varTy = oldAlloc.getInType(); 658 if (mlir::Operation *op = insertionPoint.tryGetOperation()) { 659 rewriter.setInsertionPointAfter(op); 660 } else { 661 mlir::Block *block = insertionPoint.tryGetBlock(); 662 assert(block && "There must be a valid insertion point"); 663 rewriter.setInsertionPointToStart(block); 664 } 665 666 auto unpackName = [](std::optional<llvm::StringRef> opt) -> llvm::StringRef { 667 if (opt) 668 return *opt; 669 return {}; 670 }; 671 672 llvm::StringRef uniqName = unpackName(oldAlloc.getUniqName()); 673 llvm::StringRef bindcName = unpackName(oldAlloc.getBindcName()); 674 return rewriter.create<fir::AllocaOp>(loc, varTy, uniqName, bindcName, 675 oldAlloc.getTypeparams(), 676 oldAlloc.getShape()); 677 } 678 679 void AllocMemConversion::insertStackSaveRestore( 680 fir::AllocMemOp &oldAlloc, mlir::PatternRewriter &rewriter) const { 681 auto oldPoint = rewriter.saveInsertionPoint(); 682 auto mod = oldAlloc->getParentOfType<mlir::ModuleOp>(); 683 fir::FirOpBuilder builder{rewriter, mod}; 684 685 mlir::func::FuncOp stackSaveFn = fir::factory::getLlvmStackSave(builder); 686 mlir::SymbolRefAttr stackSaveSym = 687 builder.getSymbolRefAttr(stackSaveFn.getName()); 688 689 builder.setInsertionPoint(oldAlloc); 690 mlir::Value sp = 691 builder 692 .create<fir::CallOp>(oldAlloc.getLoc(), 693 stackSaveFn.getFunctionType().getResults(), 694 stackSaveSym, mlir::ValueRange{}) 695 .getResult(0); 696 697 mlir::func::FuncOp stackRestoreFn = 698 fir::factory::getLlvmStackRestore(builder); 699 mlir::SymbolRefAttr stackRestoreSym = 700 builder.getSymbolRefAttr(stackRestoreFn.getName()); 701 702 for (mlir::Operation *user : oldAlloc->getUsers()) { 703 if (mlir::isa<fir::FreeMemOp>(user)) { 704 builder.setInsertionPoint(user); 705 builder.create<fir::CallOp>(user->getLoc(), 706 stackRestoreFn.getFunctionType().getResults(), 707 stackRestoreSym, mlir::ValueRange{sp}); 708 } 709 } 710 711 rewriter.restoreInsertionPoint(oldPoint); 712 } 713 714 StackArraysPass::StackArraysPass(const StackArraysPass &pass) 715 : fir::impl::StackArraysBase<StackArraysPass>(pass) {} 716 717 llvm::StringRef StackArraysPass::getDescription() const { 718 return "Move heap allocated array temporaries to the stack"; 719 } 720 721 void StackArraysPass::runOnOperation() { 722 mlir::ModuleOp mod = getOperation(); 723 724 mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); }); 725 } 726 727 void StackArraysPass::runOnFunc(mlir::Operation *func) { 728 assert(mlir::isa<mlir::func::FuncOp>(func)); 729 730 auto &analysis = getAnalysis<StackArraysAnalysisWrapper>(); 731 const StackArraysAnalysisWrapper::AllocMemMap *candidateOps = 732 analysis.getCandidateOps(func); 733 if (!candidateOps) { 734 signalPassFailure(); 735 return; 736 } 737 738 if (candidateOps->empty()) 739 return; 740 runCount += candidateOps->size(); 741 742 llvm::SmallVector<mlir::Operation *> opsToConvert; 743 opsToConvert.reserve(candidateOps->size()); 744 for (auto [op, _] : *candidateOps) 745 opsToConvert.push_back(op); 746 747 mlir::MLIRContext &context = getContext(); 748 mlir::RewritePatternSet patterns(&context); 749 mlir::GreedyRewriteConfig config; 750 // prevent the pattern driver form merging blocks 751 config.enableRegionSimplification = false; 752 753 patterns.insert<AllocMemConversion>(&context, *candidateOps); 754 if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert, 755 std::move(patterns), config))) { 756 mlir::emitError(func->getLoc(), "error in stack arrays optimization\n"); 757 signalPassFailure(); 758 } 759 } 760 761 std::unique_ptr<mlir::Pass> fir::createStackArraysPass() { 762 return std::make_unique<StackArraysPass>(); 763 } 764