1bdd1243dSDimitry Andric //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===// 2bdd1243dSDimitry Andric // 3bdd1243dSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4bdd1243dSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5bdd1243dSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6bdd1243dSDimitry Andric // 7bdd1243dSDimitry Andric //===----------------------------------------------------------------------===// 8bdd1243dSDimitry Andric // 9bdd1243dSDimitry Andric // This file implements utilities for working with Profiling Metadata. 10bdd1243dSDimitry Andric // 11bdd1243dSDimitry Andric //===----------------------------------------------------------------------===// 12bdd1243dSDimitry Andric 13bdd1243dSDimitry Andric #include "llvm/IR/ProfDataUtils.h" 14bdd1243dSDimitry Andric #include "llvm/ADT/SmallVector.h" 15bdd1243dSDimitry Andric #include "llvm/ADT/Twine.h" 16bdd1243dSDimitry Andric #include "llvm/IR/Constants.h" 17bdd1243dSDimitry Andric #include "llvm/IR/Function.h" 18bdd1243dSDimitry Andric #include "llvm/IR/Instructions.h" 19bdd1243dSDimitry Andric #include "llvm/IR/LLVMContext.h" 205f757f3fSDimitry Andric #include "llvm/IR/MDBuilder.h" 21bdd1243dSDimitry Andric #include "llvm/IR/Metadata.h" 22*0fca6ea1SDimitry Andric #include "llvm/IR/ProfDataUtils.h" 23bdd1243dSDimitry Andric #include "llvm/Support/BranchProbability.h" 24bdd1243dSDimitry Andric #include "llvm/Support/CommandLine.h" 25bdd1243dSDimitry Andric 26bdd1243dSDimitry Andric using namespace llvm; 27bdd1243dSDimitry Andric 28bdd1243dSDimitry Andric namespace { 29bdd1243dSDimitry Andric 30bdd1243dSDimitry Andric // MD_prof nodes have the following layout 31bdd1243dSDimitry Andric // 32bdd1243dSDimitry Andric // In general: 33bdd1243dSDimitry Andric // { String name, Array of i32 } 34bdd1243dSDimitry Andric // 35bdd1243dSDimitry Andric // In terms of Types: 36bdd1243dSDimitry Andric // { MDString, [i32, i32, ...]} 37bdd1243dSDimitry Andric // 38bdd1243dSDimitry Andric // Concretely for Branch Weights 39bdd1243dSDimitry Andric // { "branch_weights", [i32 1, i32 10000]} 40bdd1243dSDimitry Andric // 41bdd1243dSDimitry Andric // We maintain some constants here to ensure that we access the branch weights 42bdd1243dSDimitry Andric // correctly, and can change the behavior in the future if the layout changes 43bdd1243dSDimitry Andric 44bdd1243dSDimitry Andric // the minimum number of operands for MD_prof nodes with branch weights 45bdd1243dSDimitry Andric constexpr unsigned MinBWOps = 3; 46bdd1243dSDimitry Andric 47*0fca6ea1SDimitry Andric // the minimum number of operands for MD_prof nodes with value profiles 48*0fca6ea1SDimitry Andric constexpr unsigned MinVPOps = 5; 49*0fca6ea1SDimitry Andric 50bdd1243dSDimitry Andric // We may want to add support for other MD_prof types, so provide an abstraction 51bdd1243dSDimitry Andric // for checking the metadata type. 52bdd1243dSDimitry Andric bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { 53bdd1243dSDimitry Andric // TODO: This routine may be simplified if MD_prof used an enum instead of a 54bdd1243dSDimitry Andric // string to differentiate the types of MD_prof nodes. 55bdd1243dSDimitry Andric if (!ProfData || !Name || MinOps < 2) 56bdd1243dSDimitry Andric return false; 57bdd1243dSDimitry Andric 58bdd1243dSDimitry Andric unsigned NOps = ProfData->getNumOperands(); 59bdd1243dSDimitry Andric if (NOps < MinOps) 60bdd1243dSDimitry Andric return false; 61bdd1243dSDimitry Andric 62bdd1243dSDimitry Andric auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0)); 63bdd1243dSDimitry Andric if (!ProfDataName) 64bdd1243dSDimitry Andric return false; 65bdd1243dSDimitry Andric 66*0fca6ea1SDimitry Andric return ProfDataName->getString() == Name; 67*0fca6ea1SDimitry Andric } 68*0fca6ea1SDimitry Andric 69*0fca6ea1SDimitry Andric template <typename T, 70*0fca6ea1SDimitry Andric typename = typename std::enable_if<std::is_arithmetic_v<T>>> 71*0fca6ea1SDimitry Andric static void extractFromBranchWeightMD(const MDNode *ProfileData, 72*0fca6ea1SDimitry Andric SmallVectorImpl<T> &Weights) { 73*0fca6ea1SDimitry Andric assert(isBranchWeightMD(ProfileData) && "wrong metadata"); 74*0fca6ea1SDimitry Andric 75*0fca6ea1SDimitry Andric unsigned NOps = ProfileData->getNumOperands(); 76*0fca6ea1SDimitry Andric unsigned WeightsIdx = getBranchWeightOffset(ProfileData); 77*0fca6ea1SDimitry Andric assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); 78*0fca6ea1SDimitry Andric Weights.resize(NOps - WeightsIdx); 79*0fca6ea1SDimitry Andric 80*0fca6ea1SDimitry Andric for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { 81*0fca6ea1SDimitry Andric ConstantInt *Weight = 82*0fca6ea1SDimitry Andric mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); 83*0fca6ea1SDimitry Andric assert(Weight && "Malformed branch_weight in MD_prof node"); 84*0fca6ea1SDimitry Andric assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) && 85*0fca6ea1SDimitry Andric "Too many bits for MD_prof branch_weight"); 86*0fca6ea1SDimitry Andric Weights[Idx - WeightsIdx] = Weight->getZExtValue(); 87*0fca6ea1SDimitry Andric } 88bdd1243dSDimitry Andric } 89bdd1243dSDimitry Andric 90bdd1243dSDimitry Andric } // namespace 91bdd1243dSDimitry Andric 92bdd1243dSDimitry Andric namespace llvm { 93bdd1243dSDimitry Andric 94bdd1243dSDimitry Andric bool hasProfMD(const Instruction &I) { 95*0fca6ea1SDimitry Andric return I.hasMetadata(LLVMContext::MD_prof); 96bdd1243dSDimitry Andric } 97bdd1243dSDimitry Andric 98bdd1243dSDimitry Andric bool isBranchWeightMD(const MDNode *ProfileData) { 99bdd1243dSDimitry Andric return isTargetMD(ProfileData, "branch_weights", MinBWOps); 100bdd1243dSDimitry Andric } 101bdd1243dSDimitry Andric 102*0fca6ea1SDimitry Andric bool isValueProfileMD(const MDNode *ProfileData) { 103*0fca6ea1SDimitry Andric return isTargetMD(ProfileData, "VP", MinVPOps); 104*0fca6ea1SDimitry Andric } 105*0fca6ea1SDimitry Andric 106bdd1243dSDimitry Andric bool hasBranchWeightMD(const Instruction &I) { 107bdd1243dSDimitry Andric auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 108bdd1243dSDimitry Andric return isBranchWeightMD(ProfileData); 109bdd1243dSDimitry Andric } 110bdd1243dSDimitry Andric 111*0fca6ea1SDimitry Andric bool hasCountTypeMD(const Instruction &I) { 112*0fca6ea1SDimitry Andric auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 113*0fca6ea1SDimitry Andric // Value profiles record count-type information. 114*0fca6ea1SDimitry Andric if (isValueProfileMD(ProfileData)) 115*0fca6ea1SDimitry Andric return true; 116*0fca6ea1SDimitry Andric // Conservatively assume non CallBase instruction only get taken/not-taken 117*0fca6ea1SDimitry Andric // branch probability, so not interpret them as count. 118*0fca6ea1SDimitry Andric return isa<CallBase>(I) && !isBranchWeightMD(ProfileData); 119*0fca6ea1SDimitry Andric } 120*0fca6ea1SDimitry Andric 121bdd1243dSDimitry Andric bool hasValidBranchWeightMD(const Instruction &I) { 122bdd1243dSDimitry Andric return getValidBranchWeightMDNode(I); 123bdd1243dSDimitry Andric } 124bdd1243dSDimitry Andric 125*0fca6ea1SDimitry Andric bool hasBranchWeightOrigin(const Instruction &I) { 126*0fca6ea1SDimitry Andric auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 127*0fca6ea1SDimitry Andric return hasBranchWeightOrigin(ProfileData); 128*0fca6ea1SDimitry Andric } 129*0fca6ea1SDimitry Andric 130*0fca6ea1SDimitry Andric bool hasBranchWeightOrigin(const MDNode *ProfileData) { 131*0fca6ea1SDimitry Andric if (!isBranchWeightMD(ProfileData)) 132*0fca6ea1SDimitry Andric return false; 133*0fca6ea1SDimitry Andric auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1)); 134*0fca6ea1SDimitry Andric // NOTE: if we ever have more types of branch weight provenance, 135*0fca6ea1SDimitry Andric // we need to check the string value is "expected". For now, we 136*0fca6ea1SDimitry Andric // supply a more generic API, and avoid the spurious comparisons. 137*0fca6ea1SDimitry Andric assert(ProfDataName == nullptr || ProfDataName->getString() == "expected"); 138*0fca6ea1SDimitry Andric return ProfDataName != nullptr; 139*0fca6ea1SDimitry Andric } 140*0fca6ea1SDimitry Andric 141*0fca6ea1SDimitry Andric unsigned getBranchWeightOffset(const MDNode *ProfileData) { 142*0fca6ea1SDimitry Andric return hasBranchWeightOrigin(ProfileData) ? 2 : 1; 143*0fca6ea1SDimitry Andric } 144*0fca6ea1SDimitry Andric 145*0fca6ea1SDimitry Andric unsigned getNumBranchWeights(const MDNode &ProfileData) { 146*0fca6ea1SDimitry Andric return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData); 147*0fca6ea1SDimitry Andric } 148*0fca6ea1SDimitry Andric 149bdd1243dSDimitry Andric MDNode *getBranchWeightMDNode(const Instruction &I) { 150bdd1243dSDimitry Andric auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 151bdd1243dSDimitry Andric if (!isBranchWeightMD(ProfileData)) 152bdd1243dSDimitry Andric return nullptr; 153bdd1243dSDimitry Andric return ProfileData; 154bdd1243dSDimitry Andric } 155bdd1243dSDimitry Andric 156bdd1243dSDimitry Andric MDNode *getValidBranchWeightMDNode(const Instruction &I) { 157bdd1243dSDimitry Andric auto *ProfileData = getBranchWeightMDNode(I); 158*0fca6ea1SDimitry Andric if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors()) 159bdd1243dSDimitry Andric return ProfileData; 160bdd1243dSDimitry Andric return nullptr; 161bdd1243dSDimitry Andric } 162bdd1243dSDimitry Andric 163*0fca6ea1SDimitry Andric void extractFromBranchWeightMD32(const MDNode *ProfileData, 1645f757f3fSDimitry Andric SmallVectorImpl<uint32_t> &Weights) { 165*0fca6ea1SDimitry Andric extractFromBranchWeightMD(ProfileData, Weights); 1665f757f3fSDimitry Andric } 167*0fca6ea1SDimitry Andric 168*0fca6ea1SDimitry Andric void extractFromBranchWeightMD64(const MDNode *ProfileData, 169*0fca6ea1SDimitry Andric SmallVectorImpl<uint64_t> &Weights) { 170*0fca6ea1SDimitry Andric extractFromBranchWeightMD(ProfileData, Weights); 1715f757f3fSDimitry Andric } 1725f757f3fSDimitry Andric 173bdd1243dSDimitry Andric bool extractBranchWeights(const MDNode *ProfileData, 174bdd1243dSDimitry Andric SmallVectorImpl<uint32_t> &Weights) { 175bdd1243dSDimitry Andric if (!isBranchWeightMD(ProfileData)) 176bdd1243dSDimitry Andric return false; 1775f757f3fSDimitry Andric extractFromBranchWeightMD(ProfileData, Weights); 1785f757f3fSDimitry Andric return true; 179bdd1243dSDimitry Andric } 180bdd1243dSDimitry Andric 181bdd1243dSDimitry Andric bool extractBranchWeights(const Instruction &I, 182bdd1243dSDimitry Andric SmallVectorImpl<uint32_t> &Weights) { 183bdd1243dSDimitry Andric auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 184bdd1243dSDimitry Andric return extractBranchWeights(ProfileData, Weights); 185bdd1243dSDimitry Andric } 186bdd1243dSDimitry Andric 187bdd1243dSDimitry Andric bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal, 188bdd1243dSDimitry Andric uint64_t &FalseVal) { 189bdd1243dSDimitry Andric assert((I.getOpcode() == Instruction::Br || 190bdd1243dSDimitry Andric I.getOpcode() == Instruction::Select) && 191bdd1243dSDimitry Andric "Looking for branch weights on something besides branch, select, or " 192bdd1243dSDimitry Andric "switch"); 193bdd1243dSDimitry Andric 194bdd1243dSDimitry Andric SmallVector<uint32_t, 2> Weights; 195bdd1243dSDimitry Andric auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 196bdd1243dSDimitry Andric if (!extractBranchWeights(ProfileData, Weights)) 197bdd1243dSDimitry Andric return false; 198bdd1243dSDimitry Andric 199bdd1243dSDimitry Andric if (Weights.size() > 2) 200bdd1243dSDimitry Andric return false; 201bdd1243dSDimitry Andric 202bdd1243dSDimitry Andric TrueVal = Weights[0]; 203bdd1243dSDimitry Andric FalseVal = Weights[1]; 204bdd1243dSDimitry Andric return true; 205bdd1243dSDimitry Andric } 206bdd1243dSDimitry Andric 207bdd1243dSDimitry Andric bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) { 208bdd1243dSDimitry Andric TotalVal = 0; 209bdd1243dSDimitry Andric if (!ProfileData) 210bdd1243dSDimitry Andric return false; 211bdd1243dSDimitry Andric 212bdd1243dSDimitry Andric auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0)); 213bdd1243dSDimitry Andric if (!ProfDataName) 214bdd1243dSDimitry Andric return false; 215bdd1243dSDimitry Andric 216*0fca6ea1SDimitry Andric if (ProfDataName->getString() == "branch_weights") { 217*0fca6ea1SDimitry Andric unsigned Offset = getBranchWeightOffset(ProfileData); 218*0fca6ea1SDimitry Andric for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) { 219bdd1243dSDimitry Andric auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); 220bdd1243dSDimitry Andric assert(V && "Malformed branch_weight in MD_prof node"); 221bdd1243dSDimitry Andric TotalVal += V->getValue().getZExtValue(); 222bdd1243dSDimitry Andric } 223bdd1243dSDimitry Andric return true; 224bdd1243dSDimitry Andric } 225bdd1243dSDimitry Andric 226*0fca6ea1SDimitry Andric if (ProfDataName->getString() == "VP" && ProfileData->getNumOperands() > 3) { 227bdd1243dSDimitry Andric TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2)) 228bdd1243dSDimitry Andric ->getValue() 229bdd1243dSDimitry Andric .getZExtValue(); 230bdd1243dSDimitry Andric return true; 231bdd1243dSDimitry Andric } 232bdd1243dSDimitry Andric return false; 233bdd1243dSDimitry Andric } 234bdd1243dSDimitry Andric 235bdd1243dSDimitry Andric bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) { 236bdd1243dSDimitry Andric return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal); 237bdd1243dSDimitry Andric } 238bdd1243dSDimitry Andric 239*0fca6ea1SDimitry Andric void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights, 240*0fca6ea1SDimitry Andric bool IsExpected) { 2415f757f3fSDimitry Andric MDBuilder MDB(I.getContext()); 242*0fca6ea1SDimitry Andric MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected); 2435f757f3fSDimitry Andric I.setMetadata(LLVMContext::MD_prof, BranchWeights); 2445f757f3fSDimitry Andric } 2455f757f3fSDimitry Andric 246*0fca6ea1SDimitry Andric void scaleProfData(Instruction &I, uint64_t S, uint64_t T) { 247*0fca6ea1SDimitry Andric assert(T != 0 && "Caller should guarantee"); 248*0fca6ea1SDimitry Andric auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 249*0fca6ea1SDimitry Andric if (ProfileData == nullptr) 250*0fca6ea1SDimitry Andric return; 251*0fca6ea1SDimitry Andric 252*0fca6ea1SDimitry Andric auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0)); 253*0fca6ea1SDimitry Andric if (!ProfDataName || (ProfDataName->getString() != "branch_weights" && 254*0fca6ea1SDimitry Andric ProfDataName->getString() != "VP")) 255*0fca6ea1SDimitry Andric return; 256*0fca6ea1SDimitry Andric 257*0fca6ea1SDimitry Andric if (!hasCountTypeMD(I)) 258*0fca6ea1SDimitry Andric return; 259*0fca6ea1SDimitry Andric 260*0fca6ea1SDimitry Andric LLVMContext &C = I.getContext(); 261*0fca6ea1SDimitry Andric 262*0fca6ea1SDimitry Andric MDBuilder MDB(C); 263*0fca6ea1SDimitry Andric SmallVector<Metadata *, 3> Vals; 264*0fca6ea1SDimitry Andric Vals.push_back(ProfileData->getOperand(0)); 265*0fca6ea1SDimitry Andric APInt APS(128, S), APT(128, T); 266*0fca6ea1SDimitry Andric if (ProfDataName->getString() == "branch_weights" && 267*0fca6ea1SDimitry Andric ProfileData->getNumOperands() > 0) { 268*0fca6ea1SDimitry Andric // Using APInt::div may be expensive, but most cases should fit 64 bits. 269*0fca6ea1SDimitry Andric APInt Val(128, 270*0fca6ea1SDimitry Andric mdconst::dyn_extract<ConstantInt>( 271*0fca6ea1SDimitry Andric ProfileData->getOperand(getBranchWeightOffset(ProfileData))) 272*0fca6ea1SDimitry Andric ->getValue() 273*0fca6ea1SDimitry Andric .getZExtValue()); 274*0fca6ea1SDimitry Andric Val *= APS; 275*0fca6ea1SDimitry Andric Vals.push_back(MDB.createConstant(ConstantInt::get( 276*0fca6ea1SDimitry Andric Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX)))); 277*0fca6ea1SDimitry Andric } else if (ProfDataName->getString() == "VP") 278*0fca6ea1SDimitry Andric for (unsigned i = 1; i < ProfileData->getNumOperands(); i += 2) { 279*0fca6ea1SDimitry Andric // The first value is the key of the value profile, which will not change. 280*0fca6ea1SDimitry Andric Vals.push_back(ProfileData->getOperand(i)); 281*0fca6ea1SDimitry Andric uint64_t Count = 282*0fca6ea1SDimitry Andric mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i + 1)) 283*0fca6ea1SDimitry Andric ->getValue() 284*0fca6ea1SDimitry Andric .getZExtValue(); 285*0fca6ea1SDimitry Andric // Don't scale the magic number. 286*0fca6ea1SDimitry Andric if (Count == NOMORE_ICP_MAGICNUM) { 287*0fca6ea1SDimitry Andric Vals.push_back(ProfileData->getOperand(i + 1)); 288*0fca6ea1SDimitry Andric continue; 289*0fca6ea1SDimitry Andric } 290*0fca6ea1SDimitry Andric // Using APInt::div may be expensive, but most cases should fit 64 bits. 291*0fca6ea1SDimitry Andric APInt Val(128, Count); 292*0fca6ea1SDimitry Andric Val *= APS; 293*0fca6ea1SDimitry Andric Vals.push_back(MDB.createConstant(ConstantInt::get( 294*0fca6ea1SDimitry Andric Type::getInt64Ty(C), Val.udiv(APT).getLimitedValue()))); 295*0fca6ea1SDimitry Andric } 296*0fca6ea1SDimitry Andric I.setMetadata(LLVMContext::MD_prof, MDNode::get(C, Vals)); 297*0fca6ea1SDimitry Andric } 298*0fca6ea1SDimitry Andric 299bdd1243dSDimitry Andric } // namespace llvm 300