1bdd1243dSDimitry Andric //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===// 2bdd1243dSDimitry Andric // 3bdd1243dSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4bdd1243dSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5bdd1243dSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6bdd1243dSDimitry Andric // 7bdd1243dSDimitry Andric //===----------------------------------------------------------------------===// 8bdd1243dSDimitry Andric // 9bdd1243dSDimitry Andric // This file implements utilities for working with Profiling Metadata. 10bdd1243dSDimitry Andric // 11bdd1243dSDimitry Andric //===----------------------------------------------------------------------===// 12bdd1243dSDimitry Andric 13bdd1243dSDimitry Andric #include "llvm/IR/ProfDataUtils.h" 14bdd1243dSDimitry Andric #include "llvm/ADT/SmallVector.h" 15bdd1243dSDimitry Andric #include "llvm/ADT/Twine.h" 16bdd1243dSDimitry Andric #include "llvm/IR/Constants.h" 17bdd1243dSDimitry Andric #include "llvm/IR/Function.h" 18bdd1243dSDimitry Andric #include "llvm/IR/Instructions.h" 19bdd1243dSDimitry Andric #include "llvm/IR/LLVMContext.h" 20*5f757f3fSDimitry Andric #include "llvm/IR/MDBuilder.h" 21bdd1243dSDimitry Andric #include "llvm/IR/Metadata.h" 22bdd1243dSDimitry Andric #include "llvm/Support/BranchProbability.h" 23bdd1243dSDimitry Andric #include "llvm/Support/CommandLine.h" 24bdd1243dSDimitry Andric 25bdd1243dSDimitry Andric using namespace llvm; 26bdd1243dSDimitry Andric 27bdd1243dSDimitry Andric namespace { 28bdd1243dSDimitry Andric 29bdd1243dSDimitry Andric // MD_prof nodes have the following layout 30bdd1243dSDimitry Andric // 31bdd1243dSDimitry Andric // In general: 32bdd1243dSDimitry Andric // { String name, Array of i32 } 33bdd1243dSDimitry Andric // 34bdd1243dSDimitry Andric // In terms of Types: 35bdd1243dSDimitry Andric // { MDString, [i32, i32, ...]} 36bdd1243dSDimitry Andric // 37bdd1243dSDimitry Andric // Concretely for Branch Weights 38bdd1243dSDimitry Andric // { "branch_weights", [i32 1, i32 10000]} 39bdd1243dSDimitry Andric // 40bdd1243dSDimitry Andric // We maintain some constants here to ensure that we access the branch weights 41bdd1243dSDimitry Andric // correctly, and can change the behavior in the future if the layout changes 42bdd1243dSDimitry Andric 43bdd1243dSDimitry Andric // The index at which the weights vector starts 44bdd1243dSDimitry Andric constexpr unsigned WeightsIdx = 1; 45bdd1243dSDimitry Andric 46bdd1243dSDimitry Andric // the minimum number of operands for MD_prof nodes with branch weights 47bdd1243dSDimitry Andric constexpr unsigned MinBWOps = 3; 48bdd1243dSDimitry Andric 49bdd1243dSDimitry Andric // We may want to add support for other MD_prof types, so provide an abstraction 50bdd1243dSDimitry Andric // for checking the metadata type. 51bdd1243dSDimitry Andric bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { 52bdd1243dSDimitry Andric // TODO: This routine may be simplified if MD_prof used an enum instead of a 53bdd1243dSDimitry Andric // string to differentiate the types of MD_prof nodes. 54bdd1243dSDimitry Andric if (!ProfData || !Name || MinOps < 2) 55bdd1243dSDimitry Andric return false; 56bdd1243dSDimitry Andric 57bdd1243dSDimitry Andric unsigned NOps = ProfData->getNumOperands(); 58bdd1243dSDimitry Andric if (NOps < MinOps) 59bdd1243dSDimitry Andric return false; 60bdd1243dSDimitry Andric 61bdd1243dSDimitry Andric auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0)); 62bdd1243dSDimitry Andric if (!ProfDataName) 63bdd1243dSDimitry Andric return false; 64bdd1243dSDimitry Andric 65bdd1243dSDimitry Andric return ProfDataName->getString().equals(Name); 66bdd1243dSDimitry Andric } 67bdd1243dSDimitry Andric 68bdd1243dSDimitry Andric } // namespace 69bdd1243dSDimitry Andric 70bdd1243dSDimitry Andric namespace llvm { 71bdd1243dSDimitry Andric 72bdd1243dSDimitry Andric bool hasProfMD(const Instruction &I) { 73bdd1243dSDimitry Andric return nullptr != I.getMetadata(LLVMContext::MD_prof); 74bdd1243dSDimitry Andric } 75bdd1243dSDimitry Andric 76bdd1243dSDimitry Andric bool isBranchWeightMD(const MDNode *ProfileData) { 77bdd1243dSDimitry Andric return isTargetMD(ProfileData, "branch_weights", MinBWOps); 78bdd1243dSDimitry Andric } 79bdd1243dSDimitry Andric 80bdd1243dSDimitry Andric bool hasBranchWeightMD(const Instruction &I) { 81bdd1243dSDimitry Andric auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 82bdd1243dSDimitry Andric return isBranchWeightMD(ProfileData); 83bdd1243dSDimitry Andric } 84bdd1243dSDimitry Andric 85bdd1243dSDimitry Andric bool hasValidBranchWeightMD(const Instruction &I) { 86bdd1243dSDimitry Andric return getValidBranchWeightMDNode(I); 87bdd1243dSDimitry Andric } 88bdd1243dSDimitry Andric 89bdd1243dSDimitry Andric MDNode *getBranchWeightMDNode(const Instruction &I) { 90bdd1243dSDimitry Andric auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 91bdd1243dSDimitry Andric if (!isBranchWeightMD(ProfileData)) 92bdd1243dSDimitry Andric return nullptr; 93bdd1243dSDimitry Andric return ProfileData; 94bdd1243dSDimitry Andric } 95bdd1243dSDimitry Andric 96bdd1243dSDimitry Andric MDNode *getValidBranchWeightMDNode(const Instruction &I) { 97bdd1243dSDimitry Andric auto *ProfileData = getBranchWeightMDNode(I); 98bdd1243dSDimitry Andric if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors()) 99bdd1243dSDimitry Andric return ProfileData; 100bdd1243dSDimitry Andric return nullptr; 101bdd1243dSDimitry Andric } 102bdd1243dSDimitry Andric 103*5f757f3fSDimitry Andric void extractFromBranchWeightMD(const MDNode *ProfileData, 104*5f757f3fSDimitry Andric SmallVectorImpl<uint32_t> &Weights) { 105*5f757f3fSDimitry Andric assert(isBranchWeightMD(ProfileData) && "wrong metadata"); 106*5f757f3fSDimitry Andric 107*5f757f3fSDimitry Andric unsigned NOps = ProfileData->getNumOperands(); 108*5f757f3fSDimitry Andric assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); 109*5f757f3fSDimitry Andric Weights.resize(NOps - WeightsIdx); 110*5f757f3fSDimitry Andric 111*5f757f3fSDimitry Andric for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { 112*5f757f3fSDimitry Andric ConstantInt *Weight = 113*5f757f3fSDimitry Andric mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); 114*5f757f3fSDimitry Andric assert(Weight && "Malformed branch_weight in MD_prof node"); 115*5f757f3fSDimitry Andric assert(Weight->getValue().getActiveBits() <= 32 && 116*5f757f3fSDimitry Andric "Too many bits for uint32_t"); 117*5f757f3fSDimitry Andric Weights[Idx - WeightsIdx] = Weight->getZExtValue(); 118*5f757f3fSDimitry Andric } 119*5f757f3fSDimitry Andric } 120*5f757f3fSDimitry Andric 121bdd1243dSDimitry Andric bool extractBranchWeights(const MDNode *ProfileData, 122bdd1243dSDimitry Andric SmallVectorImpl<uint32_t> &Weights) { 123bdd1243dSDimitry Andric if (!isBranchWeightMD(ProfileData)) 124bdd1243dSDimitry Andric return false; 125*5f757f3fSDimitry Andric extractFromBranchWeightMD(ProfileData, Weights); 126*5f757f3fSDimitry Andric return true; 127bdd1243dSDimitry Andric } 128bdd1243dSDimitry Andric 129bdd1243dSDimitry Andric bool extractBranchWeights(const Instruction &I, 130bdd1243dSDimitry Andric SmallVectorImpl<uint32_t> &Weights) { 131bdd1243dSDimitry Andric auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 132bdd1243dSDimitry Andric return extractBranchWeights(ProfileData, Weights); 133bdd1243dSDimitry Andric } 134bdd1243dSDimitry Andric 135bdd1243dSDimitry Andric bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal, 136bdd1243dSDimitry Andric uint64_t &FalseVal) { 137bdd1243dSDimitry Andric assert((I.getOpcode() == Instruction::Br || 138bdd1243dSDimitry Andric I.getOpcode() == Instruction::Select) && 139bdd1243dSDimitry Andric "Looking for branch weights on something besides branch, select, or " 140bdd1243dSDimitry Andric "switch"); 141bdd1243dSDimitry Andric 142bdd1243dSDimitry Andric SmallVector<uint32_t, 2> Weights; 143bdd1243dSDimitry Andric auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 144bdd1243dSDimitry Andric if (!extractBranchWeights(ProfileData, Weights)) 145bdd1243dSDimitry Andric return false; 146bdd1243dSDimitry Andric 147bdd1243dSDimitry Andric if (Weights.size() > 2) 148bdd1243dSDimitry Andric return false; 149bdd1243dSDimitry Andric 150bdd1243dSDimitry Andric TrueVal = Weights[0]; 151bdd1243dSDimitry Andric FalseVal = Weights[1]; 152bdd1243dSDimitry Andric return true; 153bdd1243dSDimitry Andric } 154bdd1243dSDimitry Andric 155bdd1243dSDimitry Andric bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) { 156bdd1243dSDimitry Andric TotalVal = 0; 157bdd1243dSDimitry Andric if (!ProfileData) 158bdd1243dSDimitry Andric return false; 159bdd1243dSDimitry Andric 160bdd1243dSDimitry Andric auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0)); 161bdd1243dSDimitry Andric if (!ProfDataName) 162bdd1243dSDimitry Andric return false; 163bdd1243dSDimitry Andric 164bdd1243dSDimitry Andric if (ProfDataName->getString().equals("branch_weights")) { 165bdd1243dSDimitry Andric for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) { 166bdd1243dSDimitry Andric auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); 167bdd1243dSDimitry Andric assert(V && "Malformed branch_weight in MD_prof node"); 168bdd1243dSDimitry Andric TotalVal += V->getValue().getZExtValue(); 169bdd1243dSDimitry Andric } 170bdd1243dSDimitry Andric return true; 171bdd1243dSDimitry Andric } 172bdd1243dSDimitry Andric 173bdd1243dSDimitry Andric if (ProfDataName->getString().equals("VP") && 174bdd1243dSDimitry Andric ProfileData->getNumOperands() > 3) { 175bdd1243dSDimitry Andric TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2)) 176bdd1243dSDimitry Andric ->getValue() 177bdd1243dSDimitry Andric .getZExtValue(); 178bdd1243dSDimitry Andric return true; 179bdd1243dSDimitry Andric } 180bdd1243dSDimitry Andric return false; 181bdd1243dSDimitry Andric } 182bdd1243dSDimitry Andric 183bdd1243dSDimitry Andric bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) { 184bdd1243dSDimitry Andric return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal); 185bdd1243dSDimitry Andric } 186bdd1243dSDimitry Andric 187*5f757f3fSDimitry Andric void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) { 188*5f757f3fSDimitry Andric MDBuilder MDB(I.getContext()); 189*5f757f3fSDimitry Andric MDNode *BranchWeights = MDB.createBranchWeights(Weights); 190*5f757f3fSDimitry Andric I.setMetadata(LLVMContext::MD_prof, BranchWeights); 191*5f757f3fSDimitry Andric } 192*5f757f3fSDimitry Andric 193bdd1243dSDimitry Andric } // namespace llvm 194