1 //===- GenericUniformAnalysis.cpp --------------------*- C++ -*------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This template implementation resides in a separate file so that it
10 // does not get injected into every .cpp file that includes the
11 // generic header.
12 //
13 // DO NOT INCLUDE THIS FILE WHEN MERELY USING UNIFORMITYINFO.
14 //
15 // This file should only be included by files that implement a
16 // specialization of the relvant templates. Currently these are:
17 // - UniformityAnalysis.cpp
18 //
19 // Note: The DEBUG_TYPE macro should be defined before using this
20 // file so that any use of LLVM_DEBUG is associated with the
21 // including file rather than this file.
22 //
23 //===----------------------------------------------------------------------===//
24 ///
25 /// \file
26 /// \brief Implementation of uniformity analysis.
27 ///
28 /// The algorithm is a fixed point iteration that starts with the assumption
29 /// that all control flow and all values are uniform. Starting from sources of
30 /// divergence (whose discovery must be implemented by a CFG- or even
31 /// target-specific derived class), divergence of values is propagated from
32 /// definition to uses in a straight-forward way. The main complexity lies in
33 /// the propagation of the impact of divergent control flow on the divergence of
34 /// values (sync dependencies).
35 ///
36 //===----------------------------------------------------------------------===//
37
38 #ifndef LLVM_ADT_GENERICUNIFORMITYIMPL_H
39 #define LLVM_ADT_GENERICUNIFORMITYIMPL_H
40
41 #include "llvm/ADT/GenericUniformityInfo.h"
42
43 #include "llvm/ADT/SmallPtrSet.h"
44 #include "llvm/ADT/SparseBitVector.h"
45 #include "llvm/ADT/StringExtras.h"
46 #include "llvm/Support/raw_ostream.h"
47
48 #include <set>
49
50 #define DEBUG_TYPE "uniformity"
51
52 using namespace llvm;
53
54 namespace llvm {
55
unique(Range && R)56 template <typename Range> auto unique(Range &&R) {
57 return std::unique(adl_begin(R), adl_end(R));
58 }
59
60 /// Construct a specially modified post-order traversal of cycles.
61 ///
62 /// The ModifiedPO is contructed using a virtually modified CFG as follows:
63 ///
64 /// 1. The successors of pre-entry nodes (predecessors of an cycle
65 /// entry that are outside the cycle) are replaced by the
66 /// successors of the successors of the header.
67 /// 2. Successors of the cycle header are replaced by the exit blocks
68 /// of the cycle.
69 ///
70 /// Effectively, we produce a depth-first numbering with the following
71 /// properties:
72 ///
73 /// 1. Nodes after a cycle are numbered earlier than the cycle header.
74 /// 2. The header is numbered earlier than the nodes in the cycle.
75 /// 3. The numbering of the nodes within the cycle forms an interval
76 /// starting with the header.
77 ///
78 /// Effectively, the virtual modification arranges the nodes in a
79 /// cycle as a DAG with the header as the sole leaf, and successors of
80 /// the header as the roots. A reverse traversal of this numbering has
81 /// the following invariant on the unmodified original CFG:
82 ///
83 /// Each node is visited after all its predecessors, except if that
84 /// predecessor is the cycle header.
85 ///
86 template <typename ContextT> class ModifiedPostOrder {
87 public:
88 using BlockT = typename ContextT::BlockT;
89 using FunctionT = typename ContextT::FunctionT;
90 using DominatorTreeT = typename ContextT::DominatorTreeT;
91
92 using CycleInfoT = GenericCycleInfo<ContextT>;
93 using CycleT = typename CycleInfoT::CycleT;
94 using const_iterator = typename std::vector<BlockT *>::const_iterator;
95
ModifiedPostOrder(const ContextT & C)96 ModifiedPostOrder(const ContextT &C) : Context(C) {}
97
empty()98 bool empty() const { return m_order.empty(); }
size()99 size_t size() const { return m_order.size(); }
100
clear()101 void clear() { m_order.clear(); }
102 void compute(const CycleInfoT &CI);
103
count(BlockT * BB)104 unsigned count(BlockT *BB) const { return POIndex.count(BB); }
105 const BlockT *operator[](size_t idx) const { return m_order[idx]; }
106
107 void appendBlock(const BlockT &BB, bool isReducibleCycleHeader = false) {
108 POIndex[&BB] = m_order.size();
109 m_order.push_back(&BB);
110 LLVM_DEBUG(dbgs() << "ModifiedPO(" << POIndex[&BB]
111 << "): " << Context.print(&BB) << "\n");
112 if (isReducibleCycleHeader)
113 ReducibleCycleHeaders.insert(&BB);
114 }
115
getIndex(const BlockT * BB)116 unsigned getIndex(const BlockT *BB) const {
117 assert(POIndex.count(BB));
118 return POIndex.lookup(BB);
119 }
120
isReducibleCycleHeader(const BlockT * BB)121 bool isReducibleCycleHeader(const BlockT *BB) const {
122 return ReducibleCycleHeaders.contains(BB);
123 }
124
125 private:
126 SmallVector<const BlockT *> m_order;
127 DenseMap<const BlockT *, unsigned> POIndex;
128 SmallPtrSet<const BlockT *, 32> ReducibleCycleHeaders;
129 const ContextT &Context;
130
131 void computeCyclePO(const CycleInfoT &CI, const CycleT *Cycle,
132 SmallPtrSetImpl<BlockT *> &Finalized);
133
134 void computeStackPO(SmallVectorImpl<BlockT *> &Stack, const CycleInfoT &CI,
135 const CycleT *Cycle,
136 SmallPtrSetImpl<BlockT *> &Finalized);
137 };
138
139 template <typename> class DivergencePropagator;
140
141 /// \class GenericSyncDependenceAnalysis
142 ///
143 /// \brief Locate join blocks for disjoint paths starting at a divergent branch.
144 ///
145 /// An analysis per divergent branch that returns the set of basic
146 /// blocks whose phi nodes become divergent due to divergent control.
147 /// These are the blocks that are reachable by two disjoint paths from
148 /// the branch, or cycle exits reachable along a path that is disjoint
149 /// from a path to the cycle latch.
150
151 // --- Above line is not a doxygen comment; intentionally left blank ---
152 //
153 // Originally implemented in SyncDependenceAnalysis.cpp for DivergenceAnalysis.
154 //
155 // The SyncDependenceAnalysis is used in the UniformityAnalysis to model
156 // control-induced divergence in phi nodes.
157 //
158 // -- Reference --
159 // The algorithm is an extension of Section 5 of
160 //
161 // An abstract interpretation for SPMD divergence
162 // on reducible control flow graphs.
163 // Julian Rosemann, Simon Moll and Sebastian Hack
164 // POPL '21
165 //
166 //
167 // -- Sync dependence --
168 // Sync dependence characterizes the control flow aspect of the
169 // propagation of branch divergence. For example,
170 //
171 // %cond = icmp slt i32 %tid, 10
172 // br i1 %cond, label %then, label %else
173 // then:
174 // br label %merge
175 // else:
176 // br label %merge
177 // merge:
178 // %a = phi i32 [ 0, %then ], [ 1, %else ]
179 //
180 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
181 // because %tid is not on its use-def chains, %a is sync dependent on %tid
182 // because the branch "br i1 %cond" depends on %tid and affects which value %a
183 // is assigned to.
184 //
185 //
186 // -- Reduction to SSA construction --
187 // There are two disjoint paths from A to X, if a certain variant of SSA
188 // construction places a phi node in X under the following set-up scheme.
189 //
190 // This variant of SSA construction ignores incoming undef values.
191 // That is paths from the entry without a definition do not result in
192 // phi nodes.
193 //
194 // entry
195 // / \
196 // A \
197 // / \ Y
198 // B C /
199 // \ / \ /
200 // D E
201 // \ /
202 // F
203 //
204 // Assume that A contains a divergent branch. We are interested
205 // in the set of all blocks where each block is reachable from A
206 // via two disjoint paths. This would be the set {D, F} in this
207 // case.
208 // To generally reduce this query to SSA construction we introduce
209 // a virtual variable x and assign to x different values in each
210 // successor block of A.
211 //
212 // entry
213 // / \
214 // A \
215 // / \ Y
216 // x = 0 x = 1 /
217 // \ / \ /
218 // D E
219 // \ /
220 // F
221 //
222 // Our flavor of SSA construction for x will construct the following
223 //
224 // entry
225 // / \
226 // A \
227 // / \ Y
228 // x0 = 0 x1 = 1 /
229 // \ / \ /
230 // x2 = phi E
231 // \ /
232 // x3 = phi
233 //
234 // The blocks D and F contain phi nodes and are thus each reachable
235 // by two disjoins paths from A.
236 //
237 // -- Remarks --
238 // * In case of cycle exits we need to check for temporal divergence.
239 // To this end, we check whether the definition of x differs between the
240 // cycle exit and the cycle header (_after_ SSA construction).
241 //
242 // * In the presence of irreducible control flow, the fixed point is
243 // reached only after multiple iterations. This is because labels
244 // reaching the header of a cycle must be repropagated through the
245 // cycle. This is true even in a reducible cycle, since the labels
246 // may have been produced by a nested irreducible cycle.
247 //
248 // * Note that SyncDependenceAnalysis is not concerned with the points
249 // of convergence in an irreducible cycle. It's only purpose is to
250 // identify join blocks. The "diverged entry" criterion is
251 // separately applied on join blocks to determine if an entire
252 // irreducible cycle is assumed to be divergent.
253 //
254 // * Relevant related work:
255 // A simple algorithm for global data flow analysis problems.
256 // Matthew S. Hecht and Jeffrey D. Ullman.
257 // SIAM Journal on Computing, 4(4):519–532, December 1975.
258 //
259 template <typename ContextT> class GenericSyncDependenceAnalysis {
260 public:
261 using BlockT = typename ContextT::BlockT;
262 using DominatorTreeT = typename ContextT::DominatorTreeT;
263 using FunctionT = typename ContextT::FunctionT;
264 using ValueRefT = typename ContextT::ValueRefT;
265 using InstructionT = typename ContextT::InstructionT;
266
267 using CycleInfoT = GenericCycleInfo<ContextT>;
268 using CycleT = typename CycleInfoT::CycleT;
269
270 using ConstBlockSet = SmallPtrSet<const BlockT *, 4>;
271 using ModifiedPO = ModifiedPostOrder<ContextT>;
272
273 // * if BlockLabels[B] == C then C is the dominating definition at
274 // block B
275 // * if BlockLabels[B] == nullptr then we haven't seen B yet
276 // * if BlockLabels[B] == B then:
277 // - B is a join point of disjoint paths from X, or,
278 // - B is an immediate successor of X (initial value), or,
279 // - B is X
280 using BlockLabelMap = DenseMap<const BlockT *, const BlockT *>;
281
282 /// Information discovered by the sync dependence analysis for each
283 /// divergent branch.
284 struct DivergenceDescriptor {
285 // Join points of diverged paths.
286 ConstBlockSet JoinDivBlocks;
287 // Divergent cycle exits
288 ConstBlockSet CycleDivBlocks;
289 // Labels assigned to blocks on diverged paths.
290 BlockLabelMap BlockLabels;
291 };
292
293 using DivergencePropagatorT = DivergencePropagator<ContextT>;
294
295 GenericSyncDependenceAnalysis(const ContextT &Context,
296 const DominatorTreeT &DT, const CycleInfoT &CI);
297
298 /// \brief Computes divergent join points and cycle exits caused by branch
299 /// divergence in \p Term.
300 ///
301 /// This returns a pair of sets:
302 /// * The set of blocks which are reachable by disjoint paths from
303 /// \p Term.
304 /// * The set also contains cycle exits if there two disjoint paths:
305 /// one from \p Term to the cycle exit and another from \p Term to
306 /// the cycle header.
307 const DivergenceDescriptor &getJoinBlocks(const BlockT *DivTermBlock);
308
309 private:
310 static DivergenceDescriptor EmptyDivergenceDesc;
311
312 ModifiedPO CyclePO;
313
314 const DominatorTreeT &DT;
315 const CycleInfoT &CI;
316
317 DenseMap<const BlockT *, std::unique_ptr<DivergenceDescriptor>>
318 CachedControlDivDescs;
319 };
320
321 /// \brief Analysis that identifies uniform values in a data-parallel
322 /// execution.
323 ///
324 /// This analysis propagates divergence in a data-parallel context
325 /// from sources of divergence to all users. It can be instantiated
326 /// for an IR that provides a suitable SSAContext.
327 template <typename ContextT> class GenericUniformityAnalysisImpl {
328 public:
329 using BlockT = typename ContextT::BlockT;
330 using FunctionT = typename ContextT::FunctionT;
331 using ValueRefT = typename ContextT::ValueRefT;
332 using ConstValueRefT = typename ContextT::ConstValueRefT;
333 using InstructionT = typename ContextT::InstructionT;
334 using DominatorTreeT = typename ContextT::DominatorTreeT;
335
336 using CycleInfoT = GenericCycleInfo<ContextT>;
337 using CycleT = typename CycleInfoT::CycleT;
338
339 using SyncDependenceAnalysisT = GenericSyncDependenceAnalysis<ContextT>;
340 using DivergenceDescriptorT =
341 typename SyncDependenceAnalysisT::DivergenceDescriptor;
342 using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap;
343
GenericUniformityAnalysisImpl(const FunctionT & F,const DominatorTreeT & DT,const CycleInfoT & CI,const TargetTransformInfo * TTI)344 GenericUniformityAnalysisImpl(const FunctionT &F, const DominatorTreeT &DT,
345 const CycleInfoT &CI,
346 const TargetTransformInfo *TTI)
347 : Context(CI.getSSAContext()), F(F), CI(CI), TTI(TTI), DT(DT),
348 SDA(Context, DT, CI) {}
349
350 void initialize();
351
getFunction()352 const FunctionT &getFunction() const { return F; }
353
354 /// \brief Mark \p UniVal as a value that is always uniform.
355 void addUniformOverride(const InstructionT &Instr);
356
357 /// \brief Mark \p DivVal as a value that is always divergent.
358 /// \returns Whether the tracked divergence state of \p DivVal changed.
359 bool markDivergent(const InstructionT &I);
360 bool markDivergent(ConstValueRefT DivVal);
361 bool markDefsDivergent(const InstructionT &Instr,
362 bool AllDefsDivergent = true);
363
364 /// \brief Propagate divergence to all instructions in the region.
365 /// Divergence is seeded by calls to \p markDivergent.
366 void compute();
367
368 /// \brief Whether any value was marked or analyzed to be divergent.
hasDivergence()369 bool hasDivergence() const { return !DivergentValues.empty(); }
370
371 /// \brief Whether \p Val will always return a uniform value regardless of its
372 /// operands
373 bool isAlwaysUniform(const InstructionT &Instr) const;
374
375 bool hasDivergentDefs(const InstructionT &I) const;
376
isDivergent(const InstructionT & I)377 bool isDivergent(const InstructionT &I) const {
378 if (I.isTerminator()) {
379 return DivergentTermBlocks.contains(I.getParent());
380 }
381 return hasDivergentDefs(I);
382 };
383
384 /// \brief Whether \p Val is divergent at its definition.
isDivergent(ConstValueRefT V)385 bool isDivergent(ConstValueRefT V) const { return DivergentValues.count(V); }
386
hasDivergentTerminator(const BlockT & B)387 bool hasDivergentTerminator(const BlockT &B) const {
388 return DivergentTermBlocks.contains(&B);
389 }
390
391 void print(raw_ostream &out) const;
392
393 protected:
394 /// \brief Value/block pair representing a single phi input.
395 struct PhiInput {
396 ConstValueRefT value;
397 BlockT *predBlock;
398
PhiInputPhiInput399 PhiInput(ConstValueRefT value, BlockT *predBlock)
400 : value(value), predBlock(predBlock) {}
401 };
402
403 const ContextT &Context;
404 const FunctionT &F;
405 const CycleInfoT &CI;
406 const TargetTransformInfo *TTI = nullptr;
407
408 // Detected/marked divergent values.
409 std::set<ConstValueRefT> DivergentValues;
410 SmallPtrSet<const BlockT *, 32> DivergentTermBlocks;
411
412 // Internal worklist for divergence propagation.
413 std::vector<const InstructionT *> Worklist;
414
415 /// \brief Mark \p Term as divergent and push all Instructions that become
416 /// divergent as a result on the worklist.
417 void analyzeControlDivergence(const InstructionT &Term);
418
419 private:
420 const DominatorTreeT &DT;
421
422 // Recognized cycles with divergent exits.
423 SmallPtrSet<const CycleT *, 16> DivergentExitCycles;
424
425 // Cycles assumed to be divergent.
426 //
427 // We don't use a set here because every insertion needs an explicit
428 // traversal of all existing members.
429 SmallVector<const CycleT *> AssumedDivergent;
430
431 // The SDA links divergent branches to divergent control-flow joins.
432 SyncDependenceAnalysisT SDA;
433
434 // Set of known-uniform values.
435 SmallPtrSet<const InstructionT *, 32> UniformOverrides;
436
437 /// \brief Mark all nodes in \p JoinBlock as divergent and push them on
438 /// the worklist.
439 void taintAndPushAllDefs(const BlockT &JoinBlock);
440
441 /// \brief Mark all phi nodes in \p JoinBlock as divergent and push them on
442 /// the worklist.
443 void taintAndPushPhiNodes(const BlockT &JoinBlock);
444
445 /// \brief Identify all Instructions that become divergent because \p DivExit
446 /// is a divergent cycle exit of \p DivCycle. Mark those instructions as
447 /// divergent and push them on the worklist.
448 void propagateCycleExitDivergence(const BlockT &DivExit,
449 const CycleT &DivCycle);
450
451 /// \brief Internal implementation function for propagateCycleExitDivergence.
452 void analyzeCycleExitDivergence(const CycleT &OuterDivCycle);
453
454 /// \brief Mark all instruction as divergent that use a value defined in \p
455 /// OuterDivCycle. Push their users on the worklist.
456 void analyzeTemporalDivergence(const InstructionT &I,
457 const CycleT &OuterDivCycle);
458
459 /// \brief Push all users of \p Val (in the region) to the worklist.
460 void pushUsers(const InstructionT &I);
461 void pushUsers(ConstValueRefT V);
462
463 bool usesValueFromCycle(const InstructionT &I, const CycleT &DefCycle) const;
464
465 /// \brief Whether \p Val is divergent when read in \p ObservingBlock.
466 bool isTemporalDivergent(const BlockT &ObservingBlock,
467 ConstValueRefT Val) const;
468 };
469
470 template <typename ImplT>
operator()471 void GenericUniformityAnalysisImplDeleter<ImplT>::operator()(ImplT *Impl) {
472 delete Impl;
473 }
474
475 /// Compute divergence starting with a divergent branch.
476 template <typename ContextT> class DivergencePropagator {
477 public:
478 using BlockT = typename ContextT::BlockT;
479 using DominatorTreeT = typename ContextT::DominatorTreeT;
480 using FunctionT = typename ContextT::FunctionT;
481 using ValueRefT = typename ContextT::ValueRefT;
482
483 using CycleInfoT = GenericCycleInfo<ContextT>;
484 using CycleT = typename CycleInfoT::CycleT;
485
486 using ModifiedPO = ModifiedPostOrder<ContextT>;
487 using SyncDependenceAnalysisT = GenericSyncDependenceAnalysis<ContextT>;
488 using DivergenceDescriptorT =
489 typename SyncDependenceAnalysisT::DivergenceDescriptor;
490 using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap;
491
492 const ModifiedPO &CyclePOT;
493 const DominatorTreeT &DT;
494 const CycleInfoT &CI;
495 const BlockT &DivTermBlock;
496 const ContextT &Context;
497
498 // Track blocks that receive a new label. Every time we relabel a
499 // cycle header, we another pass over the modified post-order in
500 // order to propagate the header label. The bit vector also allows
501 // us to skip labels that have not changed.
502 SparseBitVector<> FreshLabels;
503
504 // divergent join and cycle exit descriptor.
505 std::unique_ptr<DivergenceDescriptorT> DivDesc;
506 BlockLabelMapT &BlockLabels;
507
DivergencePropagator(const ModifiedPO & CyclePOT,const DominatorTreeT & DT,const CycleInfoT & CI,const BlockT & DivTermBlock)508 DivergencePropagator(const ModifiedPO &CyclePOT, const DominatorTreeT &DT,
509 const CycleInfoT &CI, const BlockT &DivTermBlock)
510 : CyclePOT(CyclePOT), DT(DT), CI(CI), DivTermBlock(DivTermBlock),
511 Context(CI.getSSAContext()), DivDesc(new DivergenceDescriptorT),
512 BlockLabels(DivDesc->BlockLabels) {}
513
printDefs(raw_ostream & Out)514 void printDefs(raw_ostream &Out) {
515 Out << "Propagator::BlockLabels {\n";
516 for (int BlockIdx = (int)CyclePOT.size() - 1; BlockIdx >= 0; --BlockIdx) {
517 const auto *Block = CyclePOT[BlockIdx];
518 const auto *Label = BlockLabels[Block];
519 Out << Context.print(Block) << "(" << BlockIdx << ") : ";
520 if (!Label) {
521 Out << "<null>\n";
522 } else {
523 Out << Context.print(Label) << "\n";
524 }
525 }
526 Out << "}\n";
527 }
528
529 // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
530 // causes a divergent join.
computeJoin(const BlockT & SuccBlock,const BlockT & PushedLabel)531 bool computeJoin(const BlockT &SuccBlock, const BlockT &PushedLabel) {
532 const auto *OldLabel = BlockLabels[&SuccBlock];
533
534 LLVM_DEBUG(dbgs() << "labeling " << Context.print(&SuccBlock) << ":\n"
535 << "\tpushed label: " << Context.print(&PushedLabel)
536 << "\n"
537 << "\told label: " << Context.print(OldLabel) << "\n");
538
539 // Early exit if there is no change in the label.
540 if (OldLabel == &PushedLabel)
541 return false;
542
543 if (OldLabel != &SuccBlock) {
544 auto SuccIdx = CyclePOT.getIndex(&SuccBlock);
545 // Assigning a new label, mark this in FreshLabels.
546 LLVM_DEBUG(dbgs() << "\tfresh label: " << SuccIdx << "\n");
547 FreshLabels.set(SuccIdx);
548 }
549
550 // This is not a join if the succ was previously unlabeled.
551 if (!OldLabel) {
552 LLVM_DEBUG(dbgs() << "\tnew label: " << Context.print(&PushedLabel)
553 << "\n");
554 BlockLabels[&SuccBlock] = &PushedLabel;
555 return false;
556 }
557
558 // This is a new join. Label the join block as itself, and not as
559 // the pushed label.
560 LLVM_DEBUG(dbgs() << "\tnew label: " << Context.print(&SuccBlock) << "\n");
561 BlockLabels[&SuccBlock] = &SuccBlock;
562
563 return true;
564 }
565
566 // visiting a virtual cycle exit edge from the cycle header --> temporal
567 // divergence on join
visitCycleExitEdge(const BlockT & ExitBlock,const BlockT & Label)568 bool visitCycleExitEdge(const BlockT &ExitBlock, const BlockT &Label) {
569 if (!computeJoin(ExitBlock, Label))
570 return false;
571
572 // Identified a divergent cycle exit
573 DivDesc->CycleDivBlocks.insert(&ExitBlock);
574 LLVM_DEBUG(dbgs() << "\tDivergent cycle exit: " << Context.print(&ExitBlock)
575 << "\n");
576 return true;
577 }
578
579 // process \p SuccBlock with reaching definition \p Label
visitEdge(const BlockT & SuccBlock,const BlockT & Label)580 bool visitEdge(const BlockT &SuccBlock, const BlockT &Label) {
581 if (!computeJoin(SuccBlock, Label))
582 return false;
583
584 // Divergent, disjoint paths join.
585 DivDesc->JoinDivBlocks.insert(&SuccBlock);
586 LLVM_DEBUG(dbgs() << "\tDivergent join: " << Context.print(&SuccBlock)
587 << "\n");
588 return true;
589 }
590
computeJoinPoints()591 std::unique_ptr<DivergenceDescriptorT> computeJoinPoints() {
592 assert(DivDesc);
593
594 LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: "
595 << Context.print(&DivTermBlock) << "\n");
596
597 // Early stopping criterion
598 int FloorIdx = CyclePOT.size() - 1;
599 const BlockT *FloorLabel = nullptr;
600 int DivTermIdx = CyclePOT.getIndex(&DivTermBlock);
601
602 // Bootstrap with branch targets
603 auto const *DivTermCycle = CI.getCycle(&DivTermBlock);
604 for (const auto *SuccBlock : successors(&DivTermBlock)) {
605 if (DivTermCycle && !DivTermCycle->contains(SuccBlock)) {
606 // If DivTerm exits the cycle immediately, computeJoin() might
607 // not reach SuccBlock with a different label. We need to
608 // check for this exit now.
609 DivDesc->CycleDivBlocks.insert(SuccBlock);
610 LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: "
611 << Context.print(SuccBlock) << "\n");
612 }
613 auto SuccIdx = CyclePOT.getIndex(SuccBlock);
614 visitEdge(*SuccBlock, *SuccBlock);
615 FloorIdx = std::min<int>(FloorIdx, SuccIdx);
616 }
617
618 while (true) {
619 auto BlockIdx = FreshLabels.find_last();
620 if (BlockIdx == -1 || BlockIdx < FloorIdx)
621 break;
622
623 LLVM_DEBUG(dbgs() << "Current labels:\n"; printDefs(dbgs()));
624
625 FreshLabels.reset(BlockIdx);
626 if (BlockIdx == DivTermIdx) {
627 LLVM_DEBUG(dbgs() << "Skipping DivTermBlock\n");
628 continue;
629 }
630
631 const auto *Block = CyclePOT[BlockIdx];
632 LLVM_DEBUG(dbgs() << "visiting " << Context.print(Block) << " at index "
633 << BlockIdx << "\n");
634
635 const auto *Label = BlockLabels[Block];
636 assert(Label);
637
638 bool CausedJoin = false;
639 int LoweredFloorIdx = FloorIdx;
640
641 // If the current block is the header of a reducible cycle that
642 // contains the divergent branch, then the label should be
643 // propagated to the cycle exits. Such a header is the "last
644 // possible join" of any disjoint paths within this cycle. This
645 // prevents detection of spurious joins at the entries of any
646 // irreducible child cycles.
647 //
648 // This conclusion about the header is true for any choice of DFS:
649 //
650 // If some DFS has a reducible cycle C with header H, then for
651 // any other DFS, H is the header of a cycle C' that is a
652 // superset of C. For a divergent branch inside the subgraph
653 // C, any join node inside C is either H, or some node
654 // encountered without passing through H.
655 //
656 auto getReducibleParent = [&](const BlockT *Block) -> const CycleT * {
657 if (!CyclePOT.isReducibleCycleHeader(Block))
658 return nullptr;
659 const auto *BlockCycle = CI.getCycle(Block);
660 if (BlockCycle->contains(&DivTermBlock))
661 return BlockCycle;
662 return nullptr;
663 };
664
665 if (const auto *BlockCycle = getReducibleParent(Block)) {
666 SmallVector<BlockT *, 4> BlockCycleExits;
667 BlockCycle->getExitBlocks(BlockCycleExits);
668 for (auto *BlockCycleExit : BlockCycleExits) {
669 CausedJoin |= visitCycleExitEdge(*BlockCycleExit, *Label);
670 LoweredFloorIdx =
671 std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(BlockCycleExit));
672 }
673 } else {
674 for (const auto *SuccBlock : successors(Block)) {
675 CausedJoin |= visitEdge(*SuccBlock, *Label);
676 LoweredFloorIdx =
677 std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(SuccBlock));
678 }
679 }
680
681 // Floor update
682 if (CausedJoin) {
683 // 1. Different labels pushed to successors
684 FloorIdx = LoweredFloorIdx;
685 } else if (FloorLabel != Label) {
686 // 2. No join caused BUT we pushed a label that is different than the
687 // last pushed label
688 FloorIdx = LoweredFloorIdx;
689 FloorLabel = Label;
690 }
691 }
692
693 LLVM_DEBUG(dbgs() << "Final labeling:\n"; printDefs(dbgs()));
694
695 // Check every cycle containing DivTermBlock for exit divergence.
696 // A cycle has exit divergence if the label of an exit block does
697 // not match the label of its header.
698 for (const auto *Cycle = CI.getCycle(&DivTermBlock); Cycle;
699 Cycle = Cycle->getParentCycle()) {
700 if (Cycle->isReducible()) {
701 // The exit divergence of a reducible cycle is recorded while
702 // propagating labels.
703 continue;
704 }
705 SmallVector<BlockT *> Exits;
706 Cycle->getExitBlocks(Exits);
707 auto *Header = Cycle->getHeader();
708 auto *HeaderLabel = BlockLabels[Header];
709 for (const auto *Exit : Exits) {
710 if (BlockLabels[Exit] != HeaderLabel) {
711 // Identified a divergent cycle exit
712 DivDesc->CycleDivBlocks.insert(Exit);
713 LLVM_DEBUG(dbgs() << "\tDivergent cycle exit: " << Context.print(Exit)
714 << "\n");
715 }
716 }
717 }
718
719 return std::move(DivDesc);
720 }
721 };
722
723 template <typename ContextT>
724 typename llvm::GenericSyncDependenceAnalysis<ContextT>::DivergenceDescriptor
725 llvm::GenericSyncDependenceAnalysis<ContextT>::EmptyDivergenceDesc;
726
727 template <typename ContextT>
GenericSyncDependenceAnalysis(const ContextT & Context,const DominatorTreeT & DT,const CycleInfoT & CI)728 llvm::GenericSyncDependenceAnalysis<ContextT>::GenericSyncDependenceAnalysis(
729 const ContextT &Context, const DominatorTreeT &DT, const CycleInfoT &CI)
730 : CyclePO(Context), DT(DT), CI(CI) {
731 CyclePO.compute(CI);
732 }
733
734 template <typename ContextT>
735 auto llvm::GenericSyncDependenceAnalysis<ContextT>::getJoinBlocks(
736 const BlockT *DivTermBlock) -> const DivergenceDescriptor & {
737 // trivial case
738 if (succ_size(DivTermBlock) <= 1) {
739 return EmptyDivergenceDesc;
740 }
741
742 // already available in cache?
743 auto ItCached = CachedControlDivDescs.find(DivTermBlock);
744 if (ItCached != CachedControlDivDescs.end())
745 return *ItCached->second;
746
747 // compute all join points
748 DivergencePropagatorT Propagator(CyclePO, DT, CI, *DivTermBlock);
749 auto DivDesc = Propagator.computeJoinPoints();
750
751 auto printBlockSet = [&](ConstBlockSet &Blocks) {
752 return Printable([&](raw_ostream &Out) {
753 Out << "[";
754 ListSeparator LS;
755 for (const auto *BB : Blocks) {
756 Out << LS << CI.getSSAContext().print(BB);
757 }
758 Out << "]\n";
759 });
760 };
761
762 LLVM_DEBUG(
763 dbgs() << "\nResult (" << CI.getSSAContext().print(DivTermBlock)
764 << "):\n JoinDivBlocks: " << printBlockSet(DivDesc->JoinDivBlocks)
765 << " CycleDivBlocks: " << printBlockSet(DivDesc->CycleDivBlocks)
766 << "\n");
767 (void)printBlockSet;
768
769 auto ItInserted =
770 CachedControlDivDescs.try_emplace(DivTermBlock, std::move(DivDesc));
771 assert(ItInserted.second);
772 return *ItInserted.first->second;
773 }
774
775 template <typename ContextT>
markDivergent(const InstructionT & I)776 bool GenericUniformityAnalysisImpl<ContextT>::markDivergent(
777 const InstructionT &I) {
778 if (I.isTerminator()) {
779 if (DivergentTermBlocks.insert(I.getParent()).second) {
780 LLVM_DEBUG(dbgs() << "marked divergent term block: "
781 << Context.print(I.getParent()) << "\n");
782 return true;
783 }
784 return false;
785 }
786
787 return markDefsDivergent(I);
788 }
789
790 template <typename ContextT>
markDivergent(ConstValueRefT Val)791 bool GenericUniformityAnalysisImpl<ContextT>::markDivergent(
792 ConstValueRefT Val) {
793 if (DivergentValues.insert(Val).second) {
794 LLVM_DEBUG(dbgs() << "marked divergent: " << Context.print(Val) << "\n");
795 return true;
796 }
797 return false;
798 }
799
800 template <typename ContextT>
addUniformOverride(const InstructionT & Instr)801 void GenericUniformityAnalysisImpl<ContextT>::addUniformOverride(
802 const InstructionT &Instr) {
803 UniformOverrides.insert(&Instr);
804 }
805
806 template <typename ContextT>
analyzeTemporalDivergence(const InstructionT & I,const CycleT & OuterDivCycle)807 void GenericUniformityAnalysisImpl<ContextT>::analyzeTemporalDivergence(
808 const InstructionT &I, const CycleT &OuterDivCycle) {
809 if (isDivergent(I))
810 return;
811
812 LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << Context.print(&I)
813 << "\n");
814 if (!usesValueFromCycle(I, OuterDivCycle))
815 return;
816
817 if (isAlwaysUniform(I))
818 return;
819
820 if (markDivergent(I))
821 Worklist.push_back(&I);
822 }
823
824 // Mark all external users of values defined inside \param
825 // OuterDivCycle as divergent.
826 //
827 // This follows all live out edges wherever they may lead. Potential
828 // users of values defined inside DivCycle could be anywhere in the
829 // dominance region of DivCycle (including its fringes for phi nodes).
830 // A cycle C dominates a block B iff every path from the entry block
831 // to B must pass through a block contained in C. If C is a reducible
832 // cycle (or natural loop), C dominates B iff the header of C
833 // dominates B. But in general, we iteratively examine cycle cycle
834 // exits and their successors.
835 template <typename ContextT>
analyzeCycleExitDivergence(const CycleT & OuterDivCycle)836 void GenericUniformityAnalysisImpl<ContextT>::analyzeCycleExitDivergence(
837 const CycleT &OuterDivCycle) {
838 // Set of blocks that are dominated by the cycle, i.e., each is only
839 // reachable from paths that pass through the cycle.
840 SmallPtrSet<BlockT *, 16> DomRegion;
841
842 // The boundary of DomRegion, formed by blocks that are not
843 // dominated by the cycle.
844 SmallVector<BlockT *> DomFrontier;
845 OuterDivCycle.getExitBlocks(DomFrontier);
846
847 // Returns true if BB is dominated by the cycle.
848 auto isInDomRegion = [&](BlockT *BB) {
849 for (auto *P : predecessors(BB)) {
850 if (OuterDivCycle.contains(P))
851 continue;
852 if (DomRegion.count(P))
853 continue;
854 return false;
855 }
856 return true;
857 };
858
859 // Keep advancing the frontier along successor edges, while
860 // promoting blocks to DomRegion.
861 while (true) {
862 bool Promoted = false;
863 SmallVector<BlockT *> Temp;
864 for (auto *W : DomFrontier) {
865 if (!isInDomRegion(W)) {
866 Temp.push_back(W);
867 continue;
868 }
869 DomRegion.insert(W);
870 Promoted = true;
871 for (auto *Succ : successors(W)) {
872 if (DomRegion.contains(Succ))
873 continue;
874 Temp.push_back(Succ);
875 }
876 }
877 if (!Promoted)
878 break;
879 DomFrontier = Temp;
880 }
881
882 // At DomFrontier, only the PHI nodes are affected by temporal
883 // divergence.
884 for (const auto *UserBlock : DomFrontier) {
885 LLVM_DEBUG(dbgs() << "Analyze phis after cycle exit: "
886 << Context.print(UserBlock) << "\n");
887 for (const auto &Phi : UserBlock->phis()) {
888 LLVM_DEBUG(dbgs() << " " << Context.print(&Phi) << "\n");
889 analyzeTemporalDivergence(Phi, OuterDivCycle);
890 }
891 }
892
893 // All instructions inside the dominance region are affected by
894 // temporal divergence.
895 for (const auto *UserBlock : DomRegion) {
896 LLVM_DEBUG(dbgs() << "Analyze non-phi users after cycle exit: "
897 << Context.print(UserBlock) << "\n");
898 for (const auto &I : *UserBlock) {
899 LLVM_DEBUG(dbgs() << " " << Context.print(&I) << "\n");
900 analyzeTemporalDivergence(I, OuterDivCycle);
901 }
902 }
903 }
904
905 template <typename ContextT>
propagateCycleExitDivergence(const BlockT & DivExit,const CycleT & InnerDivCycle)906 void GenericUniformityAnalysisImpl<ContextT>::propagateCycleExitDivergence(
907 const BlockT &DivExit, const CycleT &InnerDivCycle) {
908 LLVM_DEBUG(dbgs() << "\tpropCycleExitDiv " << Context.print(&DivExit)
909 << "\n");
910 auto *DivCycle = &InnerDivCycle;
911 auto *OuterDivCycle = DivCycle;
912 auto *ExitLevelCycle = CI.getCycle(&DivExit);
913 const unsigned CycleExitDepth =
914 ExitLevelCycle ? ExitLevelCycle->getDepth() : 0;
915
916 // Find outer-most cycle that does not contain \p DivExit
917 while (DivCycle && DivCycle->getDepth() > CycleExitDepth) {
918 LLVM_DEBUG(dbgs() << " Found exiting cycle: "
919 << Context.print(DivCycle->getHeader()) << "\n");
920 OuterDivCycle = DivCycle;
921 DivCycle = DivCycle->getParentCycle();
922 }
923 LLVM_DEBUG(dbgs() << "\tOuter-most exiting cycle: "
924 << Context.print(OuterDivCycle->getHeader()) << "\n");
925
926 if (!DivergentExitCycles.insert(OuterDivCycle).second)
927 return;
928
929 // Exit divergence does not matter if the cycle itself is assumed to
930 // be divergent.
931 for (const auto *C : AssumedDivergent) {
932 if (C->contains(OuterDivCycle))
933 return;
934 }
935
936 analyzeCycleExitDivergence(*OuterDivCycle);
937 }
938
939 template <typename ContextT>
taintAndPushAllDefs(const BlockT & BB)940 void GenericUniformityAnalysisImpl<ContextT>::taintAndPushAllDefs(
941 const BlockT &BB) {
942 LLVM_DEBUG(dbgs() << "taintAndPushAllDefs " << Context.print(&BB) << "\n");
943 for (const auto &I : instrs(BB)) {
944 // Terminators do not produce values; they are divergent only if
945 // the condition is divergent. That is handled when the divergent
946 // condition is placed in the worklist.
947 if (I.isTerminator())
948 break;
949
950 // Mark this as divergent. We don't check if the instruction is
951 // always uniform. In a cycle where the thread convergence is not
952 // statically known, the instruction is not statically converged,
953 // and its outputs cannot be statically uniform.
954 if (markDivergent(I))
955 Worklist.push_back(&I);
956 }
957 }
958
959 /// Mark divergent phi nodes in a join block
960 template <typename ContextT>
taintAndPushPhiNodes(const BlockT & JoinBlock)961 void GenericUniformityAnalysisImpl<ContextT>::taintAndPushPhiNodes(
962 const BlockT &JoinBlock) {
963 LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << Context.print(&JoinBlock)
964 << "\n");
965 for (const auto &Phi : JoinBlock.phis()) {
966 if (ContextT::isConstantValuePhi(Phi))
967 continue;
968 if (markDivergent(Phi))
969 Worklist.push_back(&Phi);
970 }
971 }
972
973 /// Add \p Candidate to \p Cycles if it is not already contained in \p Cycles.
974 ///
975 /// \return true iff \p Candidate was added to \p Cycles.
976 template <typename CycleT>
insertIfNotContained(SmallVector<CycleT * > & Cycles,CycleT * Candidate)977 static bool insertIfNotContained(SmallVector<CycleT *> &Cycles,
978 CycleT *Candidate) {
979 if (llvm::any_of(Cycles,
980 [Candidate](CycleT *C) { return C->contains(Candidate); }))
981 return false;
982 Cycles.push_back(Candidate);
983 return true;
984 }
985
986 /// Return the outermost cycle made divergent by branch outside it.
987 ///
988 /// If two paths that diverged outside an irreducible cycle join
989 /// inside that cycle, then that whole cycle is assumed to be
990 /// divergent. This does not apply if the cycle is reducible.
991 template <typename CycleT, typename BlockT>
getExtDivCycle(const CycleT * Cycle,const BlockT * DivTermBlock,const BlockT * JoinBlock)992 static const CycleT *getExtDivCycle(const CycleT *Cycle,
993 const BlockT *DivTermBlock,
994 const BlockT *JoinBlock) {
995 assert(Cycle);
996 assert(Cycle->contains(JoinBlock));
997
998 if (Cycle->contains(DivTermBlock))
999 return nullptr;
1000
1001 if (Cycle->isReducible()) {
1002 assert(Cycle->getHeader() == JoinBlock);
1003 return nullptr;
1004 }
1005
1006 const auto *Parent = Cycle->getParentCycle();
1007 while (Parent && !Parent->contains(DivTermBlock)) {
1008 // If the join is inside a child, then the parent must be
1009 // irreducible. The only join in a reducible cyle is its own
1010 // header.
1011 assert(!Parent->isReducible());
1012 Cycle = Parent;
1013 Parent = Cycle->getParentCycle();
1014 }
1015
1016 LLVM_DEBUG(dbgs() << "cycle made divergent by external branch\n");
1017 return Cycle;
1018 }
1019
1020 /// Return the outermost cycle made divergent by branch inside it.
1021 ///
1022 /// This checks the "diverged entry" criterion defined in the
1023 /// docs/ConvergenceAnalysis.html.
1024 template <typename ContextT, typename CycleT, typename BlockT,
1025 typename DominatorTreeT>
1026 static const CycleT *
getIntDivCycle(const CycleT * Cycle,const BlockT * DivTermBlock,const BlockT * JoinBlock,const DominatorTreeT & DT,ContextT & Context)1027 getIntDivCycle(const CycleT *Cycle, const BlockT *DivTermBlock,
1028 const BlockT *JoinBlock, const DominatorTreeT &DT,
1029 ContextT &Context) {
1030 LLVM_DEBUG(dbgs() << "examine join " << Context.print(JoinBlock)
1031 << "for internal branch " << Context.print(DivTermBlock)
1032 << "\n");
1033 if (DT.properlyDominates(DivTermBlock, JoinBlock))
1034 return nullptr;
1035
1036 // Find the smallest common cycle, if one exists.
1037 assert(Cycle && Cycle->contains(JoinBlock));
1038 while (Cycle && !Cycle->contains(DivTermBlock)) {
1039 Cycle = Cycle->getParentCycle();
1040 }
1041 if (!Cycle || Cycle->isReducible())
1042 return nullptr;
1043
1044 if (DT.properlyDominates(Cycle->getHeader(), JoinBlock))
1045 return nullptr;
1046
1047 LLVM_DEBUG(dbgs() << " header " << Context.print(Cycle->getHeader())
1048 << " does not dominate join\n");
1049
1050 const auto *Parent = Cycle->getParentCycle();
1051 while (Parent && !DT.properlyDominates(Parent->getHeader(), JoinBlock)) {
1052 LLVM_DEBUG(dbgs() << " header " << Context.print(Parent->getHeader())
1053 << " does not dominate join\n");
1054 Cycle = Parent;
1055 Parent = Parent->getParentCycle();
1056 }
1057
1058 LLVM_DEBUG(dbgs() << " cycle made divergent by internal branch\n");
1059 return Cycle;
1060 }
1061
1062 template <typename ContextT, typename CycleT, typename BlockT,
1063 typename DominatorTreeT>
1064 static const CycleT *
getOutermostDivergentCycle(const CycleT * Cycle,const BlockT * DivTermBlock,const BlockT * JoinBlock,const DominatorTreeT & DT,ContextT & Context)1065 getOutermostDivergentCycle(const CycleT *Cycle, const BlockT *DivTermBlock,
1066 const BlockT *JoinBlock, const DominatorTreeT &DT,
1067 ContextT &Context) {
1068 if (!Cycle)
1069 return nullptr;
1070
1071 // First try to expand Cycle to the largest that contains JoinBlock
1072 // but not DivTermBlock.
1073 const auto *Ext = getExtDivCycle(Cycle, DivTermBlock, JoinBlock);
1074
1075 // Continue expanding to the largest cycle that contains both.
1076 const auto *Int = getIntDivCycle(Cycle, DivTermBlock, JoinBlock, DT, Context);
1077
1078 if (Int)
1079 return Int;
1080 return Ext;
1081 }
1082
1083 template <typename ContextT>
analyzeControlDivergence(const InstructionT & Term)1084 void GenericUniformityAnalysisImpl<ContextT>::analyzeControlDivergence(
1085 const InstructionT &Term) {
1086 const auto *DivTermBlock = Term.getParent();
1087 DivergentTermBlocks.insert(DivTermBlock);
1088 LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Context.print(DivTermBlock)
1089 << "\n");
1090
1091 // Don't propagate divergence from unreachable blocks.
1092 if (!DT.isReachableFromEntry(DivTermBlock))
1093 return;
1094
1095 const auto &DivDesc = SDA.getJoinBlocks(DivTermBlock);
1096 SmallVector<const CycleT *> DivCycles;
1097
1098 // Iterate over all blocks now reachable by a disjoint path join
1099 for (const auto *JoinBlock : DivDesc.JoinDivBlocks) {
1100 const auto *Cycle = CI.getCycle(JoinBlock);
1101 LLVM_DEBUG(dbgs() << "visiting join block " << Context.print(JoinBlock)
1102 << "\n");
1103 if (const auto *Outermost = getOutermostDivergentCycle(
1104 Cycle, DivTermBlock, JoinBlock, DT, Context)) {
1105 LLVM_DEBUG(dbgs() << "found divergent cycle\n");
1106 DivCycles.push_back(Outermost);
1107 continue;
1108 }
1109 taintAndPushPhiNodes(*JoinBlock);
1110 }
1111
1112 // Sort by order of decreasing depth. This allows later cycles to be skipped
1113 // because they are already contained in earlier ones.
1114 llvm::sort(DivCycles, [](const CycleT *A, const CycleT *B) {
1115 return A->getDepth() > B->getDepth();
1116 });
1117
1118 // Cycles that are assumed divergent due to the diverged entry
1119 // criterion potentially contain temporal divergence depending on
1120 // the DFS chosen. Conservatively, all values produced in such a
1121 // cycle are assumed divergent. "Cycle invariant" values may be
1122 // assumed uniform, but that requires further analysis.
1123 for (auto *C : DivCycles) {
1124 if (!insertIfNotContained(AssumedDivergent, C))
1125 continue;
1126 LLVM_DEBUG(dbgs() << "process divergent cycle\n");
1127 for (const BlockT *BB : C->blocks()) {
1128 taintAndPushAllDefs(*BB);
1129 }
1130 }
1131
1132 const auto *BranchCycle = CI.getCycle(DivTermBlock);
1133 assert(DivDesc.CycleDivBlocks.empty() || BranchCycle);
1134 for (const auto *DivExitBlock : DivDesc.CycleDivBlocks) {
1135 propagateCycleExitDivergence(*DivExitBlock, *BranchCycle);
1136 }
1137 }
1138
1139 template <typename ContextT>
compute()1140 void GenericUniformityAnalysisImpl<ContextT>::compute() {
1141 // Initialize worklist.
1142 auto DivValuesCopy = DivergentValues;
1143 for (const auto DivVal : DivValuesCopy) {
1144 assert(isDivergent(DivVal) && "Worklist invariant violated!");
1145 pushUsers(DivVal);
1146 }
1147
1148 // All values on the Worklist are divergent.
1149 // Their users may not have been updated yet.
1150 while (!Worklist.empty()) {
1151 const InstructionT *I = Worklist.back();
1152 Worklist.pop_back();
1153
1154 LLVM_DEBUG(dbgs() << "worklist pop: " << Context.print(I) << "\n");
1155
1156 if (I->isTerminator()) {
1157 analyzeControlDivergence(*I);
1158 continue;
1159 }
1160
1161 // propagate value divergence to users
1162 assert(isDivergent(*I) && "Worklist invariant violated!");
1163 pushUsers(*I);
1164 }
1165 }
1166
1167 template <typename ContextT>
isAlwaysUniform(const InstructionT & Instr)1168 bool GenericUniformityAnalysisImpl<ContextT>::isAlwaysUniform(
1169 const InstructionT &Instr) const {
1170 return UniformOverrides.contains(&Instr);
1171 }
1172
1173 template <typename ContextT>
GenericUniformityInfo(FunctionT & Func,const DominatorTreeT & DT,const CycleInfoT & CI,const TargetTransformInfo * TTI)1174 GenericUniformityInfo<ContextT>::GenericUniformityInfo(
1175 FunctionT &Func, const DominatorTreeT &DT, const CycleInfoT &CI,
1176 const TargetTransformInfo *TTI)
1177 : F(&Func) {
1178 DA.reset(new ImplT{Func, DT, CI, TTI});
1179 DA->initialize();
1180 DA->compute();
1181 }
1182
1183 template <typename ContextT>
print(raw_ostream & OS)1184 void GenericUniformityAnalysisImpl<ContextT>::print(raw_ostream &OS) const {
1185 bool haveDivergentArgs = false;
1186
1187 if (DivergentValues.empty()) {
1188 assert(DivergentTermBlocks.empty());
1189 assert(DivergentExitCycles.empty());
1190 OS << "ALL VALUES UNIFORM\n";
1191 return;
1192 }
1193
1194 for (const auto &entry : DivergentValues) {
1195 const BlockT *parent = Context.getDefBlock(entry);
1196 if (!parent) {
1197 if (!haveDivergentArgs) {
1198 OS << "DIVERGENT ARGUMENTS:\n";
1199 haveDivergentArgs = true;
1200 }
1201 OS << " DIVERGENT: " << Context.print(entry) << '\n';
1202 }
1203 }
1204
1205 if (!AssumedDivergent.empty()) {
1206 OS << "CYCLES ASSSUMED DIVERGENT:\n";
1207 for (const CycleT *cycle : AssumedDivergent) {
1208 OS << " " << cycle->print(Context) << '\n';
1209 }
1210 }
1211
1212 if (!DivergentExitCycles.empty()) {
1213 OS << "CYCLES WITH DIVERGENT EXIT:\n";
1214 for (const CycleT *cycle : DivergentExitCycles) {
1215 OS << " " << cycle->print(Context) << '\n';
1216 }
1217 }
1218
1219 for (auto &block : F) {
1220 OS << "\nBLOCK " << Context.print(&block) << '\n';
1221
1222 OS << "DEFINITIONS\n";
1223 SmallVector<ConstValueRefT, 16> defs;
1224 Context.appendBlockDefs(defs, block);
1225 for (auto value : defs) {
1226 if (isDivergent(value))
1227 OS << " DIVERGENT: ";
1228 else
1229 OS << " ";
1230 OS << Context.print(value) << '\n';
1231 }
1232
1233 OS << "TERMINATORS\n";
1234 SmallVector<const InstructionT *, 8> terms;
1235 Context.appendBlockTerms(terms, block);
1236 bool divergentTerminators = hasDivergentTerminator(block);
1237 for (auto *T : terms) {
1238 if (divergentTerminators)
1239 OS << " DIVERGENT: ";
1240 else
1241 OS << " ";
1242 OS << Context.print(T) << '\n';
1243 }
1244
1245 OS << "END BLOCK\n";
1246 }
1247 }
1248
1249 template <typename ContextT>
hasDivergence()1250 bool GenericUniformityInfo<ContextT>::hasDivergence() const {
1251 return DA->hasDivergence();
1252 }
1253
1254 /// Whether \p V is divergent at its definition.
1255 template <typename ContextT>
isDivergent(ConstValueRefT V)1256 bool GenericUniformityInfo<ContextT>::isDivergent(ConstValueRefT V) const {
1257 return DA->isDivergent(V);
1258 }
1259
1260 template <typename ContextT>
hasDivergentTerminator(const BlockT & B)1261 bool GenericUniformityInfo<ContextT>::hasDivergentTerminator(const BlockT &B) {
1262 return DA->hasDivergentTerminator(B);
1263 }
1264
1265 /// \brief T helper function for printing.
1266 template <typename ContextT>
print(raw_ostream & out)1267 void GenericUniformityInfo<ContextT>::print(raw_ostream &out) const {
1268 DA->print(out);
1269 }
1270
1271 template <typename ContextT>
computeStackPO(SmallVectorImpl<BlockT * > & Stack,const CycleInfoT & CI,const CycleT * Cycle,SmallPtrSetImpl<BlockT * > & Finalized)1272 void llvm::ModifiedPostOrder<ContextT>::computeStackPO(
1273 SmallVectorImpl<BlockT *> &Stack, const CycleInfoT &CI, const CycleT *Cycle,
1274 SmallPtrSetImpl<BlockT *> &Finalized) {
1275 LLVM_DEBUG(dbgs() << "inside computeStackPO\n");
1276 while (!Stack.empty()) {
1277 auto *NextBB = Stack.back();
1278 if (Finalized.count(NextBB)) {
1279 Stack.pop_back();
1280 continue;
1281 }
1282 LLVM_DEBUG(dbgs() << " visiting " << CI.getSSAContext().print(NextBB)
1283 << "\n");
1284 auto *NestedCycle = CI.getCycle(NextBB);
1285 if (Cycle != NestedCycle && (!Cycle || Cycle->contains(NestedCycle))) {
1286 LLVM_DEBUG(dbgs() << " found a cycle\n");
1287 while (NestedCycle->getParentCycle() != Cycle)
1288 NestedCycle = NestedCycle->getParentCycle();
1289
1290 SmallVector<BlockT *, 3> NestedExits;
1291 NestedCycle->getExitBlocks(NestedExits);
1292 bool PushedNodes = false;
1293 for (auto *NestedExitBB : NestedExits) {
1294 LLVM_DEBUG(dbgs() << " examine exit: "
1295 << CI.getSSAContext().print(NestedExitBB) << "\n");
1296 if (Cycle && !Cycle->contains(NestedExitBB))
1297 continue;
1298 if (Finalized.count(NestedExitBB))
1299 continue;
1300 PushedNodes = true;
1301 Stack.push_back(NestedExitBB);
1302 LLVM_DEBUG(dbgs() << " pushed exit: "
1303 << CI.getSSAContext().print(NestedExitBB) << "\n");
1304 }
1305 if (!PushedNodes) {
1306 // All loop exits finalized -> finish this node
1307 Stack.pop_back();
1308 computeCyclePO(CI, NestedCycle, Finalized);
1309 }
1310 continue;
1311 }
1312
1313 LLVM_DEBUG(dbgs() << " no nested cycle, going into DAG\n");
1314 // DAG-style
1315 bool PushedNodes = false;
1316 for (auto *SuccBB : successors(NextBB)) {
1317 LLVM_DEBUG(dbgs() << " examine succ: "
1318 << CI.getSSAContext().print(SuccBB) << "\n");
1319 if (Cycle && !Cycle->contains(SuccBB))
1320 continue;
1321 if (Finalized.count(SuccBB))
1322 continue;
1323 PushedNodes = true;
1324 Stack.push_back(SuccBB);
1325 LLVM_DEBUG(dbgs() << " pushed succ: " << CI.getSSAContext().print(SuccBB)
1326 << "\n");
1327 }
1328 if (!PushedNodes) {
1329 // Never push nodes twice
1330 LLVM_DEBUG(dbgs() << " finishing node: "
1331 << CI.getSSAContext().print(NextBB) << "\n");
1332 Stack.pop_back();
1333 Finalized.insert(NextBB);
1334 appendBlock(*NextBB);
1335 }
1336 }
1337 LLVM_DEBUG(dbgs() << "exited computeStackPO\n");
1338 }
1339
1340 template <typename ContextT>
computeCyclePO(const CycleInfoT & CI,const CycleT * Cycle,SmallPtrSetImpl<BlockT * > & Finalized)1341 void ModifiedPostOrder<ContextT>::computeCyclePO(
1342 const CycleInfoT &CI, const CycleT *Cycle,
1343 SmallPtrSetImpl<BlockT *> &Finalized) {
1344 LLVM_DEBUG(dbgs() << "inside computeCyclePO\n");
1345 SmallVector<BlockT *> Stack;
1346 auto *CycleHeader = Cycle->getHeader();
1347
1348 LLVM_DEBUG(dbgs() << " noted header: "
1349 << CI.getSSAContext().print(CycleHeader) << "\n");
1350 assert(!Finalized.count(CycleHeader));
1351 Finalized.insert(CycleHeader);
1352
1353 // Visit the header last
1354 LLVM_DEBUG(dbgs() << " finishing header: "
1355 << CI.getSSAContext().print(CycleHeader) << "\n");
1356 appendBlock(*CycleHeader, Cycle->isReducible());
1357
1358 // Initialize with immediate successors
1359 for (auto *BB : successors(CycleHeader)) {
1360 LLVM_DEBUG(dbgs() << " examine succ: " << CI.getSSAContext().print(BB)
1361 << "\n");
1362 if (!Cycle->contains(BB))
1363 continue;
1364 if (BB == CycleHeader)
1365 continue;
1366 if (!Finalized.count(BB)) {
1367 LLVM_DEBUG(dbgs() << " pushed succ: " << CI.getSSAContext().print(BB)
1368 << "\n");
1369 Stack.push_back(BB);
1370 }
1371 }
1372
1373 // Compute PO inside region
1374 computeStackPO(Stack, CI, Cycle, Finalized);
1375
1376 LLVM_DEBUG(dbgs() << "exited computeCyclePO\n");
1377 }
1378
1379 /// \brief Generically compute the modified post order.
1380 template <typename ContextT>
compute(const CycleInfoT & CI)1381 void llvm::ModifiedPostOrder<ContextT>::compute(const CycleInfoT &CI) {
1382 SmallPtrSet<BlockT *, 32> Finalized;
1383 SmallVector<BlockT *> Stack;
1384 auto *F = CI.getFunction();
1385 Stack.reserve(24); // FIXME made-up number
1386 Stack.push_back(GraphTraits<FunctionT *>::getEntryNode(F));
1387 computeStackPO(Stack, CI, nullptr, Finalized);
1388 }
1389
1390 } // namespace llvm
1391
1392 #undef DEBUG_TYPE
1393
1394 #endif // LLVM_ADT_GENERICUNIFORMITYIMPL_H
1395