1 //===- MachineUniformityAnalysis.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/CodeGen/MachineUniformityAnalysis.h" 10 #include "llvm/ADT/GenericUniformityImpl.h" 11 #include "llvm/CodeGen/MachineCycleAnalysis.h" 12 #include "llvm/CodeGen/MachineDominators.h" 13 #include "llvm/CodeGen/MachineRegisterInfo.h" 14 #include "llvm/CodeGen/MachineSSAContext.h" 15 #include "llvm/CodeGen/TargetInstrInfo.h" 16 #include "llvm/InitializePasses.h" 17 18 using namespace llvm; 19 20 template <> 21 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs( 22 const MachineInstr &I) const { 23 for (auto &op : I.operands()) { 24 if (!op.isReg() || !op.isDef()) 25 continue; 26 if (isDivergent(op.getReg())) 27 return true; 28 } 29 return false; 30 } 31 32 template <> 33 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent( 34 const MachineInstr &Instr) { 35 bool insertedDivergent = false; 36 const auto &MRI = F.getRegInfo(); 37 const auto &RBI = *F.getSubtarget().getRegBankInfo(); 38 const auto &TRI = *MRI.getTargetRegisterInfo(); 39 for (auto &op : Instr.operands()) { 40 if (!op.isReg() || !op.isDef()) 41 continue; 42 if (!op.getReg().isVirtual()) 43 continue; 44 assert(!op.getSubReg()); 45 if (TRI.isUniformReg(MRI, RBI, op.getReg())) 46 continue; 47 insertedDivergent |= markDivergent(op.getReg()); 48 } 49 return insertedDivergent; 50 } 51 52 template <> 53 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() { 54 const auto &InstrInfo = *F.getSubtarget().getInstrInfo(); 55 56 for (const MachineBasicBlock &block : F) { 57 for (const MachineInstr &instr : block) { 58 auto uniformity = InstrInfo.getInstructionUniformity(instr); 59 if (uniformity == InstructionUniformity::AlwaysUniform) { 60 addUniformOverride(instr); 61 continue; 62 } 63 64 if (uniformity == InstructionUniformity::NeverUniform) { 65 if (markDivergent(instr)) 66 Worklist.push_back(&instr); 67 } 68 } 69 } 70 } 71 72 template <> 73 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( 74 Register Reg) { 75 const auto &RegInfo = F.getRegInfo(); 76 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { 77 if (markDivergent(UserInstr)) 78 Worklist.push_back(&UserInstr); 79 } 80 } 81 82 template <> 83 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( 84 const MachineInstr &Instr) { 85 assert(!isAlwaysUniform(Instr)); 86 if (Instr.isTerminator()) 87 return; 88 for (const MachineOperand &op : Instr.operands()) { 89 if (op.isReg() && op.isDef() && op.getReg().isVirtual()) 90 pushUsers(op.getReg()); 91 } 92 } 93 94 template <> 95 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle( 96 const MachineInstr &I, const MachineCycle &DefCycle) const { 97 assert(!isAlwaysUniform(I)); 98 for (auto &Op : I.operands()) { 99 if (!Op.isReg() || !Op.readsReg()) 100 continue; 101 auto Reg = Op.getReg(); 102 103 // FIXME: Physical registers need to be properly checked instead of always 104 // returning true 105 if (Reg.isPhysical()) 106 return true; 107 108 auto *Def = F.getRegInfo().getVRegDef(Reg); 109 if (DefCycle.contains(Def->getParent())) 110 return true; 111 } 112 return false; 113 } 114 115 template <> 116 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>:: 117 propagateTemporalDivergence(const MachineInstr &I, 118 const MachineCycle &DefCycle) { 119 const auto &RegInfo = F.getRegInfo(); 120 for (auto &Op : I.operands()) { 121 if (!Op.isReg() || !Op.isDef()) 122 continue; 123 if (!Op.getReg().isVirtual()) 124 continue; 125 auto Reg = Op.getReg(); 126 if (isDivergent(Reg)) 127 continue; 128 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { 129 if (DefCycle.contains(UserInstr.getParent())) 130 continue; 131 if (markDivergent(UserInstr)) 132 Worklist.push_back(&UserInstr); 133 } 134 } 135 } 136 137 template <> 138 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse( 139 const MachineOperand &U) const { 140 if (!U.isReg()) 141 return false; 142 143 auto Reg = U.getReg(); 144 if (isDivergent(Reg)) 145 return true; 146 147 const auto &RegInfo = F.getRegInfo(); 148 auto *Def = RegInfo.getOneDef(Reg); 149 if (!Def) 150 return true; 151 152 auto *DefInstr = Def->getParent(); 153 auto *UseInstr = U.getParent(); 154 return isTemporalDivergent(*UseInstr->getParent(), *DefInstr); 155 } 156 157 // This ensures explicit instantiation of 158 // GenericUniformityAnalysisImpl::ImplDeleter::operator() 159 template class llvm::GenericUniformityInfo<MachineSSAContext>; 160 template struct llvm::GenericUniformityAnalysisImplDeleter< 161 llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>; 162 163 MachineUniformityInfo 164 llvm::computeMachineUniformityInfo(MachineFunction &F, 165 const MachineCycleInfo &cycleInfo, 166 const MachineDomTree &domTree) { 167 assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!"); 168 return MachineUniformityInfo(F, domTree, cycleInfo); 169 } 170 171 namespace { 172 173 /// Legacy analysis pass which computes a \ref MachineUniformityInfo. 174 class MachineUniformityAnalysisPass : public MachineFunctionPass { 175 MachineUniformityInfo UI; 176 177 public: 178 static char ID; 179 180 MachineUniformityAnalysisPass(); 181 182 MachineUniformityInfo &getUniformityInfo() { return UI; } 183 const MachineUniformityInfo &getUniformityInfo() const { return UI; } 184 185 bool runOnMachineFunction(MachineFunction &F) override; 186 void getAnalysisUsage(AnalysisUsage &AU) const override; 187 void print(raw_ostream &OS, const Module *M = nullptr) const override; 188 189 // TODO: verify analysis 190 }; 191 192 class MachineUniformityInfoPrinterPass : public MachineFunctionPass { 193 public: 194 static char ID; 195 196 MachineUniformityInfoPrinterPass(); 197 198 bool runOnMachineFunction(MachineFunction &F) override; 199 void getAnalysisUsage(AnalysisUsage &AU) const override; 200 }; 201 202 } // namespace 203 204 char MachineUniformityAnalysisPass::ID = 0; 205 206 MachineUniformityAnalysisPass::MachineUniformityAnalysisPass() 207 : MachineFunctionPass(ID) { 208 initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry()); 209 } 210 211 INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity", 212 "Machine Uniformity Info Analysis", true, true) 213 INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass) 214 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree) 215 INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity", 216 "Machine Uniformity Info Analysis", true, true) 217 218 void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const { 219 AU.setPreservesAll(); 220 AU.addRequired<MachineCycleInfoWrapperPass>(); 221 AU.addRequired<MachineDominatorTree>(); 222 MachineFunctionPass::getAnalysisUsage(AU); 223 } 224 225 bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) { 226 auto &DomTree = getAnalysis<MachineDominatorTree>().getBase(); 227 auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo(); 228 UI = computeMachineUniformityInfo(MF, CI, DomTree); 229 return false; 230 } 231 232 void MachineUniformityAnalysisPass::print(raw_ostream &OS, 233 const Module *) const { 234 OS << "MachineUniformityInfo for function: " << UI.getFunction().getName() 235 << "\n"; 236 UI.print(OS); 237 } 238 239 char MachineUniformityInfoPrinterPass::ID = 0; 240 241 MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass() 242 : MachineFunctionPass(ID) { 243 initializeMachineUniformityInfoPrinterPassPass( 244 *PassRegistry::getPassRegistry()); 245 } 246 247 INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass, 248 "print-machine-uniformity", 249 "Print Machine Uniformity Info Analysis", true, true) 250 INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass) 251 INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass, 252 "print-machine-uniformity", 253 "Print Machine Uniformity Info Analysis", true, true) 254 255 void MachineUniformityInfoPrinterPass::getAnalysisUsage( 256 AnalysisUsage &AU) const { 257 AU.setPreservesAll(); 258 AU.addRequired<MachineUniformityAnalysisPass>(); 259 MachineFunctionPass::getAnalysisUsage(AU); 260 } 261 262 bool MachineUniformityInfoPrinterPass::runOnMachineFunction( 263 MachineFunction &F) { 264 auto &UI = getAnalysis<MachineUniformityAnalysisPass>(); 265 UI.print(errs()); 266 return false; 267 } 268