xref: /llvm-project/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp (revision b0f0dd2554c726e5192ad8c98fb7a2f08c37994c)
1 //===- MachineUniformityAnalysis.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/CodeGen/MachineUniformityAnalysis.h"
10 #include "llvm/ADT/GenericUniformityImpl.h"
11 #include "llvm/CodeGen/MachineCycleAnalysis.h"
12 #include "llvm/CodeGen/MachineDominators.h"
13 #include "llvm/CodeGen/MachineRegisterInfo.h"
14 #include "llvm/CodeGen/MachineSSAContext.h"
15 #include "llvm/CodeGen/TargetInstrInfo.h"
16 #include "llvm/InitializePasses.h"
17 
18 using namespace llvm;
19 
20 template <>
21 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs(
22     const MachineInstr &I) const {
23   for (auto &op : I.operands()) {
24     if (!op.isReg() || !op.isDef())
25       continue;
26     if (isDivergent(op.getReg()))
27       return true;
28   }
29   return false;
30 }
31 
32 template <>
33 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent(
34     const MachineInstr &Instr, bool AllDefsDivergent) {
35   bool insertedDivergent = false;
36   const auto &MRI = F.getRegInfo();
37   const auto &TRI = *MRI.getTargetRegisterInfo();
38   for (auto &op : Instr.operands()) {
39     if (!op.isReg() || !op.isDef())
40       continue;
41     if (!op.getReg().isVirtual())
42       continue;
43     assert(!op.getSubReg());
44     if (!AllDefsDivergent) {
45       auto *RC = MRI.getRegClassOrNull(op.getReg());
46       if (RC && !TRI.isDivergentRegClass(RC))
47         continue;
48     }
49     insertedDivergent |= markDivergent(op.getReg());
50   }
51   return insertedDivergent;
52 }
53 
54 template <>
55 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
56   const auto &InstrInfo = *F.getSubtarget().getInstrInfo();
57 
58   for (const MachineBasicBlock &block : F) {
59     for (const MachineInstr &instr : block) {
60       auto uniformity = InstrInfo.getInstructionUniformity(instr);
61       if (uniformity == InstructionUniformity::AlwaysUniform) {
62         addUniformOverride(instr);
63         continue;
64       }
65 
66       if (uniformity == InstructionUniformity::NeverUniform) {
67         markDefsDivergent(instr, /* AllDefsDivergent = */ false);
68       }
69     }
70   }
71 }
72 
73 template <>
74 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
75     Register Reg) {
76   const auto &RegInfo = F.getRegInfo();
77   for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
78     if (markDivergent(UserInstr))
79       Worklist.push_back(&UserInstr);
80   }
81 }
82 
83 template <>
84 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
85     const MachineInstr &Instr) {
86   assert(!isAlwaysUniform(Instr));
87   if (Instr.isTerminator())
88     return;
89   for (const MachineOperand &op : Instr.operands()) {
90     if (op.isReg() && op.isDef() && op.getReg().isVirtual())
91       pushUsers(op.getReg());
92   }
93 }
94 
95 template <>
96 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle(
97     const MachineInstr &I, const MachineCycle &DefCycle) const {
98   assert(!isAlwaysUniform(I));
99   for (auto &Op : I.operands()) {
100     if (!Op.isReg() || !Op.readsReg())
101       continue;
102     auto Reg = Op.getReg();
103 
104     // FIXME: Physical registers need to be properly checked instead of always
105     // returning true
106     if (Reg.isPhysical())
107       return true;
108 
109     auto *Def = F.getRegInfo().getVRegDef(Reg);
110     if (DefCycle.contains(Def->getParent()))
111       return true;
112   }
113   return false;
114 }
115 
116 template <>
117 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::
118     propagateTemporalDivergence(const MachineInstr &I,
119                                 const MachineCycle &DefCycle) {
120   const auto &RegInfo = F.getRegInfo();
121   for (auto &Op : I.operands()) {
122     if (!Op.isReg() || !Op.isDef())
123       continue;
124     if (!Op.getReg().isVirtual())
125       continue;
126     auto Reg = Op.getReg();
127     if (isDivergent(Reg))
128       continue;
129     for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
130       if (DefCycle.contains(UserInstr.getParent()))
131         continue;
132       if (markDivergent(UserInstr))
133         Worklist.push_back(&UserInstr);
134     }
135   }
136 }
137 
138 template <>
139 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse(
140     const MachineOperand &U) const {
141   if (!U.isReg())
142     return false;
143 
144   auto Reg = U.getReg();
145   if (isDivergent(Reg))
146     return true;
147 
148   const auto &RegInfo = F.getRegInfo();
149   auto *Def = RegInfo.getOneDef(Reg);
150   if (!Def)
151     return true;
152 
153   auto *DefInstr = Def->getParent();
154   auto *UseInstr = U.getParent();
155   return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
156 }
157 
158 // This ensures explicit instantiation of
159 // GenericUniformityAnalysisImpl::ImplDeleter::operator()
160 template class llvm::GenericUniformityInfo<MachineSSAContext>;
161 template struct llvm::GenericUniformityAnalysisImplDeleter<
162     llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>;
163 
164 MachineUniformityInfo
165 llvm::computeMachineUniformityInfo(MachineFunction &F,
166                                    const MachineCycleInfo &cycleInfo,
167                                    const MachineDomTree &domTree) {
168   assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
169   return MachineUniformityInfo(F, domTree, cycleInfo);
170 }
171 
172 namespace {
173 
174 /// Legacy analysis pass which computes a \ref MachineUniformityInfo.
175 class MachineUniformityAnalysisPass : public MachineFunctionPass {
176   MachineUniformityInfo UI;
177 
178 public:
179   static char ID;
180 
181   MachineUniformityAnalysisPass();
182 
183   MachineUniformityInfo &getUniformityInfo() { return UI; }
184   const MachineUniformityInfo &getUniformityInfo() const { return UI; }
185 
186   bool runOnMachineFunction(MachineFunction &F) override;
187   void getAnalysisUsage(AnalysisUsage &AU) const override;
188   void print(raw_ostream &OS, const Module *M = nullptr) const override;
189 
190   // TODO: verify analysis
191 };
192 
193 class MachineUniformityInfoPrinterPass : public MachineFunctionPass {
194 public:
195   static char ID;
196 
197   MachineUniformityInfoPrinterPass();
198 
199   bool runOnMachineFunction(MachineFunction &F) override;
200   void getAnalysisUsage(AnalysisUsage &AU) const override;
201 };
202 
203 } // namespace
204 
205 char MachineUniformityAnalysisPass::ID = 0;
206 
207 MachineUniformityAnalysisPass::MachineUniformityAnalysisPass()
208     : MachineFunctionPass(ID) {
209   initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry());
210 }
211 
212 INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
213                       "Machine Uniformity Info Analysis", true, true)
214 INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
215 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
216 INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
217                     "Machine Uniformity Info Analysis", true, true)
218 
219 void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
220   AU.setPreservesAll();
221   AU.addRequired<MachineCycleInfoWrapperPass>();
222   AU.addRequired<MachineDominatorTree>();
223   MachineFunctionPass::getAnalysisUsage(AU);
224 }
225 
226 bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
227   auto &DomTree = getAnalysis<MachineDominatorTree>().getBase();
228   auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
229   UI = computeMachineUniformityInfo(MF, CI, DomTree);
230   return false;
231 }
232 
233 void MachineUniformityAnalysisPass::print(raw_ostream &OS,
234                                           const Module *) const {
235   OS << "MachineUniformityInfo for function: " << UI.getFunction().getName()
236      << "\n";
237   UI.print(OS);
238 }
239 
240 char MachineUniformityInfoPrinterPass::ID = 0;
241 
242 MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass()
243     : MachineFunctionPass(ID) {
244   initializeMachineUniformityInfoPrinterPassPass(
245       *PassRegistry::getPassRegistry());
246 }
247 
248 INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass,
249                       "print-machine-uniformity",
250                       "Print Machine Uniformity Info Analysis", true, true)
251 INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)
252 INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass,
253                     "print-machine-uniformity",
254                     "Print Machine Uniformity Info Analysis", true, true)
255 
256 void MachineUniformityInfoPrinterPass::getAnalysisUsage(
257     AnalysisUsage &AU) const {
258   AU.setPreservesAll();
259   AU.addRequired<MachineUniformityAnalysisPass>();
260   MachineFunctionPass::getAnalysisUsage(AU);
261 }
262 
263 bool MachineUniformityInfoPrinterPass::runOnMachineFunction(
264     MachineFunction &F) {
265   auto &UI = getAnalysis<MachineUniformityAnalysisPass>();
266   UI.print(errs());
267   return false;
268 }
269