xref: /llvm-project/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.h (revision 1ed65febd996eaa018164e880c87a9e9afc6f68d)
1 //===- SPIRVConvergenceRegionAnalysis.h ------------------------*- C++ -*--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The analysis determines the convergence region for each basic block of
10 // the module, and provides a tree-like structure describing the region
11 // hierarchy.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H
16 #define LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H
17 
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/Analysis/CFG.h"
20 #include "llvm/Analysis/LoopInfo.h"
21 #include "llvm/IR/Dominators.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include <iostream>
24 #include <optional>
25 #include <unordered_set>
26 
27 namespace llvm {
28 class SPIRVSubtarget;
29 class MachineFunction;
30 class MachineModuleInfo;
31 
32 namespace SPIRV {
33 
34 // Returns the first convergence intrinsic found in |BB|, |nullopt| otherwise.
35 std::optional<IntrinsicInst *> getConvergenceToken(BasicBlock *BB);
36 std::optional<const IntrinsicInst *> getConvergenceToken(const BasicBlock *BB);
37 
38 // Describes a hierarchy of convergence regions.
39 // A convergence region defines a CFG for which the execution flow can diverge
40 // starting from the entry block, but should reconverge back before the end of
41 // the exit blocks.
42 class ConvergenceRegion {
43   DominatorTree &DT;
44   LoopInfo &LI;
45 
46 public:
47   // The parent region of this region, if any.
48   ConvergenceRegion *Parent = nullptr;
49   // The sub-regions contained in this region, if any.
50   SmallVector<ConvergenceRegion *> Children = {};
51   // The convergence instruction linked to this region, if any.
52   std::optional<IntrinsicInst *> ConvergenceToken = std::nullopt;
53   // The only block with a predecessor outside of this region.
54   BasicBlock *Entry = nullptr;
55   // All the blocks with an edge leaving this convergence region.
56   SmallPtrSet<BasicBlock *, 2> Exits = {};
57   // All the blocks that belongs to this region, including its subregions'.
58   SmallPtrSet<BasicBlock *, 8> Blocks = {};
59 
60   // Creates a single convergence region encapsulating the whole function |F|.
61   ConvergenceRegion(DominatorTree &DT, LoopInfo &LI, Function &F);
62 
63   // Creates a single convergence region defined by entry and exits nodes, a
64   // list of blocks, and possibly a convergence token.
65   ConvergenceRegion(DominatorTree &DT, LoopInfo &LI,
66                     std::optional<IntrinsicInst *> ConvergenceToken,
67                     BasicBlock *Entry, SmallPtrSet<BasicBlock *, 8> &&Blocks,
68                     SmallPtrSet<BasicBlock *, 2> &&Exits);
69 
70   ConvergenceRegion(ConvergenceRegion &&CR)
71       : DT(CR.DT), LI(CR.LI), Parent(std::move(CR.Parent)),
72         Children(std::move(CR.Children)),
73         ConvergenceToken(std::move(CR.ConvergenceToken)),
74         Entry(std::move(CR.Entry)), Exits(std::move(CR.Exits)),
75         Blocks(std::move(CR.Blocks)) {}
76 
77   ConvergenceRegion(const ConvergenceRegion &other) = delete;
78 
79   // Returns true if the given basic block belongs to this region, or to one of
80   // its subregion.
81   bool contains(const BasicBlock *BB) const { return Blocks.count(BB) != 0; }
82 
83   void releaseMemory();
84 
85   // Write to the debug output this region's hierarchy.
86   // |IndentSize| defines the number of tabs to print before any new line.
87   void dump(const unsigned IndentSize = 0) const;
88 };
89 
90 // Holds a ConvergenceRegion hierarchy.
91 class ConvergenceRegionInfo {
92   // The convergence region this structure holds.
93   ConvergenceRegion *TopLevelRegion;
94 
95 public:
96   ConvergenceRegionInfo() : TopLevelRegion(nullptr) {}
97 
98   // Creates a new ConvergenceRegionInfo. Ownership of the TopLevelRegion is
99   // passed to this object.
100   ConvergenceRegionInfo(ConvergenceRegion *TopLevelRegion)
101       : TopLevelRegion(TopLevelRegion) {}
102 
103   ~ConvergenceRegionInfo() { releaseMemory(); }
104 
105   ConvergenceRegionInfo(ConvergenceRegionInfo &&LHS)
106       : TopLevelRegion(LHS.TopLevelRegion) {
107     if (TopLevelRegion != LHS.TopLevelRegion) {
108       releaseMemory();
109       TopLevelRegion = LHS.TopLevelRegion;
110     }
111     LHS.TopLevelRegion = nullptr;
112   }
113 
114   ConvergenceRegionInfo &operator=(ConvergenceRegionInfo &&LHS) {
115     if (TopLevelRegion != LHS.TopLevelRegion) {
116       releaseMemory();
117       TopLevelRegion = LHS.TopLevelRegion;
118     }
119     LHS.TopLevelRegion = nullptr;
120     return *this;
121   }
122 
123   void releaseMemory() {
124     if (TopLevelRegion == nullptr)
125       return;
126 
127     TopLevelRegion->releaseMemory();
128     delete TopLevelRegion;
129     TopLevelRegion = nullptr;
130   }
131 
132   const ConvergenceRegion *getTopLevelRegion() const { return TopLevelRegion; }
133   ConvergenceRegion *getWritableTopLevelRegion() const {
134     return TopLevelRegion;
135   }
136 };
137 
138 } // namespace SPIRV
139 
140 // Wrapper around the function above to use it with the legacy pass manager.
141 class SPIRVConvergenceRegionAnalysisWrapperPass : public FunctionPass {
142   SPIRV::ConvergenceRegionInfo CRI;
143 
144 public:
145   static char ID;
146 
147   SPIRVConvergenceRegionAnalysisWrapperPass();
148 
149   void getAnalysisUsage(AnalysisUsage &AU) const override {
150     AU.setPreservesAll();
151     AU.addRequired<LoopInfoWrapperPass>();
152     AU.addRequired<DominatorTreeWrapperPass>();
153   };
154 
155   bool runOnFunction(Function &F) override;
156 
157   SPIRV::ConvergenceRegionInfo &getRegionInfo() { return CRI; }
158   const SPIRV::ConvergenceRegionInfo &getRegionInfo() const { return CRI; }
159 };
160 
161 // Wrapper around the function above to use it with the new pass manager.
162 class SPIRVConvergenceRegionAnalysis
163     : public AnalysisInfoMixin<SPIRVConvergenceRegionAnalysis> {
164   friend AnalysisInfoMixin<SPIRVConvergenceRegionAnalysis>;
165   static AnalysisKey Key;
166 
167 public:
168   using Result = SPIRV::ConvergenceRegionInfo;
169 
170   Result run(Function &F, FunctionAnalysisManager &AM);
171 };
172 
173 namespace SPIRV {
174 ConvergenceRegionInfo getConvergenceRegions(Function &F, DominatorTree &DT,
175                                             LoopInfo &LI);
176 } // namespace SPIRV
177 
178 } // namespace llvm
179 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H
180