17b08b436SNathan Gauër //===- ConvergenceRegionAnalysis.h -----------------------------*- C++ -*--===// 27b08b436SNathan Gauër // 37b08b436SNathan Gauër // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 47b08b436SNathan Gauër // See https://llvm.org/LICENSE.txt for license information. 57b08b436SNathan Gauër // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 67b08b436SNathan Gauër // 77b08b436SNathan Gauër //===----------------------------------------------------------------------===// 87b08b436SNathan Gauër // 97b08b436SNathan Gauër // The analysis determines the convergence region for each basic block of 107b08b436SNathan Gauër // the module, and provides a tree-like structure describing the region 117b08b436SNathan Gauër // hierarchy. 127b08b436SNathan Gauër // 137b08b436SNathan Gauër //===----------------------------------------------------------------------===// 147b08b436SNathan Gauër 157b08b436SNathan Gauër #include "SPIRVConvergenceRegionAnalysis.h" 167b08b436SNathan Gauër #include "llvm/Analysis/LoopInfo.h" 177b08b436SNathan Gauër #include "llvm/IR/Dominators.h" 187b08b436SNathan Gauër #include "llvm/IR/IntrinsicInst.h" 197b08b436SNathan Gauër #include "llvm/InitializePasses.h" 207b08b436SNathan Gauër #include "llvm/Transforms/Utils/LoopSimplify.h" 217b08b436SNathan Gauër #include <optional> 227b08b436SNathan Gauër #include <queue> 237b08b436SNathan Gauër 247b08b436SNathan Gauër #define DEBUG_TYPE "spirv-convergence-region-analysis" 257b08b436SNathan Gauër 267b08b436SNathan Gauër using namespace llvm; 277b08b436SNathan Gauër 287b08b436SNathan Gauër namespace llvm { 297b08b436SNathan Gauër void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &); 307b08b436SNathan Gauër } // namespace llvm 317b08b436SNathan Gauër 327b08b436SNathan Gauër INITIALIZE_PASS_BEGIN(SPIRVConvergenceRegionAnalysisWrapperPass, 337b08b436SNathan Gauër "convergence-region", 347b08b436SNathan Gauër "SPIRV convergence regions analysis", true, true) 357b08b436SNathan Gauër INITIALIZE_PASS_DEPENDENCY(LoopSimplify) 367b08b436SNathan Gauër INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 377b08b436SNathan Gauër INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 387b08b436SNathan Gauër INITIALIZE_PASS_END(SPIRVConvergenceRegionAnalysisWrapperPass, 397b08b436SNathan Gauër "convergence-region", "SPIRV convergence regions analysis", 407b08b436SNathan Gauër true, true) 417b08b436SNathan Gauër 427b08b436SNathan Gauër namespace llvm { 437b08b436SNathan Gauër namespace SPIRV { 447b08b436SNathan Gauër namespace { 457b08b436SNathan Gauër 467b08b436SNathan Gauër template <typename BasicBlockType, typename IntrinsicInstType> 477b08b436SNathan Gauër std::optional<IntrinsicInstType *> 487b08b436SNathan Gauër getConvergenceTokenInternal(BasicBlockType *BB) { 497b08b436SNathan Gauër static_assert(std::is_const_v<IntrinsicInstType> == 507b08b436SNathan Gauër std::is_const_v<BasicBlockType>, 517b08b436SNathan Gauër "Constness must match between input and output."); 527b08b436SNathan Gauër static_assert(std::is_same_v<BasicBlock, std::remove_const_t<BasicBlockType>>, 537b08b436SNathan Gauër "Input must be a basic block."); 547b08b436SNathan Gauër static_assert( 557b08b436SNathan Gauër std::is_same_v<IntrinsicInst, std::remove_const_t<IntrinsicInstType>>, 567b08b436SNathan Gauër "Output type must be an intrinsic instruction."); 577b08b436SNathan Gauër 587b08b436SNathan Gauër for (auto &I : *BB) { 59*77e6f434SSameer Sahasrabuddhe if (auto *CI = dyn_cast<ConvergenceControlInst>(&I)) { 60*77e6f434SSameer Sahasrabuddhe // Make sure that the anchor or entry intrinsics did not reach here with a 61*77e6f434SSameer Sahasrabuddhe // parent token. This should have failed the verifier. 62*77e6f434SSameer Sahasrabuddhe assert(CI->isLoop() || 63*77e6f434SSameer Sahasrabuddhe !CI->getOperandBundle(LLVMContext::OB_convergencectrl)); 64*77e6f434SSameer Sahasrabuddhe return CI; 657b08b436SNathan Gauër } 667b08b436SNathan Gauër 677b08b436SNathan Gauër if (auto *CI = dyn_cast<CallInst>(&I)) { 687b08b436SNathan Gauër auto OB = CI->getOperandBundle(LLVMContext::OB_convergencectrl); 697b08b436SNathan Gauër if (!OB.has_value()) 707b08b436SNathan Gauër continue; 717b08b436SNathan Gauër return dyn_cast<IntrinsicInst>(OB.value().Inputs[0]); 727b08b436SNathan Gauër } 737b08b436SNathan Gauër } 747b08b436SNathan Gauër 757b08b436SNathan Gauër return std::nullopt; 767b08b436SNathan Gauër } 777b08b436SNathan Gauër 787b08b436SNathan Gauër // Given a ConvergenceRegion tree with |Start| as its root, finds the smallest 797b08b436SNathan Gauër // region |Entry| belongs to. If |Entry| does not belong to the region defined 807b08b436SNathan Gauër // by |Start|, this function returns |nullptr|. 817b08b436SNathan Gauër ConvergenceRegion *findParentRegion(ConvergenceRegion *Start, 827b08b436SNathan Gauër BasicBlock *Entry) { 837b08b436SNathan Gauër ConvergenceRegion *Candidate = nullptr; 847b08b436SNathan Gauër ConvergenceRegion *NextCandidate = Start; 857b08b436SNathan Gauër 867b08b436SNathan Gauër while (Candidate != NextCandidate && NextCandidate != nullptr) { 877b08b436SNathan Gauër Candidate = NextCandidate; 887b08b436SNathan Gauër NextCandidate = nullptr; 897b08b436SNathan Gauër 907b08b436SNathan Gauër // End of the search, we can return. 917b08b436SNathan Gauër if (Candidate->Children.size() == 0) 927b08b436SNathan Gauër return Candidate; 937b08b436SNathan Gauër 947b08b436SNathan Gauër for (auto *Child : Candidate->Children) { 957b08b436SNathan Gauër if (Child->Blocks.count(Entry) != 0) { 967b08b436SNathan Gauër NextCandidate = Child; 977b08b436SNathan Gauër break; 987b08b436SNathan Gauër } 997b08b436SNathan Gauër } 1007b08b436SNathan Gauër } 1017b08b436SNathan Gauër 1027b08b436SNathan Gauër return Candidate; 1037b08b436SNathan Gauër } 1047b08b436SNathan Gauër 1057b08b436SNathan Gauër } // anonymous namespace 1067b08b436SNathan Gauër 1077b08b436SNathan Gauër std::optional<IntrinsicInst *> getConvergenceToken(BasicBlock *BB) { 1087b08b436SNathan Gauër return getConvergenceTokenInternal<BasicBlock, IntrinsicInst>(BB); 1097b08b436SNathan Gauër } 1107b08b436SNathan Gauër 1117b08b436SNathan Gauër std::optional<const IntrinsicInst *> getConvergenceToken(const BasicBlock *BB) { 1127b08b436SNathan Gauër return getConvergenceTokenInternal<const BasicBlock, const IntrinsicInst>(BB); 1137b08b436SNathan Gauër } 1147b08b436SNathan Gauër 1157b08b436SNathan Gauër ConvergenceRegion::ConvergenceRegion(DominatorTree &DT, LoopInfo &LI, 1167b08b436SNathan Gauër Function &F) 1177b08b436SNathan Gauër : DT(DT), LI(LI), Parent(nullptr) { 1187b08b436SNathan Gauër Entry = &F.getEntryBlock(); 1197b08b436SNathan Gauër ConvergenceToken = getConvergenceToken(Entry); 1207b08b436SNathan Gauër for (auto &B : F) { 1217b08b436SNathan Gauër Blocks.insert(&B); 1227b08b436SNathan Gauër if (isa<ReturnInst>(B.getTerminator())) 1237b08b436SNathan Gauër Exits.insert(&B); 1247b08b436SNathan Gauër } 1257b08b436SNathan Gauër } 1267b08b436SNathan Gauër 1277b08b436SNathan Gauër ConvergenceRegion::ConvergenceRegion( 1287b08b436SNathan Gauër DominatorTree &DT, LoopInfo &LI, 1297b08b436SNathan Gauër std::optional<IntrinsicInst *> ConvergenceToken, BasicBlock *Entry, 1307b08b436SNathan Gauër SmallPtrSet<BasicBlock *, 8> &&Blocks, SmallPtrSet<BasicBlock *, 2> &&Exits) 1317b08b436SNathan Gauër : DT(DT), LI(LI), ConvergenceToken(ConvergenceToken), Entry(Entry), 1327b08b436SNathan Gauër Exits(std::move(Exits)), Blocks(std::move(Blocks)) { 133e83adfe5SChris B for ([[maybe_unused]] auto *BB : this->Exits) 1347b08b436SNathan Gauër assert(this->Blocks.count(BB) != 0); 1357b08b436SNathan Gauër assert(this->Blocks.count(this->Entry) != 0); 1367b08b436SNathan Gauër } 1377b08b436SNathan Gauër 1387b08b436SNathan Gauër void ConvergenceRegion::releaseMemory() { 1397b08b436SNathan Gauër // Parent memory is owned by the parent. 1407b08b436SNathan Gauër Parent = nullptr; 1417b08b436SNathan Gauër for (auto *Child : Children) { 1427b08b436SNathan Gauër Child->releaseMemory(); 1437b08b436SNathan Gauër delete Child; 1447b08b436SNathan Gauër } 1457b08b436SNathan Gauër Children.resize(0); 1467b08b436SNathan Gauër } 1477b08b436SNathan Gauër 1487b08b436SNathan Gauër void ConvergenceRegion::dump(const unsigned IndentSize) const { 1497b08b436SNathan Gauër const std::string Indent(IndentSize, '\t'); 1507b08b436SNathan Gauër dbgs() << Indent << this << ": {\n"; 1517b08b436SNathan Gauër dbgs() << Indent << " Parent: " << Parent << "\n"; 1527b08b436SNathan Gauër 1537b08b436SNathan Gauër if (ConvergenceToken.value_or(nullptr)) { 1547b08b436SNathan Gauër dbgs() << Indent 1557b08b436SNathan Gauër << " ConvergenceToken: " << ConvergenceToken.value()->getName() 1567b08b436SNathan Gauër << "\n"; 1577b08b436SNathan Gauër } 1587b08b436SNathan Gauër 1597b08b436SNathan Gauër if (Entry->getName() != "") 1607b08b436SNathan Gauër dbgs() << Indent << " Entry: " << Entry->getName() << "\n"; 1617b08b436SNathan Gauër else 1627b08b436SNathan Gauër dbgs() << Indent << " Entry: " << Entry << "\n"; 1637b08b436SNathan Gauër 1647b08b436SNathan Gauër dbgs() << Indent << " Exits: { "; 1657b08b436SNathan Gauër for (const auto &Exit : Exits) { 1667b08b436SNathan Gauër if (Exit->getName() != "") 1677b08b436SNathan Gauër dbgs() << Exit->getName() << ", "; 1687b08b436SNathan Gauër else 1697b08b436SNathan Gauër dbgs() << Exit << ", "; 1707b08b436SNathan Gauër } 1717b08b436SNathan Gauër dbgs() << " }\n"; 1727b08b436SNathan Gauër 1737b08b436SNathan Gauër dbgs() << Indent << " Blocks: { "; 1747b08b436SNathan Gauër for (const auto &Block : Blocks) { 1757b08b436SNathan Gauër if (Block->getName() != "") 1767b08b436SNathan Gauër dbgs() << Block->getName() << ", "; 1777b08b436SNathan Gauër else 1787b08b436SNathan Gauër dbgs() << Block << ", "; 1797b08b436SNathan Gauër } 1807b08b436SNathan Gauër dbgs() << " }\n"; 1817b08b436SNathan Gauër 1827b08b436SNathan Gauër dbgs() << Indent << " Children: {\n"; 1837b08b436SNathan Gauër for (const auto Child : Children) 1847b08b436SNathan Gauër Child->dump(IndentSize + 2); 1857b08b436SNathan Gauër dbgs() << Indent << " }\n"; 1867b08b436SNathan Gauër 1877b08b436SNathan Gauër dbgs() << Indent << "}\n"; 1887b08b436SNathan Gauër } 1897b08b436SNathan Gauër 1907b08b436SNathan Gauër class ConvergenceRegionAnalyzer { 1917b08b436SNathan Gauër 1927b08b436SNathan Gauër public: 1937b08b436SNathan Gauër ConvergenceRegionAnalyzer(Function &F, DominatorTree &DT, LoopInfo &LI) 1947b08b436SNathan Gauër : DT(DT), LI(LI), F(F) {} 1957b08b436SNathan Gauër 1967b08b436SNathan Gauër private: 1977b08b436SNathan Gauër bool isBackEdge(const BasicBlock *From, const BasicBlock *To) const { 1981ed65febSNathan Gauër if (From == To) 1991ed65febSNathan Gauër return true; 2007b08b436SNathan Gauër 2017b08b436SNathan Gauër // We only handle loop in the simplified form. This means: 2027b08b436SNathan Gauër // - a single back-edge, a single latch. 2037b08b436SNathan Gauër // - meaning the back-edge target can only be the loop header. 2047b08b436SNathan Gauër // - meaning the From can only be the loop latch. 2057b08b436SNathan Gauër if (!LI.isLoopHeader(To)) 2067b08b436SNathan Gauër return false; 2077b08b436SNathan Gauër 2087b08b436SNathan Gauër auto *L = LI.getLoopFor(To); 2097b08b436SNathan Gauër if (L->contains(From) && L->isLoopLatch(From)) 2107b08b436SNathan Gauër return true; 2117b08b436SNathan Gauër 2127b08b436SNathan Gauër return false; 2137b08b436SNathan Gauër } 2147b08b436SNathan Gauër 2157b08b436SNathan Gauër std::unordered_set<BasicBlock *> 2167b08b436SNathan Gauër findPathsToMatch(LoopInfo &LI, BasicBlock *From, 2177b08b436SNathan Gauër std::function<bool(const BasicBlock *)> isMatch) const { 2187b08b436SNathan Gauër std::unordered_set<BasicBlock *> Output; 2197b08b436SNathan Gauër 2207b08b436SNathan Gauër if (isMatch(From)) 2217b08b436SNathan Gauër Output.insert(From); 2227b08b436SNathan Gauër 2237b08b436SNathan Gauër auto *Terminator = From->getTerminator(); 2247b08b436SNathan Gauër for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) { 2257b08b436SNathan Gauër auto *To = Terminator->getSuccessor(i); 2261ed65febSNathan Gauër // Ignore back edges. 2277b08b436SNathan Gauër if (isBackEdge(From, To)) 2287b08b436SNathan Gauër continue; 2297b08b436SNathan Gauër 2307b08b436SNathan Gauër auto ChildSet = findPathsToMatch(LI, To, isMatch); 2317b08b436SNathan Gauër if (ChildSet.size() == 0) 2327b08b436SNathan Gauër continue; 2337b08b436SNathan Gauër 2347b08b436SNathan Gauër Output.insert(ChildSet.begin(), ChildSet.end()); 2357b08b436SNathan Gauër Output.insert(From); 2367b08b436SNathan Gauër if (LI.isLoopHeader(From)) { 2377b08b436SNathan Gauër auto *L = LI.getLoopFor(From); 2387b08b436SNathan Gauër for (auto *BB : L->getBlocks()) { 2397b08b436SNathan Gauër Output.insert(BB); 2407b08b436SNathan Gauër } 2417b08b436SNathan Gauër } 2427b08b436SNathan Gauër } 2437b08b436SNathan Gauër 2447b08b436SNathan Gauër return Output; 2457b08b436SNathan Gauër } 2467b08b436SNathan Gauër 2477b08b436SNathan Gauër SmallPtrSet<BasicBlock *, 2> 2487b08b436SNathan Gauër findExitNodes(const SmallPtrSetImpl<BasicBlock *> &RegionBlocks) { 2497b08b436SNathan Gauër SmallPtrSet<BasicBlock *, 2> Exits; 2507b08b436SNathan Gauër 2517b08b436SNathan Gauër for (auto *B : RegionBlocks) { 2527b08b436SNathan Gauër auto *Terminator = B->getTerminator(); 2537b08b436SNathan Gauër for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) { 2547b08b436SNathan Gauër auto *Child = Terminator->getSuccessor(i); 2557b08b436SNathan Gauër if (RegionBlocks.count(Child) == 0) 2567b08b436SNathan Gauër Exits.insert(B); 2577b08b436SNathan Gauër } 2587b08b436SNathan Gauër } 2597b08b436SNathan Gauër 2607b08b436SNathan Gauër return Exits; 2617b08b436SNathan Gauër } 2627b08b436SNathan Gauër 2637b08b436SNathan Gauër public: 2647b08b436SNathan Gauër ConvergenceRegionInfo analyze() { 2657b08b436SNathan Gauër ConvergenceRegion *TopLevelRegion = new ConvergenceRegion(DT, LI, F); 2667b08b436SNathan Gauër std::queue<Loop *> ToProcess; 2677b08b436SNathan Gauër for (auto *L : LI.getLoopsInPreorder()) 2687b08b436SNathan Gauër ToProcess.push(L); 2697b08b436SNathan Gauër 2707b08b436SNathan Gauër while (ToProcess.size() != 0) { 2717b08b436SNathan Gauër auto *L = ToProcess.front(); 2727b08b436SNathan Gauër ToProcess.pop(); 2737b08b436SNathan Gauër 2747b08b436SNathan Gauër auto CT = getConvergenceToken(L->getHeader()); 2757b08b436SNathan Gauër SmallPtrSet<BasicBlock *, 8> RegionBlocks(L->block_begin(), 2767b08b436SNathan Gauër L->block_end()); 2777b08b436SNathan Gauër SmallVector<BasicBlock *> LoopExits; 2787b08b436SNathan Gauër L->getExitingBlocks(LoopExits); 2797b08b436SNathan Gauër if (CT.has_value()) { 2807b08b436SNathan Gauër for (auto *Exit : LoopExits) { 2817b08b436SNathan Gauër auto N = findPathsToMatch(LI, Exit, [&CT](const BasicBlock *block) { 2827b08b436SNathan Gauër auto Token = getConvergenceToken(block); 2837b08b436SNathan Gauër if (Token == std::nullopt) 2847b08b436SNathan Gauër return false; 2857b08b436SNathan Gauër return Token.value() == CT.value(); 2867b08b436SNathan Gauër }); 2877b08b436SNathan Gauër RegionBlocks.insert(N.begin(), N.end()); 2887b08b436SNathan Gauër } 2897b08b436SNathan Gauër } 2907b08b436SNathan Gauër 2917b08b436SNathan Gauër auto RegionExits = findExitNodes(RegionBlocks); 2927b08b436SNathan Gauër ConvergenceRegion *Region = new ConvergenceRegion( 2937b08b436SNathan Gauër DT, LI, CT, L->getHeader(), std::move(RegionBlocks), 2947b08b436SNathan Gauër std::move(RegionExits)); 2957b08b436SNathan Gauër Region->Parent = findParentRegion(TopLevelRegion, Region->Entry); 2967b08b436SNathan Gauër assert(Region->Parent != nullptr && "This is impossible."); 2977b08b436SNathan Gauër Region->Parent->Children.push_back(Region); 2987b08b436SNathan Gauër } 2997b08b436SNathan Gauër 3007b08b436SNathan Gauër return ConvergenceRegionInfo(TopLevelRegion); 3017b08b436SNathan Gauër } 3027b08b436SNathan Gauër 3037b08b436SNathan Gauër private: 3047b08b436SNathan Gauër DominatorTree &DT; 3057b08b436SNathan Gauër LoopInfo &LI; 3067b08b436SNathan Gauër Function &F; 3077b08b436SNathan Gauër }; 3087b08b436SNathan Gauër 3097b08b436SNathan Gauër ConvergenceRegionInfo getConvergenceRegions(Function &F, DominatorTree &DT, 3107b08b436SNathan Gauër LoopInfo &LI) { 3117b08b436SNathan Gauër ConvergenceRegionAnalyzer Analyzer(F, DT, LI); 3127b08b436SNathan Gauër return Analyzer.analyze(); 3137b08b436SNathan Gauër } 3147b08b436SNathan Gauër 3157b08b436SNathan Gauër } // namespace SPIRV 3167b08b436SNathan Gauër 3177b08b436SNathan Gauër char SPIRVConvergenceRegionAnalysisWrapperPass::ID = 0; 3187b08b436SNathan Gauër 3197b08b436SNathan Gauër SPIRVConvergenceRegionAnalysisWrapperPass:: 3207b08b436SNathan Gauër SPIRVConvergenceRegionAnalysisWrapperPass() 3217b08b436SNathan Gauër : FunctionPass(ID) {} 3227b08b436SNathan Gauër 3237b08b436SNathan Gauër bool SPIRVConvergenceRegionAnalysisWrapperPass::runOnFunction(Function &F) { 3247b08b436SNathan Gauër DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 3257b08b436SNathan Gauër LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 3267b08b436SNathan Gauër 3277b08b436SNathan Gauër CRI = SPIRV::getConvergenceRegions(F, DT, LI); 3287b08b436SNathan Gauër // Nothing was modified. 3297b08b436SNathan Gauër return false; 3307b08b436SNathan Gauër } 3317b08b436SNathan Gauër 3327b08b436SNathan Gauër SPIRVConvergenceRegionAnalysis::Result 3337b08b436SNathan Gauër SPIRVConvergenceRegionAnalysis::run(Function &F, FunctionAnalysisManager &AM) { 3347b08b436SNathan Gauër Result CRI; 3357b08b436SNathan Gauër auto &DT = AM.getResult<DominatorTreeAnalysis>(F); 3367b08b436SNathan Gauër auto &LI = AM.getResult<LoopAnalysis>(F); 3377b08b436SNathan Gauër CRI = SPIRV::getConvergenceRegions(F, DT, LI); 3387b08b436SNathan Gauër return CRI; 3397b08b436SNathan Gauër } 3407b08b436SNathan Gauër 3417b08b436SNathan Gauër AnalysisKey SPIRVConvergenceRegionAnalysis::Key; 3427b08b436SNathan Gauër 3437b08b436SNathan Gauër } // namespace llvm 344