xref: /llvm-project/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp (revision 77e6f434ec79db025aa9c7d193179727f1d63714)
17b08b436SNathan Gauër //===- ConvergenceRegionAnalysis.h -----------------------------*- C++ -*--===//
27b08b436SNathan Gauër //
37b08b436SNathan Gauër // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47b08b436SNathan Gauër // See https://llvm.org/LICENSE.txt for license information.
57b08b436SNathan Gauër // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67b08b436SNathan Gauër //
77b08b436SNathan Gauër //===----------------------------------------------------------------------===//
87b08b436SNathan Gauër //
97b08b436SNathan Gauër // The analysis determines the convergence region for each basic block of
107b08b436SNathan Gauër // the module, and provides a tree-like structure describing the region
117b08b436SNathan Gauër // hierarchy.
127b08b436SNathan Gauër //
137b08b436SNathan Gauër //===----------------------------------------------------------------------===//
147b08b436SNathan Gauër 
157b08b436SNathan Gauër #include "SPIRVConvergenceRegionAnalysis.h"
167b08b436SNathan Gauër #include "llvm/Analysis/LoopInfo.h"
177b08b436SNathan Gauër #include "llvm/IR/Dominators.h"
187b08b436SNathan Gauër #include "llvm/IR/IntrinsicInst.h"
197b08b436SNathan Gauër #include "llvm/InitializePasses.h"
207b08b436SNathan Gauër #include "llvm/Transforms/Utils/LoopSimplify.h"
217b08b436SNathan Gauër #include <optional>
227b08b436SNathan Gauër #include <queue>
237b08b436SNathan Gauër 
247b08b436SNathan Gauër #define DEBUG_TYPE "spirv-convergence-region-analysis"
257b08b436SNathan Gauër 
267b08b436SNathan Gauër using namespace llvm;
277b08b436SNathan Gauër 
287b08b436SNathan Gauër namespace llvm {
297b08b436SNathan Gauër void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &);
307b08b436SNathan Gauër } // namespace llvm
317b08b436SNathan Gauër 
327b08b436SNathan Gauër INITIALIZE_PASS_BEGIN(SPIRVConvergenceRegionAnalysisWrapperPass,
337b08b436SNathan Gauër                       "convergence-region",
347b08b436SNathan Gauër                       "SPIRV convergence regions analysis", true, true)
357b08b436SNathan Gauër INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
367b08b436SNathan Gauër INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
377b08b436SNathan Gauër INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
387b08b436SNathan Gauër INITIALIZE_PASS_END(SPIRVConvergenceRegionAnalysisWrapperPass,
397b08b436SNathan Gauër                     "convergence-region", "SPIRV convergence regions analysis",
407b08b436SNathan Gauër                     true, true)
417b08b436SNathan Gauër 
427b08b436SNathan Gauër namespace llvm {
437b08b436SNathan Gauër namespace SPIRV {
447b08b436SNathan Gauër namespace {
457b08b436SNathan Gauër 
467b08b436SNathan Gauër template <typename BasicBlockType, typename IntrinsicInstType>
477b08b436SNathan Gauër std::optional<IntrinsicInstType *>
487b08b436SNathan Gauër getConvergenceTokenInternal(BasicBlockType *BB) {
497b08b436SNathan Gauër   static_assert(std::is_const_v<IntrinsicInstType> ==
507b08b436SNathan Gauër                     std::is_const_v<BasicBlockType>,
517b08b436SNathan Gauër                 "Constness must match between input and output.");
527b08b436SNathan Gauër   static_assert(std::is_same_v<BasicBlock, std::remove_const_t<BasicBlockType>>,
537b08b436SNathan Gauër                 "Input must be a basic block.");
547b08b436SNathan Gauër   static_assert(
557b08b436SNathan Gauër       std::is_same_v<IntrinsicInst, std::remove_const_t<IntrinsicInstType>>,
567b08b436SNathan Gauër       "Output type must be an intrinsic instruction.");
577b08b436SNathan Gauër 
587b08b436SNathan Gauër   for (auto &I : *BB) {
59*77e6f434SSameer Sahasrabuddhe     if (auto *CI = dyn_cast<ConvergenceControlInst>(&I)) {
60*77e6f434SSameer Sahasrabuddhe       // Make sure that the anchor or entry intrinsics did not reach here with a
61*77e6f434SSameer Sahasrabuddhe       // parent token. This should have failed the verifier.
62*77e6f434SSameer Sahasrabuddhe       assert(CI->isLoop() ||
63*77e6f434SSameer Sahasrabuddhe              !CI->getOperandBundle(LLVMContext::OB_convergencectrl));
64*77e6f434SSameer Sahasrabuddhe       return CI;
657b08b436SNathan Gauër     }
667b08b436SNathan Gauër 
677b08b436SNathan Gauër     if (auto *CI = dyn_cast<CallInst>(&I)) {
687b08b436SNathan Gauër       auto OB = CI->getOperandBundle(LLVMContext::OB_convergencectrl);
697b08b436SNathan Gauër       if (!OB.has_value())
707b08b436SNathan Gauër         continue;
717b08b436SNathan Gauër       return dyn_cast<IntrinsicInst>(OB.value().Inputs[0]);
727b08b436SNathan Gauër     }
737b08b436SNathan Gauër   }
747b08b436SNathan Gauër 
757b08b436SNathan Gauër   return std::nullopt;
767b08b436SNathan Gauër }
777b08b436SNathan Gauër 
787b08b436SNathan Gauër // Given a ConvergenceRegion tree with |Start| as its root, finds the smallest
797b08b436SNathan Gauër // region |Entry| belongs to. If |Entry| does not belong to the region defined
807b08b436SNathan Gauër // by |Start|, this function returns |nullptr|.
817b08b436SNathan Gauër ConvergenceRegion *findParentRegion(ConvergenceRegion *Start,
827b08b436SNathan Gauër                                     BasicBlock *Entry) {
837b08b436SNathan Gauër   ConvergenceRegion *Candidate = nullptr;
847b08b436SNathan Gauër   ConvergenceRegion *NextCandidate = Start;
857b08b436SNathan Gauër 
867b08b436SNathan Gauër   while (Candidate != NextCandidate && NextCandidate != nullptr) {
877b08b436SNathan Gauër     Candidate = NextCandidate;
887b08b436SNathan Gauër     NextCandidate = nullptr;
897b08b436SNathan Gauër 
907b08b436SNathan Gauër     // End of the search, we can return.
917b08b436SNathan Gauër     if (Candidate->Children.size() == 0)
927b08b436SNathan Gauër       return Candidate;
937b08b436SNathan Gauër 
947b08b436SNathan Gauër     for (auto *Child : Candidate->Children) {
957b08b436SNathan Gauër       if (Child->Blocks.count(Entry) != 0) {
967b08b436SNathan Gauër         NextCandidate = Child;
977b08b436SNathan Gauër         break;
987b08b436SNathan Gauër       }
997b08b436SNathan Gauër     }
1007b08b436SNathan Gauër   }
1017b08b436SNathan Gauër 
1027b08b436SNathan Gauër   return Candidate;
1037b08b436SNathan Gauër }
1047b08b436SNathan Gauër 
1057b08b436SNathan Gauër } // anonymous namespace
1067b08b436SNathan Gauër 
1077b08b436SNathan Gauër std::optional<IntrinsicInst *> getConvergenceToken(BasicBlock *BB) {
1087b08b436SNathan Gauër   return getConvergenceTokenInternal<BasicBlock, IntrinsicInst>(BB);
1097b08b436SNathan Gauër }
1107b08b436SNathan Gauër 
1117b08b436SNathan Gauër std::optional<const IntrinsicInst *> getConvergenceToken(const BasicBlock *BB) {
1127b08b436SNathan Gauër   return getConvergenceTokenInternal<const BasicBlock, const IntrinsicInst>(BB);
1137b08b436SNathan Gauër }
1147b08b436SNathan Gauër 
1157b08b436SNathan Gauër ConvergenceRegion::ConvergenceRegion(DominatorTree &DT, LoopInfo &LI,
1167b08b436SNathan Gauër                                      Function &F)
1177b08b436SNathan Gauër     : DT(DT), LI(LI), Parent(nullptr) {
1187b08b436SNathan Gauër   Entry = &F.getEntryBlock();
1197b08b436SNathan Gauër   ConvergenceToken = getConvergenceToken(Entry);
1207b08b436SNathan Gauër   for (auto &B : F) {
1217b08b436SNathan Gauër     Blocks.insert(&B);
1227b08b436SNathan Gauër     if (isa<ReturnInst>(B.getTerminator()))
1237b08b436SNathan Gauër       Exits.insert(&B);
1247b08b436SNathan Gauër   }
1257b08b436SNathan Gauër }
1267b08b436SNathan Gauër 
1277b08b436SNathan Gauër ConvergenceRegion::ConvergenceRegion(
1287b08b436SNathan Gauër     DominatorTree &DT, LoopInfo &LI,
1297b08b436SNathan Gauër     std::optional<IntrinsicInst *> ConvergenceToken, BasicBlock *Entry,
1307b08b436SNathan Gauër     SmallPtrSet<BasicBlock *, 8> &&Blocks, SmallPtrSet<BasicBlock *, 2> &&Exits)
1317b08b436SNathan Gauër     : DT(DT), LI(LI), ConvergenceToken(ConvergenceToken), Entry(Entry),
1327b08b436SNathan Gauër       Exits(std::move(Exits)), Blocks(std::move(Blocks)) {
133e83adfe5SChris B   for ([[maybe_unused]] auto *BB : this->Exits)
1347b08b436SNathan Gauër     assert(this->Blocks.count(BB) != 0);
1357b08b436SNathan Gauër   assert(this->Blocks.count(this->Entry) != 0);
1367b08b436SNathan Gauër }
1377b08b436SNathan Gauër 
1387b08b436SNathan Gauër void ConvergenceRegion::releaseMemory() {
1397b08b436SNathan Gauër   // Parent memory is owned by the parent.
1407b08b436SNathan Gauër   Parent = nullptr;
1417b08b436SNathan Gauër   for (auto *Child : Children) {
1427b08b436SNathan Gauër     Child->releaseMemory();
1437b08b436SNathan Gauër     delete Child;
1447b08b436SNathan Gauër   }
1457b08b436SNathan Gauër   Children.resize(0);
1467b08b436SNathan Gauër }
1477b08b436SNathan Gauër 
1487b08b436SNathan Gauër void ConvergenceRegion::dump(const unsigned IndentSize) const {
1497b08b436SNathan Gauër   const std::string Indent(IndentSize, '\t');
1507b08b436SNathan Gauër   dbgs() << Indent << this << ": {\n";
1517b08b436SNathan Gauër   dbgs() << Indent << "	Parent: " << Parent << "\n";
1527b08b436SNathan Gauër 
1537b08b436SNathan Gauër   if (ConvergenceToken.value_or(nullptr)) {
1547b08b436SNathan Gauër     dbgs() << Indent
1557b08b436SNathan Gauër            << "	ConvergenceToken: " << ConvergenceToken.value()->getName()
1567b08b436SNathan Gauër            << "\n";
1577b08b436SNathan Gauër   }
1587b08b436SNathan Gauër 
1597b08b436SNathan Gauër   if (Entry->getName() != "")
1607b08b436SNathan Gauër     dbgs() << Indent << "	Entry: " << Entry->getName() << "\n";
1617b08b436SNathan Gauër   else
1627b08b436SNathan Gauër     dbgs() << Indent << "	Entry: " << Entry << "\n";
1637b08b436SNathan Gauër 
1647b08b436SNathan Gauër   dbgs() << Indent << "	Exits: { ";
1657b08b436SNathan Gauër   for (const auto &Exit : Exits) {
1667b08b436SNathan Gauër     if (Exit->getName() != "")
1677b08b436SNathan Gauër       dbgs() << Exit->getName() << ", ";
1687b08b436SNathan Gauër     else
1697b08b436SNathan Gauër       dbgs() << Exit << ", ";
1707b08b436SNathan Gauër   }
1717b08b436SNathan Gauër   dbgs() << "	}\n";
1727b08b436SNathan Gauër 
1737b08b436SNathan Gauër   dbgs() << Indent << "	Blocks: { ";
1747b08b436SNathan Gauër   for (const auto &Block : Blocks) {
1757b08b436SNathan Gauër     if (Block->getName() != "")
1767b08b436SNathan Gauër       dbgs() << Block->getName() << ", ";
1777b08b436SNathan Gauër     else
1787b08b436SNathan Gauër       dbgs() << Block << ", ";
1797b08b436SNathan Gauër   }
1807b08b436SNathan Gauër   dbgs() << "	}\n";
1817b08b436SNathan Gauër 
1827b08b436SNathan Gauër   dbgs() << Indent << "	Children: {\n";
1837b08b436SNathan Gauër   for (const auto Child : Children)
1847b08b436SNathan Gauër     Child->dump(IndentSize + 2);
1857b08b436SNathan Gauër   dbgs() << Indent << "	}\n";
1867b08b436SNathan Gauër 
1877b08b436SNathan Gauër   dbgs() << Indent << "}\n";
1887b08b436SNathan Gauër }
1897b08b436SNathan Gauër 
1907b08b436SNathan Gauër class ConvergenceRegionAnalyzer {
1917b08b436SNathan Gauër 
1927b08b436SNathan Gauër public:
1937b08b436SNathan Gauër   ConvergenceRegionAnalyzer(Function &F, DominatorTree &DT, LoopInfo &LI)
1947b08b436SNathan Gauër       : DT(DT), LI(LI), F(F) {}
1957b08b436SNathan Gauër 
1967b08b436SNathan Gauër private:
1977b08b436SNathan Gauër   bool isBackEdge(const BasicBlock *From, const BasicBlock *To) const {
1981ed65febSNathan Gauër     if (From == To)
1991ed65febSNathan Gauër       return true;
2007b08b436SNathan Gauër 
2017b08b436SNathan Gauër     // We only handle loop in the simplified form. This means:
2027b08b436SNathan Gauër     // - a single back-edge, a single latch.
2037b08b436SNathan Gauër     // - meaning the back-edge target can only be the loop header.
2047b08b436SNathan Gauër     // - meaning the From can only be the loop latch.
2057b08b436SNathan Gauër     if (!LI.isLoopHeader(To))
2067b08b436SNathan Gauër       return false;
2077b08b436SNathan Gauër 
2087b08b436SNathan Gauër     auto *L = LI.getLoopFor(To);
2097b08b436SNathan Gauër     if (L->contains(From) && L->isLoopLatch(From))
2107b08b436SNathan Gauër       return true;
2117b08b436SNathan Gauër 
2127b08b436SNathan Gauër     return false;
2137b08b436SNathan Gauër   }
2147b08b436SNathan Gauër 
2157b08b436SNathan Gauër   std::unordered_set<BasicBlock *>
2167b08b436SNathan Gauër   findPathsToMatch(LoopInfo &LI, BasicBlock *From,
2177b08b436SNathan Gauër                    std::function<bool(const BasicBlock *)> isMatch) const {
2187b08b436SNathan Gauër     std::unordered_set<BasicBlock *> Output;
2197b08b436SNathan Gauër 
2207b08b436SNathan Gauër     if (isMatch(From))
2217b08b436SNathan Gauër       Output.insert(From);
2227b08b436SNathan Gauër 
2237b08b436SNathan Gauër     auto *Terminator = From->getTerminator();
2247b08b436SNathan Gauër     for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) {
2257b08b436SNathan Gauër       auto *To = Terminator->getSuccessor(i);
2261ed65febSNathan Gauër       // Ignore back edges.
2277b08b436SNathan Gauër       if (isBackEdge(From, To))
2287b08b436SNathan Gauër         continue;
2297b08b436SNathan Gauër 
2307b08b436SNathan Gauër       auto ChildSet = findPathsToMatch(LI, To, isMatch);
2317b08b436SNathan Gauër       if (ChildSet.size() == 0)
2327b08b436SNathan Gauër         continue;
2337b08b436SNathan Gauër 
2347b08b436SNathan Gauër       Output.insert(ChildSet.begin(), ChildSet.end());
2357b08b436SNathan Gauër       Output.insert(From);
2367b08b436SNathan Gauër       if (LI.isLoopHeader(From)) {
2377b08b436SNathan Gauër         auto *L = LI.getLoopFor(From);
2387b08b436SNathan Gauër         for (auto *BB : L->getBlocks()) {
2397b08b436SNathan Gauër           Output.insert(BB);
2407b08b436SNathan Gauër         }
2417b08b436SNathan Gauër       }
2427b08b436SNathan Gauër     }
2437b08b436SNathan Gauër 
2447b08b436SNathan Gauër     return Output;
2457b08b436SNathan Gauër   }
2467b08b436SNathan Gauër 
2477b08b436SNathan Gauër   SmallPtrSet<BasicBlock *, 2>
2487b08b436SNathan Gauër   findExitNodes(const SmallPtrSetImpl<BasicBlock *> &RegionBlocks) {
2497b08b436SNathan Gauër     SmallPtrSet<BasicBlock *, 2> Exits;
2507b08b436SNathan Gauër 
2517b08b436SNathan Gauër     for (auto *B : RegionBlocks) {
2527b08b436SNathan Gauër       auto *Terminator = B->getTerminator();
2537b08b436SNathan Gauër       for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) {
2547b08b436SNathan Gauër         auto *Child = Terminator->getSuccessor(i);
2557b08b436SNathan Gauër         if (RegionBlocks.count(Child) == 0)
2567b08b436SNathan Gauër           Exits.insert(B);
2577b08b436SNathan Gauër       }
2587b08b436SNathan Gauër     }
2597b08b436SNathan Gauër 
2607b08b436SNathan Gauër     return Exits;
2617b08b436SNathan Gauër   }
2627b08b436SNathan Gauër 
2637b08b436SNathan Gauër public:
2647b08b436SNathan Gauër   ConvergenceRegionInfo analyze() {
2657b08b436SNathan Gauër     ConvergenceRegion *TopLevelRegion = new ConvergenceRegion(DT, LI, F);
2667b08b436SNathan Gauër     std::queue<Loop *> ToProcess;
2677b08b436SNathan Gauër     for (auto *L : LI.getLoopsInPreorder())
2687b08b436SNathan Gauër       ToProcess.push(L);
2697b08b436SNathan Gauër 
2707b08b436SNathan Gauër     while (ToProcess.size() != 0) {
2717b08b436SNathan Gauër       auto *L = ToProcess.front();
2727b08b436SNathan Gauër       ToProcess.pop();
2737b08b436SNathan Gauër 
2747b08b436SNathan Gauër       auto CT = getConvergenceToken(L->getHeader());
2757b08b436SNathan Gauër       SmallPtrSet<BasicBlock *, 8> RegionBlocks(L->block_begin(),
2767b08b436SNathan Gauër                                                 L->block_end());
2777b08b436SNathan Gauër       SmallVector<BasicBlock *> LoopExits;
2787b08b436SNathan Gauër       L->getExitingBlocks(LoopExits);
2797b08b436SNathan Gauër       if (CT.has_value()) {
2807b08b436SNathan Gauër         for (auto *Exit : LoopExits) {
2817b08b436SNathan Gauër           auto N = findPathsToMatch(LI, Exit, [&CT](const BasicBlock *block) {
2827b08b436SNathan Gauër             auto Token = getConvergenceToken(block);
2837b08b436SNathan Gauër             if (Token == std::nullopt)
2847b08b436SNathan Gauër               return false;
2857b08b436SNathan Gauër             return Token.value() == CT.value();
2867b08b436SNathan Gauër           });
2877b08b436SNathan Gauër           RegionBlocks.insert(N.begin(), N.end());
2887b08b436SNathan Gauër         }
2897b08b436SNathan Gauër       }
2907b08b436SNathan Gauër 
2917b08b436SNathan Gauër       auto RegionExits = findExitNodes(RegionBlocks);
2927b08b436SNathan Gauër       ConvergenceRegion *Region = new ConvergenceRegion(
2937b08b436SNathan Gauër           DT, LI, CT, L->getHeader(), std::move(RegionBlocks),
2947b08b436SNathan Gauër           std::move(RegionExits));
2957b08b436SNathan Gauër       Region->Parent = findParentRegion(TopLevelRegion, Region->Entry);
2967b08b436SNathan Gauër       assert(Region->Parent != nullptr && "This is impossible.");
2977b08b436SNathan Gauër       Region->Parent->Children.push_back(Region);
2987b08b436SNathan Gauër     }
2997b08b436SNathan Gauër 
3007b08b436SNathan Gauër     return ConvergenceRegionInfo(TopLevelRegion);
3017b08b436SNathan Gauër   }
3027b08b436SNathan Gauër 
3037b08b436SNathan Gauër private:
3047b08b436SNathan Gauër   DominatorTree &DT;
3057b08b436SNathan Gauër   LoopInfo &LI;
3067b08b436SNathan Gauër   Function &F;
3077b08b436SNathan Gauër };
3087b08b436SNathan Gauër 
3097b08b436SNathan Gauër ConvergenceRegionInfo getConvergenceRegions(Function &F, DominatorTree &DT,
3107b08b436SNathan Gauër                                             LoopInfo &LI) {
3117b08b436SNathan Gauër   ConvergenceRegionAnalyzer Analyzer(F, DT, LI);
3127b08b436SNathan Gauër   return Analyzer.analyze();
3137b08b436SNathan Gauër }
3147b08b436SNathan Gauër 
3157b08b436SNathan Gauër } // namespace SPIRV
3167b08b436SNathan Gauër 
3177b08b436SNathan Gauër char SPIRVConvergenceRegionAnalysisWrapperPass::ID = 0;
3187b08b436SNathan Gauër 
3197b08b436SNathan Gauër SPIRVConvergenceRegionAnalysisWrapperPass::
3207b08b436SNathan Gauër     SPIRVConvergenceRegionAnalysisWrapperPass()
3217b08b436SNathan Gauër     : FunctionPass(ID) {}
3227b08b436SNathan Gauër 
3237b08b436SNathan Gauër bool SPIRVConvergenceRegionAnalysisWrapperPass::runOnFunction(Function &F) {
3247b08b436SNathan Gauër   DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
3257b08b436SNathan Gauër   LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
3267b08b436SNathan Gauër 
3277b08b436SNathan Gauër   CRI = SPIRV::getConvergenceRegions(F, DT, LI);
3287b08b436SNathan Gauër   // Nothing was modified.
3297b08b436SNathan Gauër   return false;
3307b08b436SNathan Gauër }
3317b08b436SNathan Gauër 
3327b08b436SNathan Gauër SPIRVConvergenceRegionAnalysis::Result
3337b08b436SNathan Gauër SPIRVConvergenceRegionAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
3347b08b436SNathan Gauër   Result CRI;
3357b08b436SNathan Gauër   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
3367b08b436SNathan Gauër   auto &LI = AM.getResult<LoopAnalysis>(F);
3377b08b436SNathan Gauër   CRI = SPIRV::getConvergenceRegions(F, DT, LI);
3387b08b436SNathan Gauër   return CRI;
3397b08b436SNathan Gauër }
3407b08b436SNathan Gauër 
3417b08b436SNathan Gauër AnalysisKey SPIRVConvergenceRegionAnalysis::Key;
3427b08b436SNathan Gauër 
3437b08b436SNathan Gauër } // namespace llvm
344