1a812b39eSPaul Kirth //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===// 2a812b39eSPaul Kirth // 3a812b39eSPaul Kirth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a812b39eSPaul Kirth // See https://llvm.org/LICENSE.txt for license information. 5a812b39eSPaul Kirth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a812b39eSPaul Kirth // 7a812b39eSPaul Kirth //===----------------------------------------------------------------------===// 8a812b39eSPaul Kirth // 9a812b39eSPaul Kirth // This file implements utilities for working with Profiling Metadata. 10a812b39eSPaul Kirth // 11a812b39eSPaul Kirth //===----------------------------------------------------------------------===// 12a812b39eSPaul Kirth 136047deb7SPaul Kirth #include "llvm/IR/ProfDataUtils.h" 146047deb7SPaul Kirth #include "llvm/ADT/SmallVector.h" 156047deb7SPaul Kirth #include "llvm/ADT/Twine.h" 166047deb7SPaul Kirth #include "llvm/IR/Constants.h" 176047deb7SPaul Kirth #include "llvm/IR/Function.h" 186047deb7SPaul Kirth #include "llvm/IR/Instructions.h" 196047deb7SPaul Kirth #include "llvm/IR/LLVMContext.h" 206047deb7SPaul Kirth #include "llvm/IR/Metadata.h" 216047deb7SPaul Kirth #include "llvm/Support/BranchProbability.h" 226047deb7SPaul Kirth #include "llvm/Support/CommandLine.h" 236047deb7SPaul Kirth 246047deb7SPaul Kirth using namespace llvm; 256047deb7SPaul Kirth 266047deb7SPaul Kirth namespace { 276047deb7SPaul Kirth 286047deb7SPaul Kirth // MD_prof nodes have the following layout 296047deb7SPaul Kirth // 306047deb7SPaul Kirth // In general: 316047deb7SPaul Kirth // { String name, Array of i32 } 326047deb7SPaul Kirth // 336047deb7SPaul Kirth // In terms of Types: 346047deb7SPaul Kirth // { MDString, [i32, i32, ...]} 356047deb7SPaul Kirth // 366047deb7SPaul Kirth // Concretely for Branch Weights 376047deb7SPaul Kirth // { "branch_weights", [i32 1, i32 10000]} 386047deb7SPaul Kirth // 396047deb7SPaul Kirth // We maintain some constants here to ensure that we access the branch weights 406047deb7SPaul Kirth // correctly, and can change the behavior in the future if the layout changes 416047deb7SPaul Kirth 426047deb7SPaul Kirth // The index at which the weights vector starts 436047deb7SPaul Kirth constexpr unsigned WeightsIdx = 1; 446047deb7SPaul Kirth 456047deb7SPaul Kirth // the minimum number of operands for MD_prof nodes with branch weights 466047deb7SPaul Kirth constexpr unsigned MinBWOps = 3; 476047deb7SPaul Kirth 486047deb7SPaul Kirth bool extractWeights(const MDNode *ProfileData, 496047deb7SPaul Kirth SmallVectorImpl<uint32_t> &Weights) { 506047deb7SPaul Kirth // Assume preconditions are already met (i.e. this is valid metadata) 516047deb7SPaul Kirth assert(ProfileData && "ProfileData was nullptr in extractWeights"); 526047deb7SPaul Kirth unsigned NOps = ProfileData->getNumOperands(); 536047deb7SPaul Kirth 546047deb7SPaul Kirth assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); 556047deb7SPaul Kirth Weights.resize(NOps - WeightsIdx); 566047deb7SPaul Kirth 576047deb7SPaul Kirth for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { 586047deb7SPaul Kirth ConstantInt *Weight = 596047deb7SPaul Kirth mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); 606047deb7SPaul Kirth assert(Weight && "Malformed branch_weight in MD_prof node"); 616047deb7SPaul Kirth assert(Weight->getValue().getActiveBits() <= 32 && 626047deb7SPaul Kirth "Too many bits for uint32_t"); 636047deb7SPaul Kirth Weights[Idx - WeightsIdx] = Weight->getZExtValue(); 646047deb7SPaul Kirth } 656047deb7SPaul Kirth return true; 666047deb7SPaul Kirth } 676047deb7SPaul Kirth 686047deb7SPaul Kirth // We may want to add support for other MD_prof types, so provide an abstraction 696047deb7SPaul Kirth // for checking the metadata type. 706047deb7SPaul Kirth bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { 716047deb7SPaul Kirth // TODO: This routine may be simplified if MD_prof used an enum instead of a 726047deb7SPaul Kirth // string to differentiate the types of MD_prof nodes. 736047deb7SPaul Kirth if (!ProfData || !Name || MinOps < 2) 746047deb7SPaul Kirth return false; 756047deb7SPaul Kirth 766047deb7SPaul Kirth unsigned NOps = ProfData->getNumOperands(); 776047deb7SPaul Kirth if (NOps < MinOps) 786047deb7SPaul Kirth return false; 796047deb7SPaul Kirth 806047deb7SPaul Kirth auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0)); 816047deb7SPaul Kirth if (!ProfDataName) 826047deb7SPaul Kirth return false; 836047deb7SPaul Kirth 846047deb7SPaul Kirth return ProfDataName->getString().equals(Name); 856047deb7SPaul Kirth } 866047deb7SPaul Kirth 876047deb7SPaul Kirth } // namespace 886047deb7SPaul Kirth 896047deb7SPaul Kirth namespace llvm { 906047deb7SPaul Kirth 916047deb7SPaul Kirth bool hasProfMD(const Instruction &I) { 926047deb7SPaul Kirth return nullptr != I.getMetadata(LLVMContext::MD_prof); 936047deb7SPaul Kirth } 946047deb7SPaul Kirth 956047deb7SPaul Kirth bool isBranchWeightMD(const MDNode *ProfileData) { 966047deb7SPaul Kirth return isTargetMD(ProfileData, "branch_weights", MinBWOps); 976047deb7SPaul Kirth } 986047deb7SPaul Kirth 996047deb7SPaul Kirth bool hasBranchWeightMD(const Instruction &I) { 1006047deb7SPaul Kirth auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 1016047deb7SPaul Kirth return isBranchWeightMD(ProfileData); 1026047deb7SPaul Kirth } 1036047deb7SPaul Kirth 104*e741b8c2SChristian Ulmann bool hasValidBranchWeightMD(const Instruction &I) { 105*e741b8c2SChristian Ulmann return getValidBranchWeightMDNode(I); 106*e741b8c2SChristian Ulmann } 107*e741b8c2SChristian Ulmann 108*e741b8c2SChristian Ulmann MDNode *getBranchWeightMDNode(const Instruction &I) { 109*e741b8c2SChristian Ulmann auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 110*e741b8c2SChristian Ulmann if (!isBranchWeightMD(ProfileData)) 111*e741b8c2SChristian Ulmann return nullptr; 112*e741b8c2SChristian Ulmann return ProfileData; 113*e741b8c2SChristian Ulmann } 114*e741b8c2SChristian Ulmann 115*e741b8c2SChristian Ulmann MDNode *getValidBranchWeightMDNode(const Instruction &I) { 116*e741b8c2SChristian Ulmann auto *ProfileData = getBranchWeightMDNode(I); 117*e741b8c2SChristian Ulmann if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors()) 118*e741b8c2SChristian Ulmann return ProfileData; 119*e741b8c2SChristian Ulmann return nullptr; 120*e741b8c2SChristian Ulmann } 121*e741b8c2SChristian Ulmann 1226047deb7SPaul Kirth bool extractBranchWeights(const MDNode *ProfileData, 1236047deb7SPaul Kirth SmallVectorImpl<uint32_t> &Weights) { 1246047deb7SPaul Kirth if (!isBranchWeightMD(ProfileData)) 1256047deb7SPaul Kirth return false; 1266047deb7SPaul Kirth return extractWeights(ProfileData, Weights); 1276047deb7SPaul Kirth } 1286047deb7SPaul Kirth 1296047deb7SPaul Kirth bool extractBranchWeights(const Instruction &I, 1306047deb7SPaul Kirth SmallVectorImpl<uint32_t> &Weights) { 1316047deb7SPaul Kirth auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 1326047deb7SPaul Kirth return extractBranchWeights(ProfileData, Weights); 1336047deb7SPaul Kirth } 1346047deb7SPaul Kirth 1356047deb7SPaul Kirth bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal, 1366047deb7SPaul Kirth uint64_t &FalseVal) { 1376047deb7SPaul Kirth assert((I.getOpcode() == Instruction::Br || 1386047deb7SPaul Kirth I.getOpcode() == Instruction::Select) && 139*e741b8c2SChristian Ulmann "Looking for branch weights on something besides branch, select, or " 140*e741b8c2SChristian Ulmann "switch"); 1416047deb7SPaul Kirth 1426047deb7SPaul Kirth SmallVector<uint32_t, 2> Weights; 1436047deb7SPaul Kirth auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 1446047deb7SPaul Kirth if (!extractBranchWeights(ProfileData, Weights)) 1456047deb7SPaul Kirth return false; 1466047deb7SPaul Kirth 1476047deb7SPaul Kirth if (Weights.size() > 2) 1486047deb7SPaul Kirth return false; 1496047deb7SPaul Kirth 1506047deb7SPaul Kirth TrueVal = Weights[0]; 1516047deb7SPaul Kirth FalseVal = Weights[1]; 1526047deb7SPaul Kirth return true; 1536047deb7SPaul Kirth } 1546047deb7SPaul Kirth 1556047deb7SPaul Kirth bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) { 1566047deb7SPaul Kirth TotalVal = 0; 1576047deb7SPaul Kirth if (!ProfileData) 1586047deb7SPaul Kirth return false; 1596047deb7SPaul Kirth 1606047deb7SPaul Kirth auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0)); 1616047deb7SPaul Kirth if (!ProfDataName) 1626047deb7SPaul Kirth return false; 1636047deb7SPaul Kirth 1646047deb7SPaul Kirth if (ProfDataName->getString().equals("branch_weights")) { 1656047deb7SPaul Kirth for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) { 1666047deb7SPaul Kirth auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); 1676047deb7SPaul Kirth assert(V && "Malformed branch_weight in MD_prof node"); 1686047deb7SPaul Kirth TotalVal += V->getValue().getZExtValue(); 1696047deb7SPaul Kirth } 1706047deb7SPaul Kirth return true; 171deef5b8cSPaul Kirth } 172deef5b8cSPaul Kirth 173deef5b8cSPaul Kirth if (ProfDataName->getString().equals("VP") && 1746047deb7SPaul Kirth ProfileData->getNumOperands() > 3) { 1756047deb7SPaul Kirth TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2)) 1766047deb7SPaul Kirth ->getValue() 1776047deb7SPaul Kirth .getZExtValue(); 1786047deb7SPaul Kirth return true; 1796047deb7SPaul Kirth } 1806047deb7SPaul Kirth return false; 1816047deb7SPaul Kirth } 1826047deb7SPaul Kirth 183*e741b8c2SChristian Ulmann bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) { 184*e741b8c2SChristian Ulmann return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal); 185*e741b8c2SChristian Ulmann } 186*e741b8c2SChristian Ulmann 1876047deb7SPaul Kirth } // namespace llvm 188