1 //===- MachineUniformityAnalysis.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/CodeGen/MachineUniformityAnalysis.h" 10 #include "llvm/ADT/GenericUniformityImpl.h" 11 #include "llvm/CodeGen/MachineCycleAnalysis.h" 12 #include "llvm/CodeGen/MachineDominators.h" 13 #include "llvm/CodeGen/MachineRegisterInfo.h" 14 #include "llvm/CodeGen/MachineSSAContext.h" 15 #include "llvm/CodeGen/TargetInstrInfo.h" 16 #include "llvm/InitializePasses.h" 17 18 using namespace llvm; 19 20 template <> 21 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs( 22 const MachineInstr &I) const { 23 for (auto &op : I.all_defs()) { 24 if (isDivergent(op.getReg())) 25 return true; 26 } 27 return false; 28 } 29 30 template <> 31 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent( 32 const MachineInstr &Instr) { 33 bool insertedDivergent = false; 34 const auto &MRI = F.getRegInfo(); 35 const auto &RBI = *F.getSubtarget().getRegBankInfo(); 36 const auto &TRI = *MRI.getTargetRegisterInfo(); 37 for (auto &op : Instr.all_defs()) { 38 if (!op.getReg().isVirtual()) 39 continue; 40 assert(!op.getSubReg()); 41 if (TRI.isUniformReg(MRI, RBI, op.getReg())) 42 continue; 43 insertedDivergent |= markDivergent(op.getReg()); 44 } 45 return insertedDivergent; 46 } 47 48 template <> 49 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() { 50 const auto &InstrInfo = *F.getSubtarget().getInstrInfo(); 51 52 for (const MachineBasicBlock &block : F) { 53 for (const MachineInstr &instr : block) { 54 auto uniformity = InstrInfo.getInstructionUniformity(instr); 55 if (uniformity == InstructionUniformity::AlwaysUniform) { 56 addUniformOverride(instr); 57 continue; 58 } 59 60 if (uniformity == InstructionUniformity::NeverUniform) { 61 markDivergent(instr); 62 } 63 } 64 } 65 } 66 67 template <> 68 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( 69 Register Reg) { 70 assert(isDivergent(Reg)); 71 const auto &RegInfo = F.getRegInfo(); 72 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { 73 markDivergent(UserInstr); 74 } 75 } 76 77 template <> 78 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( 79 const MachineInstr &Instr) { 80 assert(!isAlwaysUniform(Instr)); 81 if (Instr.isTerminator()) 82 return; 83 for (const MachineOperand &op : Instr.all_defs()) { 84 auto Reg = op.getReg(); 85 if (isDivergent(Reg)) 86 pushUsers(Reg); 87 } 88 } 89 90 template <> 91 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle( 92 const MachineInstr &I, const MachineCycle &DefCycle) const { 93 assert(!isAlwaysUniform(I)); 94 for (auto &Op : I.operands()) { 95 if (!Op.isReg() || !Op.readsReg()) 96 continue; 97 auto Reg = Op.getReg(); 98 99 // FIXME: Physical registers need to be properly checked instead of always 100 // returning true 101 if (Reg.isPhysical()) 102 return true; 103 104 auto *Def = F.getRegInfo().getVRegDef(Reg); 105 if (DefCycle.contains(Def->getParent())) 106 return true; 107 } 108 return false; 109 } 110 111 template <> 112 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>:: 113 propagateTemporalDivergence(const MachineInstr &I, 114 const MachineCycle &DefCycle) { 115 const auto &RegInfo = F.getRegInfo(); 116 for (auto &Op : I.all_defs()) { 117 if (!Op.getReg().isVirtual()) 118 continue; 119 auto Reg = Op.getReg(); 120 if (isDivergent(Reg)) 121 continue; 122 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { 123 if (DefCycle.contains(UserInstr.getParent())) 124 continue; 125 markDivergent(UserInstr); 126 } 127 } 128 } 129 130 template <> 131 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse( 132 const MachineOperand &U) const { 133 if (!U.isReg()) 134 return false; 135 136 auto Reg = U.getReg(); 137 if (isDivergent(Reg)) 138 return true; 139 140 const auto &RegInfo = F.getRegInfo(); 141 auto *Def = RegInfo.getOneDef(Reg); 142 if (!Def) 143 return true; 144 145 auto *DefInstr = Def->getParent(); 146 auto *UseInstr = U.getParent(); 147 return isTemporalDivergent(*UseInstr->getParent(), *DefInstr); 148 } 149 150 // This ensures explicit instantiation of 151 // GenericUniformityAnalysisImpl::ImplDeleter::operator() 152 template class llvm::GenericUniformityInfo<MachineSSAContext>; 153 template struct llvm::GenericUniformityAnalysisImplDeleter< 154 llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>; 155 156 MachineUniformityInfo 157 llvm::computeMachineUniformityInfo(MachineFunction &F, 158 const MachineCycleInfo &cycleInfo, 159 const MachineDomTree &domTree) { 160 assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!"); 161 return MachineUniformityInfo(F, domTree, cycleInfo); 162 } 163 164 namespace { 165 166 /// Legacy analysis pass which computes a \ref MachineUniformityInfo. 167 class MachineUniformityAnalysisPass : public MachineFunctionPass { 168 MachineUniformityInfo UI; 169 170 public: 171 static char ID; 172 173 MachineUniformityAnalysisPass(); 174 175 MachineUniformityInfo &getUniformityInfo() { return UI; } 176 const MachineUniformityInfo &getUniformityInfo() const { return UI; } 177 178 bool runOnMachineFunction(MachineFunction &F) override; 179 void getAnalysisUsage(AnalysisUsage &AU) const override; 180 void print(raw_ostream &OS, const Module *M = nullptr) const override; 181 182 // TODO: verify analysis 183 }; 184 185 class MachineUniformityInfoPrinterPass : public MachineFunctionPass { 186 public: 187 static char ID; 188 189 MachineUniformityInfoPrinterPass(); 190 191 bool runOnMachineFunction(MachineFunction &F) override; 192 void getAnalysisUsage(AnalysisUsage &AU) const override; 193 }; 194 195 } // namespace 196 197 char MachineUniformityAnalysisPass::ID = 0; 198 199 MachineUniformityAnalysisPass::MachineUniformityAnalysisPass() 200 : MachineFunctionPass(ID) { 201 initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry()); 202 } 203 204 INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity", 205 "Machine Uniformity Info Analysis", true, true) 206 INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass) 207 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree) 208 INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity", 209 "Machine Uniformity Info Analysis", true, true) 210 211 void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const { 212 AU.setPreservesAll(); 213 AU.addRequired<MachineCycleInfoWrapperPass>(); 214 AU.addRequired<MachineDominatorTree>(); 215 MachineFunctionPass::getAnalysisUsage(AU); 216 } 217 218 bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) { 219 auto &DomTree = getAnalysis<MachineDominatorTree>().getBase(); 220 auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo(); 221 UI = computeMachineUniformityInfo(MF, CI, DomTree); 222 return false; 223 } 224 225 void MachineUniformityAnalysisPass::print(raw_ostream &OS, 226 const Module *) const { 227 OS << "MachineUniformityInfo for function: " << UI.getFunction().getName() 228 << "\n"; 229 UI.print(OS); 230 } 231 232 char MachineUniformityInfoPrinterPass::ID = 0; 233 234 MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass() 235 : MachineFunctionPass(ID) { 236 initializeMachineUniformityInfoPrinterPassPass( 237 *PassRegistry::getPassRegistry()); 238 } 239 240 INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass, 241 "print-machine-uniformity", 242 "Print Machine Uniformity Info Analysis", true, true) 243 INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass) 244 INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass, 245 "print-machine-uniformity", 246 "Print Machine Uniformity Info Analysis", true, true) 247 248 void MachineUniformityInfoPrinterPass::getAnalysisUsage( 249 AnalysisUsage &AU) const { 250 AU.setPreservesAll(); 251 AU.addRequired<MachineUniformityAnalysisPass>(); 252 MachineFunctionPass::getAnalysisUsage(AU); 253 } 254 255 bool MachineUniformityInfoPrinterPass::runOnMachineFunction( 256 MachineFunction &F) { 257 auto &UI = getAnalysis<MachineUniformityAnalysisPass>(); 258 UI.print(errs()); 259 return false; 260 } 261