xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1*0fca6ea1SDimitry Andric //===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
2*0fca6ea1SDimitry Andric //
3*0fca6ea1SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*0fca6ea1SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*0fca6ea1SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*0fca6ea1SDimitry Andric //
7*0fca6ea1SDimitry Andric //===----------------------------------------------------------------------===//
8*0fca6ea1SDimitry Andric //
9*0fca6ea1SDimitry Andric // Merge the multiple exit targets of a convergence region into a single block.
10*0fca6ea1SDimitry Andric // Each exit target will be assigned a constant value, and a phi node + switch
11*0fca6ea1SDimitry Andric // will allow the new exit target to re-route to the correct basic block.
12*0fca6ea1SDimitry Andric //
13*0fca6ea1SDimitry Andric //===----------------------------------------------------------------------===//
14*0fca6ea1SDimitry Andric 
15*0fca6ea1SDimitry Andric #include "Analysis/SPIRVConvergenceRegionAnalysis.h"
16*0fca6ea1SDimitry Andric #include "SPIRV.h"
17*0fca6ea1SDimitry Andric #include "SPIRVSubtarget.h"
18*0fca6ea1SDimitry Andric #include "SPIRVTargetMachine.h"
19*0fca6ea1SDimitry Andric #include "SPIRVUtils.h"
20*0fca6ea1SDimitry Andric #include "llvm/ADT/DenseMap.h"
21*0fca6ea1SDimitry Andric #include "llvm/ADT/SmallPtrSet.h"
22*0fca6ea1SDimitry Andric #include "llvm/Analysis/LoopInfo.h"
23*0fca6ea1SDimitry Andric #include "llvm/CodeGen/IntrinsicLowering.h"
24*0fca6ea1SDimitry Andric #include "llvm/IR/CFG.h"
25*0fca6ea1SDimitry Andric #include "llvm/IR/Dominators.h"
26*0fca6ea1SDimitry Andric #include "llvm/IR/IRBuilder.h"
27*0fca6ea1SDimitry Andric #include "llvm/IR/IntrinsicInst.h"
28*0fca6ea1SDimitry Andric #include "llvm/IR/Intrinsics.h"
29*0fca6ea1SDimitry Andric #include "llvm/IR/IntrinsicsSPIRV.h"
30*0fca6ea1SDimitry Andric #include "llvm/InitializePasses.h"
31*0fca6ea1SDimitry Andric #include "llvm/Transforms/Utils/Cloning.h"
32*0fca6ea1SDimitry Andric #include "llvm/Transforms/Utils/LoopSimplify.h"
33*0fca6ea1SDimitry Andric #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
34*0fca6ea1SDimitry Andric 
35*0fca6ea1SDimitry Andric using namespace llvm;
36*0fca6ea1SDimitry Andric 
37*0fca6ea1SDimitry Andric namespace llvm {
38*0fca6ea1SDimitry Andric void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &);
39*0fca6ea1SDimitry Andric 
40*0fca6ea1SDimitry Andric class SPIRVMergeRegionExitTargets : public FunctionPass {
41*0fca6ea1SDimitry Andric public:
42*0fca6ea1SDimitry Andric   static char ID;
43*0fca6ea1SDimitry Andric 
44*0fca6ea1SDimitry Andric   SPIRVMergeRegionExitTargets() : FunctionPass(ID) {
45*0fca6ea1SDimitry Andric     initializeSPIRVMergeRegionExitTargetsPass(*PassRegistry::getPassRegistry());
46*0fca6ea1SDimitry Andric   };
47*0fca6ea1SDimitry Andric 
48*0fca6ea1SDimitry Andric   // Gather all the successors of |BB|.
49*0fca6ea1SDimitry Andric   // This function asserts if the terminator neither a branch, switch or return.
50*0fca6ea1SDimitry Andric   std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
51*0fca6ea1SDimitry Andric     std::unordered_set<BasicBlock *> output;
52*0fca6ea1SDimitry Andric     auto *T = BB->getTerminator();
53*0fca6ea1SDimitry Andric 
54*0fca6ea1SDimitry Andric     if (auto *BI = dyn_cast<BranchInst>(T)) {
55*0fca6ea1SDimitry Andric       output.insert(BI->getSuccessor(0));
56*0fca6ea1SDimitry Andric       if (BI->isConditional())
57*0fca6ea1SDimitry Andric         output.insert(BI->getSuccessor(1));
58*0fca6ea1SDimitry Andric       return output;
59*0fca6ea1SDimitry Andric     }
60*0fca6ea1SDimitry Andric 
61*0fca6ea1SDimitry Andric     if (auto *SI = dyn_cast<SwitchInst>(T)) {
62*0fca6ea1SDimitry Andric       output.insert(SI->getDefaultDest());
63*0fca6ea1SDimitry Andric       for (auto &Case : SI->cases())
64*0fca6ea1SDimitry Andric         output.insert(Case.getCaseSuccessor());
65*0fca6ea1SDimitry Andric       return output;
66*0fca6ea1SDimitry Andric     }
67*0fca6ea1SDimitry Andric 
68*0fca6ea1SDimitry Andric     assert(isa<ReturnInst>(T) && "Unhandled terminator type.");
69*0fca6ea1SDimitry Andric     return output;
70*0fca6ea1SDimitry Andric   }
71*0fca6ea1SDimitry Andric 
72*0fca6ea1SDimitry Andric   /// Create a value in BB set to the value associated with the branch the block
73*0fca6ea1SDimitry Andric   /// terminator will take.
74*0fca6ea1SDimitry Andric   llvm::Value *createExitVariable(
75*0fca6ea1SDimitry Andric       BasicBlock *BB,
76*0fca6ea1SDimitry Andric       const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
77*0fca6ea1SDimitry Andric     auto *T = BB->getTerminator();
78*0fca6ea1SDimitry Andric     if (isa<ReturnInst>(T))
79*0fca6ea1SDimitry Andric       return nullptr;
80*0fca6ea1SDimitry Andric 
81*0fca6ea1SDimitry Andric     IRBuilder<> Builder(BB);
82*0fca6ea1SDimitry Andric     Builder.SetInsertPoint(T);
83*0fca6ea1SDimitry Andric 
84*0fca6ea1SDimitry Andric     if (auto *BI = dyn_cast<BranchInst>(T)) {
85*0fca6ea1SDimitry Andric 
86*0fca6ea1SDimitry Andric       BasicBlock *LHSTarget = BI->getSuccessor(0);
87*0fca6ea1SDimitry Andric       BasicBlock *RHSTarget =
88*0fca6ea1SDimitry Andric           BI->isConditional() ? BI->getSuccessor(1) : nullptr;
89*0fca6ea1SDimitry Andric 
90*0fca6ea1SDimitry Andric       Value *LHS = TargetToValue.count(LHSTarget) != 0
91*0fca6ea1SDimitry Andric                        ? TargetToValue.at(LHSTarget)
92*0fca6ea1SDimitry Andric                        : nullptr;
93*0fca6ea1SDimitry Andric       Value *RHS = TargetToValue.count(RHSTarget) != 0
94*0fca6ea1SDimitry Andric                        ? TargetToValue.at(RHSTarget)
95*0fca6ea1SDimitry Andric                        : nullptr;
96*0fca6ea1SDimitry Andric 
97*0fca6ea1SDimitry Andric       if (LHS == nullptr || RHS == nullptr)
98*0fca6ea1SDimitry Andric         return LHS == nullptr ? RHS : LHS;
99*0fca6ea1SDimitry Andric       return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
100*0fca6ea1SDimitry Andric     }
101*0fca6ea1SDimitry Andric 
102*0fca6ea1SDimitry Andric     // TODO: add support for switch cases.
103*0fca6ea1SDimitry Andric     llvm_unreachable("Unhandled terminator type.");
104*0fca6ea1SDimitry Andric   }
105*0fca6ea1SDimitry Andric 
106*0fca6ea1SDimitry Andric   /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
107*0fca6ea1SDimitry Andric   void replaceBranchTargets(BasicBlock *BB,
108*0fca6ea1SDimitry Andric                             const SmallPtrSet<BasicBlock *, 4> &ToReplace,
109*0fca6ea1SDimitry Andric                             BasicBlock *NewTarget) {
110*0fca6ea1SDimitry Andric     auto *T = BB->getTerminator();
111*0fca6ea1SDimitry Andric     if (isa<ReturnInst>(T))
112*0fca6ea1SDimitry Andric       return;
113*0fca6ea1SDimitry Andric 
114*0fca6ea1SDimitry Andric     if (auto *BI = dyn_cast<BranchInst>(T)) {
115*0fca6ea1SDimitry Andric       for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
116*0fca6ea1SDimitry Andric         if (ToReplace.count(BI->getSuccessor(i)) != 0)
117*0fca6ea1SDimitry Andric           BI->setSuccessor(i, NewTarget);
118*0fca6ea1SDimitry Andric       }
119*0fca6ea1SDimitry Andric       return;
120*0fca6ea1SDimitry Andric     }
121*0fca6ea1SDimitry Andric 
122*0fca6ea1SDimitry Andric     if (auto *SI = dyn_cast<SwitchInst>(T)) {
123*0fca6ea1SDimitry Andric       for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
124*0fca6ea1SDimitry Andric         if (ToReplace.count(SI->getSuccessor(i)) != 0)
125*0fca6ea1SDimitry Andric           SI->setSuccessor(i, NewTarget);
126*0fca6ea1SDimitry Andric       }
127*0fca6ea1SDimitry Andric       return;
128*0fca6ea1SDimitry Andric     }
129*0fca6ea1SDimitry Andric 
130*0fca6ea1SDimitry Andric     assert(false && "Unhandled terminator type.");
131*0fca6ea1SDimitry Andric   }
132*0fca6ea1SDimitry Andric 
133*0fca6ea1SDimitry Andric   // Run the pass on the given convergence region, ignoring the sub-regions.
134*0fca6ea1SDimitry Andric   // Returns true if the CFG changed, false otherwise.
135*0fca6ea1SDimitry Andric   bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
136*0fca6ea1SDimitry Andric                                        const SPIRV::ConvergenceRegion *CR) {
137*0fca6ea1SDimitry Andric     // Gather all the exit targets for this region.
138*0fca6ea1SDimitry Andric     SmallPtrSet<BasicBlock *, 4> ExitTargets;
139*0fca6ea1SDimitry Andric     for (BasicBlock *Exit : CR->Exits) {
140*0fca6ea1SDimitry Andric       for (BasicBlock *Target : gatherSuccessors(Exit)) {
141*0fca6ea1SDimitry Andric         if (CR->Blocks.count(Target) == 0)
142*0fca6ea1SDimitry Andric           ExitTargets.insert(Target);
143*0fca6ea1SDimitry Andric       }
144*0fca6ea1SDimitry Andric     }
145*0fca6ea1SDimitry Andric 
146*0fca6ea1SDimitry Andric     // If we have zero or one exit target, nothing do to.
147*0fca6ea1SDimitry Andric     if (ExitTargets.size() <= 1)
148*0fca6ea1SDimitry Andric       return false;
149*0fca6ea1SDimitry Andric 
150*0fca6ea1SDimitry Andric     // Create the new single exit target.
151*0fca6ea1SDimitry Andric     auto F = CR->Entry->getParent();
152*0fca6ea1SDimitry Andric     auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);
153*0fca6ea1SDimitry Andric     IRBuilder<> Builder(NewExitTarget);
154*0fca6ea1SDimitry Andric 
155*0fca6ea1SDimitry Andric     // CodeGen output needs to be stable. Using the set as-is would order
156*0fca6ea1SDimitry Andric     // the targets differently depending on the allocation pattern.
157*0fca6ea1SDimitry Andric     // Sorting per basic-block ordering in the function.
158*0fca6ea1SDimitry Andric     std::vector<BasicBlock *> SortedExitTargets;
159*0fca6ea1SDimitry Andric     std::vector<BasicBlock *> SortedExits;
160*0fca6ea1SDimitry Andric     for (BasicBlock &BB : *F) {
161*0fca6ea1SDimitry Andric       if (ExitTargets.count(&BB) != 0)
162*0fca6ea1SDimitry Andric         SortedExitTargets.push_back(&BB);
163*0fca6ea1SDimitry Andric       if (CR->Exits.count(&BB) != 0)
164*0fca6ea1SDimitry Andric         SortedExits.push_back(&BB);
165*0fca6ea1SDimitry Andric     }
166*0fca6ea1SDimitry Andric 
167*0fca6ea1SDimitry Andric     // Creating one constant per distinct exit target. This will be route to the
168*0fca6ea1SDimitry Andric     // correct target.
169*0fca6ea1SDimitry Andric     DenseMap<BasicBlock *, ConstantInt *> TargetToValue;
170*0fca6ea1SDimitry Andric     for (BasicBlock *Target : SortedExitTargets)
171*0fca6ea1SDimitry Andric       TargetToValue.insert(
172*0fca6ea1SDimitry Andric           std::make_pair(Target, Builder.getInt32(TargetToValue.size())));
173*0fca6ea1SDimitry Andric 
174*0fca6ea1SDimitry Andric     // Creating one variable per exit node, set to the constant matching the
175*0fca6ea1SDimitry Andric     // targeted external block.
176*0fca6ea1SDimitry Andric     std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
177*0fca6ea1SDimitry Andric     for (auto Exit : SortedExits) {
178*0fca6ea1SDimitry Andric       llvm::Value *Value = createExitVariable(Exit, TargetToValue);
179*0fca6ea1SDimitry Andric       ExitToVariable.emplace_back(std::make_pair(Exit, Value));
180*0fca6ea1SDimitry Andric     }
181*0fca6ea1SDimitry Andric 
182*0fca6ea1SDimitry Andric     // Gather the correct value depending on the exit we came from.
183*0fca6ea1SDimitry Andric     llvm::PHINode *node =
184*0fca6ea1SDimitry Andric         Builder.CreatePHI(Builder.getInt32Ty(), ExitToVariable.size());
185*0fca6ea1SDimitry Andric     for (auto [BB, Value] : ExitToVariable) {
186*0fca6ea1SDimitry Andric       node->addIncoming(Value, BB);
187*0fca6ea1SDimitry Andric     }
188*0fca6ea1SDimitry Andric 
189*0fca6ea1SDimitry Andric     // Creating the switch to jump to the correct exit target.
190*0fca6ea1SDimitry Andric     llvm::SwitchInst *Sw = Builder.CreateSwitch(node, SortedExitTargets[0],
191*0fca6ea1SDimitry Andric                                                 SortedExitTargets.size() - 1);
192*0fca6ea1SDimitry Andric     for (size_t i = 1; i < SortedExitTargets.size(); i++) {
193*0fca6ea1SDimitry Andric       BasicBlock *BB = SortedExitTargets[i];
194*0fca6ea1SDimitry Andric       Sw->addCase(TargetToValue[BB], BB);
195*0fca6ea1SDimitry Andric     }
196*0fca6ea1SDimitry Andric 
197*0fca6ea1SDimitry Andric     // Fix exit branches to redirect to the new exit.
198*0fca6ea1SDimitry Andric     for (auto Exit : CR->Exits)
199*0fca6ea1SDimitry Andric       replaceBranchTargets(Exit, ExitTargets, NewExitTarget);
200*0fca6ea1SDimitry Andric 
201*0fca6ea1SDimitry Andric     return true;
202*0fca6ea1SDimitry Andric   }
203*0fca6ea1SDimitry Andric 
204*0fca6ea1SDimitry Andric   /// Run the pass on the given convergence region and sub-regions (DFS).
205*0fca6ea1SDimitry Andric   /// Returns true if a region/sub-region was modified, false otherwise.
206*0fca6ea1SDimitry Andric   /// This returns as soon as one region/sub-region has been modified.
207*0fca6ea1SDimitry Andric   bool runOnConvergenceRegion(LoopInfo &LI,
208*0fca6ea1SDimitry Andric                               const SPIRV::ConvergenceRegion *CR) {
209*0fca6ea1SDimitry Andric     for (auto *Child : CR->Children)
210*0fca6ea1SDimitry Andric       if (runOnConvergenceRegion(LI, Child))
211*0fca6ea1SDimitry Andric         return true;
212*0fca6ea1SDimitry Andric 
213*0fca6ea1SDimitry Andric     return runOnConvergenceRegionNoRecurse(LI, CR);
214*0fca6ea1SDimitry Andric   }
215*0fca6ea1SDimitry Andric 
216*0fca6ea1SDimitry Andric #if !NDEBUG
217*0fca6ea1SDimitry Andric   /// Validates each edge exiting the region has the same destination basic
218*0fca6ea1SDimitry Andric   /// block.
219*0fca6ea1SDimitry Andric   void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {
220*0fca6ea1SDimitry Andric     for (auto *Child : CR->Children)
221*0fca6ea1SDimitry Andric       validateRegionExits(Child);
222*0fca6ea1SDimitry Andric 
223*0fca6ea1SDimitry Andric     std::unordered_set<BasicBlock *> ExitTargets;
224*0fca6ea1SDimitry Andric     for (auto *Exit : CR->Exits) {
225*0fca6ea1SDimitry Andric       auto Set = gatherSuccessors(Exit);
226*0fca6ea1SDimitry Andric       for (auto *BB : Set) {
227*0fca6ea1SDimitry Andric         if (CR->Blocks.count(BB) == 0)
228*0fca6ea1SDimitry Andric           ExitTargets.insert(BB);
229*0fca6ea1SDimitry Andric       }
230*0fca6ea1SDimitry Andric     }
231*0fca6ea1SDimitry Andric 
232*0fca6ea1SDimitry Andric     assert(ExitTargets.size() <= 1);
233*0fca6ea1SDimitry Andric   }
234*0fca6ea1SDimitry Andric #endif
235*0fca6ea1SDimitry Andric 
236*0fca6ea1SDimitry Andric   virtual bool runOnFunction(Function &F) override {
237*0fca6ea1SDimitry Andric     LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
238*0fca6ea1SDimitry Andric     const auto *TopLevelRegion =
239*0fca6ea1SDimitry Andric         getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
240*0fca6ea1SDimitry Andric             .getRegionInfo()
241*0fca6ea1SDimitry Andric             .getTopLevelRegion();
242*0fca6ea1SDimitry Andric 
243*0fca6ea1SDimitry Andric     // FIXME: very inefficient method: each time a region is modified, we bubble
244*0fca6ea1SDimitry Andric     // back up, and recompute the whole convergence region tree. Once the
245*0fca6ea1SDimitry Andric     // algorithm is completed and test coverage good enough, rewrite this pass
246*0fca6ea1SDimitry Andric     // to be efficient instead of simple.
247*0fca6ea1SDimitry Andric     bool modified = false;
248*0fca6ea1SDimitry Andric     while (runOnConvergenceRegion(LI, TopLevelRegion)) {
249*0fca6ea1SDimitry Andric       TopLevelRegion = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
250*0fca6ea1SDimitry Andric                            .getRegionInfo()
251*0fca6ea1SDimitry Andric                            .getTopLevelRegion();
252*0fca6ea1SDimitry Andric       modified = true;
253*0fca6ea1SDimitry Andric     }
254*0fca6ea1SDimitry Andric 
255*0fca6ea1SDimitry Andric #if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS)
256*0fca6ea1SDimitry Andric     validateRegionExits(TopLevelRegion);
257*0fca6ea1SDimitry Andric #endif
258*0fca6ea1SDimitry Andric     return modified;
259*0fca6ea1SDimitry Andric   }
260*0fca6ea1SDimitry Andric 
261*0fca6ea1SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
262*0fca6ea1SDimitry Andric     AU.addRequired<DominatorTreeWrapperPass>();
263*0fca6ea1SDimitry Andric     AU.addRequired<LoopInfoWrapperPass>();
264*0fca6ea1SDimitry Andric     AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
265*0fca6ea1SDimitry Andric     FunctionPass::getAnalysisUsage(AU);
266*0fca6ea1SDimitry Andric   }
267*0fca6ea1SDimitry Andric };
268*0fca6ea1SDimitry Andric } // namespace llvm
269*0fca6ea1SDimitry Andric 
270*0fca6ea1SDimitry Andric char SPIRVMergeRegionExitTargets::ID = 0;
271*0fca6ea1SDimitry Andric 
272*0fca6ea1SDimitry Andric INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
273*0fca6ea1SDimitry Andric                       "SPIRV split region exit blocks", false, false)
274*0fca6ea1SDimitry Andric INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
275*0fca6ea1SDimitry Andric INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
276*0fca6ea1SDimitry Andric INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
277*0fca6ea1SDimitry Andric INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)
278*0fca6ea1SDimitry Andric 
279*0fca6ea1SDimitry Andric INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
280*0fca6ea1SDimitry Andric                     "SPIRV split region exit blocks", false, false)
281*0fca6ea1SDimitry Andric 
282*0fca6ea1SDimitry Andric FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() {
283*0fca6ea1SDimitry Andric   return new SPIRVMergeRegionExitTargets();
284*0fca6ea1SDimitry Andric }
285