xref: /llvm-project/bolt/lib/Passes/IndirectCallPromotion.cpp (revision a5f3d1a803020167bd9d494a8a3921e7dcc1550a)
1 //===- bolt/Passes/IndirectCallPromotion.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the IndirectCallPromotion class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "bolt/Passes/IndirectCallPromotion.h"
14 #include "bolt/Passes/BinaryFunctionCallGraph.h"
15 #include "bolt/Passes/DataflowInfoManager.h"
16 #include "bolt/Passes/Inliner.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/Support/CommandLine.h"
19 #include <iterator>
20 
21 #define DEBUG_TYPE "ICP"
22 #define DEBUG_VERBOSE(Level, X)                                                \
23   if (opts::Verbosity >= (Level)) {                                            \
24     X;                                                                         \
25   }
26 
27 using namespace llvm;
28 using namespace bolt;
29 
30 namespace opts {
31 
32 extern cl::OptionCategory BoltOptCategory;
33 
34 extern cl::opt<IndirectCallPromotionType> ICP;
35 extern cl::opt<unsigned> Verbosity;
36 extern cl::opt<unsigned> ExecutionCountThreshold;
37 
38 static cl::opt<unsigned> ICPJTRemainingPercentThreshold(
39     "icp-jt-remaining-percent-threshold",
40     cl::desc("The percentage threshold against remaining unpromoted indirect "
41              "call count for the promotion for jump tables"),
42     cl::init(30), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
43 
44 static cl::opt<unsigned> ICPJTTotalPercentThreshold(
45     "icp-jt-total-percent-threshold",
46     cl::desc(
47         "The percentage threshold against total count for the promotion for "
48         "jump tables"),
49     cl::init(5), cl::Hidden, cl::cat(BoltOptCategory));
50 
51 static cl::opt<unsigned> ICPCallsRemainingPercentThreshold(
52     "icp-calls-remaining-percent-threshold",
53     cl::desc("The percentage threshold against remaining unpromoted indirect "
54              "call count for the promotion for calls"),
55     cl::init(50), cl::Hidden, cl::cat(BoltOptCategory));
56 
57 static cl::opt<unsigned> ICPCallsTotalPercentThreshold(
58     "icp-calls-total-percent-threshold",
59     cl::desc(
60         "The percentage threshold against total count for the promotion for "
61         "calls"),
62     cl::init(30), cl::Hidden, cl::cat(BoltOptCategory));
63 
64 static cl::opt<unsigned> ICPMispredictThreshold(
65     "indirect-call-promotion-mispredict-threshold",
66     cl::desc("misprediction threshold for skipping ICP on an "
67              "indirect call"),
68     cl::init(0), cl::cat(BoltOptCategory));
69 
70 static cl::alias ICPMispredictThresholdAlias(
71     "icp-mp-threshold",
72     cl::desc("alias for --indirect-call-promotion-mispredict-threshold"),
73     cl::aliasopt(ICPMispredictThreshold));
74 
75 static cl::opt<bool> ICPUseMispredicts(
76     "indirect-call-promotion-use-mispredicts",
77     cl::desc("use misprediction frequency for determining whether or not ICP "
78              "should be applied at a callsite.  The "
79              "-indirect-call-promotion-mispredict-threshold value will be used "
80              "by this heuristic"),
81     cl::cat(BoltOptCategory));
82 
83 static cl::alias ICPUseMispredictsAlias(
84     "icp-use-mp",
85     cl::desc("alias for --indirect-call-promotion-use-mispredicts"),
86     cl::aliasopt(ICPUseMispredicts));
87 
88 static cl::opt<unsigned>
89     ICPTopN("indirect-call-promotion-topn",
90             cl::desc("limit number of targets to consider when doing indirect "
91                      "call promotion. 0 = no limit"),
92             cl::init(3), cl::cat(BoltOptCategory));
93 
94 static cl::alias
95     ICPTopNAlias("icp-topn",
96                  cl::desc("alias for --indirect-call-promotion-topn"),
97                  cl::aliasopt(ICPTopN));
98 
99 static cl::opt<unsigned> ICPCallsTopN(
100     "indirect-call-promotion-calls-topn",
101     cl::desc("limit number of targets to consider when doing indirect "
102              "call promotion on calls. 0 = no limit"),
103     cl::init(0), cl::cat(BoltOptCategory));
104 
105 static cl::alias ICPCallsTopNAlias(
106     "icp-calls-topn",
107     cl::desc("alias for --indirect-call-promotion-calls-topn"),
108     cl::aliasopt(ICPCallsTopN));
109 
110 static cl::opt<unsigned> ICPJumpTablesTopN(
111     "indirect-call-promotion-jump-tables-topn",
112     cl::desc("limit number of targets to consider when doing indirect "
113              "call promotion on jump tables. 0 = no limit"),
114     cl::init(0), cl::cat(BoltOptCategory));
115 
116 static cl::alias ICPJumpTablesTopNAlias(
117     "icp-jt-topn",
118     cl::desc("alias for --indirect-call-promotion-jump-tables-topn"),
119     cl::aliasopt(ICPJumpTablesTopN));
120 
121 static cl::opt<bool> EliminateLoads(
122     "icp-eliminate-loads",
123     cl::desc("enable load elimination using memory profiling data when "
124              "performing ICP"),
125     cl::init(true), cl::cat(BoltOptCategory));
126 
127 static cl::opt<unsigned> ICPTopCallsites(
128     "icp-top-callsites",
129     cl::desc("optimize hottest calls until at least this percentage of all "
130              "indirect calls frequency is covered. 0 = all callsites"),
131     cl::init(99), cl::Hidden, cl::cat(BoltOptCategory));
132 
133 static cl::list<std::string>
134     ICPFuncsList("icp-funcs", cl::CommaSeparated,
135                  cl::desc("list of functions to enable ICP for"),
136                  cl::value_desc("func1,func2,func3,..."), cl::Hidden,
137                  cl::cat(BoltOptCategory));
138 
139 static cl::opt<bool>
140     ICPOldCodeSequence("icp-old-code-sequence",
141                        cl::desc("use old code sequence for promoted calls"),
142                        cl::Hidden, cl::cat(BoltOptCategory));
143 
144 static cl::opt<bool> ICPJumpTablesByTarget(
145     "icp-jump-tables-targets",
146     cl::desc(
147         "for jump tables, optimize indirect jmp targets instead of indices"),
148     cl::Hidden, cl::cat(BoltOptCategory));
149 
150 static cl::alias
151     ICPJumpTablesByTargetAlias("icp-jt-targets",
152                                cl::desc("alias for --icp-jump-tables-targets"),
153                                cl::aliasopt(ICPJumpTablesByTarget));
154 
155 static cl::opt<bool> ICPPeelForInline(
156     "icp-inline", cl::desc("only promote call targets eligible for inlining"),
157     cl::Hidden, cl::cat(BoltOptCategory));
158 
159 } // namespace opts
160 
161 #ifndef NDEBUG
162 static bool verifyProfile(std::map<uint64_t, BinaryFunction> &BFs) {
163   bool IsValid = true;
164   for (auto &BFI : BFs) {
165     BinaryFunction &BF = BFI.second;
166     if (!BF.isSimple())
167       continue;
168     for (const BinaryBasicBlock &BB : BF) {
169       auto BI = BB.branch_info_begin();
170       for (BinaryBasicBlock *SuccBB : BB.successors()) {
171         if (BI->Count != BinaryBasicBlock::COUNT_NO_PROFILE && BI->Count > 0) {
172           if (BB.getKnownExecutionCount() == 0 ||
173               SuccBB->getKnownExecutionCount() == 0) {
174             errs() << "BOLT-WARNING: profile verification failed after ICP for "
175                       "function "
176                    << BF << '\n';
177             IsValid = false;
178           }
179         }
180         ++BI;
181       }
182     }
183   }
184   return IsValid;
185 }
186 #endif
187 
188 namespace llvm {
189 namespace bolt {
190 
191 IndirectCallPromotion::Callsite::Callsite(BinaryFunction &BF,
192                                           const IndirectCallProfile &ICP)
193     : From(BF.getSymbol()), To(ICP.Offset), Mispreds(ICP.Mispreds),
194       Branches(ICP.Count) {
195   if (ICP.Symbol) {
196     To.Sym = ICP.Symbol;
197     To.Addr = 0;
198   }
199 }
200 
201 void IndirectCallPromotion::printDecision(
202     llvm::raw_ostream &OS,
203     std::vector<IndirectCallPromotion::Callsite> &Targets, unsigned N) const {
204   uint64_t TotalCount = 0;
205   uint64_t TotalMispreds = 0;
206   for (const Callsite &S : Targets) {
207     TotalCount += S.Branches;
208     TotalMispreds += S.Mispreds;
209   }
210   if (!TotalCount)
211     TotalCount = 1;
212   if (!TotalMispreds)
213     TotalMispreds = 1;
214 
215   OS << "BOLT-INFO: ICP decision for call site with " << Targets.size()
216      << " targets, Count = " << TotalCount << ", Mispreds = " << TotalMispreds
217      << "\n";
218 
219   size_t I = 0;
220   for (const Callsite &S : Targets) {
221     OS << "Count = " << S.Branches << ", "
222        << format("%.1f", (100.0 * S.Branches) / TotalCount) << ", "
223        << "Mispreds = " << S.Mispreds << ", "
224        << format("%.1f", (100.0 * S.Mispreds) / TotalMispreds);
225     if (I < N)
226       OS << " * to be optimized *";
227     if (!S.JTIndices.empty()) {
228       OS << " Indices:";
229       for (const uint64_t Idx : S.JTIndices)
230         OS << " " << Idx;
231     }
232     OS << "\n";
233     I += S.JTIndices.empty() ? 1 : S.JTIndices.size();
234   }
235 }
236 
237 // Get list of targets for a given call sorted by most frequently
238 // called first.
239 std::vector<IndirectCallPromotion::Callsite>
240 IndirectCallPromotion::getCallTargets(BinaryBasicBlock &BB,
241                                       const MCInst &Inst) const {
242   BinaryFunction &BF = *BB.getFunction();
243   const BinaryContext &BC = BF.getBinaryContext();
244   std::vector<Callsite> Targets;
245 
246   if (const JumpTable *JT = BF.getJumpTable(Inst)) {
247     // Don't support PIC jump tables for now
248     if (!opts::ICPJumpTablesByTarget && JT->Type == JumpTable::JTT_PIC)
249       return Targets;
250     const Location From(BF.getSymbol());
251     const std::pair<size_t, size_t> Range =
252         JT->getEntriesForAddress(BC.MIB->getJumpTable(Inst));
253     assert(JT->Counts.empty() || JT->Counts.size() >= Range.second);
254     JumpTable::JumpInfo DefaultJI;
255     const JumpTable::JumpInfo *JI =
256         JT->Counts.empty() ? &DefaultJI : &JT->Counts[Range.first];
257     const size_t JIAdj = JT->Counts.empty() ? 0 : 1;
258     assert(JT->Type == JumpTable::JTT_PIC ||
259            JT->EntrySize == BC.AsmInfo->getCodePointerSize());
260     for (size_t I = Range.first; I < Range.second; ++I, JI += JIAdj) {
261       MCSymbol *Entry = JT->Entries[I];
262       const BinaryBasicBlock *ToBB = BF.getBasicBlockForLabel(Entry);
263       assert(ToBB || Entry == BF.getFunctionEndLabel() ||
264              Entry == BF.getFunctionEndLabel(FragmentNum::cold()));
265       if (Entry == BF.getFunctionEndLabel() ||
266           Entry == BF.getFunctionEndLabel(FragmentNum::cold()))
267         continue;
268       const Location To(Entry);
269       const BinaryBasicBlock::BinaryBranchInfo &BI = BB.getBranchInfo(*ToBB);
270       Targets.emplace_back(From, To, BI.MispredictedCount, BI.Count,
271                            I - Range.first);
272     }
273 
274     // Sort by symbol then addr.
275     llvm::sort(Targets, [](const Callsite &A, const Callsite &B) {
276       if (A.To.Sym && B.To.Sym)
277         return A.To.Sym < B.To.Sym;
278       else if (A.To.Sym && !B.To.Sym)
279         return true;
280       else if (!A.To.Sym && B.To.Sym)
281         return false;
282       else
283         return A.To.Addr < B.To.Addr;
284     });
285 
286     // Targets may contain multiple entries to the same target, but using
287     // different indices. Their profile will report the same number of branches
288     // for different indices if the target is the same. That's because we don't
289     // profile the index value, but only the target via LBR.
290     auto First = Targets.begin();
291     auto Last = Targets.end();
292     auto Result = First;
293     while (++First != Last) {
294       Callsite &A = *Result;
295       const Callsite &B = *First;
296       if (A.To.Sym && B.To.Sym && A.To.Sym == B.To.Sym)
297         A.JTIndices.insert(A.JTIndices.end(), B.JTIndices.begin(),
298                            B.JTIndices.end());
299       else
300         *(++Result) = *First;
301     }
302     ++Result;
303 
304     LLVM_DEBUG(if (Targets.end() - Result > 0) {
305       dbgs() << "BOLT-INFO: ICP: " << (Targets.end() - Result)
306              << " duplicate targets removed\n";
307     });
308 
309     Targets.erase(Result, Targets.end());
310   } else {
311     // Don't try to optimize PC relative indirect calls.
312     if (Inst.getOperand(0).isReg() &&
313         Inst.getOperand(0).getReg() == BC.MRI->getProgramCounter())
314       return Targets;
315 
316     const auto ICSP = BC.MIB->tryGetAnnotationAs<IndirectCallSiteProfile>(
317         Inst, "CallProfile");
318     if (ICSP) {
319       for (const IndirectCallProfile &CSP : ICSP.get()) {
320         Callsite Site(BF, CSP);
321         if (Site.isValid())
322           Targets.emplace_back(std::move(Site));
323       }
324     }
325   }
326 
327   // Sort by target count, number of indices in case of jump table, and
328   // mispredicts. We prioritize targets with high count, small number of indices
329   // and high mispredicts. Break ties by selecting targets with lower addresses.
330   llvm::stable_sort(Targets, [](const Callsite &A, const Callsite &B) {
331     if (A.Branches != B.Branches)
332       return A.Branches > B.Branches;
333     if (A.JTIndices.size() != B.JTIndices.size())
334       return A.JTIndices.size() < B.JTIndices.size();
335     if (A.Mispreds != B.Mispreds)
336       return A.Mispreds > B.Mispreds;
337     return A.To.Addr < B.To.Addr;
338   });
339 
340   // Remove non-symbol targets
341   llvm::erase_if(Targets, [](const Callsite &CS) { return !CS.To.Sym; });
342 
343   LLVM_DEBUG(if (BF.getJumpTable(Inst)) {
344     uint64_t TotalCount = 0;
345     uint64_t TotalMispreds = 0;
346     for (const Callsite &S : Targets) {
347       TotalCount += S.Branches;
348       TotalMispreds += S.Mispreds;
349     }
350     if (!TotalCount)
351       TotalCount = 1;
352     if (!TotalMispreds)
353       TotalMispreds = 1;
354 
355     dbgs() << "BOLT-INFO: ICP: jump table size = " << Targets.size()
356            << ", Count = " << TotalCount << ", Mispreds = " << TotalMispreds
357            << "\n";
358 
359     size_t I = 0;
360     for (const Callsite &S : Targets) {
361       dbgs() << "Count[" << I << "] = " << S.Branches << ", "
362              << format("%.1f", (100.0 * S.Branches) / TotalCount) << ", "
363              << "Mispreds[" << I << "] = " << S.Mispreds << ", "
364              << format("%.1f", (100.0 * S.Mispreds) / TotalMispreds) << "\n";
365       ++I;
366     }
367   });
368 
369   return Targets;
370 }
371 
372 IndirectCallPromotion::JumpTableInfoType
373 IndirectCallPromotion::maybeGetHotJumpTableTargets(BinaryBasicBlock &BB,
374                                                    MCInst &CallInst,
375                                                    MCInst *&TargetFetchInst,
376                                                    const JumpTable *JT) const {
377   assert(JT && "Can't get jump table addrs for non-jump tables.");
378 
379   BinaryFunction &Function = *BB.getFunction();
380   BinaryContext &BC = Function.getBinaryContext();
381 
382   if (!Function.hasMemoryProfile() || !opts::EliminateLoads)
383     return JumpTableInfoType();
384 
385   JumpTableInfoType HotTargets;
386   MCInst *MemLocInstr;
387   MCInst *PCRelBaseOut;
388   unsigned BaseReg, IndexReg;
389   int64_t DispValue;
390   const MCExpr *DispExpr;
391   MutableArrayRef<MCInst> Insts(&BB.front(), &CallInst);
392   const IndirectBranchType Type = BC.MIB->analyzeIndirectBranch(
393       CallInst, Insts.begin(), Insts.end(), BC.AsmInfo->getCodePointerSize(),
394       MemLocInstr, BaseReg, IndexReg, DispValue, DispExpr, PCRelBaseOut);
395 
396   assert(MemLocInstr && "There should always be a load for jump tables");
397   if (!MemLocInstr)
398     return JumpTableInfoType();
399 
400   LLVM_DEBUG({
401     dbgs() << "BOLT-INFO: ICP attempting to find memory profiling data for "
402            << "jump table in " << Function << " at @ "
403            << (&CallInst - &BB.front()) << "\n"
404            << "BOLT-INFO: ICP target fetch instructions:\n";
405     BC.printInstruction(dbgs(), *MemLocInstr, 0, &Function);
406     if (MemLocInstr != &CallInst)
407       BC.printInstruction(dbgs(), CallInst, 0, &Function);
408   });
409 
410   DEBUG_VERBOSE(1, {
411     dbgs() << "Jmp info: Type = " << (unsigned)Type << ", "
412            << "BaseReg = " << BC.MRI->getName(BaseReg) << ", "
413            << "IndexReg = " << BC.MRI->getName(IndexReg) << ", "
414            << "DispValue = " << Twine::utohexstr(DispValue) << ", "
415            << "DispExpr = " << DispExpr << ", "
416            << "MemLocInstr = ";
417     BC.printInstruction(dbgs(), *MemLocInstr, 0, &Function);
418     dbgs() << "\n";
419   });
420 
421   ++TotalIndexBasedCandidates;
422 
423   auto ErrorOrMemAccessProfile =
424       BC.MIB->tryGetAnnotationAs<MemoryAccessProfile>(*MemLocInstr,
425                                                       "MemoryAccessProfile");
426   if (!ErrorOrMemAccessProfile) {
427     DEBUG_VERBOSE(1, dbgs()
428                          << "BOLT-INFO: ICP no memory profiling data found\n");
429     return JumpTableInfoType();
430   }
431   MemoryAccessProfile &MemAccessProfile = ErrorOrMemAccessProfile.get();
432 
433   uint64_t ArrayStart;
434   if (DispExpr) {
435     ErrorOr<uint64_t> DispValueOrError =
436         BC.getSymbolValue(*BC.MIB->getTargetSymbol(DispExpr));
437     assert(DispValueOrError && "global symbol needs a value");
438     ArrayStart = *DispValueOrError;
439   } else {
440     ArrayStart = static_cast<uint64_t>(DispValue);
441   }
442 
443   if (BaseReg == BC.MRI->getProgramCounter())
444     ArrayStart += Function.getAddress() + MemAccessProfile.NextInstrOffset;
445 
446   // This is a map of [symbol] -> [count, index] and is used to combine indices
447   // into the jump table since there may be multiple addresses that all have the
448   // same entry.
449   std::map<MCSymbol *, std::pair<uint64_t, uint64_t>> HotTargetMap;
450   const std::pair<size_t, size_t> Range = JT->getEntriesForAddress(ArrayStart);
451 
452   for (const AddressAccess &AccessInfo : MemAccessProfile.AddressAccessInfo) {
453     size_t Index;
454     // Mem data occasionally includes nullprs, ignore them.
455     if (!AccessInfo.MemoryObject && !AccessInfo.Offset)
456       continue;
457 
458     if (AccessInfo.Offset % JT->EntrySize != 0) // ignore bogus data
459       return JumpTableInfoType();
460 
461     if (AccessInfo.MemoryObject) {
462       // Deal with bad/stale data
463       if (!AccessInfo.MemoryObject->getName().starts_with(
464               "JUMP_TABLE/" + Function.getOneName().str()))
465         return JumpTableInfoType();
466       Index =
467           (AccessInfo.Offset - (ArrayStart - JT->getAddress())) / JT->EntrySize;
468     } else {
469       Index = (AccessInfo.Offset - ArrayStart) / JT->EntrySize;
470     }
471 
472     // If Index is out of range it probably means the memory profiling data is
473     // wrong for this instruction, bail out.
474     if (Index >= Range.second) {
475       LLVM_DEBUG(dbgs() << "BOLT-INFO: Index out of range of " << Range.first
476                         << ", " << Range.second << "\n");
477       return JumpTableInfoType();
478     }
479 
480     // Make sure the hot index points at a legal label corresponding to a BB,
481     // e.g. not the end of function (unreachable) label.
482     if (!Function.getBasicBlockForLabel(JT->Entries[Index + Range.first])) {
483       LLVM_DEBUG({
484         dbgs() << "BOLT-INFO: hot index " << Index << " pointing at bogus "
485                << "label " << JT->Entries[Index + Range.first]->getName()
486                << " in jump table:\n";
487         JT->print(dbgs());
488         dbgs() << "HotTargetMap:\n";
489         for (std::pair<MCSymbol *const, std::pair<uint64_t, uint64_t>> &HT :
490              HotTargetMap)
491           dbgs() << "BOLT-INFO: " << HT.first->getName()
492                  << " = (count=" << HT.second.first
493                  << ", index=" << HT.second.second << ")\n";
494       });
495       return JumpTableInfoType();
496     }
497 
498     std::pair<uint64_t, uint64_t> &HotTarget =
499         HotTargetMap[JT->Entries[Index + Range.first]];
500     HotTarget.first += AccessInfo.Count;
501     HotTarget.second = Index;
502   }
503 
504   llvm::copy(llvm::make_second_range(HotTargetMap),
505              std::back_inserter(HotTargets));
506 
507   // Sort with highest counts first.
508   llvm::sort(reverse(HotTargets));
509 
510   LLVM_DEBUG({
511     dbgs() << "BOLT-INFO: ICP jump table hot targets:\n";
512     for (const std::pair<uint64_t, uint64_t> &Target : HotTargets)
513       dbgs() << "BOLT-INFO:  Idx = " << Target.second << ", "
514              << "Count = " << Target.first << "\n";
515   });
516 
517   BC.MIB->getOrCreateAnnotationAs<uint16_t>(CallInst, "JTIndexReg") = IndexReg;
518 
519   TargetFetchInst = MemLocInstr;
520 
521   return HotTargets;
522 }
523 
524 IndirectCallPromotion::SymTargetsType
525 IndirectCallPromotion::findCallTargetSymbols(std::vector<Callsite> &Targets,
526                                              size_t &N, BinaryBasicBlock &BB,
527                                              MCInst &CallInst,
528                                              MCInst *&TargetFetchInst) const {
529   const JumpTable *JT = BB.getFunction()->getJumpTable(CallInst);
530   SymTargetsType SymTargets;
531 
532   if (!JT) {
533     for (size_t I = 0; I < N; ++I) {
534       assert(Targets[I].To.Sym && "All ICP targets must be to known symbols");
535       assert(Targets[I].JTIndices.empty() &&
536              "Can't have jump table indices for non-jump tables");
537       SymTargets.emplace_back(Targets[I].To.Sym, 0);
538     }
539     return SymTargets;
540   }
541 
542   // Use memory profile to select hot targets.
543   JumpTableInfoType HotTargets =
544       maybeGetHotJumpTableTargets(BB, CallInst, TargetFetchInst, JT);
545 
546   auto findTargetsIndex = [&](uint64_t JTIndex) {
547     for (size_t I = 0; I < Targets.size(); ++I)
548       if (llvm::is_contained(Targets[I].JTIndices, JTIndex))
549         return I;
550     LLVM_DEBUG(dbgs() << "BOLT-ERROR: Unable to find target index for hot jump "
551                       << " table entry in " << *BB.getFunction() << "\n");
552     llvm_unreachable("Hot indices must be referred to by at least one "
553                      "callsite");
554   };
555 
556   if (!HotTargets.empty()) {
557     if (opts::Verbosity >= 1)
558       for (size_t I = 0; I < HotTargets.size(); ++I)
559         outs() << "BOLT-INFO: HotTarget[" << I << "] = (" << HotTargets[I].first
560                << ", " << HotTargets[I].second << ")\n";
561 
562     // Recompute hottest targets, now discriminating which index is hot
563     // NOTE: This is a tradeoff. On one hand, we get index information. On the
564     // other hand, info coming from the memory profile is much less accurate
565     // than LBRs. So we may actually end up working with more coarse
566     // profile granularity in exchange for information about indices.
567     std::vector<Callsite> NewTargets;
568     std::map<const MCSymbol *, uint32_t> IndicesPerTarget;
569     uint64_t TotalMemAccesses = 0;
570     for (size_t I = 0; I < HotTargets.size(); ++I) {
571       const uint64_t TargetIndex = findTargetsIndex(HotTargets[I].second);
572       ++IndicesPerTarget[Targets[TargetIndex].To.Sym];
573       TotalMemAccesses += HotTargets[I].first;
574     }
575     uint64_t RemainingMemAccesses = TotalMemAccesses;
576     const size_t TopN =
577         opts::ICPJumpTablesTopN ? opts::ICPJumpTablesTopN : opts::ICPTopN;
578     size_t I = 0;
579     for (; I < HotTargets.size(); ++I) {
580       const uint64_t MemAccesses = HotTargets[I].first;
581       if (100 * MemAccesses <
582           TotalMemAccesses * opts::ICPJTTotalPercentThreshold)
583         break;
584       if (100 * MemAccesses <
585           RemainingMemAccesses * opts::ICPJTRemainingPercentThreshold)
586         break;
587       if (TopN && I >= TopN)
588         break;
589       RemainingMemAccesses -= MemAccesses;
590 
591       const uint64_t JTIndex = HotTargets[I].second;
592       Callsite &Target = Targets[findTargetsIndex(JTIndex)];
593 
594       NewTargets.push_back(Target);
595       std::vector<uint64_t>({JTIndex}).swap(NewTargets.back().JTIndices);
596       llvm::erase(Target.JTIndices, JTIndex);
597 
598       // Keep fixCFG counts sane if more indices use this same target later
599       assert(IndicesPerTarget[Target.To.Sym] > 0 && "wrong map");
600       NewTargets.back().Branches =
601           Target.Branches / IndicesPerTarget[Target.To.Sym];
602       NewTargets.back().Mispreds =
603           Target.Mispreds / IndicesPerTarget[Target.To.Sym];
604       assert(Target.Branches >= NewTargets.back().Branches);
605       assert(Target.Mispreds >= NewTargets.back().Mispreds);
606       Target.Branches -= NewTargets.back().Branches;
607       Target.Mispreds -= NewTargets.back().Mispreds;
608     }
609     llvm::copy(Targets, std::back_inserter(NewTargets));
610     std::swap(NewTargets, Targets);
611     N = I;
612 
613     if (N == 0 && opts::Verbosity >= 1) {
614       outs() << "BOLT-INFO: ICP failed in " << *BB.getFunction() << " in "
615              << BB.getName() << ": failed to meet thresholds after memory "
616              << "profile data was loaded.\n";
617       return SymTargets;
618     }
619   }
620 
621   for (size_t I = 0, TgtIdx = 0; I < N; ++TgtIdx) {
622     Callsite &Target = Targets[TgtIdx];
623     assert(Target.To.Sym && "All ICP targets must be to known symbols");
624     assert(!Target.JTIndices.empty() && "Jump tables must have indices");
625     for (uint64_t Idx : Target.JTIndices) {
626       SymTargets.emplace_back(Target.To.Sym, Idx);
627       ++I;
628     }
629   }
630 
631   return SymTargets;
632 }
633 
634 IndirectCallPromotion::MethodInfoType IndirectCallPromotion::maybeGetVtableSyms(
635     BinaryBasicBlock &BB, MCInst &Inst,
636     const SymTargetsType &SymTargets) const {
637   BinaryFunction &Function = *BB.getFunction();
638   BinaryContext &BC = Function.getBinaryContext();
639   std::vector<std::pair<MCSymbol *, uint64_t>> VtableSyms;
640   std::vector<MCInst *> MethodFetchInsns;
641   unsigned VtableReg, MethodReg;
642   uint64_t MethodOffset;
643 
644   assert(!Function.getJumpTable(Inst) &&
645          "Can't get vtable addrs for jump tables.");
646 
647   if (!Function.hasMemoryProfile() || !opts::EliminateLoads)
648     return MethodInfoType();
649 
650   MutableArrayRef<MCInst> Insts(&BB.front(), &Inst + 1);
651   if (!BC.MIB->analyzeVirtualMethodCall(Insts.begin(), Insts.end(),
652                                         MethodFetchInsns, VtableReg, MethodReg,
653                                         MethodOffset)) {
654     DEBUG_VERBOSE(
655         1, dbgs() << "BOLT-INFO: ICP unable to analyze method call in "
656                   << Function << " at @ " << (&Inst - &BB.front()) << "\n");
657     return MethodInfoType();
658   }
659 
660   ++TotalMethodLoadEliminationCandidates;
661 
662   DEBUG_VERBOSE(1, {
663     dbgs() << "BOLT-INFO: ICP found virtual method call in " << Function
664            << " at @ " << (&Inst - &BB.front()) << "\n";
665     dbgs() << "BOLT-INFO: ICP method fetch instructions:\n";
666     for (MCInst *Inst : MethodFetchInsns)
667       BC.printInstruction(dbgs(), *Inst, 0, &Function);
668 
669     if (MethodFetchInsns.back() != &Inst)
670       BC.printInstruction(dbgs(), Inst, 0, &Function);
671   });
672 
673   // Try to get value profiling data for the method load instruction.
674   auto ErrorOrMemAccessProfile =
675       BC.MIB->tryGetAnnotationAs<MemoryAccessProfile>(*MethodFetchInsns.back(),
676                                                       "MemoryAccessProfile");
677   if (!ErrorOrMemAccessProfile) {
678     DEBUG_VERBOSE(1, dbgs()
679                          << "BOLT-INFO: ICP no memory profiling data found\n");
680     return MethodInfoType();
681   }
682   MemoryAccessProfile &MemAccessProfile = ErrorOrMemAccessProfile.get();
683 
684   // Find the vtable that each method belongs to.
685   std::map<const MCSymbol *, uint64_t> MethodToVtable;
686 
687   for (const AddressAccess &AccessInfo : MemAccessProfile.AddressAccessInfo) {
688     uint64_t Address = AccessInfo.Offset;
689     if (AccessInfo.MemoryObject)
690       Address += AccessInfo.MemoryObject->getAddress();
691 
692     // Ignore bogus data.
693     if (!Address)
694       continue;
695 
696     const uint64_t VtableBase = Address - MethodOffset;
697 
698     DEBUG_VERBOSE(1, dbgs() << "BOLT-INFO: ICP vtable = "
699                             << Twine::utohexstr(VtableBase) << "+"
700                             << MethodOffset << "/" << AccessInfo.Count << "\n");
701 
702     if (ErrorOr<uint64_t> MethodAddr = BC.getPointerAtAddress(Address)) {
703       BinaryData *MethodBD = BC.getBinaryDataAtAddress(MethodAddr.get());
704       if (!MethodBD) // skip unknown methods
705         continue;
706       MCSymbol *MethodSym = MethodBD->getSymbol();
707       MethodToVtable[MethodSym] = VtableBase;
708       DEBUG_VERBOSE(1, {
709         const BinaryFunction *Method = BC.getFunctionForSymbol(MethodSym);
710         dbgs() << "BOLT-INFO: ICP found method = "
711                << Twine::utohexstr(MethodAddr.get()) << "/"
712                << (Method ? Method->getPrintName() : "") << "\n";
713       });
714     }
715   }
716 
717   // Find the vtable for each target symbol.
718   for (size_t I = 0; I < SymTargets.size(); ++I) {
719     auto Itr = MethodToVtable.find(SymTargets[I].first);
720     if (Itr != MethodToVtable.end()) {
721       if (BinaryData *BD = BC.getBinaryDataContainingAddress(Itr->second)) {
722         const uint64_t Addend = Itr->second - BD->getAddress();
723         VtableSyms.emplace_back(BD->getSymbol(), Addend);
724         continue;
725       }
726     }
727     // Give up if we can't find the vtable for a method.
728     DEBUG_VERBOSE(1, dbgs() << "BOLT-INFO: ICP can't find vtable for "
729                             << SymTargets[I].first->getName() << "\n");
730     return MethodInfoType();
731   }
732 
733   // Make sure the vtable reg is not clobbered by the argument passing code
734   if (VtableReg != MethodReg) {
735     for (MCInst *CurInst = MethodFetchInsns.front(); CurInst < &Inst;
736          ++CurInst) {
737       const MCInstrDesc &InstrInfo = BC.MII->get(CurInst->getOpcode());
738       if (InstrInfo.hasDefOfPhysReg(*CurInst, VtableReg, *BC.MRI))
739         return MethodInfoType();
740     }
741   }
742 
743   return MethodInfoType(VtableSyms, MethodFetchInsns);
744 }
745 
746 std::vector<std::unique_ptr<BinaryBasicBlock>>
747 IndirectCallPromotion::rewriteCall(
748     BinaryBasicBlock &IndCallBlock, const MCInst &CallInst,
749     MCPlusBuilder::BlocksVectorTy &&ICPcode,
750     const std::vector<MCInst *> &MethodFetchInsns) const {
751   BinaryFunction &Function = *IndCallBlock.getFunction();
752   MCPlusBuilder *MIB = Function.getBinaryContext().MIB.get();
753 
754   // Create new basic blocks with correct code in each one first.
755   std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs;
756   const bool IsTailCallOrJT =
757       (MIB->isTailCall(CallInst) || Function.getJumpTable(CallInst));
758 
759   // If we are tracking the indirect call/jump address, propagate the address to
760   // the ICP code.
761   const std::optional<uint32_t> IndirectInstrOffset = MIB->getOffset(CallInst);
762   if (IndirectInstrOffset) {
763     for (auto &[Symbol, Instructions] : ICPcode)
764       for (MCInst &Inst : Instructions)
765         MIB->setOffset(Inst, *IndirectInstrOffset);
766   }
767 
768   // Move instructions from the tail of the original call block
769   // to the merge block.
770 
771   // Remember any pseudo instructions following a tail call.  These
772   // must be preserved and moved to the original block.
773   InstructionListType TailInsts;
774   const MCInst *TailInst = &CallInst;
775   if (IsTailCallOrJT)
776     while (TailInst + 1 < &(*IndCallBlock.end()) &&
777            MIB->isPseudo(*(TailInst + 1)))
778       TailInsts.push_back(*++TailInst);
779 
780   InstructionListType MovedInst = IndCallBlock.splitInstructions(&CallInst);
781   // Link new BBs to the original input offset of the indirect call site or its
782   // containing BB, so we can map samples recorded in new BBs back to the
783   // original BB seen in the input binary (if using BAT).
784   const uint32_t OrigOffset = IndirectInstrOffset
785                                   ? *IndirectInstrOffset
786                                   : IndCallBlock.getInputOffset();
787 
788   IndCallBlock.eraseInstructions(MethodFetchInsns.begin(),
789                                  MethodFetchInsns.end());
790   if (IndCallBlock.empty() ||
791       (!MethodFetchInsns.empty() && MethodFetchInsns.back() == &CallInst))
792     IndCallBlock.addInstructions(ICPcode.front().second.begin(),
793                                  ICPcode.front().second.end());
794   else
795     IndCallBlock.replaceInstruction(std::prev(IndCallBlock.end()),
796                                     ICPcode.front().second);
797   IndCallBlock.addInstructions(TailInsts.begin(), TailInsts.end());
798 
799   for (auto Itr = ICPcode.begin() + 1; Itr != ICPcode.end(); ++Itr) {
800     MCSymbol *&Sym = Itr->first;
801     InstructionListType &Insts = Itr->second;
802     assert(Sym);
803     std::unique_ptr<BinaryBasicBlock> TBB = Function.createBasicBlock(Sym);
804     TBB->setOffset(OrigOffset);
805     for (MCInst &Inst : Insts) // sanitize new instructions.
806       if (MIB->isCall(Inst))
807         MIB->removeAnnotation(Inst, "CallProfile");
808     TBB->addInstructions(Insts.begin(), Insts.end());
809     NewBBs.emplace_back(std::move(TBB));
810   }
811 
812   // Move tail of instructions from after the original call to
813   // the merge block.
814   if (!IsTailCallOrJT)
815     NewBBs.back()->addInstructions(MovedInst.begin(), MovedInst.end());
816 
817   return NewBBs;
818 }
819 
820 BinaryBasicBlock *
821 IndirectCallPromotion::fixCFG(BinaryBasicBlock &IndCallBlock,
822                               const bool IsTailCall, const bool IsJumpTable,
823                               IndirectCallPromotion::BasicBlocksVector &&NewBBs,
824                               const std::vector<Callsite> &Targets) const {
825   BinaryFunction &Function = *IndCallBlock.getFunction();
826   using BinaryBranchInfo = BinaryBasicBlock::BinaryBranchInfo;
827   BinaryBasicBlock *MergeBlock = nullptr;
828 
829   // Scale indirect call counts to the execution count of the original
830   // basic block containing the indirect call.
831   uint64_t TotalCount = IndCallBlock.getKnownExecutionCount();
832   uint64_t TotalIndirectBranches = 0;
833   for (const Callsite &Target : Targets)
834     TotalIndirectBranches += Target.Branches;
835   if (TotalIndirectBranches == 0)
836     TotalIndirectBranches = 1;
837   BinaryBasicBlock::BranchInfoType BBI;
838   BinaryBasicBlock::BranchInfoType ScaledBBI;
839   for (const Callsite &Target : Targets) {
840     const size_t NumEntries =
841         std::max(static_cast<std::size_t>(1UL), Target.JTIndices.size());
842     for (size_t I = 0; I < NumEntries; ++I) {
843       BBI.push_back(
844           BinaryBranchInfo{(Target.Branches + NumEntries - 1) / NumEntries,
845                            (Target.Mispreds + NumEntries - 1) / NumEntries});
846       ScaledBBI.push_back(
847           BinaryBranchInfo{uint64_t(TotalCount * Target.Branches /
848                                     (NumEntries * TotalIndirectBranches)),
849                            uint64_t(TotalCount * Target.Mispreds /
850                                     (NumEntries * TotalIndirectBranches))});
851     }
852   }
853 
854   if (IsJumpTable) {
855     BinaryBasicBlock *NewIndCallBlock = NewBBs.back().get();
856     IndCallBlock.moveAllSuccessorsTo(NewIndCallBlock);
857 
858     std::vector<MCSymbol *> SymTargets;
859     for (const Callsite &Target : Targets) {
860       const size_t NumEntries =
861           std::max(static_cast<std::size_t>(1UL), Target.JTIndices.size());
862       for (size_t I = 0; I < NumEntries; ++I)
863         SymTargets.push_back(Target.To.Sym);
864     }
865     assert(SymTargets.size() > NewBBs.size() - 1 &&
866            "There must be a target symbol associated with each new BB.");
867 
868     for (uint64_t I = 0; I < NewBBs.size(); ++I) {
869       BinaryBasicBlock *SourceBB = I ? NewBBs[I - 1].get() : &IndCallBlock;
870       SourceBB->setExecutionCount(TotalCount);
871 
872       BinaryBasicBlock *TargetBB =
873           Function.getBasicBlockForLabel(SymTargets[I]);
874       SourceBB->addSuccessor(TargetBB, ScaledBBI[I]); // taken
875 
876       TotalCount -= ScaledBBI[I].Count;
877       SourceBB->addSuccessor(NewBBs[I].get(), TotalCount); // fall-through
878 
879       // Update branch info for the indirect jump.
880       BinaryBasicBlock::BinaryBranchInfo &BranchInfo =
881           NewIndCallBlock->getBranchInfo(*TargetBB);
882       if (BranchInfo.Count > BBI[I].Count)
883         BranchInfo.Count -= BBI[I].Count;
884       else
885         BranchInfo.Count = 0;
886 
887       if (BranchInfo.MispredictedCount > BBI[I].MispredictedCount)
888         BranchInfo.MispredictedCount -= BBI[I].MispredictedCount;
889       else
890         BranchInfo.MispredictedCount = 0;
891     }
892   } else {
893     assert(NewBBs.size() >= 2);
894     assert(NewBBs.size() % 2 == 1 || IndCallBlock.succ_empty());
895     assert(NewBBs.size() % 2 == 1 || IsTailCall);
896 
897     auto ScaledBI = ScaledBBI.begin();
898     auto updateCurrentBranchInfo = [&] {
899       assert(ScaledBI != ScaledBBI.end());
900       TotalCount -= ScaledBI->Count;
901       ++ScaledBI;
902     };
903 
904     if (!IsTailCall) {
905       MergeBlock = NewBBs.back().get();
906       IndCallBlock.moveAllSuccessorsTo(MergeBlock);
907     }
908 
909     // Fix up successors and execution counts.
910     updateCurrentBranchInfo();
911     IndCallBlock.addSuccessor(NewBBs[1].get(), TotalCount);
912     IndCallBlock.addSuccessor(NewBBs[0].get(), ScaledBBI[0]);
913 
914     const size_t Adj = IsTailCall ? 1 : 2;
915     for (size_t I = 0; I < NewBBs.size() - Adj; ++I) {
916       assert(TotalCount <= IndCallBlock.getExecutionCount() ||
917              TotalCount <= uint64_t(TotalIndirectBranches));
918       uint64_t ExecCount = ScaledBBI[(I + 1) / 2].Count;
919       if (I % 2 == 0) {
920         if (MergeBlock)
921           NewBBs[I]->addSuccessor(MergeBlock, ScaledBBI[(I + 1) / 2].Count);
922       } else {
923         assert(I + 2 < NewBBs.size());
924         updateCurrentBranchInfo();
925         NewBBs[I]->addSuccessor(NewBBs[I + 2].get(), TotalCount);
926         NewBBs[I]->addSuccessor(NewBBs[I + 1].get(), ScaledBBI[(I + 1) / 2]);
927         ExecCount += TotalCount;
928       }
929       NewBBs[I]->setExecutionCount(ExecCount);
930     }
931 
932     if (MergeBlock) {
933       // Arrange for the MergeBlock to be the fallthrough for the first
934       // promoted call block.
935       std::unique_ptr<BinaryBasicBlock> MBPtr;
936       std::swap(MBPtr, NewBBs.back());
937       NewBBs.pop_back();
938       NewBBs.emplace(NewBBs.begin() + 1, std::move(MBPtr));
939       // TODO: is COUNT_FALLTHROUGH_EDGE the right thing here?
940       NewBBs.back()->addSuccessor(MergeBlock, TotalCount); // uncond branch
941     }
942   }
943 
944   // Update the execution count.
945   NewBBs.back()->setExecutionCount(TotalCount);
946 
947   // Update BB and BB layout.
948   Function.insertBasicBlocks(&IndCallBlock, std::move(NewBBs));
949   assert(Function.validateCFG());
950 
951   return MergeBlock;
952 }
953 
954 size_t IndirectCallPromotion::canPromoteCallsite(
955     const BinaryBasicBlock &BB, const MCInst &Inst,
956     const std::vector<Callsite> &Targets, uint64_t NumCalls) {
957   BinaryFunction *BF = BB.getFunction();
958   const BinaryContext &BC = BF->getBinaryContext();
959 
960   if (BB.getKnownExecutionCount() < opts::ExecutionCountThreshold)
961     return 0;
962 
963   const bool IsJumpTable = BF->getJumpTable(Inst);
964 
965   auto computeStats = [&](size_t N) {
966     for (size_t I = 0; I < N; ++I)
967       if (IsJumpTable)
968         TotalNumFrequentJmps += Targets[I].Branches;
969       else
970         TotalNumFrequentCalls += Targets[I].Branches;
971   };
972 
973   // If we have no targets (or no calls), skip this callsite.
974   if (Targets.empty() || !NumCalls) {
975     if (opts::Verbosity >= 1) {
976       const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
977       outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx << " in "
978              << BB.getName() << ", calls = " << NumCalls
979              << ", targets empty or NumCalls == 0.\n";
980     }
981     return 0;
982   }
983 
984   size_t TopN = opts::ICPTopN;
985   if (IsJumpTable)
986     TopN = opts::ICPJumpTablesTopN ? opts::ICPJumpTablesTopN : TopN;
987   else
988     TopN = opts::ICPCallsTopN ? opts::ICPCallsTopN : TopN;
989 
990   const size_t TrialN = TopN ? std::min(TopN, Targets.size()) : Targets.size();
991 
992   if (opts::ICPTopCallsites && !BC.MIB->hasAnnotation(Inst, "DoICP"))
993     return 0;
994 
995   // Pick the top N targets.
996   uint64_t TotalMispredictsTopN = 0;
997   size_t N = 0;
998 
999   if (opts::ICPUseMispredicts &&
1000       (!IsJumpTable || opts::ICPJumpTablesByTarget)) {
1001     // Count total number of mispredictions for (at most) the top N targets.
1002     // We may choose a smaller N (TrialN vs. N) if the frequency threshold
1003     // is exceeded by fewer targets.
1004     double Threshold = double(opts::ICPMispredictThreshold);
1005     for (size_t I = 0; I < TrialN && Threshold > 0; ++I, ++N) {
1006       Threshold -= (100.0 * Targets[I].Mispreds) / NumCalls;
1007       TotalMispredictsTopN += Targets[I].Mispreds;
1008     }
1009     computeStats(N);
1010 
1011     // Compute the misprediction frequency of the top N call targets.  If this
1012     // frequency is greater than the threshold, we should try ICP on this
1013     // callsite.
1014     const double TopNFrequency = (100.0 * TotalMispredictsTopN) / NumCalls;
1015     if (TopNFrequency == 0 || TopNFrequency < opts::ICPMispredictThreshold) {
1016       if (opts::Verbosity >= 1) {
1017         const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
1018         outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx
1019                << " in " << BB.getName() << ", calls = " << NumCalls
1020                << ", top N mis. frequency " << format("%.1f", TopNFrequency)
1021                << "% < " << opts::ICPMispredictThreshold << "%\n";
1022       }
1023       return 0;
1024     }
1025   } else {
1026     size_t MaxTargets = 0;
1027 
1028     // Count total number of calls for (at most) the top N targets.
1029     // We may choose a smaller N (TrialN vs. N) if the frequency threshold
1030     // is exceeded by fewer targets.
1031     const unsigned TotalThreshold = IsJumpTable
1032                                         ? opts::ICPJTTotalPercentThreshold
1033                                         : opts::ICPCallsTotalPercentThreshold;
1034     const unsigned RemainingThreshold =
1035         IsJumpTable ? opts::ICPJTRemainingPercentThreshold
1036                     : opts::ICPCallsRemainingPercentThreshold;
1037     uint64_t NumRemainingCalls = NumCalls;
1038     for (size_t I = 0; I < TrialN; ++I, ++MaxTargets) {
1039       if (100 * Targets[I].Branches < NumCalls * TotalThreshold)
1040         break;
1041       if (100 * Targets[I].Branches < NumRemainingCalls * RemainingThreshold)
1042         break;
1043       if (N + (Targets[I].JTIndices.empty() ? 1 : Targets[I].JTIndices.size()) >
1044           TrialN)
1045         break;
1046       TotalMispredictsTopN += Targets[I].Mispreds;
1047       NumRemainingCalls -= Targets[I].Branches;
1048       N += Targets[I].JTIndices.empty() ? 1 : Targets[I].JTIndices.size();
1049     }
1050     computeStats(MaxTargets);
1051 
1052     // Don't check misprediction frequency for jump tables -- we don't really
1053     // care as long as we are saving loads from the jump table.
1054     if (!IsJumpTable || opts::ICPJumpTablesByTarget) {
1055       // Compute the misprediction frequency of the top N call targets.  If
1056       // this frequency is less than the threshold, we should skip ICP at
1057       // this callsite.
1058       const double TopNMispredictFrequency =
1059           (100.0 * TotalMispredictsTopN) / NumCalls;
1060 
1061       if (TopNMispredictFrequency < opts::ICPMispredictThreshold) {
1062         if (opts::Verbosity >= 1) {
1063           const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
1064           outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx
1065                  << " in " << BB.getName() << ", calls = " << NumCalls
1066                  << ", top N mispredict frequency "
1067                  << format("%.1f", TopNMispredictFrequency) << "% < "
1068                  << opts::ICPMispredictThreshold << "%\n";
1069         }
1070         return 0;
1071       }
1072     }
1073   }
1074 
1075   // Filter by inline-ability of target functions, stop at first target that
1076   // can't be inlined.
1077   if (!IsJumpTable && opts::ICPPeelForInline) {
1078     for (size_t I = 0; I < N; ++I) {
1079       const MCSymbol *TargetSym = Targets[I].To.Sym;
1080       const BinaryFunction *TargetBF = BC.getFunctionForSymbol(TargetSym);
1081       if (!TargetBF || !BinaryFunctionPass::shouldOptimize(*TargetBF) ||
1082           getInliningInfo(*TargetBF).Type == InliningType::INL_NONE) {
1083         N = I;
1084         break;
1085       }
1086     }
1087   }
1088 
1089   // Filter functions that can have ICP applied (for debugging)
1090   if (!opts::ICPFuncsList.empty()) {
1091     for (std::string &Name : opts::ICPFuncsList)
1092       if (BF->hasName(Name))
1093         return N;
1094     return 0;
1095   }
1096 
1097   return N;
1098 }
1099 
1100 void IndirectCallPromotion::printCallsiteInfo(
1101     const BinaryBasicBlock &BB, const MCInst &Inst,
1102     const std::vector<Callsite> &Targets, const size_t N,
1103     uint64_t NumCalls) const {
1104   BinaryContext &BC = BB.getFunction()->getBinaryContext();
1105   const bool IsTailCall = BC.MIB->isTailCall(Inst);
1106   const bool IsJumpTable = BB.getFunction()->getJumpTable(Inst);
1107   const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
1108 
1109   outs() << "BOLT-INFO: ICP candidate branch info: " << *BB.getFunction()
1110          << " @ " << InstIdx << " in " << BB.getName()
1111          << " -> calls = " << NumCalls
1112          << (IsTailCall ? " (tail)" : (IsJumpTable ? " (jump table)" : ""))
1113          << "\n";
1114   for (size_t I = 0; I < N; I++) {
1115     const double Frequency = 100.0 * Targets[I].Branches / NumCalls;
1116     const double MisFrequency = 100.0 * Targets[I].Mispreds / NumCalls;
1117     outs() << "BOLT-INFO:   ";
1118     if (Targets[I].To.Sym)
1119       outs() << Targets[I].To.Sym->getName();
1120     else
1121       outs() << Targets[I].To.Addr;
1122     outs() << ", calls = " << Targets[I].Branches
1123            << ", mispreds = " << Targets[I].Mispreds
1124            << ", taken freq = " << format("%.1f", Frequency) << "%"
1125            << ", mis. freq = " << format("%.1f", MisFrequency) << "%";
1126     bool First = true;
1127     for (uint64_t JTIndex : Targets[I].JTIndices) {
1128       outs() << (First ? ", indices = " : ", ") << JTIndex;
1129       First = false;
1130     }
1131     outs() << "\n";
1132   }
1133 
1134   LLVM_DEBUG({
1135     dbgs() << "BOLT-INFO: ICP original call instruction:";
1136     BC.printInstruction(dbgs(), Inst, Targets[0].From.Addr, nullptr, true);
1137   });
1138 }
1139 
1140 Error IndirectCallPromotion::runOnFunctions(BinaryContext &BC) {
1141   if (opts::ICP == ICP_NONE)
1142     return Error::success();
1143 
1144   auto &BFs = BC.getBinaryFunctions();
1145 
1146   const bool OptimizeCalls = (opts::ICP == ICP_CALLS || opts::ICP == ICP_ALL);
1147   const bool OptimizeJumpTables =
1148       (opts::ICP == ICP_JUMP_TABLES || opts::ICP == ICP_ALL);
1149 
1150   std::unique_ptr<RegAnalysis> RA;
1151   std::unique_ptr<BinaryFunctionCallGraph> CG;
1152   if (OptimizeJumpTables) {
1153     CG.reset(new BinaryFunctionCallGraph(buildCallGraph(BC)));
1154     RA.reset(new RegAnalysis(BC, &BFs, &*CG));
1155   }
1156 
1157   // If icp-top-callsites is enabled, compute the total number of indirect
1158   // calls and then optimize the hottest callsites that contribute to that
1159   // total.
1160   SetVector<BinaryFunction *> Functions;
1161   if (opts::ICPTopCallsites == 0) {
1162     for (auto &KV : BFs)
1163       Functions.insert(&KV.second);
1164   } else {
1165     using IndirectCallsite = std::tuple<uint64_t, MCInst *, BinaryFunction *>;
1166     std::vector<IndirectCallsite> IndirectCalls;
1167     size_t TotalIndirectCalls = 0;
1168 
1169     // Find all the indirect callsites.
1170     for (auto &BFIt : BFs) {
1171       BinaryFunction &Function = BFIt.second;
1172 
1173       if (!shouldOptimize(Function))
1174         continue;
1175 
1176       const bool HasLayout = !Function.getLayout().block_empty();
1177 
1178       for (BinaryBasicBlock &BB : Function) {
1179         if (HasLayout && Function.isSplit() && BB.isCold())
1180           continue;
1181 
1182         for (MCInst &Inst : BB) {
1183           const bool IsJumpTable = Function.getJumpTable(Inst);
1184           const bool HasIndirectCallProfile =
1185               BC.MIB->hasAnnotation(Inst, "CallProfile");
1186           const bool IsDirectCall =
1187               (BC.MIB->isCall(Inst) && BC.MIB->getTargetSymbol(Inst, 0));
1188 
1189           if (!IsDirectCall &&
1190               ((HasIndirectCallProfile && !IsJumpTable && OptimizeCalls) ||
1191                (IsJumpTable && OptimizeJumpTables))) {
1192             uint64_t NumCalls = 0;
1193             for (const Callsite &BInfo : getCallTargets(BB, Inst))
1194               NumCalls += BInfo.Branches;
1195             IndirectCalls.push_back(
1196                 std::make_tuple(NumCalls, &Inst, &Function));
1197             TotalIndirectCalls += NumCalls;
1198           }
1199         }
1200       }
1201     }
1202 
1203     // Sort callsites by execution count.
1204     llvm::sort(reverse(IndirectCalls));
1205 
1206     // Find callsites that contribute to the top "opts::ICPTopCallsites"%
1207     // number of calls.
1208     const float TopPerc = opts::ICPTopCallsites / 100.0f;
1209     int64_t MaxCalls = TotalIndirectCalls * TopPerc;
1210     uint64_t LastFreq = std::numeric_limits<uint64_t>::max();
1211     size_t Num = 0;
1212     for (const IndirectCallsite &IC : IndirectCalls) {
1213       const uint64_t CurFreq = std::get<0>(IC);
1214       // Once we decide to stop, include at least all branches that share the
1215       // same frequency of the last one to avoid non-deterministic behavior
1216       // (e.g. turning on/off ICP depending on the order of functions)
1217       if (MaxCalls <= 0 && CurFreq != LastFreq)
1218         break;
1219       MaxCalls -= CurFreq;
1220       LastFreq = CurFreq;
1221       BC.MIB->addAnnotation(*std::get<1>(IC), "DoICP", true);
1222       Functions.insert(std::get<2>(IC));
1223       ++Num;
1224     }
1225     outs() << "BOLT-INFO: ICP Total indirect calls = " << TotalIndirectCalls
1226            << ", " << Num << " callsites cover " << opts::ICPTopCallsites
1227            << "% of all indirect calls\n";
1228   }
1229 
1230   for (BinaryFunction *FuncPtr : Functions) {
1231     BinaryFunction &Function = *FuncPtr;
1232 
1233     if (!shouldOptimize(Function))
1234       continue;
1235 
1236     const bool HasLayout = !Function.getLayout().block_empty();
1237 
1238     // Total number of indirect calls issued from the current Function.
1239     // (a fraction of TotalIndirectCalls)
1240     uint64_t FuncTotalIndirectCalls = 0;
1241     uint64_t FuncTotalIndirectJmps = 0;
1242 
1243     std::vector<BinaryBasicBlock *> BBs;
1244     for (BinaryBasicBlock &BB : Function) {
1245       // Skip indirect calls in cold blocks.
1246       if (!HasLayout || !Function.isSplit() || !BB.isCold())
1247         BBs.push_back(&BB);
1248     }
1249     if (BBs.empty())
1250       continue;
1251 
1252     DataflowInfoManager Info(Function, RA.get(), nullptr);
1253     while (!BBs.empty()) {
1254       BinaryBasicBlock *BB = BBs.back();
1255       BBs.pop_back();
1256 
1257       for (unsigned Idx = 0; Idx < BB->size(); ++Idx) {
1258         MCInst &Inst = BB->getInstructionAtIndex(Idx);
1259         const ptrdiff_t InstIdx = &Inst - &(*BB->begin());
1260         const bool IsTailCall = BC.MIB->isTailCall(Inst);
1261         const bool HasIndirectCallProfile =
1262             BC.MIB->hasAnnotation(Inst, "CallProfile");
1263         const bool IsJumpTable = Function.getJumpTable(Inst);
1264 
1265         if (BC.MIB->isCall(Inst))
1266           TotalCalls += BB->getKnownExecutionCount();
1267 
1268         if (IsJumpTable && !OptimizeJumpTables)
1269           continue;
1270 
1271         if (!IsJumpTable && (!HasIndirectCallProfile || !OptimizeCalls))
1272           continue;
1273 
1274         // Ignore direct calls.
1275         if (BC.MIB->isCall(Inst) && BC.MIB->getTargetSymbol(Inst, 0))
1276           continue;
1277 
1278         assert((BC.MIB->isCall(Inst) || BC.MIB->isIndirectBranch(Inst)) &&
1279                "expected a call or an indirect jump instruction");
1280 
1281         if (IsJumpTable)
1282           ++TotalJumpTableCallsites;
1283         else
1284           ++TotalIndirectCallsites;
1285 
1286         std::vector<Callsite> Targets = getCallTargets(*BB, Inst);
1287 
1288         // Compute the total number of calls from this particular callsite.
1289         uint64_t NumCalls = 0;
1290         for (const Callsite &BInfo : Targets)
1291           NumCalls += BInfo.Branches;
1292         if (!IsJumpTable)
1293           FuncTotalIndirectCalls += NumCalls;
1294         else
1295           FuncTotalIndirectJmps += NumCalls;
1296 
1297         // If FLAGS regs is alive after this jmp site, do not try
1298         // promoting because we will clobber FLAGS.
1299         if (IsJumpTable) {
1300           ErrorOr<const BitVector &> State =
1301               Info.getLivenessAnalysis().getStateBefore(Inst);
1302           if (!State || (State && (*State)[BC.MIB->getFlagsReg()])) {
1303             if (opts::Verbosity >= 1)
1304               outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1305                      << InstIdx << " in " << BB->getName()
1306                      << ", calls = " << NumCalls
1307                      << (State ? ", cannot clobber flags reg.\n"
1308                                : ", no liveness data available.\n");
1309             continue;
1310           }
1311         }
1312 
1313         // Should this callsite be optimized?  Return the number of targets
1314         // to use when promoting this call.  A value of zero means to skip
1315         // this callsite.
1316         size_t N = canPromoteCallsite(*BB, Inst, Targets, NumCalls);
1317 
1318         // If it is a jump table and it failed to meet our initial threshold,
1319         // proceed to findCallTargetSymbols -- it may reevaluate N if
1320         // memory profile is present
1321         if (!N && !IsJumpTable)
1322           continue;
1323 
1324         if (opts::Verbosity >= 1)
1325           printCallsiteInfo(*BB, Inst, Targets, N, NumCalls);
1326 
1327         // Find MCSymbols or absolute addresses for each call target.
1328         MCInst *TargetFetchInst = nullptr;
1329         const SymTargetsType SymTargets =
1330             findCallTargetSymbols(Targets, N, *BB, Inst, TargetFetchInst);
1331 
1332         // findCallTargetSymbols may have changed N if mem profile is available
1333         // for jump tables
1334         if (!N)
1335           continue;
1336 
1337         LLVM_DEBUG(printDecision(dbgs(), Targets, N));
1338 
1339         // If we can't resolve any of the target symbols, punt on this callsite.
1340         // TODO: can this ever happen?
1341         if (SymTargets.size() < N) {
1342           const size_t LastTarget = SymTargets.size();
1343           if (opts::Verbosity >= 1)
1344             outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1345                    << InstIdx << " in " << BB->getName()
1346                    << ", calls = " << NumCalls
1347                    << ", ICP failed to find target symbol for "
1348                    << Targets[LastTarget].To.Sym->getName() << "\n";
1349           continue;
1350         }
1351 
1352         MethodInfoType MethodInfo;
1353 
1354         if (!IsJumpTable) {
1355           MethodInfo = maybeGetVtableSyms(*BB, Inst, SymTargets);
1356           TotalMethodLoadsEliminated += MethodInfo.first.empty() ? 0 : 1;
1357           LLVM_DEBUG(dbgs()
1358                      << "BOLT-INFO: ICP "
1359                      << (!MethodInfo.first.empty() ? "found" : "did not find")
1360                      << " vtables for all methods.\n");
1361         } else if (TargetFetchInst) {
1362           ++TotalIndexBasedJumps;
1363           MethodInfo.second.push_back(TargetFetchInst);
1364         }
1365 
1366         // Generate new promoted call code for this callsite.
1367         MCPlusBuilder::BlocksVectorTy ICPcode =
1368             (IsJumpTable && !opts::ICPJumpTablesByTarget)
1369                 ? BC.MIB->jumpTablePromotion(Inst, SymTargets,
1370                                              MethodInfo.second, BC.Ctx.get())
1371                 : BC.MIB->indirectCallPromotion(
1372                       Inst, SymTargets, MethodInfo.first, MethodInfo.second,
1373                       opts::ICPOldCodeSequence, BC.Ctx.get());
1374 
1375         if (ICPcode.empty()) {
1376           if (opts::Verbosity >= 1)
1377             outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1378                    << InstIdx << " in " << BB->getName()
1379                    << ", calls = " << NumCalls
1380                    << ", unable to generate promoted call code.\n";
1381           continue;
1382         }
1383 
1384         LLVM_DEBUG({
1385           uint64_t Offset = Targets[0].From.Addr;
1386           dbgs() << "BOLT-INFO: ICP indirect call code:\n";
1387           for (const auto &entry : ICPcode) {
1388             const MCSymbol *const &Sym = entry.first;
1389             const InstructionListType &Insts = entry.second;
1390             if (Sym)
1391               dbgs() << Sym->getName() << ":\n";
1392             Offset = BC.printInstructions(dbgs(), Insts.begin(), Insts.end(),
1393                                           Offset);
1394           }
1395           dbgs() << "---------------------------------------------------\n";
1396         });
1397 
1398         // Rewrite the CFG with the newly generated ICP code.
1399         std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs =
1400             rewriteCall(*BB, Inst, std::move(ICPcode), MethodInfo.second);
1401 
1402         // Fix the CFG after inserting the new basic blocks.
1403         BinaryBasicBlock *MergeBlock =
1404             fixCFG(*BB, IsTailCall, IsJumpTable, std::move(NewBBs), Targets);
1405 
1406         // Since the tail of the original block was split off and it may contain
1407         // additional indirect calls, we must add the merge block to the set of
1408         // blocks to process.
1409         if (MergeBlock)
1410           BBs.push_back(MergeBlock);
1411 
1412         if (opts::Verbosity >= 1)
1413           outs() << "BOLT-INFO: ICP succeeded in " << Function << " @ "
1414                  << InstIdx << " in " << BB->getName()
1415                  << " -> calls = " << NumCalls << "\n";
1416 
1417         if (IsJumpTable)
1418           ++TotalOptimizedJumpTableCallsites;
1419         else
1420           ++TotalOptimizedIndirectCallsites;
1421 
1422         Modified.insert(&Function);
1423       }
1424     }
1425     TotalIndirectCalls += FuncTotalIndirectCalls;
1426     TotalIndirectJmps += FuncTotalIndirectJmps;
1427   }
1428 
1429   outs() << "BOLT-INFO: ICP total indirect callsites with profile = "
1430          << TotalIndirectCallsites << "\n"
1431          << "BOLT-INFO: ICP total jump table callsites = "
1432          << TotalJumpTableCallsites << "\n"
1433          << "BOLT-INFO: ICP total number of calls = " << TotalCalls << "\n"
1434          << "BOLT-INFO: ICP percentage of calls that are indirect = "
1435          << format("%.1f", (100.0 * TotalIndirectCalls) / TotalCalls) << "%\n"
1436          << "BOLT-INFO: ICP percentage of indirect calls that can be "
1437             "optimized = "
1438          << format("%.1f", (100.0 * TotalNumFrequentCalls) /
1439                                std::max<size_t>(TotalIndirectCalls, 1))
1440          << "%\n"
1441          << "BOLT-INFO: ICP percentage of indirect callsites that are "
1442             "optimized = "
1443          << format("%.1f", (100.0 * TotalOptimizedIndirectCallsites) /
1444                                std::max<uint64_t>(TotalIndirectCallsites, 1))
1445          << "%\n"
1446          << "BOLT-INFO: ICP number of method load elimination candidates = "
1447          << TotalMethodLoadEliminationCandidates << "\n"
1448          << "BOLT-INFO: ICP percentage of method calls candidates that have "
1449             "loads eliminated = "
1450          << format("%.1f", (100.0 * TotalMethodLoadsEliminated) /
1451                                std::max<uint64_t>(
1452                                    TotalMethodLoadEliminationCandidates, 1))
1453          << "%\n"
1454          << "BOLT-INFO: ICP percentage of indirect branches that are "
1455             "optimized = "
1456          << format("%.1f", (100.0 * TotalNumFrequentJmps) /
1457                                std::max<uint64_t>(TotalIndirectJmps, 1))
1458          << "%\n"
1459          << "BOLT-INFO: ICP percentage of jump table callsites that are "
1460          << "optimized = "
1461          << format("%.1f", (100.0 * TotalOptimizedJumpTableCallsites) /
1462                                std::max<uint64_t>(TotalJumpTableCallsites, 1))
1463          << "%\n"
1464          << "BOLT-INFO: ICP number of jump table callsites that can use hot "
1465          << "indices = " << TotalIndexBasedCandidates << "\n"
1466          << "BOLT-INFO: ICP percentage of jump table callsites that use hot "
1467             "indices = "
1468          << format("%.1f", (100.0 * TotalIndexBasedJumps) /
1469                                std::max<uint64_t>(TotalIndexBasedCandidates, 1))
1470          << "%\n";
1471 
1472 #ifndef NDEBUG
1473   verifyProfile(BFs);
1474 #endif
1475   return Error::success();
1476 }
1477 
1478 } // namespace bolt
1479 } // namespace llvm
1480