//===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LoopTermFold.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Config/llvm-config.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include #include using namespace llvm; #define DEBUG_TYPE "loop-term-fold" STATISTIC(NumTermFold, "Number of terminating condition fold recognized and performed"); static std::optional> canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, const LoopInfo &LI, const TargetTransformInfo &TTI) { if (!L->isInnermost()) { LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n"); return std::nullopt; } // Only inspect on simple loop structure if (!L->isLoopSimplifyForm()) { LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n"); return std::nullopt; } if (!SE.hasLoopInvariantBackedgeTakenCount(L)) { LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n"); return std::nullopt; } BasicBlock *LoopLatch = L->getLoopLatch(); BranchInst *BI = dyn_cast(LoopLatch->getTerminator()); if (!BI || BI->isUnconditional()) return std::nullopt; auto *TermCond = dyn_cast(BI->getCondition()); if (!TermCond) { LLVM_DEBUG( dbgs() << "Cannot fold on branching condition that is not an ICmpInst"); return std::nullopt; } if (!TermCond->hasOneUse()) { LLVM_DEBUG( dbgs() << "Cannot replace terminating condition with more than one use\n"); return std::nullopt; } BinaryOperator *LHS = dyn_cast(TermCond->getOperand(0)); Value *RHS = TermCond->getOperand(1); if (!LHS || !L->isLoopInvariant(RHS)) // We could pattern match the inverse form of the icmp, but that is // non-canonical, and this pass is running *very* late in the pipeline. return std::nullopt; // Find the IV used by the current exit condition. PHINode *ToFold; Value *ToFoldStart, *ToFoldStep; if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep)) return std::nullopt; // Ensure the simple recurrence is a part of the current loop. if (ToFold->getParent() != L->getHeader()) return std::nullopt; // If that IV isn't dead after we rewrite the exit condition in terms of // another IV, there's no point in doing the transform. if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond)) return std::nullopt; // Inserting instructions in the preheader has a runtime cost, scale // the allowed cost with the loops trip count as best we can. const unsigned ExpansionBudget = [&]() { unsigned Budget = 2 * SCEVCheapExpansionBudget; if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L)) return std::min(Budget, SmallTC); if (std::optional SmallTC = getLoopEstimatedTripCount(L)) return std::min(Budget, *SmallTC); // Unknown trip count, assume long running by default. return Budget; }(); const SCEV *BECount = SE.getBackedgeTakenCount(L); const DataLayout &DL = L->getHeader()->getDataLayout(); SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); PHINode *ToHelpFold = nullptr; const SCEV *TermValueS = nullptr; bool MustDropPoison = false; auto InsertPt = L->getLoopPreheader()->getTerminator(); for (PHINode &PN : L->getHeader()->phis()) { if (ToFold == &PN) continue; if (!SE.isSCEVable(PN.getType())) { LLVM_DEBUG(dbgs() << "IV of phi '" << PN << "' is not SCEV-able, not qualified for the " "terminating condition folding.\n"); continue; } const SCEVAddRecExpr *AddRec = dyn_cast(SE.getSCEV(&PN)); // Only speculate on affine AddRec if (!AddRec || !AddRec->isAffine()) { LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN << "' is not an affine add recursion, not qualified " "for the terminating condition folding.\n"); continue; } // Check that we can compute the value of AddRec on the exiting iteration // without soundness problems. evaluateAtIteration internally needs // to multiply the stride of the iteration number - which may wrap around. // The issue here is subtle because computing the result accounting for // wrap is insufficient. In order to use the result in an exit test, we // must also know that AddRec doesn't take the same value on any previous // iteration. The simplest case to consider is a candidate IV which is // narrower than the trip count (and thus original IV), but this can // also happen due to non-unit strides on the candidate IVs. if (!AddRec->hasNoSelfWrap() || !SE.isKnownNonZero(AddRec->getStepRecurrence(SE))) continue; const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE); const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE); if (!Expander.isSafeToExpand(TermValueSLocal)) { LLVM_DEBUG( dbgs() << "Is not safe to expand terminating value for phi node" << PN << "\n"); continue; } if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget, &TTI, InsertPt)) { LLVM_DEBUG( dbgs() << "Is too expensive to expand terminating value for phi node" << PN << "\n"); continue; } // The candidate IV may have been otherwise dead and poison from the // very first iteration. If we can't disprove that, we can't use the IV. if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) { LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " << PN << "\n"); continue; } // The candidate IV may become poison on the last iteration. If this // value is not branched on, this is a well defined program. We're // about to add a new use to this IV, and we have to ensure we don't // insert UB which didn't previously exist. bool MustDropPoisonLocal = false; Instruction *PostIncV = cast(PN.getIncomingValueForBlock(LoopLatch)); if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(), &DT)) { LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" << PN << "\n"); // If this is a complex recurrance with multiple instructions computing // the backedge value, we might need to strip poison flags from all of // them. if (PostIncV->getOperand(0) != &PN) continue; // In order to perform the transform, we need to drop the poison // generating flags on this instruction (if any). MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags(); } // We pick the last legal alternate IV. We could expore choosing an optimal // alternate IV if we had a decent heuristic to do so. ToHelpFold = &PN; TermValueS = TermValueSLocal; MustDropPoison = MustDropPoisonLocal; } LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs() << "Cannot find other AddRec IV to help folding\n";); LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs() << "\nFound loop that can fold terminating condition\n" << " BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n" << " TermCond: " << *TermCond << "\n" << " BrandInst: " << *BI << "\n" << " ToFold: " << *ToFold << "\n" << " ToHelpFold: " << *ToHelpFold << "\n"); if (!ToFold || !ToHelpFold) return std::nullopt; return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison); } static bool RunTermFold(Loop *L, ScalarEvolution &SE, DominatorTree &DT, LoopInfo &LI, const TargetTransformInfo &TTI, TargetLibraryInfo &TLI, MemorySSA *MSSA) { std::unique_ptr MSSAU; if (MSSA) MSSAU = std::make_unique(MSSA); auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI); if (!Opt) return false; auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt; NumTermFold++; BasicBlock *LoopPreheader = L->getLoopPreheader(); BasicBlock *LoopLatch = L->getLoopLatch(); (void)ToFold; LLVM_DEBUG(dbgs() << "To fold phi-node:\n" << *ToFold << "\n" << "New term-cond phi-node:\n" << *ToHelpFold << "\n"); Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader); (void)StartValue; Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch); // See comment in canFoldTermCondOfLoop on why this is sufficient. if (MustDrop) cast(LoopValue)->dropPoisonGeneratingFlags(); // SCEVExpander for both use in preheader and latch const DataLayout &DL = L->getHeader()->getDataLayout(); SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); assert(Expander.isSafeToExpand(TermValueS) && "Terminating value was checked safe in canFoldTerminatingCondition"); // Create new terminating value at loop preheader Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(), LoopPreheader->getTerminator()); LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n" << *StartValue << "\n" << "Terminating value of new term-cond phi-node:\n" << *TermValue << "\n"); // Create new terminating condition at loop latch BranchInst *BI = cast(LoopLatch->getTerminator()); ICmpInst *OldTermCond = cast(BI->getCondition()); IRBuilder<> LatchBuilder(LoopLatch->getTerminator()); Value *NewTermCond = LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue, "lsr_fold_term_cond.replaced_term_cond"); // Swap successors to exit loop body if IV equals to new TermValue if (BI->getSuccessor(0) == L->getHeader()) BI->swapSuccessors(); LLVM_DEBUG(dbgs() << "Old term-cond:\n" << *OldTermCond << "\n" << "New term-cond:\n" << *NewTermCond << "\n"); BI->setCondition(NewTermCond); Expander.clear(); OldTermCond->eraseFromParent(); DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get()); return true; } namespace { class LoopTermFold : public LoopPass { public: static char ID; // Pass ID, replacement for typeid LoopTermFold(); private: bool runOnLoop(Loop *L, LPPassManager &LPM) override; void getAnalysisUsage(AnalysisUsage &AU) const override; }; } // end anonymous namespace LoopTermFold::LoopTermFold() : LoopPass(ID) { initializeLoopTermFoldPass(*PassRegistry::getPassRegistry()); } void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired(); AU.addPreserved(); AU.addPreservedID(LoopSimplifyID); AU.addRequiredID(LoopSimplifyID); AU.addRequired(); AU.addPreserved(); AU.addRequired(); AU.addPreserved(); AU.addRequired(); AU.addRequired(); AU.addPreserved(); } bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) { if (skipLoop(L)) return false; auto &SE = getAnalysis().getSE(); auto &DT = getAnalysis().getDomTree(); auto &LI = getAnalysis().getLoopInfo(); const auto &TTI = getAnalysis().getTTI( *L->getHeader()->getParent()); auto &TLI = getAnalysis().getTLI( *L->getHeader()->getParent()); auto *MSSAAnalysis = getAnalysisIfAvailable(); MemorySSA *MSSA = nullptr; if (MSSAAnalysis) MSSA = &MSSAAnalysis->getMSSA(); return RunTermFold(L, SE, DT, LI, TTI, TLI, MSSA); } PreservedAnalyses LoopTermFoldPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &) { if (!RunTermFold(&L, AR.SE, AR.DT, AR.LI, AR.TTI, AR.TLI, AR.MSSA)) return PreservedAnalyses::all(); auto PA = getLoopPassPreservedAnalyses(); if (AR.MSSA) PA.preserve(); return PA; } char LoopTermFold::ID = 0; INITIALIZE_PASS_BEGIN(LoopTermFold, "loop-term-fold", "Loop Terminator Folding", false, false) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopSimplify) INITIALIZE_PASS_END(LoopTermFold, "loop-term-fold", "Loop Terminator Folding", false, false) Pass *llvm::createLoopTermFoldPass() { return new LoopTermFold(); }