xref: /llvm-project/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp (revision fbe1c0616fa83d39ebad29cfefa020bbebd90057)
1 //===- MachineUniformityAnalysis.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/CodeGen/MachineUniformityAnalysis.h"
10 #include "llvm/ADT/GenericUniformityImpl.h"
11 #include "llvm/CodeGen/MachineCycleAnalysis.h"
12 #include "llvm/CodeGen/MachineDominators.h"
13 #include "llvm/CodeGen/MachineRegisterInfo.h"
14 #include "llvm/CodeGen/MachineSSAContext.h"
15 #include "llvm/CodeGen/TargetInstrInfo.h"
16 #include "llvm/InitializePasses.h"
17 
18 using namespace llvm;
19 
20 template <>
21 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs(
22     const MachineInstr &I) const {
23   for (auto &op : I.operands()) {
24     if (!op.isReg() || !op.isDef())
25       continue;
26     if (isDivergent(op.getReg()))
27       return true;
28   }
29   return false;
30 }
31 
32 template <>
33 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent(
34     const MachineInstr &Instr) {
35   bool insertedDivergent = false;
36   const auto &MRI = F.getRegInfo();
37   const auto &RBI = *F.getSubtarget().getRegBankInfo();
38   const auto &TRI = *MRI.getTargetRegisterInfo();
39   for (auto &op : Instr.operands()) {
40     if (!op.isReg() || !op.isDef())
41       continue;
42     if (!op.getReg().isVirtual())
43       continue;
44     assert(!op.getSubReg());
45     if (TRI.isUniformReg(MRI, RBI, op.getReg()))
46       continue;
47     insertedDivergent |= markDivergent(op.getReg());
48   }
49   return insertedDivergent;
50 }
51 
52 template <>
53 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
54   const auto &InstrInfo = *F.getSubtarget().getInstrInfo();
55 
56   for (const MachineBasicBlock &block : F) {
57     for (const MachineInstr &instr : block) {
58       auto uniformity = InstrInfo.getInstructionUniformity(instr);
59       if (uniformity == InstructionUniformity::AlwaysUniform) {
60         addUniformOverride(instr);
61         continue;
62       }
63 
64       if (uniformity == InstructionUniformity::NeverUniform) {
65         if (markDivergent(instr))
66           Worklist.push_back(&instr);
67       }
68     }
69   }
70 }
71 
72 template <>
73 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
74     Register Reg) {
75   const auto &RegInfo = F.getRegInfo();
76   for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
77     if (markDivergent(UserInstr))
78       Worklist.push_back(&UserInstr);
79   }
80 }
81 
82 template <>
83 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
84     const MachineInstr &Instr) {
85   assert(!isAlwaysUniform(Instr));
86   if (Instr.isTerminator())
87     return;
88   for (const MachineOperand &op : Instr.operands()) {
89     if (op.isReg() && op.isDef() && op.getReg().isVirtual())
90       pushUsers(op.getReg());
91   }
92 }
93 
94 template <>
95 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle(
96     const MachineInstr &I, const MachineCycle &DefCycle) const {
97   assert(!isAlwaysUniform(I));
98   for (auto &Op : I.operands()) {
99     if (!Op.isReg() || !Op.readsReg())
100       continue;
101     auto Reg = Op.getReg();
102 
103     // FIXME: Physical registers need to be properly checked instead of always
104     // returning true
105     if (Reg.isPhysical())
106       return true;
107 
108     auto *Def = F.getRegInfo().getVRegDef(Reg);
109     if (DefCycle.contains(Def->getParent()))
110       return true;
111   }
112   return false;
113 }
114 
115 template <>
116 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::
117     propagateTemporalDivergence(const MachineInstr &I,
118                                 const MachineCycle &DefCycle) {
119   const auto &RegInfo = F.getRegInfo();
120   for (auto &Op : I.operands()) {
121     if (!Op.isReg() || !Op.isDef())
122       continue;
123     if (!Op.getReg().isVirtual())
124       continue;
125     auto Reg = Op.getReg();
126     if (isDivergent(Reg))
127       continue;
128     for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
129       if (DefCycle.contains(UserInstr.getParent()))
130         continue;
131       if (markDivergent(UserInstr))
132         Worklist.push_back(&UserInstr);
133     }
134   }
135 }
136 
137 template <>
138 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse(
139     const MachineOperand &U) const {
140   if (!U.isReg())
141     return false;
142 
143   auto Reg = U.getReg();
144   if (isDivergent(Reg))
145     return true;
146 
147   const auto &RegInfo = F.getRegInfo();
148   auto *Def = RegInfo.getOneDef(Reg);
149   if (!Def)
150     return true;
151 
152   auto *DefInstr = Def->getParent();
153   auto *UseInstr = U.getParent();
154   return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
155 }
156 
157 // This ensures explicit instantiation of
158 // GenericUniformityAnalysisImpl::ImplDeleter::operator()
159 template class llvm::GenericUniformityInfo<MachineSSAContext>;
160 template struct llvm::GenericUniformityAnalysisImplDeleter<
161     llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>;
162 
163 MachineUniformityInfo
164 llvm::computeMachineUniformityInfo(MachineFunction &F,
165                                    const MachineCycleInfo &cycleInfo,
166                                    const MachineDomTree &domTree) {
167   assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
168   return MachineUniformityInfo(F, domTree, cycleInfo);
169 }
170 
171 namespace {
172 
173 /// Legacy analysis pass which computes a \ref MachineUniformityInfo.
174 class MachineUniformityAnalysisPass : public MachineFunctionPass {
175   MachineUniformityInfo UI;
176 
177 public:
178   static char ID;
179 
180   MachineUniformityAnalysisPass();
181 
182   MachineUniformityInfo &getUniformityInfo() { return UI; }
183   const MachineUniformityInfo &getUniformityInfo() const { return UI; }
184 
185   bool runOnMachineFunction(MachineFunction &F) override;
186   void getAnalysisUsage(AnalysisUsage &AU) const override;
187   void print(raw_ostream &OS, const Module *M = nullptr) const override;
188 
189   // TODO: verify analysis
190 };
191 
192 class MachineUniformityInfoPrinterPass : public MachineFunctionPass {
193 public:
194   static char ID;
195 
196   MachineUniformityInfoPrinterPass();
197 
198   bool runOnMachineFunction(MachineFunction &F) override;
199   void getAnalysisUsage(AnalysisUsage &AU) const override;
200 };
201 
202 } // namespace
203 
204 char MachineUniformityAnalysisPass::ID = 0;
205 
206 MachineUniformityAnalysisPass::MachineUniformityAnalysisPass()
207     : MachineFunctionPass(ID) {
208   initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry());
209 }
210 
211 INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
212                       "Machine Uniformity Info Analysis", true, true)
213 INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
214 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
215 INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
216                     "Machine Uniformity Info Analysis", true, true)
217 
218 void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
219   AU.setPreservesAll();
220   AU.addRequired<MachineCycleInfoWrapperPass>();
221   AU.addRequired<MachineDominatorTree>();
222   MachineFunctionPass::getAnalysisUsage(AU);
223 }
224 
225 bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
226   auto &DomTree = getAnalysis<MachineDominatorTree>().getBase();
227   auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
228   UI = computeMachineUniformityInfo(MF, CI, DomTree);
229   return false;
230 }
231 
232 void MachineUniformityAnalysisPass::print(raw_ostream &OS,
233                                           const Module *) const {
234   OS << "MachineUniformityInfo for function: " << UI.getFunction().getName()
235      << "\n";
236   UI.print(OS);
237 }
238 
239 char MachineUniformityInfoPrinterPass::ID = 0;
240 
241 MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass()
242     : MachineFunctionPass(ID) {
243   initializeMachineUniformityInfoPrinterPassPass(
244       *PassRegistry::getPassRegistry());
245 }
246 
247 INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass,
248                       "print-machine-uniformity",
249                       "Print Machine Uniformity Info Analysis", true, true)
250 INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)
251 INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass,
252                     "print-machine-uniformity",
253                     "Print Machine Uniformity Info Analysis", true, true)
254 
255 void MachineUniformityInfoPrinterPass::getAnalysisUsage(
256     AnalysisUsage &AU) const {
257   AU.setPreservesAll();
258   AU.addRequired<MachineUniformityAnalysisPass>();
259   MachineFunctionPass::getAnalysisUsage(AU);
260 }
261 
262 bool MachineUniformityInfoPrinterPass::runOnMachineFunction(
263     MachineFunction &F) {
264   auto &UI = getAnalysis<MachineUniformityAnalysisPass>();
265   UI.print(errs());
266   return false;
267 }
268