xref: /llvm-project/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp (revision 77e6f434ec79db025aa9c7d193179727f1d63714)
1 //===- ConvergenceRegionAnalysis.h -----------------------------*- C++ -*--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The analysis determines the convergence region for each basic block of
10 // the module, and provides a tree-like structure describing the region
11 // hierarchy.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "SPIRVConvergenceRegionAnalysis.h"
16 #include "llvm/Analysis/LoopInfo.h"
17 #include "llvm/IR/Dominators.h"
18 #include "llvm/IR/IntrinsicInst.h"
19 #include "llvm/InitializePasses.h"
20 #include "llvm/Transforms/Utils/LoopSimplify.h"
21 #include <optional>
22 #include <queue>
23 
24 #define DEBUG_TYPE "spirv-convergence-region-analysis"
25 
26 using namespace llvm;
27 
28 namespace llvm {
29 void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &);
30 } // namespace llvm
31 
32 INITIALIZE_PASS_BEGIN(SPIRVConvergenceRegionAnalysisWrapperPass,
33                       "convergence-region",
34                       "SPIRV convergence regions analysis", true, true)
35 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
36 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
37 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
38 INITIALIZE_PASS_END(SPIRVConvergenceRegionAnalysisWrapperPass,
39                     "convergence-region", "SPIRV convergence regions analysis",
40                     true, true)
41 
42 namespace llvm {
43 namespace SPIRV {
44 namespace {
45 
46 template <typename BasicBlockType, typename IntrinsicInstType>
47 std::optional<IntrinsicInstType *>
48 getConvergenceTokenInternal(BasicBlockType *BB) {
49   static_assert(std::is_const_v<IntrinsicInstType> ==
50                     std::is_const_v<BasicBlockType>,
51                 "Constness must match between input and output.");
52   static_assert(std::is_same_v<BasicBlock, std::remove_const_t<BasicBlockType>>,
53                 "Input must be a basic block.");
54   static_assert(
55       std::is_same_v<IntrinsicInst, std::remove_const_t<IntrinsicInstType>>,
56       "Output type must be an intrinsic instruction.");
57 
58   for (auto &I : *BB) {
59     if (auto *CI = dyn_cast<ConvergenceControlInst>(&I)) {
60       // Make sure that the anchor or entry intrinsics did not reach here with a
61       // parent token. This should have failed the verifier.
62       assert(CI->isLoop() ||
63              !CI->getOperandBundle(LLVMContext::OB_convergencectrl));
64       return CI;
65     }
66 
67     if (auto *CI = dyn_cast<CallInst>(&I)) {
68       auto OB = CI->getOperandBundle(LLVMContext::OB_convergencectrl);
69       if (!OB.has_value())
70         continue;
71       return dyn_cast<IntrinsicInst>(OB.value().Inputs[0]);
72     }
73   }
74 
75   return std::nullopt;
76 }
77 
78 // Given a ConvergenceRegion tree with |Start| as its root, finds the smallest
79 // region |Entry| belongs to. If |Entry| does not belong to the region defined
80 // by |Start|, this function returns |nullptr|.
81 ConvergenceRegion *findParentRegion(ConvergenceRegion *Start,
82                                     BasicBlock *Entry) {
83   ConvergenceRegion *Candidate = nullptr;
84   ConvergenceRegion *NextCandidate = Start;
85 
86   while (Candidate != NextCandidate && NextCandidate != nullptr) {
87     Candidate = NextCandidate;
88     NextCandidate = nullptr;
89 
90     // End of the search, we can return.
91     if (Candidate->Children.size() == 0)
92       return Candidate;
93 
94     for (auto *Child : Candidate->Children) {
95       if (Child->Blocks.count(Entry) != 0) {
96         NextCandidate = Child;
97         break;
98       }
99     }
100   }
101 
102   return Candidate;
103 }
104 
105 } // anonymous namespace
106 
107 std::optional<IntrinsicInst *> getConvergenceToken(BasicBlock *BB) {
108   return getConvergenceTokenInternal<BasicBlock, IntrinsicInst>(BB);
109 }
110 
111 std::optional<const IntrinsicInst *> getConvergenceToken(const BasicBlock *BB) {
112   return getConvergenceTokenInternal<const BasicBlock, const IntrinsicInst>(BB);
113 }
114 
115 ConvergenceRegion::ConvergenceRegion(DominatorTree &DT, LoopInfo &LI,
116                                      Function &F)
117     : DT(DT), LI(LI), Parent(nullptr) {
118   Entry = &F.getEntryBlock();
119   ConvergenceToken = getConvergenceToken(Entry);
120   for (auto &B : F) {
121     Blocks.insert(&B);
122     if (isa<ReturnInst>(B.getTerminator()))
123       Exits.insert(&B);
124   }
125 }
126 
127 ConvergenceRegion::ConvergenceRegion(
128     DominatorTree &DT, LoopInfo &LI,
129     std::optional<IntrinsicInst *> ConvergenceToken, BasicBlock *Entry,
130     SmallPtrSet<BasicBlock *, 8> &&Blocks, SmallPtrSet<BasicBlock *, 2> &&Exits)
131     : DT(DT), LI(LI), ConvergenceToken(ConvergenceToken), Entry(Entry),
132       Exits(std::move(Exits)), Blocks(std::move(Blocks)) {
133   for ([[maybe_unused]] auto *BB : this->Exits)
134     assert(this->Blocks.count(BB) != 0);
135   assert(this->Blocks.count(this->Entry) != 0);
136 }
137 
138 void ConvergenceRegion::releaseMemory() {
139   // Parent memory is owned by the parent.
140   Parent = nullptr;
141   for (auto *Child : Children) {
142     Child->releaseMemory();
143     delete Child;
144   }
145   Children.resize(0);
146 }
147 
148 void ConvergenceRegion::dump(const unsigned IndentSize) const {
149   const std::string Indent(IndentSize, '\t');
150   dbgs() << Indent << this << ": {\n";
151   dbgs() << Indent << "	Parent: " << Parent << "\n";
152 
153   if (ConvergenceToken.value_or(nullptr)) {
154     dbgs() << Indent
155            << "	ConvergenceToken: " << ConvergenceToken.value()->getName()
156            << "\n";
157   }
158 
159   if (Entry->getName() != "")
160     dbgs() << Indent << "	Entry: " << Entry->getName() << "\n";
161   else
162     dbgs() << Indent << "	Entry: " << Entry << "\n";
163 
164   dbgs() << Indent << "	Exits: { ";
165   for (const auto &Exit : Exits) {
166     if (Exit->getName() != "")
167       dbgs() << Exit->getName() << ", ";
168     else
169       dbgs() << Exit << ", ";
170   }
171   dbgs() << "	}\n";
172 
173   dbgs() << Indent << "	Blocks: { ";
174   for (const auto &Block : Blocks) {
175     if (Block->getName() != "")
176       dbgs() << Block->getName() << ", ";
177     else
178       dbgs() << Block << ", ";
179   }
180   dbgs() << "	}\n";
181 
182   dbgs() << Indent << "	Children: {\n";
183   for (const auto Child : Children)
184     Child->dump(IndentSize + 2);
185   dbgs() << Indent << "	}\n";
186 
187   dbgs() << Indent << "}\n";
188 }
189 
190 class ConvergenceRegionAnalyzer {
191 
192 public:
193   ConvergenceRegionAnalyzer(Function &F, DominatorTree &DT, LoopInfo &LI)
194       : DT(DT), LI(LI), F(F) {}
195 
196 private:
197   bool isBackEdge(const BasicBlock *From, const BasicBlock *To) const {
198     if (From == To)
199       return true;
200 
201     // We only handle loop in the simplified form. This means:
202     // - a single back-edge, a single latch.
203     // - meaning the back-edge target can only be the loop header.
204     // - meaning the From can only be the loop latch.
205     if (!LI.isLoopHeader(To))
206       return false;
207 
208     auto *L = LI.getLoopFor(To);
209     if (L->contains(From) && L->isLoopLatch(From))
210       return true;
211 
212     return false;
213   }
214 
215   std::unordered_set<BasicBlock *>
216   findPathsToMatch(LoopInfo &LI, BasicBlock *From,
217                    std::function<bool(const BasicBlock *)> isMatch) const {
218     std::unordered_set<BasicBlock *> Output;
219 
220     if (isMatch(From))
221       Output.insert(From);
222 
223     auto *Terminator = From->getTerminator();
224     for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) {
225       auto *To = Terminator->getSuccessor(i);
226       // Ignore back edges.
227       if (isBackEdge(From, To))
228         continue;
229 
230       auto ChildSet = findPathsToMatch(LI, To, isMatch);
231       if (ChildSet.size() == 0)
232         continue;
233 
234       Output.insert(ChildSet.begin(), ChildSet.end());
235       Output.insert(From);
236       if (LI.isLoopHeader(From)) {
237         auto *L = LI.getLoopFor(From);
238         for (auto *BB : L->getBlocks()) {
239           Output.insert(BB);
240         }
241       }
242     }
243 
244     return Output;
245   }
246 
247   SmallPtrSet<BasicBlock *, 2>
248   findExitNodes(const SmallPtrSetImpl<BasicBlock *> &RegionBlocks) {
249     SmallPtrSet<BasicBlock *, 2> Exits;
250 
251     for (auto *B : RegionBlocks) {
252       auto *Terminator = B->getTerminator();
253       for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) {
254         auto *Child = Terminator->getSuccessor(i);
255         if (RegionBlocks.count(Child) == 0)
256           Exits.insert(B);
257       }
258     }
259 
260     return Exits;
261   }
262 
263 public:
264   ConvergenceRegionInfo analyze() {
265     ConvergenceRegion *TopLevelRegion = new ConvergenceRegion(DT, LI, F);
266     std::queue<Loop *> ToProcess;
267     for (auto *L : LI.getLoopsInPreorder())
268       ToProcess.push(L);
269 
270     while (ToProcess.size() != 0) {
271       auto *L = ToProcess.front();
272       ToProcess.pop();
273 
274       auto CT = getConvergenceToken(L->getHeader());
275       SmallPtrSet<BasicBlock *, 8> RegionBlocks(L->block_begin(),
276                                                 L->block_end());
277       SmallVector<BasicBlock *> LoopExits;
278       L->getExitingBlocks(LoopExits);
279       if (CT.has_value()) {
280         for (auto *Exit : LoopExits) {
281           auto N = findPathsToMatch(LI, Exit, [&CT](const BasicBlock *block) {
282             auto Token = getConvergenceToken(block);
283             if (Token == std::nullopt)
284               return false;
285             return Token.value() == CT.value();
286           });
287           RegionBlocks.insert(N.begin(), N.end());
288         }
289       }
290 
291       auto RegionExits = findExitNodes(RegionBlocks);
292       ConvergenceRegion *Region = new ConvergenceRegion(
293           DT, LI, CT, L->getHeader(), std::move(RegionBlocks),
294           std::move(RegionExits));
295       Region->Parent = findParentRegion(TopLevelRegion, Region->Entry);
296       assert(Region->Parent != nullptr && "This is impossible.");
297       Region->Parent->Children.push_back(Region);
298     }
299 
300     return ConvergenceRegionInfo(TopLevelRegion);
301   }
302 
303 private:
304   DominatorTree &DT;
305   LoopInfo &LI;
306   Function &F;
307 };
308 
309 ConvergenceRegionInfo getConvergenceRegions(Function &F, DominatorTree &DT,
310                                             LoopInfo &LI) {
311   ConvergenceRegionAnalyzer Analyzer(F, DT, LI);
312   return Analyzer.analyze();
313 }
314 
315 } // namespace SPIRV
316 
317 char SPIRVConvergenceRegionAnalysisWrapperPass::ID = 0;
318 
319 SPIRVConvergenceRegionAnalysisWrapperPass::
320     SPIRVConvergenceRegionAnalysisWrapperPass()
321     : FunctionPass(ID) {}
322 
323 bool SPIRVConvergenceRegionAnalysisWrapperPass::runOnFunction(Function &F) {
324   DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
325   LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
326 
327   CRI = SPIRV::getConvergenceRegions(F, DT, LI);
328   // Nothing was modified.
329   return false;
330 }
331 
332 SPIRVConvergenceRegionAnalysis::Result
333 SPIRVConvergenceRegionAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
334   Result CRI;
335   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
336   auto &LI = AM.getResult<LoopAnalysis>(F);
337   CRI = SPIRV::getConvergenceRegions(F, DT, LI);
338   return CRI;
339 }
340 
341 AnalysisKey SPIRVConvergenceRegionAnalysis::Key;
342 
343 } // namespace llvm
344