xref: /openbsd-src/gnu/llvm/llvm/include/llvm/ADT/GenericUniformityImpl.h (revision d415bd752c734aee168c4ee86ff32e8cc249eb16)
1 //===- GenericUniformAnalysis.cpp --------------------*- C++ -*------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This template implementation resides in a separate file so that it
10 // does not get injected into every .cpp file that includes the
11 // generic header.
12 //
13 // DO NOT INCLUDE THIS FILE WHEN MERELY USING UNIFORMITYINFO.
14 //
15 // This file should only be included by files that implement a
16 // specialization of the relvant templates. Currently these are:
17 // - UniformityAnalysis.cpp
18 //
19 // Note: The DEBUG_TYPE macro should be defined before using this
20 // file so that any use of LLVM_DEBUG is associated with the
21 // including file rather than this file.
22 //
23 //===----------------------------------------------------------------------===//
24 ///
25 /// \file
26 /// \brief Implementation of uniformity analysis.
27 ///
28 /// The algorithm is a fixed point iteration that starts with the assumption
29 /// that all control flow and all values are uniform. Starting from sources of
30 /// divergence (whose discovery must be implemented by a CFG- or even
31 /// target-specific derived class), divergence of values is propagated from
32 /// definition to uses in a straight-forward way. The main complexity lies in
33 /// the propagation of the impact of divergent control flow on the divergence of
34 /// values (sync dependencies).
35 ///
36 //===----------------------------------------------------------------------===//
37 
38 #ifndef LLVM_ADT_GENERICUNIFORMITYIMPL_H
39 #define LLVM_ADT_GENERICUNIFORMITYIMPL_H
40 
41 #include "llvm/ADT/GenericUniformityInfo.h"
42 
43 #include "llvm/ADT/SmallPtrSet.h"
44 #include "llvm/ADT/SparseBitVector.h"
45 #include "llvm/ADT/StringExtras.h"
46 #include "llvm/Support/raw_ostream.h"
47 
48 #include <set>
49 
50 #define DEBUG_TYPE "uniformity"
51 
52 using namespace llvm;
53 
54 namespace llvm {
55 
unique(Range && R)56 template <typename Range> auto unique(Range &&R) {
57   return std::unique(adl_begin(R), adl_end(R));
58 }
59 
60 /// Construct a specially modified post-order traversal of cycles.
61 ///
62 /// The ModifiedPO is contructed using a virtually modified CFG as follows:
63 ///
64 /// 1. The successors of pre-entry nodes (predecessors of an cycle
65 ///    entry that are outside the cycle) are replaced by the
66 ///    successors of the successors of the header.
67 /// 2. Successors of the cycle header are replaced by the exit blocks
68 ///    of the cycle.
69 ///
70 /// Effectively, we produce a depth-first numbering with the following
71 /// properties:
72 ///
73 /// 1. Nodes after a cycle are numbered earlier than the cycle header.
74 /// 2. The header is numbered earlier than the nodes in the cycle.
75 /// 3. The numbering of the nodes within the cycle forms an interval
76 ///    starting with the header.
77 ///
78 /// Effectively, the virtual modification arranges the nodes in a
79 /// cycle as a DAG with the header as the sole leaf, and successors of
80 /// the header as the roots. A reverse traversal of this numbering has
81 /// the following invariant on the unmodified original CFG:
82 ///
83 ///    Each node is visited after all its predecessors, except if that
84 ///    predecessor is the cycle header.
85 ///
86 template <typename ContextT> class ModifiedPostOrder {
87 public:
88   using BlockT = typename ContextT::BlockT;
89   using FunctionT = typename ContextT::FunctionT;
90   using DominatorTreeT = typename ContextT::DominatorTreeT;
91 
92   using CycleInfoT = GenericCycleInfo<ContextT>;
93   using CycleT = typename CycleInfoT::CycleT;
94   using const_iterator = typename std::vector<BlockT *>::const_iterator;
95 
ModifiedPostOrder(const ContextT & C)96   ModifiedPostOrder(const ContextT &C) : Context(C) {}
97 
empty()98   bool empty() const { return m_order.empty(); }
size()99   size_t size() const { return m_order.size(); }
100 
clear()101   void clear() { m_order.clear(); }
102   void compute(const CycleInfoT &CI);
103 
count(BlockT * BB)104   unsigned count(BlockT *BB) const { return POIndex.count(BB); }
105   const BlockT *operator[](size_t idx) const { return m_order[idx]; }
106 
107   void appendBlock(const BlockT &BB, bool isReducibleCycleHeader = false) {
108     POIndex[&BB] = m_order.size();
109     m_order.push_back(&BB);
110     LLVM_DEBUG(dbgs() << "ModifiedPO(" << POIndex[&BB]
111                       << "): " << Context.print(&BB) << "\n");
112     if (isReducibleCycleHeader)
113       ReducibleCycleHeaders.insert(&BB);
114   }
115 
getIndex(const BlockT * BB)116   unsigned getIndex(const BlockT *BB) const {
117     assert(POIndex.count(BB));
118     return POIndex.lookup(BB);
119   }
120 
isReducibleCycleHeader(const BlockT * BB)121   bool isReducibleCycleHeader(const BlockT *BB) const {
122     return ReducibleCycleHeaders.contains(BB);
123   }
124 
125 private:
126   SmallVector<const BlockT *> m_order;
127   DenseMap<const BlockT *, unsigned> POIndex;
128   SmallPtrSet<const BlockT *, 32> ReducibleCycleHeaders;
129   const ContextT &Context;
130 
131   void computeCyclePO(const CycleInfoT &CI, const CycleT *Cycle,
132                       SmallPtrSetImpl<BlockT *> &Finalized);
133 
134   void computeStackPO(SmallVectorImpl<BlockT *> &Stack, const CycleInfoT &CI,
135                       const CycleT *Cycle,
136                       SmallPtrSetImpl<BlockT *> &Finalized);
137 };
138 
139 template <typename> class DivergencePropagator;
140 
141 /// \class GenericSyncDependenceAnalysis
142 ///
143 /// \brief Locate join blocks for disjoint paths starting at a divergent branch.
144 ///
145 /// An analysis per divergent branch that returns the set of basic
146 /// blocks whose phi nodes become divergent due to divergent control.
147 /// These are the blocks that are reachable by two disjoint paths from
148 /// the branch, or cycle exits reachable along a path that is disjoint
149 /// from a path to the cycle latch.
150 
151 // --- Above line is not a doxygen comment; intentionally left blank ---
152 //
153 // Originally implemented in SyncDependenceAnalysis.cpp for DivergenceAnalysis.
154 //
155 // The SyncDependenceAnalysis is used in the UniformityAnalysis to model
156 // control-induced divergence in phi nodes.
157 //
158 // -- Reference --
159 // The algorithm is an extension of Section 5 of
160 //
161 //   An abstract interpretation for SPMD divergence
162 //       on reducible control flow graphs.
163 //   Julian Rosemann, Simon Moll and Sebastian Hack
164 //   POPL '21
165 //
166 //
167 // -- Sync dependence --
168 // Sync dependence characterizes the control flow aspect of the
169 // propagation of branch divergence. For example,
170 //
171 //   %cond = icmp slt i32 %tid, 10
172 //   br i1 %cond, label %then, label %else
173 // then:
174 //   br label %merge
175 // else:
176 //   br label %merge
177 // merge:
178 //   %a = phi i32 [ 0, %then ], [ 1, %else ]
179 //
180 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
181 // because %tid is not on its use-def chains, %a is sync dependent on %tid
182 // because the branch "br i1 %cond" depends on %tid and affects which value %a
183 // is assigned to.
184 //
185 //
186 // -- Reduction to SSA construction --
187 // There are two disjoint paths from A to X, if a certain variant of SSA
188 // construction places a phi node in X under the following set-up scheme.
189 //
190 // This variant of SSA construction ignores incoming undef values.
191 // That is paths from the entry without a definition do not result in
192 // phi nodes.
193 //
194 //       entry
195 //     /      \
196 //    A        \
197 //  /   \       Y
198 // B     C     /
199 //  \   /  \  /
200 //    D     E
201 //     \   /
202 //       F
203 //
204 // Assume that A contains a divergent branch. We are interested
205 // in the set of all blocks where each block is reachable from A
206 // via two disjoint paths. This would be the set {D, F} in this
207 // case.
208 // To generally reduce this query to SSA construction we introduce
209 // a virtual variable x and assign to x different values in each
210 // successor block of A.
211 //
212 //           entry
213 //         /      \
214 //        A        \
215 //      /   \       Y
216 // x = 0   x = 1   /
217 //      \  /   \  /
218 //        D     E
219 //         \   /
220 //           F
221 //
222 // Our flavor of SSA construction for x will construct the following
223 //
224 //            entry
225 //          /      \
226 //         A        \
227 //       /   \       Y
228 // x0 = 0   x1 = 1  /
229 //       \   /   \ /
230 //     x2 = phi   E
231 //         \     /
232 //         x3 = phi
233 //
234 // The blocks D and F contain phi nodes and are thus each reachable
235 // by two disjoins paths from A.
236 //
237 // -- Remarks --
238 // * In case of cycle exits we need to check for temporal divergence.
239 //   To this end, we check whether the definition of x differs between the
240 //   cycle exit and the cycle header (_after_ SSA construction).
241 //
242 // * In the presence of irreducible control flow, the fixed point is
243 //   reached only after multiple iterations. This is because labels
244 //   reaching the header of a cycle must be repropagated through the
245 //   cycle. This is true even in a reducible cycle, since the labels
246 //   may have been produced by a nested irreducible cycle.
247 //
248 // * Note that SyncDependenceAnalysis is not concerned with the points
249 //   of convergence in an irreducible cycle. It's only purpose is to
250 //   identify join blocks. The "diverged entry" criterion is
251 //   separately applied on join blocks to determine if an entire
252 //   irreducible cycle is assumed to be divergent.
253 //
254 // * Relevant related work:
255 //     A simple algorithm for global data flow analysis problems.
256 //     Matthew S. Hecht and Jeffrey D. Ullman.
257 //     SIAM Journal on Computing, 4(4):519–532, December 1975.
258 //
259 template <typename ContextT> class GenericSyncDependenceAnalysis {
260 public:
261   using BlockT = typename ContextT::BlockT;
262   using DominatorTreeT = typename ContextT::DominatorTreeT;
263   using FunctionT = typename ContextT::FunctionT;
264   using ValueRefT = typename ContextT::ValueRefT;
265   using InstructionT = typename ContextT::InstructionT;
266 
267   using CycleInfoT = GenericCycleInfo<ContextT>;
268   using CycleT = typename CycleInfoT::CycleT;
269 
270   using ConstBlockSet = SmallPtrSet<const BlockT *, 4>;
271   using ModifiedPO = ModifiedPostOrder<ContextT>;
272 
273   // * if BlockLabels[B] == C then C is the dominating definition at
274   //   block B
275   // * if BlockLabels[B] == nullptr then we haven't seen B yet
276   // * if BlockLabels[B] == B then:
277   //   - B is a join point of disjoint paths from X, or,
278   //   - B is an immediate successor of X (initial value), or,
279   //   - B is X
280   using BlockLabelMap = DenseMap<const BlockT *, const BlockT *>;
281 
282   /// Information discovered by the sync dependence analysis for each
283   /// divergent branch.
284   struct DivergenceDescriptor {
285     // Join points of diverged paths.
286     ConstBlockSet JoinDivBlocks;
287     // Divergent cycle exits
288     ConstBlockSet CycleDivBlocks;
289     // Labels assigned to blocks on diverged paths.
290     BlockLabelMap BlockLabels;
291   };
292 
293   using DivergencePropagatorT = DivergencePropagator<ContextT>;
294 
295   GenericSyncDependenceAnalysis(const ContextT &Context,
296                                 const DominatorTreeT &DT, const CycleInfoT &CI);
297 
298   /// \brief Computes divergent join points and cycle exits caused by branch
299   /// divergence in \p Term.
300   ///
301   /// This returns a pair of sets:
302   /// * The set of blocks which are reachable by disjoint paths from
303   ///   \p Term.
304   /// * The set also contains cycle exits if there two disjoint paths:
305   ///   one from \p Term to the cycle exit and another from \p Term to
306   ///   the cycle header.
307   const DivergenceDescriptor &getJoinBlocks(const BlockT *DivTermBlock);
308 
309 private:
310   static DivergenceDescriptor EmptyDivergenceDesc;
311 
312   ModifiedPO CyclePO;
313 
314   const DominatorTreeT &DT;
315   const CycleInfoT &CI;
316 
317   DenseMap<const BlockT *, std::unique_ptr<DivergenceDescriptor>>
318       CachedControlDivDescs;
319 };
320 
321 /// \brief Analysis that identifies uniform values in a data-parallel
322 /// execution.
323 ///
324 /// This analysis propagates divergence in a data-parallel context
325 /// from sources of divergence to all users. It can be instantiated
326 /// for an IR that provides a suitable SSAContext.
327 template <typename ContextT> class GenericUniformityAnalysisImpl {
328 public:
329   using BlockT = typename ContextT::BlockT;
330   using FunctionT = typename ContextT::FunctionT;
331   using ValueRefT = typename ContextT::ValueRefT;
332   using ConstValueRefT = typename ContextT::ConstValueRefT;
333   using InstructionT = typename ContextT::InstructionT;
334   using DominatorTreeT = typename ContextT::DominatorTreeT;
335 
336   using CycleInfoT = GenericCycleInfo<ContextT>;
337   using CycleT = typename CycleInfoT::CycleT;
338 
339   using SyncDependenceAnalysisT = GenericSyncDependenceAnalysis<ContextT>;
340   using DivergenceDescriptorT =
341       typename SyncDependenceAnalysisT::DivergenceDescriptor;
342   using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap;
343 
GenericUniformityAnalysisImpl(const FunctionT & F,const DominatorTreeT & DT,const CycleInfoT & CI,const TargetTransformInfo * TTI)344   GenericUniformityAnalysisImpl(const FunctionT &F, const DominatorTreeT &DT,
345                                 const CycleInfoT &CI,
346                                 const TargetTransformInfo *TTI)
347       : Context(CI.getSSAContext()), F(F), CI(CI), TTI(TTI), DT(DT),
348         SDA(Context, DT, CI) {}
349 
350   void initialize();
351 
getFunction()352   const FunctionT &getFunction() const { return F; }
353 
354   /// \brief Mark \p UniVal as a value that is always uniform.
355   void addUniformOverride(const InstructionT &Instr);
356 
357   /// \brief Mark \p DivVal as a value that is always divergent.
358   /// \returns Whether the tracked divergence state of \p DivVal changed.
359   bool markDivergent(const InstructionT &I);
360   bool markDivergent(ConstValueRefT DivVal);
361   bool markDefsDivergent(const InstructionT &Instr,
362                          bool AllDefsDivergent = true);
363 
364   /// \brief Propagate divergence to all instructions in the region.
365   /// Divergence is seeded by calls to \p markDivergent.
366   void compute();
367 
368   /// \brief Whether any value was marked or analyzed to be divergent.
hasDivergence()369   bool hasDivergence() const { return !DivergentValues.empty(); }
370 
371   /// \brief Whether \p Val will always return a uniform value regardless of its
372   /// operands
373   bool isAlwaysUniform(const InstructionT &Instr) const;
374 
375   bool hasDivergentDefs(const InstructionT &I) const;
376 
isDivergent(const InstructionT & I)377   bool isDivergent(const InstructionT &I) const {
378     if (I.isTerminator()) {
379       return DivergentTermBlocks.contains(I.getParent());
380     }
381     return hasDivergentDefs(I);
382   };
383 
384   /// \brief Whether \p Val is divergent at its definition.
isDivergent(ConstValueRefT V)385   bool isDivergent(ConstValueRefT V) const { return DivergentValues.count(V); }
386 
hasDivergentTerminator(const BlockT & B)387   bool hasDivergentTerminator(const BlockT &B) const {
388     return DivergentTermBlocks.contains(&B);
389   }
390 
391   void print(raw_ostream &out) const;
392 
393 protected:
394   /// \brief Value/block pair representing a single phi input.
395   struct PhiInput {
396     ConstValueRefT value;
397     BlockT *predBlock;
398 
PhiInputPhiInput399     PhiInput(ConstValueRefT value, BlockT *predBlock)
400         : value(value), predBlock(predBlock) {}
401   };
402 
403   const ContextT &Context;
404   const FunctionT &F;
405   const CycleInfoT &CI;
406   const TargetTransformInfo *TTI = nullptr;
407 
408   // Detected/marked divergent values.
409   std::set<ConstValueRefT> DivergentValues;
410   SmallPtrSet<const BlockT *, 32> DivergentTermBlocks;
411 
412   // Internal worklist for divergence propagation.
413   std::vector<const InstructionT *> Worklist;
414 
415   /// \brief Mark \p Term as divergent and push all Instructions that become
416   /// divergent as a result on the worklist.
417   void analyzeControlDivergence(const InstructionT &Term);
418 
419 private:
420   const DominatorTreeT &DT;
421 
422   // Recognized cycles with divergent exits.
423   SmallPtrSet<const CycleT *, 16> DivergentExitCycles;
424 
425   // Cycles assumed to be divergent.
426   //
427   // We don't use a set here because every insertion needs an explicit
428   // traversal of all existing members.
429   SmallVector<const CycleT *> AssumedDivergent;
430 
431   // The SDA links divergent branches to divergent control-flow joins.
432   SyncDependenceAnalysisT SDA;
433 
434   // Set of known-uniform values.
435   SmallPtrSet<const InstructionT *, 32> UniformOverrides;
436 
437   /// \brief Mark all nodes in \p JoinBlock as divergent and push them on
438   /// the worklist.
439   void taintAndPushAllDefs(const BlockT &JoinBlock);
440 
441   /// \brief Mark all phi nodes in \p JoinBlock as divergent and push them on
442   /// the worklist.
443   void taintAndPushPhiNodes(const BlockT &JoinBlock);
444 
445   /// \brief Identify all Instructions that become divergent because \p DivExit
446   /// is a divergent cycle exit of \p DivCycle. Mark those instructions as
447   /// divergent and push them on the worklist.
448   void propagateCycleExitDivergence(const BlockT &DivExit,
449                                     const CycleT &DivCycle);
450 
451   /// \brief Internal implementation function for propagateCycleExitDivergence.
452   void analyzeCycleExitDivergence(const CycleT &OuterDivCycle);
453 
454   /// \brief Mark all instruction as divergent that use a value defined in \p
455   /// OuterDivCycle. Push their users on the worklist.
456   void analyzeTemporalDivergence(const InstructionT &I,
457                                  const CycleT &OuterDivCycle);
458 
459   /// \brief Push all users of \p Val (in the region) to the worklist.
460   void pushUsers(const InstructionT &I);
461   void pushUsers(ConstValueRefT V);
462 
463   bool usesValueFromCycle(const InstructionT &I, const CycleT &DefCycle) const;
464 
465   /// \brief Whether \p Val is divergent when read in \p ObservingBlock.
466   bool isTemporalDivergent(const BlockT &ObservingBlock,
467                            ConstValueRefT Val) const;
468 };
469 
470 template <typename ImplT>
operator()471 void GenericUniformityAnalysisImplDeleter<ImplT>::operator()(ImplT *Impl) {
472   delete Impl;
473 }
474 
475 /// Compute divergence starting with a divergent branch.
476 template <typename ContextT> class DivergencePropagator {
477 public:
478   using BlockT = typename ContextT::BlockT;
479   using DominatorTreeT = typename ContextT::DominatorTreeT;
480   using FunctionT = typename ContextT::FunctionT;
481   using ValueRefT = typename ContextT::ValueRefT;
482 
483   using CycleInfoT = GenericCycleInfo<ContextT>;
484   using CycleT = typename CycleInfoT::CycleT;
485 
486   using ModifiedPO = ModifiedPostOrder<ContextT>;
487   using SyncDependenceAnalysisT = GenericSyncDependenceAnalysis<ContextT>;
488   using DivergenceDescriptorT =
489       typename SyncDependenceAnalysisT::DivergenceDescriptor;
490   using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap;
491 
492   const ModifiedPO &CyclePOT;
493   const DominatorTreeT &DT;
494   const CycleInfoT &CI;
495   const BlockT &DivTermBlock;
496   const ContextT &Context;
497 
498   // Track blocks that receive a new label. Every time we relabel a
499   // cycle header, we another pass over the modified post-order in
500   // order to propagate the header label. The bit vector also allows
501   // us to skip labels that have not changed.
502   SparseBitVector<> FreshLabels;
503 
504   // divergent join and cycle exit descriptor.
505   std::unique_ptr<DivergenceDescriptorT> DivDesc;
506   BlockLabelMapT &BlockLabels;
507 
DivergencePropagator(const ModifiedPO & CyclePOT,const DominatorTreeT & DT,const CycleInfoT & CI,const BlockT & DivTermBlock)508   DivergencePropagator(const ModifiedPO &CyclePOT, const DominatorTreeT &DT,
509                        const CycleInfoT &CI, const BlockT &DivTermBlock)
510       : CyclePOT(CyclePOT), DT(DT), CI(CI), DivTermBlock(DivTermBlock),
511         Context(CI.getSSAContext()), DivDesc(new DivergenceDescriptorT),
512         BlockLabels(DivDesc->BlockLabels) {}
513 
printDefs(raw_ostream & Out)514   void printDefs(raw_ostream &Out) {
515     Out << "Propagator::BlockLabels {\n";
516     for (int BlockIdx = (int)CyclePOT.size() - 1; BlockIdx >= 0; --BlockIdx) {
517       const auto *Block = CyclePOT[BlockIdx];
518       const auto *Label = BlockLabels[Block];
519       Out << Context.print(Block) << "(" << BlockIdx << ") : ";
520       if (!Label) {
521         Out << "<null>\n";
522       } else {
523         Out << Context.print(Label) << "\n";
524       }
525     }
526     Out << "}\n";
527   }
528 
529   // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
530   // causes a divergent join.
computeJoin(const BlockT & SuccBlock,const BlockT & PushedLabel)531   bool computeJoin(const BlockT &SuccBlock, const BlockT &PushedLabel) {
532     const auto *OldLabel = BlockLabels[&SuccBlock];
533 
534     LLVM_DEBUG(dbgs() << "labeling " << Context.print(&SuccBlock) << ":\n"
535                       << "\tpushed label: " << Context.print(&PushedLabel)
536                       << "\n"
537                       << "\told label: " << Context.print(OldLabel) << "\n");
538 
539     // Early exit if there is no change in the label.
540     if (OldLabel == &PushedLabel)
541       return false;
542 
543     if (OldLabel != &SuccBlock) {
544       auto SuccIdx = CyclePOT.getIndex(&SuccBlock);
545       // Assigning a new label, mark this in FreshLabels.
546       LLVM_DEBUG(dbgs() << "\tfresh label: " << SuccIdx << "\n");
547       FreshLabels.set(SuccIdx);
548     }
549 
550     // This is not a join if the succ was previously unlabeled.
551     if (!OldLabel) {
552       LLVM_DEBUG(dbgs() << "\tnew label: " << Context.print(&PushedLabel)
553                         << "\n");
554       BlockLabels[&SuccBlock] = &PushedLabel;
555       return false;
556     }
557 
558     // This is a new join. Label the join block as itself, and not as
559     // the pushed label.
560     LLVM_DEBUG(dbgs() << "\tnew label: " << Context.print(&SuccBlock) << "\n");
561     BlockLabels[&SuccBlock] = &SuccBlock;
562 
563     return true;
564   }
565 
566   // visiting a virtual cycle exit edge from the cycle header --> temporal
567   // divergence on join
visitCycleExitEdge(const BlockT & ExitBlock,const BlockT & Label)568   bool visitCycleExitEdge(const BlockT &ExitBlock, const BlockT &Label) {
569     if (!computeJoin(ExitBlock, Label))
570       return false;
571 
572     // Identified a divergent cycle exit
573     DivDesc->CycleDivBlocks.insert(&ExitBlock);
574     LLVM_DEBUG(dbgs() << "\tDivergent cycle exit: " << Context.print(&ExitBlock)
575                       << "\n");
576     return true;
577   }
578 
579   // process \p SuccBlock with reaching definition \p Label
visitEdge(const BlockT & SuccBlock,const BlockT & Label)580   bool visitEdge(const BlockT &SuccBlock, const BlockT &Label) {
581     if (!computeJoin(SuccBlock, Label))
582       return false;
583 
584     // Divergent, disjoint paths join.
585     DivDesc->JoinDivBlocks.insert(&SuccBlock);
586     LLVM_DEBUG(dbgs() << "\tDivergent join: " << Context.print(&SuccBlock)
587                       << "\n");
588     return true;
589   }
590 
computeJoinPoints()591   std::unique_ptr<DivergenceDescriptorT> computeJoinPoints() {
592     assert(DivDesc);
593 
594     LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: "
595                       << Context.print(&DivTermBlock) << "\n");
596 
597     // Early stopping criterion
598     int FloorIdx = CyclePOT.size() - 1;
599     const BlockT *FloorLabel = nullptr;
600     int DivTermIdx = CyclePOT.getIndex(&DivTermBlock);
601 
602     // Bootstrap with branch targets
603     auto const *DivTermCycle = CI.getCycle(&DivTermBlock);
604     for (const auto *SuccBlock : successors(&DivTermBlock)) {
605       if (DivTermCycle && !DivTermCycle->contains(SuccBlock)) {
606         // If DivTerm exits the cycle immediately, computeJoin() might
607         // not reach SuccBlock with a different label. We need to
608         // check for this exit now.
609         DivDesc->CycleDivBlocks.insert(SuccBlock);
610         LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: "
611                           << Context.print(SuccBlock) << "\n");
612       }
613       auto SuccIdx = CyclePOT.getIndex(SuccBlock);
614       visitEdge(*SuccBlock, *SuccBlock);
615       FloorIdx = std::min<int>(FloorIdx, SuccIdx);
616     }
617 
618     while (true) {
619       auto BlockIdx = FreshLabels.find_last();
620       if (BlockIdx == -1 || BlockIdx < FloorIdx)
621         break;
622 
623       LLVM_DEBUG(dbgs() << "Current labels:\n"; printDefs(dbgs()));
624 
625       FreshLabels.reset(BlockIdx);
626       if (BlockIdx == DivTermIdx) {
627         LLVM_DEBUG(dbgs() << "Skipping DivTermBlock\n");
628         continue;
629       }
630 
631       const auto *Block = CyclePOT[BlockIdx];
632       LLVM_DEBUG(dbgs() << "visiting " << Context.print(Block) << " at index "
633                         << BlockIdx << "\n");
634 
635       const auto *Label = BlockLabels[Block];
636       assert(Label);
637 
638       bool CausedJoin = false;
639       int LoweredFloorIdx = FloorIdx;
640 
641       // If the current block is the header of a reducible cycle that
642       // contains the divergent branch, then the label should be
643       // propagated to the cycle exits. Such a header is the "last
644       // possible join" of any disjoint paths within this cycle. This
645       // prevents detection of spurious joins at the entries of any
646       // irreducible child cycles.
647       //
648       // This conclusion about the header is true for any choice of DFS:
649       //
650       //   If some DFS has a reducible cycle C with header H, then for
651       //   any other DFS, H is the header of a cycle C' that is a
652       //   superset of C. For a divergent branch inside the subgraph
653       //   C, any join node inside C is either H, or some node
654       //   encountered without passing through H.
655       //
656       auto getReducibleParent = [&](const BlockT *Block) -> const CycleT * {
657         if (!CyclePOT.isReducibleCycleHeader(Block))
658           return nullptr;
659         const auto *BlockCycle = CI.getCycle(Block);
660         if (BlockCycle->contains(&DivTermBlock))
661           return BlockCycle;
662         return nullptr;
663       };
664 
665       if (const auto *BlockCycle = getReducibleParent(Block)) {
666         SmallVector<BlockT *, 4> BlockCycleExits;
667         BlockCycle->getExitBlocks(BlockCycleExits);
668         for (auto *BlockCycleExit : BlockCycleExits) {
669           CausedJoin |= visitCycleExitEdge(*BlockCycleExit, *Label);
670           LoweredFloorIdx =
671               std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(BlockCycleExit));
672         }
673       } else {
674         for (const auto *SuccBlock : successors(Block)) {
675           CausedJoin |= visitEdge(*SuccBlock, *Label);
676           LoweredFloorIdx =
677               std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(SuccBlock));
678         }
679       }
680 
681       // Floor update
682       if (CausedJoin) {
683         // 1. Different labels pushed to successors
684         FloorIdx = LoweredFloorIdx;
685       } else if (FloorLabel != Label) {
686         // 2. No join caused BUT we pushed a label that is different than the
687         // last pushed label
688         FloorIdx = LoweredFloorIdx;
689         FloorLabel = Label;
690       }
691     }
692 
693     LLVM_DEBUG(dbgs() << "Final labeling:\n"; printDefs(dbgs()));
694 
695     // Check every cycle containing DivTermBlock for exit divergence.
696     // A cycle has exit divergence if the label of an exit block does
697     // not match the label of its header.
698     for (const auto *Cycle = CI.getCycle(&DivTermBlock); Cycle;
699          Cycle = Cycle->getParentCycle()) {
700       if (Cycle->isReducible()) {
701         // The exit divergence of a reducible cycle is recorded while
702         // propagating labels.
703         continue;
704       }
705       SmallVector<BlockT *> Exits;
706       Cycle->getExitBlocks(Exits);
707       auto *Header = Cycle->getHeader();
708       auto *HeaderLabel = BlockLabels[Header];
709       for (const auto *Exit : Exits) {
710         if (BlockLabels[Exit] != HeaderLabel) {
711           // Identified a divergent cycle exit
712           DivDesc->CycleDivBlocks.insert(Exit);
713           LLVM_DEBUG(dbgs() << "\tDivergent cycle exit: " << Context.print(Exit)
714                             << "\n");
715         }
716       }
717     }
718 
719     return std::move(DivDesc);
720   }
721 };
722 
723 template <typename ContextT>
724 typename llvm::GenericSyncDependenceAnalysis<ContextT>::DivergenceDescriptor
725     llvm::GenericSyncDependenceAnalysis<ContextT>::EmptyDivergenceDesc;
726 
727 template <typename ContextT>
GenericSyncDependenceAnalysis(const ContextT & Context,const DominatorTreeT & DT,const CycleInfoT & CI)728 llvm::GenericSyncDependenceAnalysis<ContextT>::GenericSyncDependenceAnalysis(
729     const ContextT &Context, const DominatorTreeT &DT, const CycleInfoT &CI)
730     : CyclePO(Context), DT(DT), CI(CI) {
731   CyclePO.compute(CI);
732 }
733 
734 template <typename ContextT>
735 auto llvm::GenericSyncDependenceAnalysis<ContextT>::getJoinBlocks(
736     const BlockT *DivTermBlock) -> const DivergenceDescriptor & {
737   // trivial case
738   if (succ_size(DivTermBlock) <= 1) {
739     return EmptyDivergenceDesc;
740   }
741 
742   // already available in cache?
743   auto ItCached = CachedControlDivDescs.find(DivTermBlock);
744   if (ItCached != CachedControlDivDescs.end())
745     return *ItCached->second;
746 
747   // compute all join points
748   DivergencePropagatorT Propagator(CyclePO, DT, CI, *DivTermBlock);
749   auto DivDesc = Propagator.computeJoinPoints();
750 
751   auto printBlockSet = [&](ConstBlockSet &Blocks) {
752     return Printable([&](raw_ostream &Out) {
753       Out << "[";
754       ListSeparator LS;
755       for (const auto *BB : Blocks) {
756         Out << LS << CI.getSSAContext().print(BB);
757       }
758       Out << "]\n";
759     });
760   };
761 
762   LLVM_DEBUG(
763       dbgs() << "\nResult (" << CI.getSSAContext().print(DivTermBlock)
764              << "):\n  JoinDivBlocks: " << printBlockSet(DivDesc->JoinDivBlocks)
765              << "  CycleDivBlocks: " << printBlockSet(DivDesc->CycleDivBlocks)
766              << "\n");
767   (void)printBlockSet;
768 
769   auto ItInserted =
770       CachedControlDivDescs.try_emplace(DivTermBlock, std::move(DivDesc));
771   assert(ItInserted.second);
772   return *ItInserted.first->second;
773 }
774 
775 template <typename ContextT>
markDivergent(const InstructionT & I)776 bool GenericUniformityAnalysisImpl<ContextT>::markDivergent(
777     const InstructionT &I) {
778   if (I.isTerminator()) {
779     if (DivergentTermBlocks.insert(I.getParent()).second) {
780       LLVM_DEBUG(dbgs() << "marked divergent term block: "
781                         << Context.print(I.getParent()) << "\n");
782       return true;
783     }
784     return false;
785   }
786 
787   return markDefsDivergent(I);
788 }
789 
790 template <typename ContextT>
markDivergent(ConstValueRefT Val)791 bool GenericUniformityAnalysisImpl<ContextT>::markDivergent(
792     ConstValueRefT Val) {
793   if (DivergentValues.insert(Val).second) {
794     LLVM_DEBUG(dbgs() << "marked divergent: " << Context.print(Val) << "\n");
795     return true;
796   }
797   return false;
798 }
799 
800 template <typename ContextT>
addUniformOverride(const InstructionT & Instr)801 void GenericUniformityAnalysisImpl<ContextT>::addUniformOverride(
802     const InstructionT &Instr) {
803   UniformOverrides.insert(&Instr);
804 }
805 
806 template <typename ContextT>
analyzeTemporalDivergence(const InstructionT & I,const CycleT & OuterDivCycle)807 void GenericUniformityAnalysisImpl<ContextT>::analyzeTemporalDivergence(
808     const InstructionT &I, const CycleT &OuterDivCycle) {
809   if (isDivergent(I))
810     return;
811 
812   LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << Context.print(&I)
813                     << "\n");
814   if (!usesValueFromCycle(I, OuterDivCycle))
815     return;
816 
817   if (isAlwaysUniform(I))
818     return;
819 
820   if (markDivergent(I))
821     Worklist.push_back(&I);
822 }
823 
824 // Mark all external users of values defined inside \param
825 // OuterDivCycle as divergent.
826 //
827 // This follows all live out edges wherever they may lead. Potential
828 // users of values defined inside DivCycle could be anywhere in the
829 // dominance region of DivCycle (including its fringes for phi nodes).
830 // A cycle C dominates a block B iff every path from the entry block
831 // to B must pass through a block contained in C. If C is a reducible
832 // cycle (or natural loop), C dominates B iff the header of C
833 // dominates B. But in general, we iteratively examine cycle cycle
834 // exits and their successors.
835 template <typename ContextT>
analyzeCycleExitDivergence(const CycleT & OuterDivCycle)836 void GenericUniformityAnalysisImpl<ContextT>::analyzeCycleExitDivergence(
837     const CycleT &OuterDivCycle) {
838   // Set of blocks that are dominated by the cycle, i.e., each is only
839   // reachable from paths that pass through the cycle.
840   SmallPtrSet<BlockT *, 16> DomRegion;
841 
842   // The boundary of DomRegion, formed by blocks that are not
843   // dominated by the cycle.
844   SmallVector<BlockT *> DomFrontier;
845   OuterDivCycle.getExitBlocks(DomFrontier);
846 
847   // Returns true if BB is dominated by the cycle.
848   auto isInDomRegion = [&](BlockT *BB) {
849     for (auto *P : predecessors(BB)) {
850       if (OuterDivCycle.contains(P))
851         continue;
852       if (DomRegion.count(P))
853         continue;
854       return false;
855     }
856     return true;
857   };
858 
859   // Keep advancing the frontier along successor edges, while
860   // promoting blocks to DomRegion.
861   while (true) {
862     bool Promoted = false;
863     SmallVector<BlockT *> Temp;
864     for (auto *W : DomFrontier) {
865       if (!isInDomRegion(W)) {
866         Temp.push_back(W);
867         continue;
868       }
869       DomRegion.insert(W);
870       Promoted = true;
871       for (auto *Succ : successors(W)) {
872         if (DomRegion.contains(Succ))
873           continue;
874         Temp.push_back(Succ);
875       }
876     }
877     if (!Promoted)
878       break;
879     DomFrontier = Temp;
880   }
881 
882   // At DomFrontier, only the PHI nodes are affected by temporal
883   // divergence.
884   for (const auto *UserBlock : DomFrontier) {
885     LLVM_DEBUG(dbgs() << "Analyze phis after cycle exit: "
886                       << Context.print(UserBlock) << "\n");
887     for (const auto &Phi : UserBlock->phis()) {
888       LLVM_DEBUG(dbgs() << "  " << Context.print(&Phi) << "\n");
889       analyzeTemporalDivergence(Phi, OuterDivCycle);
890     }
891   }
892 
893   // All instructions inside the dominance region are affected by
894   // temporal divergence.
895   for (const auto *UserBlock : DomRegion) {
896     LLVM_DEBUG(dbgs() << "Analyze non-phi users after cycle exit: "
897                       << Context.print(UserBlock) << "\n");
898     for (const auto &I : *UserBlock) {
899       LLVM_DEBUG(dbgs() << "  " << Context.print(&I) << "\n");
900       analyzeTemporalDivergence(I, OuterDivCycle);
901     }
902   }
903 }
904 
905 template <typename ContextT>
propagateCycleExitDivergence(const BlockT & DivExit,const CycleT & InnerDivCycle)906 void GenericUniformityAnalysisImpl<ContextT>::propagateCycleExitDivergence(
907     const BlockT &DivExit, const CycleT &InnerDivCycle) {
908   LLVM_DEBUG(dbgs() << "\tpropCycleExitDiv " << Context.print(&DivExit)
909                     << "\n");
910   auto *DivCycle = &InnerDivCycle;
911   auto *OuterDivCycle = DivCycle;
912   auto *ExitLevelCycle = CI.getCycle(&DivExit);
913   const unsigned CycleExitDepth =
914       ExitLevelCycle ? ExitLevelCycle->getDepth() : 0;
915 
916   // Find outer-most cycle that does not contain \p DivExit
917   while (DivCycle && DivCycle->getDepth() > CycleExitDepth) {
918     LLVM_DEBUG(dbgs() << "  Found exiting cycle: "
919                       << Context.print(DivCycle->getHeader()) << "\n");
920     OuterDivCycle = DivCycle;
921     DivCycle = DivCycle->getParentCycle();
922   }
923   LLVM_DEBUG(dbgs() << "\tOuter-most exiting cycle: "
924                     << Context.print(OuterDivCycle->getHeader()) << "\n");
925 
926   if (!DivergentExitCycles.insert(OuterDivCycle).second)
927     return;
928 
929   // Exit divergence does not matter if the cycle itself is assumed to
930   // be divergent.
931   for (const auto *C : AssumedDivergent) {
932     if (C->contains(OuterDivCycle))
933       return;
934   }
935 
936   analyzeCycleExitDivergence(*OuterDivCycle);
937 }
938 
939 template <typename ContextT>
taintAndPushAllDefs(const BlockT & BB)940 void GenericUniformityAnalysisImpl<ContextT>::taintAndPushAllDefs(
941     const BlockT &BB) {
942   LLVM_DEBUG(dbgs() << "taintAndPushAllDefs " << Context.print(&BB) << "\n");
943   for (const auto &I : instrs(BB)) {
944     // Terminators do not produce values; they are divergent only if
945     // the condition is divergent. That is handled when the divergent
946     // condition is placed in the worklist.
947     if (I.isTerminator())
948       break;
949 
950     // Mark this as divergent. We don't check if the instruction is
951     // always uniform. In a cycle where the thread convergence is not
952     // statically known, the instruction is not statically converged,
953     // and its outputs cannot be statically uniform.
954     if (markDivergent(I))
955       Worklist.push_back(&I);
956   }
957 }
958 
959 /// Mark divergent phi nodes in a join block
960 template <typename ContextT>
taintAndPushPhiNodes(const BlockT & JoinBlock)961 void GenericUniformityAnalysisImpl<ContextT>::taintAndPushPhiNodes(
962     const BlockT &JoinBlock) {
963   LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << Context.print(&JoinBlock)
964                     << "\n");
965   for (const auto &Phi : JoinBlock.phis()) {
966     if (ContextT::isConstantValuePhi(Phi))
967       continue;
968     if (markDivergent(Phi))
969       Worklist.push_back(&Phi);
970   }
971 }
972 
973 /// Add \p Candidate to \p Cycles if it is not already contained in \p Cycles.
974 ///
975 /// \return true iff \p Candidate was added to \p Cycles.
976 template <typename CycleT>
insertIfNotContained(SmallVector<CycleT * > & Cycles,CycleT * Candidate)977 static bool insertIfNotContained(SmallVector<CycleT *> &Cycles,
978                                  CycleT *Candidate) {
979   if (llvm::any_of(Cycles,
980                    [Candidate](CycleT *C) { return C->contains(Candidate); }))
981     return false;
982   Cycles.push_back(Candidate);
983   return true;
984 }
985 
986 /// Return the outermost cycle made divergent by branch outside it.
987 ///
988 /// If two paths that diverged outside an irreducible cycle join
989 /// inside that cycle, then that whole cycle is assumed to be
990 /// divergent. This does not apply if the cycle is reducible.
991 template <typename CycleT, typename BlockT>
getExtDivCycle(const CycleT * Cycle,const BlockT * DivTermBlock,const BlockT * JoinBlock)992 static const CycleT *getExtDivCycle(const CycleT *Cycle,
993                                     const BlockT *DivTermBlock,
994                                     const BlockT *JoinBlock) {
995   assert(Cycle);
996   assert(Cycle->contains(JoinBlock));
997 
998   if (Cycle->contains(DivTermBlock))
999     return nullptr;
1000 
1001   if (Cycle->isReducible()) {
1002     assert(Cycle->getHeader() == JoinBlock);
1003     return nullptr;
1004   }
1005 
1006   const auto *Parent = Cycle->getParentCycle();
1007   while (Parent && !Parent->contains(DivTermBlock)) {
1008     // If the join is inside a child, then the parent must be
1009     // irreducible. The only join in a reducible cyle is its own
1010     // header.
1011     assert(!Parent->isReducible());
1012     Cycle = Parent;
1013     Parent = Cycle->getParentCycle();
1014   }
1015 
1016   LLVM_DEBUG(dbgs() << "cycle made divergent by external branch\n");
1017   return Cycle;
1018 }
1019 
1020 /// Return the outermost cycle made divergent by branch inside it.
1021 ///
1022 /// This checks the "diverged entry" criterion defined in the
1023 /// docs/ConvergenceAnalysis.html.
1024 template <typename ContextT, typename CycleT, typename BlockT,
1025           typename DominatorTreeT>
1026 static const CycleT *
getIntDivCycle(const CycleT * Cycle,const BlockT * DivTermBlock,const BlockT * JoinBlock,const DominatorTreeT & DT,ContextT & Context)1027 getIntDivCycle(const CycleT *Cycle, const BlockT *DivTermBlock,
1028                const BlockT *JoinBlock, const DominatorTreeT &DT,
1029                ContextT &Context) {
1030   LLVM_DEBUG(dbgs() << "examine join " << Context.print(JoinBlock)
1031                     << "for internal branch " << Context.print(DivTermBlock)
1032                     << "\n");
1033   if (DT.properlyDominates(DivTermBlock, JoinBlock))
1034     return nullptr;
1035 
1036   // Find the smallest common cycle, if one exists.
1037   assert(Cycle && Cycle->contains(JoinBlock));
1038   while (Cycle && !Cycle->contains(DivTermBlock)) {
1039     Cycle = Cycle->getParentCycle();
1040   }
1041   if (!Cycle || Cycle->isReducible())
1042     return nullptr;
1043 
1044   if (DT.properlyDominates(Cycle->getHeader(), JoinBlock))
1045     return nullptr;
1046 
1047   LLVM_DEBUG(dbgs() << "  header " << Context.print(Cycle->getHeader())
1048                     << " does not dominate join\n");
1049 
1050   const auto *Parent = Cycle->getParentCycle();
1051   while (Parent && !DT.properlyDominates(Parent->getHeader(), JoinBlock)) {
1052     LLVM_DEBUG(dbgs() << "  header " << Context.print(Parent->getHeader())
1053                       << " does not dominate join\n");
1054     Cycle = Parent;
1055     Parent = Parent->getParentCycle();
1056   }
1057 
1058   LLVM_DEBUG(dbgs() << "  cycle made divergent by internal branch\n");
1059   return Cycle;
1060 }
1061 
1062 template <typename ContextT, typename CycleT, typename BlockT,
1063           typename DominatorTreeT>
1064 static const CycleT *
getOutermostDivergentCycle(const CycleT * Cycle,const BlockT * DivTermBlock,const BlockT * JoinBlock,const DominatorTreeT & DT,ContextT & Context)1065 getOutermostDivergentCycle(const CycleT *Cycle, const BlockT *DivTermBlock,
1066                            const BlockT *JoinBlock, const DominatorTreeT &DT,
1067                            ContextT &Context) {
1068   if (!Cycle)
1069     return nullptr;
1070 
1071   // First try to expand Cycle to the largest that contains JoinBlock
1072   // but not DivTermBlock.
1073   const auto *Ext = getExtDivCycle(Cycle, DivTermBlock, JoinBlock);
1074 
1075   // Continue expanding to the largest cycle that contains both.
1076   const auto *Int = getIntDivCycle(Cycle, DivTermBlock, JoinBlock, DT, Context);
1077 
1078   if (Int)
1079     return Int;
1080   return Ext;
1081 }
1082 
1083 template <typename ContextT>
analyzeControlDivergence(const InstructionT & Term)1084 void GenericUniformityAnalysisImpl<ContextT>::analyzeControlDivergence(
1085     const InstructionT &Term) {
1086   const auto *DivTermBlock = Term.getParent();
1087   DivergentTermBlocks.insert(DivTermBlock);
1088   LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Context.print(DivTermBlock)
1089                     << "\n");
1090 
1091   // Don't propagate divergence from unreachable blocks.
1092   if (!DT.isReachableFromEntry(DivTermBlock))
1093     return;
1094 
1095   const auto &DivDesc = SDA.getJoinBlocks(DivTermBlock);
1096   SmallVector<const CycleT *> DivCycles;
1097 
1098   // Iterate over all blocks now reachable by a disjoint path join
1099   for (const auto *JoinBlock : DivDesc.JoinDivBlocks) {
1100     const auto *Cycle = CI.getCycle(JoinBlock);
1101     LLVM_DEBUG(dbgs() << "visiting join block " << Context.print(JoinBlock)
1102                       << "\n");
1103     if (const auto *Outermost = getOutermostDivergentCycle(
1104             Cycle, DivTermBlock, JoinBlock, DT, Context)) {
1105       LLVM_DEBUG(dbgs() << "found divergent cycle\n");
1106       DivCycles.push_back(Outermost);
1107       continue;
1108     }
1109     taintAndPushPhiNodes(*JoinBlock);
1110   }
1111 
1112   // Sort by order of decreasing depth. This allows later cycles to be skipped
1113   // because they are already contained in earlier ones.
1114   llvm::sort(DivCycles, [](const CycleT *A, const CycleT *B) {
1115     return A->getDepth() > B->getDepth();
1116   });
1117 
1118   // Cycles that are assumed divergent due to the diverged entry
1119   // criterion potentially contain temporal divergence depending on
1120   // the DFS chosen. Conservatively, all values produced in such a
1121   // cycle are assumed divergent. "Cycle invariant" values may be
1122   // assumed uniform, but that requires further analysis.
1123   for (auto *C : DivCycles) {
1124     if (!insertIfNotContained(AssumedDivergent, C))
1125       continue;
1126     LLVM_DEBUG(dbgs() << "process divergent cycle\n");
1127     for (const BlockT *BB : C->blocks()) {
1128       taintAndPushAllDefs(*BB);
1129     }
1130   }
1131 
1132   const auto *BranchCycle = CI.getCycle(DivTermBlock);
1133   assert(DivDesc.CycleDivBlocks.empty() || BranchCycle);
1134   for (const auto *DivExitBlock : DivDesc.CycleDivBlocks) {
1135     propagateCycleExitDivergence(*DivExitBlock, *BranchCycle);
1136   }
1137 }
1138 
1139 template <typename ContextT>
compute()1140 void GenericUniformityAnalysisImpl<ContextT>::compute() {
1141   // Initialize worklist.
1142   auto DivValuesCopy = DivergentValues;
1143   for (const auto DivVal : DivValuesCopy) {
1144     assert(isDivergent(DivVal) && "Worklist invariant violated!");
1145     pushUsers(DivVal);
1146   }
1147 
1148   // All values on the Worklist are divergent.
1149   // Their users may not have been updated yet.
1150   while (!Worklist.empty()) {
1151     const InstructionT *I = Worklist.back();
1152     Worklist.pop_back();
1153 
1154     LLVM_DEBUG(dbgs() << "worklist pop: " << Context.print(I) << "\n");
1155 
1156     if (I->isTerminator()) {
1157       analyzeControlDivergence(*I);
1158       continue;
1159     }
1160 
1161     // propagate value divergence to users
1162     assert(isDivergent(*I) && "Worklist invariant violated!");
1163     pushUsers(*I);
1164   }
1165 }
1166 
1167 template <typename ContextT>
isAlwaysUniform(const InstructionT & Instr)1168 bool GenericUniformityAnalysisImpl<ContextT>::isAlwaysUniform(
1169     const InstructionT &Instr) const {
1170   return UniformOverrides.contains(&Instr);
1171 }
1172 
1173 template <typename ContextT>
GenericUniformityInfo(FunctionT & Func,const DominatorTreeT & DT,const CycleInfoT & CI,const TargetTransformInfo * TTI)1174 GenericUniformityInfo<ContextT>::GenericUniformityInfo(
1175     FunctionT &Func, const DominatorTreeT &DT, const CycleInfoT &CI,
1176     const TargetTransformInfo *TTI)
1177     : F(&Func) {
1178   DA.reset(new ImplT{Func, DT, CI, TTI});
1179   DA->initialize();
1180   DA->compute();
1181 }
1182 
1183 template <typename ContextT>
print(raw_ostream & OS)1184 void GenericUniformityAnalysisImpl<ContextT>::print(raw_ostream &OS) const {
1185   bool haveDivergentArgs = false;
1186 
1187   if (DivergentValues.empty()) {
1188     assert(DivergentTermBlocks.empty());
1189     assert(DivergentExitCycles.empty());
1190     OS << "ALL VALUES UNIFORM\n";
1191     return;
1192   }
1193 
1194   for (const auto &entry : DivergentValues) {
1195     const BlockT *parent = Context.getDefBlock(entry);
1196     if (!parent) {
1197       if (!haveDivergentArgs) {
1198         OS << "DIVERGENT ARGUMENTS:\n";
1199         haveDivergentArgs = true;
1200       }
1201       OS << "  DIVERGENT: " << Context.print(entry) << '\n';
1202     }
1203   }
1204 
1205   if (!AssumedDivergent.empty()) {
1206     OS << "CYCLES ASSSUMED DIVERGENT:\n";
1207     for (const CycleT *cycle : AssumedDivergent) {
1208       OS << "  " << cycle->print(Context) << '\n';
1209     }
1210   }
1211 
1212   if (!DivergentExitCycles.empty()) {
1213     OS << "CYCLES WITH DIVERGENT EXIT:\n";
1214     for (const CycleT *cycle : DivergentExitCycles) {
1215       OS << "  " << cycle->print(Context) << '\n';
1216     }
1217   }
1218 
1219   for (auto &block : F) {
1220     OS << "\nBLOCK " << Context.print(&block) << '\n';
1221 
1222     OS << "DEFINITIONS\n";
1223     SmallVector<ConstValueRefT, 16> defs;
1224     Context.appendBlockDefs(defs, block);
1225     for (auto value : defs) {
1226       if (isDivergent(value))
1227         OS << "  DIVERGENT: ";
1228       else
1229         OS << "             ";
1230       OS << Context.print(value) << '\n';
1231     }
1232 
1233     OS << "TERMINATORS\n";
1234     SmallVector<const InstructionT *, 8> terms;
1235     Context.appendBlockTerms(terms, block);
1236     bool divergentTerminators = hasDivergentTerminator(block);
1237     for (auto *T : terms) {
1238       if (divergentTerminators)
1239         OS << "  DIVERGENT: ";
1240       else
1241         OS << "             ";
1242       OS << Context.print(T) << '\n';
1243     }
1244 
1245     OS << "END BLOCK\n";
1246   }
1247 }
1248 
1249 template <typename ContextT>
hasDivergence()1250 bool GenericUniformityInfo<ContextT>::hasDivergence() const {
1251   return DA->hasDivergence();
1252 }
1253 
1254 /// Whether \p V is divergent at its definition.
1255 template <typename ContextT>
isDivergent(ConstValueRefT V)1256 bool GenericUniformityInfo<ContextT>::isDivergent(ConstValueRefT V) const {
1257   return DA->isDivergent(V);
1258 }
1259 
1260 template <typename ContextT>
hasDivergentTerminator(const BlockT & B)1261 bool GenericUniformityInfo<ContextT>::hasDivergentTerminator(const BlockT &B) {
1262   return DA->hasDivergentTerminator(B);
1263 }
1264 
1265 /// \brief T helper function for printing.
1266 template <typename ContextT>
print(raw_ostream & out)1267 void GenericUniformityInfo<ContextT>::print(raw_ostream &out) const {
1268   DA->print(out);
1269 }
1270 
1271 template <typename ContextT>
computeStackPO(SmallVectorImpl<BlockT * > & Stack,const CycleInfoT & CI,const CycleT * Cycle,SmallPtrSetImpl<BlockT * > & Finalized)1272 void llvm::ModifiedPostOrder<ContextT>::computeStackPO(
1273     SmallVectorImpl<BlockT *> &Stack, const CycleInfoT &CI, const CycleT *Cycle,
1274     SmallPtrSetImpl<BlockT *> &Finalized) {
1275   LLVM_DEBUG(dbgs() << "inside computeStackPO\n");
1276   while (!Stack.empty()) {
1277     auto *NextBB = Stack.back();
1278     if (Finalized.count(NextBB)) {
1279       Stack.pop_back();
1280       continue;
1281     }
1282     LLVM_DEBUG(dbgs() << "  visiting " << CI.getSSAContext().print(NextBB)
1283                       << "\n");
1284     auto *NestedCycle = CI.getCycle(NextBB);
1285     if (Cycle != NestedCycle && (!Cycle || Cycle->contains(NestedCycle))) {
1286       LLVM_DEBUG(dbgs() << "  found a cycle\n");
1287       while (NestedCycle->getParentCycle() != Cycle)
1288         NestedCycle = NestedCycle->getParentCycle();
1289 
1290       SmallVector<BlockT *, 3> NestedExits;
1291       NestedCycle->getExitBlocks(NestedExits);
1292       bool PushedNodes = false;
1293       for (auto *NestedExitBB : NestedExits) {
1294         LLVM_DEBUG(dbgs() << "  examine exit: "
1295                           << CI.getSSAContext().print(NestedExitBB) << "\n");
1296         if (Cycle && !Cycle->contains(NestedExitBB))
1297           continue;
1298         if (Finalized.count(NestedExitBB))
1299           continue;
1300         PushedNodes = true;
1301         Stack.push_back(NestedExitBB);
1302         LLVM_DEBUG(dbgs() << "  pushed exit: "
1303                           << CI.getSSAContext().print(NestedExitBB) << "\n");
1304       }
1305       if (!PushedNodes) {
1306         // All loop exits finalized -> finish this node
1307         Stack.pop_back();
1308         computeCyclePO(CI, NestedCycle, Finalized);
1309       }
1310       continue;
1311     }
1312 
1313     LLVM_DEBUG(dbgs() << "  no nested cycle, going into DAG\n");
1314     // DAG-style
1315     bool PushedNodes = false;
1316     for (auto *SuccBB : successors(NextBB)) {
1317       LLVM_DEBUG(dbgs() << "  examine succ: "
1318                         << CI.getSSAContext().print(SuccBB) << "\n");
1319       if (Cycle && !Cycle->contains(SuccBB))
1320         continue;
1321       if (Finalized.count(SuccBB))
1322         continue;
1323       PushedNodes = true;
1324       Stack.push_back(SuccBB);
1325       LLVM_DEBUG(dbgs() << "  pushed succ: " << CI.getSSAContext().print(SuccBB)
1326                         << "\n");
1327     }
1328     if (!PushedNodes) {
1329       // Never push nodes twice
1330       LLVM_DEBUG(dbgs() << "  finishing node: "
1331                         << CI.getSSAContext().print(NextBB) << "\n");
1332       Stack.pop_back();
1333       Finalized.insert(NextBB);
1334       appendBlock(*NextBB);
1335     }
1336   }
1337   LLVM_DEBUG(dbgs() << "exited computeStackPO\n");
1338 }
1339 
1340 template <typename ContextT>
computeCyclePO(const CycleInfoT & CI,const CycleT * Cycle,SmallPtrSetImpl<BlockT * > & Finalized)1341 void ModifiedPostOrder<ContextT>::computeCyclePO(
1342     const CycleInfoT &CI, const CycleT *Cycle,
1343     SmallPtrSetImpl<BlockT *> &Finalized) {
1344   LLVM_DEBUG(dbgs() << "inside computeCyclePO\n");
1345   SmallVector<BlockT *> Stack;
1346   auto *CycleHeader = Cycle->getHeader();
1347 
1348   LLVM_DEBUG(dbgs() << "  noted header: "
1349                     << CI.getSSAContext().print(CycleHeader) << "\n");
1350   assert(!Finalized.count(CycleHeader));
1351   Finalized.insert(CycleHeader);
1352 
1353   // Visit the header last
1354   LLVM_DEBUG(dbgs() << "  finishing header: "
1355                     << CI.getSSAContext().print(CycleHeader) << "\n");
1356   appendBlock(*CycleHeader, Cycle->isReducible());
1357 
1358   // Initialize with immediate successors
1359   for (auto *BB : successors(CycleHeader)) {
1360     LLVM_DEBUG(dbgs() << "  examine succ: " << CI.getSSAContext().print(BB)
1361                       << "\n");
1362     if (!Cycle->contains(BB))
1363       continue;
1364     if (BB == CycleHeader)
1365       continue;
1366     if (!Finalized.count(BB)) {
1367       LLVM_DEBUG(dbgs() << "  pushed succ: " << CI.getSSAContext().print(BB)
1368                         << "\n");
1369       Stack.push_back(BB);
1370     }
1371   }
1372 
1373   // Compute PO inside region
1374   computeStackPO(Stack, CI, Cycle, Finalized);
1375 
1376   LLVM_DEBUG(dbgs() << "exited computeCyclePO\n");
1377 }
1378 
1379 /// \brief Generically compute the modified post order.
1380 template <typename ContextT>
compute(const CycleInfoT & CI)1381 void llvm::ModifiedPostOrder<ContextT>::compute(const CycleInfoT &CI) {
1382   SmallPtrSet<BlockT *, 32> Finalized;
1383   SmallVector<BlockT *> Stack;
1384   auto *F = CI.getFunction();
1385   Stack.reserve(24); // FIXME made-up number
1386   Stack.push_back(GraphTraits<FunctionT *>::getEntryNode(F));
1387   computeStackPO(Stack, CI, nullptr, Finalized);
1388 }
1389 
1390 } // namespace llvm
1391 
1392 #undef DEBUG_TYPE
1393 
1394 #endif // LLVM_ADT_GENERICUNIFORMITYIMPL_H
1395