1 //===- GenericUniformityImpl.h -----------------------*- C++ -*------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This template implementation resides in a separate file so that it 10 // does not get injected into every .cpp file that includes the 11 // generic header. 12 // 13 // DO NOT INCLUDE THIS FILE WHEN MERELY USING UNIFORMITYINFO. 14 // 15 // This file should only be included by files that implement a 16 // specialization of the relvant templates. Currently these are: 17 // - UniformityAnalysis.cpp 18 // 19 // Note: The DEBUG_TYPE macro should be defined before using this 20 // file so that any use of LLVM_DEBUG is associated with the 21 // including file rather than this file. 22 // 23 //===----------------------------------------------------------------------===// 24 /// 25 /// \file 26 /// \brief Implementation of uniformity analysis. 27 /// 28 /// The algorithm is a fixed point iteration that starts with the assumption 29 /// that all control flow and all values are uniform. Starting from sources of 30 /// divergence (whose discovery must be implemented by a CFG- or even 31 /// target-specific derived class), divergence of values is propagated from 32 /// definition to uses in a straight-forward way. The main complexity lies in 33 /// the propagation of the impact of divergent control flow on the divergence of 34 /// values (sync dependencies). 35 /// 36 /// NOTE: In general, no interface exists for a transform to update 37 /// (Machine)UniformityInfo. Additionally, (Machine)CycleAnalysis is a 38 /// transitive dependence, but it also does not provide an interface for 39 /// updating itself. Given that, transforms should not preserve uniformity in 40 /// their getAnalysisUsage() callback. 41 /// 42 //===----------------------------------------------------------------------===// 43 44 #ifndef LLVM_ADT_GENERICUNIFORMITYIMPL_H 45 #define LLVM_ADT_GENERICUNIFORMITYIMPL_H 46 47 #include "llvm/ADT/GenericUniformityInfo.h" 48 49 #include "llvm/ADT/DenseSet.h" 50 #include "llvm/ADT/STLExtras.h" 51 #include "llvm/ADT/SmallPtrSet.h" 52 #include "llvm/ADT/SparseBitVector.h" 53 #include "llvm/ADT/StringExtras.h" 54 #include "llvm/Support/raw_ostream.h" 55 56 #define DEBUG_TYPE "uniformity" 57 58 namespace llvm { 59 60 /// Construct a specially modified post-order traversal of cycles. 61 /// 62 /// The ModifiedPO is contructed using a virtually modified CFG as follows: 63 /// 64 /// 1. The successors of pre-entry nodes (predecessors of an cycle 65 /// entry that are outside the cycle) are replaced by the 66 /// successors of the successors of the header. 67 /// 2. Successors of the cycle header are replaced by the exit blocks 68 /// of the cycle. 69 /// 70 /// Effectively, we produce a depth-first numbering with the following 71 /// properties: 72 /// 73 /// 1. Nodes after a cycle are numbered earlier than the cycle header. 74 /// 2. The header is numbered earlier than the nodes in the cycle. 75 /// 3. The numbering of the nodes within the cycle forms an interval 76 /// starting with the header. 77 /// 78 /// Effectively, the virtual modification arranges the nodes in a 79 /// cycle as a DAG with the header as the sole leaf, and successors of 80 /// the header as the roots. A reverse traversal of this numbering has 81 /// the following invariant on the unmodified original CFG: 82 /// 83 /// Each node is visited after all its predecessors, except if that 84 /// predecessor is the cycle header. 85 /// 86 template <typename ContextT> class ModifiedPostOrder { 87 public: 88 using BlockT = typename ContextT::BlockT; 89 using FunctionT = typename ContextT::FunctionT; 90 using DominatorTreeT = typename ContextT::DominatorTreeT; 91 92 using CycleInfoT = GenericCycleInfo<ContextT>; 93 using CycleT = typename CycleInfoT::CycleT; 94 using const_iterator = typename std::vector<BlockT *>::const_iterator; 95 96 ModifiedPostOrder(const ContextT &C) : Context(C) {} 97 98 bool empty() const { return m_order.empty(); } 99 size_t size() const { return m_order.size(); } 100 101 void clear() { m_order.clear(); } 102 void compute(const CycleInfoT &CI); 103 104 unsigned count(BlockT *BB) const { return POIndex.count(BB); } 105 const BlockT *operator[](size_t idx) const { return m_order[idx]; } 106 107 void appendBlock(const BlockT &BB, bool isReducibleCycleHeader = false) { 108 POIndex[&BB] = m_order.size(); 109 m_order.push_back(&BB); 110 LLVM_DEBUG(dbgs() << "ModifiedPO(" << POIndex[&BB] 111 << "): " << Context.print(&BB) << "\n"); 112 if (isReducibleCycleHeader) 113 ReducibleCycleHeaders.insert(&BB); 114 } 115 116 unsigned getIndex(const BlockT *BB) const { 117 assert(POIndex.count(BB)); 118 return POIndex.lookup(BB); 119 } 120 121 bool isReducibleCycleHeader(const BlockT *BB) const { 122 return ReducibleCycleHeaders.contains(BB); 123 } 124 125 private: 126 SmallVector<const BlockT *> m_order; 127 DenseMap<const BlockT *, unsigned> POIndex; 128 SmallPtrSet<const BlockT *, 32> ReducibleCycleHeaders; 129 const ContextT &Context; 130 131 void computeCyclePO(const CycleInfoT &CI, const CycleT *Cycle, 132 SmallPtrSetImpl<const BlockT *> &Finalized); 133 134 void computeStackPO(SmallVectorImpl<const BlockT *> &Stack, 135 const CycleInfoT &CI, const CycleT *Cycle, 136 SmallPtrSetImpl<const BlockT *> &Finalized); 137 }; 138 139 template <typename> class DivergencePropagator; 140 141 /// \class GenericSyncDependenceAnalysis 142 /// 143 /// \brief Locate join blocks for disjoint paths starting at a divergent branch. 144 /// 145 /// An analysis per divergent branch that returns the set of basic 146 /// blocks whose phi nodes become divergent due to divergent control. 147 /// These are the blocks that are reachable by two disjoint paths from 148 /// the branch, or cycle exits reachable along a path that is disjoint 149 /// from a path to the cycle latch. 150 151 // --- Above line is not a doxygen comment; intentionally left blank --- 152 // 153 // Originally implemented in SyncDependenceAnalysis.cpp for DivergenceAnalysis. 154 // 155 // The SyncDependenceAnalysis is used in the UniformityAnalysis to model 156 // control-induced divergence in phi nodes. 157 // 158 // -- Reference -- 159 // The algorithm is an extension of Section 5 of 160 // 161 // An abstract interpretation for SPMD divergence 162 // on reducible control flow graphs. 163 // Julian Rosemann, Simon Moll and Sebastian Hack 164 // POPL '21 165 // 166 // 167 // -- Sync dependence -- 168 // Sync dependence characterizes the control flow aspect of the 169 // propagation of branch divergence. For example, 170 // 171 // %cond = icmp slt i32 %tid, 10 172 // br i1 %cond, label %then, label %else 173 // then: 174 // br label %merge 175 // else: 176 // br label %merge 177 // merge: 178 // %a = phi i32 [ 0, %then ], [ 1, %else ] 179 // 180 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid 181 // because %tid is not on its use-def chains, %a is sync dependent on %tid 182 // because the branch "br i1 %cond" depends on %tid and affects which value %a 183 // is assigned to. 184 // 185 // 186 // -- Reduction to SSA construction -- 187 // There are two disjoint paths from A to X, if a certain variant of SSA 188 // construction places a phi node in X under the following set-up scheme. 189 // 190 // This variant of SSA construction ignores incoming undef values. 191 // That is paths from the entry without a definition do not result in 192 // phi nodes. 193 // 194 // entry 195 // / \ 196 // A \ 197 // / \ Y 198 // B C / 199 // \ / \ / 200 // D E 201 // \ / 202 // F 203 // 204 // Assume that A contains a divergent branch. We are interested 205 // in the set of all blocks where each block is reachable from A 206 // via two disjoint paths. This would be the set {D, F} in this 207 // case. 208 // To generally reduce this query to SSA construction we introduce 209 // a virtual variable x and assign to x different values in each 210 // successor block of A. 211 // 212 // entry 213 // / \ 214 // A \ 215 // / \ Y 216 // x = 0 x = 1 / 217 // \ / \ / 218 // D E 219 // \ / 220 // F 221 // 222 // Our flavor of SSA construction for x will construct the following 223 // 224 // entry 225 // / \ 226 // A \ 227 // / \ Y 228 // x0 = 0 x1 = 1 / 229 // \ / \ / 230 // x2 = phi E 231 // \ / 232 // x3 = phi 233 // 234 // The blocks D and F contain phi nodes and are thus each reachable 235 // by two disjoins paths from A. 236 // 237 // -- Remarks -- 238 // * In case of cycle exits we need to check for temporal divergence. 239 // To this end, we check whether the definition of x differs between the 240 // cycle exit and the cycle header (_after_ SSA construction). 241 // 242 // * In the presence of irreducible control flow, the fixed point is 243 // reached only after multiple iterations. This is because labels 244 // reaching the header of a cycle must be repropagated through the 245 // cycle. This is true even in a reducible cycle, since the labels 246 // may have been produced by a nested irreducible cycle. 247 // 248 // * Note that SyncDependenceAnalysis is not concerned with the points 249 // of convergence in an irreducible cycle. It's only purpose is to 250 // identify join blocks. The "diverged entry" criterion is 251 // separately applied on join blocks to determine if an entire 252 // irreducible cycle is assumed to be divergent. 253 // 254 // * Relevant related work: 255 // A simple algorithm for global data flow analysis problems. 256 // Matthew S. Hecht and Jeffrey D. Ullman. 257 // SIAM Journal on Computing, 4(4):519–532, December 1975. 258 // 259 template <typename ContextT> class GenericSyncDependenceAnalysis { 260 public: 261 using BlockT = typename ContextT::BlockT; 262 using DominatorTreeT = typename ContextT::DominatorTreeT; 263 using FunctionT = typename ContextT::FunctionT; 264 using ValueRefT = typename ContextT::ValueRefT; 265 using InstructionT = typename ContextT::InstructionT; 266 267 using CycleInfoT = GenericCycleInfo<ContextT>; 268 using CycleT = typename CycleInfoT::CycleT; 269 270 using ConstBlockSet = SmallPtrSet<const BlockT *, 4>; 271 using ModifiedPO = ModifiedPostOrder<ContextT>; 272 273 // * if BlockLabels[B] == C then C is the dominating definition at 274 // block B 275 // * if BlockLabels[B] == nullptr then we haven't seen B yet 276 // * if BlockLabels[B] == B then: 277 // - B is a join point of disjoint paths from X, or, 278 // - B is an immediate successor of X (initial value), or, 279 // - B is X 280 using BlockLabelMap = DenseMap<const BlockT *, const BlockT *>; 281 282 /// Information discovered by the sync dependence analysis for each 283 /// divergent branch. 284 struct DivergenceDescriptor { 285 // Join points of diverged paths. 286 ConstBlockSet JoinDivBlocks; 287 // Divergent cycle exits 288 ConstBlockSet CycleDivBlocks; 289 // Labels assigned to blocks on diverged paths. 290 BlockLabelMap BlockLabels; 291 }; 292 293 using DivergencePropagatorT = DivergencePropagator<ContextT>; 294 295 GenericSyncDependenceAnalysis(const ContextT &Context, 296 const DominatorTreeT &DT, const CycleInfoT &CI); 297 298 /// \brief Computes divergent join points and cycle exits caused by branch 299 /// divergence in \p Term. 300 /// 301 /// This returns a pair of sets: 302 /// * The set of blocks which are reachable by disjoint paths from 303 /// \p Term. 304 /// * The set also contains cycle exits if there two disjoint paths: 305 /// one from \p Term to the cycle exit and another from \p Term to 306 /// the cycle header. 307 const DivergenceDescriptor &getJoinBlocks(const BlockT *DivTermBlock); 308 309 private: 310 static DivergenceDescriptor EmptyDivergenceDesc; 311 312 ModifiedPO CyclePO; 313 314 const DominatorTreeT &DT; 315 const CycleInfoT &CI; 316 317 DenseMap<const BlockT *, std::unique_ptr<DivergenceDescriptor>> 318 CachedControlDivDescs; 319 }; 320 321 /// \brief Analysis that identifies uniform values in a data-parallel 322 /// execution. 323 /// 324 /// This analysis propagates divergence in a data-parallel context 325 /// from sources of divergence to all users. It can be instantiated 326 /// for an IR that provides a suitable SSAContext. 327 template <typename ContextT> class GenericUniformityAnalysisImpl { 328 public: 329 using BlockT = typename ContextT::BlockT; 330 using FunctionT = typename ContextT::FunctionT; 331 using ValueRefT = typename ContextT::ValueRefT; 332 using ConstValueRefT = typename ContextT::ConstValueRefT; 333 using UseT = typename ContextT::UseT; 334 using InstructionT = typename ContextT::InstructionT; 335 using DominatorTreeT = typename ContextT::DominatorTreeT; 336 337 using CycleInfoT = GenericCycleInfo<ContextT>; 338 using CycleT = typename CycleInfoT::CycleT; 339 340 using SyncDependenceAnalysisT = GenericSyncDependenceAnalysis<ContextT>; 341 using DivergenceDescriptorT = 342 typename SyncDependenceAnalysisT::DivergenceDescriptor; 343 using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap; 344 345 GenericUniformityAnalysisImpl(const DominatorTreeT &DT, const CycleInfoT &CI, 346 const TargetTransformInfo *TTI) 347 : Context(CI.getSSAContext()), F(*Context.getFunction()), CI(CI), 348 TTI(TTI), DT(DT), SDA(Context, DT, CI) {} 349 350 void initialize(); 351 352 const FunctionT &getFunction() const { return F; } 353 354 /// \brief Mark \p UniVal as a value that is always uniform. 355 void addUniformOverride(const InstructionT &Instr); 356 357 /// \brief Examine \p I for divergent outputs and add to the worklist. 358 void markDivergent(const InstructionT &I); 359 360 /// \brief Mark \p DivVal as a divergent value. 361 /// \returns Whether the tracked divergence state of \p DivVal changed. 362 bool markDivergent(ConstValueRefT DivVal); 363 364 /// \brief Mark outputs of \p Instr as divergent. 365 /// \returns Whether the tracked divergence state of any output has changed. 366 bool markDefsDivergent(const InstructionT &Instr); 367 368 /// \brief Propagate divergence to all instructions in the region. 369 /// Divergence is seeded by calls to \p markDivergent. 370 void compute(); 371 372 /// \brief Whether any value was marked or analyzed to be divergent. 373 bool hasDivergence() const { return !DivergentValues.empty(); } 374 375 /// \brief Whether \p Val will always return a uniform value regardless of its 376 /// operands 377 bool isAlwaysUniform(const InstructionT &Instr) const; 378 379 bool hasDivergentDefs(const InstructionT &I) const; 380 381 bool isDivergent(const InstructionT &I) const { 382 if (I.isTerminator()) { 383 return DivergentTermBlocks.contains(I.getParent()); 384 } 385 return hasDivergentDefs(I); 386 }; 387 388 /// \brief Whether \p Val is divergent at its definition. 389 bool isDivergent(ConstValueRefT V) const { return DivergentValues.count(V); } 390 391 bool isDivergentUse(const UseT &U) const; 392 393 bool hasDivergentTerminator(const BlockT &B) const { 394 return DivergentTermBlocks.contains(&B); 395 } 396 397 void print(raw_ostream &out) const; 398 399 protected: 400 /// \brief Value/block pair representing a single phi input. 401 struct PhiInput { 402 ConstValueRefT value; 403 BlockT *predBlock; 404 405 PhiInput(ConstValueRefT value, BlockT *predBlock) 406 : value(value), predBlock(predBlock) {} 407 }; 408 409 const ContextT &Context; 410 const FunctionT &F; 411 const CycleInfoT &CI; 412 const TargetTransformInfo *TTI = nullptr; 413 414 // Detected/marked divergent values. 415 DenseSet<ConstValueRefT> DivergentValues; 416 SmallPtrSet<const BlockT *, 32> DivergentTermBlocks; 417 418 // Internal worklist for divergence propagation. 419 std::vector<const InstructionT *> Worklist; 420 421 /// \brief Mark \p Term as divergent and push all Instructions that become 422 /// divergent as a result on the worklist. 423 void analyzeControlDivergence(const InstructionT &Term); 424 425 private: 426 const DominatorTreeT &DT; 427 428 // Recognized cycles with divergent exits. 429 SmallPtrSet<const CycleT *, 16> DivergentExitCycles; 430 431 // Cycles assumed to be divergent. 432 // 433 // We don't use a set here because every insertion needs an explicit 434 // traversal of all existing members. 435 SmallVector<const CycleT *> AssumedDivergent; 436 437 // The SDA links divergent branches to divergent control-flow joins. 438 SyncDependenceAnalysisT SDA; 439 440 // Set of known-uniform values. 441 SmallPtrSet<const InstructionT *, 32> UniformOverrides; 442 443 /// \brief Mark all nodes in \p JoinBlock as divergent and push them on 444 /// the worklist. 445 void taintAndPushAllDefs(const BlockT &JoinBlock); 446 447 /// \brief Mark all phi nodes in \p JoinBlock as divergent and push them on 448 /// the worklist. 449 void taintAndPushPhiNodes(const BlockT &JoinBlock); 450 451 /// \brief Identify all Instructions that become divergent because \p DivExit 452 /// is a divergent cycle exit of \p DivCycle. Mark those instructions as 453 /// divergent and push them on the worklist. 454 void propagateCycleExitDivergence(const BlockT &DivExit, 455 const CycleT &DivCycle); 456 457 /// Mark as divergent all external uses of values defined in \p DefCycle. 458 void analyzeCycleExitDivergence(const CycleT &DefCycle); 459 460 /// \brief Mark as divergent all uses of \p I that are outside \p DefCycle. 461 void propagateTemporalDivergence(const InstructionT &I, 462 const CycleT &DefCycle); 463 464 /// \brief Push all users of \p Val (in the region) to the worklist. 465 void pushUsers(const InstructionT &I); 466 void pushUsers(ConstValueRefT V); 467 468 bool usesValueFromCycle(const InstructionT &I, const CycleT &DefCycle) const; 469 470 /// \brief Whether \p Def is divergent when read in \p ObservingBlock. 471 bool isTemporalDivergent(const BlockT &ObservingBlock, 472 const InstructionT &Def) const; 473 }; 474 475 template <typename ImplT> 476 void GenericUniformityAnalysisImplDeleter<ImplT>::operator()(ImplT *Impl) { 477 delete Impl; 478 } 479 480 /// Compute divergence starting with a divergent branch. 481 template <typename ContextT> class DivergencePropagator { 482 public: 483 using BlockT = typename ContextT::BlockT; 484 using DominatorTreeT = typename ContextT::DominatorTreeT; 485 using FunctionT = typename ContextT::FunctionT; 486 using ValueRefT = typename ContextT::ValueRefT; 487 488 using CycleInfoT = GenericCycleInfo<ContextT>; 489 using CycleT = typename CycleInfoT::CycleT; 490 491 using ModifiedPO = ModifiedPostOrder<ContextT>; 492 using SyncDependenceAnalysisT = GenericSyncDependenceAnalysis<ContextT>; 493 using DivergenceDescriptorT = 494 typename SyncDependenceAnalysisT::DivergenceDescriptor; 495 using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap; 496 497 const ModifiedPO &CyclePOT; 498 const DominatorTreeT &DT; 499 const CycleInfoT &CI; 500 const BlockT &DivTermBlock; 501 const ContextT &Context; 502 503 // Track blocks that receive a new label. Every time we relabel a 504 // cycle header, we another pass over the modified post-order in 505 // order to propagate the header label. The bit vector also allows 506 // us to skip labels that have not changed. 507 SparseBitVector<> FreshLabels; 508 509 // divergent join and cycle exit descriptor. 510 std::unique_ptr<DivergenceDescriptorT> DivDesc; 511 BlockLabelMapT &BlockLabels; 512 513 DivergencePropagator(const ModifiedPO &CyclePOT, const DominatorTreeT &DT, 514 const CycleInfoT &CI, const BlockT &DivTermBlock) 515 : CyclePOT(CyclePOT), DT(DT), CI(CI), DivTermBlock(DivTermBlock), 516 Context(CI.getSSAContext()), DivDesc(new DivergenceDescriptorT), 517 BlockLabels(DivDesc->BlockLabels) {} 518 519 void printDefs(raw_ostream &Out) { 520 Out << "Propagator::BlockLabels {\n"; 521 for (int BlockIdx = (int)CyclePOT.size() - 1; BlockIdx >= 0; --BlockIdx) { 522 const auto *Block = CyclePOT[BlockIdx]; 523 const auto *Label = BlockLabels[Block]; 524 Out << Context.print(Block) << "(" << BlockIdx << ") : "; 525 if (!Label) { 526 Out << "<null>\n"; 527 } else { 528 Out << Context.print(Label) << "\n"; 529 } 530 } 531 Out << "}\n"; 532 } 533 534 // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this 535 // causes a divergent join. 536 bool computeJoin(const BlockT &SuccBlock, const BlockT &PushedLabel) { 537 const auto *OldLabel = BlockLabels[&SuccBlock]; 538 539 LLVM_DEBUG(dbgs() << "labeling " << Context.print(&SuccBlock) << ":\n" 540 << "\tpushed label: " << Context.print(&PushedLabel) 541 << "\n" 542 << "\told label: " << Context.print(OldLabel) << "\n"); 543 544 // Early exit if there is no change in the label. 545 if (OldLabel == &PushedLabel) 546 return false; 547 548 if (OldLabel != &SuccBlock) { 549 auto SuccIdx = CyclePOT.getIndex(&SuccBlock); 550 // Assigning a new label, mark this in FreshLabels. 551 LLVM_DEBUG(dbgs() << "\tfresh label: " << SuccIdx << "\n"); 552 FreshLabels.set(SuccIdx); 553 } 554 555 // This is not a join if the succ was previously unlabeled. 556 if (!OldLabel) { 557 LLVM_DEBUG(dbgs() << "\tnew label: " << Context.print(&PushedLabel) 558 << "\n"); 559 BlockLabels[&SuccBlock] = &PushedLabel; 560 return false; 561 } 562 563 // This is a new join. Label the join block as itself, and not as 564 // the pushed label. 565 LLVM_DEBUG(dbgs() << "\tnew label: " << Context.print(&SuccBlock) << "\n"); 566 BlockLabels[&SuccBlock] = &SuccBlock; 567 568 return true; 569 } 570 571 // visiting a virtual cycle exit edge from the cycle header --> temporal 572 // divergence on join 573 bool visitCycleExitEdge(const BlockT &ExitBlock, const BlockT &Label) { 574 if (!computeJoin(ExitBlock, Label)) 575 return false; 576 577 // Identified a divergent cycle exit 578 DivDesc->CycleDivBlocks.insert(&ExitBlock); 579 LLVM_DEBUG(dbgs() << "\tDivergent cycle exit: " << Context.print(&ExitBlock) 580 << "\n"); 581 return true; 582 } 583 584 // process \p SuccBlock with reaching definition \p Label 585 bool visitEdge(const BlockT &SuccBlock, const BlockT &Label) { 586 if (!computeJoin(SuccBlock, Label)) 587 return false; 588 589 // Divergent, disjoint paths join. 590 DivDesc->JoinDivBlocks.insert(&SuccBlock); 591 LLVM_DEBUG(dbgs() << "\tDivergent join: " << Context.print(&SuccBlock) 592 << "\n"); 593 return true; 594 } 595 596 std::unique_ptr<DivergenceDescriptorT> computeJoinPoints() { 597 assert(DivDesc); 598 599 LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " 600 << Context.print(&DivTermBlock) << "\n"); 601 602 // Early stopping criterion 603 int FloorIdx = CyclePOT.size() - 1; 604 const BlockT *FloorLabel = nullptr; 605 int DivTermIdx = CyclePOT.getIndex(&DivTermBlock); 606 607 // Bootstrap with branch targets 608 auto const *DivTermCycle = CI.getCycle(&DivTermBlock); 609 for (const auto *SuccBlock : successors(&DivTermBlock)) { 610 if (DivTermCycle && !DivTermCycle->contains(SuccBlock)) { 611 // If DivTerm exits the cycle immediately, computeJoin() might 612 // not reach SuccBlock with a different label. We need to 613 // check for this exit now. 614 DivDesc->CycleDivBlocks.insert(SuccBlock); 615 LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: " 616 << Context.print(SuccBlock) << "\n"); 617 } 618 auto SuccIdx = CyclePOT.getIndex(SuccBlock); 619 visitEdge(*SuccBlock, *SuccBlock); 620 FloorIdx = std::min<int>(FloorIdx, SuccIdx); 621 } 622 623 while (true) { 624 auto BlockIdx = FreshLabels.find_last(); 625 if (BlockIdx == -1 || BlockIdx < FloorIdx) 626 break; 627 628 LLVM_DEBUG(dbgs() << "Current labels:\n"; printDefs(dbgs())); 629 630 FreshLabels.reset(BlockIdx); 631 if (BlockIdx == DivTermIdx) { 632 LLVM_DEBUG(dbgs() << "Skipping DivTermBlock\n"); 633 continue; 634 } 635 636 const auto *Block = CyclePOT[BlockIdx]; 637 LLVM_DEBUG(dbgs() << "visiting " << Context.print(Block) << " at index " 638 << BlockIdx << "\n"); 639 640 const auto *Label = BlockLabels[Block]; 641 assert(Label); 642 643 bool CausedJoin = false; 644 int LoweredFloorIdx = FloorIdx; 645 646 // If the current block is the header of a reducible cycle that 647 // contains the divergent branch, then the label should be 648 // propagated to the cycle exits. Such a header is the "last 649 // possible join" of any disjoint paths within this cycle. This 650 // prevents detection of spurious joins at the entries of any 651 // irreducible child cycles. 652 // 653 // This conclusion about the header is true for any choice of DFS: 654 // 655 // If some DFS has a reducible cycle C with header H, then for 656 // any other DFS, H is the header of a cycle C' that is a 657 // superset of C. For a divergent branch inside the subgraph 658 // C, any join node inside C is either H, or some node 659 // encountered without passing through H. 660 // 661 auto getReducibleParent = [&](const BlockT *Block) -> const CycleT * { 662 if (!CyclePOT.isReducibleCycleHeader(Block)) 663 return nullptr; 664 const auto *BlockCycle = CI.getCycle(Block); 665 if (BlockCycle->contains(&DivTermBlock)) 666 return BlockCycle; 667 return nullptr; 668 }; 669 670 if (const auto *BlockCycle = getReducibleParent(Block)) { 671 SmallVector<BlockT *, 4> BlockCycleExits; 672 BlockCycle->getExitBlocks(BlockCycleExits); 673 for (auto *BlockCycleExit : BlockCycleExits) { 674 CausedJoin |= visitCycleExitEdge(*BlockCycleExit, *Label); 675 LoweredFloorIdx = 676 std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(BlockCycleExit)); 677 } 678 } else { 679 for (const auto *SuccBlock : successors(Block)) { 680 CausedJoin |= visitEdge(*SuccBlock, *Label); 681 LoweredFloorIdx = 682 std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(SuccBlock)); 683 } 684 } 685 686 // Floor update 687 if (CausedJoin) { 688 // 1. Different labels pushed to successors 689 FloorIdx = LoweredFloorIdx; 690 } else if (FloorLabel != Label) { 691 // 2. No join caused BUT we pushed a label that is different than the 692 // last pushed label 693 FloorIdx = LoweredFloorIdx; 694 FloorLabel = Label; 695 } 696 } 697 698 LLVM_DEBUG(dbgs() << "Final labeling:\n"; printDefs(dbgs())); 699 700 // Check every cycle containing DivTermBlock for exit divergence. 701 // A cycle has exit divergence if the label of an exit block does 702 // not match the label of its header. 703 for (const auto *Cycle = CI.getCycle(&DivTermBlock); Cycle; 704 Cycle = Cycle->getParentCycle()) { 705 if (Cycle->isReducible()) { 706 // The exit divergence of a reducible cycle is recorded while 707 // propagating labels. 708 continue; 709 } 710 SmallVector<BlockT *> Exits; 711 Cycle->getExitBlocks(Exits); 712 auto *Header = Cycle->getHeader(); 713 auto *HeaderLabel = BlockLabels[Header]; 714 for (const auto *Exit : Exits) { 715 if (BlockLabels[Exit] != HeaderLabel) { 716 // Identified a divergent cycle exit 717 DivDesc->CycleDivBlocks.insert(Exit); 718 LLVM_DEBUG(dbgs() << "\tDivergent cycle exit: " << Context.print(Exit) 719 << "\n"); 720 } 721 } 722 } 723 724 return std::move(DivDesc); 725 } 726 }; 727 728 template <typename ContextT> 729 typename llvm::GenericSyncDependenceAnalysis<ContextT>::DivergenceDescriptor 730 llvm::GenericSyncDependenceAnalysis<ContextT>::EmptyDivergenceDesc; 731 732 template <typename ContextT> 733 llvm::GenericSyncDependenceAnalysis<ContextT>::GenericSyncDependenceAnalysis( 734 const ContextT &Context, const DominatorTreeT &DT, const CycleInfoT &CI) 735 : CyclePO(Context), DT(DT), CI(CI) { 736 CyclePO.compute(CI); 737 } 738 739 template <typename ContextT> 740 auto llvm::GenericSyncDependenceAnalysis<ContextT>::getJoinBlocks( 741 const BlockT *DivTermBlock) -> const DivergenceDescriptor & { 742 // trivial case 743 if (succ_size(DivTermBlock) <= 1) { 744 return EmptyDivergenceDesc; 745 } 746 747 // already available in cache? 748 auto ItCached = CachedControlDivDescs.find(DivTermBlock); 749 if (ItCached != CachedControlDivDescs.end()) 750 return *ItCached->second; 751 752 // compute all join points 753 DivergencePropagatorT Propagator(CyclePO, DT, CI, *DivTermBlock); 754 auto DivDesc = Propagator.computeJoinPoints(); 755 756 auto printBlockSet = [&](ConstBlockSet &Blocks) { 757 return Printable([&](raw_ostream &Out) { 758 Out << "["; 759 ListSeparator LS; 760 for (const auto *BB : Blocks) { 761 Out << LS << CI.getSSAContext().print(BB); 762 } 763 Out << "]\n"; 764 }); 765 }; 766 767 LLVM_DEBUG( 768 dbgs() << "\nResult (" << CI.getSSAContext().print(DivTermBlock) 769 << "):\n JoinDivBlocks: " << printBlockSet(DivDesc->JoinDivBlocks) 770 << " CycleDivBlocks: " << printBlockSet(DivDesc->CycleDivBlocks) 771 << "\n"); 772 (void)printBlockSet; 773 774 auto ItInserted = 775 CachedControlDivDescs.try_emplace(DivTermBlock, std::move(DivDesc)); 776 assert(ItInserted.second); 777 return *ItInserted.first->second; 778 } 779 780 template <typename ContextT> 781 void GenericUniformityAnalysisImpl<ContextT>::markDivergent( 782 const InstructionT &I) { 783 if (isAlwaysUniform(I)) 784 return; 785 bool Marked = false; 786 if (I.isTerminator()) { 787 Marked = DivergentTermBlocks.insert(I.getParent()).second; 788 if (Marked) { 789 LLVM_DEBUG(dbgs() << "marked divergent term block: " 790 << Context.print(I.getParent()) << "\n"); 791 } 792 } else { 793 Marked = markDefsDivergent(I); 794 } 795 796 if (Marked) 797 Worklist.push_back(&I); 798 } 799 800 template <typename ContextT> 801 bool GenericUniformityAnalysisImpl<ContextT>::markDivergent( 802 ConstValueRefT Val) { 803 if (DivergentValues.insert(Val).second) { 804 LLVM_DEBUG(dbgs() << "marked divergent: " << Context.print(Val) << "\n"); 805 return true; 806 } 807 return false; 808 } 809 810 template <typename ContextT> 811 void GenericUniformityAnalysisImpl<ContextT>::addUniformOverride( 812 const InstructionT &Instr) { 813 UniformOverrides.insert(&Instr); 814 } 815 816 // Mark as divergent all external uses of values defined in \p DefCycle. 817 // 818 // A value V defined by a block B inside \p DefCycle may be used outside the 819 // cycle only if the use is a PHI in some exit block, or B dominates some exit 820 // block. Thus, we check uses as follows: 821 // 822 // - Check all PHIs in all exit blocks for inputs defined inside \p DefCycle. 823 // - For every block B inside \p DefCycle that dominates at least one exit 824 // block, check all uses outside \p DefCycle. 825 // 826 // FIXME: This function does not distinguish between divergent and uniform 827 // exits. For each divergent exit, only the values that are live at that exit 828 // need to be propagated as divergent at their use outside the cycle. 829 template <typename ContextT> 830 void GenericUniformityAnalysisImpl<ContextT>::analyzeCycleExitDivergence( 831 const CycleT &DefCycle) { 832 SmallVector<BlockT *> Exits; 833 DefCycle.getExitBlocks(Exits); 834 for (auto *Exit : Exits) { 835 for (auto &Phi : Exit->phis()) { 836 if (usesValueFromCycle(Phi, DefCycle)) { 837 markDivergent(Phi); 838 } 839 } 840 } 841 842 for (auto *BB : DefCycle.blocks()) { 843 if (!llvm::any_of(Exits, 844 [&](BlockT *Exit) { return DT.dominates(BB, Exit); })) 845 continue; 846 for (auto &II : *BB) { 847 propagateTemporalDivergence(II, DefCycle); 848 } 849 } 850 } 851 852 template <typename ContextT> 853 void GenericUniformityAnalysisImpl<ContextT>::propagateCycleExitDivergence( 854 const BlockT &DivExit, const CycleT &InnerDivCycle) { 855 LLVM_DEBUG(dbgs() << "\tpropCycleExitDiv " << Context.print(&DivExit) 856 << "\n"); 857 auto *DivCycle = &InnerDivCycle; 858 auto *OuterDivCycle = DivCycle; 859 auto *ExitLevelCycle = CI.getCycle(&DivExit); 860 const unsigned CycleExitDepth = 861 ExitLevelCycle ? ExitLevelCycle->getDepth() : 0; 862 863 // Find outer-most cycle that does not contain \p DivExit 864 while (DivCycle && DivCycle->getDepth() > CycleExitDepth) { 865 LLVM_DEBUG(dbgs() << " Found exiting cycle: " 866 << Context.print(DivCycle->getHeader()) << "\n"); 867 OuterDivCycle = DivCycle; 868 DivCycle = DivCycle->getParentCycle(); 869 } 870 LLVM_DEBUG(dbgs() << "\tOuter-most exiting cycle: " 871 << Context.print(OuterDivCycle->getHeader()) << "\n"); 872 873 if (!DivergentExitCycles.insert(OuterDivCycle).second) 874 return; 875 876 // Exit divergence does not matter if the cycle itself is assumed to 877 // be divergent. 878 for (const auto *C : AssumedDivergent) { 879 if (C->contains(OuterDivCycle)) 880 return; 881 } 882 883 analyzeCycleExitDivergence(*OuterDivCycle); 884 } 885 886 template <typename ContextT> 887 void GenericUniformityAnalysisImpl<ContextT>::taintAndPushAllDefs( 888 const BlockT &BB) { 889 LLVM_DEBUG(dbgs() << "taintAndPushAllDefs " << Context.print(&BB) << "\n"); 890 for (const auto &I : instrs(BB)) { 891 // Terminators do not produce values; they are divergent only if 892 // the condition is divergent. That is handled when the divergent 893 // condition is placed in the worklist. 894 if (I.isTerminator()) 895 break; 896 897 markDivergent(I); 898 } 899 } 900 901 /// Mark divergent phi nodes in a join block 902 template <typename ContextT> 903 void GenericUniformityAnalysisImpl<ContextT>::taintAndPushPhiNodes( 904 const BlockT &JoinBlock) { 905 LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << Context.print(&JoinBlock) 906 << "\n"); 907 for (const auto &Phi : JoinBlock.phis()) { 908 // FIXME: The non-undef value is not constant per se; it just happens to be 909 // uniform and may not dominate this PHI. So assuming that the same value 910 // reaches along all incoming edges may itself be undefined behaviour. This 911 // particular interpretation of the undef value was added to 912 // DivergenceAnalysis in the following review: 913 // 914 // https://reviews.llvm.org/D19013 915 if (ContextT::isConstantOrUndefValuePhi(Phi)) 916 continue; 917 markDivergent(Phi); 918 } 919 } 920 921 /// Add \p Candidate to \p Cycles if it is not already contained in \p Cycles. 922 /// 923 /// \return true iff \p Candidate was added to \p Cycles. 924 template <typename CycleT> 925 static bool insertIfNotContained(SmallVector<CycleT *> &Cycles, 926 CycleT *Candidate) { 927 if (llvm::any_of(Cycles, 928 [Candidate](CycleT *C) { return C->contains(Candidate); })) 929 return false; 930 Cycles.push_back(Candidate); 931 return true; 932 } 933 934 /// Return the outermost cycle made divergent by branch outside it. 935 /// 936 /// If two paths that diverged outside an irreducible cycle join 937 /// inside that cycle, then that whole cycle is assumed to be 938 /// divergent. This does not apply if the cycle is reducible. 939 template <typename CycleT, typename BlockT> 940 static const CycleT *getExtDivCycle(const CycleT *Cycle, 941 const BlockT *DivTermBlock, 942 const BlockT *JoinBlock) { 943 assert(Cycle); 944 assert(Cycle->contains(JoinBlock)); 945 946 if (Cycle->contains(DivTermBlock)) 947 return nullptr; 948 949 const auto *OriginalCycle = Cycle; 950 const auto *Parent = Cycle->getParentCycle(); 951 while (Parent && !Parent->contains(DivTermBlock)) { 952 Cycle = Parent; 953 Parent = Cycle->getParentCycle(); 954 } 955 956 // If the original cycle is not the outermost cycle, then the outermost cycle 957 // is irreducible. If the outermost cycle were reducible, then external 958 // diverged paths would not reach the original inner cycle. 959 (void)OriginalCycle; 960 assert(Cycle == OriginalCycle || !Cycle->isReducible()); 961 962 if (Cycle->isReducible()) { 963 assert(Cycle->getHeader() == JoinBlock); 964 return nullptr; 965 } 966 967 LLVM_DEBUG(dbgs() << "cycle made divergent by external branch\n"); 968 return Cycle; 969 } 970 971 /// Return the outermost cycle made divergent by branch inside it. 972 /// 973 /// This checks the "diverged entry" criterion defined in the 974 /// docs/ConvergenceAnalysis.html. 975 template <typename ContextT, typename CycleT, typename BlockT, 976 typename DominatorTreeT> 977 static const CycleT * 978 getIntDivCycle(const CycleT *Cycle, const BlockT *DivTermBlock, 979 const BlockT *JoinBlock, const DominatorTreeT &DT, 980 ContextT &Context) { 981 LLVM_DEBUG(dbgs() << "examine join " << Context.print(JoinBlock) 982 << " for internal branch " << Context.print(DivTermBlock) 983 << "\n"); 984 if (DT.properlyDominates(DivTermBlock, JoinBlock)) 985 return nullptr; 986 987 // Find the smallest common cycle, if one exists. 988 assert(Cycle && Cycle->contains(JoinBlock)); 989 while (Cycle && !Cycle->contains(DivTermBlock)) { 990 Cycle = Cycle->getParentCycle(); 991 } 992 if (!Cycle || Cycle->isReducible()) 993 return nullptr; 994 995 if (DT.properlyDominates(Cycle->getHeader(), JoinBlock)) 996 return nullptr; 997 998 LLVM_DEBUG(dbgs() << " header " << Context.print(Cycle->getHeader()) 999 << " does not dominate join\n"); 1000 1001 const auto *Parent = Cycle->getParentCycle(); 1002 while (Parent && !DT.properlyDominates(Parent->getHeader(), JoinBlock)) { 1003 LLVM_DEBUG(dbgs() << " header " << Context.print(Parent->getHeader()) 1004 << " does not dominate join\n"); 1005 Cycle = Parent; 1006 Parent = Parent->getParentCycle(); 1007 } 1008 1009 LLVM_DEBUG(dbgs() << " cycle made divergent by internal branch\n"); 1010 return Cycle; 1011 } 1012 1013 template <typename ContextT, typename CycleT, typename BlockT, 1014 typename DominatorTreeT> 1015 static const CycleT * 1016 getOutermostDivergentCycle(const CycleT *Cycle, const BlockT *DivTermBlock, 1017 const BlockT *JoinBlock, const DominatorTreeT &DT, 1018 ContextT &Context) { 1019 if (!Cycle) 1020 return nullptr; 1021 1022 // First try to expand Cycle to the largest that contains JoinBlock 1023 // but not DivTermBlock. 1024 const auto *Ext = getExtDivCycle(Cycle, DivTermBlock, JoinBlock); 1025 1026 // Continue expanding to the largest cycle that contains both. 1027 const auto *Int = getIntDivCycle(Cycle, DivTermBlock, JoinBlock, DT, Context); 1028 1029 if (Int) 1030 return Int; 1031 return Ext; 1032 } 1033 1034 template <typename ContextT> 1035 bool GenericUniformityAnalysisImpl<ContextT>::isTemporalDivergent( 1036 const BlockT &ObservingBlock, const InstructionT &Def) const { 1037 const BlockT *DefBlock = Def.getParent(); 1038 for (const CycleT *Cycle = CI.getCycle(DefBlock); 1039 Cycle && !Cycle->contains(&ObservingBlock); 1040 Cycle = Cycle->getParentCycle()) { 1041 if (DivergentExitCycles.contains(Cycle)) { 1042 return true; 1043 } 1044 } 1045 return false; 1046 } 1047 1048 template <typename ContextT> 1049 void GenericUniformityAnalysisImpl<ContextT>::analyzeControlDivergence( 1050 const InstructionT &Term) { 1051 const auto *DivTermBlock = Term.getParent(); 1052 DivergentTermBlocks.insert(DivTermBlock); 1053 LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Context.print(DivTermBlock) 1054 << "\n"); 1055 1056 // Don't propagate divergence from unreachable blocks. 1057 if (!DT.isReachableFromEntry(DivTermBlock)) 1058 return; 1059 1060 const auto &DivDesc = SDA.getJoinBlocks(DivTermBlock); 1061 SmallVector<const CycleT *> DivCycles; 1062 1063 // Iterate over all blocks now reachable by a disjoint path join 1064 for (const auto *JoinBlock : DivDesc.JoinDivBlocks) { 1065 const auto *Cycle = CI.getCycle(JoinBlock); 1066 LLVM_DEBUG(dbgs() << "visiting join block " << Context.print(JoinBlock) 1067 << "\n"); 1068 if (const auto *Outermost = getOutermostDivergentCycle( 1069 Cycle, DivTermBlock, JoinBlock, DT, Context)) { 1070 LLVM_DEBUG(dbgs() << "found divergent cycle\n"); 1071 DivCycles.push_back(Outermost); 1072 continue; 1073 } 1074 taintAndPushPhiNodes(*JoinBlock); 1075 } 1076 1077 // Sort by order of decreasing depth. This allows later cycles to be skipped 1078 // because they are already contained in earlier ones. 1079 llvm::sort(DivCycles, [](const CycleT *A, const CycleT *B) { 1080 return A->getDepth() > B->getDepth(); 1081 }); 1082 1083 // Cycles that are assumed divergent due to the diverged entry 1084 // criterion potentially contain temporal divergence depending on 1085 // the DFS chosen. Conservatively, all values produced in such a 1086 // cycle are assumed divergent. "Cycle invariant" values may be 1087 // assumed uniform, but that requires further analysis. 1088 for (auto *C : DivCycles) { 1089 if (!insertIfNotContained(AssumedDivergent, C)) 1090 continue; 1091 LLVM_DEBUG(dbgs() << "process divergent cycle\n"); 1092 for (const BlockT *BB : C->blocks()) { 1093 taintAndPushAllDefs(*BB); 1094 } 1095 } 1096 1097 const auto *BranchCycle = CI.getCycle(DivTermBlock); 1098 assert(DivDesc.CycleDivBlocks.empty() || BranchCycle); 1099 for (const auto *DivExitBlock : DivDesc.CycleDivBlocks) { 1100 propagateCycleExitDivergence(*DivExitBlock, *BranchCycle); 1101 } 1102 } 1103 1104 template <typename ContextT> 1105 void GenericUniformityAnalysisImpl<ContextT>::compute() { 1106 // Initialize worklist. 1107 auto DivValuesCopy = DivergentValues; 1108 for (const auto DivVal : DivValuesCopy) { 1109 assert(isDivergent(DivVal) && "Worklist invariant violated!"); 1110 pushUsers(DivVal); 1111 } 1112 1113 // All values on the Worklist are divergent. 1114 // Their users may not have been updated yet. 1115 while (!Worklist.empty()) { 1116 const InstructionT *I = Worklist.back(); 1117 Worklist.pop_back(); 1118 1119 LLVM_DEBUG(dbgs() << "worklist pop: " << Context.print(I) << "\n"); 1120 1121 if (I->isTerminator()) { 1122 analyzeControlDivergence(*I); 1123 continue; 1124 } 1125 1126 // propagate value divergence to users 1127 assert(isDivergent(*I) && "Worklist invariant violated!"); 1128 pushUsers(*I); 1129 } 1130 } 1131 1132 template <typename ContextT> 1133 bool GenericUniformityAnalysisImpl<ContextT>::isAlwaysUniform( 1134 const InstructionT &Instr) const { 1135 return UniformOverrides.contains(&Instr); 1136 } 1137 1138 template <typename ContextT> 1139 GenericUniformityInfo<ContextT>::GenericUniformityInfo( 1140 const DominatorTreeT &DT, const CycleInfoT &CI, 1141 const TargetTransformInfo *TTI) { 1142 DA.reset(new ImplT{DT, CI, TTI}); 1143 } 1144 1145 template <typename ContextT> 1146 void GenericUniformityAnalysisImpl<ContextT>::print(raw_ostream &OS) const { 1147 bool haveDivergentArgs = false; 1148 1149 // Control flow instructions may be divergent even if their inputs are 1150 // uniform. Thus, although exceedingly rare, it is possible to have a program 1151 // with no divergent values but with divergent control structures. 1152 if (DivergentValues.empty() && DivergentTermBlocks.empty() && 1153 DivergentExitCycles.empty()) { 1154 OS << "ALL VALUES UNIFORM\n"; 1155 return; 1156 } 1157 1158 for (const auto &entry : DivergentValues) { 1159 const BlockT *parent = Context.getDefBlock(entry); 1160 if (!parent) { 1161 if (!haveDivergentArgs) { 1162 OS << "DIVERGENT ARGUMENTS:\n"; 1163 haveDivergentArgs = true; 1164 } 1165 OS << " DIVERGENT: " << Context.print(entry) << '\n'; 1166 } 1167 } 1168 1169 if (!AssumedDivergent.empty()) { 1170 OS << "CYCLES ASSSUMED DIVERGENT:\n"; 1171 for (const CycleT *cycle : AssumedDivergent) { 1172 OS << " " << cycle->print(Context) << '\n'; 1173 } 1174 } 1175 1176 if (!DivergentExitCycles.empty()) { 1177 OS << "CYCLES WITH DIVERGENT EXIT:\n"; 1178 for (const CycleT *cycle : DivergentExitCycles) { 1179 OS << " " << cycle->print(Context) << '\n'; 1180 } 1181 } 1182 1183 for (auto &block : F) { 1184 OS << "\nBLOCK " << Context.print(&block) << '\n'; 1185 1186 OS << "DEFINITIONS\n"; 1187 SmallVector<ConstValueRefT, 16> defs; 1188 Context.appendBlockDefs(defs, block); 1189 for (auto value : defs) { 1190 if (isDivergent(value)) 1191 OS << " DIVERGENT: "; 1192 else 1193 OS << " "; 1194 OS << Context.print(value) << '\n'; 1195 } 1196 1197 OS << "TERMINATORS\n"; 1198 SmallVector<const InstructionT *, 8> terms; 1199 Context.appendBlockTerms(terms, block); 1200 bool divergentTerminators = hasDivergentTerminator(block); 1201 for (auto *T : terms) { 1202 if (divergentTerminators) 1203 OS << " DIVERGENT: "; 1204 else 1205 OS << " "; 1206 OS << Context.print(T) << '\n'; 1207 } 1208 1209 OS << "END BLOCK\n"; 1210 } 1211 } 1212 1213 template <typename ContextT> 1214 bool GenericUniformityInfo<ContextT>::hasDivergence() const { 1215 return DA->hasDivergence(); 1216 } 1217 1218 template <typename ContextT> 1219 const typename ContextT::FunctionT & 1220 GenericUniformityInfo<ContextT>::getFunction() const { 1221 return DA->getFunction(); 1222 } 1223 1224 /// Whether \p V is divergent at its definition. 1225 template <typename ContextT> 1226 bool GenericUniformityInfo<ContextT>::isDivergent(ConstValueRefT V) const { 1227 return DA->isDivergent(V); 1228 } 1229 1230 template <typename ContextT> 1231 bool GenericUniformityInfo<ContextT>::isDivergent(const InstructionT *I) const { 1232 return DA->isDivergent(*I); 1233 } 1234 1235 template <typename ContextT> 1236 bool GenericUniformityInfo<ContextT>::isDivergentUse(const UseT &U) const { 1237 return DA->isDivergentUse(U); 1238 } 1239 1240 template <typename ContextT> 1241 bool GenericUniformityInfo<ContextT>::hasDivergentTerminator(const BlockT &B) { 1242 return DA->hasDivergentTerminator(B); 1243 } 1244 1245 /// \brief T helper function for printing. 1246 template <typename ContextT> 1247 void GenericUniformityInfo<ContextT>::print(raw_ostream &out) const { 1248 DA->print(out); 1249 } 1250 1251 template <typename ContextT> 1252 void llvm::ModifiedPostOrder<ContextT>::computeStackPO( 1253 SmallVectorImpl<const BlockT *> &Stack, const CycleInfoT &CI, 1254 const CycleT *Cycle, SmallPtrSetImpl<const BlockT *> &Finalized) { 1255 LLVM_DEBUG(dbgs() << "inside computeStackPO\n"); 1256 while (!Stack.empty()) { 1257 auto *NextBB = Stack.back(); 1258 if (Finalized.count(NextBB)) { 1259 Stack.pop_back(); 1260 continue; 1261 } 1262 LLVM_DEBUG(dbgs() << " visiting " << CI.getSSAContext().print(NextBB) 1263 << "\n"); 1264 auto *NestedCycle = CI.getCycle(NextBB); 1265 if (Cycle != NestedCycle && (!Cycle || Cycle->contains(NestedCycle))) { 1266 LLVM_DEBUG(dbgs() << " found a cycle\n"); 1267 while (NestedCycle->getParentCycle() != Cycle) 1268 NestedCycle = NestedCycle->getParentCycle(); 1269 1270 SmallVector<BlockT *, 3> NestedExits; 1271 NestedCycle->getExitBlocks(NestedExits); 1272 bool PushedNodes = false; 1273 for (auto *NestedExitBB : NestedExits) { 1274 LLVM_DEBUG(dbgs() << " examine exit: " 1275 << CI.getSSAContext().print(NestedExitBB) << "\n"); 1276 if (Cycle && !Cycle->contains(NestedExitBB)) 1277 continue; 1278 if (Finalized.count(NestedExitBB)) 1279 continue; 1280 PushedNodes = true; 1281 Stack.push_back(NestedExitBB); 1282 LLVM_DEBUG(dbgs() << " pushed exit: " 1283 << CI.getSSAContext().print(NestedExitBB) << "\n"); 1284 } 1285 if (!PushedNodes) { 1286 // All loop exits finalized -> finish this node 1287 Stack.pop_back(); 1288 computeCyclePO(CI, NestedCycle, Finalized); 1289 } 1290 continue; 1291 } 1292 1293 LLVM_DEBUG(dbgs() << " no nested cycle, going into DAG\n"); 1294 // DAG-style 1295 bool PushedNodes = false; 1296 for (auto *SuccBB : successors(NextBB)) { 1297 LLVM_DEBUG(dbgs() << " examine succ: " 1298 << CI.getSSAContext().print(SuccBB) << "\n"); 1299 if (Cycle && !Cycle->contains(SuccBB)) 1300 continue; 1301 if (Finalized.count(SuccBB)) 1302 continue; 1303 PushedNodes = true; 1304 Stack.push_back(SuccBB); 1305 LLVM_DEBUG(dbgs() << " pushed succ: " << CI.getSSAContext().print(SuccBB) 1306 << "\n"); 1307 } 1308 if (!PushedNodes) { 1309 // Never push nodes twice 1310 LLVM_DEBUG(dbgs() << " finishing node: " 1311 << CI.getSSAContext().print(NextBB) << "\n"); 1312 Stack.pop_back(); 1313 Finalized.insert(NextBB); 1314 appendBlock(*NextBB); 1315 } 1316 } 1317 LLVM_DEBUG(dbgs() << "exited computeStackPO\n"); 1318 } 1319 1320 template <typename ContextT> 1321 void ModifiedPostOrder<ContextT>::computeCyclePO( 1322 const CycleInfoT &CI, const CycleT *Cycle, 1323 SmallPtrSetImpl<const BlockT *> &Finalized) { 1324 LLVM_DEBUG(dbgs() << "inside computeCyclePO\n"); 1325 SmallVector<const BlockT *> Stack; 1326 auto *CycleHeader = Cycle->getHeader(); 1327 1328 LLVM_DEBUG(dbgs() << " noted header: " 1329 << CI.getSSAContext().print(CycleHeader) << "\n"); 1330 assert(!Finalized.count(CycleHeader)); 1331 Finalized.insert(CycleHeader); 1332 1333 // Visit the header last 1334 LLVM_DEBUG(dbgs() << " finishing header: " 1335 << CI.getSSAContext().print(CycleHeader) << "\n"); 1336 appendBlock(*CycleHeader, Cycle->isReducible()); 1337 1338 // Initialize with immediate successors 1339 for (auto *BB : successors(CycleHeader)) { 1340 LLVM_DEBUG(dbgs() << " examine succ: " << CI.getSSAContext().print(BB) 1341 << "\n"); 1342 if (!Cycle->contains(BB)) 1343 continue; 1344 if (BB == CycleHeader) 1345 continue; 1346 if (!Finalized.count(BB)) { 1347 LLVM_DEBUG(dbgs() << " pushed succ: " << CI.getSSAContext().print(BB) 1348 << "\n"); 1349 Stack.push_back(BB); 1350 } 1351 } 1352 1353 // Compute PO inside region 1354 computeStackPO(Stack, CI, Cycle, Finalized); 1355 1356 LLVM_DEBUG(dbgs() << "exited computeCyclePO\n"); 1357 } 1358 1359 /// \brief Generically compute the modified post order. 1360 template <typename ContextT> 1361 void llvm::ModifiedPostOrder<ContextT>::compute(const CycleInfoT &CI) { 1362 SmallPtrSet<const BlockT *, 32> Finalized; 1363 SmallVector<const BlockT *> Stack; 1364 auto *F = CI.getFunction(); 1365 Stack.reserve(24); // FIXME made-up number 1366 Stack.push_back(&F->front()); 1367 computeStackPO(Stack, CI, nullptr, Finalized); 1368 } 1369 1370 } // namespace llvm 1371 1372 #undef DEBUG_TYPE 1373 1374 #endif // LLVM_ADT_GENERICUNIFORMITYIMPL_H 1375