1 //===- MachineUniformityAnalysis.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/CodeGen/MachineUniformityAnalysis.h" 10 #include "llvm/ADT/GenericUniformityImpl.h" 11 #include "llvm/CodeGen/MachineCycleAnalysis.h" 12 #include "llvm/CodeGen/MachineDominators.h" 13 #include "llvm/CodeGen/MachineRegisterInfo.h" 14 #include "llvm/CodeGen/MachineSSAContext.h" 15 #include "llvm/CodeGen/TargetInstrInfo.h" 16 #include "llvm/InitializePasses.h" 17 18 using namespace llvm; 19 20 template <> 21 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs( 22 const MachineInstr &I) const { 23 for (auto &op : I.operands()) { 24 if (!op.isReg() || !op.isDef()) 25 continue; 26 if (isDivergent(op.getReg())) 27 return true; 28 } 29 return false; 30 } 31 32 template <> 33 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent( 34 const MachineInstr &Instr, bool AllDefsDivergent) { 35 bool insertedDivergent = false; 36 const auto &MRI = F.getRegInfo(); 37 const auto &TRI = *MRI.getTargetRegisterInfo(); 38 for (auto &op : Instr.operands()) { 39 if (!op.isReg() || !op.isDef()) 40 continue; 41 if (!op.getReg().isVirtual()) 42 continue; 43 assert(!op.getSubReg()); 44 if (!AllDefsDivergent) { 45 auto *RC = MRI.getRegClassOrNull(op.getReg()); 46 if (RC && !TRI.isDivergentRegClass(RC)) 47 continue; 48 } 49 insertedDivergent |= markDivergent(op.getReg()); 50 } 51 return insertedDivergent; 52 } 53 54 template <> 55 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() { 56 const auto &InstrInfo = *F.getSubtarget().getInstrInfo(); 57 58 for (const MachineBasicBlock &block : F) { 59 for (const MachineInstr &instr : block) { 60 auto uniformity = InstrInfo.getInstructionUniformity(instr); 61 if (uniformity == InstructionUniformity::AlwaysUniform) { 62 addUniformOverride(instr); 63 continue; 64 } 65 66 if (uniformity == InstructionUniformity::NeverUniform) { 67 markDefsDivergent(instr, /* AllDefsDivergent = */ false); 68 } 69 } 70 } 71 } 72 73 template <> 74 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( 75 Register Reg) { 76 const auto &RegInfo = F.getRegInfo(); 77 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { 78 if (markDivergent(UserInstr)) 79 Worklist.push_back(&UserInstr); 80 } 81 } 82 83 template <> 84 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( 85 const MachineInstr &Instr) { 86 assert(!isAlwaysUniform(Instr)); 87 if (Instr.isTerminator()) 88 return; 89 for (const MachineOperand &op : Instr.operands()) { 90 if (op.isReg() && op.isDef() && op.getReg().isVirtual()) 91 pushUsers(op.getReg()); 92 } 93 } 94 95 template <> 96 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle( 97 const MachineInstr &I, const MachineCycle &DefCycle) const { 98 assert(!isAlwaysUniform(I)); 99 for (auto &Op : I.operands()) { 100 if (!Op.isReg() || !Op.readsReg()) 101 continue; 102 auto Reg = Op.getReg(); 103 104 // FIXME: Physical registers need to be properly checked instead of always 105 // returning true 106 if (Reg.isPhysical()) 107 return true; 108 109 auto *Def = F.getRegInfo().getVRegDef(Reg); 110 if (DefCycle.contains(Def->getParent())) 111 return true; 112 } 113 return false; 114 } 115 116 template <> 117 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse( 118 const MachineOperand &U) const { 119 if (!U.isReg()) 120 return false; 121 122 auto Reg = U.getReg(); 123 if (isDivergent(Reg)) 124 return true; 125 126 const auto &RegInfo = F.getRegInfo(); 127 auto *Def = RegInfo.getOneDef(Reg); 128 if (!Def) 129 return true; 130 131 auto *DefInstr = Def->getParent(); 132 auto *UseInstr = U.getParent(); 133 return isTemporalDivergent(*UseInstr->getParent(), *DefInstr); 134 } 135 136 // This ensures explicit instantiation of 137 // GenericUniformityAnalysisImpl::ImplDeleter::operator() 138 template class llvm::GenericUniformityInfo<MachineSSAContext>; 139 template struct llvm::GenericUniformityAnalysisImplDeleter< 140 llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>; 141 142 MachineUniformityInfo 143 llvm::computeMachineUniformityInfo(MachineFunction &F, 144 const MachineCycleInfo &cycleInfo, 145 const MachineDomTree &domTree) { 146 assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!"); 147 return MachineUniformityInfo(F, domTree, cycleInfo); 148 } 149 150 namespace { 151 152 /// Legacy analysis pass which computes a \ref MachineUniformityInfo. 153 class MachineUniformityAnalysisPass : public MachineFunctionPass { 154 MachineUniformityInfo UI; 155 156 public: 157 static char ID; 158 159 MachineUniformityAnalysisPass(); 160 161 MachineUniformityInfo &getUniformityInfo() { return UI; } 162 const MachineUniformityInfo &getUniformityInfo() const { return UI; } 163 164 bool runOnMachineFunction(MachineFunction &F) override; 165 void getAnalysisUsage(AnalysisUsage &AU) const override; 166 void print(raw_ostream &OS, const Module *M = nullptr) const override; 167 168 // TODO: verify analysis 169 }; 170 171 class MachineUniformityInfoPrinterPass : public MachineFunctionPass { 172 public: 173 static char ID; 174 175 MachineUniformityInfoPrinterPass(); 176 177 bool runOnMachineFunction(MachineFunction &F) override; 178 void getAnalysisUsage(AnalysisUsage &AU) const override; 179 }; 180 181 } // namespace 182 183 char MachineUniformityAnalysisPass::ID = 0; 184 185 MachineUniformityAnalysisPass::MachineUniformityAnalysisPass() 186 : MachineFunctionPass(ID) { 187 initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry()); 188 } 189 190 INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity", 191 "Machine Uniformity Info Analysis", true, true) 192 INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass) 193 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree) 194 INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity", 195 "Machine Uniformity Info Analysis", true, true) 196 197 void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const { 198 AU.setPreservesAll(); 199 AU.addRequired<MachineCycleInfoWrapperPass>(); 200 AU.addRequired<MachineDominatorTree>(); 201 MachineFunctionPass::getAnalysisUsage(AU); 202 } 203 204 bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) { 205 auto &DomTree = getAnalysis<MachineDominatorTree>().getBase(); 206 auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo(); 207 UI = computeMachineUniformityInfo(MF, CI, DomTree); 208 return false; 209 } 210 211 void MachineUniformityAnalysisPass::print(raw_ostream &OS, 212 const Module *) const { 213 OS << "MachineUniformityInfo for function: " << UI.getFunction().getName() 214 << "\n"; 215 UI.print(OS); 216 } 217 218 char MachineUniformityInfoPrinterPass::ID = 0; 219 220 MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass() 221 : MachineFunctionPass(ID) { 222 initializeMachineUniformityInfoPrinterPassPass( 223 *PassRegistry::getPassRegistry()); 224 } 225 226 INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass, 227 "print-machine-uniformity", 228 "Print Machine Uniformity Info Analysis", true, true) 229 INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass) 230 INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass, 231 "print-machine-uniformity", 232 "Print Machine Uniformity Info Analysis", true, true) 233 234 void MachineUniformityInfoPrinterPass::getAnalysisUsage( 235 AnalysisUsage &AU) const { 236 AU.setPreservesAll(); 237 AU.addRequired<MachineUniformityAnalysisPass>(); 238 MachineFunctionPass::getAnalysisUsage(AU); 239 } 240 241 bool MachineUniformityInfoPrinterPass::runOnMachineFunction( 242 MachineFunction &F) { 243 auto &UI = getAnalysis<MachineUniformityAnalysisPass>(); 244 UI.print(errs()); 245 return false; 246 } 247