1 //===- UniformityAnalysis.cpp ---------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Analysis/UniformityAnalysis.h" 10 #include "llvm/ADT/GenericUniformityImpl.h" 11 #include "llvm/Analysis/CycleAnalysis.h" 12 #include "llvm/Analysis/TargetTransformInfo.h" 13 #include "llvm/IR/Dominators.h" 14 #include "llvm/IR/InstIterator.h" 15 #include "llvm/IR/Instructions.h" 16 #include "llvm/InitializePasses.h" 17 18 using namespace llvm; 19 20 template <> 21 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs( 22 const Instruction &I) const { 23 return isDivergent((const Value *)&I); 24 } 25 26 template <> 27 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent( 28 const Instruction &Instr) { 29 return markDivergent(cast<Value>(&Instr)); 30 } 31 32 template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() { 33 for (auto &I : instructions(F)) { 34 if (TTI->isSourceOfDivergence(&I)) 35 markDivergent(I); 36 else if (TTI->isAlwaysUniform(&I)) 37 addUniformOverride(I); 38 } 39 for (auto &Arg : F.args()) { 40 if (TTI->isSourceOfDivergence(&Arg)) { 41 markDivergent(&Arg); 42 } 43 } 44 } 45 46 template <> 47 void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers( 48 const Value *V) { 49 for (const auto *User : V->users()) { 50 if (const auto *UserInstr = dyn_cast<const Instruction>(User)) { 51 markDivergent(*UserInstr); 52 } 53 } 54 } 55 56 template <> 57 void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers( 58 const Instruction &Instr) { 59 assert(!isAlwaysUniform(Instr)); 60 if (Instr.isTerminator()) 61 return; 62 pushUsers(cast<Value>(&Instr)); 63 } 64 65 template <> 66 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle( 67 const Instruction &I, const Cycle &DefCycle) const { 68 assert(!isAlwaysUniform(I)); 69 for (const Use &U : I.operands()) { 70 if (auto *I = dyn_cast<Instruction>(&U)) { 71 if (DefCycle.contains(I->getParent())) 72 return true; 73 } 74 } 75 return false; 76 } 77 78 template <> 79 void llvm::GenericUniformityAnalysisImpl< 80 SSAContext>::propagateTemporalDivergence(const Instruction &I, 81 const Cycle &DefCycle) { 82 if (isDivergent(I)) 83 return; 84 for (auto *User : I.users()) { 85 auto *UserInstr = cast<Instruction>(User); 86 if (DefCycle.contains(UserInstr->getParent())) 87 continue; 88 markDivergent(*UserInstr); 89 } 90 } 91 92 template <> 93 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse( 94 const Use &U) const { 95 const auto *V = U.get(); 96 if (isDivergent(V)) 97 return true; 98 if (const auto *DefInstr = dyn_cast<Instruction>(V)) { 99 const auto *UseInstr = cast<Instruction>(U.getUser()); 100 return isTemporalDivergent(*UseInstr->getParent(), *DefInstr); 101 } 102 return false; 103 } 104 105 // This ensures explicit instantiation of 106 // GenericUniformityAnalysisImpl::ImplDeleter::operator() 107 template class llvm::GenericUniformityInfo<SSAContext>; 108 template struct llvm::GenericUniformityAnalysisImplDeleter< 109 llvm::GenericUniformityAnalysisImpl<SSAContext>>; 110 111 //===----------------------------------------------------------------------===// 112 // UniformityInfoAnalysis and related pass implementations 113 //===----------------------------------------------------------------------===// 114 115 llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F, 116 FunctionAnalysisManager &FAM) { 117 auto &DT = FAM.getResult<DominatorTreeAnalysis>(F); 118 auto &TTI = FAM.getResult<TargetIRAnalysis>(F); 119 auto &CI = FAM.getResult<CycleAnalysis>(F); 120 UniformityInfo UI{DT, CI, &TTI}; 121 // Skip computation if we can assume everything is uniform. 122 if (TTI.hasBranchDivergence(&F)) 123 UI.compute(); 124 125 return UI; 126 } 127 128 AnalysisKey UniformityInfoAnalysis::Key; 129 130 UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS) 131 : OS(OS) {} 132 133 PreservedAnalyses UniformityInfoPrinterPass::run(Function &F, 134 FunctionAnalysisManager &AM) { 135 OS << "UniformityInfo for function '" << F.getName() << "':\n"; 136 AM.getResult<UniformityInfoAnalysis>(F).print(OS); 137 138 return PreservedAnalyses::all(); 139 } 140 141 //===----------------------------------------------------------------------===// 142 // UniformityInfoWrapperPass Implementation 143 //===----------------------------------------------------------------------===// 144 145 char UniformityInfoWrapperPass::ID = 0; 146 147 UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) { 148 initializeUniformityInfoWrapperPassPass(*PassRegistry::getPassRegistry()); 149 } 150 151 INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity", 152 "Uniformity Analysis", true, true) 153 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 154 INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass) 155 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 156 INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity", 157 "Uniformity Analysis", true, true) 158 159 void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { 160 AU.setPreservesAll(); 161 AU.addRequired<DominatorTreeWrapperPass>(); 162 AU.addRequiredTransitive<CycleInfoWrapperPass>(); 163 AU.addRequired<TargetTransformInfoWrapperPass>(); 164 } 165 166 bool UniformityInfoWrapperPass::runOnFunction(Function &F) { 167 auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult(); 168 auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 169 auto &targetTransformInfo = 170 getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 171 172 m_function = &F; 173 m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo}; 174 175 // Skip computation if we can assume everything is uniform. 176 if (targetTransformInfo.hasBranchDivergence(m_function)) 177 m_uniformityInfo.compute(); 178 179 return false; 180 } 181 182 void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const { 183 OS << "UniformityInfo for function '" << m_function->getName() << "':\n"; 184 } 185 186 void UniformityInfoWrapperPass::releaseMemory() { 187 m_uniformityInfo = UniformityInfo{}; 188 m_function = nullptr; 189 } 190