1*0fca6ea1SDimitry Andric //===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===// 2*0fca6ea1SDimitry Andric // 3*0fca6ea1SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*0fca6ea1SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5*0fca6ea1SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*0fca6ea1SDimitry Andric // 7*0fca6ea1SDimitry Andric //===----------------------------------------------------------------------===// 8*0fca6ea1SDimitry Andric // 9*0fca6ea1SDimitry Andric // Merge the multiple exit targets of a convergence region into a single block. 10*0fca6ea1SDimitry Andric // Each exit target will be assigned a constant value, and a phi node + switch 11*0fca6ea1SDimitry Andric // will allow the new exit target to re-route to the correct basic block. 12*0fca6ea1SDimitry Andric // 13*0fca6ea1SDimitry Andric //===----------------------------------------------------------------------===// 14*0fca6ea1SDimitry Andric 15*0fca6ea1SDimitry Andric #include "Analysis/SPIRVConvergenceRegionAnalysis.h" 16*0fca6ea1SDimitry Andric #include "SPIRV.h" 17*0fca6ea1SDimitry Andric #include "SPIRVSubtarget.h" 18*0fca6ea1SDimitry Andric #include "SPIRVTargetMachine.h" 19*0fca6ea1SDimitry Andric #include "SPIRVUtils.h" 20*0fca6ea1SDimitry Andric #include "llvm/ADT/DenseMap.h" 21*0fca6ea1SDimitry Andric #include "llvm/ADT/SmallPtrSet.h" 22*0fca6ea1SDimitry Andric #include "llvm/Analysis/LoopInfo.h" 23*0fca6ea1SDimitry Andric #include "llvm/CodeGen/IntrinsicLowering.h" 24*0fca6ea1SDimitry Andric #include "llvm/IR/CFG.h" 25*0fca6ea1SDimitry Andric #include "llvm/IR/Dominators.h" 26*0fca6ea1SDimitry Andric #include "llvm/IR/IRBuilder.h" 27*0fca6ea1SDimitry Andric #include "llvm/IR/IntrinsicInst.h" 28*0fca6ea1SDimitry Andric #include "llvm/IR/Intrinsics.h" 29*0fca6ea1SDimitry Andric #include "llvm/IR/IntrinsicsSPIRV.h" 30*0fca6ea1SDimitry Andric #include "llvm/InitializePasses.h" 31*0fca6ea1SDimitry Andric #include "llvm/Transforms/Utils/Cloning.h" 32*0fca6ea1SDimitry Andric #include "llvm/Transforms/Utils/LoopSimplify.h" 33*0fca6ea1SDimitry Andric #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" 34*0fca6ea1SDimitry Andric 35*0fca6ea1SDimitry Andric using namespace llvm; 36*0fca6ea1SDimitry Andric 37*0fca6ea1SDimitry Andric namespace llvm { 38*0fca6ea1SDimitry Andric void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &); 39*0fca6ea1SDimitry Andric 40*0fca6ea1SDimitry Andric class SPIRVMergeRegionExitTargets : public FunctionPass { 41*0fca6ea1SDimitry Andric public: 42*0fca6ea1SDimitry Andric static char ID; 43*0fca6ea1SDimitry Andric 44*0fca6ea1SDimitry Andric SPIRVMergeRegionExitTargets() : FunctionPass(ID) { 45*0fca6ea1SDimitry Andric initializeSPIRVMergeRegionExitTargetsPass(*PassRegistry::getPassRegistry()); 46*0fca6ea1SDimitry Andric }; 47*0fca6ea1SDimitry Andric 48*0fca6ea1SDimitry Andric // Gather all the successors of |BB|. 49*0fca6ea1SDimitry Andric // This function asserts if the terminator neither a branch, switch or return. 50*0fca6ea1SDimitry Andric std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) { 51*0fca6ea1SDimitry Andric std::unordered_set<BasicBlock *> output; 52*0fca6ea1SDimitry Andric auto *T = BB->getTerminator(); 53*0fca6ea1SDimitry Andric 54*0fca6ea1SDimitry Andric if (auto *BI = dyn_cast<BranchInst>(T)) { 55*0fca6ea1SDimitry Andric output.insert(BI->getSuccessor(0)); 56*0fca6ea1SDimitry Andric if (BI->isConditional()) 57*0fca6ea1SDimitry Andric output.insert(BI->getSuccessor(1)); 58*0fca6ea1SDimitry Andric return output; 59*0fca6ea1SDimitry Andric } 60*0fca6ea1SDimitry Andric 61*0fca6ea1SDimitry Andric if (auto *SI = dyn_cast<SwitchInst>(T)) { 62*0fca6ea1SDimitry Andric output.insert(SI->getDefaultDest()); 63*0fca6ea1SDimitry Andric for (auto &Case : SI->cases()) 64*0fca6ea1SDimitry Andric output.insert(Case.getCaseSuccessor()); 65*0fca6ea1SDimitry Andric return output; 66*0fca6ea1SDimitry Andric } 67*0fca6ea1SDimitry Andric 68*0fca6ea1SDimitry Andric assert(isa<ReturnInst>(T) && "Unhandled terminator type."); 69*0fca6ea1SDimitry Andric return output; 70*0fca6ea1SDimitry Andric } 71*0fca6ea1SDimitry Andric 72*0fca6ea1SDimitry Andric /// Create a value in BB set to the value associated with the branch the block 73*0fca6ea1SDimitry Andric /// terminator will take. 74*0fca6ea1SDimitry Andric llvm::Value *createExitVariable( 75*0fca6ea1SDimitry Andric BasicBlock *BB, 76*0fca6ea1SDimitry Andric const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) { 77*0fca6ea1SDimitry Andric auto *T = BB->getTerminator(); 78*0fca6ea1SDimitry Andric if (isa<ReturnInst>(T)) 79*0fca6ea1SDimitry Andric return nullptr; 80*0fca6ea1SDimitry Andric 81*0fca6ea1SDimitry Andric IRBuilder<> Builder(BB); 82*0fca6ea1SDimitry Andric Builder.SetInsertPoint(T); 83*0fca6ea1SDimitry Andric 84*0fca6ea1SDimitry Andric if (auto *BI = dyn_cast<BranchInst>(T)) { 85*0fca6ea1SDimitry Andric 86*0fca6ea1SDimitry Andric BasicBlock *LHSTarget = BI->getSuccessor(0); 87*0fca6ea1SDimitry Andric BasicBlock *RHSTarget = 88*0fca6ea1SDimitry Andric BI->isConditional() ? BI->getSuccessor(1) : nullptr; 89*0fca6ea1SDimitry Andric 90*0fca6ea1SDimitry Andric Value *LHS = TargetToValue.count(LHSTarget) != 0 91*0fca6ea1SDimitry Andric ? TargetToValue.at(LHSTarget) 92*0fca6ea1SDimitry Andric : nullptr; 93*0fca6ea1SDimitry Andric Value *RHS = TargetToValue.count(RHSTarget) != 0 94*0fca6ea1SDimitry Andric ? TargetToValue.at(RHSTarget) 95*0fca6ea1SDimitry Andric : nullptr; 96*0fca6ea1SDimitry Andric 97*0fca6ea1SDimitry Andric if (LHS == nullptr || RHS == nullptr) 98*0fca6ea1SDimitry Andric return LHS == nullptr ? RHS : LHS; 99*0fca6ea1SDimitry Andric return Builder.CreateSelect(BI->getCondition(), LHS, RHS); 100*0fca6ea1SDimitry Andric } 101*0fca6ea1SDimitry Andric 102*0fca6ea1SDimitry Andric // TODO: add support for switch cases. 103*0fca6ea1SDimitry Andric llvm_unreachable("Unhandled terminator type."); 104*0fca6ea1SDimitry Andric } 105*0fca6ea1SDimitry Andric 106*0fca6ea1SDimitry Andric /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|. 107*0fca6ea1SDimitry Andric void replaceBranchTargets(BasicBlock *BB, 108*0fca6ea1SDimitry Andric const SmallPtrSet<BasicBlock *, 4> &ToReplace, 109*0fca6ea1SDimitry Andric BasicBlock *NewTarget) { 110*0fca6ea1SDimitry Andric auto *T = BB->getTerminator(); 111*0fca6ea1SDimitry Andric if (isa<ReturnInst>(T)) 112*0fca6ea1SDimitry Andric return; 113*0fca6ea1SDimitry Andric 114*0fca6ea1SDimitry Andric if (auto *BI = dyn_cast<BranchInst>(T)) { 115*0fca6ea1SDimitry Andric for (size_t i = 0; i < BI->getNumSuccessors(); i++) { 116*0fca6ea1SDimitry Andric if (ToReplace.count(BI->getSuccessor(i)) != 0) 117*0fca6ea1SDimitry Andric BI->setSuccessor(i, NewTarget); 118*0fca6ea1SDimitry Andric } 119*0fca6ea1SDimitry Andric return; 120*0fca6ea1SDimitry Andric } 121*0fca6ea1SDimitry Andric 122*0fca6ea1SDimitry Andric if (auto *SI = dyn_cast<SwitchInst>(T)) { 123*0fca6ea1SDimitry Andric for (size_t i = 0; i < SI->getNumSuccessors(); i++) { 124*0fca6ea1SDimitry Andric if (ToReplace.count(SI->getSuccessor(i)) != 0) 125*0fca6ea1SDimitry Andric SI->setSuccessor(i, NewTarget); 126*0fca6ea1SDimitry Andric } 127*0fca6ea1SDimitry Andric return; 128*0fca6ea1SDimitry Andric } 129*0fca6ea1SDimitry Andric 130*0fca6ea1SDimitry Andric assert(false && "Unhandled terminator type."); 131*0fca6ea1SDimitry Andric } 132*0fca6ea1SDimitry Andric 133*0fca6ea1SDimitry Andric // Run the pass on the given convergence region, ignoring the sub-regions. 134*0fca6ea1SDimitry Andric // Returns true if the CFG changed, false otherwise. 135*0fca6ea1SDimitry Andric bool runOnConvergenceRegionNoRecurse(LoopInfo &LI, 136*0fca6ea1SDimitry Andric const SPIRV::ConvergenceRegion *CR) { 137*0fca6ea1SDimitry Andric // Gather all the exit targets for this region. 138*0fca6ea1SDimitry Andric SmallPtrSet<BasicBlock *, 4> ExitTargets; 139*0fca6ea1SDimitry Andric for (BasicBlock *Exit : CR->Exits) { 140*0fca6ea1SDimitry Andric for (BasicBlock *Target : gatherSuccessors(Exit)) { 141*0fca6ea1SDimitry Andric if (CR->Blocks.count(Target) == 0) 142*0fca6ea1SDimitry Andric ExitTargets.insert(Target); 143*0fca6ea1SDimitry Andric } 144*0fca6ea1SDimitry Andric } 145*0fca6ea1SDimitry Andric 146*0fca6ea1SDimitry Andric // If we have zero or one exit target, nothing do to. 147*0fca6ea1SDimitry Andric if (ExitTargets.size() <= 1) 148*0fca6ea1SDimitry Andric return false; 149*0fca6ea1SDimitry Andric 150*0fca6ea1SDimitry Andric // Create the new single exit target. 151*0fca6ea1SDimitry Andric auto F = CR->Entry->getParent(); 152*0fca6ea1SDimitry Andric auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F); 153*0fca6ea1SDimitry Andric IRBuilder<> Builder(NewExitTarget); 154*0fca6ea1SDimitry Andric 155*0fca6ea1SDimitry Andric // CodeGen output needs to be stable. Using the set as-is would order 156*0fca6ea1SDimitry Andric // the targets differently depending on the allocation pattern. 157*0fca6ea1SDimitry Andric // Sorting per basic-block ordering in the function. 158*0fca6ea1SDimitry Andric std::vector<BasicBlock *> SortedExitTargets; 159*0fca6ea1SDimitry Andric std::vector<BasicBlock *> SortedExits; 160*0fca6ea1SDimitry Andric for (BasicBlock &BB : *F) { 161*0fca6ea1SDimitry Andric if (ExitTargets.count(&BB) != 0) 162*0fca6ea1SDimitry Andric SortedExitTargets.push_back(&BB); 163*0fca6ea1SDimitry Andric if (CR->Exits.count(&BB) != 0) 164*0fca6ea1SDimitry Andric SortedExits.push_back(&BB); 165*0fca6ea1SDimitry Andric } 166*0fca6ea1SDimitry Andric 167*0fca6ea1SDimitry Andric // Creating one constant per distinct exit target. This will be route to the 168*0fca6ea1SDimitry Andric // correct target. 169*0fca6ea1SDimitry Andric DenseMap<BasicBlock *, ConstantInt *> TargetToValue; 170*0fca6ea1SDimitry Andric for (BasicBlock *Target : SortedExitTargets) 171*0fca6ea1SDimitry Andric TargetToValue.insert( 172*0fca6ea1SDimitry Andric std::make_pair(Target, Builder.getInt32(TargetToValue.size()))); 173*0fca6ea1SDimitry Andric 174*0fca6ea1SDimitry Andric // Creating one variable per exit node, set to the constant matching the 175*0fca6ea1SDimitry Andric // targeted external block. 176*0fca6ea1SDimitry Andric std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable; 177*0fca6ea1SDimitry Andric for (auto Exit : SortedExits) { 178*0fca6ea1SDimitry Andric llvm::Value *Value = createExitVariable(Exit, TargetToValue); 179*0fca6ea1SDimitry Andric ExitToVariable.emplace_back(std::make_pair(Exit, Value)); 180*0fca6ea1SDimitry Andric } 181*0fca6ea1SDimitry Andric 182*0fca6ea1SDimitry Andric // Gather the correct value depending on the exit we came from. 183*0fca6ea1SDimitry Andric llvm::PHINode *node = 184*0fca6ea1SDimitry Andric Builder.CreatePHI(Builder.getInt32Ty(), ExitToVariable.size()); 185*0fca6ea1SDimitry Andric for (auto [BB, Value] : ExitToVariable) { 186*0fca6ea1SDimitry Andric node->addIncoming(Value, BB); 187*0fca6ea1SDimitry Andric } 188*0fca6ea1SDimitry Andric 189*0fca6ea1SDimitry Andric // Creating the switch to jump to the correct exit target. 190*0fca6ea1SDimitry Andric llvm::SwitchInst *Sw = Builder.CreateSwitch(node, SortedExitTargets[0], 191*0fca6ea1SDimitry Andric SortedExitTargets.size() - 1); 192*0fca6ea1SDimitry Andric for (size_t i = 1; i < SortedExitTargets.size(); i++) { 193*0fca6ea1SDimitry Andric BasicBlock *BB = SortedExitTargets[i]; 194*0fca6ea1SDimitry Andric Sw->addCase(TargetToValue[BB], BB); 195*0fca6ea1SDimitry Andric } 196*0fca6ea1SDimitry Andric 197*0fca6ea1SDimitry Andric // Fix exit branches to redirect to the new exit. 198*0fca6ea1SDimitry Andric for (auto Exit : CR->Exits) 199*0fca6ea1SDimitry Andric replaceBranchTargets(Exit, ExitTargets, NewExitTarget); 200*0fca6ea1SDimitry Andric 201*0fca6ea1SDimitry Andric return true; 202*0fca6ea1SDimitry Andric } 203*0fca6ea1SDimitry Andric 204*0fca6ea1SDimitry Andric /// Run the pass on the given convergence region and sub-regions (DFS). 205*0fca6ea1SDimitry Andric /// Returns true if a region/sub-region was modified, false otherwise. 206*0fca6ea1SDimitry Andric /// This returns as soon as one region/sub-region has been modified. 207*0fca6ea1SDimitry Andric bool runOnConvergenceRegion(LoopInfo &LI, 208*0fca6ea1SDimitry Andric const SPIRV::ConvergenceRegion *CR) { 209*0fca6ea1SDimitry Andric for (auto *Child : CR->Children) 210*0fca6ea1SDimitry Andric if (runOnConvergenceRegion(LI, Child)) 211*0fca6ea1SDimitry Andric return true; 212*0fca6ea1SDimitry Andric 213*0fca6ea1SDimitry Andric return runOnConvergenceRegionNoRecurse(LI, CR); 214*0fca6ea1SDimitry Andric } 215*0fca6ea1SDimitry Andric 216*0fca6ea1SDimitry Andric #if !NDEBUG 217*0fca6ea1SDimitry Andric /// Validates each edge exiting the region has the same destination basic 218*0fca6ea1SDimitry Andric /// block. 219*0fca6ea1SDimitry Andric void validateRegionExits(const SPIRV::ConvergenceRegion *CR) { 220*0fca6ea1SDimitry Andric for (auto *Child : CR->Children) 221*0fca6ea1SDimitry Andric validateRegionExits(Child); 222*0fca6ea1SDimitry Andric 223*0fca6ea1SDimitry Andric std::unordered_set<BasicBlock *> ExitTargets; 224*0fca6ea1SDimitry Andric for (auto *Exit : CR->Exits) { 225*0fca6ea1SDimitry Andric auto Set = gatherSuccessors(Exit); 226*0fca6ea1SDimitry Andric for (auto *BB : Set) { 227*0fca6ea1SDimitry Andric if (CR->Blocks.count(BB) == 0) 228*0fca6ea1SDimitry Andric ExitTargets.insert(BB); 229*0fca6ea1SDimitry Andric } 230*0fca6ea1SDimitry Andric } 231*0fca6ea1SDimitry Andric 232*0fca6ea1SDimitry Andric assert(ExitTargets.size() <= 1); 233*0fca6ea1SDimitry Andric } 234*0fca6ea1SDimitry Andric #endif 235*0fca6ea1SDimitry Andric 236*0fca6ea1SDimitry Andric virtual bool runOnFunction(Function &F) override { 237*0fca6ea1SDimitry Andric LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 238*0fca6ea1SDimitry Andric const auto *TopLevelRegion = 239*0fca6ea1SDimitry Andric getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>() 240*0fca6ea1SDimitry Andric .getRegionInfo() 241*0fca6ea1SDimitry Andric .getTopLevelRegion(); 242*0fca6ea1SDimitry Andric 243*0fca6ea1SDimitry Andric // FIXME: very inefficient method: each time a region is modified, we bubble 244*0fca6ea1SDimitry Andric // back up, and recompute the whole convergence region tree. Once the 245*0fca6ea1SDimitry Andric // algorithm is completed and test coverage good enough, rewrite this pass 246*0fca6ea1SDimitry Andric // to be efficient instead of simple. 247*0fca6ea1SDimitry Andric bool modified = false; 248*0fca6ea1SDimitry Andric while (runOnConvergenceRegion(LI, TopLevelRegion)) { 249*0fca6ea1SDimitry Andric TopLevelRegion = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>() 250*0fca6ea1SDimitry Andric .getRegionInfo() 251*0fca6ea1SDimitry Andric .getTopLevelRegion(); 252*0fca6ea1SDimitry Andric modified = true; 253*0fca6ea1SDimitry Andric } 254*0fca6ea1SDimitry Andric 255*0fca6ea1SDimitry Andric #if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS) 256*0fca6ea1SDimitry Andric validateRegionExits(TopLevelRegion); 257*0fca6ea1SDimitry Andric #endif 258*0fca6ea1SDimitry Andric return modified; 259*0fca6ea1SDimitry Andric } 260*0fca6ea1SDimitry Andric 261*0fca6ea1SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 262*0fca6ea1SDimitry Andric AU.addRequired<DominatorTreeWrapperPass>(); 263*0fca6ea1SDimitry Andric AU.addRequired<LoopInfoWrapperPass>(); 264*0fca6ea1SDimitry Andric AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>(); 265*0fca6ea1SDimitry Andric FunctionPass::getAnalysisUsage(AU); 266*0fca6ea1SDimitry Andric } 267*0fca6ea1SDimitry Andric }; 268*0fca6ea1SDimitry Andric } // namespace llvm 269*0fca6ea1SDimitry Andric 270*0fca6ea1SDimitry Andric char SPIRVMergeRegionExitTargets::ID = 0; 271*0fca6ea1SDimitry Andric 272*0fca6ea1SDimitry Andric INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks", 273*0fca6ea1SDimitry Andric "SPIRV split region exit blocks", false, false) 274*0fca6ea1SDimitry Andric INITIALIZE_PASS_DEPENDENCY(LoopSimplify) 275*0fca6ea1SDimitry Andric INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 276*0fca6ea1SDimitry Andric INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 277*0fca6ea1SDimitry Andric INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass) 278*0fca6ea1SDimitry Andric 279*0fca6ea1SDimitry Andric INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks", 280*0fca6ea1SDimitry Andric "SPIRV split region exit blocks", false, false) 281*0fca6ea1SDimitry Andric 282*0fca6ea1SDimitry Andric FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() { 283*0fca6ea1SDimitry Andric return new SPIRVMergeRegionExitTargets(); 284*0fca6ea1SDimitry Andric } 285