xref: /llvm-project/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp (revision 475ce4c200ca640f1d6ccd097b167a04f009cb18)
1 //===- MachineUniformityAnalysis.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/CodeGen/MachineUniformityAnalysis.h"
10 #include "llvm/ADT/GenericUniformityImpl.h"
11 #include "llvm/CodeGen/MachineCycleAnalysis.h"
12 #include "llvm/CodeGen/MachineDominators.h"
13 #include "llvm/CodeGen/MachineRegisterInfo.h"
14 #include "llvm/CodeGen/MachineSSAContext.h"
15 #include "llvm/CodeGen/TargetInstrInfo.h"
16 #include "llvm/InitializePasses.h"
17 
18 using namespace llvm;
19 
20 template <>
21 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs(
22     const MachineInstr &I) const {
23   for (auto &op : I.operands()) {
24     if (!op.isReg() || !op.isDef())
25       continue;
26     if (isDivergent(op.getReg()))
27       return true;
28   }
29   return false;
30 }
31 
32 template <>
33 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent(
34     const MachineInstr &Instr, bool AllDefsDivergent) {
35   bool insertedDivergent = false;
36   const auto &MRI = F.getRegInfo();
37   const auto &TRI = *MRI.getTargetRegisterInfo();
38   for (auto &op : Instr.operands()) {
39     if (!op.isReg() || !op.isDef())
40       continue;
41     if (!op.getReg().isVirtual())
42       continue;
43     assert(!op.getSubReg());
44     if (!AllDefsDivergent) {
45       auto *RC = MRI.getRegClassOrNull(op.getReg());
46       if (RC && !TRI.isDivergentRegClass(RC))
47         continue;
48     }
49     insertedDivergent |= markDivergent(op.getReg());
50   }
51   return insertedDivergent;
52 }
53 
54 template <>
55 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
56   const auto &InstrInfo = *F.getSubtarget().getInstrInfo();
57 
58   for (const MachineBasicBlock &block : F) {
59     for (const MachineInstr &instr : block) {
60       auto uniformity = InstrInfo.getInstructionUniformity(instr);
61       if (uniformity == InstructionUniformity::AlwaysUniform) {
62         addUniformOverride(instr);
63         continue;
64       }
65 
66       if (uniformity == InstructionUniformity::NeverUniform) {
67         markDefsDivergent(instr, /* AllDefsDivergent = */ false);
68       }
69     }
70   }
71 }
72 
73 template <>
74 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
75     Register Reg) {
76   const auto &RegInfo = F.getRegInfo();
77   for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
78     if (isAlwaysUniform(UserInstr))
79       continue;
80     if (markDivergent(UserInstr))
81       Worklist.push_back(&UserInstr);
82   }
83 }
84 
85 template <>
86 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
87     const MachineInstr &Instr) {
88   assert(!isAlwaysUniform(Instr));
89   if (Instr.isTerminator())
90     return;
91   for (const MachineOperand &op : Instr.operands()) {
92     if (op.isReg() && op.isDef() && op.getReg().isVirtual())
93       pushUsers(op.getReg());
94   }
95 }
96 
97 template <>
98 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle(
99     const MachineInstr &I, const MachineCycle &DefCycle) const {
100   assert(!isAlwaysUniform(I));
101   for (auto &Op : I.operands()) {
102     if (!Op.isReg() || !Op.readsReg())
103       continue;
104     auto Reg = Op.getReg();
105     assert(Reg.isVirtual());
106     auto *Def = F.getRegInfo().getVRegDef(Reg);
107     if (DefCycle.contains(Def->getParent()))
108       return true;
109   }
110   return false;
111 }
112 
113 // This ensures explicit instantiation of
114 // GenericUniformityAnalysisImpl::ImplDeleter::operator()
115 template class llvm::GenericUniformityInfo<MachineSSAContext>;
116 
117 MachineUniformityInfo
118 llvm::computeMachineUniformityInfo(MachineFunction &F,
119                                    const MachineCycleInfo &cycleInfo,
120                                    const MachineDomTree &domTree) {
121   auto &MRI = F.getRegInfo();
122   assert(MRI.isSSA() && "Expected to be run on SSA form!");
123   return MachineUniformityInfo(F, domTree, cycleInfo);
124 }
125 
126 namespace {
127 
128 /// Legacy analysis pass which computes a \ref MachineUniformityInfo.
129 class MachineUniformityAnalysisPass : public MachineFunctionPass {
130   MachineUniformityInfo UI;
131 
132 public:
133   static char ID;
134 
135   MachineUniformityAnalysisPass();
136 
137   MachineUniformityInfo &getUniformityInfo() { return UI; }
138   const MachineUniformityInfo &getUniformityInfo() const { return UI; }
139 
140   bool runOnMachineFunction(MachineFunction &F) override;
141   void getAnalysisUsage(AnalysisUsage &AU) const override;
142   void print(raw_ostream &OS, const Module *M = nullptr) const override;
143 
144   // TODO: verify analysis
145 };
146 
147 class MachineUniformityInfoPrinterPass : public MachineFunctionPass {
148 public:
149   static char ID;
150 
151   MachineUniformityInfoPrinterPass();
152 
153   bool runOnMachineFunction(MachineFunction &F) override;
154   void getAnalysisUsage(AnalysisUsage &AU) const override;
155 };
156 
157 } // namespace
158 
159 char MachineUniformityAnalysisPass::ID = 0;
160 
161 MachineUniformityAnalysisPass::MachineUniformityAnalysisPass()
162     : MachineFunctionPass(ID) {
163   initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry());
164 }
165 
166 INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
167                       "Machine Uniformity Info Analysis", true, true)
168 INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
169 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
170 INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
171                     "Machine Uniformity Info Analysis", true, true)
172 
173 void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
174   AU.setPreservesAll();
175   AU.addRequired<MachineCycleInfoWrapperPass>();
176   AU.addRequired<MachineDominatorTree>();
177   MachineFunctionPass::getAnalysisUsage(AU);
178 }
179 
180 bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
181   auto &DomTree = getAnalysis<MachineDominatorTree>().getBase();
182   auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
183   UI = computeMachineUniformityInfo(MF, CI, DomTree);
184   return false;
185 }
186 
187 void MachineUniformityAnalysisPass::print(raw_ostream &OS,
188                                           const Module *) const {
189   OS << "MachineUniformityInfo for function: " << UI.getFunction().getName()
190      << "\n";
191   UI.print(OS);
192 }
193 
194 char MachineUniformityInfoPrinterPass::ID = 0;
195 
196 MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass()
197     : MachineFunctionPass(ID) {
198   initializeMachineUniformityInfoPrinterPassPass(
199       *PassRegistry::getPassRegistry());
200 }
201 
202 INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass,
203                       "print-machine-uniformity",
204                       "Print Machine Uniformity Info Analysis", true, true)
205 INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)
206 INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass,
207                     "print-machine-uniformity",
208                     "Print Machine Uniformity Info Analysis", true, true)
209 
210 void MachineUniformityInfoPrinterPass::getAnalysisUsage(
211     AnalysisUsage &AU) const {
212   AU.setPreservesAll();
213   AU.addRequired<MachineUniformityAnalysisPass>();
214   MachineFunctionPass::getAnalysisUsage(AU);
215 }
216 
217 bool MachineUniformityInfoPrinterPass::runOnMachineFunction(
218     MachineFunction &F) {
219   auto &UI = getAnalysis<MachineUniformityAnalysisPass>();
220   UI.print(errs());
221   return false;
222 }
223