1 //===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Merge the multiple exit targets of a convergence region into a single block. 10 // Each exit target will be assigned a constant value, and a phi node + switch 11 // will allow the new exit target to re-route to the correct basic block. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "Analysis/SPIRVConvergenceRegionAnalysis.h" 16 #include "SPIRV.h" 17 #include "SPIRVSubtarget.h" 18 #include "SPIRVTargetMachine.h" 19 #include "SPIRVUtils.h" 20 #include "llvm/ADT/DenseMap.h" 21 #include "llvm/ADT/SmallPtrSet.h" 22 #include "llvm/Analysis/LoopInfo.h" 23 #include "llvm/CodeGen/IntrinsicLowering.h" 24 #include "llvm/IR/CFG.h" 25 #include "llvm/IR/Dominators.h" 26 #include "llvm/IR/IRBuilder.h" 27 #include "llvm/IR/IntrinsicInst.h" 28 #include "llvm/IR/Intrinsics.h" 29 #include "llvm/IR/IntrinsicsSPIRV.h" 30 #include "llvm/InitializePasses.h" 31 #include "llvm/Transforms/Utils/Cloning.h" 32 #include "llvm/Transforms/Utils/LoopSimplify.h" 33 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" 34 35 using namespace llvm; 36 37 namespace llvm { 38 void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &); 39 40 class SPIRVMergeRegionExitTargets : public FunctionPass { 41 public: 42 static char ID; 43 44 SPIRVMergeRegionExitTargets() : FunctionPass(ID) { 45 initializeSPIRVMergeRegionExitTargetsPass(*PassRegistry::getPassRegistry()); 46 }; 47 48 // Gather all the successors of |BB|. 49 // This function asserts if the terminator neither a branch, switch or return. 50 std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) { 51 std::unordered_set<BasicBlock *> output; 52 auto *T = BB->getTerminator(); 53 54 if (auto *BI = dyn_cast<BranchInst>(T)) { 55 output.insert(BI->getSuccessor(0)); 56 if (BI->isConditional()) 57 output.insert(BI->getSuccessor(1)); 58 return output; 59 } 60 61 if (auto *SI = dyn_cast<SwitchInst>(T)) { 62 output.insert(SI->getDefaultDest()); 63 for (auto &Case : SI->cases()) 64 output.insert(Case.getCaseSuccessor()); 65 return output; 66 } 67 68 assert(isa<ReturnInst>(T) && "Unhandled terminator type."); 69 return output; 70 } 71 72 /// Create a value in BB set to the value associated with the branch the block 73 /// terminator will take. 74 llvm::Value *createExitVariable( 75 BasicBlock *BB, 76 const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) { 77 auto *T = BB->getTerminator(); 78 if (isa<ReturnInst>(T)) 79 return nullptr; 80 81 IRBuilder<> Builder(BB); 82 Builder.SetInsertPoint(T); 83 84 if (auto *BI = dyn_cast<BranchInst>(T)) { 85 86 BasicBlock *LHSTarget = BI->getSuccessor(0); 87 BasicBlock *RHSTarget = 88 BI->isConditional() ? BI->getSuccessor(1) : nullptr; 89 90 Value *LHS = TargetToValue.count(LHSTarget) != 0 91 ? TargetToValue.at(LHSTarget) 92 : nullptr; 93 Value *RHS = TargetToValue.count(RHSTarget) != 0 94 ? TargetToValue.at(RHSTarget) 95 : nullptr; 96 97 if (LHS == nullptr || RHS == nullptr) 98 return LHS == nullptr ? RHS : LHS; 99 return Builder.CreateSelect(BI->getCondition(), LHS, RHS); 100 } 101 102 // TODO: add support for switch cases. 103 llvm_unreachable("Unhandled terminator type."); 104 } 105 106 /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|. 107 void replaceBranchTargets(BasicBlock *BB, 108 const SmallPtrSet<BasicBlock *, 4> &ToReplace, 109 BasicBlock *NewTarget) { 110 auto *T = BB->getTerminator(); 111 if (isa<ReturnInst>(T)) 112 return; 113 114 if (auto *BI = dyn_cast<BranchInst>(T)) { 115 for (size_t i = 0; i < BI->getNumSuccessors(); i++) { 116 if (ToReplace.count(BI->getSuccessor(i)) != 0) 117 BI->setSuccessor(i, NewTarget); 118 } 119 return; 120 } 121 122 if (auto *SI = dyn_cast<SwitchInst>(T)) { 123 for (size_t i = 0; i < SI->getNumSuccessors(); i++) { 124 if (ToReplace.count(SI->getSuccessor(i)) != 0) 125 SI->setSuccessor(i, NewTarget); 126 } 127 return; 128 } 129 130 assert(false && "Unhandled terminator type."); 131 } 132 133 AllocaInst *CreateVariable(Function &F, Type *Type, 134 BasicBlock::iterator Position) { 135 const DataLayout &DL = F.getDataLayout(); 136 return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg", 137 Position); 138 } 139 140 // Run the pass on the given convergence region, ignoring the sub-regions. 141 // Returns true if the CFG changed, false otherwise. 142 bool runOnConvergenceRegionNoRecurse(LoopInfo &LI, 143 SPIRV::ConvergenceRegion *CR) { 144 // Gather all the exit targets for this region. 145 SmallPtrSet<BasicBlock *, 4> ExitTargets; 146 for (BasicBlock *Exit : CR->Exits) { 147 for (BasicBlock *Target : gatherSuccessors(Exit)) { 148 if (CR->Blocks.count(Target) == 0) 149 ExitTargets.insert(Target); 150 } 151 } 152 153 // If we have zero or one exit target, nothing do to. 154 if (ExitTargets.size() <= 1) 155 return false; 156 157 // Create the new single exit target. 158 auto F = CR->Entry->getParent(); 159 auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F); 160 IRBuilder<> Builder(NewExitTarget); 161 162 AllocaInst *Variable = CreateVariable(*F, Builder.getInt32Ty(), 163 F->begin()->getFirstInsertionPt()); 164 165 // CodeGen output needs to be stable. Using the set as-is would order 166 // the targets differently depending on the allocation pattern. 167 // Sorting per basic-block ordering in the function. 168 std::vector<BasicBlock *> SortedExitTargets; 169 std::vector<BasicBlock *> SortedExits; 170 for (BasicBlock &BB : *F) { 171 if (ExitTargets.count(&BB) != 0) 172 SortedExitTargets.push_back(&BB); 173 if (CR->Exits.count(&BB) != 0) 174 SortedExits.push_back(&BB); 175 } 176 177 // Creating one constant per distinct exit target. This will be route to the 178 // correct target. 179 DenseMap<BasicBlock *, ConstantInt *> TargetToValue; 180 for (BasicBlock *Target : SortedExitTargets) 181 TargetToValue.insert( 182 std::make_pair(Target, Builder.getInt32(TargetToValue.size()))); 183 184 // Creating one variable per exit node, set to the constant matching the 185 // targeted external block. 186 std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable; 187 for (auto Exit : SortedExits) { 188 llvm::Value *Value = createExitVariable(Exit, TargetToValue); 189 IRBuilder<> B2(Exit); 190 B2.SetInsertPoint(Exit->getFirstInsertionPt()); 191 B2.CreateStore(Value, Variable); 192 ExitToVariable.emplace_back(std::make_pair(Exit, Value)); 193 } 194 195 llvm::Value *Load = Builder.CreateLoad(Builder.getInt32Ty(), Variable); 196 197 // Creating the switch to jump to the correct exit target. 198 llvm::SwitchInst *Sw = Builder.CreateSwitch(Load, SortedExitTargets[0], 199 SortedExitTargets.size() - 1); 200 for (size_t i = 1; i < SortedExitTargets.size(); i++) { 201 BasicBlock *BB = SortedExitTargets[i]; 202 Sw->addCase(TargetToValue[BB], BB); 203 } 204 205 // Fix exit branches to redirect to the new exit. 206 for (auto Exit : CR->Exits) 207 replaceBranchTargets(Exit, ExitTargets, NewExitTarget); 208 209 CR = CR->Parent; 210 while (CR) { 211 CR->Blocks.insert(NewExitTarget); 212 CR = CR->Parent; 213 } 214 215 return true; 216 } 217 218 /// Run the pass on the given convergence region and sub-regions (DFS). 219 /// Returns true if a region/sub-region was modified, false otherwise. 220 /// This returns as soon as one region/sub-region has been modified. 221 bool runOnConvergenceRegion(LoopInfo &LI, SPIRV::ConvergenceRegion *CR) { 222 for (auto *Child : CR->Children) 223 if (runOnConvergenceRegion(LI, Child)) 224 return true; 225 226 return runOnConvergenceRegionNoRecurse(LI, CR); 227 } 228 229 #if !NDEBUG 230 /// Validates each edge exiting the region has the same destination basic 231 /// block. 232 void validateRegionExits(const SPIRV::ConvergenceRegion *CR) { 233 for (auto *Child : CR->Children) 234 validateRegionExits(Child); 235 236 std::unordered_set<BasicBlock *> ExitTargets; 237 for (auto *Exit : CR->Exits) { 238 auto Set = gatherSuccessors(Exit); 239 for (auto *BB : Set) { 240 if (CR->Blocks.count(BB) == 0) 241 ExitTargets.insert(BB); 242 } 243 } 244 245 assert(ExitTargets.size() <= 1); 246 } 247 #endif 248 249 virtual bool runOnFunction(Function &F) override { 250 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 251 auto *TopLevelRegion = 252 getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>() 253 .getRegionInfo() 254 .getWritableTopLevelRegion(); 255 256 // FIXME: very inefficient method: each time a region is modified, we bubble 257 // back up, and recompute the whole convergence region tree. Once the 258 // algorithm is completed and test coverage good enough, rewrite this pass 259 // to be efficient instead of simple. 260 bool modified = false; 261 while (runOnConvergenceRegion(LI, TopLevelRegion)) { 262 modified = true; 263 } 264 265 #if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS) 266 validateRegionExits(TopLevelRegion); 267 #endif 268 return modified; 269 } 270 271 void getAnalysisUsage(AnalysisUsage &AU) const override { 272 AU.addRequired<DominatorTreeWrapperPass>(); 273 AU.addRequired<LoopInfoWrapperPass>(); 274 AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>(); 275 276 AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>(); 277 FunctionPass::getAnalysisUsage(AU); 278 } 279 }; 280 } // namespace llvm 281 282 char SPIRVMergeRegionExitTargets::ID = 0; 283 284 INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks", 285 "SPIRV split region exit blocks", false, false) 286 INITIALIZE_PASS_DEPENDENCY(LoopSimplify) 287 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 288 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 289 INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass) 290 291 INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks", 292 "SPIRV split region exit blocks", false, false) 293 294 FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() { 295 return new SPIRVMergeRegionExitTargets(); 296 } 297