xref: /openbsd-src/gnu/llvm/llvm/include/llvm/Analysis/DivergenceAnalysis.h (revision d415bd752c734aee168c4ee86ff32e8cc249eb16)
1 //===- llvm/Analysis/DivergenceAnalysis.h - Divergence Analysis -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // The divergence analysis determines which instructions and branches are
11 // divergent given a set of divergent source instructions.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef LLVM_ANALYSIS_DIVERGENCEANALYSIS_H
16 #define LLVM_ANALYSIS_DIVERGENCEANALYSIS_H
17 
18 #include "llvm/ADT/DenseSet.h"
19 #include "llvm/Analysis/SyncDependenceAnalysis.h"
20 #include "llvm/IR/PassManager.h"
21 #include <vector>
22 
23 namespace llvm {
24 class Function;
25 class Instruction;
26 class Loop;
27 class raw_ostream;
28 class TargetTransformInfo;
29 class Value;
30 
31 /// \brief Generic divergence analysis for reducible CFGs.
32 ///
33 /// This analysis propagates divergence in a data-parallel context from sources
34 /// of divergence to all users. It requires reducible CFGs. All assignments
35 /// should be in SSA form.
36 class DivergenceAnalysisImpl {
37 public:
38   /// \brief This instance will analyze the whole function \p F or the loop \p
39   /// RegionLoop.
40   ///
41   /// \param RegionLoop if non-null the analysis is restricted to \p RegionLoop.
42   /// Otherwise the whole function is analyzed.
43   /// \param IsLCSSAForm whether the analysis may assume that the IR in the
44   /// region in LCSSA form.
45   DivergenceAnalysisImpl(const Function &F, const Loop *RegionLoop,
46                          const DominatorTree &DT, const LoopInfo &LI,
47                          SyncDependenceAnalysis &SDA, bool IsLCSSAForm);
48 
49   /// \brief The loop that defines the analyzed region (if any).
getRegionLoop()50   const Loop *getRegionLoop() const { return RegionLoop; }
getFunction()51   const Function &getFunction() const { return F; }
52 
53   /// \brief Whether \p BB is part of the region.
54   bool inRegion(const BasicBlock &BB) const;
55   /// \brief Whether \p I is part of the region.
56   bool inRegion(const Instruction &I) const;
57 
58   /// \brief Mark \p UniVal as a value that is always uniform.
59   void addUniformOverride(const Value &UniVal);
60 
61   /// \brief Mark \p DivVal as a value that is always divergent. Will not do so
62   /// if `isAlwaysUniform(DivVal)`.
63   /// \returns Whether the tracked divergence state of \p DivVal changed.
64   bool markDivergent(const Value &DivVal);
65 
66   /// \brief Propagate divergence to all instructions in the region.
67   /// Divergence is seeded by calls to \p markDivergent.
68   void compute();
69 
70   /// \brief Whether any value was marked or analyzed to be divergent.
hasDetectedDivergence()71   bool hasDetectedDivergence() const { return !DivergentValues.empty(); }
72 
73   /// \brief Whether \p Val will always return a uniform value regardless of its
74   /// operands
75   bool isAlwaysUniform(const Value &Val) const;
76 
77   /// \brief Whether \p Val is divergent at its definition.
78   bool isDivergent(const Value &Val) const;
79 
80   /// \brief Whether \p U is divergent. Uses of a uniform value can be
81   /// divergent.
82   bool isDivergentUse(const Use &U) const;
83 
84 private:
85   /// \brief Mark \p Term as divergent and push all Instructions that become
86   /// divergent as a result on the worklist.
87   void analyzeControlDivergence(const Instruction &Term);
88   /// \brief Mark all phi nodes in \p JoinBlock as divergent and push them on
89   /// the worklist.
90   void taintAndPushPhiNodes(const BasicBlock &JoinBlock);
91 
92   /// \brief Identify all Instructions that become divergent because \p DivExit
93   /// is a divergent loop exit of \p DivLoop. Mark those instructions as
94   /// divergent and push them on the worklist.
95   void propagateLoopExitDivergence(const BasicBlock &DivExit,
96                                    const Loop &DivLoop);
97 
98   /// \brief Internal implementation function for propagateLoopExitDivergence.
99   void analyzeLoopExitDivergence(const BasicBlock &DivExit,
100                                  const Loop &OuterDivLoop);
101 
102   /// \brief Mark all instruction as divergent that use a value defined in \p
103   /// OuterDivLoop. Push their users on the worklist.
104   void analyzeTemporalDivergence(const Instruction &I,
105                                  const Loop &OuterDivLoop);
106 
107   /// \brief Push all users of \p Val (in the region) to the worklist.
108   void pushUsers(const Value &I);
109 
110   /// \brief Whether \p Val is divergent when read in \p ObservingBlock.
111   bool isTemporalDivergent(const BasicBlock &ObservingBlock,
112                            const Value &Val) const;
113 
114 private:
115   const Function &F;
116   // If regionLoop != nullptr, analysis is only performed within \p RegionLoop.
117   // Otherwise, analyze the whole function
118   const Loop *RegionLoop;
119 
120   const DominatorTree &DT;
121   const LoopInfo &LI;
122 
123   // Recognized divergent loops
124   DenseSet<const Loop *> DivergentLoops;
125 
126   // The SDA links divergent branches to divergent control-flow joins.
127   SyncDependenceAnalysis &SDA;
128 
129   // Use simplified code path for LCSSA form.
130   bool IsLCSSAForm;
131 
132   // Set of known-uniform values.
133   DenseSet<const Value *> UniformOverrides;
134 
135   // Detected/marked divergent values.
136   DenseSet<const Value *> DivergentValues;
137 
138   // Internal worklist for divergence propagation.
139   std::vector<const Instruction *> Worklist;
140 };
141 
142 class DivergenceInfo {
143   Function &F;
144 
145   // If the function contains an irreducible region the divergence
146   // analysis can run indefinitely. We set ContainsIrreducible and no
147   // analysis is actually performed on the function. All values in
148   // this function are conservatively reported as divergent instead.
149   bool ContainsIrreducible = false;
150   std::unique_ptr<SyncDependenceAnalysis> SDA;
151   std::unique_ptr<DivergenceAnalysisImpl> DA;
152 
153 public:
154   DivergenceInfo(Function &F, const DominatorTree &DT,
155                  const PostDominatorTree &PDT, const LoopInfo &LI,
156                  const TargetTransformInfo &TTI, bool KnownReducible);
157 
158   /// Whether any divergence was detected.
hasDivergence()159   bool hasDivergence() const {
160     return ContainsIrreducible || DA->hasDetectedDivergence();
161   }
162 
163   /// The GPU kernel this analysis result is for
getFunction()164   const Function &getFunction() const { return F; }
165 
166   /// Whether \p V is divergent at its definition.
isDivergent(const Value & V)167   bool isDivergent(const Value &V) const {
168     return ContainsIrreducible || DA->isDivergent(V);
169   }
170 
171   /// Whether \p U is divergent. Uses of a uniform value can be divergent.
isDivergentUse(const Use & U)172   bool isDivergentUse(const Use &U) const {
173     return ContainsIrreducible || DA->isDivergentUse(U);
174   }
175 
176   /// Whether \p V is uniform/non-divergent.
isUniform(const Value & V)177   bool isUniform(const Value &V) const { return !isDivergent(V); }
178 
179   /// Whether \p U is uniform/non-divergent. Uses of a uniform value can be
180   /// divergent.
isUniformUse(const Use & U)181   bool isUniformUse(const Use &U) const { return !isDivergentUse(U); }
182 };
183 
184 /// \brief Divergence analysis frontend for GPU kernels.
185 class DivergenceAnalysis : public AnalysisInfoMixin<DivergenceAnalysis> {
186   friend AnalysisInfoMixin<DivergenceAnalysis>;
187 
188   static AnalysisKey Key;
189 
190 public:
191   using Result = DivergenceInfo;
192 
193   /// Runs the divergence analysis on @F, a GPU kernel
194   Result run(Function &F, FunctionAnalysisManager &AM);
195 };
196 
197 /// Printer pass to dump divergence analysis results.
198 struct DivergenceAnalysisPrinterPass
199     : public PassInfoMixin<DivergenceAnalysisPrinterPass> {
DivergenceAnalysisPrinterPassDivergenceAnalysisPrinterPass200   DivergenceAnalysisPrinterPass(raw_ostream &OS) : OS(OS) {}
201 
202   PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM);
203 
204 private:
205   raw_ostream &OS;
206 }; // class DivergenceAnalysisPrinterPass
207 
208 } // namespace llvm
209 
210 #endif // LLVM_ANALYSIS_DIVERGENCEANALYSIS_H
211