1 //===- MachineUniformityAnalysis.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/CodeGen/MachineUniformityAnalysis.h" 10 #include "llvm/ADT/GenericUniformityImpl.h" 11 #include "llvm/CodeGen/MachineCycleAnalysis.h" 12 #include "llvm/CodeGen/MachineDominators.h" 13 #include "llvm/CodeGen/MachineRegisterInfo.h" 14 #include "llvm/CodeGen/MachineSSAContext.h" 15 #include "llvm/CodeGen/TargetInstrInfo.h" 16 #include "llvm/InitializePasses.h" 17 18 using namespace llvm; 19 20 template <> 21 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs( 22 const MachineInstr &I) const { 23 for (auto &op : I.operands()) { 24 if (!op.isReg() || !op.isDef()) 25 continue; 26 if (isDivergent(op.getReg())) 27 return true; 28 } 29 return false; 30 } 31 32 template <> 33 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent( 34 const MachineInstr &Instr) { 35 bool insertedDivergent = false; 36 const auto &MRI = F.getRegInfo(); 37 const auto &RBI = *F.getSubtarget().getRegBankInfo(); 38 const auto &TRI = *MRI.getTargetRegisterInfo(); 39 for (auto &op : Instr.operands()) { 40 if (!op.isReg() || !op.isDef()) 41 continue; 42 if (!op.getReg().isVirtual()) 43 continue; 44 assert(!op.getSubReg()); 45 if (TRI.isUniformReg(MRI, RBI, op.getReg())) 46 continue; 47 insertedDivergent |= markDivergent(op.getReg()); 48 } 49 return insertedDivergent; 50 } 51 52 template <> 53 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() { 54 const auto &InstrInfo = *F.getSubtarget().getInstrInfo(); 55 56 for (const MachineBasicBlock &block : F) { 57 for (const MachineInstr &instr : block) { 58 auto uniformity = InstrInfo.getInstructionUniformity(instr); 59 if (uniformity == InstructionUniformity::AlwaysUniform) { 60 addUniformOverride(instr); 61 continue; 62 } 63 64 if (uniformity == InstructionUniformity::NeverUniform) { 65 markDivergent(instr); 66 } 67 } 68 } 69 } 70 71 template <> 72 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( 73 Register Reg) { 74 assert(isDivergent(Reg)); 75 const auto &RegInfo = F.getRegInfo(); 76 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { 77 markDivergent(UserInstr); 78 } 79 } 80 81 template <> 82 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( 83 const MachineInstr &Instr) { 84 assert(!isAlwaysUniform(Instr)); 85 if (Instr.isTerminator()) 86 return; 87 for (const MachineOperand &op : Instr.operands()) { 88 if (!op.isReg() || !op.isDef()) 89 continue; 90 auto Reg = op.getReg(); 91 if (isDivergent(Reg)) 92 pushUsers(Reg); 93 } 94 } 95 96 template <> 97 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle( 98 const MachineInstr &I, const MachineCycle &DefCycle) const { 99 assert(!isAlwaysUniform(I)); 100 for (auto &Op : I.operands()) { 101 if (!Op.isReg() || !Op.readsReg()) 102 continue; 103 auto Reg = Op.getReg(); 104 105 // FIXME: Physical registers need to be properly checked instead of always 106 // returning true 107 if (Reg.isPhysical()) 108 return true; 109 110 auto *Def = F.getRegInfo().getVRegDef(Reg); 111 if (DefCycle.contains(Def->getParent())) 112 return true; 113 } 114 return false; 115 } 116 117 template <> 118 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>:: 119 propagateTemporalDivergence(const MachineInstr &I, 120 const MachineCycle &DefCycle) { 121 const auto &RegInfo = F.getRegInfo(); 122 for (auto &Op : I.operands()) { 123 if (!Op.isReg() || !Op.isDef()) 124 continue; 125 if (!Op.getReg().isVirtual()) 126 continue; 127 auto Reg = Op.getReg(); 128 if (isDivergent(Reg)) 129 continue; 130 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { 131 if (DefCycle.contains(UserInstr.getParent())) 132 continue; 133 markDivergent(UserInstr); 134 } 135 } 136 } 137 138 template <> 139 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse( 140 const MachineOperand &U) const { 141 if (!U.isReg()) 142 return false; 143 144 auto Reg = U.getReg(); 145 if (isDivergent(Reg)) 146 return true; 147 148 const auto &RegInfo = F.getRegInfo(); 149 auto *Def = RegInfo.getOneDef(Reg); 150 if (!Def) 151 return true; 152 153 auto *DefInstr = Def->getParent(); 154 auto *UseInstr = U.getParent(); 155 return isTemporalDivergent(*UseInstr->getParent(), *DefInstr); 156 } 157 158 // This ensures explicit instantiation of 159 // GenericUniformityAnalysisImpl::ImplDeleter::operator() 160 template class llvm::GenericUniformityInfo<MachineSSAContext>; 161 template struct llvm::GenericUniformityAnalysisImplDeleter< 162 llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>; 163 164 MachineUniformityInfo 165 llvm::computeMachineUniformityInfo(MachineFunction &F, 166 const MachineCycleInfo &cycleInfo, 167 const MachineDomTree &domTree) { 168 assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!"); 169 return MachineUniformityInfo(F, domTree, cycleInfo); 170 } 171 172 namespace { 173 174 /// Legacy analysis pass which computes a \ref MachineUniformityInfo. 175 class MachineUniformityAnalysisPass : public MachineFunctionPass { 176 MachineUniformityInfo UI; 177 178 public: 179 static char ID; 180 181 MachineUniformityAnalysisPass(); 182 183 MachineUniformityInfo &getUniformityInfo() { return UI; } 184 const MachineUniformityInfo &getUniformityInfo() const { return UI; } 185 186 bool runOnMachineFunction(MachineFunction &F) override; 187 void getAnalysisUsage(AnalysisUsage &AU) const override; 188 void print(raw_ostream &OS, const Module *M = nullptr) const override; 189 190 // TODO: verify analysis 191 }; 192 193 class MachineUniformityInfoPrinterPass : public MachineFunctionPass { 194 public: 195 static char ID; 196 197 MachineUniformityInfoPrinterPass(); 198 199 bool runOnMachineFunction(MachineFunction &F) override; 200 void getAnalysisUsage(AnalysisUsage &AU) const override; 201 }; 202 203 } // namespace 204 205 char MachineUniformityAnalysisPass::ID = 0; 206 207 MachineUniformityAnalysisPass::MachineUniformityAnalysisPass() 208 : MachineFunctionPass(ID) { 209 initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry()); 210 } 211 212 INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity", 213 "Machine Uniformity Info Analysis", true, true) 214 INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass) 215 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree) 216 INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity", 217 "Machine Uniformity Info Analysis", true, true) 218 219 void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const { 220 AU.setPreservesAll(); 221 AU.addRequired<MachineCycleInfoWrapperPass>(); 222 AU.addRequired<MachineDominatorTree>(); 223 MachineFunctionPass::getAnalysisUsage(AU); 224 } 225 226 bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) { 227 auto &DomTree = getAnalysis<MachineDominatorTree>().getBase(); 228 auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo(); 229 UI = computeMachineUniformityInfo(MF, CI, DomTree); 230 return false; 231 } 232 233 void MachineUniformityAnalysisPass::print(raw_ostream &OS, 234 const Module *) const { 235 OS << "MachineUniformityInfo for function: " << UI.getFunction().getName() 236 << "\n"; 237 UI.print(OS); 238 } 239 240 char MachineUniformityInfoPrinterPass::ID = 0; 241 242 MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass() 243 : MachineFunctionPass(ID) { 244 initializeMachineUniformityInfoPrinterPassPass( 245 *PassRegistry::getPassRegistry()); 246 } 247 248 INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass, 249 "print-machine-uniformity", 250 "Print Machine Uniformity Info Analysis", true, true) 251 INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass) 252 INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass, 253 "print-machine-uniformity", 254 "Print Machine Uniformity Info Analysis", true, true) 255 256 void MachineUniformityInfoPrinterPass::getAnalysisUsage( 257 AnalysisUsage &AU) const { 258 AU.setPreservesAll(); 259 AU.addRequired<MachineUniformityAnalysisPass>(); 260 MachineFunctionPass::getAnalysisUsage(AU); 261 } 262 263 bool MachineUniformityInfoPrinterPass::runOnMachineFunction( 264 MachineFunction &F) { 265 auto &UI = getAnalysis<MachineUniformityAnalysisPass>(); 266 UI.print(errs()); 267 return false; 268 } 269