xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp (revision cba70550ccf55c6ad3daa621bb8caf3c4ca6cbd7)
1 //===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Merge the multiple exit targets of a convergence region into a single block.
10 // Each exit target will be assigned a constant value, and a phi node + switch
11 // will allow the new exit target to re-route to the correct basic block.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "Analysis/SPIRVConvergenceRegionAnalysis.h"
16 #include "SPIRV.h"
17 #include "SPIRVSubtarget.h"
18 #include "SPIRVTargetMachine.h"
19 #include "SPIRVUtils.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/Analysis/LoopInfo.h"
23 #include "llvm/CodeGen/IntrinsicLowering.h"
24 #include "llvm/IR/CFG.h"
25 #include "llvm/IR/Dominators.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/IntrinsicInst.h"
28 #include "llvm/IR/Intrinsics.h"
29 #include "llvm/IR/IntrinsicsSPIRV.h"
30 #include "llvm/InitializePasses.h"
31 #include "llvm/Transforms/Utils/Cloning.h"
32 #include "llvm/Transforms/Utils/LoopSimplify.h"
33 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
34 
35 using namespace llvm;
36 
37 namespace llvm {
38 void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &);
39 
40 class SPIRVMergeRegionExitTargets : public FunctionPass {
41 public:
42   static char ID;
43 
44   SPIRVMergeRegionExitTargets() : FunctionPass(ID) {
45     initializeSPIRVMergeRegionExitTargetsPass(*PassRegistry::getPassRegistry());
46   };
47 
48   // Gather all the successors of |BB|.
49   // This function asserts if the terminator neither a branch, switch or return.
50   std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
51     std::unordered_set<BasicBlock *> output;
52     auto *T = BB->getTerminator();
53 
54     if (auto *BI = dyn_cast<BranchInst>(T)) {
55       output.insert(BI->getSuccessor(0));
56       if (BI->isConditional())
57         output.insert(BI->getSuccessor(1));
58       return output;
59     }
60 
61     if (auto *SI = dyn_cast<SwitchInst>(T)) {
62       output.insert(SI->getDefaultDest());
63       for (auto &Case : SI->cases())
64         output.insert(Case.getCaseSuccessor());
65       return output;
66     }
67 
68     assert(isa<ReturnInst>(T) && "Unhandled terminator type.");
69     return output;
70   }
71 
72   /// Create a value in BB set to the value associated with the branch the block
73   /// terminator will take.
74   llvm::Value *createExitVariable(
75       BasicBlock *BB,
76       const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
77     auto *T = BB->getTerminator();
78     if (isa<ReturnInst>(T))
79       return nullptr;
80 
81     IRBuilder<> Builder(BB);
82     Builder.SetInsertPoint(T);
83 
84     if (auto *BI = dyn_cast<BranchInst>(T)) {
85 
86       BasicBlock *LHSTarget = BI->getSuccessor(0);
87       BasicBlock *RHSTarget =
88           BI->isConditional() ? BI->getSuccessor(1) : nullptr;
89 
90       Value *LHS = TargetToValue.count(LHSTarget) != 0
91                        ? TargetToValue.at(LHSTarget)
92                        : nullptr;
93       Value *RHS = TargetToValue.count(RHSTarget) != 0
94                        ? TargetToValue.at(RHSTarget)
95                        : nullptr;
96 
97       if (LHS == nullptr || RHS == nullptr)
98         return LHS == nullptr ? RHS : LHS;
99       return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
100     }
101 
102     // TODO: add support for switch cases.
103     llvm_unreachable("Unhandled terminator type.");
104   }
105 
106   /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
107   void replaceBranchTargets(BasicBlock *BB,
108                             const SmallPtrSet<BasicBlock *, 4> &ToReplace,
109                             BasicBlock *NewTarget) {
110     auto *T = BB->getTerminator();
111     if (isa<ReturnInst>(T))
112       return;
113 
114     if (auto *BI = dyn_cast<BranchInst>(T)) {
115       for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
116         if (ToReplace.count(BI->getSuccessor(i)) != 0)
117           BI->setSuccessor(i, NewTarget);
118       }
119       return;
120     }
121 
122     if (auto *SI = dyn_cast<SwitchInst>(T)) {
123       for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
124         if (ToReplace.count(SI->getSuccessor(i)) != 0)
125           SI->setSuccessor(i, NewTarget);
126       }
127       return;
128     }
129 
130     assert(false && "Unhandled terminator type.");
131   }
132 
133   AllocaInst *CreateVariable(Function &F, Type *Type,
134                              BasicBlock::iterator Position) {
135     const DataLayout &DL = F.getDataLayout();
136     return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg",
137                           Position);
138   }
139 
140   // Run the pass on the given convergence region, ignoring the sub-regions.
141   // Returns true if the CFG changed, false otherwise.
142   bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
143                                        SPIRV::ConvergenceRegion *CR) {
144     // Gather all the exit targets for this region.
145     SmallPtrSet<BasicBlock *, 4> ExitTargets;
146     for (BasicBlock *Exit : CR->Exits) {
147       for (BasicBlock *Target : gatherSuccessors(Exit)) {
148         if (CR->Blocks.count(Target) == 0)
149           ExitTargets.insert(Target);
150       }
151     }
152 
153     // If we have zero or one exit target, nothing do to.
154     if (ExitTargets.size() <= 1)
155       return false;
156 
157     // Create the new single exit target.
158     auto F = CR->Entry->getParent();
159     auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);
160     IRBuilder<> Builder(NewExitTarget);
161 
162     AllocaInst *Variable = CreateVariable(*F, Builder.getInt32Ty(),
163                                           F->begin()->getFirstInsertionPt());
164 
165     // CodeGen output needs to be stable. Using the set as-is would order
166     // the targets differently depending on the allocation pattern.
167     // Sorting per basic-block ordering in the function.
168     std::vector<BasicBlock *> SortedExitTargets;
169     std::vector<BasicBlock *> SortedExits;
170     for (BasicBlock &BB : *F) {
171       if (ExitTargets.count(&BB) != 0)
172         SortedExitTargets.push_back(&BB);
173       if (CR->Exits.count(&BB) != 0)
174         SortedExits.push_back(&BB);
175     }
176 
177     // Creating one constant per distinct exit target. This will be route to the
178     // correct target.
179     DenseMap<BasicBlock *, ConstantInt *> TargetToValue;
180     for (BasicBlock *Target : SortedExitTargets)
181       TargetToValue.insert(
182           std::make_pair(Target, Builder.getInt32(TargetToValue.size())));
183 
184     // Creating one variable per exit node, set to the constant matching the
185     // targeted external block.
186     std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
187     for (auto Exit : SortedExits) {
188       llvm::Value *Value = createExitVariable(Exit, TargetToValue);
189       IRBuilder<> B2(Exit);
190       B2.SetInsertPoint(Exit->getFirstInsertionPt());
191       B2.CreateStore(Value, Variable);
192       ExitToVariable.emplace_back(std::make_pair(Exit, Value));
193     }
194 
195     llvm::Value *Load = Builder.CreateLoad(Builder.getInt32Ty(), Variable);
196 
197     // Creating the switch to jump to the correct exit target.
198     llvm::SwitchInst *Sw = Builder.CreateSwitch(Load, SortedExitTargets[0],
199                                                 SortedExitTargets.size() - 1);
200     for (size_t i = 1; i < SortedExitTargets.size(); i++) {
201       BasicBlock *BB = SortedExitTargets[i];
202       Sw->addCase(TargetToValue[BB], BB);
203     }
204 
205     // Fix exit branches to redirect to the new exit.
206     for (auto Exit : CR->Exits)
207       replaceBranchTargets(Exit, ExitTargets, NewExitTarget);
208 
209     CR = CR->Parent;
210     while (CR) {
211       CR->Blocks.insert(NewExitTarget);
212       CR = CR->Parent;
213     }
214 
215     return true;
216   }
217 
218   /// Run the pass on the given convergence region and sub-regions (DFS).
219   /// Returns true if a region/sub-region was modified, false otherwise.
220   /// This returns as soon as one region/sub-region has been modified.
221   bool runOnConvergenceRegion(LoopInfo &LI, SPIRV::ConvergenceRegion *CR) {
222     for (auto *Child : CR->Children)
223       if (runOnConvergenceRegion(LI, Child))
224         return true;
225 
226     return runOnConvergenceRegionNoRecurse(LI, CR);
227   }
228 
229 #if !NDEBUG
230   /// Validates each edge exiting the region has the same destination basic
231   /// block.
232   void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {
233     for (auto *Child : CR->Children)
234       validateRegionExits(Child);
235 
236     std::unordered_set<BasicBlock *> ExitTargets;
237     for (auto *Exit : CR->Exits) {
238       auto Set = gatherSuccessors(Exit);
239       for (auto *BB : Set) {
240         if (CR->Blocks.count(BB) == 0)
241           ExitTargets.insert(BB);
242       }
243     }
244 
245     assert(ExitTargets.size() <= 1);
246   }
247 #endif
248 
249   virtual bool runOnFunction(Function &F) override {
250     LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
251     auto *TopLevelRegion =
252         getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
253             .getRegionInfo()
254             .getWritableTopLevelRegion();
255 
256     // FIXME: very inefficient method: each time a region is modified, we bubble
257     // back up, and recompute the whole convergence region tree. Once the
258     // algorithm is completed and test coverage good enough, rewrite this pass
259     // to be efficient instead of simple.
260     bool modified = false;
261     while (runOnConvergenceRegion(LI, TopLevelRegion)) {
262       modified = true;
263     }
264 
265 #if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS)
266     validateRegionExits(TopLevelRegion);
267 #endif
268     return modified;
269   }
270 
271   void getAnalysisUsage(AnalysisUsage &AU) const override {
272     AU.addRequired<DominatorTreeWrapperPass>();
273     AU.addRequired<LoopInfoWrapperPass>();
274     AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
275 
276     AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
277     FunctionPass::getAnalysisUsage(AU);
278   }
279 };
280 } // namespace llvm
281 
282 char SPIRVMergeRegionExitTargets::ID = 0;
283 
284 INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
285                       "SPIRV split region exit blocks", false, false)
286 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
287 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
288 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
289 INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)
290 
291 INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
292                     "SPIRV split region exit blocks", false, false)
293 
294 FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() {
295   return new SPIRVMergeRegionExitTargets();
296 }
297