1*6047deb7SPaul Kirth #include "llvm/IR/ProfDataUtils.h" 2*6047deb7SPaul Kirth #include "llvm/ADT/SmallVector.h" 3*6047deb7SPaul Kirth #include "llvm/ADT/Twine.h" 4*6047deb7SPaul Kirth #include "llvm/IR/Constants.h" 5*6047deb7SPaul Kirth #include "llvm/IR/Function.h" 6*6047deb7SPaul Kirth #include "llvm/IR/Instructions.h" 7*6047deb7SPaul Kirth #include "llvm/IR/LLVMContext.h" 8*6047deb7SPaul Kirth #include "llvm/IR/Metadata.h" 9*6047deb7SPaul Kirth #include "llvm/Support/BranchProbability.h" 10*6047deb7SPaul Kirth #include "llvm/Support/CommandLine.h" 11*6047deb7SPaul Kirth 12*6047deb7SPaul Kirth using namespace llvm; 13*6047deb7SPaul Kirth 14*6047deb7SPaul Kirth namespace { 15*6047deb7SPaul Kirth 16*6047deb7SPaul Kirth // MD_prof nodes have the following layout 17*6047deb7SPaul Kirth // 18*6047deb7SPaul Kirth // In general: 19*6047deb7SPaul Kirth // { String name, Array of i32 } 20*6047deb7SPaul Kirth // 21*6047deb7SPaul Kirth // In terms of Types: 22*6047deb7SPaul Kirth // { MDString, [i32, i32, ...]} 23*6047deb7SPaul Kirth // 24*6047deb7SPaul Kirth // Concretely for Branch Weights 25*6047deb7SPaul Kirth // { "branch_weights", [i32 1, i32 10000]} 26*6047deb7SPaul Kirth // 27*6047deb7SPaul Kirth // We maintain some constants here to ensure that we access the branch weights 28*6047deb7SPaul Kirth // correctly, and can change the behavior in the future if the layout changes 29*6047deb7SPaul Kirth 30*6047deb7SPaul Kirth // The index at which the weights vector starts 31*6047deb7SPaul Kirth constexpr unsigned WeightsIdx = 1; 32*6047deb7SPaul Kirth 33*6047deb7SPaul Kirth // the minimum number of operands for MD_prof nodes with branch weights 34*6047deb7SPaul Kirth constexpr unsigned MinBWOps = 3; 35*6047deb7SPaul Kirth 36*6047deb7SPaul Kirth bool extractWeights(const MDNode *ProfileData, 37*6047deb7SPaul Kirth SmallVectorImpl<uint32_t> &Weights) { 38*6047deb7SPaul Kirth // Assume preconditions are already met (i.e. this is valid metadata) 39*6047deb7SPaul Kirth assert(ProfileData && "ProfileData was nullptr in extractWeights"); 40*6047deb7SPaul Kirth unsigned NOps = ProfileData->getNumOperands(); 41*6047deb7SPaul Kirth 42*6047deb7SPaul Kirth assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); 43*6047deb7SPaul Kirth Weights.resize(NOps - WeightsIdx); 44*6047deb7SPaul Kirth 45*6047deb7SPaul Kirth for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { 46*6047deb7SPaul Kirth ConstantInt *Weight = 47*6047deb7SPaul Kirth mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); 48*6047deb7SPaul Kirth assert(Weight && "Malformed branch_weight in MD_prof node"); 49*6047deb7SPaul Kirth assert(Weight->getValue().getActiveBits() <= 32 && 50*6047deb7SPaul Kirth "Too many bits for uint32_t"); 51*6047deb7SPaul Kirth Weights[Idx - WeightsIdx] = Weight->getZExtValue(); 52*6047deb7SPaul Kirth } 53*6047deb7SPaul Kirth return true; 54*6047deb7SPaul Kirth } 55*6047deb7SPaul Kirth 56*6047deb7SPaul Kirth // We may want to add support for other MD_prof types, so provide an abstraction 57*6047deb7SPaul Kirth // for checking the metadata type. 58*6047deb7SPaul Kirth bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { 59*6047deb7SPaul Kirth // TODO: This routine may be simplified if MD_prof used an enum instead of a 60*6047deb7SPaul Kirth // string to differentiate the types of MD_prof nodes. 61*6047deb7SPaul Kirth if (!ProfData || !Name || MinOps < 2) 62*6047deb7SPaul Kirth return false; 63*6047deb7SPaul Kirth 64*6047deb7SPaul Kirth unsigned NOps = ProfData->getNumOperands(); 65*6047deb7SPaul Kirth if (NOps < MinOps) 66*6047deb7SPaul Kirth return false; 67*6047deb7SPaul Kirth 68*6047deb7SPaul Kirth auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0)); 69*6047deb7SPaul Kirth if (!ProfDataName) 70*6047deb7SPaul Kirth return false; 71*6047deb7SPaul Kirth 72*6047deb7SPaul Kirth return ProfDataName->getString().equals(Name); 73*6047deb7SPaul Kirth } 74*6047deb7SPaul Kirth 75*6047deb7SPaul Kirth } // namespace 76*6047deb7SPaul Kirth 77*6047deb7SPaul Kirth namespace llvm { 78*6047deb7SPaul Kirth 79*6047deb7SPaul Kirth bool hasProfMD(const Instruction &I) { 80*6047deb7SPaul Kirth return nullptr != I.getMetadata(LLVMContext::MD_prof); 81*6047deb7SPaul Kirth } 82*6047deb7SPaul Kirth 83*6047deb7SPaul Kirth bool isBranchWeightMD(const MDNode *ProfileData) { 84*6047deb7SPaul Kirth return isTargetMD(ProfileData, "branch_weights", MinBWOps); 85*6047deb7SPaul Kirth } 86*6047deb7SPaul Kirth 87*6047deb7SPaul Kirth bool hasBranchWeightMD(const Instruction &I) { 88*6047deb7SPaul Kirth auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 89*6047deb7SPaul Kirth return isBranchWeightMD(ProfileData); 90*6047deb7SPaul Kirth } 91*6047deb7SPaul Kirth 92*6047deb7SPaul Kirth bool extractBranchWeights(const MDNode *ProfileData, 93*6047deb7SPaul Kirth SmallVectorImpl<uint32_t> &Weights) { 94*6047deb7SPaul Kirth if (!isBranchWeightMD(ProfileData)) 95*6047deb7SPaul Kirth return false; 96*6047deb7SPaul Kirth return extractWeights(ProfileData, Weights); 97*6047deb7SPaul Kirth } 98*6047deb7SPaul Kirth 99*6047deb7SPaul Kirth bool extractBranchWeights(const Instruction &I, 100*6047deb7SPaul Kirth SmallVectorImpl<uint32_t> &Weights) { 101*6047deb7SPaul Kirth auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 102*6047deb7SPaul Kirth return extractBranchWeights(ProfileData, Weights); 103*6047deb7SPaul Kirth } 104*6047deb7SPaul Kirth 105*6047deb7SPaul Kirth bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal, 106*6047deb7SPaul Kirth uint64_t &FalseVal) { 107*6047deb7SPaul Kirth assert((I.getOpcode() == Instruction::Br || 108*6047deb7SPaul Kirth I.getOpcode() == Instruction::Select) && 109*6047deb7SPaul Kirth "Looking for branch weights on something besides branch or select"); 110*6047deb7SPaul Kirth 111*6047deb7SPaul Kirth SmallVector<uint32_t, 2> Weights; 112*6047deb7SPaul Kirth auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 113*6047deb7SPaul Kirth if (!extractBranchWeights(ProfileData, Weights)) 114*6047deb7SPaul Kirth return false; 115*6047deb7SPaul Kirth 116*6047deb7SPaul Kirth if (Weights.size() > 2) 117*6047deb7SPaul Kirth return false; 118*6047deb7SPaul Kirth 119*6047deb7SPaul Kirth TrueVal = Weights[0]; 120*6047deb7SPaul Kirth FalseVal = Weights[1]; 121*6047deb7SPaul Kirth return true; 122*6047deb7SPaul Kirth } 123*6047deb7SPaul Kirth 124*6047deb7SPaul Kirth bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) { 125*6047deb7SPaul Kirth TotalVal = 0; 126*6047deb7SPaul Kirth if (!ProfileData) 127*6047deb7SPaul Kirth return false; 128*6047deb7SPaul Kirth 129*6047deb7SPaul Kirth auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0)); 130*6047deb7SPaul Kirth if (!ProfDataName) 131*6047deb7SPaul Kirth return false; 132*6047deb7SPaul Kirth 133*6047deb7SPaul Kirth if (ProfDataName->getString().equals("branch_weights")) { 134*6047deb7SPaul Kirth for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) { 135*6047deb7SPaul Kirth auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); 136*6047deb7SPaul Kirth assert(V && "Malformed branch_weight in MD_prof node"); 137*6047deb7SPaul Kirth TotalVal += V->getValue().getZExtValue(); 138*6047deb7SPaul Kirth } 139*6047deb7SPaul Kirth return true; 140*6047deb7SPaul Kirth } else if (ProfDataName->getString().equals("VP") && 141*6047deb7SPaul Kirth ProfileData->getNumOperands() > 3) { 142*6047deb7SPaul Kirth TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2)) 143*6047deb7SPaul Kirth ->getValue() 144*6047deb7SPaul Kirth .getZExtValue(); 145*6047deb7SPaul Kirth return true; 146*6047deb7SPaul Kirth } 147*6047deb7SPaul Kirth return false; 148*6047deb7SPaul Kirth } 149*6047deb7SPaul Kirth 150*6047deb7SPaul Kirth } // namespace llvm 151