1 //===- ConvergenceRegionAnalysis.h -----------------------------*- C++ -*--===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The analysis determines the convergence region for each basic block of 10 // the module, and provides a tree-like structure describing the region 11 // hierarchy. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "SPIRVConvergenceRegionAnalysis.h" 16 #include "llvm/Analysis/LoopInfo.h" 17 #include "llvm/IR/Dominators.h" 18 #include "llvm/IR/IntrinsicInst.h" 19 #include "llvm/InitializePasses.h" 20 #include "llvm/Transforms/Utils/LoopSimplify.h" 21 #include <optional> 22 #include <queue> 23 24 #define DEBUG_TYPE "spirv-convergence-region-analysis" 25 26 using namespace llvm; 27 28 namespace llvm { 29 void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &); 30 } // namespace llvm 31 32 INITIALIZE_PASS_BEGIN(SPIRVConvergenceRegionAnalysisWrapperPass, 33 "convergence-region", 34 "SPIRV convergence regions analysis", true, true) 35 INITIALIZE_PASS_DEPENDENCY(LoopSimplify) 36 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 37 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 38 INITIALIZE_PASS_END(SPIRVConvergenceRegionAnalysisWrapperPass, 39 "convergence-region", "SPIRV convergence regions analysis", 40 true, true) 41 42 namespace llvm { 43 namespace SPIRV { 44 namespace { 45 46 template <typename BasicBlockType, typename IntrinsicInstType> 47 std::optional<IntrinsicInstType *> 48 getConvergenceTokenInternal(BasicBlockType *BB) { 49 static_assert(std::is_const_v<IntrinsicInstType> == 50 std::is_const_v<BasicBlockType>, 51 "Constness must match between input and output."); 52 static_assert(std::is_same_v<BasicBlock, std::remove_const_t<BasicBlockType>>, 53 "Input must be a basic block."); 54 static_assert( 55 std::is_same_v<IntrinsicInst, std::remove_const_t<IntrinsicInstType>>, 56 "Output type must be an intrinsic instruction."); 57 58 for (auto &I : *BB) { 59 if (auto *CI = dyn_cast<ConvergenceControlInst>(&I)) { 60 // Make sure that the anchor or entry intrinsics did not reach here with a 61 // parent token. This should have failed the verifier. 62 assert(CI->isLoop() || 63 !CI->getOperandBundle(LLVMContext::OB_convergencectrl)); 64 return CI; 65 } 66 67 if (auto *CI = dyn_cast<CallInst>(&I)) { 68 auto OB = CI->getOperandBundle(LLVMContext::OB_convergencectrl); 69 if (!OB.has_value()) 70 continue; 71 return dyn_cast<IntrinsicInst>(OB.value().Inputs[0]); 72 } 73 } 74 75 return std::nullopt; 76 } 77 78 // Given a ConvergenceRegion tree with |Start| as its root, finds the smallest 79 // region |Entry| belongs to. If |Entry| does not belong to the region defined 80 // by |Start|, this function returns |nullptr|. 81 ConvergenceRegion *findParentRegion(ConvergenceRegion *Start, 82 BasicBlock *Entry) { 83 ConvergenceRegion *Candidate = nullptr; 84 ConvergenceRegion *NextCandidate = Start; 85 86 while (Candidate != NextCandidate && NextCandidate != nullptr) { 87 Candidate = NextCandidate; 88 NextCandidate = nullptr; 89 90 // End of the search, we can return. 91 if (Candidate->Children.size() == 0) 92 return Candidate; 93 94 for (auto *Child : Candidate->Children) { 95 if (Child->Blocks.count(Entry) != 0) { 96 NextCandidate = Child; 97 break; 98 } 99 } 100 } 101 102 return Candidate; 103 } 104 105 } // anonymous namespace 106 107 std::optional<IntrinsicInst *> getConvergenceToken(BasicBlock *BB) { 108 return getConvergenceTokenInternal<BasicBlock, IntrinsicInst>(BB); 109 } 110 111 std::optional<const IntrinsicInst *> getConvergenceToken(const BasicBlock *BB) { 112 return getConvergenceTokenInternal<const BasicBlock, const IntrinsicInst>(BB); 113 } 114 115 ConvergenceRegion::ConvergenceRegion(DominatorTree &DT, LoopInfo &LI, 116 Function &F) 117 : DT(DT), LI(LI), Parent(nullptr) { 118 Entry = &F.getEntryBlock(); 119 ConvergenceToken = getConvergenceToken(Entry); 120 for (auto &B : F) { 121 Blocks.insert(&B); 122 if (isa<ReturnInst>(B.getTerminator())) 123 Exits.insert(&B); 124 } 125 } 126 127 ConvergenceRegion::ConvergenceRegion( 128 DominatorTree &DT, LoopInfo &LI, 129 std::optional<IntrinsicInst *> ConvergenceToken, BasicBlock *Entry, 130 SmallPtrSet<BasicBlock *, 8> &&Blocks, SmallPtrSet<BasicBlock *, 2> &&Exits) 131 : DT(DT), LI(LI), ConvergenceToken(ConvergenceToken), Entry(Entry), 132 Exits(std::move(Exits)), Blocks(std::move(Blocks)) { 133 for ([[maybe_unused]] auto *BB : this->Exits) 134 assert(this->Blocks.count(BB) != 0); 135 assert(this->Blocks.count(this->Entry) != 0); 136 } 137 138 void ConvergenceRegion::releaseMemory() { 139 // Parent memory is owned by the parent. 140 Parent = nullptr; 141 for (auto *Child : Children) { 142 Child->releaseMemory(); 143 delete Child; 144 } 145 Children.resize(0); 146 } 147 148 void ConvergenceRegion::dump(const unsigned IndentSize) const { 149 const std::string Indent(IndentSize, '\t'); 150 dbgs() << Indent << this << ": {\n"; 151 dbgs() << Indent << " Parent: " << Parent << "\n"; 152 153 if (ConvergenceToken.value_or(nullptr)) { 154 dbgs() << Indent 155 << " ConvergenceToken: " << ConvergenceToken.value()->getName() 156 << "\n"; 157 } 158 159 if (Entry->getName() != "") 160 dbgs() << Indent << " Entry: " << Entry->getName() << "\n"; 161 else 162 dbgs() << Indent << " Entry: " << Entry << "\n"; 163 164 dbgs() << Indent << " Exits: { "; 165 for (const auto &Exit : Exits) { 166 if (Exit->getName() != "") 167 dbgs() << Exit->getName() << ", "; 168 else 169 dbgs() << Exit << ", "; 170 } 171 dbgs() << " }\n"; 172 173 dbgs() << Indent << " Blocks: { "; 174 for (const auto &Block : Blocks) { 175 if (Block->getName() != "") 176 dbgs() << Block->getName() << ", "; 177 else 178 dbgs() << Block << ", "; 179 } 180 dbgs() << " }\n"; 181 182 dbgs() << Indent << " Children: {\n"; 183 for (const auto Child : Children) 184 Child->dump(IndentSize + 2); 185 dbgs() << Indent << " }\n"; 186 187 dbgs() << Indent << "}\n"; 188 } 189 190 class ConvergenceRegionAnalyzer { 191 192 public: 193 ConvergenceRegionAnalyzer(Function &F, DominatorTree &DT, LoopInfo &LI) 194 : DT(DT), LI(LI), F(F) {} 195 196 private: 197 bool isBackEdge(const BasicBlock *From, const BasicBlock *To) const { 198 if (From == To) 199 return true; 200 201 // We only handle loop in the simplified form. This means: 202 // - a single back-edge, a single latch. 203 // - meaning the back-edge target can only be the loop header. 204 // - meaning the From can only be the loop latch. 205 if (!LI.isLoopHeader(To)) 206 return false; 207 208 auto *L = LI.getLoopFor(To); 209 if (L->contains(From) && L->isLoopLatch(From)) 210 return true; 211 212 return false; 213 } 214 215 std::unordered_set<BasicBlock *> 216 findPathsToMatch(LoopInfo &LI, BasicBlock *From, 217 std::function<bool(const BasicBlock *)> isMatch) const { 218 std::unordered_set<BasicBlock *> Output; 219 220 if (isMatch(From)) 221 Output.insert(From); 222 223 auto *Terminator = From->getTerminator(); 224 for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) { 225 auto *To = Terminator->getSuccessor(i); 226 // Ignore back edges. 227 if (isBackEdge(From, To)) 228 continue; 229 230 auto ChildSet = findPathsToMatch(LI, To, isMatch); 231 if (ChildSet.size() == 0) 232 continue; 233 234 Output.insert(ChildSet.begin(), ChildSet.end()); 235 Output.insert(From); 236 if (LI.isLoopHeader(From)) { 237 auto *L = LI.getLoopFor(From); 238 for (auto *BB : L->getBlocks()) { 239 Output.insert(BB); 240 } 241 } 242 } 243 244 return Output; 245 } 246 247 SmallPtrSet<BasicBlock *, 2> 248 findExitNodes(const SmallPtrSetImpl<BasicBlock *> &RegionBlocks) { 249 SmallPtrSet<BasicBlock *, 2> Exits; 250 251 for (auto *B : RegionBlocks) { 252 auto *Terminator = B->getTerminator(); 253 for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) { 254 auto *Child = Terminator->getSuccessor(i); 255 if (RegionBlocks.count(Child) == 0) 256 Exits.insert(B); 257 } 258 } 259 260 return Exits; 261 } 262 263 public: 264 ConvergenceRegionInfo analyze() { 265 ConvergenceRegion *TopLevelRegion = new ConvergenceRegion(DT, LI, F); 266 std::queue<Loop *> ToProcess; 267 for (auto *L : LI.getLoopsInPreorder()) 268 ToProcess.push(L); 269 270 while (ToProcess.size() != 0) { 271 auto *L = ToProcess.front(); 272 ToProcess.pop(); 273 274 auto CT = getConvergenceToken(L->getHeader()); 275 SmallPtrSet<BasicBlock *, 8> RegionBlocks(L->block_begin(), 276 L->block_end()); 277 SmallVector<BasicBlock *> LoopExits; 278 L->getExitingBlocks(LoopExits); 279 if (CT.has_value()) { 280 for (auto *Exit : LoopExits) { 281 auto N = findPathsToMatch(LI, Exit, [&CT](const BasicBlock *block) { 282 auto Token = getConvergenceToken(block); 283 if (Token == std::nullopt) 284 return false; 285 return Token.value() == CT.value(); 286 }); 287 RegionBlocks.insert(N.begin(), N.end()); 288 } 289 } 290 291 auto RegionExits = findExitNodes(RegionBlocks); 292 ConvergenceRegion *Region = new ConvergenceRegion( 293 DT, LI, CT, L->getHeader(), std::move(RegionBlocks), 294 std::move(RegionExits)); 295 Region->Parent = findParentRegion(TopLevelRegion, Region->Entry); 296 assert(Region->Parent != nullptr && "This is impossible."); 297 Region->Parent->Children.push_back(Region); 298 } 299 300 return ConvergenceRegionInfo(TopLevelRegion); 301 } 302 303 private: 304 DominatorTree &DT; 305 LoopInfo &LI; 306 Function &F; 307 }; 308 309 ConvergenceRegionInfo getConvergenceRegions(Function &F, DominatorTree &DT, 310 LoopInfo &LI) { 311 ConvergenceRegionAnalyzer Analyzer(F, DT, LI); 312 return Analyzer.analyze(); 313 } 314 315 } // namespace SPIRV 316 317 char SPIRVConvergenceRegionAnalysisWrapperPass::ID = 0; 318 319 SPIRVConvergenceRegionAnalysisWrapperPass:: 320 SPIRVConvergenceRegionAnalysisWrapperPass() 321 : FunctionPass(ID) {} 322 323 bool SPIRVConvergenceRegionAnalysisWrapperPass::runOnFunction(Function &F) { 324 DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 325 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 326 327 CRI = SPIRV::getConvergenceRegions(F, DT, LI); 328 // Nothing was modified. 329 return false; 330 } 331 332 SPIRVConvergenceRegionAnalysis::Result 333 SPIRVConvergenceRegionAnalysis::run(Function &F, FunctionAnalysisManager &AM) { 334 Result CRI; 335 auto &DT = AM.getResult<DominatorTreeAnalysis>(F); 336 auto &LI = AM.getResult<LoopAnalysis>(F); 337 CRI = SPIRV::getConvergenceRegions(F, DT, LI); 338 return CRI; 339 } 340 341 AnalysisKey SPIRVConvergenceRegionAnalysis::Key; 342 343 } // namespace llvm 344