1a812b39eSPaul Kirth //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===// 2a812b39eSPaul Kirth // 3a812b39eSPaul Kirth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a812b39eSPaul Kirth // See https://llvm.org/LICENSE.txt for license information. 5a812b39eSPaul Kirth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a812b39eSPaul Kirth // 7a812b39eSPaul Kirth //===----------------------------------------------------------------------===// 8a812b39eSPaul Kirth // 9a812b39eSPaul Kirth // This file implements utilities for working with Profiling Metadata. 10a812b39eSPaul Kirth // 11a812b39eSPaul Kirth //===----------------------------------------------------------------------===// 12a812b39eSPaul Kirth 136047deb7SPaul Kirth #include "llvm/IR/ProfDataUtils.h" 146047deb7SPaul Kirth #include "llvm/ADT/SmallVector.h" 156047deb7SPaul Kirth #include "llvm/ADT/Twine.h" 166047deb7SPaul Kirth #include "llvm/IR/Constants.h" 176047deb7SPaul Kirth #include "llvm/IR/Function.h" 186047deb7SPaul Kirth #include "llvm/IR/Instructions.h" 196047deb7SPaul Kirth #include "llvm/IR/LLVMContext.h" 20cb4627d1SMatthias Braun #include "llvm/IR/MDBuilder.h" 216047deb7SPaul Kirth #include "llvm/IR/Metadata.h" 22294f3ce5SPaul Kirth #include "llvm/IR/ProfDataUtils.h" 236047deb7SPaul Kirth #include "llvm/Support/BranchProbability.h" 246047deb7SPaul Kirth #include "llvm/Support/CommandLine.h" 256047deb7SPaul Kirth 266047deb7SPaul Kirth using namespace llvm; 276047deb7SPaul Kirth 286047deb7SPaul Kirth namespace { 296047deb7SPaul Kirth 306047deb7SPaul Kirth // MD_prof nodes have the following layout 316047deb7SPaul Kirth // 326047deb7SPaul Kirth // In general: 336047deb7SPaul Kirth // { String name, Array of i32 } 346047deb7SPaul Kirth // 356047deb7SPaul Kirth // In terms of Types: 366047deb7SPaul Kirth // { MDString, [i32, i32, ...]} 376047deb7SPaul Kirth // 386047deb7SPaul Kirth // Concretely for Branch Weights 396047deb7SPaul Kirth // { "branch_weights", [i32 1, i32 10000]} 406047deb7SPaul Kirth // 416047deb7SPaul Kirth // We maintain some constants here to ensure that we access the branch weights 426047deb7SPaul Kirth // correctly, and can change the behavior in the future if the layout changes 436047deb7SPaul Kirth 446047deb7SPaul Kirth // the minimum number of operands for MD_prof nodes with branch weights 456047deb7SPaul Kirth constexpr unsigned MinBWOps = 3; 466047deb7SPaul Kirth 4764f4ceb0SMingming Liu // the minimum number of operands for MD_prof nodes with value profiles 4864f4ceb0SMingming Liu constexpr unsigned MinVPOps = 5; 4964f4ceb0SMingming Liu 506047deb7SPaul Kirth // We may want to add support for other MD_prof types, so provide an abstraction 516047deb7SPaul Kirth // for checking the metadata type. 526047deb7SPaul Kirth bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { 536047deb7SPaul Kirth // TODO: This routine may be simplified if MD_prof used an enum instead of a 546047deb7SPaul Kirth // string to differentiate the types of MD_prof nodes. 556047deb7SPaul Kirth if (!ProfData || !Name || MinOps < 2) 566047deb7SPaul Kirth return false; 576047deb7SPaul Kirth 586047deb7SPaul Kirth unsigned NOps = ProfData->getNumOperands(); 596047deb7SPaul Kirth if (NOps < MinOps) 606047deb7SPaul Kirth return false; 616047deb7SPaul Kirth 626047deb7SPaul Kirth auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0)); 636047deb7SPaul Kirth if (!ProfDataName) 646047deb7SPaul Kirth return false; 656047deb7SPaul Kirth 664e6f6fdaSKazu Hirata return ProfDataName->getString() == Name; 676047deb7SPaul Kirth } 686047deb7SPaul Kirth 697538df90SPaul Kirth template <typename T, 707538df90SPaul Kirth typename = typename std::enable_if<std::is_arithmetic_v<T>>> 717538df90SPaul Kirth static void extractFromBranchWeightMD(const MDNode *ProfileData, 727538df90SPaul Kirth SmallVectorImpl<T> &Weights) { 737538df90SPaul Kirth assert(isBranchWeightMD(ProfileData) && "wrong metadata"); 747538df90SPaul Kirth 757538df90SPaul Kirth unsigned NOps = ProfileData->getNumOperands(); 76294f3ce5SPaul Kirth unsigned WeightsIdx = getBranchWeightOffset(ProfileData); 777538df90SPaul Kirth assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); 787538df90SPaul Kirth Weights.resize(NOps - WeightsIdx); 797538df90SPaul Kirth 807538df90SPaul Kirth for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { 817538df90SPaul Kirth ConstantInt *Weight = 827538df90SPaul Kirth mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); 837538df90SPaul Kirth assert(Weight && "Malformed branch_weight in MD_prof node"); 84294f3ce5SPaul Kirth assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) && 85294f3ce5SPaul Kirth "Too many bits for MD_prof branch_weight"); 867538df90SPaul Kirth Weights[Idx - WeightsIdx] = Weight->getZExtValue(); 877538df90SPaul Kirth } 887538df90SPaul Kirth } 897538df90SPaul Kirth 906047deb7SPaul Kirth } // namespace 916047deb7SPaul Kirth 926047deb7SPaul Kirth namespace llvm { 936047deb7SPaul Kirth 946047deb7SPaul Kirth bool hasProfMD(const Instruction &I) { 958f8cab6bSKazu Hirata return I.hasMetadata(LLVMContext::MD_prof); 966047deb7SPaul Kirth } 976047deb7SPaul Kirth 986047deb7SPaul Kirth bool isBranchWeightMD(const MDNode *ProfileData) { 996047deb7SPaul Kirth return isTargetMD(ProfileData, "branch_weights", MinBWOps); 1006047deb7SPaul Kirth } 1016047deb7SPaul Kirth 10264f4ceb0SMingming Liu bool isValueProfileMD(const MDNode *ProfileData) { 10364f4ceb0SMingming Liu return isTargetMD(ProfileData, "VP", MinVPOps); 10464f4ceb0SMingming Liu } 10564f4ceb0SMingming Liu 1066047deb7SPaul Kirth bool hasBranchWeightMD(const Instruction &I) { 1076047deb7SPaul Kirth auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 1086047deb7SPaul Kirth return isBranchWeightMD(ProfileData); 1096047deb7SPaul Kirth } 1106047deb7SPaul Kirth 11164f4ceb0SMingming Liu bool hasCountTypeMD(const Instruction &I) { 11264f4ceb0SMingming Liu auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 11364f4ceb0SMingming Liu // Value profiles record count-type information. 11464f4ceb0SMingming Liu if (isValueProfileMD(ProfileData)) 11564f4ceb0SMingming Liu return true; 11664f4ceb0SMingming Liu // Conservatively assume non CallBase instruction only get taken/not-taken 11764f4ceb0SMingming Liu // branch probability, so not interpret them as count. 11864f4ceb0SMingming Liu return isa<CallBase>(I) && !isBranchWeightMD(ProfileData); 11964f4ceb0SMingming Liu } 12064f4ceb0SMingming Liu 121e741b8c2SChristian Ulmann bool hasValidBranchWeightMD(const Instruction &I) { 122e741b8c2SChristian Ulmann return getValidBranchWeightMDNode(I); 123e741b8c2SChristian Ulmann } 124e741b8c2SChristian Ulmann 125294f3ce5SPaul Kirth bool hasBranchWeightOrigin(const Instruction &I) { 126294f3ce5SPaul Kirth auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 127294f3ce5SPaul Kirth return hasBranchWeightOrigin(ProfileData); 128294f3ce5SPaul Kirth } 129294f3ce5SPaul Kirth 130294f3ce5SPaul Kirth bool hasBranchWeightOrigin(const MDNode *ProfileData) { 131294f3ce5SPaul Kirth if (!isBranchWeightMD(ProfileData)) 132294f3ce5SPaul Kirth return false; 133294f3ce5SPaul Kirth auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1)); 134294f3ce5SPaul Kirth // NOTE: if we ever have more types of branch weight provenance, 135294f3ce5SPaul Kirth // we need to check the string value is "expected". For now, we 136294f3ce5SPaul Kirth // supply a more generic API, and avoid the spurious comparisons. 137294f3ce5SPaul Kirth assert(ProfDataName == nullptr || ProfDataName->getString() == "expected"); 138294f3ce5SPaul Kirth return ProfDataName != nullptr; 139294f3ce5SPaul Kirth } 140294f3ce5SPaul Kirth 141294f3ce5SPaul Kirth unsigned getBranchWeightOffset(const MDNode *ProfileData) { 142294f3ce5SPaul Kirth return hasBranchWeightOrigin(ProfileData) ? 2 : 1; 143294f3ce5SPaul Kirth } 144294f3ce5SPaul Kirth 145*a3a44bfbSPaul Kirth unsigned getNumBranchWeights(const MDNode &ProfileData) { 146*a3a44bfbSPaul Kirth return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData); 147*a3a44bfbSPaul Kirth } 148*a3a44bfbSPaul Kirth 149e741b8c2SChristian Ulmann MDNode *getBranchWeightMDNode(const Instruction &I) { 150e741b8c2SChristian Ulmann auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 151e741b8c2SChristian Ulmann if (!isBranchWeightMD(ProfileData)) 152e741b8c2SChristian Ulmann return nullptr; 153e741b8c2SChristian Ulmann return ProfileData; 154e741b8c2SChristian Ulmann } 155e741b8c2SChristian Ulmann 156e741b8c2SChristian Ulmann MDNode *getValidBranchWeightMDNode(const Instruction &I) { 157e741b8c2SChristian Ulmann auto *ProfileData = getBranchWeightMDNode(I); 158*a3a44bfbSPaul Kirth if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors()) 159e741b8c2SChristian Ulmann return ProfileData; 160e741b8c2SChristian Ulmann return nullptr; 161e741b8c2SChristian Ulmann } 162e741b8c2SChristian Ulmann 1637538df90SPaul Kirth void extractFromBranchWeightMD32(const MDNode *ProfileData, 164285e0235SMatthias Braun SmallVectorImpl<uint32_t> &Weights) { 1657538df90SPaul Kirth extractFromBranchWeightMD(ProfileData, Weights); 166285e0235SMatthias Braun } 1677538df90SPaul Kirth 1687538df90SPaul Kirth void extractFromBranchWeightMD64(const MDNode *ProfileData, 1697538df90SPaul Kirth SmallVectorImpl<uint64_t> &Weights) { 1707538df90SPaul Kirth extractFromBranchWeightMD(ProfileData, Weights); 171285e0235SMatthias Braun } 172285e0235SMatthias Braun 1736047deb7SPaul Kirth bool extractBranchWeights(const MDNode *ProfileData, 1746047deb7SPaul Kirth SmallVectorImpl<uint32_t> &Weights) { 1756047deb7SPaul Kirth if (!isBranchWeightMD(ProfileData)) 1766047deb7SPaul Kirth return false; 177285e0235SMatthias Braun extractFromBranchWeightMD(ProfileData, Weights); 178285e0235SMatthias Braun return true; 1796047deb7SPaul Kirth } 1806047deb7SPaul Kirth 1816047deb7SPaul Kirth bool extractBranchWeights(const Instruction &I, 1826047deb7SPaul Kirth SmallVectorImpl<uint32_t> &Weights) { 1836047deb7SPaul Kirth auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 1846047deb7SPaul Kirth return extractBranchWeights(ProfileData, Weights); 1856047deb7SPaul Kirth } 1866047deb7SPaul Kirth 1876047deb7SPaul Kirth bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal, 1886047deb7SPaul Kirth uint64_t &FalseVal) { 1896047deb7SPaul Kirth assert((I.getOpcode() == Instruction::Br || 1906047deb7SPaul Kirth I.getOpcode() == Instruction::Select) && 191e741b8c2SChristian Ulmann "Looking for branch weights on something besides branch, select, or " 192e741b8c2SChristian Ulmann "switch"); 1936047deb7SPaul Kirth 1946047deb7SPaul Kirth SmallVector<uint32_t, 2> Weights; 1956047deb7SPaul Kirth auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 1966047deb7SPaul Kirth if (!extractBranchWeights(ProfileData, Weights)) 1976047deb7SPaul Kirth return false; 1986047deb7SPaul Kirth 1996047deb7SPaul Kirth if (Weights.size() > 2) 2006047deb7SPaul Kirth return false; 2016047deb7SPaul Kirth 2026047deb7SPaul Kirth TrueVal = Weights[0]; 2036047deb7SPaul Kirth FalseVal = Weights[1]; 2046047deb7SPaul Kirth return true; 2056047deb7SPaul Kirth } 2066047deb7SPaul Kirth 2076047deb7SPaul Kirth bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) { 2086047deb7SPaul Kirth TotalVal = 0; 2096047deb7SPaul Kirth if (!ProfileData) 2106047deb7SPaul Kirth return false; 2116047deb7SPaul Kirth 2126047deb7SPaul Kirth auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0)); 2136047deb7SPaul Kirth if (!ProfDataName) 2146047deb7SPaul Kirth return false; 2156047deb7SPaul Kirth 2164e6f6fdaSKazu Hirata if (ProfDataName->getString() == "branch_weights") { 217294f3ce5SPaul Kirth unsigned Offset = getBranchWeightOffset(ProfileData); 218294f3ce5SPaul Kirth for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) { 2196ce03ff3SShubham Sandeep Rastogi auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); 2206047deb7SPaul Kirth assert(V && "Malformed branch_weight in MD_prof node"); 2216047deb7SPaul Kirth TotalVal += V->getValue().getZExtValue(); 2226047deb7SPaul Kirth } 2236047deb7SPaul Kirth return true; 224deef5b8cSPaul Kirth } 225deef5b8cSPaul Kirth 2264e6f6fdaSKazu Hirata if (ProfDataName->getString() == "VP" && ProfileData->getNumOperands() > 3) { 2276047deb7SPaul Kirth TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2)) 2286047deb7SPaul Kirth ->getValue() 2296047deb7SPaul Kirth .getZExtValue(); 2306047deb7SPaul Kirth return true; 2316047deb7SPaul Kirth } 2326047deb7SPaul Kirth return false; 2336047deb7SPaul Kirth } 2346047deb7SPaul Kirth 235e741b8c2SChristian Ulmann bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) { 236e741b8c2SChristian Ulmann return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal); 237e741b8c2SChristian Ulmann } 238e741b8c2SChristian Ulmann 239294f3ce5SPaul Kirth void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights, 240294f3ce5SPaul Kirth bool IsExpected) { 241cb4627d1SMatthias Braun MDBuilder MDB(I.getContext()); 242294f3ce5SPaul Kirth MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected); 243cb4627d1SMatthias Braun I.setMetadata(LLVMContext::MD_prof, BranchWeights); 244cb4627d1SMatthias Braun } 245cb4627d1SMatthias Braun 2462d641858SMingming Liu void scaleProfData(Instruction &I, uint64_t S, uint64_t T) { 2472d641858SMingming Liu assert(T != 0 && "Caller should guarantee"); 2482d641858SMingming Liu auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 2492d641858SMingming Liu if (ProfileData == nullptr) 2502d641858SMingming Liu return; 2512d641858SMingming Liu 2522d641858SMingming Liu auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0)); 2534e6f6fdaSKazu Hirata if (!ProfDataName || (ProfDataName->getString() != "branch_weights" && 2544e6f6fdaSKazu Hirata ProfDataName->getString() != "VP")) 2552d641858SMingming Liu return; 2562d641858SMingming Liu 25764f4ceb0SMingming Liu if (!hasCountTypeMD(I)) 25864f4ceb0SMingming Liu return; 25964f4ceb0SMingming Liu 2602d641858SMingming Liu LLVMContext &C = I.getContext(); 2612d641858SMingming Liu 2622d641858SMingming Liu MDBuilder MDB(C); 2632d641858SMingming Liu SmallVector<Metadata *, 3> Vals; 2642d641858SMingming Liu Vals.push_back(ProfileData->getOperand(0)); 2652d641858SMingming Liu APInt APS(128, S), APT(128, T); 2664e6f6fdaSKazu Hirata if (ProfDataName->getString() == "branch_weights" && 2672d641858SMingming Liu ProfileData->getNumOperands() > 0) { 2682d641858SMingming Liu // Using APInt::div may be expensive, but most cases should fit 64 bits. 269294f3ce5SPaul Kirth APInt Val(128, 270294f3ce5SPaul Kirth mdconst::dyn_extract<ConstantInt>( 271294f3ce5SPaul Kirth ProfileData->getOperand(getBranchWeightOffset(ProfileData))) 2722d641858SMingming Liu ->getValue() 2732d641858SMingming Liu .getZExtValue()); 2742d641858SMingming Liu Val *= APS; 2752d641858SMingming Liu Vals.push_back(MDB.createConstant(ConstantInt::get( 2762d641858SMingming Liu Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX)))); 2774e6f6fdaSKazu Hirata } else if (ProfDataName->getString() == "VP") 2782d641858SMingming Liu for (unsigned i = 1; i < ProfileData->getNumOperands(); i += 2) { 2792d641858SMingming Liu // The first value is the key of the value profile, which will not change. 2802d641858SMingming Liu Vals.push_back(ProfileData->getOperand(i)); 2812d641858SMingming Liu uint64_t Count = 2822d641858SMingming Liu mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i + 1)) 2832d641858SMingming Liu ->getValue() 2842d641858SMingming Liu .getZExtValue(); 2852d641858SMingming Liu // Don't scale the magic number. 2862d641858SMingming Liu if (Count == NOMORE_ICP_MAGICNUM) { 2872d641858SMingming Liu Vals.push_back(ProfileData->getOperand(i + 1)); 2882d641858SMingming Liu continue; 2892d641858SMingming Liu } 2902d641858SMingming Liu // Using APInt::div may be expensive, but most cases should fit 64 bits. 2912d641858SMingming Liu APInt Val(128, Count); 2922d641858SMingming Liu Val *= APS; 2932d641858SMingming Liu Vals.push_back(MDB.createConstant(ConstantInt::get( 2942d641858SMingming Liu Type::getInt64Ty(C), Val.udiv(APT).getLimitedValue()))); 2952d641858SMingming Liu } 2962d641858SMingming Liu I.setMetadata(LLVMContext::MD_prof, MDNode::get(C, Vals)); 2972d641858SMingming Liu } 2982d641858SMingming Liu 2996047deb7SPaul Kirth } // namespace llvm 300