xref: /llvm-project/llvm/lib/IR/ProfDataUtils.cpp (revision 6047deb7c2aa94d9bc2b70b49799d22cce778bd4)
1*6047deb7SPaul Kirth #include "llvm/IR/ProfDataUtils.h"
2*6047deb7SPaul Kirth #include "llvm/ADT/SmallVector.h"
3*6047deb7SPaul Kirth #include "llvm/ADT/Twine.h"
4*6047deb7SPaul Kirth #include "llvm/IR/Constants.h"
5*6047deb7SPaul Kirth #include "llvm/IR/Function.h"
6*6047deb7SPaul Kirth #include "llvm/IR/Instructions.h"
7*6047deb7SPaul Kirth #include "llvm/IR/LLVMContext.h"
8*6047deb7SPaul Kirth #include "llvm/IR/Metadata.h"
9*6047deb7SPaul Kirth #include "llvm/Support/BranchProbability.h"
10*6047deb7SPaul Kirth #include "llvm/Support/CommandLine.h"
11*6047deb7SPaul Kirth 
12*6047deb7SPaul Kirth using namespace llvm;
13*6047deb7SPaul Kirth 
14*6047deb7SPaul Kirth namespace {
15*6047deb7SPaul Kirth 
16*6047deb7SPaul Kirth // MD_prof nodes have the following layout
17*6047deb7SPaul Kirth //
18*6047deb7SPaul Kirth // In general:
19*6047deb7SPaul Kirth // { String name,         Array of i32   }
20*6047deb7SPaul Kirth //
21*6047deb7SPaul Kirth // In terms of Types:
22*6047deb7SPaul Kirth // { MDString,            [i32, i32, ...]}
23*6047deb7SPaul Kirth //
24*6047deb7SPaul Kirth // Concretely for Branch Weights
25*6047deb7SPaul Kirth // { "branch_weights",    [i32 1, i32 10000]}
26*6047deb7SPaul Kirth //
27*6047deb7SPaul Kirth // We maintain some constants here to ensure that we access the branch weights
28*6047deb7SPaul Kirth // correctly, and can change the behavior in the future if the layout changes
29*6047deb7SPaul Kirth 
30*6047deb7SPaul Kirth // The index at which the weights vector starts
31*6047deb7SPaul Kirth constexpr unsigned WeightsIdx = 1;
32*6047deb7SPaul Kirth 
33*6047deb7SPaul Kirth // the minimum number of operands for MD_prof nodes with branch weights
34*6047deb7SPaul Kirth constexpr unsigned MinBWOps = 3;
35*6047deb7SPaul Kirth 
36*6047deb7SPaul Kirth bool extractWeights(const MDNode *ProfileData,
37*6047deb7SPaul Kirth                     SmallVectorImpl<uint32_t> &Weights) {
38*6047deb7SPaul Kirth   // Assume preconditions are already met (i.e. this is valid metadata)
39*6047deb7SPaul Kirth   assert(ProfileData && "ProfileData was nullptr in extractWeights");
40*6047deb7SPaul Kirth   unsigned NOps = ProfileData->getNumOperands();
41*6047deb7SPaul Kirth 
42*6047deb7SPaul Kirth   assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
43*6047deb7SPaul Kirth   Weights.resize(NOps - WeightsIdx);
44*6047deb7SPaul Kirth 
45*6047deb7SPaul Kirth   for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
46*6047deb7SPaul Kirth     ConstantInt *Weight =
47*6047deb7SPaul Kirth         mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
48*6047deb7SPaul Kirth     assert(Weight && "Malformed branch_weight in MD_prof node");
49*6047deb7SPaul Kirth     assert(Weight->getValue().getActiveBits() <= 32 &&
50*6047deb7SPaul Kirth            "Too many bits for uint32_t");
51*6047deb7SPaul Kirth     Weights[Idx - WeightsIdx] = Weight->getZExtValue();
52*6047deb7SPaul Kirth   }
53*6047deb7SPaul Kirth   return true;
54*6047deb7SPaul Kirth }
55*6047deb7SPaul Kirth 
56*6047deb7SPaul Kirth // We may want to add support for other MD_prof types, so provide an abstraction
57*6047deb7SPaul Kirth // for checking the metadata type.
58*6047deb7SPaul Kirth bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
59*6047deb7SPaul Kirth   // TODO: This routine may be simplified if MD_prof used an enum instead of a
60*6047deb7SPaul Kirth   // string to differentiate the types of MD_prof nodes.
61*6047deb7SPaul Kirth   if (!ProfData || !Name || MinOps < 2)
62*6047deb7SPaul Kirth     return false;
63*6047deb7SPaul Kirth 
64*6047deb7SPaul Kirth   unsigned NOps = ProfData->getNumOperands();
65*6047deb7SPaul Kirth   if (NOps < MinOps)
66*6047deb7SPaul Kirth     return false;
67*6047deb7SPaul Kirth 
68*6047deb7SPaul Kirth   auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0));
69*6047deb7SPaul Kirth   if (!ProfDataName)
70*6047deb7SPaul Kirth     return false;
71*6047deb7SPaul Kirth 
72*6047deb7SPaul Kirth   return ProfDataName->getString().equals(Name);
73*6047deb7SPaul Kirth }
74*6047deb7SPaul Kirth 
75*6047deb7SPaul Kirth } // namespace
76*6047deb7SPaul Kirth 
77*6047deb7SPaul Kirth namespace llvm {
78*6047deb7SPaul Kirth 
79*6047deb7SPaul Kirth bool hasProfMD(const Instruction &I) {
80*6047deb7SPaul Kirth   return nullptr != I.getMetadata(LLVMContext::MD_prof);
81*6047deb7SPaul Kirth }
82*6047deb7SPaul Kirth 
83*6047deb7SPaul Kirth bool isBranchWeightMD(const MDNode *ProfileData) {
84*6047deb7SPaul Kirth   return isTargetMD(ProfileData, "branch_weights", MinBWOps);
85*6047deb7SPaul Kirth }
86*6047deb7SPaul Kirth 
87*6047deb7SPaul Kirth bool hasBranchWeightMD(const Instruction &I) {
88*6047deb7SPaul Kirth   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
89*6047deb7SPaul Kirth   return isBranchWeightMD(ProfileData);
90*6047deb7SPaul Kirth }
91*6047deb7SPaul Kirth 
92*6047deb7SPaul Kirth bool extractBranchWeights(const MDNode *ProfileData,
93*6047deb7SPaul Kirth                           SmallVectorImpl<uint32_t> &Weights) {
94*6047deb7SPaul Kirth   if (!isBranchWeightMD(ProfileData))
95*6047deb7SPaul Kirth     return false;
96*6047deb7SPaul Kirth   return extractWeights(ProfileData, Weights);
97*6047deb7SPaul Kirth }
98*6047deb7SPaul Kirth 
99*6047deb7SPaul Kirth bool extractBranchWeights(const Instruction &I,
100*6047deb7SPaul Kirth                           SmallVectorImpl<uint32_t> &Weights) {
101*6047deb7SPaul Kirth   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
102*6047deb7SPaul Kirth   return extractBranchWeights(ProfileData, Weights);
103*6047deb7SPaul Kirth }
104*6047deb7SPaul Kirth 
105*6047deb7SPaul Kirth bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
106*6047deb7SPaul Kirth                           uint64_t &FalseVal) {
107*6047deb7SPaul Kirth   assert((I.getOpcode() == Instruction::Br ||
108*6047deb7SPaul Kirth           I.getOpcode() == Instruction::Select) &&
109*6047deb7SPaul Kirth          "Looking for branch weights on something besides branch or select");
110*6047deb7SPaul Kirth 
111*6047deb7SPaul Kirth   SmallVector<uint32_t, 2> Weights;
112*6047deb7SPaul Kirth   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
113*6047deb7SPaul Kirth   if (!extractBranchWeights(ProfileData, Weights))
114*6047deb7SPaul Kirth     return false;
115*6047deb7SPaul Kirth 
116*6047deb7SPaul Kirth   if (Weights.size() > 2)
117*6047deb7SPaul Kirth     return false;
118*6047deb7SPaul Kirth 
119*6047deb7SPaul Kirth   TrueVal = Weights[0];
120*6047deb7SPaul Kirth   FalseVal = Weights[1];
121*6047deb7SPaul Kirth   return true;
122*6047deb7SPaul Kirth }
123*6047deb7SPaul Kirth 
124*6047deb7SPaul Kirth bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
125*6047deb7SPaul Kirth   TotalVal = 0;
126*6047deb7SPaul Kirth   if (!ProfileData)
127*6047deb7SPaul Kirth     return false;
128*6047deb7SPaul Kirth 
129*6047deb7SPaul Kirth   auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
130*6047deb7SPaul Kirth   if (!ProfDataName)
131*6047deb7SPaul Kirth     return false;
132*6047deb7SPaul Kirth 
133*6047deb7SPaul Kirth   if (ProfDataName->getString().equals("branch_weights")) {
134*6047deb7SPaul Kirth     for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) {
135*6047deb7SPaul Kirth       auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
136*6047deb7SPaul Kirth       assert(V && "Malformed branch_weight in MD_prof node");
137*6047deb7SPaul Kirth       TotalVal += V->getValue().getZExtValue();
138*6047deb7SPaul Kirth     }
139*6047deb7SPaul Kirth     return true;
140*6047deb7SPaul Kirth   } else if (ProfDataName->getString().equals("VP") &&
141*6047deb7SPaul Kirth              ProfileData->getNumOperands() > 3) {
142*6047deb7SPaul Kirth     TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2))
143*6047deb7SPaul Kirth                    ->getValue()
144*6047deb7SPaul Kirth                    .getZExtValue();
145*6047deb7SPaul Kirth     return true;
146*6047deb7SPaul Kirth   }
147*6047deb7SPaul Kirth   return false;
148*6047deb7SPaul Kirth }
149*6047deb7SPaul Kirth 
150*6047deb7SPaul Kirth } // namespace llvm
151