1 //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for working with Profiling Metadata. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "llvm/IR/ProfDataUtils.h" 14 #include "llvm/ADT/SmallVector.h" 15 #include "llvm/ADT/Twine.h" 16 #include "llvm/IR/Constants.h" 17 #include "llvm/IR/Function.h" 18 #include "llvm/IR/Instructions.h" 19 #include "llvm/IR/LLVMContext.h" 20 #include "llvm/IR/MDBuilder.h" 21 #include "llvm/IR/Metadata.h" 22 #include "llvm/Support/BranchProbability.h" 23 #include "llvm/Support/CommandLine.h" 24 25 using namespace llvm; 26 27 namespace { 28 29 // MD_prof nodes have the following layout 30 // 31 // In general: 32 // { String name, Array of i32 } 33 // 34 // In terms of Types: 35 // { MDString, [i32, i32, ...]} 36 // 37 // Concretely for Branch Weights 38 // { "branch_weights", [i32 1, i32 10000]} 39 // 40 // We maintain some constants here to ensure that we access the branch weights 41 // correctly, and can change the behavior in the future if the layout changes 42 43 // The index at which the weights vector starts 44 constexpr unsigned WeightsIdx = 1; 45 46 // the minimum number of operands for MD_prof nodes with branch weights 47 constexpr unsigned MinBWOps = 3; 48 49 // the minimum number of operands for MD_prof nodes with value profiles 50 constexpr unsigned MinVPOps = 5; 51 52 // We may want to add support for other MD_prof types, so provide an abstraction 53 // for checking the metadata type. 54 bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { 55 // TODO: This routine may be simplified if MD_prof used an enum instead of a 56 // string to differentiate the types of MD_prof nodes. 57 if (!ProfData || !Name || MinOps < 2) 58 return false; 59 60 unsigned NOps = ProfData->getNumOperands(); 61 if (NOps < MinOps) 62 return false; 63 64 auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0)); 65 if (!ProfDataName) 66 return false; 67 68 return ProfDataName->getString() == Name; 69 } 70 71 template <typename T, 72 typename = typename std::enable_if<std::is_arithmetic_v<T>>> 73 static void extractFromBranchWeightMD(const MDNode *ProfileData, 74 SmallVectorImpl<T> &Weights) { 75 assert(isBranchWeightMD(ProfileData) && "wrong metadata"); 76 77 unsigned NOps = ProfileData->getNumOperands(); 78 assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); 79 Weights.resize(NOps - WeightsIdx); 80 81 for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { 82 ConstantInt *Weight = 83 mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); 84 assert(Weight && "Malformed branch_weight in MD_prof node"); 85 assert(Weight->getValue().getActiveBits() <= 32 && 86 "Too many bits for uint32_t"); 87 Weights[Idx - WeightsIdx] = Weight->getZExtValue(); 88 } 89 } 90 91 } // namespace 92 93 namespace llvm { 94 95 bool hasProfMD(const Instruction &I) { 96 return I.hasMetadata(LLVMContext::MD_prof); 97 } 98 99 bool isBranchWeightMD(const MDNode *ProfileData) { 100 return isTargetMD(ProfileData, "branch_weights", MinBWOps); 101 } 102 103 bool isValueProfileMD(const MDNode *ProfileData) { 104 return isTargetMD(ProfileData, "VP", MinVPOps); 105 } 106 107 bool hasBranchWeightMD(const Instruction &I) { 108 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 109 return isBranchWeightMD(ProfileData); 110 } 111 112 bool hasCountTypeMD(const Instruction &I) { 113 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 114 // Value profiles record count-type information. 115 if (isValueProfileMD(ProfileData)) 116 return true; 117 // Conservatively assume non CallBase instruction only get taken/not-taken 118 // branch probability, so not interpret them as count. 119 return isa<CallBase>(I) && !isBranchWeightMD(ProfileData); 120 } 121 122 bool hasValidBranchWeightMD(const Instruction &I) { 123 return getValidBranchWeightMDNode(I); 124 } 125 126 MDNode *getBranchWeightMDNode(const Instruction &I) { 127 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 128 if (!isBranchWeightMD(ProfileData)) 129 return nullptr; 130 return ProfileData; 131 } 132 133 MDNode *getValidBranchWeightMDNode(const Instruction &I) { 134 auto *ProfileData = getBranchWeightMDNode(I); 135 if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors()) 136 return ProfileData; 137 return nullptr; 138 } 139 140 void extractFromBranchWeightMD32(const MDNode *ProfileData, 141 SmallVectorImpl<uint32_t> &Weights) { 142 extractFromBranchWeightMD(ProfileData, Weights); 143 } 144 145 void extractFromBranchWeightMD64(const MDNode *ProfileData, 146 SmallVectorImpl<uint64_t> &Weights) { 147 extractFromBranchWeightMD(ProfileData, Weights); 148 } 149 150 bool extractBranchWeights(const MDNode *ProfileData, 151 SmallVectorImpl<uint32_t> &Weights) { 152 if (!isBranchWeightMD(ProfileData)) 153 return false; 154 extractFromBranchWeightMD(ProfileData, Weights); 155 return true; 156 } 157 158 bool extractBranchWeights(const Instruction &I, 159 SmallVectorImpl<uint32_t> &Weights) { 160 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 161 return extractBranchWeights(ProfileData, Weights); 162 } 163 164 bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal, 165 uint64_t &FalseVal) { 166 assert((I.getOpcode() == Instruction::Br || 167 I.getOpcode() == Instruction::Select) && 168 "Looking for branch weights on something besides branch, select, or " 169 "switch"); 170 171 SmallVector<uint32_t, 2> Weights; 172 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 173 if (!extractBranchWeights(ProfileData, Weights)) 174 return false; 175 176 if (Weights.size() > 2) 177 return false; 178 179 TrueVal = Weights[0]; 180 FalseVal = Weights[1]; 181 return true; 182 } 183 184 bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) { 185 TotalVal = 0; 186 if (!ProfileData) 187 return false; 188 189 auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0)); 190 if (!ProfDataName) 191 return false; 192 193 if (ProfDataName->getString() == "branch_weights") { 194 for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) { 195 auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); 196 assert(V && "Malformed branch_weight in MD_prof node"); 197 TotalVal += V->getValue().getZExtValue(); 198 } 199 return true; 200 } 201 202 if (ProfDataName->getString() == "VP" && ProfileData->getNumOperands() > 3) { 203 TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2)) 204 ->getValue() 205 .getZExtValue(); 206 return true; 207 } 208 return false; 209 } 210 211 bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) { 212 return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal); 213 } 214 215 void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) { 216 MDBuilder MDB(I.getContext()); 217 MDNode *BranchWeights = MDB.createBranchWeights(Weights); 218 I.setMetadata(LLVMContext::MD_prof, BranchWeights); 219 } 220 221 void scaleProfData(Instruction &I, uint64_t S, uint64_t T) { 222 assert(T != 0 && "Caller should guarantee"); 223 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 224 if (ProfileData == nullptr) 225 return; 226 227 auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0)); 228 if (!ProfDataName || (ProfDataName->getString() != "branch_weights" && 229 ProfDataName->getString() != "VP")) 230 return; 231 232 if (!hasCountTypeMD(I)) 233 return; 234 235 LLVMContext &C = I.getContext(); 236 237 MDBuilder MDB(C); 238 SmallVector<Metadata *, 3> Vals; 239 Vals.push_back(ProfileData->getOperand(0)); 240 APInt APS(128, S), APT(128, T); 241 if (ProfDataName->getString() == "branch_weights" && 242 ProfileData->getNumOperands() > 0) { 243 // Using APInt::div may be expensive, but most cases should fit 64 bits. 244 APInt Val(128, mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(1)) 245 ->getValue() 246 .getZExtValue()); 247 Val *= APS; 248 Vals.push_back(MDB.createConstant(ConstantInt::get( 249 Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX)))); 250 } else if (ProfDataName->getString() == "VP") 251 for (unsigned i = 1; i < ProfileData->getNumOperands(); i += 2) { 252 // The first value is the key of the value profile, which will not change. 253 Vals.push_back(ProfileData->getOperand(i)); 254 uint64_t Count = 255 mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i + 1)) 256 ->getValue() 257 .getZExtValue(); 258 // Don't scale the magic number. 259 if (Count == NOMORE_ICP_MAGICNUM) { 260 Vals.push_back(ProfileData->getOperand(i + 1)); 261 continue; 262 } 263 // Using APInt::div may be expensive, but most cases should fit 64 bits. 264 APInt Val(128, Count); 265 Val *= APS; 266 Vals.push_back(MDB.createConstant(ConstantInt::get( 267 Type::getInt64Ty(C), Val.udiv(APT).getLimitedValue()))); 268 } 269 I.setMetadata(LLVMContext::MD_prof, MDNode::get(C, Vals)); 270 } 271 272 } // namespace llvm 273