xref: /freebsd-src/contrib/llvm-project/llvm/lib/IR/ProfDataUtils.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1bdd1243dSDimitry Andric //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===//
2bdd1243dSDimitry Andric //
3bdd1243dSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4bdd1243dSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5bdd1243dSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6bdd1243dSDimitry Andric //
7bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
8bdd1243dSDimitry Andric //
9bdd1243dSDimitry Andric // This file implements utilities for working with Profiling Metadata.
10bdd1243dSDimitry Andric //
11bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
12bdd1243dSDimitry Andric 
13bdd1243dSDimitry Andric #include "llvm/IR/ProfDataUtils.h"
14bdd1243dSDimitry Andric #include "llvm/ADT/SmallVector.h"
15bdd1243dSDimitry Andric #include "llvm/ADT/Twine.h"
16bdd1243dSDimitry Andric #include "llvm/IR/Constants.h"
17bdd1243dSDimitry Andric #include "llvm/IR/Function.h"
18bdd1243dSDimitry Andric #include "llvm/IR/Instructions.h"
19bdd1243dSDimitry Andric #include "llvm/IR/LLVMContext.h"
205f757f3fSDimitry Andric #include "llvm/IR/MDBuilder.h"
21bdd1243dSDimitry Andric #include "llvm/IR/Metadata.h"
22*0fca6ea1SDimitry Andric #include "llvm/IR/ProfDataUtils.h"
23bdd1243dSDimitry Andric #include "llvm/Support/BranchProbability.h"
24bdd1243dSDimitry Andric #include "llvm/Support/CommandLine.h"
25bdd1243dSDimitry Andric 
26bdd1243dSDimitry Andric using namespace llvm;
27bdd1243dSDimitry Andric 
28bdd1243dSDimitry Andric namespace {
29bdd1243dSDimitry Andric 
30bdd1243dSDimitry Andric // MD_prof nodes have the following layout
31bdd1243dSDimitry Andric //
32bdd1243dSDimitry Andric // In general:
33bdd1243dSDimitry Andric // { String name,         Array of i32   }
34bdd1243dSDimitry Andric //
35bdd1243dSDimitry Andric // In terms of Types:
36bdd1243dSDimitry Andric // { MDString,            [i32, i32, ...]}
37bdd1243dSDimitry Andric //
38bdd1243dSDimitry Andric // Concretely for Branch Weights
39bdd1243dSDimitry Andric // { "branch_weights",    [i32 1, i32 10000]}
40bdd1243dSDimitry Andric //
41bdd1243dSDimitry Andric // We maintain some constants here to ensure that we access the branch weights
42bdd1243dSDimitry Andric // correctly, and can change the behavior in the future if the layout changes
43bdd1243dSDimitry Andric 
44bdd1243dSDimitry Andric // the minimum number of operands for MD_prof nodes with branch weights
45bdd1243dSDimitry Andric constexpr unsigned MinBWOps = 3;
46bdd1243dSDimitry Andric 
47*0fca6ea1SDimitry Andric // the minimum number of operands for MD_prof nodes with value profiles
48*0fca6ea1SDimitry Andric constexpr unsigned MinVPOps = 5;
49*0fca6ea1SDimitry Andric 
50bdd1243dSDimitry Andric // We may want to add support for other MD_prof types, so provide an abstraction
51bdd1243dSDimitry Andric // for checking the metadata type.
52bdd1243dSDimitry Andric bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
53bdd1243dSDimitry Andric   // TODO: This routine may be simplified if MD_prof used an enum instead of a
54bdd1243dSDimitry Andric   // string to differentiate the types of MD_prof nodes.
55bdd1243dSDimitry Andric   if (!ProfData || !Name || MinOps < 2)
56bdd1243dSDimitry Andric     return false;
57bdd1243dSDimitry Andric 
58bdd1243dSDimitry Andric   unsigned NOps = ProfData->getNumOperands();
59bdd1243dSDimitry Andric   if (NOps < MinOps)
60bdd1243dSDimitry Andric     return false;
61bdd1243dSDimitry Andric 
62bdd1243dSDimitry Andric   auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0));
63bdd1243dSDimitry Andric   if (!ProfDataName)
64bdd1243dSDimitry Andric     return false;
65bdd1243dSDimitry Andric 
66*0fca6ea1SDimitry Andric   return ProfDataName->getString() == Name;
67*0fca6ea1SDimitry Andric }
68*0fca6ea1SDimitry Andric 
69*0fca6ea1SDimitry Andric template <typename T,
70*0fca6ea1SDimitry Andric           typename = typename std::enable_if<std::is_arithmetic_v<T>>>
71*0fca6ea1SDimitry Andric static void extractFromBranchWeightMD(const MDNode *ProfileData,
72*0fca6ea1SDimitry Andric                                       SmallVectorImpl<T> &Weights) {
73*0fca6ea1SDimitry Andric   assert(isBranchWeightMD(ProfileData) && "wrong metadata");
74*0fca6ea1SDimitry Andric 
75*0fca6ea1SDimitry Andric   unsigned NOps = ProfileData->getNumOperands();
76*0fca6ea1SDimitry Andric   unsigned WeightsIdx = getBranchWeightOffset(ProfileData);
77*0fca6ea1SDimitry Andric   assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
78*0fca6ea1SDimitry Andric   Weights.resize(NOps - WeightsIdx);
79*0fca6ea1SDimitry Andric 
80*0fca6ea1SDimitry Andric   for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
81*0fca6ea1SDimitry Andric     ConstantInt *Weight =
82*0fca6ea1SDimitry Andric         mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
83*0fca6ea1SDimitry Andric     assert(Weight && "Malformed branch_weight in MD_prof node");
84*0fca6ea1SDimitry Andric     assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) &&
85*0fca6ea1SDimitry Andric            "Too many bits for MD_prof branch_weight");
86*0fca6ea1SDimitry Andric     Weights[Idx - WeightsIdx] = Weight->getZExtValue();
87*0fca6ea1SDimitry Andric   }
88bdd1243dSDimitry Andric }
89bdd1243dSDimitry Andric 
90bdd1243dSDimitry Andric } // namespace
91bdd1243dSDimitry Andric 
92bdd1243dSDimitry Andric namespace llvm {
93bdd1243dSDimitry Andric 
94bdd1243dSDimitry Andric bool hasProfMD(const Instruction &I) {
95*0fca6ea1SDimitry Andric   return I.hasMetadata(LLVMContext::MD_prof);
96bdd1243dSDimitry Andric }
97bdd1243dSDimitry Andric 
98bdd1243dSDimitry Andric bool isBranchWeightMD(const MDNode *ProfileData) {
99bdd1243dSDimitry Andric   return isTargetMD(ProfileData, "branch_weights", MinBWOps);
100bdd1243dSDimitry Andric }
101bdd1243dSDimitry Andric 
102*0fca6ea1SDimitry Andric bool isValueProfileMD(const MDNode *ProfileData) {
103*0fca6ea1SDimitry Andric   return isTargetMD(ProfileData, "VP", MinVPOps);
104*0fca6ea1SDimitry Andric }
105*0fca6ea1SDimitry Andric 
106bdd1243dSDimitry Andric bool hasBranchWeightMD(const Instruction &I) {
107bdd1243dSDimitry Andric   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
108bdd1243dSDimitry Andric   return isBranchWeightMD(ProfileData);
109bdd1243dSDimitry Andric }
110bdd1243dSDimitry Andric 
111*0fca6ea1SDimitry Andric bool hasCountTypeMD(const Instruction &I) {
112*0fca6ea1SDimitry Andric   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
113*0fca6ea1SDimitry Andric   // Value profiles record count-type information.
114*0fca6ea1SDimitry Andric   if (isValueProfileMD(ProfileData))
115*0fca6ea1SDimitry Andric     return true;
116*0fca6ea1SDimitry Andric   // Conservatively assume non CallBase instruction only get taken/not-taken
117*0fca6ea1SDimitry Andric   // branch probability, so not interpret them as count.
118*0fca6ea1SDimitry Andric   return isa<CallBase>(I) && !isBranchWeightMD(ProfileData);
119*0fca6ea1SDimitry Andric }
120*0fca6ea1SDimitry Andric 
121bdd1243dSDimitry Andric bool hasValidBranchWeightMD(const Instruction &I) {
122bdd1243dSDimitry Andric   return getValidBranchWeightMDNode(I);
123bdd1243dSDimitry Andric }
124bdd1243dSDimitry Andric 
125*0fca6ea1SDimitry Andric bool hasBranchWeightOrigin(const Instruction &I) {
126*0fca6ea1SDimitry Andric   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
127*0fca6ea1SDimitry Andric   return hasBranchWeightOrigin(ProfileData);
128*0fca6ea1SDimitry Andric }
129*0fca6ea1SDimitry Andric 
130*0fca6ea1SDimitry Andric bool hasBranchWeightOrigin(const MDNode *ProfileData) {
131*0fca6ea1SDimitry Andric   if (!isBranchWeightMD(ProfileData))
132*0fca6ea1SDimitry Andric     return false;
133*0fca6ea1SDimitry Andric   auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1));
134*0fca6ea1SDimitry Andric   // NOTE: if we ever have more types of branch weight provenance,
135*0fca6ea1SDimitry Andric   // we need to check the string value is "expected". For now, we
136*0fca6ea1SDimitry Andric   // supply a more generic API, and avoid the spurious comparisons.
137*0fca6ea1SDimitry Andric   assert(ProfDataName == nullptr || ProfDataName->getString() == "expected");
138*0fca6ea1SDimitry Andric   return ProfDataName != nullptr;
139*0fca6ea1SDimitry Andric }
140*0fca6ea1SDimitry Andric 
141*0fca6ea1SDimitry Andric unsigned getBranchWeightOffset(const MDNode *ProfileData) {
142*0fca6ea1SDimitry Andric   return hasBranchWeightOrigin(ProfileData) ? 2 : 1;
143*0fca6ea1SDimitry Andric }
144*0fca6ea1SDimitry Andric 
145*0fca6ea1SDimitry Andric unsigned getNumBranchWeights(const MDNode &ProfileData) {
146*0fca6ea1SDimitry Andric   return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData);
147*0fca6ea1SDimitry Andric }
148*0fca6ea1SDimitry Andric 
149bdd1243dSDimitry Andric MDNode *getBranchWeightMDNode(const Instruction &I) {
150bdd1243dSDimitry Andric   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
151bdd1243dSDimitry Andric   if (!isBranchWeightMD(ProfileData))
152bdd1243dSDimitry Andric     return nullptr;
153bdd1243dSDimitry Andric   return ProfileData;
154bdd1243dSDimitry Andric }
155bdd1243dSDimitry Andric 
156bdd1243dSDimitry Andric MDNode *getValidBranchWeightMDNode(const Instruction &I) {
157bdd1243dSDimitry Andric   auto *ProfileData = getBranchWeightMDNode(I);
158*0fca6ea1SDimitry Andric   if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors())
159bdd1243dSDimitry Andric     return ProfileData;
160bdd1243dSDimitry Andric   return nullptr;
161bdd1243dSDimitry Andric }
162bdd1243dSDimitry Andric 
163*0fca6ea1SDimitry Andric void extractFromBranchWeightMD32(const MDNode *ProfileData,
1645f757f3fSDimitry Andric                                  SmallVectorImpl<uint32_t> &Weights) {
165*0fca6ea1SDimitry Andric   extractFromBranchWeightMD(ProfileData, Weights);
1665f757f3fSDimitry Andric }
167*0fca6ea1SDimitry Andric 
168*0fca6ea1SDimitry Andric void extractFromBranchWeightMD64(const MDNode *ProfileData,
169*0fca6ea1SDimitry Andric                                  SmallVectorImpl<uint64_t> &Weights) {
170*0fca6ea1SDimitry Andric   extractFromBranchWeightMD(ProfileData, Weights);
1715f757f3fSDimitry Andric }
1725f757f3fSDimitry Andric 
173bdd1243dSDimitry Andric bool extractBranchWeights(const MDNode *ProfileData,
174bdd1243dSDimitry Andric                           SmallVectorImpl<uint32_t> &Weights) {
175bdd1243dSDimitry Andric   if (!isBranchWeightMD(ProfileData))
176bdd1243dSDimitry Andric     return false;
1775f757f3fSDimitry Andric   extractFromBranchWeightMD(ProfileData, Weights);
1785f757f3fSDimitry Andric   return true;
179bdd1243dSDimitry Andric }
180bdd1243dSDimitry Andric 
181bdd1243dSDimitry Andric bool extractBranchWeights(const Instruction &I,
182bdd1243dSDimitry Andric                           SmallVectorImpl<uint32_t> &Weights) {
183bdd1243dSDimitry Andric   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
184bdd1243dSDimitry Andric   return extractBranchWeights(ProfileData, Weights);
185bdd1243dSDimitry Andric }
186bdd1243dSDimitry Andric 
187bdd1243dSDimitry Andric bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
188bdd1243dSDimitry Andric                           uint64_t &FalseVal) {
189bdd1243dSDimitry Andric   assert((I.getOpcode() == Instruction::Br ||
190bdd1243dSDimitry Andric           I.getOpcode() == Instruction::Select) &&
191bdd1243dSDimitry Andric          "Looking for branch weights on something besides branch, select, or "
192bdd1243dSDimitry Andric          "switch");
193bdd1243dSDimitry Andric 
194bdd1243dSDimitry Andric   SmallVector<uint32_t, 2> Weights;
195bdd1243dSDimitry Andric   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
196bdd1243dSDimitry Andric   if (!extractBranchWeights(ProfileData, Weights))
197bdd1243dSDimitry Andric     return false;
198bdd1243dSDimitry Andric 
199bdd1243dSDimitry Andric   if (Weights.size() > 2)
200bdd1243dSDimitry Andric     return false;
201bdd1243dSDimitry Andric 
202bdd1243dSDimitry Andric   TrueVal = Weights[0];
203bdd1243dSDimitry Andric   FalseVal = Weights[1];
204bdd1243dSDimitry Andric   return true;
205bdd1243dSDimitry Andric }
206bdd1243dSDimitry Andric 
207bdd1243dSDimitry Andric bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
208bdd1243dSDimitry Andric   TotalVal = 0;
209bdd1243dSDimitry Andric   if (!ProfileData)
210bdd1243dSDimitry Andric     return false;
211bdd1243dSDimitry Andric 
212bdd1243dSDimitry Andric   auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
213bdd1243dSDimitry Andric   if (!ProfDataName)
214bdd1243dSDimitry Andric     return false;
215bdd1243dSDimitry Andric 
216*0fca6ea1SDimitry Andric   if (ProfDataName->getString() == "branch_weights") {
217*0fca6ea1SDimitry Andric     unsigned Offset = getBranchWeightOffset(ProfileData);
218*0fca6ea1SDimitry Andric     for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) {
219bdd1243dSDimitry Andric       auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
220bdd1243dSDimitry Andric       assert(V && "Malformed branch_weight in MD_prof node");
221bdd1243dSDimitry Andric       TotalVal += V->getValue().getZExtValue();
222bdd1243dSDimitry Andric     }
223bdd1243dSDimitry Andric     return true;
224bdd1243dSDimitry Andric   }
225bdd1243dSDimitry Andric 
226*0fca6ea1SDimitry Andric   if (ProfDataName->getString() == "VP" && ProfileData->getNumOperands() > 3) {
227bdd1243dSDimitry Andric     TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2))
228bdd1243dSDimitry Andric                    ->getValue()
229bdd1243dSDimitry Andric                    .getZExtValue();
230bdd1243dSDimitry Andric     return true;
231bdd1243dSDimitry Andric   }
232bdd1243dSDimitry Andric   return false;
233bdd1243dSDimitry Andric }
234bdd1243dSDimitry Andric 
235bdd1243dSDimitry Andric bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
236bdd1243dSDimitry Andric   return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
237bdd1243dSDimitry Andric }
238bdd1243dSDimitry Andric 
239*0fca6ea1SDimitry Andric void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
240*0fca6ea1SDimitry Andric                       bool IsExpected) {
2415f757f3fSDimitry Andric   MDBuilder MDB(I.getContext());
242*0fca6ea1SDimitry Andric   MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
2435f757f3fSDimitry Andric   I.setMetadata(LLVMContext::MD_prof, BranchWeights);
2445f757f3fSDimitry Andric }
2455f757f3fSDimitry Andric 
246*0fca6ea1SDimitry Andric void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
247*0fca6ea1SDimitry Andric   assert(T != 0 && "Caller should guarantee");
248*0fca6ea1SDimitry Andric   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
249*0fca6ea1SDimitry Andric   if (ProfileData == nullptr)
250*0fca6ea1SDimitry Andric     return;
251*0fca6ea1SDimitry Andric 
252*0fca6ea1SDimitry Andric   auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
253*0fca6ea1SDimitry Andric   if (!ProfDataName || (ProfDataName->getString() != "branch_weights" &&
254*0fca6ea1SDimitry Andric                         ProfDataName->getString() != "VP"))
255*0fca6ea1SDimitry Andric     return;
256*0fca6ea1SDimitry Andric 
257*0fca6ea1SDimitry Andric   if (!hasCountTypeMD(I))
258*0fca6ea1SDimitry Andric     return;
259*0fca6ea1SDimitry Andric 
260*0fca6ea1SDimitry Andric   LLVMContext &C = I.getContext();
261*0fca6ea1SDimitry Andric 
262*0fca6ea1SDimitry Andric   MDBuilder MDB(C);
263*0fca6ea1SDimitry Andric   SmallVector<Metadata *, 3> Vals;
264*0fca6ea1SDimitry Andric   Vals.push_back(ProfileData->getOperand(0));
265*0fca6ea1SDimitry Andric   APInt APS(128, S), APT(128, T);
266*0fca6ea1SDimitry Andric   if (ProfDataName->getString() == "branch_weights" &&
267*0fca6ea1SDimitry Andric       ProfileData->getNumOperands() > 0) {
268*0fca6ea1SDimitry Andric     // Using APInt::div may be expensive, but most cases should fit 64 bits.
269*0fca6ea1SDimitry Andric     APInt Val(128,
270*0fca6ea1SDimitry Andric               mdconst::dyn_extract<ConstantInt>(
271*0fca6ea1SDimitry Andric                   ProfileData->getOperand(getBranchWeightOffset(ProfileData)))
272*0fca6ea1SDimitry Andric                   ->getValue()
273*0fca6ea1SDimitry Andric                   .getZExtValue());
274*0fca6ea1SDimitry Andric     Val *= APS;
275*0fca6ea1SDimitry Andric     Vals.push_back(MDB.createConstant(ConstantInt::get(
276*0fca6ea1SDimitry Andric         Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX))));
277*0fca6ea1SDimitry Andric   } else if (ProfDataName->getString() == "VP")
278*0fca6ea1SDimitry Andric     for (unsigned i = 1; i < ProfileData->getNumOperands(); i += 2) {
279*0fca6ea1SDimitry Andric       // The first value is the key of the value profile, which will not change.
280*0fca6ea1SDimitry Andric       Vals.push_back(ProfileData->getOperand(i));
281*0fca6ea1SDimitry Andric       uint64_t Count =
282*0fca6ea1SDimitry Andric           mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i + 1))
283*0fca6ea1SDimitry Andric               ->getValue()
284*0fca6ea1SDimitry Andric               .getZExtValue();
285*0fca6ea1SDimitry Andric       // Don't scale the magic number.
286*0fca6ea1SDimitry Andric       if (Count == NOMORE_ICP_MAGICNUM) {
287*0fca6ea1SDimitry Andric         Vals.push_back(ProfileData->getOperand(i + 1));
288*0fca6ea1SDimitry Andric         continue;
289*0fca6ea1SDimitry Andric       }
290*0fca6ea1SDimitry Andric       // Using APInt::div may be expensive, but most cases should fit 64 bits.
291*0fca6ea1SDimitry Andric       APInt Val(128, Count);
292*0fca6ea1SDimitry Andric       Val *= APS;
293*0fca6ea1SDimitry Andric       Vals.push_back(MDB.createConstant(ConstantInt::get(
294*0fca6ea1SDimitry Andric           Type::getInt64Ty(C), Val.udiv(APT).getLimitedValue())));
295*0fca6ea1SDimitry Andric     }
296*0fca6ea1SDimitry Andric   I.setMetadata(LLVMContext::MD_prof, MDNode::get(C, Vals));
297*0fca6ea1SDimitry Andric }
298*0fca6ea1SDimitry Andric 
299bdd1243dSDimitry Andric } // namespace llvm
300