1 //===--- SyncDependenceAnalysis.cpp - Compute Control Divergence Effects --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements an algorithm that returns for a divergent branch 10 // the set of basic blocks whose phi nodes become divergent due to divergent 11 // control. These are the blocks that are reachable by two disjoint paths from 12 // the branch or loop exits that have a reaching path that is disjoint from a 13 // path to the loop latch. 14 // 15 // The SyncDependenceAnalysis is used in the DivergenceAnalysis to model 16 // control-induced divergence in phi nodes. 17 // 18 // -- Summary -- 19 // The SyncDependenceAnalysis lazily computes sync dependences [3]. 20 // The analysis evaluates the disjoint path criterion [2] by a reduction 21 // to SSA construction. The SSA construction algorithm is implemented as 22 // a simple data-flow analysis [1]. 23 // 24 // [1] "A Simple, Fast Dominance Algorithm", SPI '01, Cooper, Harvey and Kennedy 25 // [2] "Efficiently Computing Static Single Assignment Form 26 // and the Control Dependence Graph", TOPLAS '91, 27 // Cytron, Ferrante, Rosen, Wegman and Zadeck 28 // [3] "Improving Performance of OpenCL on CPUs", CC '12, Karrenberg and Hack 29 // [4] "Divergence Analysis", TOPLAS '13, Sampaio, Souza, Collange and Pereira 30 // 31 // -- Sync dependence -- 32 // Sync dependence [4] characterizes the control flow aspect of the 33 // propagation of branch divergence. For example, 34 // 35 // %cond = icmp slt i32 %tid, 10 36 // br i1 %cond, label %then, label %else 37 // then: 38 // br label %merge 39 // else: 40 // br label %merge 41 // merge: 42 // %a = phi i32 [ 0, %then ], [ 1, %else ] 43 // 44 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid 45 // because %tid is not on its use-def chains, %a is sync dependent on %tid 46 // because the branch "br i1 %cond" depends on %tid and affects which value %a 47 // is assigned to. 48 // 49 // -- Reduction to SSA construction -- 50 // There are two disjoint paths from A to X, if a certain variant of SSA 51 // construction places a phi node in X under the following set-up scheme [2]. 52 // 53 // This variant of SSA construction ignores incoming undef values. 54 // That is paths from the entry without a definition do not result in 55 // phi nodes. 56 // 57 // entry 58 // / \ 59 // A \ 60 // / \ Y 61 // B C / 62 // \ / \ / 63 // D E 64 // \ / 65 // F 66 // Assume that A contains a divergent branch. We are interested 67 // in the set of all blocks where each block is reachable from A 68 // via two disjoint paths. This would be the set {D, F} in this 69 // case. 70 // To generally reduce this query to SSA construction we introduce 71 // a virtual variable x and assign to x different values in each 72 // successor block of A. 73 // entry 74 // / \ 75 // A \ 76 // / \ Y 77 // x = 0 x = 1 / 78 // \ / \ / 79 // D E 80 // \ / 81 // F 82 // Our flavor of SSA construction for x will construct the following 83 // entry 84 // / \ 85 // A \ 86 // / \ Y 87 // x0 = 0 x1 = 1 / 88 // \ / \ / 89 // x2=phi E 90 // \ / 91 // x3=phi 92 // The blocks D and F contain phi nodes and are thus each reachable 93 // by two disjoins paths from A. 94 // 95 // -- Remarks -- 96 // In case of loop exits we need to check the disjoint path criterion for loops 97 // [2]. To this end, we check whether the definition of x differs between the 98 // loop exit and the loop header (_after_ SSA construction). 99 // 100 //===----------------------------------------------------------------------===// 101 #include "llvm/Analysis/SyncDependenceAnalysis.h" 102 #include "llvm/ADT/PostOrderIterator.h" 103 #include "llvm/ADT/SmallPtrSet.h" 104 #include "llvm/Analysis/PostDominators.h" 105 #include "llvm/IR/BasicBlock.h" 106 #include "llvm/IR/CFG.h" 107 #include "llvm/IR/Dominators.h" 108 #include "llvm/IR/Function.h" 109 110 #include <functional> 111 #include <stack> 112 #include <unordered_set> 113 114 #define DEBUG_TYPE "sync-dependence" 115 116 // The SDA algorithm operates on a modified CFG - we modify the edges leaving 117 // loop headers as follows: 118 // 119 // * We remove all edges leaving all loop headers. 120 // * We add additional edges from the loop headers to their exit blocks. 121 // 122 // The modification is virtual, that is whenever we visit a loop header we 123 // pretend it had different successors. 124 namespace { 125 using namespace llvm; 126 127 // Custom Post-Order Traveral 128 // 129 // We cannot use the vanilla (R)PO computation of LLVM because: 130 // * We (virtually) modify the CFG. 131 // * We want a loop-compact block enumeration, that is the numbers assigned by 132 // the traveral to the blocks of a loop are an interval. 133 using POCB = std::function<void(const BasicBlock &)>; 134 using VisitedSet = std::set<const BasicBlock *>; 135 using BlockStack = std::vector<const BasicBlock *>; 136 137 // forward 138 static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack, 139 VisitedSet &Finalized); 140 141 // for a nested region (top-level loop or nested loop) 142 static void computeStackPO(BlockStack &Stack, const LoopInfo &LI, Loop *Loop, 143 POCB CallBack, VisitedSet &Finalized) { 144 const auto *LoopHeader = Loop ? Loop->getHeader() : nullptr; 145 while (!Stack.empty()) { 146 const auto *NextBB = Stack.back(); 147 148 auto *NestedLoop = LI.getLoopFor(NextBB); 149 bool IsNestedLoop = NestedLoop != Loop; 150 151 // Treat the loop as a node 152 if (IsNestedLoop) { 153 SmallVector<BasicBlock *, 3> NestedExits; 154 NestedLoop->getUniqueExitBlocks(NestedExits); 155 bool PushedNodes = false; 156 for (const auto *NestedExitBB : NestedExits) { 157 if (NestedExitBB == LoopHeader) 158 continue; 159 if (Loop && !Loop->contains(NestedExitBB)) 160 continue; 161 if (Finalized.count(NestedExitBB)) 162 continue; 163 PushedNodes = true; 164 Stack.push_back(NestedExitBB); 165 } 166 if (!PushedNodes) { 167 // All loop exits finalized -> finish this node 168 Stack.pop_back(); 169 computeLoopPO(LI, *NestedLoop, CallBack, Finalized); 170 } 171 continue; 172 } 173 174 // DAG-style 175 bool PushedNodes = false; 176 for (const auto *SuccBB : successors(NextBB)) { 177 if (SuccBB == LoopHeader) 178 continue; 179 if (Loop && !Loop->contains(SuccBB)) 180 continue; 181 if (Finalized.count(SuccBB)) 182 continue; 183 PushedNodes = true; 184 Stack.push_back(SuccBB); 185 } 186 if (!PushedNodes) { 187 // Never push nodes twice 188 Stack.pop_back(); 189 if (!Finalized.insert(NextBB).second) 190 continue; 191 CallBack(*NextBB); 192 } 193 } 194 } 195 196 static void computeTopLevelPO(Function &F, const LoopInfo &LI, POCB CallBack) { 197 VisitedSet Finalized; 198 BlockStack Stack; 199 Stack.reserve(24); // FIXME made-up number 200 Stack.push_back(&F.getEntryBlock()); 201 computeStackPO(Stack, LI, nullptr, CallBack, Finalized); 202 } 203 204 static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack, 205 VisitedSet &Finalized) { 206 /// Call CallBack on all loop blocks. 207 std::vector<const BasicBlock *> Stack; 208 const auto *LoopHeader = Loop.getHeader(); 209 210 // Visit the header last 211 Finalized.insert(LoopHeader); 212 CallBack(*LoopHeader); 213 214 // Initialize with immediate successors 215 for (const auto *BB : successors(LoopHeader)) { 216 if (!Loop.contains(BB)) 217 continue; 218 if (BB == LoopHeader) 219 continue; 220 Stack.push_back(BB); 221 } 222 223 // Compute PO inside region 224 computeStackPO(Stack, LI, &Loop, CallBack, Finalized); 225 } 226 227 } // namespace 228 229 namespace llvm { 230 231 ControlDivergenceDesc SyncDependenceAnalysis::EmptyDivergenceDesc; 232 233 SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT, 234 const PostDominatorTree &PDT, 235 const LoopInfo &LI) 236 : DT(DT), PDT(PDT), LI(LI) { 237 computeTopLevelPO(*DT.getRoot()->getParent(), LI, 238 [&](const BasicBlock &BB) { LoopPO.appendBlock(BB); }); 239 } 240 241 SyncDependenceAnalysis::~SyncDependenceAnalysis() {} 242 243 // divergence propagator for reducible CFGs 244 struct DivergencePropagator { 245 const ModifiedPO &LoopPOT; 246 const DominatorTree &DT; 247 const PostDominatorTree &PDT; 248 const LoopInfo &LI; 249 const BasicBlock &DivTermBlock; 250 251 // * if BlockLabels[IndexOf(B)] == C then C is the dominating definition at 252 // block B 253 // * if BlockLabels[IndexOf(B)] ~ undef then we haven't seen B yet 254 // * if BlockLabels[IndexOf(B)] == B then B is a join point of disjoint paths 255 // from X or B is an immediate successor of X (initial value). 256 using BlockLabelVec = std::vector<const BasicBlock *>; 257 BlockLabelVec BlockLabels; 258 // divergent join and loop exit descriptor. 259 std::unique_ptr<ControlDivergenceDesc> DivDesc; 260 261 DivergencePropagator(const ModifiedPO &LoopPOT, const DominatorTree &DT, 262 const PostDominatorTree &PDT, const LoopInfo &LI, 263 const BasicBlock &DivTermBlock) 264 : LoopPOT(LoopPOT), DT(DT), PDT(PDT), LI(LI), DivTermBlock(DivTermBlock), 265 BlockLabels(LoopPOT.size(), nullptr), 266 DivDesc(new ControlDivergenceDesc) {} 267 268 void printDefs(raw_ostream &Out) { 269 Out << "Propagator::BlockLabels {\n"; 270 for (int BlockIdx = (int)BlockLabels.size() - 1; BlockIdx > 0; --BlockIdx) { 271 const auto *Label = BlockLabels[BlockIdx]; 272 Out << LoopPOT.getBlockAt(BlockIdx)->getName().str() << "(" << BlockIdx 273 << ") : "; 274 if (!Label) { 275 Out << "<null>\n"; 276 } else { 277 Out << Label->getName() << "\n"; 278 } 279 } 280 Out << "}\n"; 281 } 282 283 // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this 284 // causes a divergent join. 285 bool computeJoin(const BasicBlock &SuccBlock, const BasicBlock &PushedLabel) { 286 auto SuccIdx = LoopPOT.getIndexOf(SuccBlock); 287 288 // unset or same reaching label 289 const auto *OldLabel = BlockLabels[SuccIdx]; 290 if (!OldLabel || (OldLabel == &PushedLabel)) { 291 BlockLabels[SuccIdx] = &PushedLabel; 292 return false; 293 } 294 295 // Update the definition 296 BlockLabels[SuccIdx] = &SuccBlock; 297 return true; 298 } 299 300 // visiting a virtual loop exit edge from the loop header --> temporal 301 // divergence on join 302 bool visitLoopExitEdge(const BasicBlock &ExitBlock, 303 const BasicBlock &DefBlock, bool FromParentLoop) { 304 // Pushing from a non-parent loop cannot cause temporal divergence. 305 if (!FromParentLoop) 306 return visitEdge(ExitBlock, DefBlock); 307 308 if (!computeJoin(ExitBlock, DefBlock)) 309 return false; 310 311 // Identified a divergent loop exit 312 DivDesc->LoopDivBlocks.insert(&ExitBlock); 313 LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock.getName() 314 << "\n"); 315 return true; 316 } 317 318 // process \p SuccBlock with reaching definition \p DefBlock 319 bool visitEdge(const BasicBlock &SuccBlock, const BasicBlock &DefBlock) { 320 if (!computeJoin(SuccBlock, DefBlock)) 321 return false; 322 323 // Divergent, disjoint paths join. 324 DivDesc->JoinDivBlocks.insert(&SuccBlock); 325 LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock.getName()); 326 return true; 327 } 328 329 std::unique_ptr<ControlDivergenceDesc> computeJoinPoints() { 330 assert(DivDesc); 331 332 LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock.getName() 333 << "\n"); 334 335 const auto *DivBlockLoop = LI.getLoopFor(&DivTermBlock); 336 337 // Early stopping criterion 338 int FloorIdx = LoopPOT.size() - 1; 339 const BasicBlock *FloorLabel = nullptr; 340 341 // bootstrap with branch targets 342 int BlockIdx = 0; 343 344 for (const auto *SuccBlock : successors(&DivTermBlock)) { 345 auto SuccIdx = LoopPOT.getIndexOf(*SuccBlock); 346 BlockLabels[SuccIdx] = SuccBlock; 347 348 // Find the successor with the highest index to start with 349 BlockIdx = std::max<int>(BlockIdx, SuccIdx); 350 FloorIdx = std::min<int>(FloorIdx, SuccIdx); 351 352 // Identify immediate divergent loop exits 353 if (!DivBlockLoop) 354 continue; 355 356 const auto *BlockLoop = LI.getLoopFor(SuccBlock); 357 if (BlockLoop && DivBlockLoop->contains(BlockLoop)) 358 continue; 359 DivDesc->LoopDivBlocks.insert(SuccBlock); 360 LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: " 361 << SuccBlock->getName() << "\n"); 362 } 363 364 // propagate definitions at the immediate successors of the node in RPO 365 for (; BlockIdx >= FloorIdx; --BlockIdx) { 366 LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs())); 367 368 // Any label available here 369 const auto *Label = BlockLabels[BlockIdx]; 370 if (!Label) 371 continue; 372 373 // Ok. Get the block 374 const auto *Block = LoopPOT.getBlockAt(BlockIdx); 375 LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n"); 376 377 auto *BlockLoop = LI.getLoopFor(Block); 378 bool IsLoopHeader = BlockLoop && BlockLoop->getHeader() == Block; 379 bool CausedJoin = false; 380 int LoweredFloorIdx = FloorIdx; 381 if (IsLoopHeader) { 382 // Disconnect from immediate successors and propagate directly to loop 383 // exits. 384 SmallVector<BasicBlock *, 4> BlockLoopExits; 385 BlockLoop->getExitBlocks(BlockLoopExits); 386 387 bool IsParentLoop = BlockLoop->contains(&DivTermBlock); 388 for (const auto *BlockLoopExit : BlockLoopExits) { 389 CausedJoin |= visitLoopExitEdge(*BlockLoopExit, *Label, IsParentLoop); 390 LoweredFloorIdx = std::min<int>(LoweredFloorIdx, 391 LoopPOT.getIndexOf(*BlockLoopExit)); 392 } 393 } else { 394 // Acyclic successor case 395 for (const auto *SuccBlock : successors(Block)) { 396 CausedJoin |= visitEdge(*SuccBlock, *Label); 397 LoweredFloorIdx = 398 std::min<int>(LoweredFloorIdx, LoopPOT.getIndexOf(*SuccBlock)); 399 } 400 } 401 402 // Floor update 403 if (CausedJoin) { 404 // 1. Different labels pushed to successors 405 FloorIdx = LoweredFloorIdx; 406 } else if (FloorLabel != Label) { 407 // 2. No join caused BUT we pushed a label that is different than the 408 // last pushed label 409 FloorIdx = LoweredFloorIdx; 410 FloorLabel = Label; 411 } 412 } 413 414 LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs())); 415 416 return std::move(DivDesc); 417 } 418 }; 419 420 #ifndef NDEBUG 421 static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out) { 422 Out << "["; 423 ListSeparator LS; 424 for (const auto *BB : Blocks) 425 Out << LS << BB->getName(); 426 Out << "]"; 427 } 428 #endif 429 430 const ControlDivergenceDesc & 431 SyncDependenceAnalysis::getJoinBlocks(const Instruction &Term) { 432 // trivial case 433 if (Term.getNumSuccessors() <= 1) { 434 return EmptyDivergenceDesc; 435 } 436 437 // already available in cache? 438 auto ItCached = CachedControlDivDescs.find(&Term); 439 if (ItCached != CachedControlDivDescs.end()) 440 return *ItCached->second; 441 442 // compute all join points 443 // Special handling of divergent loop exits is not needed for LCSSA 444 const auto &TermBlock = *Term.getParent(); 445 DivergencePropagator Propagator(LoopPO, DT, PDT, LI, TermBlock); 446 auto DivDesc = Propagator.computeJoinPoints(); 447 448 LLVM_DEBUG(dbgs() << "Result (" << Term.getParent()->getName() << "):\n"; 449 dbgs() << "JoinDivBlocks: "; 450 printBlockSet(DivDesc->JoinDivBlocks, dbgs()); 451 dbgs() << "\nLoopDivBlocks: "; 452 printBlockSet(DivDesc->LoopDivBlocks, dbgs()); dbgs() << "\n";); 453 454 auto ItInserted = CachedControlDivDescs.emplace(&Term, std::move(DivDesc)); 455 assert(ItInserted.second); 456 return *ItInserted.first->second; 457 } 458 459 } // namespace llvm 460