xref: /freebsd-src/contrib/llvm-project/llvm/lib/IR/ProfDataUtils.cpp (revision bdd1243df58e60e85101c09001d9812a789b6bc4)
1*bdd1243dSDimitry Andric //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===//
2*bdd1243dSDimitry Andric //
3*bdd1243dSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*bdd1243dSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*bdd1243dSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*bdd1243dSDimitry Andric //
7*bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
8*bdd1243dSDimitry Andric //
9*bdd1243dSDimitry Andric // This file implements utilities for working with Profiling Metadata.
10*bdd1243dSDimitry Andric //
11*bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
12*bdd1243dSDimitry Andric 
13*bdd1243dSDimitry Andric #include "llvm/IR/ProfDataUtils.h"
14*bdd1243dSDimitry Andric #include "llvm/ADT/SmallVector.h"
15*bdd1243dSDimitry Andric #include "llvm/ADT/Twine.h"
16*bdd1243dSDimitry Andric #include "llvm/IR/Constants.h"
17*bdd1243dSDimitry Andric #include "llvm/IR/Function.h"
18*bdd1243dSDimitry Andric #include "llvm/IR/Instructions.h"
19*bdd1243dSDimitry Andric #include "llvm/IR/LLVMContext.h"
20*bdd1243dSDimitry Andric #include "llvm/IR/Metadata.h"
21*bdd1243dSDimitry Andric #include "llvm/Support/BranchProbability.h"
22*bdd1243dSDimitry Andric #include "llvm/Support/CommandLine.h"
23*bdd1243dSDimitry Andric 
24*bdd1243dSDimitry Andric using namespace llvm;
25*bdd1243dSDimitry Andric 
26*bdd1243dSDimitry Andric namespace {
27*bdd1243dSDimitry Andric 
28*bdd1243dSDimitry Andric // MD_prof nodes have the following layout
29*bdd1243dSDimitry Andric //
30*bdd1243dSDimitry Andric // In general:
31*bdd1243dSDimitry Andric // { String name,         Array of i32   }
32*bdd1243dSDimitry Andric //
33*bdd1243dSDimitry Andric // In terms of Types:
34*bdd1243dSDimitry Andric // { MDString,            [i32, i32, ...]}
35*bdd1243dSDimitry Andric //
36*bdd1243dSDimitry Andric // Concretely for Branch Weights
37*bdd1243dSDimitry Andric // { "branch_weights",    [i32 1, i32 10000]}
38*bdd1243dSDimitry Andric //
39*bdd1243dSDimitry Andric // We maintain some constants here to ensure that we access the branch weights
40*bdd1243dSDimitry Andric // correctly, and can change the behavior in the future if the layout changes
41*bdd1243dSDimitry Andric 
42*bdd1243dSDimitry Andric // The index at which the weights vector starts
43*bdd1243dSDimitry Andric constexpr unsigned WeightsIdx = 1;
44*bdd1243dSDimitry Andric 
45*bdd1243dSDimitry Andric // the minimum number of operands for MD_prof nodes with branch weights
46*bdd1243dSDimitry Andric constexpr unsigned MinBWOps = 3;
47*bdd1243dSDimitry Andric 
48*bdd1243dSDimitry Andric bool extractWeights(const MDNode *ProfileData,
49*bdd1243dSDimitry Andric                     SmallVectorImpl<uint32_t> &Weights) {
50*bdd1243dSDimitry Andric   // Assume preconditions are already met (i.e. this is valid metadata)
51*bdd1243dSDimitry Andric   assert(ProfileData && "ProfileData was nullptr in extractWeights");
52*bdd1243dSDimitry Andric   unsigned NOps = ProfileData->getNumOperands();
53*bdd1243dSDimitry Andric 
54*bdd1243dSDimitry Andric   assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
55*bdd1243dSDimitry Andric   Weights.resize(NOps - WeightsIdx);
56*bdd1243dSDimitry Andric 
57*bdd1243dSDimitry Andric   for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
58*bdd1243dSDimitry Andric     ConstantInt *Weight =
59*bdd1243dSDimitry Andric         mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
60*bdd1243dSDimitry Andric     assert(Weight && "Malformed branch_weight in MD_prof node");
61*bdd1243dSDimitry Andric     assert(Weight->getValue().getActiveBits() <= 32 &&
62*bdd1243dSDimitry Andric            "Too many bits for uint32_t");
63*bdd1243dSDimitry Andric     Weights[Idx - WeightsIdx] = Weight->getZExtValue();
64*bdd1243dSDimitry Andric   }
65*bdd1243dSDimitry Andric   return true;
66*bdd1243dSDimitry Andric }
67*bdd1243dSDimitry Andric 
68*bdd1243dSDimitry Andric // We may want to add support for other MD_prof types, so provide an abstraction
69*bdd1243dSDimitry Andric // for checking the metadata type.
70*bdd1243dSDimitry Andric bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
71*bdd1243dSDimitry Andric   // TODO: This routine may be simplified if MD_prof used an enum instead of a
72*bdd1243dSDimitry Andric   // string to differentiate the types of MD_prof nodes.
73*bdd1243dSDimitry Andric   if (!ProfData || !Name || MinOps < 2)
74*bdd1243dSDimitry Andric     return false;
75*bdd1243dSDimitry Andric 
76*bdd1243dSDimitry Andric   unsigned NOps = ProfData->getNumOperands();
77*bdd1243dSDimitry Andric   if (NOps < MinOps)
78*bdd1243dSDimitry Andric     return false;
79*bdd1243dSDimitry Andric 
80*bdd1243dSDimitry Andric   auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0));
81*bdd1243dSDimitry Andric   if (!ProfDataName)
82*bdd1243dSDimitry Andric     return false;
83*bdd1243dSDimitry Andric 
84*bdd1243dSDimitry Andric   return ProfDataName->getString().equals(Name);
85*bdd1243dSDimitry Andric }
86*bdd1243dSDimitry Andric 
87*bdd1243dSDimitry Andric } // namespace
88*bdd1243dSDimitry Andric 
89*bdd1243dSDimitry Andric namespace llvm {
90*bdd1243dSDimitry Andric 
91*bdd1243dSDimitry Andric bool hasProfMD(const Instruction &I) {
92*bdd1243dSDimitry Andric   return nullptr != I.getMetadata(LLVMContext::MD_prof);
93*bdd1243dSDimitry Andric }
94*bdd1243dSDimitry Andric 
95*bdd1243dSDimitry Andric bool isBranchWeightMD(const MDNode *ProfileData) {
96*bdd1243dSDimitry Andric   return isTargetMD(ProfileData, "branch_weights", MinBWOps);
97*bdd1243dSDimitry Andric }
98*bdd1243dSDimitry Andric 
99*bdd1243dSDimitry Andric bool hasBranchWeightMD(const Instruction &I) {
100*bdd1243dSDimitry Andric   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
101*bdd1243dSDimitry Andric   return isBranchWeightMD(ProfileData);
102*bdd1243dSDimitry Andric }
103*bdd1243dSDimitry Andric 
104*bdd1243dSDimitry Andric bool hasValidBranchWeightMD(const Instruction &I) {
105*bdd1243dSDimitry Andric   return getValidBranchWeightMDNode(I);
106*bdd1243dSDimitry Andric }
107*bdd1243dSDimitry Andric 
108*bdd1243dSDimitry Andric MDNode *getBranchWeightMDNode(const Instruction &I) {
109*bdd1243dSDimitry Andric   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
110*bdd1243dSDimitry Andric   if (!isBranchWeightMD(ProfileData))
111*bdd1243dSDimitry Andric     return nullptr;
112*bdd1243dSDimitry Andric   return ProfileData;
113*bdd1243dSDimitry Andric }
114*bdd1243dSDimitry Andric 
115*bdd1243dSDimitry Andric MDNode *getValidBranchWeightMDNode(const Instruction &I) {
116*bdd1243dSDimitry Andric   auto *ProfileData = getBranchWeightMDNode(I);
117*bdd1243dSDimitry Andric   if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors())
118*bdd1243dSDimitry Andric     return ProfileData;
119*bdd1243dSDimitry Andric   return nullptr;
120*bdd1243dSDimitry Andric }
121*bdd1243dSDimitry Andric 
122*bdd1243dSDimitry Andric bool extractBranchWeights(const MDNode *ProfileData,
123*bdd1243dSDimitry Andric                           SmallVectorImpl<uint32_t> &Weights) {
124*bdd1243dSDimitry Andric   if (!isBranchWeightMD(ProfileData))
125*bdd1243dSDimitry Andric     return false;
126*bdd1243dSDimitry Andric   return extractWeights(ProfileData, Weights);
127*bdd1243dSDimitry Andric }
128*bdd1243dSDimitry Andric 
129*bdd1243dSDimitry Andric bool extractBranchWeights(const Instruction &I,
130*bdd1243dSDimitry Andric                           SmallVectorImpl<uint32_t> &Weights) {
131*bdd1243dSDimitry Andric   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
132*bdd1243dSDimitry Andric   return extractBranchWeights(ProfileData, Weights);
133*bdd1243dSDimitry Andric }
134*bdd1243dSDimitry Andric 
135*bdd1243dSDimitry Andric bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
136*bdd1243dSDimitry Andric                           uint64_t &FalseVal) {
137*bdd1243dSDimitry Andric   assert((I.getOpcode() == Instruction::Br ||
138*bdd1243dSDimitry Andric           I.getOpcode() == Instruction::Select) &&
139*bdd1243dSDimitry Andric          "Looking for branch weights on something besides branch, select, or "
140*bdd1243dSDimitry Andric          "switch");
141*bdd1243dSDimitry Andric 
142*bdd1243dSDimitry Andric   SmallVector<uint32_t, 2> Weights;
143*bdd1243dSDimitry Andric   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
144*bdd1243dSDimitry Andric   if (!extractBranchWeights(ProfileData, Weights))
145*bdd1243dSDimitry Andric     return false;
146*bdd1243dSDimitry Andric 
147*bdd1243dSDimitry Andric   if (Weights.size() > 2)
148*bdd1243dSDimitry Andric     return false;
149*bdd1243dSDimitry Andric 
150*bdd1243dSDimitry Andric   TrueVal = Weights[0];
151*bdd1243dSDimitry Andric   FalseVal = Weights[1];
152*bdd1243dSDimitry Andric   return true;
153*bdd1243dSDimitry Andric }
154*bdd1243dSDimitry Andric 
155*bdd1243dSDimitry Andric bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
156*bdd1243dSDimitry Andric   TotalVal = 0;
157*bdd1243dSDimitry Andric   if (!ProfileData)
158*bdd1243dSDimitry Andric     return false;
159*bdd1243dSDimitry Andric 
160*bdd1243dSDimitry Andric   auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
161*bdd1243dSDimitry Andric   if (!ProfDataName)
162*bdd1243dSDimitry Andric     return false;
163*bdd1243dSDimitry Andric 
164*bdd1243dSDimitry Andric   if (ProfDataName->getString().equals("branch_weights")) {
165*bdd1243dSDimitry Andric     for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) {
166*bdd1243dSDimitry Andric       auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
167*bdd1243dSDimitry Andric       assert(V && "Malformed branch_weight in MD_prof node");
168*bdd1243dSDimitry Andric       TotalVal += V->getValue().getZExtValue();
169*bdd1243dSDimitry Andric     }
170*bdd1243dSDimitry Andric     return true;
171*bdd1243dSDimitry Andric   }
172*bdd1243dSDimitry Andric 
173*bdd1243dSDimitry Andric   if (ProfDataName->getString().equals("VP") &&
174*bdd1243dSDimitry Andric       ProfileData->getNumOperands() > 3) {
175*bdd1243dSDimitry Andric     TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2))
176*bdd1243dSDimitry Andric                    ->getValue()
177*bdd1243dSDimitry Andric                    .getZExtValue();
178*bdd1243dSDimitry Andric     return true;
179*bdd1243dSDimitry Andric   }
180*bdd1243dSDimitry Andric   return false;
181*bdd1243dSDimitry Andric }
182*bdd1243dSDimitry Andric 
183*bdd1243dSDimitry Andric bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
184*bdd1243dSDimitry Andric   return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
185*bdd1243dSDimitry Andric }
186*bdd1243dSDimitry Andric 
187*bdd1243dSDimitry Andric } // namespace llvm
188