1 //===- StackArrays.cpp ----------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/FIRBuilder.h" 10 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" 11 #include "flang/Optimizer/Dialect/FIRAttr.h" 12 #include "flang/Optimizer/Dialect/FIRDialect.h" 13 #include "flang/Optimizer/Dialect/FIROps.h" 14 #include "flang/Optimizer/Dialect/FIRType.h" 15 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 16 #include "flang/Optimizer/Transforms/Passes.h" 17 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" 18 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 19 #include "mlir/Analysis/DataFlow/DenseAnalysis.h" 20 #include "mlir/Analysis/DataFlowFramework.h" 21 #include "mlir/Dialect/Func/IR/FuncOps.h" 22 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 23 #include "mlir/IR/Builders.h" 24 #include "mlir/IR/Diagnostics.h" 25 #include "mlir/IR/Value.h" 26 #include "mlir/Interfaces/LoopLikeInterface.h" 27 #include "mlir/Pass/Pass.h" 28 #include "mlir/Support/LogicalResult.h" 29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 30 #include "mlir/Transforms/Passes.h" 31 #include "llvm/ADT/DenseMap.h" 32 #include "llvm/ADT/DenseSet.h" 33 #include "llvm/ADT/PointerUnion.h" 34 #include "llvm/Support/Casting.h" 35 #include "llvm/Support/raw_ostream.h" 36 #include <optional> 37 38 namespace fir { 39 #define GEN_PASS_DEF_STACKARRAYS 40 #include "flang/Optimizer/Transforms/Passes.h.inc" 41 } // namespace fir 42 43 #define DEBUG_TYPE "stack-arrays" 44 45 static llvm::cl::opt<std::size_t> maxAllocsPerFunc( 46 "stack-arrays-max-allocs", 47 llvm::cl::desc("The maximum number of heap allocations to consider in one " 48 "function before skipping (to save compilation time). Set " 49 "to 0 for no limit."), 50 llvm::cl::init(1000), llvm::cl::Hidden); 51 52 namespace { 53 54 /// The state of an SSA value at each program point 55 enum class AllocationState { 56 /// This means that the allocation state of a variable cannot be determined 57 /// at this program point, e.g. because one route through a conditional freed 58 /// the variable and the other route didn't. 59 /// This asserts a known-unknown: different from the unknown-unknown of having 60 /// no AllocationState stored for a particular SSA value 61 Unknown, 62 /// Means this SSA value was allocated on the heap in this function and has 63 /// now been freed 64 Freed, 65 /// Means this SSA value was allocated on the heap in this function and is a 66 /// candidate for moving to the stack 67 Allocated, 68 }; 69 70 /// Stores where an alloca should be inserted. If the PointerUnion is an 71 /// Operation the alloca should be inserted /after/ the operation. If it is a 72 /// block, the alloca can be placed anywhere in that block. 73 class InsertionPoint { 74 llvm::PointerUnion<mlir::Operation *, mlir::Block *> location; 75 bool saveRestoreStack; 76 77 /// Get contained pointer type or nullptr 78 template <class T> 79 T *tryGetPtr() const { 80 if (location.is<T *>()) 81 return location.get<T *>(); 82 return nullptr; 83 } 84 85 public: 86 template <class T> 87 InsertionPoint(T *ptr, bool saveRestoreStack = false) 88 : location(ptr), saveRestoreStack{saveRestoreStack} {} 89 InsertionPoint(std::nullptr_t null) 90 : location(null), saveRestoreStack{false} {} 91 92 /// Get contained operation, or nullptr 93 mlir::Operation *tryGetOperation() const { 94 return tryGetPtr<mlir::Operation>(); 95 } 96 97 /// Get contained block, or nullptr 98 mlir::Block *tryGetBlock() const { return tryGetPtr<mlir::Block>(); } 99 100 /// Get whether the stack should be saved/restored. If yes, an llvm.stacksave 101 /// intrinsic should be added before the alloca, and an llvm.stackrestore 102 /// intrinsic should be added where the freemem is 103 bool shouldSaveRestoreStack() const { return saveRestoreStack; } 104 105 operator bool() const { return tryGetOperation() || tryGetBlock(); } 106 107 bool operator==(const InsertionPoint &rhs) const { 108 return (location == rhs.location) && 109 (saveRestoreStack == rhs.saveRestoreStack); 110 } 111 112 bool operator!=(const InsertionPoint &rhs) const { return !(*this == rhs); } 113 }; 114 115 /// Maps SSA values to their AllocationState at a particular program point. 116 /// Also caches the insertion points for the new alloca operations 117 class LatticePoint : public mlir::dataflow::AbstractDenseLattice { 118 // Maps all values we are interested in to states 119 llvm::SmallDenseMap<mlir::Value, AllocationState, 1> stateMap; 120 121 public: 122 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LatticePoint) 123 using AbstractDenseLattice::AbstractDenseLattice; 124 125 bool operator==(const LatticePoint &rhs) const { 126 return stateMap == rhs.stateMap; 127 } 128 129 /// Join the lattice accross control-flow edges 130 mlir::ChangeResult join(const AbstractDenseLattice &lattice) override; 131 132 void print(llvm::raw_ostream &os) const override; 133 134 /// Clear all modifications 135 mlir::ChangeResult reset(); 136 137 /// Set the state of an SSA value 138 mlir::ChangeResult set(mlir::Value value, AllocationState state); 139 140 /// Get fir.allocmem ops which were allocated in this function and always 141 /// freed before the function returns, plus whre to insert replacement 142 /// fir.alloca ops 143 void appendFreedValues(llvm::DenseSet<mlir::Value> &out) const; 144 145 std::optional<AllocationState> get(mlir::Value val) const; 146 }; 147 148 class AllocationAnalysis 149 : public mlir::dataflow::DenseForwardDataFlowAnalysis<LatticePoint> { 150 public: 151 using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis; 152 153 void visitOperation(mlir::Operation *op, const LatticePoint &before, 154 LatticePoint *after) override; 155 156 /// At an entry point, the last modifications of all memory resources are 157 /// yet to be determined 158 void setToEntryState(LatticePoint *lattice) override; 159 160 protected: 161 /// Visit control flow operations and decide whether to call visitOperation 162 /// to apply the transfer function 163 void processOperation(mlir::Operation *op) override; 164 }; 165 166 /// Drives analysis to find candidate fir.allocmem operations which could be 167 /// moved to the stack. Intended to be used with mlir::Pass::getAnalysis 168 class StackArraysAnalysisWrapper { 169 public: 170 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StackArraysAnalysisWrapper) 171 172 // Maps fir.allocmem -> place to insert alloca 173 using AllocMemMap = llvm::DenseMap<mlir::Operation *, InsertionPoint>; 174 175 StackArraysAnalysisWrapper(mlir::Operation *op) {} 176 177 // returns nullptr if analysis failed 178 const AllocMemMap *getCandidateOps(mlir::Operation *func); 179 180 private: 181 llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps; 182 183 mlir::LogicalResult analyseFunction(mlir::Operation *func); 184 }; 185 186 /// Converts a fir.allocmem to a fir.alloca 187 class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> { 188 public: 189 explicit AllocMemConversion( 190 mlir::MLIRContext *ctx, 191 const StackArraysAnalysisWrapper::AllocMemMap &candidateOps) 192 : OpRewritePattern(ctx), candidateOps{candidateOps} {} 193 194 mlir::LogicalResult 195 matchAndRewrite(fir::AllocMemOp allocmem, 196 mlir::PatternRewriter &rewriter) const override; 197 198 /// Determine where to insert the alloca operation. The returned value should 199 /// be checked to see if it is inside a loop 200 static InsertionPoint findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc); 201 202 private: 203 /// Handle to the DFA (already run) 204 const StackArraysAnalysisWrapper::AllocMemMap &candidateOps; 205 206 /// If we failed to find an insertion point not inside a loop, see if it would 207 /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop 208 static InsertionPoint findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc); 209 210 /// Returns the alloca if it was successfully inserted, otherwise {} 211 std::optional<fir::AllocaOp> 212 insertAlloca(fir::AllocMemOp &oldAlloc, 213 mlir::PatternRewriter &rewriter) const; 214 215 /// Inserts a stacksave before oldAlloc and a stackrestore after each freemem 216 void insertStackSaveRestore(fir::AllocMemOp &oldAlloc, 217 mlir::PatternRewriter &rewriter) const; 218 }; 219 220 class StackArraysPass : public fir::impl::StackArraysBase<StackArraysPass> { 221 public: 222 StackArraysPass() = default; 223 StackArraysPass(const StackArraysPass &pass); 224 225 llvm::StringRef getDescription() const override; 226 227 void runOnOperation() override; 228 void runOnFunc(mlir::Operation *func); 229 230 private: 231 Statistic runCount{this, "stackArraysRunCount", 232 "Number of heap allocations moved to the stack"}; 233 }; 234 235 } // namespace 236 237 static void print(llvm::raw_ostream &os, AllocationState state) { 238 switch (state) { 239 case AllocationState::Unknown: 240 os << "Unknown"; 241 break; 242 case AllocationState::Freed: 243 os << "Freed"; 244 break; 245 case AllocationState::Allocated: 246 os << "Allocated"; 247 break; 248 } 249 } 250 251 /// Join two AllocationStates for the same value coming from different CFG 252 /// blocks 253 static AllocationState join(AllocationState lhs, AllocationState rhs) { 254 // | Allocated | Freed | Unknown 255 // ========= | ========= | ========= | ========= 256 // Allocated | Allocated | Unknown | Unknown 257 // Freed | Unknown | Freed | Unknown 258 // Unknown | Unknown | Unknown | Unknown 259 if (lhs == rhs) 260 return lhs; 261 return AllocationState::Unknown; 262 } 263 264 mlir::ChangeResult LatticePoint::join(const AbstractDenseLattice &lattice) { 265 const auto &rhs = static_cast<const LatticePoint &>(lattice); 266 mlir::ChangeResult changed = mlir::ChangeResult::NoChange; 267 268 // add everything from rhs to map, handling cases where values are in both 269 for (const auto &[value, rhsState] : rhs.stateMap) { 270 auto it = stateMap.find(value); 271 if (it != stateMap.end()) { 272 // value is present in both maps 273 AllocationState myState = it->second; 274 AllocationState newState = ::join(myState, rhsState); 275 if (newState != myState) { 276 changed = mlir::ChangeResult::Change; 277 it->getSecond() = newState; 278 } 279 } else { 280 // value not present in current map: add it 281 stateMap.insert({value, rhsState}); 282 changed = mlir::ChangeResult::Change; 283 } 284 } 285 286 return changed; 287 } 288 289 void LatticePoint::print(llvm::raw_ostream &os) const { 290 for (const auto &[value, state] : stateMap) { 291 os << value << ": "; 292 ::print(os, state); 293 } 294 } 295 296 mlir::ChangeResult LatticePoint::reset() { 297 if (stateMap.empty()) 298 return mlir::ChangeResult::NoChange; 299 stateMap.clear(); 300 return mlir::ChangeResult::Change; 301 } 302 303 mlir::ChangeResult LatticePoint::set(mlir::Value value, AllocationState state) { 304 if (stateMap.count(value)) { 305 // already in map 306 AllocationState &oldState = stateMap[value]; 307 if (oldState != state) { 308 stateMap[value] = state; 309 return mlir::ChangeResult::Change; 310 } 311 return mlir::ChangeResult::NoChange; 312 } 313 stateMap.insert({value, state}); 314 return mlir::ChangeResult::Change; 315 } 316 317 /// Get values which were allocated in this function and always freed before 318 /// the function returns 319 void LatticePoint::appendFreedValues(llvm::DenseSet<mlir::Value> &out) const { 320 for (auto &[value, state] : stateMap) { 321 if (state == AllocationState::Freed) 322 out.insert(value); 323 } 324 } 325 326 std::optional<AllocationState> LatticePoint::get(mlir::Value val) const { 327 auto it = stateMap.find(val); 328 if (it == stateMap.end()) 329 return {}; 330 return it->second; 331 } 332 333 void AllocationAnalysis::visitOperation(mlir::Operation *op, 334 const LatticePoint &before, 335 LatticePoint *after) { 336 LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op 337 << "\n"); 338 LLVM_DEBUG(llvm::dbgs() << "--Lattice in: " << before << "\n"); 339 340 // propagate before -> after 341 mlir::ChangeResult changed = after->join(before); 342 343 if (auto allocmem = mlir::dyn_cast<fir::AllocMemOp>(op)) { 344 assert(op->getNumResults() == 1 && "fir.allocmem has one result"); 345 auto attr = op->getAttrOfType<fir::MustBeHeapAttr>( 346 fir::MustBeHeapAttr::getAttrName()); 347 if (attr && attr.getValue()) { 348 LLVM_DEBUG(llvm::dbgs() << "--Found fir.must_be_heap: skipping\n"); 349 // skip allocation marked not to be moved 350 return; 351 } 352 353 auto retTy = allocmem.getAllocatedType(); 354 if (!mlir::isa<fir::SequenceType>(retTy)) { 355 LLVM_DEBUG(llvm::dbgs() 356 << "--Allocation is not for an array: skipping\n"); 357 return; 358 } 359 360 mlir::Value result = op->getResult(0); 361 changed |= after->set(result, AllocationState::Allocated); 362 } else if (mlir::isa<fir::FreeMemOp>(op)) { 363 assert(op->getNumOperands() == 1 && "fir.freemem has one operand"); 364 mlir::Value operand = op->getOperand(0); 365 std::optional<AllocationState> operandState = before.get(operand); 366 if (operandState && *operandState == AllocationState::Allocated) { 367 // don't tag things not allocated in this function as freed, so that we 368 // don't think they are candidates for moving to the stack 369 changed |= after->set(operand, AllocationState::Freed); 370 } 371 } else if (mlir::isa<fir::ResultOp>(op)) { 372 mlir::Operation *parent = op->getParentOp(); 373 LatticePoint *parentLattice = getLattice(parent); 374 assert(parentLattice); 375 mlir::ChangeResult parentChanged = parentLattice->join(*after); 376 propagateIfChanged(parentLattice, parentChanged); 377 } 378 379 // we pass lattices straight through fir.call because called functions should 380 // not deallocate flang-generated array temporaries 381 382 LLVM_DEBUG(llvm::dbgs() << "--Lattice out: " << *after << "\n"); 383 propagateIfChanged(after, changed); 384 } 385 386 void AllocationAnalysis::setToEntryState(LatticePoint *lattice) { 387 propagateIfChanged(lattice, lattice->reset()); 388 } 389 390 /// Mostly a copy of AbstractDenseLattice::processOperation - the difference 391 /// being that call operations are passed through to the transfer function 392 void AllocationAnalysis::processOperation(mlir::Operation *op) { 393 // If the containing block is not executable, bail out. 394 if (!getOrCreateFor<mlir::dataflow::Executable>(op, op->getBlock())->isLive()) 395 return; 396 397 // Get the dense lattice to update 398 mlir::dataflow::AbstractDenseLattice *after = getLattice(op); 399 400 // If this op implements region control-flow, then control-flow dictates its 401 // transfer function. 402 if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) 403 return visitRegionBranchOperation(op, branch, after); 404 405 // pass call operations through to the transfer function 406 407 // Get the dense state before the execution of the op. 408 const mlir::dataflow::AbstractDenseLattice *before; 409 if (mlir::Operation *prev = op->getPrevNode()) 410 before = getLatticeFor(op, prev); 411 else 412 before = getLatticeFor(op, op->getBlock()); 413 414 /// Invoke the operation transfer function 415 visitOperationImpl(op, *before, after); 416 } 417 418 mlir::LogicalResult 419 StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) { 420 assert(mlir::isa<mlir::func::FuncOp>(func)); 421 size_t nAllocs = 0; 422 func->walk([&nAllocs](fir::AllocMemOp) { nAllocs++; }); 423 // don't bother with the analysis if there are no heap allocations 424 if (nAllocs == 0) 425 return mlir::success(); 426 if ((maxAllocsPerFunc != 0) && (nAllocs > maxAllocsPerFunc)) { 427 LLVM_DEBUG(llvm::dbgs() << "Skipping stack arrays for function with " 428 << nAllocs << " heap allocations"); 429 return mlir::success(); 430 } 431 432 mlir::DataFlowSolver solver; 433 // constant propagation is required for dead code analysis, dead code analysis 434 // is required to mark blocks live (required for mlir dense dfa) 435 solver.load<mlir::dataflow::SparseConstantPropagation>(); 436 solver.load<mlir::dataflow::DeadCodeAnalysis>(); 437 438 auto [it, inserted] = funcMaps.try_emplace(func); 439 AllocMemMap &candidateOps = it->second; 440 441 solver.load<AllocationAnalysis>(); 442 if (failed(solver.initializeAndRun(func))) { 443 llvm::errs() << "DataFlowSolver failed!"; 444 return mlir::failure(); 445 } 446 447 LatticePoint point{func}; 448 auto joinOperationLattice = [&](mlir::Operation *op) { 449 const LatticePoint *lattice = solver.lookupState<LatticePoint>(op); 450 // there will be no lattice for an unreachable block 451 if (lattice) 452 (void)point.join(*lattice); 453 }; 454 func->walk([&](mlir::func::ReturnOp child) { joinOperationLattice(child); }); 455 func->walk([&](fir::UnreachableOp child) { joinOperationLattice(child); }); 456 llvm::DenseSet<mlir::Value> freedValues; 457 point.appendFreedValues(freedValues); 458 459 // We only replace allocations which are definately freed on all routes 460 // through the function because otherwise the allocation may have an intende 461 // lifetime longer than the current stack frame (e.g. a heap allocation which 462 // is then freed by another function). 463 for (mlir::Value freedValue : freedValues) { 464 fir::AllocMemOp allocmem = freedValue.getDefiningOp<fir::AllocMemOp>(); 465 InsertionPoint insertionPoint = 466 AllocMemConversion::findAllocaInsertionPoint(allocmem); 467 if (insertionPoint) 468 candidateOps.insert({allocmem, insertionPoint}); 469 } 470 471 LLVM_DEBUG(for (auto [allocMemOp, _] 472 : candidateOps) { 473 llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n'; 474 }); 475 return mlir::success(); 476 } 477 478 const StackArraysAnalysisWrapper::AllocMemMap * 479 StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) { 480 if (!funcMaps.contains(func)) 481 if (mlir::failed(analyseFunction(func))) 482 return nullptr; 483 return &funcMaps[func]; 484 } 485 486 /// Restore the old allocation type exected by existing code 487 static mlir::Value convertAllocationType(mlir::PatternRewriter &rewriter, 488 const mlir::Location &loc, 489 mlir::Value heap, mlir::Value stack) { 490 mlir::Type heapTy = heap.getType(); 491 mlir::Type stackTy = stack.getType(); 492 493 if (heapTy == stackTy) 494 return stack; 495 496 fir::HeapType firHeapTy = mlir::cast<fir::HeapType>(heapTy); 497 LLVM_ATTRIBUTE_UNUSED fir::ReferenceType firRefTy = 498 mlir::cast<fir::ReferenceType>(stackTy); 499 assert(firHeapTy.getElementType() == firRefTy.getElementType() && 500 "Allocations must have the same type"); 501 502 auto insertionPoint = rewriter.saveInsertionPoint(); 503 rewriter.setInsertionPointAfter(stack.getDefiningOp()); 504 mlir::Value conv = 505 rewriter.create<fir::ConvertOp>(loc, firHeapTy, stack).getResult(); 506 rewriter.restoreInsertionPoint(insertionPoint); 507 return conv; 508 } 509 510 mlir::LogicalResult 511 AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem, 512 mlir::PatternRewriter &rewriter) const { 513 auto oldInsertionPt = rewriter.saveInsertionPoint(); 514 // add alloca operation 515 std::optional<fir::AllocaOp> alloca = insertAlloca(allocmem, rewriter); 516 rewriter.restoreInsertionPoint(oldInsertionPt); 517 if (!alloca) 518 return mlir::failure(); 519 520 // remove freemem operations 521 llvm::SmallVector<mlir::Operation *> erases; 522 for (mlir::Operation *user : allocmem.getOperation()->getUsers()) 523 if (mlir::isa<fir::FreeMemOp>(user)) 524 erases.push_back(user); 525 // now we are done iterating the users, it is safe to mutate them 526 for (mlir::Operation *erase : erases) 527 rewriter.eraseOp(erase); 528 529 // replace references to heap allocation with references to stack allocation 530 mlir::Value newValue = convertAllocationType( 531 rewriter, allocmem.getLoc(), allocmem.getResult(), alloca->getResult()); 532 rewriter.replaceAllUsesWith(allocmem.getResult(), newValue); 533 534 // remove allocmem operation 535 rewriter.eraseOp(allocmem.getOperation()); 536 537 return mlir::success(); 538 } 539 540 static bool isInLoop(mlir::Block *block) { 541 return mlir::LoopLikeOpInterface::blockIsInLoop(block); 542 } 543 544 static bool isInLoop(mlir::Operation *op) { 545 return isInLoop(op->getBlock()) || 546 op->getParentOfType<mlir::LoopLikeOpInterface>(); 547 } 548 549 InsertionPoint 550 AllocMemConversion::findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc) { 551 // Ideally the alloca should be inserted at the end of the function entry 552 // block so that we do not allocate stack space in a loop. However, 553 // the operands to the alloca may not be available that early, so insert it 554 // after the last operand becomes available 555 // If the old allocmem op was in an openmp region then it should not be moved 556 // outside of that 557 LLVM_DEBUG(llvm::dbgs() << "StackArrays: findAllocaInsertionPoint: " 558 << oldAlloc << "\n"); 559 560 // check that an Operation or Block we are about to return is not in a loop 561 auto checkReturn = [&](auto *point) -> InsertionPoint { 562 if (isInLoop(point)) { 563 mlir::Operation *oldAllocOp = oldAlloc.getOperation(); 564 if (isInLoop(oldAllocOp)) { 565 // where we want to put it is in a loop, and even the old location is in 566 // a loop. Give up. 567 return findAllocaLoopInsertionPoint(oldAlloc); 568 } 569 return {oldAllocOp}; 570 } 571 return {point}; 572 }; 573 574 auto oldOmpRegion = 575 oldAlloc->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>(); 576 577 // Find when the last operand value becomes available 578 mlir::Block *operandsBlock = nullptr; 579 mlir::Operation *lastOperand = nullptr; 580 for (mlir::Value operand : oldAlloc.getOperands()) { 581 LLVM_DEBUG(llvm::dbgs() << "--considering operand " << operand << "\n"); 582 mlir::Operation *op = operand.getDefiningOp(); 583 if (!op) 584 return checkReturn(oldAlloc.getOperation()); 585 if (!operandsBlock) 586 operandsBlock = op->getBlock(); 587 else if (operandsBlock != op->getBlock()) { 588 LLVM_DEBUG(llvm::dbgs() 589 << "----operand declared in a different block!\n"); 590 // Operation::isBeforeInBlock requires the operations to be in the same 591 // block. The best we can do is the location of the allocmem. 592 return checkReturn(oldAlloc.getOperation()); 593 } 594 if (!lastOperand || lastOperand->isBeforeInBlock(op)) 595 lastOperand = op; 596 } 597 598 if (lastOperand) { 599 // there were value operands to the allocmem so insert after the last one 600 LLVM_DEBUG(llvm::dbgs() 601 << "--Placing after last operand: " << *lastOperand << "\n"); 602 // check we aren't moving out of an omp region 603 auto lastOpOmpRegion = 604 lastOperand->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>(); 605 if (lastOpOmpRegion == oldOmpRegion) 606 return checkReturn(lastOperand); 607 // Presumably this happened because the operands became ready before the 608 // start of this openmp region. (lastOpOmpRegion != oldOmpRegion) should 609 // imply that oldOmpRegion comes after lastOpOmpRegion. 610 return checkReturn(oldOmpRegion.getAllocaBlock()); 611 } 612 613 // There were no value operands to the allocmem so we are safe to insert it 614 // as early as we want 615 616 // handle openmp case 617 if (oldOmpRegion) 618 return checkReturn(oldOmpRegion.getAllocaBlock()); 619 620 // fall back to the function entry block 621 mlir::func::FuncOp func = oldAlloc->getParentOfType<mlir::func::FuncOp>(); 622 assert(func && "This analysis is run on func.func"); 623 mlir::Block &entryBlock = func.getBlocks().front(); 624 LLVM_DEBUG(llvm::dbgs() << "--Placing at the start of func entry block\n"); 625 return checkReturn(&entryBlock); 626 } 627 628 InsertionPoint 629 AllocMemConversion::findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc) { 630 mlir::Operation *oldAllocOp = oldAlloc; 631 // This is only called as a last resort. We should try to insert at the 632 // location of the old allocation, which is inside of a loop, using 633 // llvm.stacksave/llvm.stackrestore 634 635 // find freemem ops 636 llvm::SmallVector<mlir::Operation *, 1> freeOps; 637 for (mlir::Operation *user : oldAllocOp->getUsers()) 638 if (mlir::isa<fir::FreeMemOp>(user)) 639 freeOps.push_back(user); 640 assert(freeOps.size() && "DFA should only return freed memory"); 641 642 // Don't attempt to reason about a stacksave/stackrestore between different 643 // blocks 644 for (mlir::Operation *free : freeOps) 645 if (free->getBlock() != oldAllocOp->getBlock()) 646 return {nullptr}; 647 648 // Check that there aren't any other stack allocations in between the 649 // stack save and stack restore 650 // note: for flang generated temporaries there should only be one free op 651 for (mlir::Operation *free : freeOps) { 652 for (mlir::Operation *op = oldAlloc; op && op != free; 653 op = op->getNextNode()) { 654 if (mlir::isa<fir::AllocaOp>(op)) 655 return {nullptr}; 656 } 657 } 658 659 return InsertionPoint{oldAllocOp, /*shouldStackSaveRestore=*/true}; 660 } 661 662 std::optional<fir::AllocaOp> 663 AllocMemConversion::insertAlloca(fir::AllocMemOp &oldAlloc, 664 mlir::PatternRewriter &rewriter) const { 665 auto it = candidateOps.find(oldAlloc.getOperation()); 666 if (it == candidateOps.end()) 667 return {}; 668 InsertionPoint insertionPoint = it->second; 669 if (!insertionPoint) 670 return {}; 671 672 if (insertionPoint.shouldSaveRestoreStack()) 673 insertStackSaveRestore(oldAlloc, rewriter); 674 675 mlir::Location loc = oldAlloc.getLoc(); 676 mlir::Type varTy = oldAlloc.getInType(); 677 if (mlir::Operation *op = insertionPoint.tryGetOperation()) { 678 rewriter.setInsertionPointAfter(op); 679 } else { 680 mlir::Block *block = insertionPoint.tryGetBlock(); 681 assert(block && "There must be a valid insertion point"); 682 rewriter.setInsertionPointToStart(block); 683 } 684 685 auto unpackName = [](std::optional<llvm::StringRef> opt) -> llvm::StringRef { 686 if (opt) 687 return *opt; 688 return {}; 689 }; 690 691 llvm::StringRef uniqName = unpackName(oldAlloc.getUniqName()); 692 llvm::StringRef bindcName = unpackName(oldAlloc.getBindcName()); 693 return rewriter.create<fir::AllocaOp>(loc, varTy, uniqName, bindcName, 694 oldAlloc.getTypeparams(), 695 oldAlloc.getShape()); 696 } 697 698 void AllocMemConversion::insertStackSaveRestore( 699 fir::AllocMemOp &oldAlloc, mlir::PatternRewriter &rewriter) const { 700 auto oldPoint = rewriter.saveInsertionPoint(); 701 auto mod = oldAlloc->getParentOfType<mlir::ModuleOp>(); 702 fir::FirOpBuilder builder{rewriter, mod}; 703 704 mlir::func::FuncOp stackSaveFn = fir::factory::getLlvmStackSave(builder); 705 mlir::SymbolRefAttr stackSaveSym = 706 builder.getSymbolRefAttr(stackSaveFn.getName()); 707 708 builder.setInsertionPoint(oldAlloc); 709 mlir::Value sp = 710 builder 711 .create<fir::CallOp>(oldAlloc.getLoc(), 712 stackSaveFn.getFunctionType().getResults(), 713 stackSaveSym, mlir::ValueRange{}) 714 .getResult(0); 715 716 mlir::func::FuncOp stackRestoreFn = 717 fir::factory::getLlvmStackRestore(builder); 718 mlir::SymbolRefAttr stackRestoreSym = 719 builder.getSymbolRefAttr(stackRestoreFn.getName()); 720 721 for (mlir::Operation *user : oldAlloc->getUsers()) { 722 if (mlir::isa<fir::FreeMemOp>(user)) { 723 builder.setInsertionPoint(user); 724 builder.create<fir::CallOp>(user->getLoc(), 725 stackRestoreFn.getFunctionType().getResults(), 726 stackRestoreSym, mlir::ValueRange{sp}); 727 } 728 } 729 730 rewriter.restoreInsertionPoint(oldPoint); 731 } 732 733 StackArraysPass::StackArraysPass(const StackArraysPass &pass) 734 : fir::impl::StackArraysBase<StackArraysPass>(pass) {} 735 736 llvm::StringRef StackArraysPass::getDescription() const { 737 return "Move heap allocated array temporaries to the stack"; 738 } 739 740 void StackArraysPass::runOnOperation() { 741 mlir::ModuleOp mod = getOperation(); 742 743 mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); }); 744 } 745 746 void StackArraysPass::runOnFunc(mlir::Operation *func) { 747 assert(mlir::isa<mlir::func::FuncOp>(func)); 748 749 auto &analysis = getAnalysis<StackArraysAnalysisWrapper>(); 750 const StackArraysAnalysisWrapper::AllocMemMap *candidateOps = 751 analysis.getCandidateOps(func); 752 if (!candidateOps) { 753 signalPassFailure(); 754 return; 755 } 756 757 if (candidateOps->empty()) 758 return; 759 runCount += candidateOps->size(); 760 761 llvm::SmallVector<mlir::Operation *> opsToConvert; 762 opsToConvert.reserve(candidateOps->size()); 763 for (auto [op, _] : *candidateOps) 764 opsToConvert.push_back(op); 765 766 mlir::MLIRContext &context = getContext(); 767 mlir::RewritePatternSet patterns(&context); 768 mlir::GreedyRewriteConfig config; 769 // prevent the pattern driver form merging blocks 770 config.enableRegionSimplification = false; 771 772 patterns.insert<AllocMemConversion>(&context, *candidateOps); 773 if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert, 774 std::move(patterns), config))) { 775 mlir::emitError(func->getLoc(), "error in stack arrays optimization\n"); 776 signalPassFailure(); 777 } 778 } 779