xref: /llvm-project/llvm/lib/IR/ProfDataUtils.cpp (revision 6f10b65297707c1e964d570421ab4559dc2928d4)
1a812b39eSPaul Kirth //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===//
2a812b39eSPaul Kirth //
3a812b39eSPaul Kirth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a812b39eSPaul Kirth // See https://llvm.org/LICENSE.txt for license information.
5a812b39eSPaul Kirth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a812b39eSPaul Kirth //
7a812b39eSPaul Kirth //===----------------------------------------------------------------------===//
8a812b39eSPaul Kirth //
9a812b39eSPaul Kirth // This file implements utilities for working with Profiling Metadata.
10a812b39eSPaul Kirth //
11a812b39eSPaul Kirth //===----------------------------------------------------------------------===//
12a812b39eSPaul Kirth 
136047deb7SPaul Kirth #include "llvm/IR/ProfDataUtils.h"
146047deb7SPaul Kirth #include "llvm/ADT/SmallVector.h"
156047deb7SPaul Kirth #include "llvm/IR/Constants.h"
166047deb7SPaul Kirth #include "llvm/IR/Function.h"
176047deb7SPaul Kirth #include "llvm/IR/Instructions.h"
186047deb7SPaul Kirth #include "llvm/IR/LLVMContext.h"
19cb4627d1SMatthias Braun #include "llvm/IR/MDBuilder.h"
206047deb7SPaul Kirth #include "llvm/IR/Metadata.h"
21294f3ce5SPaul Kirth #include "llvm/IR/ProfDataUtils.h"
226047deb7SPaul Kirth 
236047deb7SPaul Kirth using namespace llvm;
246047deb7SPaul Kirth 
256047deb7SPaul Kirth namespace {
266047deb7SPaul Kirth 
276047deb7SPaul Kirth // MD_prof nodes have the following layout
286047deb7SPaul Kirth //
296047deb7SPaul Kirth // In general:
306047deb7SPaul Kirth // { String name,         Array of i32   }
316047deb7SPaul Kirth //
326047deb7SPaul Kirth // In terms of Types:
336047deb7SPaul Kirth // { MDString,            [i32, i32, ...]}
346047deb7SPaul Kirth //
356047deb7SPaul Kirth // Concretely for Branch Weights
366047deb7SPaul Kirth // { "branch_weights",    [i32 1, i32 10000]}
376047deb7SPaul Kirth //
386047deb7SPaul Kirth // We maintain some constants here to ensure that we access the branch weights
396047deb7SPaul Kirth // correctly, and can change the behavior in the future if the layout changes
406047deb7SPaul Kirth 
416047deb7SPaul Kirth // the minimum number of operands for MD_prof nodes with branch weights
426047deb7SPaul Kirth constexpr unsigned MinBWOps = 3;
436047deb7SPaul Kirth 
4464f4ceb0SMingming Liu // the minimum number of operands for MD_prof nodes with value profiles
4564f4ceb0SMingming Liu constexpr unsigned MinVPOps = 5;
4664f4ceb0SMingming Liu 
476047deb7SPaul Kirth // We may want to add support for other MD_prof types, so provide an abstraction
486047deb7SPaul Kirth // for checking the metadata type.
496047deb7SPaul Kirth bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
506047deb7SPaul Kirth   // TODO: This routine may be simplified if MD_prof used an enum instead of a
516047deb7SPaul Kirth   // string to differentiate the types of MD_prof nodes.
526047deb7SPaul Kirth   if (!ProfData || !Name || MinOps < 2)
536047deb7SPaul Kirth     return false;
546047deb7SPaul Kirth 
556047deb7SPaul Kirth   unsigned NOps = ProfData->getNumOperands();
566047deb7SPaul Kirth   if (NOps < MinOps)
576047deb7SPaul Kirth     return false;
586047deb7SPaul Kirth 
596047deb7SPaul Kirth   auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0));
606047deb7SPaul Kirth   if (!ProfDataName)
616047deb7SPaul Kirth     return false;
626047deb7SPaul Kirth 
634e6f6fdaSKazu Hirata   return ProfDataName->getString() == Name;
646047deb7SPaul Kirth }
656047deb7SPaul Kirth 
667538df90SPaul Kirth template <typename T,
677538df90SPaul Kirth           typename = typename std::enable_if<std::is_arithmetic_v<T>>>
687538df90SPaul Kirth static void extractFromBranchWeightMD(const MDNode *ProfileData,
697538df90SPaul Kirth                                       SmallVectorImpl<T> &Weights) {
707538df90SPaul Kirth   assert(isBranchWeightMD(ProfileData) && "wrong metadata");
717538df90SPaul Kirth 
727538df90SPaul Kirth   unsigned NOps = ProfileData->getNumOperands();
73294f3ce5SPaul Kirth   unsigned WeightsIdx = getBranchWeightOffset(ProfileData);
747538df90SPaul Kirth   assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
757538df90SPaul Kirth   Weights.resize(NOps - WeightsIdx);
767538df90SPaul Kirth 
777538df90SPaul Kirth   for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
787538df90SPaul Kirth     ConstantInt *Weight =
797538df90SPaul Kirth         mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
807538df90SPaul Kirth     assert(Weight && "Malformed branch_weight in MD_prof node");
81294f3ce5SPaul Kirth     assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) &&
82294f3ce5SPaul Kirth            "Too many bits for MD_prof branch_weight");
837538df90SPaul Kirth     Weights[Idx - WeightsIdx] = Weight->getZExtValue();
847538df90SPaul Kirth   }
857538df90SPaul Kirth }
867538df90SPaul Kirth 
876047deb7SPaul Kirth } // namespace
886047deb7SPaul Kirth 
896047deb7SPaul Kirth namespace llvm {
906047deb7SPaul Kirth 
916047deb7SPaul Kirth bool hasProfMD(const Instruction &I) {
928f8cab6bSKazu Hirata   return I.hasMetadata(LLVMContext::MD_prof);
936047deb7SPaul Kirth }
946047deb7SPaul Kirth 
956047deb7SPaul Kirth bool isBranchWeightMD(const MDNode *ProfileData) {
966047deb7SPaul Kirth   return isTargetMD(ProfileData, "branch_weights", MinBWOps);
976047deb7SPaul Kirth }
986047deb7SPaul Kirth 
9964f4ceb0SMingming Liu bool isValueProfileMD(const MDNode *ProfileData) {
10064f4ceb0SMingming Liu   return isTargetMD(ProfileData, "VP", MinVPOps);
10164f4ceb0SMingming Liu }
10264f4ceb0SMingming Liu 
1036047deb7SPaul Kirth bool hasBranchWeightMD(const Instruction &I) {
1046047deb7SPaul Kirth   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
1056047deb7SPaul Kirth   return isBranchWeightMD(ProfileData);
1066047deb7SPaul Kirth }
1076047deb7SPaul Kirth 
10864f4ceb0SMingming Liu bool hasCountTypeMD(const Instruction &I) {
10964f4ceb0SMingming Liu   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
11064f4ceb0SMingming Liu   // Value profiles record count-type information.
11164f4ceb0SMingming Liu   if (isValueProfileMD(ProfileData))
11264f4ceb0SMingming Liu     return true;
11364f4ceb0SMingming Liu   // Conservatively assume non CallBase instruction only get taken/not-taken
11464f4ceb0SMingming Liu   // branch probability, so not interpret them as count.
11564f4ceb0SMingming Liu   return isa<CallBase>(I) && !isBranchWeightMD(ProfileData);
11664f4ceb0SMingming Liu }
11764f4ceb0SMingming Liu 
118e741b8c2SChristian Ulmann bool hasValidBranchWeightMD(const Instruction &I) {
119e741b8c2SChristian Ulmann   return getValidBranchWeightMDNode(I);
120e741b8c2SChristian Ulmann }
121e741b8c2SChristian Ulmann 
122294f3ce5SPaul Kirth bool hasBranchWeightOrigin(const Instruction &I) {
123294f3ce5SPaul Kirth   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
124294f3ce5SPaul Kirth   return hasBranchWeightOrigin(ProfileData);
125294f3ce5SPaul Kirth }
126294f3ce5SPaul Kirth 
127294f3ce5SPaul Kirth bool hasBranchWeightOrigin(const MDNode *ProfileData) {
128294f3ce5SPaul Kirth   if (!isBranchWeightMD(ProfileData))
129294f3ce5SPaul Kirth     return false;
130294f3ce5SPaul Kirth   auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1));
131294f3ce5SPaul Kirth   // NOTE: if we ever have more types of branch weight provenance,
132294f3ce5SPaul Kirth   // we need to check the string value is "expected". For now, we
133294f3ce5SPaul Kirth   // supply a more generic API, and avoid the spurious comparisons.
134294f3ce5SPaul Kirth   assert(ProfDataName == nullptr || ProfDataName->getString() == "expected");
135294f3ce5SPaul Kirth   return ProfDataName != nullptr;
136294f3ce5SPaul Kirth }
137294f3ce5SPaul Kirth 
138294f3ce5SPaul Kirth unsigned getBranchWeightOffset(const MDNode *ProfileData) {
139294f3ce5SPaul Kirth   return hasBranchWeightOrigin(ProfileData) ? 2 : 1;
140294f3ce5SPaul Kirth }
141294f3ce5SPaul Kirth 
142a3a44bfbSPaul Kirth unsigned getNumBranchWeights(const MDNode &ProfileData) {
143a3a44bfbSPaul Kirth   return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData);
144a3a44bfbSPaul Kirth }
145a3a44bfbSPaul Kirth 
146e741b8c2SChristian Ulmann MDNode *getBranchWeightMDNode(const Instruction &I) {
147e741b8c2SChristian Ulmann   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
148e741b8c2SChristian Ulmann   if (!isBranchWeightMD(ProfileData))
149e741b8c2SChristian Ulmann     return nullptr;
150e741b8c2SChristian Ulmann   return ProfileData;
151e741b8c2SChristian Ulmann }
152e741b8c2SChristian Ulmann 
153e741b8c2SChristian Ulmann MDNode *getValidBranchWeightMDNode(const Instruction &I) {
154e741b8c2SChristian Ulmann   auto *ProfileData = getBranchWeightMDNode(I);
155a3a44bfbSPaul Kirth   if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors())
156e741b8c2SChristian Ulmann     return ProfileData;
157e741b8c2SChristian Ulmann   return nullptr;
158e741b8c2SChristian Ulmann }
159e741b8c2SChristian Ulmann 
1607538df90SPaul Kirth void extractFromBranchWeightMD32(const MDNode *ProfileData,
161285e0235SMatthias Braun                                  SmallVectorImpl<uint32_t> &Weights) {
1627538df90SPaul Kirth   extractFromBranchWeightMD(ProfileData, Weights);
163285e0235SMatthias Braun }
1647538df90SPaul Kirth 
1657538df90SPaul Kirth void extractFromBranchWeightMD64(const MDNode *ProfileData,
1667538df90SPaul Kirth                                  SmallVectorImpl<uint64_t> &Weights) {
1677538df90SPaul Kirth   extractFromBranchWeightMD(ProfileData, Weights);
168285e0235SMatthias Braun }
169285e0235SMatthias Braun 
1706047deb7SPaul Kirth bool extractBranchWeights(const MDNode *ProfileData,
1716047deb7SPaul Kirth                           SmallVectorImpl<uint32_t> &Weights) {
1726047deb7SPaul Kirth   if (!isBranchWeightMD(ProfileData))
1736047deb7SPaul Kirth     return false;
174285e0235SMatthias Braun   extractFromBranchWeightMD(ProfileData, Weights);
175285e0235SMatthias Braun   return true;
1766047deb7SPaul Kirth }
1776047deb7SPaul Kirth 
1786047deb7SPaul Kirth bool extractBranchWeights(const Instruction &I,
1796047deb7SPaul Kirth                           SmallVectorImpl<uint32_t> &Weights) {
1806047deb7SPaul Kirth   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
1816047deb7SPaul Kirth   return extractBranchWeights(ProfileData, Weights);
1826047deb7SPaul Kirth }
1836047deb7SPaul Kirth 
1846047deb7SPaul Kirth bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
1856047deb7SPaul Kirth                           uint64_t &FalseVal) {
1866047deb7SPaul Kirth   assert((I.getOpcode() == Instruction::Br ||
1876047deb7SPaul Kirth           I.getOpcode() == Instruction::Select) &&
188e741b8c2SChristian Ulmann          "Looking for branch weights on something besides branch, select, or "
189e741b8c2SChristian Ulmann          "switch");
1906047deb7SPaul Kirth 
1916047deb7SPaul Kirth   SmallVector<uint32_t, 2> Weights;
1926047deb7SPaul Kirth   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
1936047deb7SPaul Kirth   if (!extractBranchWeights(ProfileData, Weights))
1946047deb7SPaul Kirth     return false;
1956047deb7SPaul Kirth 
1966047deb7SPaul Kirth   if (Weights.size() > 2)
1976047deb7SPaul Kirth     return false;
1986047deb7SPaul Kirth 
1996047deb7SPaul Kirth   TrueVal = Weights[0];
2006047deb7SPaul Kirth   FalseVal = Weights[1];
2016047deb7SPaul Kirth   return true;
2026047deb7SPaul Kirth }
2036047deb7SPaul Kirth 
2046047deb7SPaul Kirth bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
2056047deb7SPaul Kirth   TotalVal = 0;
2066047deb7SPaul Kirth   if (!ProfileData)
2076047deb7SPaul Kirth     return false;
2086047deb7SPaul Kirth 
2096047deb7SPaul Kirth   auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
2106047deb7SPaul Kirth   if (!ProfDataName)
2116047deb7SPaul Kirth     return false;
2126047deb7SPaul Kirth 
2134e6f6fdaSKazu Hirata   if (ProfDataName->getString() == "branch_weights") {
214294f3ce5SPaul Kirth     unsigned Offset = getBranchWeightOffset(ProfileData);
215294f3ce5SPaul Kirth     for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) {
216*c71b2122SMatt Arsenault       auto *V = mdconst::extract<ConstantInt>(ProfileData->getOperand(Idx));
2176047deb7SPaul Kirth       TotalVal += V->getValue().getZExtValue();
2186047deb7SPaul Kirth     }
2196047deb7SPaul Kirth     return true;
220deef5b8cSPaul Kirth   }
221deef5b8cSPaul Kirth 
2224e6f6fdaSKazu Hirata   if (ProfDataName->getString() == "VP" && ProfileData->getNumOperands() > 3) {
2236047deb7SPaul Kirth     TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2))
2246047deb7SPaul Kirth                    ->getValue()
2256047deb7SPaul Kirth                    .getZExtValue();
2266047deb7SPaul Kirth     return true;
2276047deb7SPaul Kirth   }
2286047deb7SPaul Kirth   return false;
2296047deb7SPaul Kirth }
2306047deb7SPaul Kirth 
231e741b8c2SChristian Ulmann bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
232e741b8c2SChristian Ulmann   return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
233e741b8c2SChristian Ulmann }
234e741b8c2SChristian Ulmann 
235294f3ce5SPaul Kirth void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
236294f3ce5SPaul Kirth                       bool IsExpected) {
237cb4627d1SMatthias Braun   MDBuilder MDB(I.getContext());
238294f3ce5SPaul Kirth   MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
239cb4627d1SMatthias Braun   I.setMetadata(LLVMContext::MD_prof, BranchWeights);
240cb4627d1SMatthias Braun }
241cb4627d1SMatthias Braun 
2422d641858SMingming Liu void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
2432d641858SMingming Liu   assert(T != 0 && "Caller should guarantee");
2442d641858SMingming Liu   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
2452d641858SMingming Liu   if (ProfileData == nullptr)
2462d641858SMingming Liu     return;
2472d641858SMingming Liu 
2482d641858SMingming Liu   auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
2494e6f6fdaSKazu Hirata   if (!ProfDataName || (ProfDataName->getString() != "branch_weights" &&
2504e6f6fdaSKazu Hirata                         ProfDataName->getString() != "VP"))
2512d641858SMingming Liu     return;
2522d641858SMingming Liu 
25364f4ceb0SMingming Liu   if (!hasCountTypeMD(I))
25464f4ceb0SMingming Liu     return;
25564f4ceb0SMingming Liu 
2562d641858SMingming Liu   LLVMContext &C = I.getContext();
2572d641858SMingming Liu 
2582d641858SMingming Liu   MDBuilder MDB(C);
2592d641858SMingming Liu   SmallVector<Metadata *, 3> Vals;
2602d641858SMingming Liu   Vals.push_back(ProfileData->getOperand(0));
2612d641858SMingming Liu   APInt APS(128, S), APT(128, T);
2624e6f6fdaSKazu Hirata   if (ProfDataName->getString() == "branch_weights" &&
2632d641858SMingming Liu       ProfileData->getNumOperands() > 0) {
2642d641858SMingming Liu     // Using APInt::div may be expensive, but most cases should fit 64 bits.
265294f3ce5SPaul Kirth     APInt Val(128,
266294f3ce5SPaul Kirth               mdconst::dyn_extract<ConstantInt>(
267294f3ce5SPaul Kirth                   ProfileData->getOperand(getBranchWeightOffset(ProfileData)))
2682d641858SMingming Liu                   ->getValue()
2692d641858SMingming Liu                   .getZExtValue());
2702d641858SMingming Liu     Val *= APS;
2712d641858SMingming Liu     Vals.push_back(MDB.createConstant(ConstantInt::get(
2722d641858SMingming Liu         Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX))));
2734e6f6fdaSKazu Hirata   } else if (ProfDataName->getString() == "VP")
2742d641858SMingming Liu     for (unsigned i = 1; i < ProfileData->getNumOperands(); i += 2) {
2752d641858SMingming Liu       // The first value is the key of the value profile, which will not change.
2762d641858SMingming Liu       Vals.push_back(ProfileData->getOperand(i));
2772d641858SMingming Liu       uint64_t Count =
2782d641858SMingming Liu           mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i + 1))
2792d641858SMingming Liu               ->getValue()
2802d641858SMingming Liu               .getZExtValue();
2812d641858SMingming Liu       // Don't scale the magic number.
2822d641858SMingming Liu       if (Count == NOMORE_ICP_MAGICNUM) {
2832d641858SMingming Liu         Vals.push_back(ProfileData->getOperand(i + 1));
2842d641858SMingming Liu         continue;
2852d641858SMingming Liu       }
2862d641858SMingming Liu       // Using APInt::div may be expensive, but most cases should fit 64 bits.
2872d641858SMingming Liu       APInt Val(128, Count);
2882d641858SMingming Liu       Val *= APS;
2892d641858SMingming Liu       Vals.push_back(MDB.createConstant(ConstantInt::get(
2902d641858SMingming Liu           Type::getInt64Ty(C), Val.udiv(APT).getLimitedValue())));
2912d641858SMingming Liu     }
2922d641858SMingming Liu   I.setMetadata(LLVMContext::MD_prof, MDNode::get(C, Vals));
2932d641858SMingming Liu }
2942d641858SMingming Liu 
2956047deb7SPaul Kirth } // namespace llvm
296