1 //===--- SyncDependenceAnalysis.cpp - Compute Control Divergence Effects --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements an algorithm that returns for a divergent branch 10 // the set of basic blocks whose phi nodes become divergent due to divergent 11 // control. These are the blocks that are reachable by two disjoint paths from 12 // the branch or loop exits that have a reaching path that is disjoint from a 13 // path to the loop latch. 14 // 15 // The SyncDependenceAnalysis is used in the DivergenceAnalysis to model 16 // control-induced divergence in phi nodes. 17 // 18 // 19 // -- Reference -- 20 // The algorithm is presented in Section 5 of 21 // 22 // An abstract interpretation for SPMD divergence 23 // on reducible control flow graphs. 24 // Julian Rosemann, Simon Moll and Sebastian Hack 25 // POPL '21 26 // 27 // 28 // -- Sync dependence -- 29 // Sync dependence characterizes the control flow aspect of the 30 // propagation of branch divergence. For example, 31 // 32 // %cond = icmp slt i32 %tid, 10 33 // br i1 %cond, label %then, label %else 34 // then: 35 // br label %merge 36 // else: 37 // br label %merge 38 // merge: 39 // %a = phi i32 [ 0, %then ], [ 1, %else ] 40 // 41 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid 42 // because %tid is not on its use-def chains, %a is sync dependent on %tid 43 // because the branch "br i1 %cond" depends on %tid and affects which value %a 44 // is assigned to. 45 // 46 // 47 // -- Reduction to SSA construction -- 48 // There are two disjoint paths from A to X, if a certain variant of SSA 49 // construction places a phi node in X under the following set-up scheme. 50 // 51 // This variant of SSA construction ignores incoming undef values. 52 // That is paths from the entry without a definition do not result in 53 // phi nodes. 54 // 55 // entry 56 // / \ 57 // A \ 58 // / \ Y 59 // B C / 60 // \ / \ / 61 // D E 62 // \ / 63 // F 64 // 65 // Assume that A contains a divergent branch. We are interested 66 // in the set of all blocks where each block is reachable from A 67 // via two disjoint paths. This would be the set {D, F} in this 68 // case. 69 // To generally reduce this query to SSA construction we introduce 70 // a virtual variable x and assign to x different values in each 71 // successor block of A. 72 // 73 // entry 74 // / \ 75 // A \ 76 // / \ Y 77 // x = 0 x = 1 / 78 // \ / \ / 79 // D E 80 // \ / 81 // F 82 // 83 // Our flavor of SSA construction for x will construct the following 84 // 85 // entry 86 // / \ 87 // A \ 88 // / \ Y 89 // x0 = 0 x1 = 1 / 90 // \ / \ / 91 // x2 = phi E 92 // \ / 93 // x3 = phi 94 // 95 // The blocks D and F contain phi nodes and are thus each reachable 96 // by two disjoins paths from A. 97 // 98 // -- Remarks -- 99 // * In case of loop exits we need to check the disjoint path criterion for loops. 100 // To this end, we check whether the definition of x differs between the 101 // loop exit and the loop header (_after_ SSA construction). 102 // 103 // -- Known Limitations & Future Work -- 104 // * The algorithm requires reducible loops because the implementation 105 // implicitly performs a single iteration of the underlying data flow analysis. 106 // This was done for pragmatism, simplicity and speed. 107 // 108 // Relevant related work for extending the algorithm to irreducible control: 109 // A simple algorithm for global data flow analysis problems. 110 // Matthew S. Hecht and Jeffrey D. Ullman. 111 // SIAM Journal on Computing, 4(4):519–532, December 1975. 112 // 113 // * Another reason for requiring reducible loops is that points of 114 // synchronization in irreducible loops aren't 'obvious' - there is no unique 115 // header where threads 'should' synchronize when entering or coming back 116 // around from the latch. 117 // 118 //===----------------------------------------------------------------------===// 119 120 #include "llvm/Analysis/SyncDependenceAnalysis.h" 121 #include "llvm/ADT/SmallPtrSet.h" 122 #include "llvm/Analysis/LoopInfo.h" 123 #include "llvm/IR/BasicBlock.h" 124 #include "llvm/IR/CFG.h" 125 #include "llvm/IR/Dominators.h" 126 #include "llvm/IR/Function.h" 127 128 #include <functional> 129 130 #define DEBUG_TYPE "sync-dependence" 131 132 // The SDA algorithm operates on a modified CFG - we modify the edges leaving 133 // loop headers as follows: 134 // 135 // * We remove all edges leaving all loop headers. 136 // * We add additional edges from the loop headers to their exit blocks. 137 // 138 // The modification is virtual, that is whenever we visit a loop header we 139 // pretend it had different successors. 140 namespace { 141 using namespace llvm; 142 143 // Custom Post-Order Traveral 144 // 145 // We cannot use the vanilla (R)PO computation of LLVM because: 146 // * We (virtually) modify the CFG. 147 // * We want a loop-compact block enumeration, that is the numbers assigned to 148 // blocks of a loop form an interval 149 // 150 using POCB = std::function<void(const BasicBlock &)>; 151 using VisitedSet = std::set<const BasicBlock *>; 152 using BlockStack = std::vector<const BasicBlock *>; 153 154 // forward 155 static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack, 156 VisitedSet &Finalized); 157 158 // for a nested region (top-level loop or nested loop) 159 static void computeStackPO(BlockStack &Stack, const LoopInfo &LI, Loop *Loop, 160 POCB CallBack, VisitedSet &Finalized) { 161 const auto *LoopHeader = Loop ? Loop->getHeader() : nullptr; 162 while (!Stack.empty()) { 163 const auto *NextBB = Stack.back(); 164 165 auto *NestedLoop = LI.getLoopFor(NextBB); 166 bool IsNestedLoop = NestedLoop != Loop; 167 168 // Treat the loop as a node 169 if (IsNestedLoop) { 170 SmallVector<BasicBlock *, 3> NestedExits; 171 NestedLoop->getUniqueExitBlocks(NestedExits); 172 bool PushedNodes = false; 173 for (const auto *NestedExitBB : NestedExits) { 174 if (NestedExitBB == LoopHeader) 175 continue; 176 if (Loop && !Loop->contains(NestedExitBB)) 177 continue; 178 if (Finalized.count(NestedExitBB)) 179 continue; 180 PushedNodes = true; 181 Stack.push_back(NestedExitBB); 182 } 183 if (!PushedNodes) { 184 // All loop exits finalized -> finish this node 185 Stack.pop_back(); 186 computeLoopPO(LI, *NestedLoop, CallBack, Finalized); 187 } 188 continue; 189 } 190 191 // DAG-style 192 bool PushedNodes = false; 193 for (const auto *SuccBB : successors(NextBB)) { 194 if (SuccBB == LoopHeader) 195 continue; 196 if (Loop && !Loop->contains(SuccBB)) 197 continue; 198 if (Finalized.count(SuccBB)) 199 continue; 200 PushedNodes = true; 201 Stack.push_back(SuccBB); 202 } 203 if (!PushedNodes) { 204 // Never push nodes twice 205 Stack.pop_back(); 206 if (!Finalized.insert(NextBB).second) 207 continue; 208 CallBack(*NextBB); 209 } 210 } 211 } 212 213 static void computeTopLevelPO(Function &F, const LoopInfo &LI, POCB CallBack) { 214 VisitedSet Finalized; 215 BlockStack Stack; 216 Stack.reserve(24); // FIXME made-up number 217 Stack.push_back(&F.getEntryBlock()); 218 computeStackPO(Stack, LI, nullptr, CallBack, Finalized); 219 } 220 221 static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack, 222 VisitedSet &Finalized) { 223 /// Call CallBack on all loop blocks. 224 std::vector<const BasicBlock *> Stack; 225 const auto *LoopHeader = Loop.getHeader(); 226 227 // Visit the header last 228 Finalized.insert(LoopHeader); 229 CallBack(*LoopHeader); 230 231 // Initialize with immediate successors 232 for (const auto *BB : successors(LoopHeader)) { 233 if (!Loop.contains(BB)) 234 continue; 235 if (BB == LoopHeader) 236 continue; 237 Stack.push_back(BB); 238 } 239 240 // Compute PO inside region 241 computeStackPO(Stack, LI, &Loop, CallBack, Finalized); 242 } 243 244 } // namespace 245 246 namespace llvm { 247 248 ControlDivergenceDesc SyncDependenceAnalysis::EmptyDivergenceDesc; 249 250 SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT, 251 const PostDominatorTree &PDT, 252 const LoopInfo &LI) 253 : DT(DT), PDT(PDT), LI(LI) { 254 computeTopLevelPO(*DT.getRoot()->getParent(), LI, 255 [&](const BasicBlock &BB) { LoopPO.appendBlock(BB); }); 256 } 257 258 SyncDependenceAnalysis::~SyncDependenceAnalysis() = default; 259 260 namespace { 261 // divergence propagator for reducible CFGs 262 struct DivergencePropagator { 263 const ModifiedPO &LoopPOT; 264 const DominatorTree &DT; 265 const PostDominatorTree &PDT; 266 const LoopInfo &LI; 267 const BasicBlock &DivTermBlock; 268 269 // * if BlockLabels[IndexOf(B)] == C then C is the dominating definition at 270 // block B 271 // * if BlockLabels[IndexOf(B)] ~ undef then we haven't seen B yet 272 // * if BlockLabels[IndexOf(B)] == B then B is a join point of disjoint paths 273 // from X or B is an immediate successor of X (initial value). 274 using BlockLabelVec = std::vector<const BasicBlock *>; 275 BlockLabelVec BlockLabels; 276 // divergent join and loop exit descriptor. 277 std::unique_ptr<ControlDivergenceDesc> DivDesc; 278 279 DivergencePropagator(const ModifiedPO &LoopPOT, const DominatorTree &DT, 280 const PostDominatorTree &PDT, const LoopInfo &LI, 281 const BasicBlock &DivTermBlock) 282 : LoopPOT(LoopPOT), DT(DT), PDT(PDT), LI(LI), DivTermBlock(DivTermBlock), 283 BlockLabels(LoopPOT.size(), nullptr), 284 DivDesc(new ControlDivergenceDesc) {} 285 286 void printDefs(raw_ostream &Out) { 287 Out << "Propagator::BlockLabels {\n"; 288 for (int BlockIdx = (int)BlockLabels.size() - 1; BlockIdx > 0; --BlockIdx) { 289 const auto *Label = BlockLabels[BlockIdx]; 290 Out << LoopPOT.getBlockAt(BlockIdx)->getName().str() << "(" << BlockIdx 291 << ") : "; 292 if (!Label) { 293 Out << "<null>\n"; 294 } else { 295 Out << Label->getName() << "\n"; 296 } 297 } 298 Out << "}\n"; 299 } 300 301 // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this 302 // causes a divergent join. 303 bool computeJoin(const BasicBlock &SuccBlock, const BasicBlock &PushedLabel) { 304 auto SuccIdx = LoopPOT.getIndexOf(SuccBlock); 305 306 // unset or same reaching label 307 const auto *OldLabel = BlockLabels[SuccIdx]; 308 if (!OldLabel || (OldLabel == &PushedLabel)) { 309 BlockLabels[SuccIdx] = &PushedLabel; 310 return false; 311 } 312 313 // Update the definition 314 BlockLabels[SuccIdx] = &SuccBlock; 315 return true; 316 } 317 318 // visiting a virtual loop exit edge from the loop header --> temporal 319 // divergence on join 320 bool visitLoopExitEdge(const BasicBlock &ExitBlock, 321 const BasicBlock &DefBlock, bool FromParentLoop) { 322 // Pushing from a non-parent loop cannot cause temporal divergence. 323 if (!FromParentLoop) 324 return visitEdge(ExitBlock, DefBlock); 325 326 if (!computeJoin(ExitBlock, DefBlock)) 327 return false; 328 329 // Identified a divergent loop exit 330 DivDesc->LoopDivBlocks.insert(&ExitBlock); 331 LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock.getName() 332 << "\n"); 333 return true; 334 } 335 336 // process \p SuccBlock with reaching definition \p DefBlock 337 bool visitEdge(const BasicBlock &SuccBlock, const BasicBlock &DefBlock) { 338 if (!computeJoin(SuccBlock, DefBlock)) 339 return false; 340 341 // Divergent, disjoint paths join. 342 DivDesc->JoinDivBlocks.insert(&SuccBlock); 343 LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock.getName()); 344 return true; 345 } 346 347 std::unique_ptr<ControlDivergenceDesc> computeJoinPoints() { 348 assert(DivDesc); 349 350 LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock.getName() 351 << "\n"); 352 353 const auto *DivBlockLoop = LI.getLoopFor(&DivTermBlock); 354 355 // Early stopping criterion 356 int FloorIdx = LoopPOT.size() - 1; 357 const BasicBlock *FloorLabel = nullptr; 358 359 // bootstrap with branch targets 360 int BlockIdx = 0; 361 362 for (const auto *SuccBlock : successors(&DivTermBlock)) { 363 auto SuccIdx = LoopPOT.getIndexOf(*SuccBlock); 364 BlockLabels[SuccIdx] = SuccBlock; 365 366 // Find the successor with the highest index to start with 367 BlockIdx = std::max<int>(BlockIdx, SuccIdx); 368 FloorIdx = std::min<int>(FloorIdx, SuccIdx); 369 370 // Identify immediate divergent loop exits 371 if (!DivBlockLoop) 372 continue; 373 374 const auto *BlockLoop = LI.getLoopFor(SuccBlock); 375 if (BlockLoop && DivBlockLoop->contains(BlockLoop)) 376 continue; 377 DivDesc->LoopDivBlocks.insert(SuccBlock); 378 LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: " 379 << SuccBlock->getName() << "\n"); 380 } 381 382 // propagate definitions at the immediate successors of the node in RPO 383 for (; BlockIdx >= FloorIdx; --BlockIdx) { 384 LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs())); 385 386 // Any label available here 387 const auto *Label = BlockLabels[BlockIdx]; 388 if (!Label) 389 continue; 390 391 // Ok. Get the block 392 const auto *Block = LoopPOT.getBlockAt(BlockIdx); 393 LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n"); 394 395 auto *BlockLoop = LI.getLoopFor(Block); 396 bool IsLoopHeader = BlockLoop && BlockLoop->getHeader() == Block; 397 bool CausedJoin = false; 398 int LoweredFloorIdx = FloorIdx; 399 if (IsLoopHeader) { 400 // Disconnect from immediate successors and propagate directly to loop 401 // exits. 402 SmallVector<BasicBlock *, 4> BlockLoopExits; 403 BlockLoop->getExitBlocks(BlockLoopExits); 404 405 bool IsParentLoop = BlockLoop->contains(&DivTermBlock); 406 for (const auto *BlockLoopExit : BlockLoopExits) { 407 CausedJoin |= visitLoopExitEdge(*BlockLoopExit, *Label, IsParentLoop); 408 LoweredFloorIdx = std::min<int>(LoweredFloorIdx, 409 LoopPOT.getIndexOf(*BlockLoopExit)); 410 } 411 } else { 412 // Acyclic successor case 413 for (const auto *SuccBlock : successors(Block)) { 414 CausedJoin |= visitEdge(*SuccBlock, *Label); 415 LoweredFloorIdx = 416 std::min<int>(LoweredFloorIdx, LoopPOT.getIndexOf(*SuccBlock)); 417 } 418 } 419 420 // Floor update 421 if (CausedJoin) { 422 // 1. Different labels pushed to successors 423 FloorIdx = LoweredFloorIdx; 424 } else if (FloorLabel != Label) { 425 // 2. No join caused BUT we pushed a label that is different than the 426 // last pushed label 427 FloorIdx = LoweredFloorIdx; 428 FloorLabel = Label; 429 } 430 } 431 432 LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs())); 433 434 return std::move(DivDesc); 435 } 436 }; 437 } // end anonymous namespace 438 439 #ifndef NDEBUG 440 static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out) { 441 Out << "["; 442 ListSeparator LS; 443 for (const auto *BB : Blocks) 444 Out << LS << BB->getName(); 445 Out << "]"; 446 } 447 #endif 448 449 const ControlDivergenceDesc & 450 SyncDependenceAnalysis::getJoinBlocks(const Instruction &Term) { 451 // trivial case 452 if (Term.getNumSuccessors() <= 1) { 453 return EmptyDivergenceDesc; 454 } 455 456 // already available in cache? 457 auto ItCached = CachedControlDivDescs.find(&Term); 458 if (ItCached != CachedControlDivDescs.end()) 459 return *ItCached->second; 460 461 // compute all join points 462 // Special handling of divergent loop exits is not needed for LCSSA 463 const auto &TermBlock = *Term.getParent(); 464 DivergencePropagator Propagator(LoopPO, DT, PDT, LI, TermBlock); 465 auto DivDesc = Propagator.computeJoinPoints(); 466 467 LLVM_DEBUG(dbgs() << "Result (" << Term.getParent()->getName() << "):\n"; 468 dbgs() << "JoinDivBlocks: "; 469 printBlockSet(DivDesc->JoinDivBlocks, dbgs()); 470 dbgs() << "\nLoopDivBlocks: "; 471 printBlockSet(DivDesc->LoopDivBlocks, dbgs()); dbgs() << "\n";); 472 473 auto ItInserted = CachedControlDivDescs.emplace(&Term, std::move(DivDesc)); 474 assert(ItInserted.second); 475 return *ItInserted.first->second; 476 } 477 478 } // namespace llvm 479