xref: /llvm-project/bolt/lib/Passes/IndirectCallPromotion.cpp (revision 3023b15fb1ec00dbe6a1cb630236125f500978ef)
1 //===- bolt/Passes/IndirectCallPromotion.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the IndirectCallPromotion class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "bolt/Passes/IndirectCallPromotion.h"
14 #include "bolt/Core/BinaryFunctionCallGraph.h"
15 #include "bolt/Passes/DataflowInfoManager.h"
16 #include "bolt/Passes/Inliner.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/Support/CommandLine.h"
19 #include <iterator>
20 
21 #define DEBUG_TYPE "ICP"
22 #define DEBUG_VERBOSE(Level, X)                                                \
23   if (opts::Verbosity >= (Level)) {                                            \
24     X;                                                                         \
25   }
26 
27 using namespace llvm;
28 using namespace bolt;
29 
30 namespace opts {
31 
32 extern cl::OptionCategory BoltOptCategory;
33 
34 extern cl::opt<IndirectCallPromotionType> ICP;
35 extern cl::opt<unsigned> Verbosity;
36 extern cl::opt<unsigned> ExecutionCountThreshold;
37 
38 static cl::opt<unsigned> ICPJTRemainingPercentThreshold(
39     "icp-jt-remaining-percent-threshold",
40     cl::desc("The percentage threshold against remaining unpromoted indirect "
41              "call count for the promotion for jump tables"),
42     cl::init(30), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
43 
44 static cl::opt<unsigned> ICPJTTotalPercentThreshold(
45     "icp-jt-total-percent-threshold",
46     cl::desc(
47         "The percentage threshold against total count for the promotion for "
48         "jump tables"),
49     cl::init(5), cl::Hidden, cl::cat(BoltOptCategory));
50 
51 static cl::opt<unsigned> ICPCallsRemainingPercentThreshold(
52     "icp-calls-remaining-percent-threshold",
53     cl::desc("The percentage threshold against remaining unpromoted indirect "
54              "call count for the promotion for calls"),
55     cl::init(50), cl::Hidden, cl::cat(BoltOptCategory));
56 
57 static cl::opt<unsigned> ICPCallsTotalPercentThreshold(
58     "icp-calls-total-percent-threshold",
59     cl::desc(
60         "The percentage threshold against total count for the promotion for "
61         "calls"),
62     cl::init(30), cl::Hidden, cl::cat(BoltOptCategory));
63 
64 static cl::opt<unsigned> ICPMispredictThreshold(
65     "indirect-call-promotion-mispredict-threshold",
66     cl::desc("misprediction threshold for skipping ICP on an "
67              "indirect call"),
68     cl::init(0), cl::cat(BoltOptCategory));
69 
70 static cl::alias ICPMispredictThresholdAlias(
71     "icp-mp-threshold",
72     cl::desc("alias for --indirect-call-promotion-mispredict-threshold"),
73     cl::aliasopt(ICPMispredictThreshold));
74 
75 static cl::opt<bool> ICPUseMispredicts(
76     "indirect-call-promotion-use-mispredicts",
77     cl::desc("use misprediction frequency for determining whether or not ICP "
78              "should be applied at a callsite.  The "
79              "-indirect-call-promotion-mispredict-threshold value will be used "
80              "by this heuristic"),
81     cl::cat(BoltOptCategory));
82 
83 static cl::alias ICPUseMispredictsAlias(
84     "icp-use-mp",
85     cl::desc("alias for --indirect-call-promotion-use-mispredicts"),
86     cl::aliasopt(ICPUseMispredicts));
87 
88 static cl::opt<unsigned>
89     ICPTopN("indirect-call-promotion-topn",
90             cl::desc("limit number of targets to consider when doing indirect "
91                      "call promotion. 0 = no limit"),
92             cl::init(3), cl::cat(BoltOptCategory));
93 
94 static cl::alias
95     ICPTopNAlias("icp-topn",
96                  cl::desc("alias for --indirect-call-promotion-topn"),
97                  cl::aliasopt(ICPTopN));
98 
99 static cl::opt<unsigned> ICPCallsTopN(
100     "indirect-call-promotion-calls-topn",
101     cl::desc("limit number of targets to consider when doing indirect "
102              "call promotion on calls. 0 = no limit"),
103     cl::init(0), cl::cat(BoltOptCategory));
104 
105 static cl::alias ICPCallsTopNAlias(
106     "icp-calls-topn",
107     cl::desc("alias for --indirect-call-promotion-calls-topn"),
108     cl::aliasopt(ICPCallsTopN));
109 
110 static cl::opt<unsigned> ICPJumpTablesTopN(
111     "indirect-call-promotion-jump-tables-topn",
112     cl::desc("limit number of targets to consider when doing indirect "
113              "call promotion on jump tables. 0 = no limit"),
114     cl::init(0), cl::cat(BoltOptCategory));
115 
116 static cl::alias ICPJumpTablesTopNAlias(
117     "icp-jt-topn",
118     cl::desc("alias for --indirect-call-promotion-jump-tables-topn"),
119     cl::aliasopt(ICPJumpTablesTopN));
120 
121 static cl::opt<bool> EliminateLoads(
122     "icp-eliminate-loads",
123     cl::desc("enable load elimination using memory profiling data when "
124              "performing ICP"),
125     cl::init(true), cl::cat(BoltOptCategory));
126 
127 static cl::opt<unsigned> ICPTopCallsites(
128     "icp-top-callsites",
129     cl::desc("optimize hottest calls until at least this percentage of all "
130              "indirect calls frequency is covered. 0 = all callsites"),
131     cl::init(99), cl::Hidden, cl::cat(BoltOptCategory));
132 
133 static cl::list<std::string>
134     ICPFuncsList("icp-funcs", cl::CommaSeparated,
135                  cl::desc("list of functions to enable ICP for"),
136                  cl::value_desc("func1,func2,func3,..."), cl::Hidden,
137                  cl::cat(BoltOptCategory));
138 
139 static cl::opt<bool>
140     ICPOldCodeSequence("icp-old-code-sequence",
141                        cl::desc("use old code sequence for promoted calls"),
142                        cl::Hidden, cl::cat(BoltOptCategory));
143 
144 static cl::opt<bool> ICPJumpTablesByTarget(
145     "icp-jump-tables-targets",
146     cl::desc(
147         "for jump tables, optimize indirect jmp targets instead of indices"),
148     cl::Hidden, cl::cat(BoltOptCategory));
149 
150 static cl::alias
151     ICPJumpTablesByTargetAlias("icp-jt-targets",
152                                cl::desc("alias for --icp-jump-tables-targets"),
153                                cl::aliasopt(ICPJumpTablesByTarget));
154 
155 static cl::opt<bool> ICPPeelForInline(
156     "icp-inline", cl::desc("only promote call targets eligible for inlining"),
157     cl::Hidden, cl::cat(BoltOptCategory));
158 
159 } // namespace opts
160 
161 #ifndef NDEBUG
162 static bool verifyProfile(std::map<uint64_t, BinaryFunction> &BFs) {
163   bool IsValid = true;
164   for (auto &BFI : BFs) {
165     BinaryFunction &BF = BFI.second;
166     if (!BF.isSimple())
167       continue;
168     for (const BinaryBasicBlock &BB : BF) {
169       auto BI = BB.branch_info_begin();
170       for (BinaryBasicBlock *SuccBB : BB.successors()) {
171         if (BI->Count != BinaryBasicBlock::COUNT_NO_PROFILE && BI->Count > 0) {
172           if (BB.getKnownExecutionCount() == 0 ||
173               SuccBB->getKnownExecutionCount() == 0) {
174             BF.getBinaryContext().errs()
175                 << "BOLT-WARNING: profile verification failed after ICP for "
176                    "function "
177                 << BF << '\n';
178             IsValid = false;
179           }
180         }
181         ++BI;
182       }
183     }
184   }
185   return IsValid;
186 }
187 #endif
188 
189 namespace llvm {
190 namespace bolt {
191 
192 IndirectCallPromotion::Callsite::Callsite(BinaryFunction &BF,
193                                           const IndirectCallProfile &ICP)
194     : From(BF.getSymbol()), To(ICP.Offset), Mispreds(ICP.Mispreds),
195       Branches(ICP.Count) {
196   if (ICP.Symbol) {
197     To.Sym = ICP.Symbol;
198     To.Addr = 0;
199   }
200 }
201 
202 void IndirectCallPromotion::printDecision(
203     llvm::raw_ostream &OS,
204     std::vector<IndirectCallPromotion::Callsite> &Targets, unsigned N) const {
205   uint64_t TotalCount = 0;
206   uint64_t TotalMispreds = 0;
207   for (const Callsite &S : Targets) {
208     TotalCount += S.Branches;
209     TotalMispreds += S.Mispreds;
210   }
211   if (!TotalCount)
212     TotalCount = 1;
213   if (!TotalMispreds)
214     TotalMispreds = 1;
215 
216   OS << "BOLT-INFO: ICP decision for call site with " << Targets.size()
217      << " targets, Count = " << TotalCount << ", Mispreds = " << TotalMispreds
218      << "\n";
219 
220   size_t I = 0;
221   for (const Callsite &S : Targets) {
222     OS << "Count = " << S.Branches << ", "
223        << format("%.1f", (100.0 * S.Branches) / TotalCount) << ", "
224        << "Mispreds = " << S.Mispreds << ", "
225        << format("%.1f", (100.0 * S.Mispreds) / TotalMispreds);
226     if (I < N)
227       OS << " * to be optimized *";
228     if (!S.JTIndices.empty()) {
229       OS << " Indices:";
230       for (const uint64_t Idx : S.JTIndices)
231         OS << " " << Idx;
232     }
233     OS << "\n";
234     I += S.JTIndices.empty() ? 1 : S.JTIndices.size();
235   }
236 }
237 
238 // Get list of targets for a given call sorted by most frequently
239 // called first.
240 std::vector<IndirectCallPromotion::Callsite>
241 IndirectCallPromotion::getCallTargets(BinaryBasicBlock &BB,
242                                       const MCInst &Inst) const {
243   BinaryFunction &BF = *BB.getFunction();
244   const BinaryContext &BC = BF.getBinaryContext();
245   std::vector<Callsite> Targets;
246 
247   if (const JumpTable *JT = BF.getJumpTable(Inst)) {
248     // Don't support PIC jump tables for now
249     if (!opts::ICPJumpTablesByTarget && JT->Type == JumpTable::JTT_PIC)
250       return Targets;
251     const Location From(BF.getSymbol());
252     const std::pair<size_t, size_t> Range =
253         JT->getEntriesForAddress(BC.MIB->getJumpTable(Inst));
254     assert(JT->Counts.empty() || JT->Counts.size() >= Range.second);
255     JumpTable::JumpInfo DefaultJI;
256     const JumpTable::JumpInfo *JI =
257         JT->Counts.empty() ? &DefaultJI : &JT->Counts[Range.first];
258     const size_t JIAdj = JT->Counts.empty() ? 0 : 1;
259     assert(JT->Type == JumpTable::JTT_PIC ||
260            JT->EntrySize == BC.AsmInfo->getCodePointerSize());
261     for (size_t I = Range.first; I < Range.second; ++I, JI += JIAdj) {
262       MCSymbol *Entry = JT->Entries[I];
263       const BinaryBasicBlock *ToBB = BF.getBasicBlockForLabel(Entry);
264       assert(ToBB || Entry == BF.getFunctionEndLabel() ||
265              Entry == BF.getFunctionEndLabel(FragmentNum::cold()));
266       if (Entry == BF.getFunctionEndLabel() ||
267           Entry == BF.getFunctionEndLabel(FragmentNum::cold()))
268         continue;
269       const Location To(Entry);
270       const BinaryBasicBlock::BinaryBranchInfo &BI = BB.getBranchInfo(*ToBB);
271       Targets.emplace_back(From, To, BI.MispredictedCount, BI.Count,
272                            I - Range.first);
273     }
274 
275     // Sort by symbol then addr.
276     llvm::sort(Targets, [](const Callsite &A, const Callsite &B) {
277       if (A.To.Sym && B.To.Sym)
278         return A.To.Sym < B.To.Sym;
279       else if (A.To.Sym && !B.To.Sym)
280         return true;
281       else if (!A.To.Sym && B.To.Sym)
282         return false;
283       else
284         return A.To.Addr < B.To.Addr;
285     });
286 
287     // Targets may contain multiple entries to the same target, but using
288     // different indices. Their profile will report the same number of branches
289     // for different indices if the target is the same. That's because we don't
290     // profile the index value, but only the target via LBR.
291     auto First = Targets.begin();
292     auto Last = Targets.end();
293     auto Result = First;
294     while (++First != Last) {
295       Callsite &A = *Result;
296       const Callsite &B = *First;
297       if (A.To.Sym && B.To.Sym && A.To.Sym == B.To.Sym)
298         A.JTIndices.insert(A.JTIndices.end(), B.JTIndices.begin(),
299                            B.JTIndices.end());
300       else
301         *(++Result) = *First;
302     }
303     ++Result;
304 
305     LLVM_DEBUG(if (Targets.end() - Result > 0) {
306       dbgs() << "BOLT-INFO: ICP: " << (Targets.end() - Result)
307              << " duplicate targets removed\n";
308     });
309 
310     Targets.erase(Result, Targets.end());
311   } else {
312     // Don't try to optimize PC relative indirect calls.
313     if (Inst.getOperand(0).isReg() &&
314         Inst.getOperand(0).getReg() == BC.MRI->getProgramCounter())
315       return Targets;
316 
317     const auto ICSP = BC.MIB->tryGetAnnotationAs<IndirectCallSiteProfile>(
318         Inst, "CallProfile");
319     if (ICSP) {
320       for (const IndirectCallProfile &CSP : ICSP.get()) {
321         Callsite Site(BF, CSP);
322         if (Site.isValid())
323           Targets.emplace_back(std::move(Site));
324       }
325     }
326   }
327 
328   // Sort by target count, number of indices in case of jump table, and
329   // mispredicts. We prioritize targets with high count, small number of indices
330   // and high mispredicts. Break ties by selecting targets with lower addresses.
331   llvm::stable_sort(Targets, [](const Callsite &A, const Callsite &B) {
332     if (A.Branches != B.Branches)
333       return A.Branches > B.Branches;
334     if (A.JTIndices.size() != B.JTIndices.size())
335       return A.JTIndices.size() < B.JTIndices.size();
336     if (A.Mispreds != B.Mispreds)
337       return A.Mispreds > B.Mispreds;
338     return A.To.Addr < B.To.Addr;
339   });
340 
341   // Remove non-symbol targets
342   llvm::erase_if(Targets, [](const Callsite &CS) { return !CS.To.Sym; });
343 
344   LLVM_DEBUG(if (BF.getJumpTable(Inst)) {
345     uint64_t TotalCount = 0;
346     uint64_t TotalMispreds = 0;
347     for (const Callsite &S : Targets) {
348       TotalCount += S.Branches;
349       TotalMispreds += S.Mispreds;
350     }
351     if (!TotalCount)
352       TotalCount = 1;
353     if (!TotalMispreds)
354       TotalMispreds = 1;
355 
356     dbgs() << "BOLT-INFO: ICP: jump table size = " << Targets.size()
357            << ", Count = " << TotalCount << ", Mispreds = " << TotalMispreds
358            << "\n";
359 
360     size_t I = 0;
361     for (const Callsite &S : Targets) {
362       dbgs() << "Count[" << I << "] = " << S.Branches << ", "
363              << format("%.1f", (100.0 * S.Branches) / TotalCount) << ", "
364              << "Mispreds[" << I << "] = " << S.Mispreds << ", "
365              << format("%.1f", (100.0 * S.Mispreds) / TotalMispreds) << "\n";
366       ++I;
367     }
368   });
369 
370   return Targets;
371 }
372 
373 IndirectCallPromotion::JumpTableInfoType
374 IndirectCallPromotion::maybeGetHotJumpTableTargets(BinaryBasicBlock &BB,
375                                                    MCInst &CallInst,
376                                                    MCInst *&TargetFetchInst,
377                                                    const JumpTable *JT) const {
378   assert(JT && "Can't get jump table addrs for non-jump tables.");
379 
380   BinaryFunction &Function = *BB.getFunction();
381   BinaryContext &BC = Function.getBinaryContext();
382 
383   if (!Function.hasMemoryProfile() || !opts::EliminateLoads)
384     return JumpTableInfoType();
385 
386   JumpTableInfoType HotTargets;
387   MCInst *MemLocInstr;
388   MCInst *PCRelBaseOut;
389   MCInst *FixedEntryLoadInstr;
390   unsigned BaseReg, IndexReg;
391   int64_t DispValue;
392   const MCExpr *DispExpr;
393   MutableArrayRef<MCInst> Insts(&BB.front(), &CallInst);
394   const IndirectBranchType Type = BC.MIB->analyzeIndirectBranch(
395       CallInst, Insts.begin(), Insts.end(), BC.AsmInfo->getCodePointerSize(),
396       MemLocInstr, BaseReg, IndexReg, DispValue, DispExpr, PCRelBaseOut,
397       FixedEntryLoadInstr);
398 
399   assert(MemLocInstr && "There should always be a load for jump tables");
400   if (!MemLocInstr)
401     return JumpTableInfoType();
402 
403   LLVM_DEBUG({
404     dbgs() << "BOLT-INFO: ICP attempting to find memory profiling data for "
405            << "jump table in " << Function << " at @ "
406            << (&CallInst - &BB.front()) << "\n"
407            << "BOLT-INFO: ICP target fetch instructions:\n";
408     BC.printInstruction(dbgs(), *MemLocInstr, 0, &Function);
409     if (MemLocInstr != &CallInst)
410       BC.printInstruction(dbgs(), CallInst, 0, &Function);
411   });
412 
413   DEBUG_VERBOSE(1, {
414     dbgs() << "Jmp info: Type = " << (unsigned)Type << ", "
415            << "BaseReg = " << BC.MRI->getName(BaseReg) << ", "
416            << "IndexReg = " << BC.MRI->getName(IndexReg) << ", "
417            << "DispValue = " << Twine::utohexstr(DispValue) << ", "
418            << "DispExpr = " << DispExpr << ", "
419            << "MemLocInstr = ";
420     BC.printInstruction(dbgs(), *MemLocInstr, 0, &Function);
421     dbgs() << "\n";
422   });
423 
424   ++TotalIndexBasedCandidates;
425 
426   auto ErrorOrMemAccessProfile =
427       BC.MIB->tryGetAnnotationAs<MemoryAccessProfile>(*MemLocInstr,
428                                                       "MemoryAccessProfile");
429   if (!ErrorOrMemAccessProfile) {
430     DEBUG_VERBOSE(1, dbgs()
431                          << "BOLT-INFO: ICP no memory profiling data found\n");
432     return JumpTableInfoType();
433   }
434   MemoryAccessProfile &MemAccessProfile = ErrorOrMemAccessProfile.get();
435 
436   uint64_t ArrayStart;
437   if (DispExpr) {
438     ErrorOr<uint64_t> DispValueOrError =
439         BC.getSymbolValue(*BC.MIB->getTargetSymbol(DispExpr));
440     assert(DispValueOrError && "global symbol needs a value");
441     ArrayStart = *DispValueOrError;
442   } else {
443     ArrayStart = static_cast<uint64_t>(DispValue);
444   }
445 
446   if (BaseReg == BC.MRI->getProgramCounter())
447     ArrayStart += Function.getAddress() + MemAccessProfile.NextInstrOffset;
448 
449   // This is a map of [symbol] -> [count, index] and is used to combine indices
450   // into the jump table since there may be multiple addresses that all have the
451   // same entry.
452   std::map<MCSymbol *, std::pair<uint64_t, uint64_t>> HotTargetMap;
453   const std::pair<size_t, size_t> Range = JT->getEntriesForAddress(ArrayStart);
454 
455   for (const AddressAccess &AccessInfo : MemAccessProfile.AddressAccessInfo) {
456     size_t Index;
457     // Mem data occasionally includes nullprs, ignore them.
458     if (!AccessInfo.MemoryObject && !AccessInfo.Offset)
459       continue;
460 
461     if (AccessInfo.Offset % JT->EntrySize != 0) // ignore bogus data
462       return JumpTableInfoType();
463 
464     if (AccessInfo.MemoryObject) {
465       // Deal with bad/stale data
466       if (!AccessInfo.MemoryObject->getName().starts_with(
467               "JUMP_TABLE/" + Function.getOneName().str()))
468         return JumpTableInfoType();
469       Index =
470           (AccessInfo.Offset - (ArrayStart - JT->getAddress())) / JT->EntrySize;
471     } else {
472       Index = (AccessInfo.Offset - ArrayStart) / JT->EntrySize;
473     }
474 
475     // If Index is out of range it probably means the memory profiling data is
476     // wrong for this instruction, bail out.
477     if (Index >= Range.second) {
478       LLVM_DEBUG(dbgs() << "BOLT-INFO: Index out of range of " << Range.first
479                         << ", " << Range.second << "\n");
480       return JumpTableInfoType();
481     }
482 
483     // Make sure the hot index points at a legal label corresponding to a BB,
484     // e.g. not the end of function (unreachable) label.
485     if (!Function.getBasicBlockForLabel(JT->Entries[Index + Range.first])) {
486       LLVM_DEBUG({
487         dbgs() << "BOLT-INFO: hot index " << Index << " pointing at bogus "
488                << "label " << JT->Entries[Index + Range.first]->getName()
489                << " in jump table:\n";
490         JT->print(dbgs());
491         dbgs() << "HotTargetMap:\n";
492         for (std::pair<MCSymbol *const, std::pair<uint64_t, uint64_t>> &HT :
493              HotTargetMap)
494           dbgs() << "BOLT-INFO: " << HT.first->getName()
495                  << " = (count=" << HT.second.first
496                  << ", index=" << HT.second.second << ")\n";
497       });
498       return JumpTableInfoType();
499     }
500 
501     std::pair<uint64_t, uint64_t> &HotTarget =
502         HotTargetMap[JT->Entries[Index + Range.first]];
503     HotTarget.first += AccessInfo.Count;
504     HotTarget.second = Index;
505   }
506 
507   llvm::copy(llvm::make_second_range(HotTargetMap),
508              std::back_inserter(HotTargets));
509 
510   // Sort with highest counts first.
511   llvm::sort(reverse(HotTargets));
512 
513   LLVM_DEBUG({
514     dbgs() << "BOLT-INFO: ICP jump table hot targets:\n";
515     for (const std::pair<uint64_t, uint64_t> &Target : HotTargets)
516       dbgs() << "BOLT-INFO:  Idx = " << Target.second << ", "
517              << "Count = " << Target.first << "\n";
518   });
519 
520   BC.MIB->getOrCreateAnnotationAs<uint16_t>(CallInst, "JTIndexReg") = IndexReg;
521 
522   TargetFetchInst = MemLocInstr;
523 
524   return HotTargets;
525 }
526 
527 IndirectCallPromotion::SymTargetsType
528 IndirectCallPromotion::findCallTargetSymbols(std::vector<Callsite> &Targets,
529                                              size_t &N, BinaryBasicBlock &BB,
530                                              MCInst &CallInst,
531                                              MCInst *&TargetFetchInst) const {
532   const BinaryContext &BC = BB.getFunction()->getBinaryContext();
533   const JumpTable *JT = BB.getFunction()->getJumpTable(CallInst);
534   SymTargetsType SymTargets;
535 
536   if (!JT) {
537     for (size_t I = 0; I < N; ++I) {
538       assert(Targets[I].To.Sym && "All ICP targets must be to known symbols");
539       assert(Targets[I].JTIndices.empty() &&
540              "Can't have jump table indices for non-jump tables");
541       SymTargets.emplace_back(Targets[I].To.Sym, 0);
542     }
543     return SymTargets;
544   }
545 
546   // Use memory profile to select hot targets.
547   JumpTableInfoType HotTargets =
548       maybeGetHotJumpTableTargets(BB, CallInst, TargetFetchInst, JT);
549 
550   auto findTargetsIndex = [&](uint64_t JTIndex) {
551     for (size_t I = 0; I < Targets.size(); ++I)
552       if (llvm::is_contained(Targets[I].JTIndices, JTIndex))
553         return I;
554     LLVM_DEBUG(dbgs() << "BOLT-ERROR: Unable to find target index for hot jump "
555                       << " table entry in " << *BB.getFunction() << "\n");
556     llvm_unreachable("Hot indices must be referred to by at least one "
557                      "callsite");
558   };
559 
560   if (!HotTargets.empty()) {
561     if (opts::Verbosity >= 1)
562       for (size_t I = 0; I < HotTargets.size(); ++I)
563         BC.outs() << "BOLT-INFO: HotTarget[" << I << "] = ("
564                   << HotTargets[I].first << ", " << HotTargets[I].second
565                   << ")\n";
566 
567     // Recompute hottest targets, now discriminating which index is hot
568     // NOTE: This is a tradeoff. On one hand, we get index information. On the
569     // other hand, info coming from the memory profile is much less accurate
570     // than LBRs. So we may actually end up working with more coarse
571     // profile granularity in exchange for information about indices.
572     std::vector<Callsite> NewTargets;
573     std::map<const MCSymbol *, uint32_t> IndicesPerTarget;
574     uint64_t TotalMemAccesses = 0;
575     for (size_t I = 0; I < HotTargets.size(); ++I) {
576       const uint64_t TargetIndex = findTargetsIndex(HotTargets[I].second);
577       ++IndicesPerTarget[Targets[TargetIndex].To.Sym];
578       TotalMemAccesses += HotTargets[I].first;
579     }
580     uint64_t RemainingMemAccesses = TotalMemAccesses;
581     const size_t TopN =
582         opts::ICPJumpTablesTopN ? opts::ICPJumpTablesTopN : opts::ICPTopN;
583     size_t I = 0;
584     for (; I < HotTargets.size(); ++I) {
585       const uint64_t MemAccesses = HotTargets[I].first;
586       if (100 * MemAccesses <
587           TotalMemAccesses * opts::ICPJTTotalPercentThreshold)
588         break;
589       if (100 * MemAccesses <
590           RemainingMemAccesses * opts::ICPJTRemainingPercentThreshold)
591         break;
592       if (TopN && I >= TopN)
593         break;
594       RemainingMemAccesses -= MemAccesses;
595 
596       const uint64_t JTIndex = HotTargets[I].second;
597       Callsite &Target = Targets[findTargetsIndex(JTIndex)];
598 
599       NewTargets.push_back(Target);
600       std::vector<uint64_t>({JTIndex}).swap(NewTargets.back().JTIndices);
601       llvm::erase(Target.JTIndices, JTIndex);
602 
603       // Keep fixCFG counts sane if more indices use this same target later
604       assert(IndicesPerTarget[Target.To.Sym] > 0 && "wrong map");
605       NewTargets.back().Branches =
606           Target.Branches / IndicesPerTarget[Target.To.Sym];
607       NewTargets.back().Mispreds =
608           Target.Mispreds / IndicesPerTarget[Target.To.Sym];
609       assert(Target.Branches >= NewTargets.back().Branches);
610       assert(Target.Mispreds >= NewTargets.back().Mispreds);
611       Target.Branches -= NewTargets.back().Branches;
612       Target.Mispreds -= NewTargets.back().Mispreds;
613     }
614     llvm::copy(Targets, std::back_inserter(NewTargets));
615     std::swap(NewTargets, Targets);
616     N = I;
617 
618     if (N == 0 && opts::Verbosity >= 1) {
619       BC.outs() << "BOLT-INFO: ICP failed in " << *BB.getFunction() << " in "
620                 << BB.getName() << ": failed to meet thresholds after memory "
621                 << "profile data was loaded.\n";
622       return SymTargets;
623     }
624   }
625 
626   for (size_t I = 0, TgtIdx = 0; I < N; ++TgtIdx) {
627     Callsite &Target = Targets[TgtIdx];
628     assert(Target.To.Sym && "All ICP targets must be to known symbols");
629     assert(!Target.JTIndices.empty() && "Jump tables must have indices");
630     for (uint64_t Idx : Target.JTIndices) {
631       SymTargets.emplace_back(Target.To.Sym, Idx);
632       ++I;
633     }
634   }
635 
636   return SymTargets;
637 }
638 
639 IndirectCallPromotion::MethodInfoType IndirectCallPromotion::maybeGetVtableSyms(
640     BinaryBasicBlock &BB, MCInst &Inst,
641     const SymTargetsType &SymTargets) const {
642   BinaryFunction &Function = *BB.getFunction();
643   BinaryContext &BC = Function.getBinaryContext();
644   std::vector<std::pair<MCSymbol *, uint64_t>> VtableSyms;
645   std::vector<MCInst *> MethodFetchInsns;
646   unsigned VtableReg, MethodReg;
647   uint64_t MethodOffset;
648 
649   assert(!Function.getJumpTable(Inst) &&
650          "Can't get vtable addrs for jump tables.");
651 
652   if (!Function.hasMemoryProfile() || !opts::EliminateLoads)
653     return MethodInfoType();
654 
655   MutableArrayRef<MCInst> Insts(&BB.front(), &Inst + 1);
656   if (!BC.MIB->analyzeVirtualMethodCall(Insts.begin(), Insts.end(),
657                                         MethodFetchInsns, VtableReg, MethodReg,
658                                         MethodOffset)) {
659     DEBUG_VERBOSE(
660         1, dbgs() << "BOLT-INFO: ICP unable to analyze method call in "
661                   << Function << " at @ " << (&Inst - &BB.front()) << "\n");
662     return MethodInfoType();
663   }
664 
665   ++TotalMethodLoadEliminationCandidates;
666 
667   DEBUG_VERBOSE(1, {
668     dbgs() << "BOLT-INFO: ICP found virtual method call in " << Function
669            << " at @ " << (&Inst - &BB.front()) << "\n";
670     dbgs() << "BOLT-INFO: ICP method fetch instructions:\n";
671     for (MCInst *Inst : MethodFetchInsns)
672       BC.printInstruction(dbgs(), *Inst, 0, &Function);
673 
674     if (MethodFetchInsns.back() != &Inst)
675       BC.printInstruction(dbgs(), Inst, 0, &Function);
676   });
677 
678   // Try to get value profiling data for the method load instruction.
679   auto ErrorOrMemAccessProfile =
680       BC.MIB->tryGetAnnotationAs<MemoryAccessProfile>(*MethodFetchInsns.back(),
681                                                       "MemoryAccessProfile");
682   if (!ErrorOrMemAccessProfile) {
683     DEBUG_VERBOSE(1, dbgs()
684                          << "BOLT-INFO: ICP no memory profiling data found\n");
685     return MethodInfoType();
686   }
687   MemoryAccessProfile &MemAccessProfile = ErrorOrMemAccessProfile.get();
688 
689   // Find the vtable that each method belongs to.
690   std::map<const MCSymbol *, uint64_t> MethodToVtable;
691 
692   for (const AddressAccess &AccessInfo : MemAccessProfile.AddressAccessInfo) {
693     uint64_t Address = AccessInfo.Offset;
694     if (AccessInfo.MemoryObject)
695       Address += AccessInfo.MemoryObject->getAddress();
696 
697     // Ignore bogus data.
698     if (!Address)
699       continue;
700 
701     const uint64_t VtableBase = Address - MethodOffset;
702 
703     DEBUG_VERBOSE(1, dbgs() << "BOLT-INFO: ICP vtable = "
704                             << Twine::utohexstr(VtableBase) << "+"
705                             << MethodOffset << "/" << AccessInfo.Count << "\n");
706 
707     if (ErrorOr<uint64_t> MethodAddr = BC.getPointerAtAddress(Address)) {
708       BinaryData *MethodBD = BC.getBinaryDataAtAddress(MethodAddr.get());
709       if (!MethodBD) // skip unknown methods
710         continue;
711       MCSymbol *MethodSym = MethodBD->getSymbol();
712       MethodToVtable[MethodSym] = VtableBase;
713       DEBUG_VERBOSE(1, {
714         const BinaryFunction *Method = BC.getFunctionForSymbol(MethodSym);
715         dbgs() << "BOLT-INFO: ICP found method = "
716                << Twine::utohexstr(MethodAddr.get()) << "/"
717                << (Method ? Method->getPrintName() : "") << "\n";
718       });
719     }
720   }
721 
722   // Find the vtable for each target symbol.
723   for (size_t I = 0; I < SymTargets.size(); ++I) {
724     auto Itr = MethodToVtable.find(SymTargets[I].first);
725     if (Itr != MethodToVtable.end()) {
726       if (BinaryData *BD = BC.getBinaryDataContainingAddress(Itr->second)) {
727         const uint64_t Addend = Itr->second - BD->getAddress();
728         VtableSyms.emplace_back(BD->getSymbol(), Addend);
729         continue;
730       }
731     }
732     // Give up if we can't find the vtable for a method.
733     DEBUG_VERBOSE(1, dbgs() << "BOLT-INFO: ICP can't find vtable for "
734                             << SymTargets[I].first->getName() << "\n");
735     return MethodInfoType();
736   }
737 
738   // Make sure the vtable reg is not clobbered by the argument passing code
739   if (VtableReg != MethodReg) {
740     for (MCInst *CurInst = MethodFetchInsns.front(); CurInst < &Inst;
741          ++CurInst) {
742       const MCInstrDesc &InstrInfo = BC.MII->get(CurInst->getOpcode());
743       if (InstrInfo.hasDefOfPhysReg(*CurInst, VtableReg, *BC.MRI))
744         return MethodInfoType();
745     }
746   }
747 
748   return MethodInfoType(VtableSyms, MethodFetchInsns);
749 }
750 
751 std::vector<std::unique_ptr<BinaryBasicBlock>>
752 IndirectCallPromotion::rewriteCall(
753     BinaryBasicBlock &IndCallBlock, const MCInst &CallInst,
754     MCPlusBuilder::BlocksVectorTy &&ICPcode,
755     const std::vector<MCInst *> &MethodFetchInsns) const {
756   BinaryFunction &Function = *IndCallBlock.getFunction();
757   MCPlusBuilder *MIB = Function.getBinaryContext().MIB.get();
758 
759   // Create new basic blocks with correct code in each one first.
760   std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs;
761   const bool IsTailCallOrJT =
762       (MIB->isTailCall(CallInst) || Function.getJumpTable(CallInst));
763 
764   // If we are tracking the indirect call/jump address, propagate the address to
765   // the ICP code.
766   const std::optional<uint32_t> IndirectInstrOffset = MIB->getOffset(CallInst);
767   if (IndirectInstrOffset) {
768     for (auto &[Symbol, Instructions] : ICPcode)
769       for (MCInst &Inst : Instructions)
770         MIB->setOffset(Inst, *IndirectInstrOffset);
771   }
772 
773   // Move instructions from the tail of the original call block
774   // to the merge block.
775 
776   // Remember any pseudo instructions following a tail call.  These
777   // must be preserved and moved to the original block.
778   InstructionListType TailInsts;
779   const MCInst *TailInst = &CallInst;
780   if (IsTailCallOrJT)
781     while (TailInst + 1 < &(*IndCallBlock.end()) &&
782            MIB->isPseudo(*(TailInst + 1)))
783       TailInsts.push_back(*++TailInst);
784 
785   InstructionListType MovedInst = IndCallBlock.splitInstructions(&CallInst);
786   // Link new BBs to the original input offset of the indirect call site or its
787   // containing BB, so we can map samples recorded in new BBs back to the
788   // original BB seen in the input binary (if using BAT).
789   const uint32_t OrigOffset = IndirectInstrOffset
790                                   ? *IndirectInstrOffset
791                                   : IndCallBlock.getInputOffset();
792 
793   IndCallBlock.eraseInstructions(MethodFetchInsns.begin(),
794                                  MethodFetchInsns.end());
795   if (IndCallBlock.empty() ||
796       (!MethodFetchInsns.empty() && MethodFetchInsns.back() == &CallInst))
797     IndCallBlock.addInstructions(ICPcode.front().second.begin(),
798                                  ICPcode.front().second.end());
799   else
800     IndCallBlock.replaceInstruction(std::prev(IndCallBlock.end()),
801                                     ICPcode.front().second);
802   IndCallBlock.addInstructions(TailInsts.begin(), TailInsts.end());
803 
804   for (auto Itr = ICPcode.begin() + 1; Itr != ICPcode.end(); ++Itr) {
805     MCSymbol *&Sym = Itr->first;
806     InstructionListType &Insts = Itr->second;
807     assert(Sym);
808     std::unique_ptr<BinaryBasicBlock> TBB = Function.createBasicBlock(Sym);
809     TBB->setOffset(OrigOffset);
810     for (MCInst &Inst : Insts) // sanitize new instructions.
811       if (MIB->isCall(Inst))
812         MIB->removeAnnotation(Inst, "CallProfile");
813     TBB->addInstructions(Insts.begin(), Insts.end());
814     NewBBs.emplace_back(std::move(TBB));
815   }
816 
817   // Move tail of instructions from after the original call to
818   // the merge block.
819   if (!IsTailCallOrJT)
820     NewBBs.back()->addInstructions(MovedInst.begin(), MovedInst.end());
821 
822   return NewBBs;
823 }
824 
825 BinaryBasicBlock *
826 IndirectCallPromotion::fixCFG(BinaryBasicBlock &IndCallBlock,
827                               const bool IsTailCall, const bool IsJumpTable,
828                               IndirectCallPromotion::BasicBlocksVector &&NewBBs,
829                               const std::vector<Callsite> &Targets) const {
830   BinaryFunction &Function = *IndCallBlock.getFunction();
831   using BinaryBranchInfo = BinaryBasicBlock::BinaryBranchInfo;
832   BinaryBasicBlock *MergeBlock = nullptr;
833 
834   // Scale indirect call counts to the execution count of the original
835   // basic block containing the indirect call.
836   uint64_t TotalCount = IndCallBlock.getKnownExecutionCount();
837   uint64_t TotalIndirectBranches = 0;
838   for (const Callsite &Target : Targets)
839     TotalIndirectBranches += Target.Branches;
840   if (TotalIndirectBranches == 0)
841     TotalIndirectBranches = 1;
842   BinaryBasicBlock::BranchInfoType BBI;
843   BinaryBasicBlock::BranchInfoType ScaledBBI;
844   for (const Callsite &Target : Targets) {
845     const size_t NumEntries =
846         std::max(static_cast<std::size_t>(1UL), Target.JTIndices.size());
847     for (size_t I = 0; I < NumEntries; ++I) {
848       BBI.push_back(
849           BinaryBranchInfo{(Target.Branches + NumEntries - 1) / NumEntries,
850                            (Target.Mispreds + NumEntries - 1) / NumEntries});
851       ScaledBBI.push_back(
852           BinaryBranchInfo{uint64_t(TotalCount * Target.Branches /
853                                     (NumEntries * TotalIndirectBranches)),
854                            uint64_t(TotalCount * Target.Mispreds /
855                                     (NumEntries * TotalIndirectBranches))});
856     }
857   }
858 
859   if (IsJumpTable) {
860     BinaryBasicBlock *NewIndCallBlock = NewBBs.back().get();
861     IndCallBlock.moveAllSuccessorsTo(NewIndCallBlock);
862 
863     std::vector<MCSymbol *> SymTargets;
864     for (const Callsite &Target : Targets) {
865       const size_t NumEntries =
866           std::max(static_cast<std::size_t>(1UL), Target.JTIndices.size());
867       for (size_t I = 0; I < NumEntries; ++I)
868         SymTargets.push_back(Target.To.Sym);
869     }
870     assert(SymTargets.size() > NewBBs.size() - 1 &&
871            "There must be a target symbol associated with each new BB.");
872 
873     for (uint64_t I = 0; I < NewBBs.size(); ++I) {
874       BinaryBasicBlock *SourceBB = I ? NewBBs[I - 1].get() : &IndCallBlock;
875       SourceBB->setExecutionCount(TotalCount);
876 
877       BinaryBasicBlock *TargetBB =
878           Function.getBasicBlockForLabel(SymTargets[I]);
879       SourceBB->addSuccessor(TargetBB, ScaledBBI[I]); // taken
880 
881       TotalCount -= ScaledBBI[I].Count;
882       SourceBB->addSuccessor(NewBBs[I].get(), TotalCount); // fall-through
883 
884       // Update branch info for the indirect jump.
885       BinaryBasicBlock::BinaryBranchInfo &BranchInfo =
886           NewIndCallBlock->getBranchInfo(*TargetBB);
887       if (BranchInfo.Count > BBI[I].Count)
888         BranchInfo.Count -= BBI[I].Count;
889       else
890         BranchInfo.Count = 0;
891 
892       if (BranchInfo.MispredictedCount > BBI[I].MispredictedCount)
893         BranchInfo.MispredictedCount -= BBI[I].MispredictedCount;
894       else
895         BranchInfo.MispredictedCount = 0;
896     }
897   } else {
898     assert(NewBBs.size() >= 2);
899     assert(NewBBs.size() % 2 == 1 || IndCallBlock.succ_empty());
900     assert(NewBBs.size() % 2 == 1 || IsTailCall);
901 
902     auto ScaledBI = ScaledBBI.begin();
903     auto updateCurrentBranchInfo = [&] {
904       assert(ScaledBI != ScaledBBI.end());
905       TotalCount -= ScaledBI->Count;
906       ++ScaledBI;
907     };
908 
909     if (!IsTailCall) {
910       MergeBlock = NewBBs.back().get();
911       IndCallBlock.moveAllSuccessorsTo(MergeBlock);
912     }
913 
914     // Fix up successors and execution counts.
915     updateCurrentBranchInfo();
916     IndCallBlock.addSuccessor(NewBBs[1].get(), TotalCount);
917     IndCallBlock.addSuccessor(NewBBs[0].get(), ScaledBBI[0]);
918 
919     const size_t Adj = IsTailCall ? 1 : 2;
920     for (size_t I = 0; I < NewBBs.size() - Adj; ++I) {
921       assert(TotalCount <= IndCallBlock.getExecutionCount() ||
922              TotalCount <= uint64_t(TotalIndirectBranches));
923       uint64_t ExecCount = ScaledBBI[(I + 1) / 2].Count;
924       if (I % 2 == 0) {
925         if (MergeBlock)
926           NewBBs[I]->addSuccessor(MergeBlock, ScaledBBI[(I + 1) / 2].Count);
927       } else {
928         assert(I + 2 < NewBBs.size());
929         updateCurrentBranchInfo();
930         NewBBs[I]->addSuccessor(NewBBs[I + 2].get(), TotalCount);
931         NewBBs[I]->addSuccessor(NewBBs[I + 1].get(), ScaledBBI[(I + 1) / 2]);
932         ExecCount += TotalCount;
933       }
934       NewBBs[I]->setExecutionCount(ExecCount);
935     }
936 
937     if (MergeBlock) {
938       // Arrange for the MergeBlock to be the fallthrough for the first
939       // promoted call block.
940       std::unique_ptr<BinaryBasicBlock> MBPtr;
941       std::swap(MBPtr, NewBBs.back());
942       NewBBs.pop_back();
943       NewBBs.emplace(NewBBs.begin() + 1, std::move(MBPtr));
944       // TODO: is COUNT_FALLTHROUGH_EDGE the right thing here?
945       NewBBs.back()->addSuccessor(MergeBlock, TotalCount); // uncond branch
946     }
947   }
948 
949   // Update the execution count.
950   NewBBs.back()->setExecutionCount(TotalCount);
951 
952   // Update BB and BB layout.
953   Function.insertBasicBlocks(&IndCallBlock, std::move(NewBBs));
954   assert(Function.validateCFG());
955 
956   return MergeBlock;
957 }
958 
959 size_t IndirectCallPromotion::canPromoteCallsite(
960     const BinaryBasicBlock &BB, const MCInst &Inst,
961     const std::vector<Callsite> &Targets, uint64_t NumCalls) {
962   BinaryFunction *BF = BB.getFunction();
963   const BinaryContext &BC = BF->getBinaryContext();
964 
965   if (BB.getKnownExecutionCount() < opts::ExecutionCountThreshold)
966     return 0;
967 
968   const bool IsJumpTable = BF->getJumpTable(Inst);
969 
970   auto computeStats = [&](size_t N) {
971     for (size_t I = 0; I < N; ++I)
972       if (IsJumpTable)
973         TotalNumFrequentJmps += Targets[I].Branches;
974       else
975         TotalNumFrequentCalls += Targets[I].Branches;
976   };
977 
978   // If we have no targets (or no calls), skip this callsite.
979   if (Targets.empty() || !NumCalls) {
980     if (opts::Verbosity >= 1) {
981       const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
982       BC.outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx
983                 << " in " << BB.getName() << ", calls = " << NumCalls
984                 << ", targets empty or NumCalls == 0.\n";
985     }
986     return 0;
987   }
988 
989   size_t TopN = opts::ICPTopN;
990   if (IsJumpTable)
991     TopN = opts::ICPJumpTablesTopN ? opts::ICPJumpTablesTopN : TopN;
992   else
993     TopN = opts::ICPCallsTopN ? opts::ICPCallsTopN : TopN;
994 
995   const size_t TrialN = TopN ? std::min(TopN, Targets.size()) : Targets.size();
996 
997   if (opts::ICPTopCallsites && !BC.MIB->hasAnnotation(Inst, "DoICP"))
998     return 0;
999 
1000   // Pick the top N targets.
1001   uint64_t TotalMispredictsTopN = 0;
1002   size_t N = 0;
1003 
1004   if (opts::ICPUseMispredicts &&
1005       (!IsJumpTable || opts::ICPJumpTablesByTarget)) {
1006     // Count total number of mispredictions for (at most) the top N targets.
1007     // We may choose a smaller N (TrialN vs. N) if the frequency threshold
1008     // is exceeded by fewer targets.
1009     double Threshold = double(opts::ICPMispredictThreshold);
1010     for (size_t I = 0; I < TrialN && Threshold > 0; ++I, ++N) {
1011       Threshold -= (100.0 * Targets[I].Mispreds) / NumCalls;
1012       TotalMispredictsTopN += Targets[I].Mispreds;
1013     }
1014     computeStats(N);
1015 
1016     // Compute the misprediction frequency of the top N call targets.  If this
1017     // frequency is greater than the threshold, we should try ICP on this
1018     // callsite.
1019     const double TopNFrequency = (100.0 * TotalMispredictsTopN) / NumCalls;
1020     if (TopNFrequency == 0 || TopNFrequency < opts::ICPMispredictThreshold) {
1021       if (opts::Verbosity >= 1) {
1022         const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
1023         BC.outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx
1024                   << " in " << BB.getName() << ", calls = " << NumCalls
1025                   << ", top N mis. frequency " << format("%.1f", TopNFrequency)
1026                   << "% < " << opts::ICPMispredictThreshold << "%\n";
1027       }
1028       return 0;
1029     }
1030   } else {
1031     size_t MaxTargets = 0;
1032 
1033     // Count total number of calls for (at most) the top N targets.
1034     // We may choose a smaller N (TrialN vs. N) if the frequency threshold
1035     // is exceeded by fewer targets.
1036     const unsigned TotalThreshold = IsJumpTable
1037                                         ? opts::ICPJTTotalPercentThreshold
1038                                         : opts::ICPCallsTotalPercentThreshold;
1039     const unsigned RemainingThreshold =
1040         IsJumpTable ? opts::ICPJTRemainingPercentThreshold
1041                     : opts::ICPCallsRemainingPercentThreshold;
1042     uint64_t NumRemainingCalls = NumCalls;
1043     for (size_t I = 0; I < TrialN; ++I, ++MaxTargets) {
1044       if (100 * Targets[I].Branches < NumCalls * TotalThreshold)
1045         break;
1046       if (100 * Targets[I].Branches < NumRemainingCalls * RemainingThreshold)
1047         break;
1048       if (N + (Targets[I].JTIndices.empty() ? 1 : Targets[I].JTIndices.size()) >
1049           TrialN)
1050         break;
1051       TotalMispredictsTopN += Targets[I].Mispreds;
1052       NumRemainingCalls -= Targets[I].Branches;
1053       N += Targets[I].JTIndices.empty() ? 1 : Targets[I].JTIndices.size();
1054     }
1055     computeStats(MaxTargets);
1056 
1057     // Don't check misprediction frequency for jump tables -- we don't really
1058     // care as long as we are saving loads from the jump table.
1059     if (!IsJumpTable || opts::ICPJumpTablesByTarget) {
1060       // Compute the misprediction frequency of the top N call targets.  If
1061       // this frequency is less than the threshold, we should skip ICP at
1062       // this callsite.
1063       const double TopNMispredictFrequency =
1064           (100.0 * TotalMispredictsTopN) / NumCalls;
1065 
1066       if (TopNMispredictFrequency < opts::ICPMispredictThreshold) {
1067         if (opts::Verbosity >= 1) {
1068           const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
1069           BC.outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx
1070                     << " in " << BB.getName() << ", calls = " << NumCalls
1071                     << ", top N mispredict frequency "
1072                     << format("%.1f", TopNMispredictFrequency) << "% < "
1073                     << opts::ICPMispredictThreshold << "%\n";
1074         }
1075         return 0;
1076       }
1077     }
1078   }
1079 
1080   // Filter by inline-ability of target functions, stop at first target that
1081   // can't be inlined.
1082   if (!IsJumpTable && opts::ICPPeelForInline) {
1083     for (size_t I = 0; I < N; ++I) {
1084       const MCSymbol *TargetSym = Targets[I].To.Sym;
1085       const BinaryFunction *TargetBF = BC.getFunctionForSymbol(TargetSym);
1086       if (!TargetBF || !BinaryFunctionPass::shouldOptimize(*TargetBF) ||
1087           getInliningInfo(*TargetBF).Type == InliningType::INL_NONE) {
1088         N = I;
1089         break;
1090       }
1091     }
1092   }
1093 
1094   // Filter functions that can have ICP applied (for debugging)
1095   if (!opts::ICPFuncsList.empty()) {
1096     for (std::string &Name : opts::ICPFuncsList)
1097       if (BF->hasName(Name))
1098         return N;
1099     return 0;
1100   }
1101 
1102   return N;
1103 }
1104 
1105 void IndirectCallPromotion::printCallsiteInfo(
1106     const BinaryBasicBlock &BB, const MCInst &Inst,
1107     const std::vector<Callsite> &Targets, const size_t N,
1108     uint64_t NumCalls) const {
1109   BinaryContext &BC = BB.getFunction()->getBinaryContext();
1110   const bool IsTailCall = BC.MIB->isTailCall(Inst);
1111   const bool IsJumpTable = BB.getFunction()->getJumpTable(Inst);
1112   const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
1113 
1114   BC.outs() << "BOLT-INFO: ICP candidate branch info: " << *BB.getFunction()
1115             << " @ " << InstIdx << " in " << BB.getName()
1116             << " -> calls = " << NumCalls
1117             << (IsTailCall ? " (tail)" : (IsJumpTable ? " (jump table)" : ""))
1118             << "\n";
1119   for (size_t I = 0; I < N; I++) {
1120     const double Frequency = 100.0 * Targets[I].Branches / NumCalls;
1121     const double MisFrequency = 100.0 * Targets[I].Mispreds / NumCalls;
1122     BC.outs() << "BOLT-INFO:   ";
1123     if (Targets[I].To.Sym)
1124       BC.outs() << Targets[I].To.Sym->getName();
1125     else
1126       BC.outs() << Targets[I].To.Addr;
1127     BC.outs() << ", calls = " << Targets[I].Branches
1128               << ", mispreds = " << Targets[I].Mispreds
1129               << ", taken freq = " << format("%.1f", Frequency) << "%"
1130               << ", mis. freq = " << format("%.1f", MisFrequency) << "%";
1131     bool First = true;
1132     for (uint64_t JTIndex : Targets[I].JTIndices) {
1133       BC.outs() << (First ? ", indices = " : ", ") << JTIndex;
1134       First = false;
1135     }
1136     BC.outs() << "\n";
1137   }
1138 
1139   LLVM_DEBUG({
1140     dbgs() << "BOLT-INFO: ICP original call instruction:";
1141     BC.printInstruction(dbgs(), Inst, Targets[0].From.Addr, nullptr, true);
1142   });
1143 }
1144 
1145 Error IndirectCallPromotion::runOnFunctions(BinaryContext &BC) {
1146   if (opts::ICP == ICP_NONE)
1147     return Error::success();
1148 
1149   auto &BFs = BC.getBinaryFunctions();
1150 
1151   const bool OptimizeCalls = (opts::ICP == ICP_CALLS || opts::ICP == ICP_ALL);
1152   const bool OptimizeJumpTables =
1153       (opts::ICP == ICP_JUMP_TABLES || opts::ICP == ICP_ALL);
1154 
1155   std::unique_ptr<RegAnalysis> RA;
1156   std::unique_ptr<BinaryFunctionCallGraph> CG;
1157   if (OptimizeJumpTables) {
1158     CG.reset(new BinaryFunctionCallGraph(buildCallGraph(BC)));
1159     RA.reset(new RegAnalysis(BC, &BFs, &*CG));
1160   }
1161 
1162   // If icp-top-callsites is enabled, compute the total number of indirect
1163   // calls and then optimize the hottest callsites that contribute to that
1164   // total.
1165   SetVector<BinaryFunction *> Functions;
1166   if (opts::ICPTopCallsites == 0) {
1167     for (auto &KV : BFs)
1168       Functions.insert(&KV.second);
1169   } else {
1170     using IndirectCallsite = std::tuple<uint64_t, MCInst *, BinaryFunction *>;
1171     std::vector<IndirectCallsite> IndirectCalls;
1172     size_t TotalIndirectCalls = 0;
1173 
1174     // Find all the indirect callsites.
1175     for (auto &BFIt : BFs) {
1176       BinaryFunction &Function = BFIt.second;
1177 
1178       if (!shouldOptimize(Function))
1179         continue;
1180 
1181       const bool HasLayout = !Function.getLayout().block_empty();
1182 
1183       for (BinaryBasicBlock &BB : Function) {
1184         if (HasLayout && Function.isSplit() && BB.isCold())
1185           continue;
1186 
1187         for (MCInst &Inst : BB) {
1188           const bool IsJumpTable = Function.getJumpTable(Inst);
1189           const bool HasIndirectCallProfile =
1190               BC.MIB->hasAnnotation(Inst, "CallProfile");
1191           const bool IsDirectCall =
1192               (BC.MIB->isCall(Inst) && BC.MIB->getTargetSymbol(Inst, 0));
1193 
1194           if (!IsDirectCall &&
1195               ((HasIndirectCallProfile && !IsJumpTable && OptimizeCalls) ||
1196                (IsJumpTable && OptimizeJumpTables))) {
1197             uint64_t NumCalls = 0;
1198             for (const Callsite &BInfo : getCallTargets(BB, Inst))
1199               NumCalls += BInfo.Branches;
1200             IndirectCalls.push_back(
1201                 std::make_tuple(NumCalls, &Inst, &Function));
1202             TotalIndirectCalls += NumCalls;
1203           }
1204         }
1205       }
1206     }
1207 
1208     // Sort callsites by execution count.
1209     llvm::sort(reverse(IndirectCalls));
1210 
1211     // Find callsites that contribute to the top "opts::ICPTopCallsites"%
1212     // number of calls.
1213     const float TopPerc = opts::ICPTopCallsites / 100.0f;
1214     int64_t MaxCalls = TotalIndirectCalls * TopPerc;
1215     uint64_t LastFreq = std::numeric_limits<uint64_t>::max();
1216     size_t Num = 0;
1217     for (const IndirectCallsite &IC : IndirectCalls) {
1218       const uint64_t CurFreq = std::get<0>(IC);
1219       // Once we decide to stop, include at least all branches that share the
1220       // same frequency of the last one to avoid non-deterministic behavior
1221       // (e.g. turning on/off ICP depending on the order of functions)
1222       if (MaxCalls <= 0 && CurFreq != LastFreq)
1223         break;
1224       MaxCalls -= CurFreq;
1225       LastFreq = CurFreq;
1226       BC.MIB->addAnnotation(*std::get<1>(IC), "DoICP", true);
1227       Functions.insert(std::get<2>(IC));
1228       ++Num;
1229     }
1230     BC.outs() << "BOLT-INFO: ICP Total indirect calls = " << TotalIndirectCalls
1231               << ", " << Num << " callsites cover " << opts::ICPTopCallsites
1232               << "% of all indirect calls\n";
1233   }
1234 
1235   for (BinaryFunction *FuncPtr : Functions) {
1236     BinaryFunction &Function = *FuncPtr;
1237 
1238     if (!shouldOptimize(Function))
1239       continue;
1240 
1241     const bool HasLayout = !Function.getLayout().block_empty();
1242 
1243     // Total number of indirect calls issued from the current Function.
1244     // (a fraction of TotalIndirectCalls)
1245     uint64_t FuncTotalIndirectCalls = 0;
1246     uint64_t FuncTotalIndirectJmps = 0;
1247 
1248     std::vector<BinaryBasicBlock *> BBs;
1249     for (BinaryBasicBlock &BB : Function) {
1250       // Skip indirect calls in cold blocks.
1251       if (!HasLayout || !Function.isSplit() || !BB.isCold())
1252         BBs.push_back(&BB);
1253     }
1254     if (BBs.empty())
1255       continue;
1256 
1257     DataflowInfoManager Info(Function, RA.get(), nullptr);
1258     while (!BBs.empty()) {
1259       BinaryBasicBlock *BB = BBs.back();
1260       BBs.pop_back();
1261 
1262       for (unsigned Idx = 0; Idx < BB->size(); ++Idx) {
1263         MCInst &Inst = BB->getInstructionAtIndex(Idx);
1264         const ptrdiff_t InstIdx = &Inst - &(*BB->begin());
1265         const bool IsTailCall = BC.MIB->isTailCall(Inst);
1266         const bool HasIndirectCallProfile =
1267             BC.MIB->hasAnnotation(Inst, "CallProfile");
1268         const bool IsJumpTable = Function.getJumpTable(Inst);
1269 
1270         if (BC.MIB->isCall(Inst))
1271           TotalCalls += BB->getKnownExecutionCount();
1272 
1273         if (IsJumpTable && !OptimizeJumpTables)
1274           continue;
1275 
1276         if (!IsJumpTable && (!HasIndirectCallProfile || !OptimizeCalls))
1277           continue;
1278 
1279         // Ignore direct calls.
1280         if (BC.MIB->isCall(Inst) && BC.MIB->getTargetSymbol(Inst, 0))
1281           continue;
1282 
1283         assert((BC.MIB->isCall(Inst) || BC.MIB->isIndirectBranch(Inst)) &&
1284                "expected a call or an indirect jump instruction");
1285 
1286         if (IsJumpTable)
1287           ++TotalJumpTableCallsites;
1288         else
1289           ++TotalIndirectCallsites;
1290 
1291         std::vector<Callsite> Targets = getCallTargets(*BB, Inst);
1292 
1293         // Compute the total number of calls from this particular callsite.
1294         uint64_t NumCalls = 0;
1295         for (const Callsite &BInfo : Targets)
1296           NumCalls += BInfo.Branches;
1297         if (!IsJumpTable)
1298           FuncTotalIndirectCalls += NumCalls;
1299         else
1300           FuncTotalIndirectJmps += NumCalls;
1301 
1302         // If FLAGS regs is alive after this jmp site, do not try
1303         // promoting because we will clobber FLAGS.
1304         if (IsJumpTable) {
1305           ErrorOr<const BitVector &> State =
1306               Info.getLivenessAnalysis().getStateBefore(Inst);
1307           if (!State || (State && (*State)[BC.MIB->getFlagsReg()])) {
1308             if (opts::Verbosity >= 1)
1309               BC.outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1310                         << InstIdx << " in " << BB->getName()
1311                         << ", calls = " << NumCalls
1312                         << (State ? ", cannot clobber flags reg.\n"
1313                                   : ", no liveness data available.\n");
1314             continue;
1315           }
1316         }
1317 
1318         // Should this callsite be optimized?  Return the number of targets
1319         // to use when promoting this call.  A value of zero means to skip
1320         // this callsite.
1321         size_t N = canPromoteCallsite(*BB, Inst, Targets, NumCalls);
1322 
1323         // If it is a jump table and it failed to meet our initial threshold,
1324         // proceed to findCallTargetSymbols -- it may reevaluate N if
1325         // memory profile is present
1326         if (!N && !IsJumpTable)
1327           continue;
1328 
1329         if (opts::Verbosity >= 1)
1330           printCallsiteInfo(*BB, Inst, Targets, N, NumCalls);
1331 
1332         // Find MCSymbols or absolute addresses for each call target.
1333         MCInst *TargetFetchInst = nullptr;
1334         const SymTargetsType SymTargets =
1335             findCallTargetSymbols(Targets, N, *BB, Inst, TargetFetchInst);
1336 
1337         // findCallTargetSymbols may have changed N if mem profile is available
1338         // for jump tables
1339         if (!N)
1340           continue;
1341 
1342         LLVM_DEBUG(printDecision(dbgs(), Targets, N));
1343 
1344         // If we can't resolve any of the target symbols, punt on this callsite.
1345         // TODO: can this ever happen?
1346         if (SymTargets.size() < N) {
1347           const size_t LastTarget = SymTargets.size();
1348           if (opts::Verbosity >= 1)
1349             BC.outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1350                       << InstIdx << " in " << BB->getName()
1351                       << ", calls = " << NumCalls
1352                       << ", ICP failed to find target symbol for "
1353                       << Targets[LastTarget].To.Sym->getName() << "\n";
1354           continue;
1355         }
1356 
1357         MethodInfoType MethodInfo;
1358 
1359         if (!IsJumpTable) {
1360           MethodInfo = maybeGetVtableSyms(*BB, Inst, SymTargets);
1361           TotalMethodLoadsEliminated += MethodInfo.first.empty() ? 0 : 1;
1362           LLVM_DEBUG(dbgs()
1363                      << "BOLT-INFO: ICP "
1364                      << (!MethodInfo.first.empty() ? "found" : "did not find")
1365                      << " vtables for all methods.\n");
1366         } else if (TargetFetchInst) {
1367           ++TotalIndexBasedJumps;
1368           MethodInfo.second.push_back(TargetFetchInst);
1369         }
1370 
1371         // Generate new promoted call code for this callsite.
1372         MCPlusBuilder::BlocksVectorTy ICPcode =
1373             (IsJumpTable && !opts::ICPJumpTablesByTarget)
1374                 ? BC.MIB->jumpTablePromotion(Inst, SymTargets,
1375                                              MethodInfo.second, BC.Ctx.get())
1376                 : BC.MIB->indirectCallPromotion(
1377                       Inst, SymTargets, MethodInfo.first, MethodInfo.second,
1378                       opts::ICPOldCodeSequence, BC.Ctx.get());
1379 
1380         if (ICPcode.empty()) {
1381           if (opts::Verbosity >= 1)
1382             BC.outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1383                       << InstIdx << " in " << BB->getName()
1384                       << ", calls = " << NumCalls
1385                       << ", unable to generate promoted call code.\n";
1386           continue;
1387         }
1388 
1389         LLVM_DEBUG({
1390           uint64_t Offset = Targets[0].From.Addr;
1391           dbgs() << "BOLT-INFO: ICP indirect call code:\n";
1392           for (const auto &entry : ICPcode) {
1393             const MCSymbol *const &Sym = entry.first;
1394             const InstructionListType &Insts = entry.second;
1395             if (Sym)
1396               dbgs() << Sym->getName() << ":\n";
1397             Offset = BC.printInstructions(dbgs(), Insts.begin(), Insts.end(),
1398                                           Offset);
1399           }
1400           dbgs() << "---------------------------------------------------\n";
1401         });
1402 
1403         // Rewrite the CFG with the newly generated ICP code.
1404         std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs =
1405             rewriteCall(*BB, Inst, std::move(ICPcode), MethodInfo.second);
1406 
1407         // Fix the CFG after inserting the new basic blocks.
1408         BinaryBasicBlock *MergeBlock =
1409             fixCFG(*BB, IsTailCall, IsJumpTable, std::move(NewBBs), Targets);
1410 
1411         // Since the tail of the original block was split off and it may contain
1412         // additional indirect calls, we must add the merge block to the set of
1413         // blocks to process.
1414         if (MergeBlock)
1415           BBs.push_back(MergeBlock);
1416 
1417         if (opts::Verbosity >= 1)
1418           BC.outs() << "BOLT-INFO: ICP succeeded in " << Function << " @ "
1419                     << InstIdx << " in " << BB->getName()
1420                     << " -> calls = " << NumCalls << "\n";
1421 
1422         if (IsJumpTable)
1423           ++TotalOptimizedJumpTableCallsites;
1424         else
1425           ++TotalOptimizedIndirectCallsites;
1426 
1427         Modified.insert(&Function);
1428       }
1429     }
1430     TotalIndirectCalls += FuncTotalIndirectCalls;
1431     TotalIndirectJmps += FuncTotalIndirectJmps;
1432   }
1433 
1434   BC.outs()
1435       << "BOLT-INFO: ICP total indirect callsites with profile = "
1436       << TotalIndirectCallsites << "\n"
1437       << "BOLT-INFO: ICP total jump table callsites = "
1438       << TotalJumpTableCallsites << "\n"
1439       << "BOLT-INFO: ICP total number of calls = " << TotalCalls << "\n"
1440       << "BOLT-INFO: ICP percentage of calls that are indirect = "
1441       << format("%.1f", (100.0 * TotalIndirectCalls) / TotalCalls) << "%\n"
1442       << "BOLT-INFO: ICP percentage of indirect calls that can be "
1443          "optimized = "
1444       << format("%.1f", (100.0 * TotalNumFrequentCalls) /
1445                             std::max<size_t>(TotalIndirectCalls, 1))
1446       << "%\n"
1447       << "BOLT-INFO: ICP percentage of indirect callsites that are "
1448          "optimized = "
1449       << format("%.1f", (100.0 * TotalOptimizedIndirectCallsites) /
1450                             std::max<uint64_t>(TotalIndirectCallsites, 1))
1451       << "%\n"
1452       << "BOLT-INFO: ICP number of method load elimination candidates = "
1453       << TotalMethodLoadEliminationCandidates << "\n"
1454       << "BOLT-INFO: ICP percentage of method calls candidates that have "
1455          "loads eliminated = "
1456       << format("%.1f",
1457                 (100.0 * TotalMethodLoadsEliminated) /
1458                     std::max<uint64_t>(TotalMethodLoadEliminationCandidates, 1))
1459       << "%\n"
1460       << "BOLT-INFO: ICP percentage of indirect branches that are "
1461          "optimized = "
1462       << format("%.1f", (100.0 * TotalNumFrequentJmps) /
1463                             std::max<uint64_t>(TotalIndirectJmps, 1))
1464       << "%\n"
1465       << "BOLT-INFO: ICP percentage of jump table callsites that are "
1466       << "optimized = "
1467       << format("%.1f", (100.0 * TotalOptimizedJumpTableCallsites) /
1468                             std::max<uint64_t>(TotalJumpTableCallsites, 1))
1469       << "%\n"
1470       << "BOLT-INFO: ICP number of jump table callsites that can use hot "
1471       << "indices = " << TotalIndexBasedCandidates << "\n"
1472       << "BOLT-INFO: ICP percentage of jump table callsites that use hot "
1473          "indices = "
1474       << format("%.1f", (100.0 * TotalIndexBasedJumps) /
1475                             std::max<uint64_t>(TotalIndexBasedCandidates, 1))
1476       << "%\n";
1477 
1478 #ifndef NDEBUG
1479   verifyProfile(BFs);
1480 #endif
1481   return Error::success();
1482 }
1483 
1484 } // namespace bolt
1485 } // namespace llvm
1486