xref: /llvm-project/bolt/lib/Passes/IndirectCallPromotion.cpp (revision e2142ff47c72cc083a8f5261875b30c29d3cf66a)
1 //===- bolt/Passes/IndirectCallPromotion.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the IndirectCallPromotion class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "bolt/Passes/IndirectCallPromotion.h"
14 #include "bolt/Passes/BinaryFunctionCallGraph.h"
15 #include "bolt/Passes/DataflowInfoManager.h"
16 #include "bolt/Passes/Inliner.h"
17 #include "llvm/Support/CommandLine.h"
18 
19 #define DEBUG_TYPE "ICP"
20 #define DEBUG_VERBOSE(Level, X)                                                \
21   if (opts::Verbosity >= (Level)) {                                            \
22     X;                                                                         \
23   }
24 
25 using namespace llvm;
26 using namespace bolt;
27 
28 namespace opts {
29 
30 extern cl::OptionCategory BoltOptCategory;
31 
32 extern cl::opt<IndirectCallPromotionType> ICP;
33 extern cl::opt<unsigned> Verbosity;
34 extern cl::opt<unsigned> ExecutionCountThreshold;
35 
36 static cl::opt<unsigned> ICPJTRemainingPercentThreshold(
37     "icp-jt-remaining-percent-threshold",
38     cl::desc("The percentage threshold against remaining unpromoted indirect "
39              "call count for the promotion for jump tables"),
40     cl::init(30), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
41 
42 static cl::opt<unsigned> ICPJTTotalPercentThreshold(
43     "icp-jt-total-percent-threshold",
44     cl::desc(
45         "The percentage threshold against total count for the promotion for "
46         "jump tables"),
47     cl::init(5), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
48 
49 static cl::opt<unsigned> ICPCallsRemainingPercentThreshold(
50     "icp-calls-remaining-percent-threshold",
51     cl::desc("The percentage threshold against remaining unpromoted indirect "
52              "call count for the promotion for calls"),
53     cl::init(50), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
54 
55 static cl::opt<unsigned> ICPCallsTotalPercentThreshold(
56     "icp-calls-total-percent-threshold",
57     cl::desc(
58         "The percentage threshold against total count for the promotion for "
59         "calls"),
60     cl::init(30), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
61 
62 static cl::opt<unsigned> ICPMispredictThreshold(
63     "indirect-call-promotion-mispredict-threshold",
64     cl::desc("misprediction threshold for skipping ICP on an "
65              "indirect call"),
66     cl::init(0), cl::ZeroOrMore, cl::cat(BoltOptCategory));
67 
68 static cl::opt<bool> ICPUseMispredicts(
69     "indirect-call-promotion-use-mispredicts",
70     cl::desc("use misprediction frequency for determining whether or not ICP "
71              "should be applied at a callsite.  The "
72              "-indirect-call-promotion-mispredict-threshold value will be used "
73              "by this heuristic"),
74     cl::ZeroOrMore, cl::cat(BoltOptCategory));
75 
76 static cl::opt<unsigned>
77     ICPTopN("indirect-call-promotion-topn",
78             cl::desc("limit number of targets to consider when doing indirect "
79                      "call promotion. 0 = no limit"),
80             cl::init(3), cl::ZeroOrMore, cl::cat(BoltOptCategory));
81 
82 static cl::opt<unsigned> ICPCallsTopN(
83     "indirect-call-promotion-calls-topn",
84     cl::desc("limit number of targets to consider when doing indirect "
85              "call promotion on calls. 0 = no limit"),
86     cl::init(0), cl::ZeroOrMore, cl::cat(BoltOptCategory));
87 
88 static cl::opt<unsigned> ICPJumpTablesTopN(
89     "indirect-call-promotion-jump-tables-topn",
90     cl::desc("limit number of targets to consider when doing indirect "
91              "call promotion on jump tables. 0 = no limit"),
92     cl::init(0), cl::ZeroOrMore, cl::cat(BoltOptCategory));
93 
94 static cl::opt<bool> EliminateLoads(
95     "icp-eliminate-loads",
96     cl::desc("enable load elimination using memory profiling data when "
97              "performing ICP"),
98     cl::init(true), cl::ZeroOrMore, cl::cat(BoltOptCategory));
99 
100 static cl::opt<unsigned> ICPTopCallsites(
101     "icp-top-callsites",
102     cl::desc("optimize hottest calls until at least this percentage of all "
103              "indirect calls frequency is covered. 0 = all callsites"),
104     cl::init(99), cl::Hidden, cl::ZeroOrMore, cl::cat(BoltOptCategory));
105 
106 static cl::list<std::string>
107     ICPFuncsList("icp-funcs", cl::CommaSeparated,
108                  cl::desc("list of functions to enable ICP for"),
109                  cl::value_desc("func1,func2,func3,..."), cl::Hidden,
110                  cl::cat(BoltOptCategory));
111 
112 static cl::opt<bool>
113     ICPOldCodeSequence("icp-old-code-sequence",
114                        cl::desc("use old code sequence for promoted calls"),
115                        cl::init(false), cl::ZeroOrMore, cl::Hidden,
116                        cl::cat(BoltOptCategory));
117 
118 static cl::opt<bool> ICPJumpTablesByTarget(
119     "icp-jump-tables-targets",
120     cl::desc(
121         "for jump tables, optimize indirect jmp targets instead of indices"),
122     cl::init(false), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
123 
124 static cl::opt<bool> ICPPeelForInline(
125     "icp-inline", cl::desc("only promote call targets eligible for inlining"),
126     cl::init(false), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
127 
128 } // namespace opts
129 
130 static bool verifyProfile(std::map<uint64_t, BinaryFunction> &BFs) {
131   bool IsValid = true;
132   for (auto &BFI : BFs) {
133     BinaryFunction &BF = BFI.second;
134     if (!BF.isSimple())
135       continue;
136     for (BinaryBasicBlock *BB : BF.layout()) {
137       auto BI = BB->branch_info_begin();
138       for (BinaryBasicBlock *SuccBB : BB->successors()) {
139         if (BI->Count != BinaryBasicBlock::COUNT_NO_PROFILE && BI->Count > 0) {
140           if (BB->getKnownExecutionCount() == 0 ||
141               SuccBB->getKnownExecutionCount() == 0) {
142             errs() << "BOLT-WARNING: profile verification failed after ICP for "
143                       "function "
144                    << BF << '\n';
145             IsValid = false;
146           }
147         }
148         ++BI;
149       }
150     }
151   }
152   return IsValid;
153 }
154 
155 namespace llvm {
156 namespace bolt {
157 
158 IndirectCallPromotion::Callsite::Callsite(BinaryFunction &BF,
159                                           const IndirectCallProfile &ICP)
160     : From(BF.getSymbol()), To(ICP.Offset), Mispreds(ICP.Mispreds),
161       Branches(ICP.Count) {
162   if (ICP.Symbol) {
163     To.Sym = ICP.Symbol;
164     To.Addr = 0;
165   }
166 }
167 
168 void IndirectCallPromotion::printDecision(
169     llvm::raw_ostream &OS,
170     std::vector<IndirectCallPromotion::Callsite> &Targets, unsigned N) const {
171   uint64_t TotalCount = 0;
172   uint64_t TotalMispreds = 0;
173   for (const Callsite &S : Targets) {
174     TotalCount += S.Branches;
175     TotalMispreds += S.Mispreds;
176   }
177   if (!TotalCount)
178     TotalCount = 1;
179   if (!TotalMispreds)
180     TotalMispreds = 1;
181 
182   OS << "BOLT-INFO: ICP decision for call site with " << Targets.size()
183      << " targets, Count = " << TotalCount << ", Mispreds = " << TotalMispreds
184      << "\n";
185 
186   size_t I = 0;
187   for (const Callsite &S : Targets) {
188     OS << "Count = " << S.Branches << ", "
189        << format("%.1f", (100.0 * S.Branches) / TotalCount) << ", "
190        << "Mispreds = " << S.Mispreds << ", "
191        << format("%.1f", (100.0 * S.Mispreds) / TotalMispreds);
192     if (I < N)
193       OS << " * to be optimized *";
194     if (!S.JTIndices.empty()) {
195       OS << " Indices:";
196       for (const uint64_t Idx : S.JTIndices)
197         OS << " " << Idx;
198     }
199     OS << "\n";
200     I += S.JTIndices.empty() ? 1 : S.JTIndices.size();
201   }
202 }
203 
204 // Get list of targets for a given call sorted by most frequently
205 // called first.
206 std::vector<IndirectCallPromotion::Callsite>
207 IndirectCallPromotion::getCallTargets(BinaryBasicBlock &BB,
208                                       const MCInst &Inst) const {
209   BinaryFunction &BF = *BB.getFunction();
210   const BinaryContext &BC = BF.getBinaryContext();
211   std::vector<Callsite> Targets;
212 
213   if (const JumpTable *JT = BF.getJumpTable(Inst)) {
214     // Don't support PIC jump tables for now
215     if (!opts::ICPJumpTablesByTarget && JT->Type == JumpTable::JTT_PIC)
216       return Targets;
217     const Location From(BF.getSymbol());
218     const std::pair<size_t, size_t> Range =
219         JT->getEntriesForAddress(BC.MIB->getJumpTable(Inst));
220     assert(JT->Counts.empty() || JT->Counts.size() >= Range.second);
221     JumpTable::JumpInfo DefaultJI;
222     const JumpTable::JumpInfo *JI =
223         JT->Counts.empty() ? &DefaultJI : &JT->Counts[Range.first];
224     const size_t JIAdj = JT->Counts.empty() ? 0 : 1;
225     assert(JT->Type == JumpTable::JTT_PIC ||
226            JT->EntrySize == BC.AsmInfo->getCodePointerSize());
227     for (size_t I = Range.first; I < Range.second; ++I, JI += JIAdj) {
228       MCSymbol *Entry = JT->Entries[I];
229       assert(BF.getBasicBlockForLabel(Entry) ||
230              Entry == BF.getFunctionEndLabel() ||
231              Entry == BF.getFunctionColdEndLabel());
232       if (Entry == BF.getFunctionEndLabel() ||
233           Entry == BF.getFunctionColdEndLabel())
234         continue;
235       const Location To(Entry);
236       const BinaryBasicBlock::BinaryBranchInfo &BI = BB.getBranchInfo(Entry);
237       Targets.emplace_back(From, To, BI.MispredictedCount, BI.Count,
238                            I - Range.first);
239     }
240 
241     // Sort by symbol then addr.
242     std::sort(Targets.begin(), Targets.end(),
243               [](const Callsite &A, const Callsite &B) {
244                 if (A.To.Sym && B.To.Sym)
245                   return A.To.Sym < B.To.Sym;
246                 else if (A.To.Sym && !B.To.Sym)
247                   return true;
248                 else if (!A.To.Sym && B.To.Sym)
249                   return false;
250                 else
251                   return A.To.Addr < B.To.Addr;
252               });
253 
254     // Targets may contain multiple entries to the same target, but using
255     // different indices. Their profile will report the same number of branches
256     // for different indices if the target is the same. That's because we don't
257     // profile the index value, but only the target via LBR.
258     auto First = Targets.begin();
259     auto Last = Targets.end();
260     auto Result = First;
261     while (++First != Last) {
262       Callsite &A = *Result;
263       const Callsite &B = *First;
264       if (A.To.Sym && B.To.Sym && A.To.Sym == B.To.Sym)
265         A.JTIndices.insert(A.JTIndices.end(), B.JTIndices.begin(),
266                            B.JTIndices.end());
267       else
268         *(++Result) = *First;
269     }
270     ++Result;
271 
272     LLVM_DEBUG(if (Targets.end() - Result > 0) {
273       dbgs() << "BOLT-INFO: ICP: " << (Targets.end() - Result)
274              << " duplicate targets removed\n";
275     });
276 
277     Targets.erase(Result, Targets.end());
278   } else {
279     // Don't try to optimize PC relative indirect calls.
280     if (Inst.getOperand(0).isReg() &&
281         Inst.getOperand(0).getReg() == BC.MRI->getProgramCounter())
282       return Targets;
283 
284     const auto ICSP = BC.MIB->tryGetAnnotationAs<IndirectCallSiteProfile>(
285         Inst, "CallProfile");
286     if (ICSP) {
287       for (const IndirectCallProfile &CSP : ICSP.get()) {
288         Callsite Site(BF, CSP);
289         if (Site.isValid())
290           Targets.emplace_back(std::move(Site));
291       }
292     }
293   }
294 
295   // Sort by target count, number of indices in case of jump table, and
296   // mispredicts. We prioritize targets with high count, small number of indices
297   // and high mispredicts. Break ties by selecting targets with lower addresses.
298   std::stable_sort(Targets.begin(), Targets.end(),
299                    [](const Callsite &A, const Callsite &B) {
300                      if (A.Branches != B.Branches)
301                        return A.Branches > B.Branches;
302                      if (A.JTIndices.size() != B.JTIndices.size())
303                        return A.JTIndices.size() < B.JTIndices.size();
304                      if (A.Mispreds != B.Mispreds)
305                        return A.Mispreds > B.Mispreds;
306                      return A.To.Addr < B.To.Addr;
307                    });
308 
309   // Remove non-symbol targets
310   auto Last = std::remove_if(Targets.begin(), Targets.end(),
311                              [](const Callsite &CS) { return !CS.To.Sym; });
312   Targets.erase(Last, Targets.end());
313 
314   LLVM_DEBUG(if (BF.getJumpTable(Inst)) {
315     uint64_t TotalCount = 0;
316     uint64_t TotalMispreds = 0;
317     for (const Callsite &S : Targets) {
318       TotalCount += S.Branches;
319       TotalMispreds += S.Mispreds;
320     }
321     if (!TotalCount)
322       TotalCount = 1;
323     if (!TotalMispreds)
324       TotalMispreds = 1;
325 
326     dbgs() << "BOLT-INFO: ICP: jump table size = " << Targets.size()
327            << ", Count = " << TotalCount << ", Mispreds = " << TotalMispreds
328            << "\n";
329 
330     size_t I = 0;
331     for (const Callsite &S : Targets) {
332       dbgs() << "Count[" << I << "] = " << S.Branches << ", "
333              << format("%.1f", (100.0 * S.Branches) / TotalCount) << ", "
334              << "Mispreds[" << I << "] = " << S.Mispreds << ", "
335              << format("%.1f", (100.0 * S.Mispreds) / TotalMispreds) << "\n";
336       ++I;
337     }
338   });
339 
340   return Targets;
341 }
342 
343 IndirectCallPromotion::JumpTableInfoType
344 IndirectCallPromotion::maybeGetHotJumpTableTargets(BinaryBasicBlock &BB,
345                                                    MCInst &CallInst,
346                                                    MCInst *&TargetFetchInst,
347                                                    const JumpTable *JT) const {
348   assert(JT && "Can't get jump table addrs for non-jump tables.");
349 
350   BinaryFunction &Function = *BB.getFunction();
351   BinaryContext &BC = Function.getBinaryContext();
352 
353   if (!Function.hasMemoryProfile() || !opts::EliminateLoads)
354     return JumpTableInfoType();
355 
356   JumpTableInfoType HotTargets;
357   MCInst *MemLocInstr;
358   MCInst *PCRelBaseOut;
359   unsigned BaseReg, IndexReg;
360   int64_t DispValue;
361   const MCExpr *DispExpr;
362   MutableArrayRef<MCInst> Insts(&BB.front(), &CallInst);
363   const IndirectBranchType Type = BC.MIB->analyzeIndirectBranch(
364       CallInst, Insts.begin(), Insts.end(), BC.AsmInfo->getCodePointerSize(),
365       MemLocInstr, BaseReg, IndexReg, DispValue, DispExpr, PCRelBaseOut);
366 
367   assert(MemLocInstr && "There should always be a load for jump tables");
368   if (!MemLocInstr)
369     return JumpTableInfoType();
370 
371   LLVM_DEBUG({
372     dbgs() << "BOLT-INFO: ICP attempting to find memory profiling data for "
373            << "jump table in " << Function << " at @ "
374            << (&CallInst - &BB.front()) << "\n"
375            << "BOLT-INFO: ICP target fetch instructions:\n";
376     BC.printInstruction(dbgs(), *MemLocInstr, 0, &Function);
377     if (MemLocInstr != &CallInst)
378       BC.printInstruction(dbgs(), CallInst, 0, &Function);
379   });
380 
381   DEBUG_VERBOSE(1, {
382     dbgs() << "Jmp info: Type = " << (unsigned)Type << ", "
383            << "BaseReg = " << BC.MRI->getName(BaseReg) << ", "
384            << "IndexReg = " << BC.MRI->getName(IndexReg) << ", "
385            << "DispValue = " << Twine::utohexstr(DispValue) << ", "
386            << "DispExpr = " << DispExpr << ", "
387            << "MemLocInstr = ";
388     BC.printInstruction(dbgs(), *MemLocInstr, 0, &Function);
389     dbgs() << "\n";
390   });
391 
392   ++TotalIndexBasedCandidates;
393 
394   auto ErrorOrMemAccesssProfile =
395       BC.MIB->tryGetAnnotationAs<MemoryAccessProfile>(*MemLocInstr,
396                                                       "MemoryAccessProfile");
397   if (!ErrorOrMemAccesssProfile) {
398     DEBUG_VERBOSE(1, dbgs()
399                          << "BOLT-INFO: ICP no memory profiling data found\n");
400     return JumpTableInfoType();
401   }
402   MemoryAccessProfile &MemAccessProfile = ErrorOrMemAccesssProfile.get();
403 
404   uint64_t ArrayStart;
405   if (DispExpr) {
406     ErrorOr<uint64_t> DispValueOrError =
407         BC.getSymbolValue(*BC.MIB->getTargetSymbol(DispExpr));
408     assert(DispValueOrError && "global symbol needs a value");
409     ArrayStart = *DispValueOrError;
410   } else {
411     ArrayStart = static_cast<uint64_t>(DispValue);
412   }
413 
414   if (BaseReg == BC.MRI->getProgramCounter())
415     ArrayStart += Function.getAddress() + MemAccessProfile.NextInstrOffset;
416 
417   // This is a map of [symbol] -> [count, index] and is used to combine indices
418   // into the jump table since there may be multiple addresses that all have the
419   // same entry.
420   std::map<MCSymbol *, std::pair<uint64_t, uint64_t>> HotTargetMap;
421   const std::pair<size_t, size_t> Range = JT->getEntriesForAddress(ArrayStart);
422 
423   for (const AddressAccess &AccessInfo : MemAccessProfile.AddressAccessInfo) {
424     size_t Index;
425     // Mem data occasionally includes nullprs, ignore them.
426     if (!AccessInfo.MemoryObject && !AccessInfo.Offset)
427       continue;
428 
429     if (AccessInfo.Offset % JT->EntrySize != 0) // ignore bogus data
430       return JumpTableInfoType();
431 
432     if (AccessInfo.MemoryObject) {
433       // Deal with bad/stale data
434       if (!AccessInfo.MemoryObject->getName().startswith(
435               "JUMP_TABLE/" + Function.getOneName().str()))
436         return JumpTableInfoType();
437       Index =
438           (AccessInfo.Offset - (ArrayStart - JT->getAddress())) / JT->EntrySize;
439     } else {
440       Index = (AccessInfo.Offset - ArrayStart) / JT->EntrySize;
441     }
442 
443     // If Index is out of range it probably means the memory profiling data is
444     // wrong for this instruction, bail out.
445     if (Index >= Range.second) {
446       LLVM_DEBUG(dbgs() << "BOLT-INFO: Index out of range of " << Range.first
447                         << ", " << Range.second << "\n");
448       return JumpTableInfoType();
449     }
450 
451     // Make sure the hot index points at a legal label corresponding to a BB,
452     // e.g. not the end of function (unreachable) label.
453     if (!Function.getBasicBlockForLabel(JT->Entries[Index + Range.first])) {
454       LLVM_DEBUG({
455         dbgs() << "BOLT-INFO: hot index " << Index << " pointing at bogus "
456                << "label " << JT->Entries[Index + Range.first]->getName()
457                << " in jump table:\n";
458         JT->print(dbgs());
459         dbgs() << "HotTargetMap:\n";
460         for (std::pair<MCSymbol *const, std::pair<uint64_t, uint64_t>> &HT :
461              HotTargetMap)
462           dbgs() << "BOLT-INFO: " << HT.first->getName()
463                  << " = (count=" << HT.second.first
464                  << ", index=" << HT.second.second << ")\n";
465       });
466       return JumpTableInfoType();
467     }
468 
469     std::pair<uint64_t, uint64_t> &HotTarget =
470         HotTargetMap[JT->Entries[Index + Range.first]];
471     HotTarget.first += AccessInfo.Count;
472     HotTarget.second = Index;
473   }
474 
475   std::transform(
476       HotTargetMap.begin(), HotTargetMap.end(), std::back_inserter(HotTargets),
477       [](const std::pair<MCSymbol *, std::pair<uint64_t, uint64_t>> &A) {
478         return A.second;
479       });
480 
481   // Sort with highest counts first.
482   std::sort(HotTargets.rbegin(), HotTargets.rend());
483 
484   LLVM_DEBUG({
485     dbgs() << "BOLT-INFO: ICP jump table hot targets:\n";
486     for (const std::pair<uint64_t, uint64_t> &Target : HotTargets)
487       dbgs() << "BOLT-INFO:  Idx = " << Target.second << ", "
488              << "Count = " << Target.first << "\n";
489   });
490 
491   BC.MIB->getOrCreateAnnotationAs<uint16_t>(CallInst, "JTIndexReg") = IndexReg;
492 
493   TargetFetchInst = MemLocInstr;
494 
495   return HotTargets;
496 }
497 
498 IndirectCallPromotion::SymTargetsType
499 IndirectCallPromotion::findCallTargetSymbols(std::vector<Callsite> &Targets,
500                                              size_t &N, BinaryBasicBlock &BB,
501                                              MCInst &CallInst,
502                                              MCInst *&TargetFetchInst) const {
503   const JumpTable *JT = BB.getFunction()->getJumpTable(CallInst);
504   SymTargetsType SymTargets;
505 
506   if (!JT) {
507     for (size_t I = 0; I < N; ++I) {
508       assert(Targets[I].To.Sym && "All ICP targets must be to known symbols");
509       assert(Targets[I].JTIndices.empty() &&
510              "Can't have jump table indices for non-jump tables");
511       SymTargets.emplace_back(Targets[I].To.Sym, 0);
512     }
513     return SymTargets;
514   }
515 
516   // Use memory profile to select hot targets.
517   JumpTableInfoType HotTargets =
518       maybeGetHotJumpTableTargets(BB, CallInst, TargetFetchInst, JT);
519 
520   auto findTargetsIndex = [&](uint64_t JTIndex) {
521     for (size_t I = 0; I < Targets.size(); ++I)
522       if (llvm::is_contained(Targets[I].JTIndices, JTIndex))
523         return I;
524     LLVM_DEBUG(dbgs() << "BOLT-ERROR: Unable to find target index for hot jump "
525                       << " table entry in " << *BB.getFunction() << "\n");
526     llvm_unreachable("Hot indices must be referred to by at least one "
527                      "callsite");
528   };
529 
530   if (!HotTargets.empty()) {
531     if (opts::Verbosity >= 1)
532       for (size_t I = 0; I < HotTargets.size(); ++I)
533         outs() << "BOLT-INFO: HotTarget[" << I << "] = (" << HotTargets[I].first
534                << ", " << HotTargets[I].second << ")\n";
535 
536     // Recompute hottest targets, now discriminating which index is hot
537     // NOTE: This is a tradeoff. On one hand, we get index information. On the
538     // other hand, info coming from the memory profile is much less accurate
539     // than LBRs. So we may actually end up working with more coarse
540     // profile granularity in exchange for information about indices.
541     std::vector<Callsite> NewTargets;
542     std::map<const MCSymbol *, uint32_t> IndicesPerTarget;
543     uint64_t TotalMemAccesses = 0;
544     for (size_t I = 0; I < HotTargets.size(); ++I) {
545       const uint64_t TargetIndex = findTargetsIndex(HotTargets[I].second);
546       ++IndicesPerTarget[Targets[TargetIndex].To.Sym];
547       TotalMemAccesses += HotTargets[I].first;
548     }
549     uint64_t RemainingMemAccesses = TotalMemAccesses;
550     const size_t TopN =
551         opts::ICPJumpTablesTopN ? opts::ICPJumpTablesTopN : opts::ICPTopN;
552     size_t I = 0;
553     for (; I < HotTargets.size(); ++I) {
554       const uint64_t MemAccesses = HotTargets[I].first;
555       if (100 * MemAccesses <
556           TotalMemAccesses * opts::ICPJTTotalPercentThreshold)
557         break;
558       if (100 * MemAccesses <
559           RemainingMemAccesses * opts::ICPJTRemainingPercentThreshold)
560         break;
561       if (TopN && I >= TopN)
562         break;
563       RemainingMemAccesses -= MemAccesses;
564 
565       const uint64_t JTIndex = HotTargets[I].second;
566       Callsite &Target = Targets[findTargetsIndex(JTIndex)];
567 
568       NewTargets.push_back(Target);
569       std::vector<uint64_t>({JTIndex}).swap(NewTargets.back().JTIndices);
570       Target.JTIndices.erase(std::remove(Target.JTIndices.begin(),
571                                          Target.JTIndices.end(), JTIndex),
572                              Target.JTIndices.end());
573 
574       // Keep fixCFG counts sane if more indices use this same target later
575       assert(IndicesPerTarget[Target.To.Sym] > 0 && "wrong map");
576       NewTargets.back().Branches =
577           Target.Branches / IndicesPerTarget[Target.To.Sym];
578       NewTargets.back().Mispreds =
579           Target.Mispreds / IndicesPerTarget[Target.To.Sym];
580       assert(Target.Branches >= NewTargets.back().Branches);
581       assert(Target.Mispreds >= NewTargets.back().Mispreds);
582       Target.Branches -= NewTargets.back().Branches;
583       Target.Mispreds -= NewTargets.back().Mispreds;
584     }
585     std::copy(Targets.begin(), Targets.end(), std::back_inserter(NewTargets));
586     std::swap(NewTargets, Targets);
587     N = I;
588 
589     if (N == 0 && opts::Verbosity >= 1) {
590       outs() << "BOLT-INFO: ICP failed in " << *BB.getFunction() << " in "
591              << BB.getName() << ": failed to meet thresholds after memory "
592              << "profile data was loaded.\n";
593       return SymTargets;
594     }
595   }
596 
597   for (size_t I = 0, TgtIdx = 0; I < N; ++TgtIdx) {
598     Callsite &Target = Targets[TgtIdx];
599     assert(Target.To.Sym && "All ICP targets must be to known symbols");
600     assert(!Target.JTIndices.empty() && "Jump tables must have indices");
601     for (uint64_t Idx : Target.JTIndices) {
602       SymTargets.emplace_back(Target.To.Sym, Idx);
603       ++I;
604     }
605   }
606 
607   return SymTargets;
608 }
609 
610 IndirectCallPromotion::MethodInfoType IndirectCallPromotion::maybeGetVtableSyms(
611     BinaryBasicBlock &BB, MCInst &Inst,
612     const SymTargetsType &SymTargets) const {
613   BinaryFunction &Function = *BB.getFunction();
614   BinaryContext &BC = Function.getBinaryContext();
615   std::vector<std::pair<MCSymbol *, uint64_t>> VtableSyms;
616   std::vector<MCInst *> MethodFetchInsns;
617   unsigned VtableReg, MethodReg;
618   uint64_t MethodOffset;
619 
620   assert(!Function.getJumpTable(Inst) &&
621          "Can't get vtable addrs for jump tables.");
622 
623   if (!Function.hasMemoryProfile() || !opts::EliminateLoads)
624     return MethodInfoType();
625 
626   MutableArrayRef<MCInst> Insts(&BB.front(), &Inst + 1);
627   if (!BC.MIB->analyzeVirtualMethodCall(Insts.begin(), Insts.end(),
628                                         MethodFetchInsns, VtableReg, MethodReg,
629                                         MethodOffset)) {
630     DEBUG_VERBOSE(
631         1, dbgs() << "BOLT-INFO: ICP unable to analyze method call in "
632                   << Function << " at @ " << (&Inst - &BB.front()) << "\n");
633     return MethodInfoType();
634   }
635 
636   ++TotalMethodLoadEliminationCandidates;
637 
638   DEBUG_VERBOSE(1, {
639     dbgs() << "BOLT-INFO: ICP found virtual method call in " << Function
640            << " at @ " << (&Inst - &BB.front()) << "\n";
641     dbgs() << "BOLT-INFO: ICP method fetch instructions:\n";
642     for (MCInst *Inst : MethodFetchInsns)
643       BC.printInstruction(dbgs(), *Inst, 0, &Function);
644 
645     if (MethodFetchInsns.back() != &Inst)
646       BC.printInstruction(dbgs(), Inst, 0, &Function);
647   });
648 
649   // Try to get value profiling data for the method load instruction.
650   auto ErrorOrMemAccesssProfile =
651       BC.MIB->tryGetAnnotationAs<MemoryAccessProfile>(*MethodFetchInsns.back(),
652                                                       "MemoryAccessProfile");
653   if (!ErrorOrMemAccesssProfile) {
654     DEBUG_VERBOSE(1, dbgs()
655                          << "BOLT-INFO: ICP no memory profiling data found\n");
656     return MethodInfoType();
657   }
658   MemoryAccessProfile &MemAccessProfile = ErrorOrMemAccesssProfile.get();
659 
660   // Find the vtable that each method belongs to.
661   std::map<const MCSymbol *, uint64_t> MethodToVtable;
662 
663   for (const AddressAccess &AccessInfo : MemAccessProfile.AddressAccessInfo) {
664     uint64_t Address = AccessInfo.Offset;
665     if (AccessInfo.MemoryObject)
666       Address += AccessInfo.MemoryObject->getAddress();
667 
668     // Ignore bogus data.
669     if (!Address)
670       continue;
671 
672     const uint64_t VtableBase = Address - MethodOffset;
673 
674     DEBUG_VERBOSE(1, dbgs() << "BOLT-INFO: ICP vtable = "
675                             << Twine::utohexstr(VtableBase) << "+"
676                             << MethodOffset << "/" << AccessInfo.Count << "\n");
677 
678     if (ErrorOr<uint64_t> MethodAddr = BC.getPointerAtAddress(Address)) {
679       BinaryData *MethodBD = BC.getBinaryDataAtAddress(MethodAddr.get());
680       if (!MethodBD) // skip unknown methods
681         continue;
682       MCSymbol *MethodSym = MethodBD->getSymbol();
683       MethodToVtable[MethodSym] = VtableBase;
684       DEBUG_VERBOSE(1, {
685         const BinaryFunction *Method = BC.getFunctionForSymbol(MethodSym);
686         dbgs() << "BOLT-INFO: ICP found method = "
687                << Twine::utohexstr(MethodAddr.get()) << "/"
688                << (Method ? Method->getPrintName() : "") << "\n";
689       });
690     }
691   }
692 
693   // Find the vtable for each target symbol.
694   for (size_t I = 0; I < SymTargets.size(); ++I) {
695     auto Itr = MethodToVtable.find(SymTargets[I].first);
696     if (Itr != MethodToVtable.end()) {
697       if (BinaryData *BD = BC.getBinaryDataContainingAddress(Itr->second)) {
698         const uint64_t Addend = Itr->second - BD->getAddress();
699         VtableSyms.emplace_back(BD->getSymbol(), Addend);
700         continue;
701       }
702     }
703     // Give up if we can't find the vtable for a method.
704     DEBUG_VERBOSE(1, dbgs() << "BOLT-INFO: ICP can't find vtable for "
705                             << SymTargets[I].first->getName() << "\n");
706     return MethodInfoType();
707   }
708 
709   // Make sure the vtable reg is not clobbered by the argument passing code
710   if (VtableReg != MethodReg) {
711     for (MCInst *CurInst = MethodFetchInsns.front(); CurInst < &Inst;
712          ++CurInst) {
713       const MCInstrDesc &InstrInfo = BC.MII->get(CurInst->getOpcode());
714       if (InstrInfo.hasDefOfPhysReg(*CurInst, VtableReg, *BC.MRI))
715         return MethodInfoType();
716     }
717   }
718 
719   return MethodInfoType(VtableSyms, MethodFetchInsns);
720 }
721 
722 std::vector<std::unique_ptr<BinaryBasicBlock>>
723 IndirectCallPromotion::rewriteCall(
724     BinaryBasicBlock &IndCallBlock, const MCInst &CallInst,
725     MCPlusBuilder::BlocksVectorTy &&ICPcode,
726     const std::vector<MCInst *> &MethodFetchInsns) const {
727   BinaryFunction &Function = *IndCallBlock.getFunction();
728   MCPlusBuilder *MIB = Function.getBinaryContext().MIB.get();
729 
730   // Create new basic blocks with correct code in each one first.
731   std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs;
732   const bool IsTailCallOrJT =
733       (MIB->isTailCall(CallInst) || Function.getJumpTable(CallInst));
734 
735   // Move instructions from the tail of the original call block
736   // to the merge block.
737 
738   // Remember any pseudo instructions following a tail call.  These
739   // must be preserved and moved to the original block.
740   InstructionListType TailInsts;
741   const MCInst *TailInst = &CallInst;
742   if (IsTailCallOrJT)
743     while (TailInst + 1 < &(*IndCallBlock.end()) &&
744            MIB->isPseudo(*(TailInst + 1)))
745       TailInsts.push_back(*++TailInst);
746 
747   InstructionListType MovedInst = IndCallBlock.splitInstructions(&CallInst);
748   // Link new BBs to the original input offset of the BB where the indirect
749   // call site is, so we can map samples recorded in new BBs back to the
750   // original BB seen in the input binary (if using BAT)
751   const uint32_t OrigOffset = IndCallBlock.getInputOffset();
752 
753   IndCallBlock.eraseInstructions(MethodFetchInsns.begin(),
754                                  MethodFetchInsns.end());
755   if (IndCallBlock.empty() ||
756       (!MethodFetchInsns.empty() && MethodFetchInsns.back() == &CallInst))
757     IndCallBlock.addInstructions(ICPcode.front().second.begin(),
758                                  ICPcode.front().second.end());
759   else
760     IndCallBlock.replaceInstruction(std::prev(IndCallBlock.end()),
761                                     ICPcode.front().second);
762   IndCallBlock.addInstructions(TailInsts.begin(), TailInsts.end());
763 
764   for (auto Itr = ICPcode.begin() + 1; Itr != ICPcode.end(); ++Itr) {
765     MCSymbol *&Sym = Itr->first;
766     InstructionListType &Insts = Itr->second;
767     assert(Sym);
768     std::unique_ptr<BinaryBasicBlock> TBB =
769         Function.createBasicBlock(OrigOffset, Sym);
770     for (MCInst &Inst : Insts) // sanitize new instructions.
771       if (MIB->isCall(Inst))
772         MIB->removeAnnotation(Inst, "CallProfile");
773     TBB->addInstructions(Insts.begin(), Insts.end());
774     NewBBs.emplace_back(std::move(TBB));
775   }
776 
777   // Move tail of instructions from after the original call to
778   // the merge block.
779   if (!IsTailCallOrJT)
780     NewBBs.back()->addInstructions(MovedInst.begin(), MovedInst.end());
781 
782   return NewBBs;
783 }
784 
785 BinaryBasicBlock *
786 IndirectCallPromotion::fixCFG(BinaryBasicBlock &IndCallBlock,
787                               const bool IsTailCall, const bool IsJumpTable,
788                               IndirectCallPromotion::BasicBlocksVector &&NewBBs,
789                               const std::vector<Callsite> &Targets) const {
790   BinaryFunction &Function = *IndCallBlock.getFunction();
791   using BinaryBranchInfo = BinaryBasicBlock::BinaryBranchInfo;
792   BinaryBasicBlock *MergeBlock = nullptr;
793 
794   // Scale indirect call counts to the execution count of the original
795   // basic block containing the indirect call.
796   uint64_t TotalCount = IndCallBlock.getKnownExecutionCount();
797   uint64_t TotalIndirectBranches = 0;
798   for (const Callsite &Target : Targets)
799     TotalIndirectBranches += Target.Branches;
800   if (TotalIndirectBranches == 0)
801     TotalIndirectBranches = 1;
802   BinaryBasicBlock::BranchInfoType BBI;
803   BinaryBasicBlock::BranchInfoType ScaledBBI;
804   for (const Callsite &Target : Targets) {
805     const size_t NumEntries =
806         std::max(static_cast<std::size_t>(1UL), Target.JTIndices.size());
807     for (size_t I = 0; I < NumEntries; ++I) {
808       BBI.push_back(
809           BinaryBranchInfo{(Target.Branches + NumEntries - 1) / NumEntries,
810                            (Target.Mispreds + NumEntries - 1) / NumEntries});
811       ScaledBBI.push_back(
812           BinaryBranchInfo{uint64_t(TotalCount * Target.Branches /
813                                     (NumEntries * TotalIndirectBranches)),
814                            uint64_t(TotalCount * Target.Mispreds /
815                                     (NumEntries * TotalIndirectBranches))});
816     }
817   }
818 
819   if (IsJumpTable) {
820     BinaryBasicBlock *NewIndCallBlock = NewBBs.back().get();
821     IndCallBlock.moveAllSuccessorsTo(NewIndCallBlock);
822 
823     std::vector<MCSymbol *> SymTargets;
824     for (const Callsite &Target : Targets) {
825       const size_t NumEntries =
826           std::max(static_cast<std::size_t>(1UL), Target.JTIndices.size());
827       for (size_t I = 0; I < NumEntries; ++I)
828         SymTargets.push_back(Target.To.Sym);
829     }
830     assert(SymTargets.size() > NewBBs.size() - 1 &&
831            "There must be a target symbol associated with each new BB.");
832 
833     for (uint64_t I = 0; I < NewBBs.size(); ++I) {
834       BinaryBasicBlock *SourceBB = I ? NewBBs[I - 1].get() : &IndCallBlock;
835       SourceBB->setExecutionCount(TotalCount);
836 
837       BinaryBasicBlock *TargetBB =
838           Function.getBasicBlockForLabel(SymTargets[I]);
839       SourceBB->addSuccessor(TargetBB, ScaledBBI[I]); // taken
840 
841       TotalCount -= ScaledBBI[I].Count;
842       SourceBB->addSuccessor(NewBBs[I].get(), TotalCount); // fall-through
843 
844       // Update branch info for the indirect jump.
845       BinaryBasicBlock::BinaryBranchInfo &BranchInfo =
846           NewIndCallBlock->getBranchInfo(*TargetBB);
847       if (BranchInfo.Count > BBI[I].Count)
848         BranchInfo.Count -= BBI[I].Count;
849       else
850         BranchInfo.Count = 0;
851 
852       if (BranchInfo.MispredictedCount > BBI[I].MispredictedCount)
853         BranchInfo.MispredictedCount -= BBI[I].MispredictedCount;
854       else
855         BranchInfo.MispredictedCount = 0;
856     }
857   } else {
858     assert(NewBBs.size() >= 2);
859     assert(NewBBs.size() % 2 == 1 || IndCallBlock.succ_empty());
860     assert(NewBBs.size() % 2 == 1 || IsTailCall);
861 
862     auto ScaledBI = ScaledBBI.begin();
863     auto updateCurrentBranchInfo = [&] {
864       assert(ScaledBI != ScaledBBI.end());
865       TotalCount -= ScaledBI->Count;
866       ++ScaledBI;
867     };
868 
869     if (!IsTailCall) {
870       MergeBlock = NewBBs.back().get();
871       IndCallBlock.moveAllSuccessorsTo(MergeBlock);
872     }
873 
874     // Fix up successors and execution counts.
875     updateCurrentBranchInfo();
876     IndCallBlock.addSuccessor(NewBBs[1].get(), TotalCount);
877     IndCallBlock.addSuccessor(NewBBs[0].get(), ScaledBBI[0]);
878 
879     const size_t Adj = IsTailCall ? 1 : 2;
880     for (size_t I = 0; I < NewBBs.size() - Adj; ++I) {
881       assert(TotalCount <= IndCallBlock.getExecutionCount() ||
882              TotalCount <= uint64_t(TotalIndirectBranches));
883       uint64_t ExecCount = ScaledBBI[(I + 1) / 2].Count;
884       if (I % 2 == 0) {
885         if (MergeBlock)
886           NewBBs[I]->addSuccessor(MergeBlock, ScaledBBI[(I + 1) / 2].Count);
887       } else {
888         assert(I + 2 < NewBBs.size());
889         updateCurrentBranchInfo();
890         NewBBs[I]->addSuccessor(NewBBs[I + 2].get(), TotalCount);
891         NewBBs[I]->addSuccessor(NewBBs[I + 1].get(), ScaledBBI[(I + 1) / 2]);
892         ExecCount += TotalCount;
893       }
894       NewBBs[I]->setExecutionCount(ExecCount);
895     }
896 
897     if (MergeBlock) {
898       // Arrange for the MergeBlock to be the fallthrough for the first
899       // promoted call block.
900       std::unique_ptr<BinaryBasicBlock> MBPtr;
901       std::swap(MBPtr, NewBBs.back());
902       NewBBs.pop_back();
903       NewBBs.emplace(NewBBs.begin() + 1, std::move(MBPtr));
904       // TODO: is COUNT_FALLTHROUGH_EDGE the right thing here?
905       NewBBs.back()->addSuccessor(MergeBlock, TotalCount); // uncond branch
906     }
907   }
908 
909   // Update the execution count.
910   NewBBs.back()->setExecutionCount(TotalCount);
911 
912   // Update BB and BB layout.
913   Function.insertBasicBlocks(&IndCallBlock, std::move(NewBBs));
914   assert(Function.validateCFG());
915 
916   return MergeBlock;
917 }
918 
919 size_t IndirectCallPromotion::canPromoteCallsite(
920     const BinaryBasicBlock &BB, const MCInst &Inst,
921     const std::vector<Callsite> &Targets, uint64_t NumCalls) {
922   BinaryFunction *BF = BB.getFunction();
923   const BinaryContext &BC = BF->getBinaryContext();
924 
925   if (BB.getKnownExecutionCount() < opts::ExecutionCountThreshold)
926     return 0;
927 
928   const bool IsJumpTable = BF->getJumpTable(Inst);
929 
930   auto computeStats = [&](size_t N) {
931     for (size_t I = 0; I < N; ++I)
932       if (IsJumpTable)
933         TotalNumFrequentJmps += Targets[I].Branches;
934       else
935         TotalNumFrequentCalls += Targets[I].Branches;
936   };
937 
938   // If we have no targets (or no calls), skip this callsite.
939   if (Targets.empty() || !NumCalls) {
940     if (opts::Verbosity >= 1) {
941       const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
942       outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx << " in "
943              << BB.getName() << ", calls = " << NumCalls
944              << ", targets empty or NumCalls == 0.\n";
945     }
946     return 0;
947   }
948 
949   size_t TopN = opts::ICPTopN;
950   if (IsJumpTable)
951     TopN = opts::ICPJumpTablesTopN ? opts::ICPJumpTablesTopN : TopN;
952   else
953     TopN = opts::ICPCallsTopN ? opts::ICPCallsTopN : TopN;
954 
955   const size_t TrialN = TopN ? std::min(TopN, Targets.size()) : Targets.size();
956 
957   if (opts::ICPTopCallsites > 0) {
958     if (!BC.MIB->hasAnnotation(Inst, "DoICP"))
959       return 0;
960   }
961 
962   // Pick the top N targets.
963   uint64_t TotalMispredictsTopN = 0;
964   size_t N = 0;
965 
966   if (opts::ICPUseMispredicts &&
967       (!IsJumpTable || opts::ICPJumpTablesByTarget)) {
968     // Count total number of mispredictions for (at most) the top N targets.
969     // We may choose a smaller N (TrialN vs. N) if the frequency threshold
970     // is exceeded by fewer targets.
971     double Threshold = double(opts::ICPMispredictThreshold);
972     for (size_t I = 0; I < TrialN && Threshold > 0; ++I, ++N) {
973       Threshold -= (100.0 * Targets[I].Mispreds) / NumCalls;
974       TotalMispredictsTopN += Targets[I].Mispreds;
975     }
976     computeStats(N);
977 
978     // Compute the misprediction frequency of the top N call targets.  If this
979     // frequency is greater than the threshold, we should try ICP on this
980     // callsite.
981     const double TopNFrequency = (100.0 * TotalMispredictsTopN) / NumCalls;
982     if (TopNFrequency == 0 || TopNFrequency < opts::ICPMispredictThreshold) {
983       if (opts::Verbosity >= 1) {
984         const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
985         outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx
986                << " in " << BB.getName() << ", calls = " << NumCalls
987                << ", top N mis. frequency " << format("%.1f", TopNFrequency)
988                << "% < " << opts::ICPMispredictThreshold << "%\n";
989       }
990       return 0;
991     }
992   } else {
993     size_t MaxTargets = 0;
994 
995     // Count total number of calls for (at most) the top N targets.
996     // We may choose a smaller N (TrialN vs. N) if the frequency threshold
997     // is exceeded by fewer targets.
998     const unsigned TotalThreshold = IsJumpTable
999                                         ? opts::ICPJTTotalPercentThreshold
1000                                         : opts::ICPCallsTotalPercentThreshold;
1001     const unsigned RemainingThreshold =
1002         IsJumpTable ? opts::ICPJTRemainingPercentThreshold
1003                     : opts::ICPCallsRemainingPercentThreshold;
1004     uint64_t NumRemainingCalls = NumCalls;
1005     for (size_t I = 0; I < TrialN; ++I, ++MaxTargets) {
1006       if (100 * Targets[I].Branches < NumCalls * TotalThreshold)
1007         break;
1008       if (100 * Targets[I].Branches < NumRemainingCalls * RemainingThreshold)
1009         break;
1010       if (N + (Targets[I].JTIndices.empty() ? 1 : Targets[I].JTIndices.size()) >
1011           TrialN)
1012         break;
1013       TotalMispredictsTopN += Targets[I].Mispreds;
1014       NumRemainingCalls -= Targets[I].Branches;
1015       N += Targets[I].JTIndices.empty() ? 1 : Targets[I].JTIndices.size();
1016     }
1017     computeStats(MaxTargets);
1018 
1019     // Don't check misprediction frequency for jump tables -- we don't really
1020     // care as long as we are saving loads from the jump table.
1021     if (!IsJumpTable || opts::ICPJumpTablesByTarget) {
1022       // Compute the misprediction frequency of the top N call targets.  If
1023       // this frequency is less than the threshold, we should skip ICP at
1024       // this callsite.
1025       const double TopNMispredictFrequency =
1026           (100.0 * TotalMispredictsTopN) / NumCalls;
1027 
1028       if (TopNMispredictFrequency < opts::ICPMispredictThreshold) {
1029         if (opts::Verbosity >= 1) {
1030           const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
1031           outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx
1032                  << " in " << BB.getName() << ", calls = " << NumCalls
1033                  << ", top N mispredict frequency "
1034                  << format("%.1f", TopNMispredictFrequency) << "% < "
1035                  << opts::ICPMispredictThreshold << "%\n";
1036         }
1037         return 0;
1038       }
1039     }
1040   }
1041 
1042   // Filter by inline-ability of target functions, stop at first target that
1043   // can't be inlined.
1044   if (opts::ICPPeelForInline) {
1045     for (size_t I = 0; I < N; ++I) {
1046       const MCSymbol *TargetSym = Targets[I].To.Sym;
1047       const BinaryFunction *TargetBF = BC.getFunctionForSymbol(TargetSym);
1048       if (!BinaryFunctionPass::shouldOptimize(*TargetBF) ||
1049           getInliningInfo(*TargetBF).Type == InliningType::INL_NONE) {
1050         N = I;
1051         break;
1052       }
1053     }
1054   }
1055 
1056   // Filter functions that can have ICP applied (for debugging)
1057   if (!opts::ICPFuncsList.empty()) {
1058     for (std::string &Name : opts::ICPFuncsList)
1059       if (BF->hasName(Name))
1060         return N;
1061     return 0;
1062   }
1063 
1064   return N;
1065 }
1066 
1067 void IndirectCallPromotion::printCallsiteInfo(
1068     const BinaryBasicBlock &BB, const MCInst &Inst,
1069     const std::vector<Callsite> &Targets, const size_t N,
1070     uint64_t NumCalls) const {
1071   BinaryContext &BC = BB.getFunction()->getBinaryContext();
1072   const bool IsTailCall = BC.MIB->isTailCall(Inst);
1073   const bool IsJumpTable = BB.getFunction()->getJumpTable(Inst);
1074   const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
1075 
1076   outs() << "BOLT-INFO: ICP candidate branch info: " << *BB.getFunction()
1077          << " @ " << InstIdx << " in " << BB.getName()
1078          << " -> calls = " << NumCalls
1079          << (IsTailCall ? " (tail)" : (IsJumpTable ? " (jump table)" : ""))
1080          << "\n";
1081   for (size_t I = 0; I < N; I++) {
1082     const double Frequency = 100.0 * Targets[I].Branches / NumCalls;
1083     const double MisFrequency = 100.0 * Targets[I].Mispreds / NumCalls;
1084     outs() << "BOLT-INFO:   ";
1085     if (Targets[I].To.Sym)
1086       outs() << Targets[I].To.Sym->getName();
1087     else
1088       outs() << Targets[I].To.Addr;
1089     outs() << ", calls = " << Targets[I].Branches
1090            << ", mispreds = " << Targets[I].Mispreds
1091            << ", taken freq = " << format("%.1f", Frequency) << "%"
1092            << ", mis. freq = " << format("%.1f", MisFrequency) << "%";
1093     bool First = true;
1094     for (uint64_t JTIndex : Targets[I].JTIndices) {
1095       outs() << (First ? ", indices = " : ", ") << JTIndex;
1096       First = false;
1097     }
1098     outs() << "\n";
1099   }
1100 
1101   LLVM_DEBUG({
1102     dbgs() << "BOLT-INFO: ICP original call instruction:";
1103     BC.printInstruction(dbgs(), Inst, Targets[0].From.Addr, nullptr, true);
1104   });
1105 }
1106 
1107 void IndirectCallPromotion::runOnFunctions(BinaryContext &BC) {
1108   if (opts::ICP == ICP_NONE)
1109     return;
1110 
1111   auto &BFs = BC.getBinaryFunctions();
1112 
1113   const bool OptimizeCalls = (opts::ICP == ICP_CALLS || opts::ICP == ICP_ALL);
1114   const bool OptimizeJumpTables =
1115       (opts::ICP == ICP_JUMP_TABLES || opts::ICP == ICP_ALL);
1116 
1117   std::unique_ptr<RegAnalysis> RA;
1118   std::unique_ptr<BinaryFunctionCallGraph> CG;
1119   if (OptimizeJumpTables) {
1120     CG.reset(new BinaryFunctionCallGraph(buildCallGraph(BC)));
1121     RA.reset(new RegAnalysis(BC, &BFs, &*CG));
1122   }
1123 
1124   // If icp-top-callsites is enabled, compute the total number of indirect
1125   // calls and then optimize the hottest callsites that contribute to that
1126   // total.
1127   SetVector<BinaryFunction *> Functions;
1128   if (opts::ICPTopCallsites == 0) {
1129     for (auto &KV : BFs)
1130       Functions.insert(&KV.second);
1131   } else {
1132     using IndirectCallsite = std::tuple<uint64_t, MCInst *, BinaryFunction *>;
1133     std::vector<IndirectCallsite> IndirectCalls;
1134     size_t TotalIndirectCalls = 0;
1135 
1136     // Find all the indirect callsites.
1137     for (auto &BFIt : BFs) {
1138       BinaryFunction &Function = BFIt.second;
1139 
1140       if (!Function.isSimple() || Function.isIgnored() ||
1141           !Function.hasProfile())
1142         continue;
1143 
1144       const bool HasLayout = !Function.layout_empty();
1145 
1146       for (BinaryBasicBlock &BB : Function) {
1147         if (HasLayout && Function.isSplit() && BB.isCold())
1148           continue;
1149 
1150         for (MCInst &Inst : BB) {
1151           const bool IsJumpTable = Function.getJumpTable(Inst);
1152           const bool HasIndirectCallProfile =
1153               BC.MIB->hasAnnotation(Inst, "CallProfile");
1154           const bool IsDirectCall =
1155               (BC.MIB->isCall(Inst) && BC.MIB->getTargetSymbol(Inst, 0));
1156 
1157           if (!IsDirectCall &&
1158               ((HasIndirectCallProfile && !IsJumpTable && OptimizeCalls) ||
1159                (IsJumpTable && OptimizeJumpTables))) {
1160             uint64_t NumCalls = 0;
1161             for (const Callsite &BInfo : getCallTargets(BB, Inst))
1162               NumCalls += BInfo.Branches;
1163             IndirectCalls.push_back(
1164                 std::make_tuple(NumCalls, &Inst, &Function));
1165             TotalIndirectCalls += NumCalls;
1166           }
1167         }
1168       }
1169     }
1170 
1171     // Sort callsites by execution count.
1172     std::sort(IndirectCalls.rbegin(), IndirectCalls.rend());
1173 
1174     // Find callsites that contribute to the top "opts::ICPTopCallsites"%
1175     // number of calls.
1176     const float TopPerc = opts::ICPTopCallsites / 100.0f;
1177     int64_t MaxCalls = TotalIndirectCalls * TopPerc;
1178     uint64_t LastFreq = std::numeric_limits<uint64_t>::max();
1179     size_t Num = 0;
1180     for (const IndirectCallsite &IC : IndirectCalls) {
1181       const uint64_t CurFreq = std::get<0>(IC);
1182       // Once we decide to stop, include at least all branches that share the
1183       // same frequency of the last one to avoid non-deterministic behavior
1184       // (e.g. turning on/off ICP depending on the order of functions)
1185       if (MaxCalls <= 0 && CurFreq != LastFreq)
1186         break;
1187       MaxCalls -= CurFreq;
1188       LastFreq = CurFreq;
1189       BC.MIB->addAnnotation(*std::get<1>(IC), "DoICP", true);
1190       Functions.insert(std::get<2>(IC));
1191       ++Num;
1192     }
1193     outs() << "BOLT-INFO: ICP Total indirect calls = " << TotalIndirectCalls
1194            << ", " << Num << " callsites cover " << opts::ICPTopCallsites
1195            << "% of all indirect calls\n";
1196   }
1197 
1198   for (BinaryFunction *FuncPtr : Functions) {
1199     BinaryFunction &Function = *FuncPtr;
1200 
1201     if (!Function.isSimple() || Function.isIgnored() || !Function.hasProfile())
1202       continue;
1203 
1204     const bool HasLayout = !Function.layout_empty();
1205 
1206     // Total number of indirect calls issued from the current Function.
1207     // (a fraction of TotalIndirectCalls)
1208     uint64_t FuncTotalIndirectCalls = 0;
1209     uint64_t FuncTotalIndirectJmps = 0;
1210 
1211     std::vector<BinaryBasicBlock *> BBs;
1212     for (BinaryBasicBlock &BB : Function) {
1213       // Skip indirect calls in cold blocks.
1214       if (!HasLayout || !Function.isSplit() || !BB.isCold())
1215         BBs.push_back(&BB);
1216     }
1217     if (BBs.empty())
1218       continue;
1219 
1220     DataflowInfoManager Info(Function, RA.get(), nullptr);
1221     while (!BBs.empty()) {
1222       BinaryBasicBlock *BB = BBs.back();
1223       BBs.pop_back();
1224 
1225       for (unsigned Idx = 0; Idx < BB->size(); ++Idx) {
1226         MCInst &Inst = BB->getInstructionAtIndex(Idx);
1227         const ptrdiff_t InstIdx = &Inst - &(*BB->begin());
1228         const bool IsTailCall = BC.MIB->isTailCall(Inst);
1229         const bool HasIndirectCallProfile =
1230             BC.MIB->hasAnnotation(Inst, "CallProfile");
1231         const bool IsJumpTable = Function.getJumpTable(Inst);
1232 
1233         if (BC.MIB->isCall(Inst))
1234           TotalCalls += BB->getKnownExecutionCount();
1235 
1236         if (IsJumpTable && !OptimizeJumpTables)
1237           continue;
1238 
1239         if (!IsJumpTable && (!HasIndirectCallProfile || !OptimizeCalls))
1240           continue;
1241 
1242         // Ignore direct calls.
1243         if (BC.MIB->isCall(Inst) && BC.MIB->getTargetSymbol(Inst, 0))
1244           continue;
1245 
1246         assert((BC.MIB->isCall(Inst) || BC.MIB->isIndirectBranch(Inst)) &&
1247                "expected a call or an indirect jump instruction");
1248 
1249         if (IsJumpTable)
1250           ++TotalJumpTableCallsites;
1251         else
1252           ++TotalIndirectCallsites;
1253 
1254         std::vector<Callsite> Targets = getCallTargets(*BB, Inst);
1255 
1256         // Compute the total number of calls from this particular callsite.
1257         uint64_t NumCalls = 0;
1258         for (const Callsite &BInfo : Targets)
1259           NumCalls += BInfo.Branches;
1260         if (!IsJumpTable)
1261           FuncTotalIndirectCalls += NumCalls;
1262         else
1263           FuncTotalIndirectJmps += NumCalls;
1264 
1265         // If FLAGS regs is alive after this jmp site, do not try
1266         // promoting because we will clobber FLAGS.
1267         if (IsJumpTable) {
1268           ErrorOr<const BitVector &> State =
1269               Info.getLivenessAnalysis().getStateBefore(Inst);
1270           if (!State || (State && (*State)[BC.MIB->getFlagsReg()])) {
1271             if (opts::Verbosity >= 1)
1272               outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1273                      << InstIdx << " in " << BB->getName()
1274                      << ", calls = " << NumCalls
1275                      << (State ? ", cannot clobber flags reg.\n"
1276                                : ", no liveness data available.\n");
1277             continue;
1278           }
1279         }
1280 
1281         // Should this callsite be optimized?  Return the number of targets
1282         // to use when promoting this call.  A value of zero means to skip
1283         // this callsite.
1284         size_t N = canPromoteCallsite(*BB, Inst, Targets, NumCalls);
1285 
1286         // If it is a jump table and it failed to meet our initial threshold,
1287         // proceed to findCallTargetSymbols -- it may reevaluate N if
1288         // memory profile is present
1289         if (!N && !IsJumpTable)
1290           continue;
1291 
1292         if (opts::Verbosity >= 1)
1293           printCallsiteInfo(*BB, Inst, Targets, N, NumCalls);
1294 
1295         // Find MCSymbols or absolute addresses for each call target.
1296         MCInst *TargetFetchInst = nullptr;
1297         const SymTargetsType SymTargets =
1298             findCallTargetSymbols(Targets, N, *BB, Inst, TargetFetchInst);
1299 
1300         // findCallTargetSymbols may have changed N if mem profile is available
1301         // for jump tables
1302         if (!N)
1303           continue;
1304 
1305         LLVM_DEBUG(printDecision(dbgs(), Targets, N));
1306 
1307         // If we can't resolve any of the target symbols, punt on this callsite.
1308         // TODO: can this ever happen?
1309         if (SymTargets.size() < N) {
1310           const size_t LastTarget = SymTargets.size();
1311           if (opts::Verbosity >= 1)
1312             outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1313                    << InstIdx << " in " << BB->getName()
1314                    << ", calls = " << NumCalls
1315                    << ", ICP failed to find target symbol for "
1316                    << Targets[LastTarget].To.Sym->getName() << "\n";
1317           continue;
1318         }
1319 
1320         MethodInfoType MethodInfo;
1321 
1322         if (!IsJumpTable) {
1323           MethodInfo = maybeGetVtableSyms(*BB, Inst, SymTargets);
1324           TotalMethodLoadsEliminated += MethodInfo.first.empty() ? 0 : 1;
1325           LLVM_DEBUG(dbgs()
1326                      << "BOLT-INFO: ICP "
1327                      << (!MethodInfo.first.empty() ? "found" : "did not find")
1328                      << " vtables for all methods.\n");
1329         } else if (TargetFetchInst) {
1330           ++TotalIndexBasedJumps;
1331           MethodInfo.second.push_back(TargetFetchInst);
1332         }
1333 
1334         // Generate new promoted call code for this callsite.
1335         MCPlusBuilder::BlocksVectorTy ICPcode =
1336             (IsJumpTable && !opts::ICPJumpTablesByTarget)
1337                 ? BC.MIB->jumpTablePromotion(Inst, SymTargets,
1338                                              MethodInfo.second, BC.Ctx.get())
1339                 : BC.MIB->indirectCallPromotion(
1340                       Inst, SymTargets, MethodInfo.first, MethodInfo.second,
1341                       opts::ICPOldCodeSequence, BC.Ctx.get());
1342 
1343         if (ICPcode.empty()) {
1344           if (opts::Verbosity >= 1)
1345             outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1346                    << InstIdx << " in " << BB->getName()
1347                    << ", calls = " << NumCalls
1348                    << ", unable to generate promoted call code.\n";
1349           continue;
1350         }
1351 
1352         LLVM_DEBUG({
1353           uint64_t Offset = Targets[0].From.Addr;
1354           dbgs() << "BOLT-INFO: ICP indirect call code:\n";
1355           for (const auto &entry : ICPcode) {
1356             const MCSymbol *const &Sym = entry.first;
1357             const InstructionListType &Insts = entry.second;
1358             if (Sym)
1359               dbgs() << Sym->getName() << ":\n";
1360             Offset = BC.printInstructions(dbgs(), Insts.begin(), Insts.end(),
1361                                           Offset);
1362           }
1363           dbgs() << "---------------------------------------------------\n";
1364         });
1365 
1366         // Rewrite the CFG with the newly generated ICP code.
1367         std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs =
1368             rewriteCall(*BB, Inst, std::move(ICPcode), MethodInfo.second);
1369 
1370         // Fix the CFG after inserting the new basic blocks.
1371         BinaryBasicBlock *MergeBlock =
1372             fixCFG(*BB, IsTailCall, IsJumpTable, std::move(NewBBs), Targets);
1373 
1374         // Since the tail of the original block was split off and it may contain
1375         // additional indirect calls, we must add the merge block to the set of
1376         // blocks to process.
1377         if (MergeBlock)
1378           BBs.push_back(MergeBlock);
1379 
1380         if (opts::Verbosity >= 1)
1381           outs() << "BOLT-INFO: ICP succeeded in " << Function << " @ "
1382                  << InstIdx << " in " << BB->getName()
1383                  << " -> calls = " << NumCalls << "\n";
1384 
1385         if (IsJumpTable)
1386           ++TotalOptimizedJumpTableCallsites;
1387         else
1388           ++TotalOptimizedIndirectCallsites;
1389 
1390         Modified.insert(&Function);
1391       }
1392     }
1393     TotalIndirectCalls += FuncTotalIndirectCalls;
1394     TotalIndirectJmps += FuncTotalIndirectJmps;
1395   }
1396 
1397   outs() << "BOLT-INFO: ICP total indirect callsites with profile = "
1398          << TotalIndirectCallsites << "\n"
1399          << "BOLT-INFO: ICP total jump table callsites = "
1400          << TotalJumpTableCallsites << "\n"
1401          << "BOLT-INFO: ICP total number of calls = " << TotalCalls << "\n"
1402          << "BOLT-INFO: ICP percentage of calls that are indirect = "
1403          << format("%.1f", (100.0 * TotalIndirectCalls) / TotalCalls) << "%\n"
1404          << "BOLT-INFO: ICP percentage of indirect calls that can be "
1405             "optimized = "
1406          << format("%.1f", (100.0 * TotalNumFrequentCalls) /
1407                                std::max<size_t>(TotalIndirectCalls, 1))
1408          << "%\n"
1409          << "BOLT-INFO: ICP percentage of indirect callsites that are "
1410             "optimized = "
1411          << format("%.1f", (100.0 * TotalOptimizedIndirectCallsites) /
1412                                std::max<uint64_t>(TotalIndirectCallsites, 1))
1413          << "%\n"
1414          << "BOLT-INFO: ICP number of method load elimination candidates = "
1415          << TotalMethodLoadEliminationCandidates << "\n"
1416          << "BOLT-INFO: ICP percentage of method calls candidates that have "
1417             "loads eliminated = "
1418          << format("%.1f", (100.0 * TotalMethodLoadsEliminated) /
1419                                std::max<uint64_t>(
1420                                    TotalMethodLoadEliminationCandidates, 1))
1421          << "%\n"
1422          << "BOLT-INFO: ICP percentage of indirect branches that are "
1423             "optimized = "
1424          << format("%.1f", (100.0 * TotalNumFrequentJmps) /
1425                                std::max<uint64_t>(TotalIndirectJmps, 1))
1426          << "%\n"
1427          << "BOLT-INFO: ICP percentage of jump table callsites that are "
1428          << "optimized = "
1429          << format("%.1f", (100.0 * TotalOptimizedJumpTableCallsites) /
1430                                std::max<uint64_t>(TotalJumpTableCallsites, 1))
1431          << "%\n"
1432          << "BOLT-INFO: ICP number of jump table callsites that can use hot "
1433          << "indices = " << TotalIndexBasedCandidates << "\n"
1434          << "BOLT-INFO: ICP percentage of jump table callsites that use hot "
1435             "indices = "
1436          << format("%.1f", (100.0 * TotalIndexBasedJumps) /
1437                                std::max<uint64_t>(TotalIndexBasedCandidates, 1))
1438          << "%\n";
1439 
1440   (void)verifyProfile;
1441 #ifndef NDEBUG
1442   verifyProfile(BFs);
1443 #endif
1444 }
1445 
1446 } // namespace bolt
1447 } // namespace llvm
1448