xref: /llvm-project/llvm/lib/IR/ProfDataUtils.cpp (revision a3a44bfbdfefe0928124f9e40d242507f75b87f4)
1a812b39eSPaul Kirth //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===//
2a812b39eSPaul Kirth //
3a812b39eSPaul Kirth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a812b39eSPaul Kirth // See https://llvm.org/LICENSE.txt for license information.
5a812b39eSPaul Kirth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a812b39eSPaul Kirth //
7a812b39eSPaul Kirth //===----------------------------------------------------------------------===//
8a812b39eSPaul Kirth //
9a812b39eSPaul Kirth // This file implements utilities for working with Profiling Metadata.
10a812b39eSPaul Kirth //
11a812b39eSPaul Kirth //===----------------------------------------------------------------------===//
12a812b39eSPaul Kirth 
136047deb7SPaul Kirth #include "llvm/IR/ProfDataUtils.h"
146047deb7SPaul Kirth #include "llvm/ADT/SmallVector.h"
156047deb7SPaul Kirth #include "llvm/ADT/Twine.h"
166047deb7SPaul Kirth #include "llvm/IR/Constants.h"
176047deb7SPaul Kirth #include "llvm/IR/Function.h"
186047deb7SPaul Kirth #include "llvm/IR/Instructions.h"
196047deb7SPaul Kirth #include "llvm/IR/LLVMContext.h"
20cb4627d1SMatthias Braun #include "llvm/IR/MDBuilder.h"
216047deb7SPaul Kirth #include "llvm/IR/Metadata.h"
22294f3ce5SPaul Kirth #include "llvm/IR/ProfDataUtils.h"
236047deb7SPaul Kirth #include "llvm/Support/BranchProbability.h"
246047deb7SPaul Kirth #include "llvm/Support/CommandLine.h"
256047deb7SPaul Kirth 
266047deb7SPaul Kirth using namespace llvm;
276047deb7SPaul Kirth 
286047deb7SPaul Kirth namespace {
296047deb7SPaul Kirth 
306047deb7SPaul Kirth // MD_prof nodes have the following layout
316047deb7SPaul Kirth //
326047deb7SPaul Kirth // In general:
336047deb7SPaul Kirth // { String name,         Array of i32   }
346047deb7SPaul Kirth //
356047deb7SPaul Kirth // In terms of Types:
366047deb7SPaul Kirth // { MDString,            [i32, i32, ...]}
376047deb7SPaul Kirth //
386047deb7SPaul Kirth // Concretely for Branch Weights
396047deb7SPaul Kirth // { "branch_weights",    [i32 1, i32 10000]}
406047deb7SPaul Kirth //
416047deb7SPaul Kirth // We maintain some constants here to ensure that we access the branch weights
426047deb7SPaul Kirth // correctly, and can change the behavior in the future if the layout changes
436047deb7SPaul Kirth 
446047deb7SPaul Kirth // the minimum number of operands for MD_prof nodes with branch weights
456047deb7SPaul Kirth constexpr unsigned MinBWOps = 3;
466047deb7SPaul Kirth 
4764f4ceb0SMingming Liu // the minimum number of operands for MD_prof nodes with value profiles
4864f4ceb0SMingming Liu constexpr unsigned MinVPOps = 5;
4964f4ceb0SMingming Liu 
506047deb7SPaul Kirth // We may want to add support for other MD_prof types, so provide an abstraction
516047deb7SPaul Kirth // for checking the metadata type.
526047deb7SPaul Kirth bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
536047deb7SPaul Kirth   // TODO: This routine may be simplified if MD_prof used an enum instead of a
546047deb7SPaul Kirth   // string to differentiate the types of MD_prof nodes.
556047deb7SPaul Kirth   if (!ProfData || !Name || MinOps < 2)
566047deb7SPaul Kirth     return false;
576047deb7SPaul Kirth 
586047deb7SPaul Kirth   unsigned NOps = ProfData->getNumOperands();
596047deb7SPaul Kirth   if (NOps < MinOps)
606047deb7SPaul Kirth     return false;
616047deb7SPaul Kirth 
626047deb7SPaul Kirth   auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0));
636047deb7SPaul Kirth   if (!ProfDataName)
646047deb7SPaul Kirth     return false;
656047deb7SPaul Kirth 
664e6f6fdaSKazu Hirata   return ProfDataName->getString() == Name;
676047deb7SPaul Kirth }
686047deb7SPaul Kirth 
697538df90SPaul Kirth template <typename T,
707538df90SPaul Kirth           typename = typename std::enable_if<std::is_arithmetic_v<T>>>
717538df90SPaul Kirth static void extractFromBranchWeightMD(const MDNode *ProfileData,
727538df90SPaul Kirth                                       SmallVectorImpl<T> &Weights) {
737538df90SPaul Kirth   assert(isBranchWeightMD(ProfileData) && "wrong metadata");
747538df90SPaul Kirth 
757538df90SPaul Kirth   unsigned NOps = ProfileData->getNumOperands();
76294f3ce5SPaul Kirth   unsigned WeightsIdx = getBranchWeightOffset(ProfileData);
777538df90SPaul Kirth   assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
787538df90SPaul Kirth   Weights.resize(NOps - WeightsIdx);
797538df90SPaul Kirth 
807538df90SPaul Kirth   for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
817538df90SPaul Kirth     ConstantInt *Weight =
827538df90SPaul Kirth         mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
837538df90SPaul Kirth     assert(Weight && "Malformed branch_weight in MD_prof node");
84294f3ce5SPaul Kirth     assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) &&
85294f3ce5SPaul Kirth            "Too many bits for MD_prof branch_weight");
867538df90SPaul Kirth     Weights[Idx - WeightsIdx] = Weight->getZExtValue();
877538df90SPaul Kirth   }
887538df90SPaul Kirth }
897538df90SPaul Kirth 
906047deb7SPaul Kirth } // namespace
916047deb7SPaul Kirth 
926047deb7SPaul Kirth namespace llvm {
936047deb7SPaul Kirth 
946047deb7SPaul Kirth bool hasProfMD(const Instruction &I) {
958f8cab6bSKazu Hirata   return I.hasMetadata(LLVMContext::MD_prof);
966047deb7SPaul Kirth }
976047deb7SPaul Kirth 
986047deb7SPaul Kirth bool isBranchWeightMD(const MDNode *ProfileData) {
996047deb7SPaul Kirth   return isTargetMD(ProfileData, "branch_weights", MinBWOps);
1006047deb7SPaul Kirth }
1016047deb7SPaul Kirth 
10264f4ceb0SMingming Liu bool isValueProfileMD(const MDNode *ProfileData) {
10364f4ceb0SMingming Liu   return isTargetMD(ProfileData, "VP", MinVPOps);
10464f4ceb0SMingming Liu }
10564f4ceb0SMingming Liu 
1066047deb7SPaul Kirth bool hasBranchWeightMD(const Instruction &I) {
1076047deb7SPaul Kirth   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
1086047deb7SPaul Kirth   return isBranchWeightMD(ProfileData);
1096047deb7SPaul Kirth }
1106047deb7SPaul Kirth 
11164f4ceb0SMingming Liu bool hasCountTypeMD(const Instruction &I) {
11264f4ceb0SMingming Liu   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
11364f4ceb0SMingming Liu   // Value profiles record count-type information.
11464f4ceb0SMingming Liu   if (isValueProfileMD(ProfileData))
11564f4ceb0SMingming Liu     return true;
11664f4ceb0SMingming Liu   // Conservatively assume non CallBase instruction only get taken/not-taken
11764f4ceb0SMingming Liu   // branch probability, so not interpret them as count.
11864f4ceb0SMingming Liu   return isa<CallBase>(I) && !isBranchWeightMD(ProfileData);
11964f4ceb0SMingming Liu }
12064f4ceb0SMingming Liu 
121e741b8c2SChristian Ulmann bool hasValidBranchWeightMD(const Instruction &I) {
122e741b8c2SChristian Ulmann   return getValidBranchWeightMDNode(I);
123e741b8c2SChristian Ulmann }
124e741b8c2SChristian Ulmann 
125294f3ce5SPaul Kirth bool hasBranchWeightOrigin(const Instruction &I) {
126294f3ce5SPaul Kirth   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
127294f3ce5SPaul Kirth   return hasBranchWeightOrigin(ProfileData);
128294f3ce5SPaul Kirth }
129294f3ce5SPaul Kirth 
130294f3ce5SPaul Kirth bool hasBranchWeightOrigin(const MDNode *ProfileData) {
131294f3ce5SPaul Kirth   if (!isBranchWeightMD(ProfileData))
132294f3ce5SPaul Kirth     return false;
133294f3ce5SPaul Kirth   auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1));
134294f3ce5SPaul Kirth   // NOTE: if we ever have more types of branch weight provenance,
135294f3ce5SPaul Kirth   // we need to check the string value is "expected". For now, we
136294f3ce5SPaul Kirth   // supply a more generic API, and avoid the spurious comparisons.
137294f3ce5SPaul Kirth   assert(ProfDataName == nullptr || ProfDataName->getString() == "expected");
138294f3ce5SPaul Kirth   return ProfDataName != nullptr;
139294f3ce5SPaul Kirth }
140294f3ce5SPaul Kirth 
141294f3ce5SPaul Kirth unsigned getBranchWeightOffset(const MDNode *ProfileData) {
142294f3ce5SPaul Kirth   return hasBranchWeightOrigin(ProfileData) ? 2 : 1;
143294f3ce5SPaul Kirth }
144294f3ce5SPaul Kirth 
145*a3a44bfbSPaul Kirth unsigned getNumBranchWeights(const MDNode &ProfileData) {
146*a3a44bfbSPaul Kirth   return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData);
147*a3a44bfbSPaul Kirth }
148*a3a44bfbSPaul Kirth 
149e741b8c2SChristian Ulmann MDNode *getBranchWeightMDNode(const Instruction &I) {
150e741b8c2SChristian Ulmann   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
151e741b8c2SChristian Ulmann   if (!isBranchWeightMD(ProfileData))
152e741b8c2SChristian Ulmann     return nullptr;
153e741b8c2SChristian Ulmann   return ProfileData;
154e741b8c2SChristian Ulmann }
155e741b8c2SChristian Ulmann 
156e741b8c2SChristian Ulmann MDNode *getValidBranchWeightMDNode(const Instruction &I) {
157e741b8c2SChristian Ulmann   auto *ProfileData = getBranchWeightMDNode(I);
158*a3a44bfbSPaul Kirth   if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors())
159e741b8c2SChristian Ulmann     return ProfileData;
160e741b8c2SChristian Ulmann   return nullptr;
161e741b8c2SChristian Ulmann }
162e741b8c2SChristian Ulmann 
1637538df90SPaul Kirth void extractFromBranchWeightMD32(const MDNode *ProfileData,
164285e0235SMatthias Braun                                  SmallVectorImpl<uint32_t> &Weights) {
1657538df90SPaul Kirth   extractFromBranchWeightMD(ProfileData, Weights);
166285e0235SMatthias Braun }
1677538df90SPaul Kirth 
1687538df90SPaul Kirth void extractFromBranchWeightMD64(const MDNode *ProfileData,
1697538df90SPaul Kirth                                  SmallVectorImpl<uint64_t> &Weights) {
1707538df90SPaul Kirth   extractFromBranchWeightMD(ProfileData, Weights);
171285e0235SMatthias Braun }
172285e0235SMatthias Braun 
1736047deb7SPaul Kirth bool extractBranchWeights(const MDNode *ProfileData,
1746047deb7SPaul Kirth                           SmallVectorImpl<uint32_t> &Weights) {
1756047deb7SPaul Kirth   if (!isBranchWeightMD(ProfileData))
1766047deb7SPaul Kirth     return false;
177285e0235SMatthias Braun   extractFromBranchWeightMD(ProfileData, Weights);
178285e0235SMatthias Braun   return true;
1796047deb7SPaul Kirth }
1806047deb7SPaul Kirth 
1816047deb7SPaul Kirth bool extractBranchWeights(const Instruction &I,
1826047deb7SPaul Kirth                           SmallVectorImpl<uint32_t> &Weights) {
1836047deb7SPaul Kirth   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
1846047deb7SPaul Kirth   return extractBranchWeights(ProfileData, Weights);
1856047deb7SPaul Kirth }
1866047deb7SPaul Kirth 
1876047deb7SPaul Kirth bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
1886047deb7SPaul Kirth                           uint64_t &FalseVal) {
1896047deb7SPaul Kirth   assert((I.getOpcode() == Instruction::Br ||
1906047deb7SPaul Kirth           I.getOpcode() == Instruction::Select) &&
191e741b8c2SChristian Ulmann          "Looking for branch weights on something besides branch, select, or "
192e741b8c2SChristian Ulmann          "switch");
1936047deb7SPaul Kirth 
1946047deb7SPaul Kirth   SmallVector<uint32_t, 2> Weights;
1956047deb7SPaul Kirth   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
1966047deb7SPaul Kirth   if (!extractBranchWeights(ProfileData, Weights))
1976047deb7SPaul Kirth     return false;
1986047deb7SPaul Kirth 
1996047deb7SPaul Kirth   if (Weights.size() > 2)
2006047deb7SPaul Kirth     return false;
2016047deb7SPaul Kirth 
2026047deb7SPaul Kirth   TrueVal = Weights[0];
2036047deb7SPaul Kirth   FalseVal = Weights[1];
2046047deb7SPaul Kirth   return true;
2056047deb7SPaul Kirth }
2066047deb7SPaul Kirth 
2076047deb7SPaul Kirth bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
2086047deb7SPaul Kirth   TotalVal = 0;
2096047deb7SPaul Kirth   if (!ProfileData)
2106047deb7SPaul Kirth     return false;
2116047deb7SPaul Kirth 
2126047deb7SPaul Kirth   auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
2136047deb7SPaul Kirth   if (!ProfDataName)
2146047deb7SPaul Kirth     return false;
2156047deb7SPaul Kirth 
2164e6f6fdaSKazu Hirata   if (ProfDataName->getString() == "branch_weights") {
217294f3ce5SPaul Kirth     unsigned Offset = getBranchWeightOffset(ProfileData);
218294f3ce5SPaul Kirth     for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) {
2196ce03ff3SShubham Sandeep Rastogi       auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
2206047deb7SPaul Kirth       assert(V && "Malformed branch_weight in MD_prof node");
2216047deb7SPaul Kirth       TotalVal += V->getValue().getZExtValue();
2226047deb7SPaul Kirth     }
2236047deb7SPaul Kirth     return true;
224deef5b8cSPaul Kirth   }
225deef5b8cSPaul Kirth 
2264e6f6fdaSKazu Hirata   if (ProfDataName->getString() == "VP" && ProfileData->getNumOperands() > 3) {
2276047deb7SPaul Kirth     TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2))
2286047deb7SPaul Kirth                    ->getValue()
2296047deb7SPaul Kirth                    .getZExtValue();
2306047deb7SPaul Kirth     return true;
2316047deb7SPaul Kirth   }
2326047deb7SPaul Kirth   return false;
2336047deb7SPaul Kirth }
2346047deb7SPaul Kirth 
235e741b8c2SChristian Ulmann bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
236e741b8c2SChristian Ulmann   return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
237e741b8c2SChristian Ulmann }
238e741b8c2SChristian Ulmann 
239294f3ce5SPaul Kirth void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
240294f3ce5SPaul Kirth                       bool IsExpected) {
241cb4627d1SMatthias Braun   MDBuilder MDB(I.getContext());
242294f3ce5SPaul Kirth   MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
243cb4627d1SMatthias Braun   I.setMetadata(LLVMContext::MD_prof, BranchWeights);
244cb4627d1SMatthias Braun }
245cb4627d1SMatthias Braun 
2462d641858SMingming Liu void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
2472d641858SMingming Liu   assert(T != 0 && "Caller should guarantee");
2482d641858SMingming Liu   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
2492d641858SMingming Liu   if (ProfileData == nullptr)
2502d641858SMingming Liu     return;
2512d641858SMingming Liu 
2522d641858SMingming Liu   auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
2534e6f6fdaSKazu Hirata   if (!ProfDataName || (ProfDataName->getString() != "branch_weights" &&
2544e6f6fdaSKazu Hirata                         ProfDataName->getString() != "VP"))
2552d641858SMingming Liu     return;
2562d641858SMingming Liu 
25764f4ceb0SMingming Liu   if (!hasCountTypeMD(I))
25864f4ceb0SMingming Liu     return;
25964f4ceb0SMingming Liu 
2602d641858SMingming Liu   LLVMContext &C = I.getContext();
2612d641858SMingming Liu 
2622d641858SMingming Liu   MDBuilder MDB(C);
2632d641858SMingming Liu   SmallVector<Metadata *, 3> Vals;
2642d641858SMingming Liu   Vals.push_back(ProfileData->getOperand(0));
2652d641858SMingming Liu   APInt APS(128, S), APT(128, T);
2664e6f6fdaSKazu Hirata   if (ProfDataName->getString() == "branch_weights" &&
2672d641858SMingming Liu       ProfileData->getNumOperands() > 0) {
2682d641858SMingming Liu     // Using APInt::div may be expensive, but most cases should fit 64 bits.
269294f3ce5SPaul Kirth     APInt Val(128,
270294f3ce5SPaul Kirth               mdconst::dyn_extract<ConstantInt>(
271294f3ce5SPaul Kirth                   ProfileData->getOperand(getBranchWeightOffset(ProfileData)))
2722d641858SMingming Liu                   ->getValue()
2732d641858SMingming Liu                   .getZExtValue());
2742d641858SMingming Liu     Val *= APS;
2752d641858SMingming Liu     Vals.push_back(MDB.createConstant(ConstantInt::get(
2762d641858SMingming Liu         Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX))));
2774e6f6fdaSKazu Hirata   } else if (ProfDataName->getString() == "VP")
2782d641858SMingming Liu     for (unsigned i = 1; i < ProfileData->getNumOperands(); i += 2) {
2792d641858SMingming Liu       // The first value is the key of the value profile, which will not change.
2802d641858SMingming Liu       Vals.push_back(ProfileData->getOperand(i));
2812d641858SMingming Liu       uint64_t Count =
2822d641858SMingming Liu           mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i + 1))
2832d641858SMingming Liu               ->getValue()
2842d641858SMingming Liu               .getZExtValue();
2852d641858SMingming Liu       // Don't scale the magic number.
2862d641858SMingming Liu       if (Count == NOMORE_ICP_MAGICNUM) {
2872d641858SMingming Liu         Vals.push_back(ProfileData->getOperand(i + 1));
2882d641858SMingming Liu         continue;
2892d641858SMingming Liu       }
2902d641858SMingming Liu       // Using APInt::div may be expensive, but most cases should fit 64 bits.
2912d641858SMingming Liu       APInt Val(128, Count);
2922d641858SMingming Liu       Val *= APS;
2932d641858SMingming Liu       Vals.push_back(MDB.createConstant(ConstantInt::get(
2942d641858SMingming Liu           Type::getInt64Ty(C), Val.udiv(APT).getLimitedValue())));
2952d641858SMingming Liu     }
2962d641858SMingming Liu   I.setMetadata(LLVMContext::MD_prof, MDNode::get(C, Vals));
2972d641858SMingming Liu }
2982d641858SMingming Liu 
2996047deb7SPaul Kirth } // namespace llvm
300