xref: /freebsd-src/contrib/llvm-project/llvm/lib/IR/ProfDataUtils.cpp (revision 5f757f3ff9144b609b3c433dfd370cc6bdc191ad)
1bdd1243dSDimitry Andric //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===//
2bdd1243dSDimitry Andric //
3bdd1243dSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4bdd1243dSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5bdd1243dSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6bdd1243dSDimitry Andric //
7bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
8bdd1243dSDimitry Andric //
9bdd1243dSDimitry Andric // This file implements utilities for working with Profiling Metadata.
10bdd1243dSDimitry Andric //
11bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
12bdd1243dSDimitry Andric 
13bdd1243dSDimitry Andric #include "llvm/IR/ProfDataUtils.h"
14bdd1243dSDimitry Andric #include "llvm/ADT/SmallVector.h"
15bdd1243dSDimitry Andric #include "llvm/ADT/Twine.h"
16bdd1243dSDimitry Andric #include "llvm/IR/Constants.h"
17bdd1243dSDimitry Andric #include "llvm/IR/Function.h"
18bdd1243dSDimitry Andric #include "llvm/IR/Instructions.h"
19bdd1243dSDimitry Andric #include "llvm/IR/LLVMContext.h"
20*5f757f3fSDimitry Andric #include "llvm/IR/MDBuilder.h"
21bdd1243dSDimitry Andric #include "llvm/IR/Metadata.h"
22bdd1243dSDimitry Andric #include "llvm/Support/BranchProbability.h"
23bdd1243dSDimitry Andric #include "llvm/Support/CommandLine.h"
24bdd1243dSDimitry Andric 
25bdd1243dSDimitry Andric using namespace llvm;
26bdd1243dSDimitry Andric 
27bdd1243dSDimitry Andric namespace {
28bdd1243dSDimitry Andric 
29bdd1243dSDimitry Andric // MD_prof nodes have the following layout
30bdd1243dSDimitry Andric //
31bdd1243dSDimitry Andric // In general:
32bdd1243dSDimitry Andric // { String name,         Array of i32   }
33bdd1243dSDimitry Andric //
34bdd1243dSDimitry Andric // In terms of Types:
35bdd1243dSDimitry Andric // { MDString,            [i32, i32, ...]}
36bdd1243dSDimitry Andric //
37bdd1243dSDimitry Andric // Concretely for Branch Weights
38bdd1243dSDimitry Andric // { "branch_weights",    [i32 1, i32 10000]}
39bdd1243dSDimitry Andric //
40bdd1243dSDimitry Andric // We maintain some constants here to ensure that we access the branch weights
41bdd1243dSDimitry Andric // correctly, and can change the behavior in the future if the layout changes
42bdd1243dSDimitry Andric 
43bdd1243dSDimitry Andric // The index at which the weights vector starts
44bdd1243dSDimitry Andric constexpr unsigned WeightsIdx = 1;
45bdd1243dSDimitry Andric 
46bdd1243dSDimitry Andric // the minimum number of operands for MD_prof nodes with branch weights
47bdd1243dSDimitry Andric constexpr unsigned MinBWOps = 3;
48bdd1243dSDimitry Andric 
49bdd1243dSDimitry Andric // We may want to add support for other MD_prof types, so provide an abstraction
50bdd1243dSDimitry Andric // for checking the metadata type.
51bdd1243dSDimitry Andric bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
52bdd1243dSDimitry Andric   // TODO: This routine may be simplified if MD_prof used an enum instead of a
53bdd1243dSDimitry Andric   // string to differentiate the types of MD_prof nodes.
54bdd1243dSDimitry Andric   if (!ProfData || !Name || MinOps < 2)
55bdd1243dSDimitry Andric     return false;
56bdd1243dSDimitry Andric 
57bdd1243dSDimitry Andric   unsigned NOps = ProfData->getNumOperands();
58bdd1243dSDimitry Andric   if (NOps < MinOps)
59bdd1243dSDimitry Andric     return false;
60bdd1243dSDimitry Andric 
61bdd1243dSDimitry Andric   auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0));
62bdd1243dSDimitry Andric   if (!ProfDataName)
63bdd1243dSDimitry Andric     return false;
64bdd1243dSDimitry Andric 
65bdd1243dSDimitry Andric   return ProfDataName->getString().equals(Name);
66bdd1243dSDimitry Andric }
67bdd1243dSDimitry Andric 
68bdd1243dSDimitry Andric } // namespace
69bdd1243dSDimitry Andric 
70bdd1243dSDimitry Andric namespace llvm {
71bdd1243dSDimitry Andric 
72bdd1243dSDimitry Andric bool hasProfMD(const Instruction &I) {
73bdd1243dSDimitry Andric   return nullptr != I.getMetadata(LLVMContext::MD_prof);
74bdd1243dSDimitry Andric }
75bdd1243dSDimitry Andric 
76bdd1243dSDimitry Andric bool isBranchWeightMD(const MDNode *ProfileData) {
77bdd1243dSDimitry Andric   return isTargetMD(ProfileData, "branch_weights", MinBWOps);
78bdd1243dSDimitry Andric }
79bdd1243dSDimitry Andric 
80bdd1243dSDimitry Andric bool hasBranchWeightMD(const Instruction &I) {
81bdd1243dSDimitry Andric   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
82bdd1243dSDimitry Andric   return isBranchWeightMD(ProfileData);
83bdd1243dSDimitry Andric }
84bdd1243dSDimitry Andric 
85bdd1243dSDimitry Andric bool hasValidBranchWeightMD(const Instruction &I) {
86bdd1243dSDimitry Andric   return getValidBranchWeightMDNode(I);
87bdd1243dSDimitry Andric }
88bdd1243dSDimitry Andric 
89bdd1243dSDimitry Andric MDNode *getBranchWeightMDNode(const Instruction &I) {
90bdd1243dSDimitry Andric   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
91bdd1243dSDimitry Andric   if (!isBranchWeightMD(ProfileData))
92bdd1243dSDimitry Andric     return nullptr;
93bdd1243dSDimitry Andric   return ProfileData;
94bdd1243dSDimitry Andric }
95bdd1243dSDimitry Andric 
96bdd1243dSDimitry Andric MDNode *getValidBranchWeightMDNode(const Instruction &I) {
97bdd1243dSDimitry Andric   auto *ProfileData = getBranchWeightMDNode(I);
98bdd1243dSDimitry Andric   if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors())
99bdd1243dSDimitry Andric     return ProfileData;
100bdd1243dSDimitry Andric   return nullptr;
101bdd1243dSDimitry Andric }
102bdd1243dSDimitry Andric 
103*5f757f3fSDimitry Andric void extractFromBranchWeightMD(const MDNode *ProfileData,
104*5f757f3fSDimitry Andric                                SmallVectorImpl<uint32_t> &Weights) {
105*5f757f3fSDimitry Andric   assert(isBranchWeightMD(ProfileData) && "wrong metadata");
106*5f757f3fSDimitry Andric 
107*5f757f3fSDimitry Andric   unsigned NOps = ProfileData->getNumOperands();
108*5f757f3fSDimitry Andric   assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
109*5f757f3fSDimitry Andric   Weights.resize(NOps - WeightsIdx);
110*5f757f3fSDimitry Andric 
111*5f757f3fSDimitry Andric   for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
112*5f757f3fSDimitry Andric     ConstantInt *Weight =
113*5f757f3fSDimitry Andric         mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
114*5f757f3fSDimitry Andric     assert(Weight && "Malformed branch_weight in MD_prof node");
115*5f757f3fSDimitry Andric     assert(Weight->getValue().getActiveBits() <= 32 &&
116*5f757f3fSDimitry Andric            "Too many bits for uint32_t");
117*5f757f3fSDimitry Andric     Weights[Idx - WeightsIdx] = Weight->getZExtValue();
118*5f757f3fSDimitry Andric   }
119*5f757f3fSDimitry Andric }
120*5f757f3fSDimitry Andric 
121bdd1243dSDimitry Andric bool extractBranchWeights(const MDNode *ProfileData,
122bdd1243dSDimitry Andric                           SmallVectorImpl<uint32_t> &Weights) {
123bdd1243dSDimitry Andric   if (!isBranchWeightMD(ProfileData))
124bdd1243dSDimitry Andric     return false;
125*5f757f3fSDimitry Andric   extractFromBranchWeightMD(ProfileData, Weights);
126*5f757f3fSDimitry Andric   return true;
127bdd1243dSDimitry Andric }
128bdd1243dSDimitry Andric 
129bdd1243dSDimitry Andric bool extractBranchWeights(const Instruction &I,
130bdd1243dSDimitry Andric                           SmallVectorImpl<uint32_t> &Weights) {
131bdd1243dSDimitry Andric   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
132bdd1243dSDimitry Andric   return extractBranchWeights(ProfileData, Weights);
133bdd1243dSDimitry Andric }
134bdd1243dSDimitry Andric 
135bdd1243dSDimitry Andric bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
136bdd1243dSDimitry Andric                           uint64_t &FalseVal) {
137bdd1243dSDimitry Andric   assert((I.getOpcode() == Instruction::Br ||
138bdd1243dSDimitry Andric           I.getOpcode() == Instruction::Select) &&
139bdd1243dSDimitry Andric          "Looking for branch weights on something besides branch, select, or "
140bdd1243dSDimitry Andric          "switch");
141bdd1243dSDimitry Andric 
142bdd1243dSDimitry Andric   SmallVector<uint32_t, 2> Weights;
143bdd1243dSDimitry Andric   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
144bdd1243dSDimitry Andric   if (!extractBranchWeights(ProfileData, Weights))
145bdd1243dSDimitry Andric     return false;
146bdd1243dSDimitry Andric 
147bdd1243dSDimitry Andric   if (Weights.size() > 2)
148bdd1243dSDimitry Andric     return false;
149bdd1243dSDimitry Andric 
150bdd1243dSDimitry Andric   TrueVal = Weights[0];
151bdd1243dSDimitry Andric   FalseVal = Weights[1];
152bdd1243dSDimitry Andric   return true;
153bdd1243dSDimitry Andric }
154bdd1243dSDimitry Andric 
155bdd1243dSDimitry Andric bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
156bdd1243dSDimitry Andric   TotalVal = 0;
157bdd1243dSDimitry Andric   if (!ProfileData)
158bdd1243dSDimitry Andric     return false;
159bdd1243dSDimitry Andric 
160bdd1243dSDimitry Andric   auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
161bdd1243dSDimitry Andric   if (!ProfDataName)
162bdd1243dSDimitry Andric     return false;
163bdd1243dSDimitry Andric 
164bdd1243dSDimitry Andric   if (ProfDataName->getString().equals("branch_weights")) {
165bdd1243dSDimitry Andric     for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) {
166bdd1243dSDimitry Andric       auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
167bdd1243dSDimitry Andric       assert(V && "Malformed branch_weight in MD_prof node");
168bdd1243dSDimitry Andric       TotalVal += V->getValue().getZExtValue();
169bdd1243dSDimitry Andric     }
170bdd1243dSDimitry Andric     return true;
171bdd1243dSDimitry Andric   }
172bdd1243dSDimitry Andric 
173bdd1243dSDimitry Andric   if (ProfDataName->getString().equals("VP") &&
174bdd1243dSDimitry Andric       ProfileData->getNumOperands() > 3) {
175bdd1243dSDimitry Andric     TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2))
176bdd1243dSDimitry Andric                    ->getValue()
177bdd1243dSDimitry Andric                    .getZExtValue();
178bdd1243dSDimitry Andric     return true;
179bdd1243dSDimitry Andric   }
180bdd1243dSDimitry Andric   return false;
181bdd1243dSDimitry Andric }
182bdd1243dSDimitry Andric 
183bdd1243dSDimitry Andric bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
184bdd1243dSDimitry Andric   return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
185bdd1243dSDimitry Andric }
186bdd1243dSDimitry Andric 
187*5f757f3fSDimitry Andric void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) {
188*5f757f3fSDimitry Andric   MDBuilder MDB(I.getContext());
189*5f757f3fSDimitry Andric   MDNode *BranchWeights = MDB.createBranchWeights(Weights);
190*5f757f3fSDimitry Andric   I.setMetadata(LLVMContext::MD_prof, BranchWeights);
191*5f757f3fSDimitry Andric }
192*5f757f3fSDimitry Andric 
193bdd1243dSDimitry Andric } // namespace llvm
194