xref: /llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp (revision 236fda550d36d35a00785938c3e38b0f402aeda6)
1 //===- UniformityAnalysis.cpp ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/UniformityAnalysis.h"
10 #include "llvm/ADT/GenericUniformityImpl.h"
11 #include "llvm/Analysis/CycleAnalysis.h"
12 #include "llvm/Analysis/TargetTransformInfo.h"
13 #include "llvm/IR/Dominators.h"
14 #include "llvm/IR/InstIterator.h"
15 #include "llvm/IR/Instructions.h"
16 #include "llvm/InitializePasses.h"
17 
18 using namespace llvm;
19 
20 template <>
21 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
22     const Instruction &I) const {
23   return isDivergent((const Value *)&I);
24 }
25 
26 template <>
27 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
28     const Instruction &Instr) {
29   return markDivergent(cast<Value>(&Instr));
30 }
31 
32 template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
33   for (auto &I : instructions(F)) {
34     if (TTI->isSourceOfDivergence(&I))
35       markDivergent(I);
36     else if (TTI->isAlwaysUniform(&I))
37       addUniformOverride(I);
38   }
39   for (auto &Arg : F.args()) {
40     if (TTI->isSourceOfDivergence(&Arg)) {
41       markDivergent(&Arg);
42     }
43   }
44 }
45 
46 template <>
47 void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
48     const Value *V) {
49   for (const auto *User : V->users()) {
50     if (const auto *UserInstr = dyn_cast<const Instruction>(User)) {
51       markDivergent(*UserInstr);
52     }
53   }
54 }
55 
56 template <>
57 void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
58     const Instruction &Instr) {
59   assert(!isAlwaysUniform(Instr));
60   if (Instr.isTerminator())
61     return;
62   pushUsers(cast<Value>(&Instr));
63 }
64 
65 template <>
66 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
67     const Instruction &I, const Cycle &DefCycle) const {
68   assert(!isAlwaysUniform(I));
69   for (const Use &U : I.operands()) {
70     if (auto *I = dyn_cast<Instruction>(&U)) {
71       if (DefCycle.contains(I->getParent()))
72         return true;
73     }
74   }
75   return false;
76 }
77 
78 template <>
79 void llvm::GenericUniformityAnalysisImpl<
80     SSAContext>::propagateTemporalDivergence(const Instruction &I,
81                                              const Cycle &DefCycle) {
82   if (isDivergent(I))
83     return;
84   for (auto *User : I.users()) {
85     auto *UserInstr = cast<Instruction>(User);
86     if (DefCycle.contains(UserInstr->getParent()))
87       continue;
88     markDivergent(*UserInstr);
89   }
90 }
91 
92 template <>
93 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse(
94     const Use &U) const {
95   const auto *V = U.get();
96   if (isDivergent(V))
97     return true;
98   if (const auto *DefInstr = dyn_cast<Instruction>(V)) {
99     const auto *UseInstr = cast<Instruction>(U.getUser());
100     return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
101   }
102   return false;
103 }
104 
105 // This ensures explicit instantiation of
106 // GenericUniformityAnalysisImpl::ImplDeleter::operator()
107 template class llvm::GenericUniformityInfo<SSAContext>;
108 template struct llvm::GenericUniformityAnalysisImplDeleter<
109     llvm::GenericUniformityAnalysisImpl<SSAContext>>;
110 
111 //===----------------------------------------------------------------------===//
112 //  UniformityInfoAnalysis and related pass implementations
113 //===----------------------------------------------------------------------===//
114 
115 llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
116                                                  FunctionAnalysisManager &FAM) {
117   auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
118   auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
119   auto &CI = FAM.getResult<CycleAnalysis>(F);
120   UniformityInfo UI{DT, CI, &TTI};
121   // Skip computation if we can assume everything is uniform.
122   if (TTI.hasBranchDivergence(&F))
123     UI.compute();
124 
125   return UI;
126 }
127 
128 AnalysisKey UniformityInfoAnalysis::Key;
129 
130 UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS)
131     : OS(OS) {}
132 
133 PreservedAnalyses UniformityInfoPrinterPass::run(Function &F,
134                                                  FunctionAnalysisManager &AM) {
135   OS << "UniformityInfo for function '" << F.getName() << "':\n";
136   AM.getResult<UniformityInfoAnalysis>(F).print(OS);
137 
138   return PreservedAnalyses::all();
139 }
140 
141 //===----------------------------------------------------------------------===//
142 //  UniformityInfoWrapperPass Implementation
143 //===----------------------------------------------------------------------===//
144 
145 char UniformityInfoWrapperPass::ID = 0;
146 
147 UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {
148   initializeUniformityInfoWrapperPassPass(*PassRegistry::getPassRegistry());
149 }
150 
151 INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
152                       "Uniformity Analysis", true, true)
153 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
154 INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
155 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
156 INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
157                     "Uniformity Analysis", true, true)
158 
159 void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
160   AU.setPreservesAll();
161   AU.addRequired<DominatorTreeWrapperPass>();
162   AU.addRequiredTransitive<CycleInfoWrapperPass>();
163   AU.addRequired<TargetTransformInfoWrapperPass>();
164 }
165 
166 bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
167   auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
168   auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
169   auto &targetTransformInfo =
170       getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
171 
172   m_function = &F;
173   m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo};
174 
175   // Skip computation if we can assume everything is uniform.
176   if (targetTransformInfo.hasBranchDivergence(m_function))
177     m_uniformityInfo.compute();
178 
179   return false;
180 }
181 
182 void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const {
183   OS << "UniformityInfo for function '" << m_function->getName() << "':\n";
184 }
185 
186 void UniformityInfoWrapperPass::releaseMemory() {
187   m_uniformityInfo = UniformityInfo{};
188   m_function = nullptr;
189 }
190