xref: /llvm-project/bolt/lib/Passes/IndirectCallPromotion.cpp (revision 2ad1c7540eb0e07047911a39d12a12d062d4bbf4)
1 //===- bolt/Passes/IndirectCallPromotion.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the IndirectCallPromotion class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "bolt/Passes/IndirectCallPromotion.h"
14 #include "bolt/Passes/BinaryFunctionCallGraph.h"
15 #include "bolt/Passes/DataflowInfoManager.h"
16 #include "llvm/Support/CommandLine.h"
17 
18 #define DEBUG_TYPE "ICP"
19 #define DEBUG_VERBOSE(Level, X)                                                \
20   if (opts::Verbosity >= (Level)) {                                            \
21     X;                                                                         \
22   }
23 
24 using namespace llvm;
25 using namespace bolt;
26 
27 namespace opts {
28 
29 extern cl::OptionCategory BoltOptCategory;
30 
31 extern cl::opt<IndirectCallPromotionType> ICP;
32 extern cl::opt<unsigned> Verbosity;
33 extern cl::opt<unsigned> ExecutionCountThreshold;
34 
35 static cl::opt<unsigned> ICPJTRemainingPercentThreshold(
36     "icp-jt-remaining-percent-threshold",
37     cl::desc("The percentage threshold against remaining unpromoted indirect "
38              "call count for the promotion for jump tables"),
39     cl::init(30), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
40 
41 static cl::opt<unsigned> ICPJTTotalPercentThreshold(
42     "icp-jt-total-percent-threshold",
43     cl::desc(
44         "The percentage threshold against total count for the promotion for "
45         "jump tables"),
46     cl::init(5), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
47 
48 static cl::opt<unsigned> ICPCallsRemainingPercentThreshold(
49     "icp-calls-remaining-percent-threshold",
50     cl::desc("The percentage threshold against remaining unpromoted indirect "
51              "call count for the promotion for calls"),
52     cl::init(50), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
53 
54 static cl::opt<unsigned> ICPCallsTotalPercentThreshold(
55     "icp-calls-total-percent-threshold",
56     cl::desc(
57         "The percentage threshold against total count for the promotion for "
58         "calls"),
59     cl::init(30), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
60 
61 static cl::opt<unsigned> ICPMispredictThreshold(
62     "indirect-call-promotion-mispredict-threshold",
63     cl::desc("misprediction threshold for skipping ICP on an "
64              "indirect call"),
65     cl::init(0), cl::ZeroOrMore, cl::cat(BoltOptCategory));
66 
67 static cl::opt<bool> ICPUseMispredicts(
68     "indirect-call-promotion-use-mispredicts",
69     cl::desc("use misprediction frequency for determining whether or not ICP "
70              "should be applied at a callsite.  The "
71              "-indirect-call-promotion-mispredict-threshold value will be used "
72              "by this heuristic"),
73     cl::ZeroOrMore, cl::cat(BoltOptCategory));
74 
75 static cl::opt<unsigned>
76     ICPTopN("indirect-call-promotion-topn",
77             cl::desc("limit number of targets to consider when doing indirect "
78                      "call promotion. 0 = no limit"),
79             cl::init(3), cl::ZeroOrMore, cl::cat(BoltOptCategory));
80 
81 static cl::opt<unsigned> ICPCallsTopN(
82     "indirect-call-promotion-calls-topn",
83     cl::desc("limit number of targets to consider when doing indirect "
84              "call promotion on calls. 0 = no limit"),
85     cl::init(0), cl::ZeroOrMore, cl::cat(BoltOptCategory));
86 
87 static cl::opt<unsigned> ICPJumpTablesTopN(
88     "indirect-call-promotion-jump-tables-topn",
89     cl::desc("limit number of targets to consider when doing indirect "
90              "call promotion on jump tables. 0 = no limit"),
91     cl::init(0), cl::ZeroOrMore, cl::cat(BoltOptCategory));
92 
93 static cl::opt<bool> EliminateLoads(
94     "icp-eliminate-loads",
95     cl::desc("enable load elimination using memory profiling data when "
96              "performing ICP"),
97     cl::init(true), cl::ZeroOrMore, cl::cat(BoltOptCategory));
98 
99 static cl::opt<unsigned> ICPTopCallsites(
100     "icp-top-callsites",
101     cl::desc("optimize hottest calls until at least this percentage of all "
102              "indirect calls frequency is covered. 0 = all callsites"),
103     cl::init(99), cl::Hidden, cl::ZeroOrMore, cl::cat(BoltOptCategory));
104 
105 static cl::list<std::string>
106     ICPFuncsList("icp-funcs", cl::CommaSeparated,
107                  cl::desc("list of functions to enable ICP for"),
108                  cl::value_desc("func1,func2,func3,..."), cl::Hidden,
109                  cl::cat(BoltOptCategory));
110 
111 static cl::opt<bool>
112     ICPOldCodeSequence("icp-old-code-sequence",
113                        cl::desc("use old code sequence for promoted calls"),
114                        cl::init(false), cl::ZeroOrMore, cl::Hidden,
115                        cl::cat(BoltOptCategory));
116 
117 static cl::opt<bool> ICPJumpTablesByTarget(
118     "icp-jump-tables-targets",
119     cl::desc(
120         "for jump tables, optimize indirect jmp targets instead of indices"),
121     cl::init(false), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
122 
123 } // namespace opts
124 
125 namespace llvm {
126 namespace bolt {
127 
128 namespace {
129 
130 bool verifyProfile(std::map<uint64_t, BinaryFunction> &BFs) {
131   bool IsValid = true;
132   for (auto &BFI : BFs) {
133     BinaryFunction &BF = BFI.second;
134     if (!BF.isSimple())
135       continue;
136     for (BinaryBasicBlock *BB : BF.layout()) {
137       auto BI = BB->branch_info_begin();
138       for (BinaryBasicBlock *SuccBB : BB->successors()) {
139         if (BI->Count != BinaryBasicBlock::COUNT_NO_PROFILE && BI->Count > 0) {
140           if (BB->getKnownExecutionCount() == 0 ||
141               SuccBB->getKnownExecutionCount() == 0) {
142             errs() << "BOLT-WARNING: profile verification failed after ICP for "
143                       "function "
144                    << BF << '\n';
145             IsValid = false;
146           }
147         }
148         ++BI;
149       }
150     }
151   }
152   return IsValid;
153 }
154 
155 } // namespace
156 
157 IndirectCallPromotion::Callsite::Callsite(BinaryFunction &BF,
158                                           const IndirectCallProfile &ICP)
159     : From(BF.getSymbol()), To(ICP.Offset), Mispreds(ICP.Mispreds),
160       Branches(ICP.Count) {
161   if (ICP.Symbol) {
162     To.Sym = ICP.Symbol;
163     To.Addr = 0;
164   }
165 }
166 
167 void IndirectCallPromotion::printDecision(
168     llvm::raw_ostream &OS,
169     std::vector<IndirectCallPromotion::Callsite> &Targets, unsigned N) const {
170   uint64_t TotalCount = 0;
171   uint64_t TotalMispreds = 0;
172   for (const Callsite &S : Targets) {
173     TotalCount += S.Branches;
174     TotalMispreds += S.Mispreds;
175   }
176   if (!TotalCount)
177     TotalCount = 1;
178   if (!TotalMispreds)
179     TotalMispreds = 1;
180 
181   OS << "BOLT-INFO: ICP decision for call site with " << Targets.size()
182      << " targets, Count = " << TotalCount << ", Mispreds = " << TotalMispreds
183      << "\n";
184 
185   size_t I = 0;
186   for (const Callsite &S : Targets) {
187     OS << "Count = " << S.Branches << ", "
188        << format("%.1f", (100.0 * S.Branches) / TotalCount) << ", "
189        << "Mispreds = " << S.Mispreds << ", "
190        << format("%.1f", (100.0 * S.Mispreds) / TotalMispreds);
191     if (I < N)
192       OS << " * to be optimized *";
193     if (!S.JTIndices.empty()) {
194       OS << " Indices:";
195       for (const uint64_t Idx : S.JTIndices)
196         OS << " " << Idx;
197     }
198     OS << "\n";
199     I += S.JTIndices.empty() ? 1 : S.JTIndices.size();
200   }
201 }
202 
203 // Get list of targets for a given call sorted by most frequently
204 // called first.
205 std::vector<IndirectCallPromotion::Callsite>
206 IndirectCallPromotion::getCallTargets(BinaryBasicBlock &BB,
207                                       const MCInst &Inst) const {
208   BinaryFunction &BF = *BB.getFunction();
209   const BinaryContext &BC = BF.getBinaryContext();
210   std::vector<Callsite> Targets;
211 
212   if (const JumpTable *JT = BF.getJumpTable(Inst)) {
213     // Don't support PIC jump tables for now
214     if (!opts::ICPJumpTablesByTarget && JT->Type == JumpTable::JTT_PIC)
215       return Targets;
216     const Location From(BF.getSymbol());
217     const std::pair<size_t, size_t> Range =
218         JT->getEntriesForAddress(BC.MIB->getJumpTable(Inst));
219     assert(JT->Counts.empty() || JT->Counts.size() >= Range.second);
220     JumpTable::JumpInfo DefaultJI;
221     const JumpTable::JumpInfo *JI =
222         JT->Counts.empty() ? &DefaultJI : &JT->Counts[Range.first];
223     const size_t JIAdj = JT->Counts.empty() ? 0 : 1;
224     assert(JT->Type == JumpTable::JTT_PIC ||
225            JT->EntrySize == BC.AsmInfo->getCodePointerSize());
226     for (size_t I = Range.first; I < Range.second; ++I, JI += JIAdj) {
227       MCSymbol *Entry = JT->Entries[I];
228       assert(BF.getBasicBlockForLabel(Entry) ||
229              Entry == BF.getFunctionEndLabel() ||
230              Entry == BF.getFunctionColdEndLabel());
231       if (Entry == BF.getFunctionEndLabel() ||
232           Entry == BF.getFunctionColdEndLabel())
233         continue;
234       const Location To(Entry);
235       const BinaryBasicBlock::BinaryBranchInfo &BI = BB.getBranchInfo(Entry);
236       Targets.emplace_back(From, To, BI.MispredictedCount, BI.Count,
237                            I - Range.first);
238     }
239 
240     // Sort by symbol then addr.
241     std::sort(Targets.begin(), Targets.end(),
242               [](const Callsite &A, const Callsite &B) {
243                 if (A.To.Sym && B.To.Sym)
244                   return A.To.Sym < B.To.Sym;
245                 else if (A.To.Sym && !B.To.Sym)
246                   return true;
247                 else if (!A.To.Sym && B.To.Sym)
248                   return false;
249                 else
250                   return A.To.Addr < B.To.Addr;
251               });
252 
253     // Targets may contain multiple entries to the same target, but using
254     // different indices. Their profile will report the same number of branches
255     // for different indices if the target is the same. That's because we don't
256     // profile the index value, but only the target via LBR.
257     auto First = Targets.begin();
258     auto Last = Targets.end();
259     auto Result = First;
260     while (++First != Last) {
261       Callsite &A = *Result;
262       const Callsite &B = *First;
263       if (A.To.Sym && B.To.Sym && A.To.Sym == B.To.Sym)
264         A.JTIndices.insert(A.JTIndices.end(), B.JTIndices.begin(),
265                            B.JTIndices.end());
266       else
267         *(++Result) = *First;
268     }
269     ++Result;
270 
271     LLVM_DEBUG(if (Targets.end() - Result > 0) {
272       dbgs() << "BOLT-INFO: ICP: " << (Targets.end() - Result)
273              << " duplicate targets removed\n";
274     });
275 
276     Targets.erase(Result, Targets.end());
277   } else {
278     // Don't try to optimize PC relative indirect calls.
279     if (Inst.getOperand(0).isReg() &&
280         Inst.getOperand(0).getReg() == BC.MRI->getProgramCounter())
281       return Targets;
282 
283     const auto ICSP = BC.MIB->tryGetAnnotationAs<IndirectCallSiteProfile>(
284         Inst, "CallProfile");
285     if (ICSP) {
286       for (const IndirectCallProfile &CSP : ICSP.get()) {
287         Callsite Site(BF, CSP);
288         if (Site.isValid())
289           Targets.emplace_back(std::move(Site));
290       }
291     }
292   }
293 
294   // Sort by target count, number of indices in case of jump table, and
295   // mispredicts. We prioritize targets with high count, small number of indices
296   // and high mispredicts. Break ties by selecting targets with lower addresses.
297   std::stable_sort(Targets.begin(), Targets.end(),
298                    [](const Callsite &A, const Callsite &B) {
299                      if (A.Branches != B.Branches)
300                        return A.Branches > B.Branches;
301                      if (A.JTIndices.size() != B.JTIndices.size())
302                        return A.JTIndices.size() < B.JTIndices.size();
303                      if (A.Mispreds != B.Mispreds)
304                        return A.Mispreds > B.Mispreds;
305                      return A.To.Addr < B.To.Addr;
306                    });
307 
308   // Remove non-symbol targets
309   auto Last = std::remove_if(Targets.begin(), Targets.end(),
310                              [](const Callsite &CS) { return !CS.To.Sym; });
311   Targets.erase(Last, Targets.end());
312 
313   LLVM_DEBUG(if (BF.getJumpTable(Inst)) {
314     uint64_t TotalCount = 0;
315     uint64_t TotalMispreds = 0;
316     for (const Callsite &S : Targets) {
317       TotalCount += S.Branches;
318       TotalMispreds += S.Mispreds;
319     }
320     if (!TotalCount)
321       TotalCount = 1;
322     if (!TotalMispreds)
323       TotalMispreds = 1;
324 
325     dbgs() << "BOLT-INFO: ICP: jump table size = " << Targets.size()
326            << ", Count = " << TotalCount << ", Mispreds = " << TotalMispreds
327            << "\n";
328 
329     size_t I = 0;
330     for (const Callsite &S : Targets) {
331       dbgs() << "Count[" << I << "] = " << S.Branches << ", "
332              << format("%.1f", (100.0 * S.Branches) / TotalCount) << ", "
333              << "Mispreds[" << I << "] = " << S.Mispreds << ", "
334              << format("%.1f", (100.0 * S.Mispreds) / TotalMispreds) << "\n";
335       ++I;
336     }
337   });
338 
339   return Targets;
340 }
341 
342 IndirectCallPromotion::JumpTableInfoType
343 IndirectCallPromotion::maybeGetHotJumpTableTargets(BinaryBasicBlock &BB,
344                                                    MCInst &CallInst,
345                                                    MCInst *&TargetFetchInst,
346                                                    const JumpTable *JT) const {
347   assert(JT && "Can't get jump table addrs for non-jump tables.");
348 
349   BinaryFunction &Function = *BB.getFunction();
350   BinaryContext &BC = Function.getBinaryContext();
351 
352   if (!Function.hasMemoryProfile() || !opts::EliminateLoads)
353     return JumpTableInfoType();
354 
355   JumpTableInfoType HotTargets;
356   MCInst *MemLocInstr;
357   MCInst *PCRelBaseOut;
358   unsigned BaseReg, IndexReg;
359   int64_t DispValue;
360   const MCExpr *DispExpr;
361   MutableArrayRef<MCInst> Insts(&BB.front(), &CallInst);
362   const IndirectBranchType Type = BC.MIB->analyzeIndirectBranch(
363       CallInst, Insts.begin(), Insts.end(), BC.AsmInfo->getCodePointerSize(),
364       MemLocInstr, BaseReg, IndexReg, DispValue, DispExpr, PCRelBaseOut);
365 
366   assert(MemLocInstr && "There should always be a load for jump tables");
367   if (!MemLocInstr)
368     return JumpTableInfoType();
369 
370   LLVM_DEBUG({
371     dbgs() << "BOLT-INFO: ICP attempting to find memory profiling data for "
372            << "jump table in " << Function << " at @ "
373            << (&CallInst - &BB.front()) << "\n"
374            << "BOLT-INFO: ICP target fetch instructions:\n";
375     BC.printInstruction(dbgs(), *MemLocInstr, 0, &Function);
376     if (MemLocInstr != &CallInst)
377       BC.printInstruction(dbgs(), CallInst, 0, &Function);
378   });
379 
380   DEBUG_VERBOSE(1, {
381     dbgs() << "Jmp info: Type = " << (unsigned)Type << ", "
382            << "BaseReg = " << BC.MRI->getName(BaseReg) << ", "
383            << "IndexReg = " << BC.MRI->getName(IndexReg) << ", "
384            << "DispValue = " << Twine::utohexstr(DispValue) << ", "
385            << "DispExpr = " << DispExpr << ", "
386            << "MemLocInstr = ";
387     BC.printInstruction(dbgs(), *MemLocInstr, 0, &Function);
388     dbgs() << "\n";
389   });
390 
391   ++TotalIndexBasedCandidates;
392 
393   auto ErrorOrMemAccesssProfile =
394       BC.MIB->tryGetAnnotationAs<MemoryAccessProfile>(*MemLocInstr,
395                                                       "MemoryAccessProfile");
396   if (!ErrorOrMemAccesssProfile) {
397     DEBUG_VERBOSE(1, dbgs()
398                          << "BOLT-INFO: ICP no memory profiling data found\n");
399     return JumpTableInfoType();
400   }
401   MemoryAccessProfile &MemAccessProfile = ErrorOrMemAccesssProfile.get();
402 
403   uint64_t ArrayStart;
404   if (DispExpr) {
405     ErrorOr<uint64_t> DispValueOrError =
406         BC.getSymbolValue(*BC.MIB->getTargetSymbol(DispExpr));
407     assert(DispValueOrError && "global symbol needs a value");
408     ArrayStart = *DispValueOrError;
409   } else {
410     ArrayStart = static_cast<uint64_t>(DispValue);
411   }
412 
413   if (BaseReg == BC.MRI->getProgramCounter())
414     ArrayStart += Function.getAddress() + MemAccessProfile.NextInstrOffset;
415 
416   // This is a map of [symbol] -> [count, index] and is used to combine indices
417   // into the jump table since there may be multiple addresses that all have the
418   // same entry.
419   std::map<MCSymbol *, std::pair<uint64_t, uint64_t>> HotTargetMap;
420   const std::pair<size_t, size_t> Range = JT->getEntriesForAddress(ArrayStart);
421 
422   for (const AddressAccess &AccessInfo : MemAccessProfile.AddressAccessInfo) {
423     size_t Index;
424     // Mem data occasionally includes nullprs, ignore them.
425     if (!AccessInfo.MemoryObject && !AccessInfo.Offset)
426       continue;
427 
428     if (AccessInfo.Offset % JT->EntrySize != 0) // ignore bogus data
429       return JumpTableInfoType();
430 
431     if (AccessInfo.MemoryObject) {
432       // Deal with bad/stale data
433       if (!AccessInfo.MemoryObject->getName().startswith(
434               "JUMP_TABLE/" + Function.getOneName().str()))
435         return JumpTableInfoType();
436       Index =
437           (AccessInfo.Offset - (ArrayStart - JT->getAddress())) / JT->EntrySize;
438     } else {
439       Index = (AccessInfo.Offset - ArrayStart) / JT->EntrySize;
440     }
441 
442     // If Index is out of range it probably means the memory profiling data is
443     // wrong for this instruction, bail out.
444     if (Index >= Range.second) {
445       LLVM_DEBUG(dbgs() << "BOLT-INFO: Index out of range of " << Range.first
446                         << ", " << Range.second << "\n");
447       return JumpTableInfoType();
448     }
449 
450     // Make sure the hot index points at a legal label corresponding to a BB,
451     // e.g. not the end of function (unreachable) label.
452     if (!Function.getBasicBlockForLabel(JT->Entries[Index + Range.first])) {
453       LLVM_DEBUG({
454         dbgs() << "BOLT-INFO: hot index " << Index << " pointing at bogus "
455                << "label " << JT->Entries[Index + Range.first]->getName()
456                << " in jump table:\n";
457         JT->print(dbgs());
458         dbgs() << "HotTargetMap:\n";
459         for (std::pair<MCSymbol *const, std::pair<uint64_t, uint64_t>> &HT :
460              HotTargetMap)
461           dbgs() << "BOLT-INFO: " << HT.first->getName()
462                  << " = (count=" << HT.second.first
463                  << ", index=" << HT.second.second << ")\n";
464       });
465       return JumpTableInfoType();
466     }
467 
468     std::pair<uint64_t, uint64_t> &HotTarget =
469         HotTargetMap[JT->Entries[Index + Range.first]];
470     HotTarget.first += AccessInfo.Count;
471     HotTarget.second = Index;
472   }
473 
474   std::transform(
475       HotTargetMap.begin(), HotTargetMap.end(), std::back_inserter(HotTargets),
476       [](const std::pair<MCSymbol *, std::pair<uint64_t, uint64_t>> &A) {
477         return A.second;
478       });
479 
480   // Sort with highest counts first.
481   std::sort(HotTargets.rbegin(), HotTargets.rend());
482 
483   LLVM_DEBUG({
484     dbgs() << "BOLT-INFO: ICP jump table hot targets:\n";
485     for (const std::pair<uint64_t, uint64_t> &Target : HotTargets)
486       dbgs() << "BOLT-INFO:  Idx = " << Target.second << ", "
487              << "Count = " << Target.first << "\n";
488   });
489 
490   BC.MIB->getOrCreateAnnotationAs<uint16_t>(CallInst, "JTIndexReg") = IndexReg;
491 
492   TargetFetchInst = MemLocInstr;
493 
494   return HotTargets;
495 }
496 
497 IndirectCallPromotion::SymTargetsType
498 IndirectCallPromotion::findCallTargetSymbols(std::vector<Callsite> &Targets,
499                                              size_t &N, BinaryBasicBlock &BB,
500                                              MCInst &CallInst,
501                                              MCInst *&TargetFetchInst) const {
502   const JumpTable *JT = BB.getFunction()->getJumpTable(CallInst);
503   SymTargetsType SymTargets;
504 
505   if (!JT) {
506     for (size_t I = 0; I < N; ++I) {
507       assert(Targets[I].To.Sym && "All ICP targets must be to known symbols");
508       assert(Targets[I].JTIndices.empty() &&
509              "Can't have jump table indices for non-jump tables");
510       SymTargets.emplace_back(Targets[I].To.Sym, 0);
511     }
512     return SymTargets;
513   }
514 
515   // Use memory profile to select hot targets.
516   JumpTableInfoType HotTargets =
517       maybeGetHotJumpTableTargets(BB, CallInst, TargetFetchInst, JT);
518 
519   auto findTargetsIndex = [&](uint64_t JTIndex) {
520     for (size_t I = 0; I < Targets.size(); ++I)
521       if (llvm::is_contained(Targets[I].JTIndices, JTIndex))
522         return I;
523     LLVM_DEBUG(dbgs() << "BOLT-ERROR: Unable to find target index for hot jump "
524                       << " table entry in " << *BB.getFunction() << "\n");
525     llvm_unreachable("Hot indices must be referred to by at least one "
526                      "callsite");
527   };
528 
529   if (!HotTargets.empty()) {
530     if (opts::Verbosity >= 1)
531       for (size_t I = 0; I < HotTargets.size(); ++I)
532         outs() << "BOLT-INFO: HotTarget[" << I << "] = (" << HotTargets[I].first
533                << ", " << HotTargets[I].second << ")\n";
534 
535     // Recompute hottest targets, now discriminating which index is hot
536     // NOTE: This is a tradeoff. On one hand, we get index information. On the
537     // other hand, info coming from the memory profile is much less accurate
538     // than LBRs. So we may actually end up working with more coarse
539     // profile granularity in exchange for information about indices.
540     std::vector<Callsite> NewTargets;
541     std::map<const MCSymbol *, uint32_t> IndicesPerTarget;
542     uint64_t TotalMemAccesses = 0;
543     for (size_t I = 0; I < HotTargets.size(); ++I) {
544       const uint64_t TargetIndex = findTargetsIndex(HotTargets[I].second);
545       ++IndicesPerTarget[Targets[TargetIndex].To.Sym];
546       TotalMemAccesses += HotTargets[I].first;
547     }
548     uint64_t RemainingMemAccesses = TotalMemAccesses;
549     const size_t TopN =
550         opts::ICPJumpTablesTopN ? opts::ICPJumpTablesTopN : opts::ICPTopN;
551     size_t I = 0;
552     for (; I < HotTargets.size(); ++I) {
553       const uint64_t MemAccesses = HotTargets[I].first;
554       if (100 * MemAccesses <
555           TotalMemAccesses * opts::ICPJTTotalPercentThreshold)
556         break;
557       if (100 * MemAccesses <
558           RemainingMemAccesses * opts::ICPJTRemainingPercentThreshold)
559         break;
560       if (TopN && I >= TopN)
561         break;
562       RemainingMemAccesses -= MemAccesses;
563 
564       const uint64_t JTIndex = HotTargets[I].second;
565       Callsite &Target = Targets[findTargetsIndex(JTIndex)];
566 
567       NewTargets.push_back(Target);
568       std::vector<uint64_t>({JTIndex}).swap(NewTargets.back().JTIndices);
569       Target.JTIndices.erase(std::remove(Target.JTIndices.begin(),
570                                          Target.JTIndices.end(), JTIndex),
571                              Target.JTIndices.end());
572 
573       // Keep fixCFG counts sane if more indices use this same target later
574       assert(IndicesPerTarget[Target.To.Sym] > 0 && "wrong map");
575       NewTargets.back().Branches =
576           Target.Branches / IndicesPerTarget[Target.To.Sym];
577       NewTargets.back().Mispreds =
578           Target.Mispreds / IndicesPerTarget[Target.To.Sym];
579       assert(Target.Branches >= NewTargets.back().Branches);
580       assert(Target.Mispreds >= NewTargets.back().Mispreds);
581       Target.Branches -= NewTargets.back().Branches;
582       Target.Mispreds -= NewTargets.back().Mispreds;
583     }
584     std::copy(Targets.begin(), Targets.end(), std::back_inserter(NewTargets));
585     std::swap(NewTargets, Targets);
586     N = I;
587 
588     if (N == 0 && opts::Verbosity >= 1) {
589       outs() << "BOLT-INFO: ICP failed in " << *BB.getFunction() << " in "
590              << BB.getName() << ": failed to meet thresholds after memory "
591              << "profile data was loaded.\n";
592       return SymTargets;
593     }
594   }
595 
596   for (size_t I = 0, TgtIdx = 0; I < N; ++TgtIdx) {
597     Callsite &Target = Targets[TgtIdx];
598     assert(Target.To.Sym && "All ICP targets must be to known symbols");
599     assert(!Target.JTIndices.empty() && "Jump tables must have indices");
600     for (uint64_t Idx : Target.JTIndices) {
601       SymTargets.emplace_back(Target.To.Sym, Idx);
602       ++I;
603     }
604   }
605 
606   return SymTargets;
607 }
608 
609 IndirectCallPromotion::MethodInfoType IndirectCallPromotion::maybeGetVtableSyms(
610     BinaryBasicBlock &BB, MCInst &Inst,
611     const SymTargetsType &SymTargets) const {
612   BinaryFunction &Function = *BB.getFunction();
613   BinaryContext &BC = Function.getBinaryContext();
614   std::vector<std::pair<MCSymbol *, uint64_t>> VtableSyms;
615   std::vector<MCInst *> MethodFetchInsns;
616   unsigned VtableReg, MethodReg;
617   uint64_t MethodOffset;
618 
619   assert(!Function.getJumpTable(Inst) &&
620          "Can't get vtable addrs for jump tables.");
621 
622   if (!Function.hasMemoryProfile() || !opts::EliminateLoads)
623     return MethodInfoType();
624 
625   MutableArrayRef<MCInst> Insts(&BB.front(), &Inst + 1);
626   if (!BC.MIB->analyzeVirtualMethodCall(Insts.begin(), Insts.end(),
627                                         MethodFetchInsns, VtableReg, MethodReg,
628                                         MethodOffset)) {
629     DEBUG_VERBOSE(
630         1, dbgs() << "BOLT-INFO: ICP unable to analyze method call in "
631                   << Function << " at @ " << (&Inst - &BB.front()) << "\n");
632     return MethodInfoType();
633   }
634 
635   ++TotalMethodLoadEliminationCandidates;
636 
637   DEBUG_VERBOSE(1, {
638     dbgs() << "BOLT-INFO: ICP found virtual method call in " << Function
639            << " at @ " << (&Inst - &BB.front()) << "\n";
640     dbgs() << "BOLT-INFO: ICP method fetch instructions:\n";
641     for (MCInst *Inst : MethodFetchInsns)
642       BC.printInstruction(dbgs(), *Inst, 0, &Function);
643 
644     if (MethodFetchInsns.back() != &Inst)
645       BC.printInstruction(dbgs(), Inst, 0, &Function);
646   });
647 
648   // Try to get value profiling data for the method load instruction.
649   auto ErrorOrMemAccesssProfile =
650       BC.MIB->tryGetAnnotationAs<MemoryAccessProfile>(*MethodFetchInsns.back(),
651                                                       "MemoryAccessProfile");
652   if (!ErrorOrMemAccesssProfile) {
653     DEBUG_VERBOSE(1, dbgs()
654                          << "BOLT-INFO: ICP no memory profiling data found\n");
655     return MethodInfoType();
656   }
657   MemoryAccessProfile &MemAccessProfile = ErrorOrMemAccesssProfile.get();
658 
659   // Find the vtable that each method belongs to.
660   std::map<const MCSymbol *, uint64_t> MethodToVtable;
661 
662   for (const AddressAccess &AccessInfo : MemAccessProfile.AddressAccessInfo) {
663     uint64_t Address = AccessInfo.Offset;
664     if (AccessInfo.MemoryObject)
665       Address += AccessInfo.MemoryObject->getAddress();
666 
667     // Ignore bogus data.
668     if (!Address)
669       continue;
670 
671     const uint64_t VtableBase = Address - MethodOffset;
672 
673     DEBUG_VERBOSE(1, dbgs() << "BOLT-INFO: ICP vtable = "
674                             << Twine::utohexstr(VtableBase) << "+"
675                             << MethodOffset << "/" << AccessInfo.Count << "\n");
676 
677     if (ErrorOr<uint64_t> MethodAddr = BC.getPointerAtAddress(Address)) {
678       BinaryData *MethodBD = BC.getBinaryDataAtAddress(MethodAddr.get());
679       if (!MethodBD) // skip unknown methods
680         continue;
681       MCSymbol *MethodSym = MethodBD->getSymbol();
682       MethodToVtable[MethodSym] = VtableBase;
683       DEBUG_VERBOSE(1, {
684         const BinaryFunction *Method = BC.getFunctionForSymbol(MethodSym);
685         dbgs() << "BOLT-INFO: ICP found method = "
686                << Twine::utohexstr(MethodAddr.get()) << "/"
687                << (Method ? Method->getPrintName() : "") << "\n";
688       });
689     }
690   }
691 
692   // Find the vtable for each target symbol.
693   for (size_t I = 0; I < SymTargets.size(); ++I) {
694     auto Itr = MethodToVtable.find(SymTargets[I].first);
695     if (Itr != MethodToVtable.end()) {
696       if (BinaryData *BD = BC.getBinaryDataContainingAddress(Itr->second)) {
697         const uint64_t Addend = Itr->second - BD->getAddress();
698         VtableSyms.emplace_back(BD->getSymbol(), Addend);
699         continue;
700       }
701     }
702     // Give up if we can't find the vtable for a method.
703     DEBUG_VERBOSE(1, dbgs() << "BOLT-INFO: ICP can't find vtable for "
704                             << SymTargets[I].first->getName() << "\n");
705     return MethodInfoType();
706   }
707 
708   // Make sure the vtable reg is not clobbered by the argument passing code
709   if (VtableReg != MethodReg) {
710     for (MCInst *CurInst = MethodFetchInsns.front(); CurInst < &Inst;
711          ++CurInst) {
712       const MCInstrDesc &InstrInfo = BC.MII->get(CurInst->getOpcode());
713       if (InstrInfo.hasDefOfPhysReg(*CurInst, VtableReg, *BC.MRI))
714         return MethodInfoType();
715     }
716   }
717 
718   return MethodInfoType(VtableSyms, MethodFetchInsns);
719 }
720 
721 std::vector<std::unique_ptr<BinaryBasicBlock>>
722 IndirectCallPromotion::rewriteCall(
723     BinaryBasicBlock &IndCallBlock, const MCInst &CallInst,
724     MCPlusBuilder::BlocksVectorTy &&ICPcode,
725     const std::vector<MCInst *> &MethodFetchInsns) const {
726   BinaryFunction &Function = *IndCallBlock.getFunction();
727   MCPlusBuilder *MIB = Function.getBinaryContext().MIB.get();
728 
729   // Create new basic blocks with correct code in each one first.
730   std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs;
731   const bool IsTailCallOrJT =
732       (MIB->isTailCall(CallInst) || Function.getJumpTable(CallInst));
733 
734   // Move instructions from the tail of the original call block
735   // to the merge block.
736 
737   // Remember any pseudo instructions following a tail call.  These
738   // must be preserved and moved to the original block.
739   InstructionListType TailInsts;
740   const MCInst *TailInst = &CallInst;
741   if (IsTailCallOrJT)
742     while (TailInst + 1 < &(*IndCallBlock.end()) &&
743            MIB->isPseudo(*(TailInst + 1)))
744       TailInsts.push_back(*++TailInst);
745 
746   InstructionListType MovedInst = IndCallBlock.splitInstructions(&CallInst);
747   // Link new BBs to the original input offset of the BB where the indirect
748   // call site is, so we can map samples recorded in new BBs back to the
749   // original BB seen in the input binary (if using BAT)
750   const uint32_t OrigOffset = IndCallBlock.getInputOffset();
751 
752   IndCallBlock.eraseInstructions(MethodFetchInsns.begin(),
753                                  MethodFetchInsns.end());
754   if (IndCallBlock.empty() ||
755       (!MethodFetchInsns.empty() && MethodFetchInsns.back() == &CallInst))
756     IndCallBlock.addInstructions(ICPcode.front().second.begin(),
757                                  ICPcode.front().second.end());
758   else
759     IndCallBlock.replaceInstruction(std::prev(IndCallBlock.end()),
760                                     ICPcode.front().second);
761   IndCallBlock.addInstructions(TailInsts.begin(), TailInsts.end());
762 
763   for (auto Itr = ICPcode.begin() + 1; Itr != ICPcode.end(); ++Itr) {
764     MCSymbol *&Sym = Itr->first;
765     InstructionListType &Insts = Itr->second;
766     assert(Sym);
767     std::unique_ptr<BinaryBasicBlock> TBB =
768         Function.createBasicBlock(OrigOffset, Sym);
769     for (MCInst &Inst : Insts) // sanitize new instructions.
770       if (MIB->isCall(Inst))
771         MIB->removeAnnotation(Inst, "CallProfile");
772     TBB->addInstructions(Insts.begin(), Insts.end());
773     NewBBs.emplace_back(std::move(TBB));
774   }
775 
776   // Move tail of instructions from after the original call to
777   // the merge block.
778   if (!IsTailCallOrJT)
779     NewBBs.back()->addInstructions(MovedInst.begin(), MovedInst.end());
780 
781   return NewBBs;
782 }
783 
784 BinaryBasicBlock *
785 IndirectCallPromotion::fixCFG(BinaryBasicBlock &IndCallBlock,
786                               const bool IsTailCall, const bool IsJumpTable,
787                               IndirectCallPromotion::BasicBlocksVector &&NewBBs,
788                               const std::vector<Callsite> &Targets) const {
789   BinaryFunction &Function = *IndCallBlock.getFunction();
790   using BinaryBranchInfo = BinaryBasicBlock::BinaryBranchInfo;
791   BinaryBasicBlock *MergeBlock = nullptr;
792 
793   // Scale indirect call counts to the execution count of the original
794   // basic block containing the indirect call.
795   uint64_t TotalCount = IndCallBlock.getKnownExecutionCount();
796   uint64_t TotalIndirectBranches = 0;
797   for (const Callsite &Target : Targets)
798     TotalIndirectBranches += Target.Branches;
799   if (TotalIndirectBranches == 0)
800     TotalIndirectBranches = 1;
801   BinaryBasicBlock::BranchInfoType BBI;
802   BinaryBasicBlock::BranchInfoType ScaledBBI;
803   for (const Callsite &Target : Targets) {
804     const size_t NumEntries =
805         std::max(static_cast<std::size_t>(1UL), Target.JTIndices.size());
806     for (size_t I = 0; I < NumEntries; ++I) {
807       BBI.push_back(
808           BinaryBranchInfo{(Target.Branches + NumEntries - 1) / NumEntries,
809                            (Target.Mispreds + NumEntries - 1) / NumEntries});
810       ScaledBBI.push_back(
811           BinaryBranchInfo{uint64_t(TotalCount * Target.Branches /
812                                     (NumEntries * TotalIndirectBranches)),
813                            uint64_t(TotalCount * Target.Mispreds /
814                                     (NumEntries * TotalIndirectBranches))});
815     }
816   }
817 
818   if (IsJumpTable) {
819     BinaryBasicBlock *NewIndCallBlock = NewBBs.back().get();
820     IndCallBlock.moveAllSuccessorsTo(NewIndCallBlock);
821 
822     std::vector<MCSymbol *> SymTargets;
823     for (const Callsite &Target : Targets) {
824       const size_t NumEntries =
825           std::max(static_cast<std::size_t>(1UL), Target.JTIndices.size());
826       for (size_t I = 0; I < NumEntries; ++I)
827         SymTargets.push_back(Target.To.Sym);
828     }
829     assert(SymTargets.size() > NewBBs.size() - 1 &&
830            "There must be a target symbol associated with each new BB.");
831 
832     for (uint64_t I = 0; I < NewBBs.size(); ++I) {
833       BinaryBasicBlock *SourceBB = I ? NewBBs[I - 1].get() : &IndCallBlock;
834       SourceBB->setExecutionCount(TotalCount);
835 
836       BinaryBasicBlock *TargetBB =
837           Function.getBasicBlockForLabel(SymTargets[I]);
838       SourceBB->addSuccessor(TargetBB, ScaledBBI[I]); // taken
839 
840       TotalCount -= ScaledBBI[I].Count;
841       SourceBB->addSuccessor(NewBBs[I].get(), TotalCount); // fall-through
842 
843       // Update branch info for the indirect jump.
844       BinaryBasicBlock::BinaryBranchInfo &BranchInfo =
845           NewIndCallBlock->getBranchInfo(*TargetBB);
846       if (BranchInfo.Count > BBI[I].Count)
847         BranchInfo.Count -= BBI[I].Count;
848       else
849         BranchInfo.Count = 0;
850 
851       if (BranchInfo.MispredictedCount > BBI[I].MispredictedCount)
852         BranchInfo.MispredictedCount -= BBI[I].MispredictedCount;
853       else
854         BranchInfo.MispredictedCount = 0;
855     }
856   } else {
857     assert(NewBBs.size() >= 2);
858     assert(NewBBs.size() % 2 == 1 || IndCallBlock.succ_empty());
859     assert(NewBBs.size() % 2 == 1 || IsTailCall);
860 
861     auto ScaledBI = ScaledBBI.begin();
862     auto updateCurrentBranchInfo = [&] {
863       assert(ScaledBI != ScaledBBI.end());
864       TotalCount -= ScaledBI->Count;
865       ++ScaledBI;
866     };
867 
868     if (!IsTailCall) {
869       MergeBlock = NewBBs.back().get();
870       IndCallBlock.moveAllSuccessorsTo(MergeBlock);
871     }
872 
873     // Fix up successors and execution counts.
874     updateCurrentBranchInfo();
875     IndCallBlock.addSuccessor(NewBBs[1].get(), TotalCount);
876     IndCallBlock.addSuccessor(NewBBs[0].get(), ScaledBBI[0]);
877 
878     const size_t Adj = IsTailCall ? 1 : 2;
879     for (size_t I = 0; I < NewBBs.size() - Adj; ++I) {
880       assert(TotalCount <= IndCallBlock.getExecutionCount() ||
881              TotalCount <= uint64_t(TotalIndirectBranches));
882       uint64_t ExecCount = ScaledBBI[(I + 1) / 2].Count;
883       if (I % 2 == 0) {
884         if (MergeBlock)
885           NewBBs[I]->addSuccessor(MergeBlock, ScaledBBI[(I + 1) / 2].Count);
886       } else {
887         assert(I + 2 < NewBBs.size());
888         updateCurrentBranchInfo();
889         NewBBs[I]->addSuccessor(NewBBs[I + 2].get(), TotalCount);
890         NewBBs[I]->addSuccessor(NewBBs[I + 1].get(), ScaledBBI[(I + 1) / 2]);
891         ExecCount += TotalCount;
892       }
893       NewBBs[I]->setExecutionCount(ExecCount);
894     }
895 
896     if (MergeBlock) {
897       // Arrange for the MergeBlock to be the fallthrough for the first
898       // promoted call block.
899       std::unique_ptr<BinaryBasicBlock> MBPtr;
900       std::swap(MBPtr, NewBBs.back());
901       NewBBs.pop_back();
902       NewBBs.emplace(NewBBs.begin() + 1, std::move(MBPtr));
903       // TODO: is COUNT_FALLTHROUGH_EDGE the right thing here?
904       NewBBs.back()->addSuccessor(MergeBlock, TotalCount); // uncond branch
905     }
906   }
907 
908   // Update the execution count.
909   NewBBs.back()->setExecutionCount(TotalCount);
910 
911   // Update BB and BB layout.
912   Function.insertBasicBlocks(&IndCallBlock, std::move(NewBBs));
913   assert(Function.validateCFG());
914 
915   return MergeBlock;
916 }
917 
918 size_t IndirectCallPromotion::canPromoteCallsite(
919     const BinaryBasicBlock &BB, const MCInst &Inst,
920     const std::vector<Callsite> &Targets, uint64_t NumCalls) {
921   BinaryFunction *BF = BB.getFunction();
922   const BinaryContext &BC = BF->getBinaryContext();
923 
924   if (BB.getKnownExecutionCount() < opts::ExecutionCountThreshold)
925     return 0;
926 
927   const bool IsJumpTable = BF->getJumpTable(Inst);
928 
929   auto computeStats = [&](size_t N) {
930     for (size_t I = 0; I < N; ++I)
931       if (IsJumpTable)
932         TotalNumFrequentJmps += Targets[I].Branches;
933       else
934         TotalNumFrequentCalls += Targets[I].Branches;
935   };
936 
937   // If we have no targets (or no calls), skip this callsite.
938   if (Targets.empty() || !NumCalls) {
939     if (opts::Verbosity >= 1) {
940       const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
941       outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx << " in "
942              << BB.getName() << ", calls = " << NumCalls
943              << ", targets empty or NumCalls == 0.\n";
944     }
945     return 0;
946   }
947 
948   size_t TopN = opts::ICPTopN;
949   if (IsJumpTable)
950     TopN = opts::ICPJumpTablesTopN ? opts::ICPJumpTablesTopN : TopN;
951   else
952     TopN = opts::ICPCallsTopN ? opts::ICPCallsTopN : TopN;
953 
954   const size_t TrialN = TopN ? std::min(TopN, Targets.size()) : Targets.size();
955 
956   if (opts::ICPTopCallsites > 0) {
957     if (!BC.MIB->hasAnnotation(Inst, "DoICP"))
958       return 0;
959   }
960 
961   // Pick the top N targets.
962   uint64_t TotalMispredictsTopN = 0;
963   size_t N = 0;
964 
965   if (opts::ICPUseMispredicts &&
966       (!IsJumpTable || opts::ICPJumpTablesByTarget)) {
967     // Count total number of mispredictions for (at most) the top N targets.
968     // We may choose a smaller N (TrialN vs. N) if the frequency threshold
969     // is exceeded by fewer targets.
970     double Threshold = double(opts::ICPMispredictThreshold);
971     for (size_t I = 0; I < TrialN && Threshold > 0; ++I, ++N) {
972       Threshold -= (100.0 * Targets[I].Mispreds) / NumCalls;
973       TotalMispredictsTopN += Targets[I].Mispreds;
974     }
975     computeStats(N);
976 
977     // Compute the misprediction frequency of the top N call targets.  If this
978     // frequency is greater than the threshold, we should try ICP on this
979     // callsite.
980     const double TopNFrequency = (100.0 * TotalMispredictsTopN) / NumCalls;
981     if (TopNFrequency == 0 || TopNFrequency < opts::ICPMispredictThreshold) {
982       if (opts::Verbosity >= 1) {
983         const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
984         outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx
985                << " in " << BB.getName() << ", calls = " << NumCalls
986                << ", top N mis. frequency " << format("%.1f", TopNFrequency)
987                << "% < " << opts::ICPMispredictThreshold << "%\n";
988       }
989       return 0;
990     }
991   } else {
992     size_t MaxTargets = 0;
993 
994     // Count total number of calls for (at most) the top N targets.
995     // We may choose a smaller N (TrialN vs. N) if the frequency threshold
996     // is exceeded by fewer targets.
997     const unsigned TotalThreshold = IsJumpTable
998                                         ? opts::ICPJTTotalPercentThreshold
999                                         : opts::ICPCallsTotalPercentThreshold;
1000     const unsigned RemainingThreshold =
1001         IsJumpTable ? opts::ICPJTRemainingPercentThreshold
1002                     : opts::ICPCallsRemainingPercentThreshold;
1003     uint64_t NumRemainingCalls = NumCalls;
1004     for (size_t I = 0; I < TrialN; ++I, ++MaxTargets) {
1005       if (100 * Targets[I].Branches < NumCalls * TotalThreshold)
1006         break;
1007       if (100 * Targets[I].Branches < NumRemainingCalls * RemainingThreshold)
1008         break;
1009       if (N + (Targets[I].JTIndices.empty() ? 1 : Targets[I].JTIndices.size()) >
1010           TrialN)
1011         break;
1012       TotalMispredictsTopN += Targets[I].Mispreds;
1013       NumRemainingCalls -= Targets[I].Branches;
1014       N += Targets[I].JTIndices.empty() ? 1 : Targets[I].JTIndices.size();
1015     }
1016     computeStats(MaxTargets);
1017 
1018     // Don't check misprediction frequency for jump tables -- we don't really
1019     // care as long as we are saving loads from the jump table.
1020     if (!IsJumpTable || opts::ICPJumpTablesByTarget) {
1021       // Compute the misprediction frequency of the top N call targets.  If
1022       // this frequency is less than the threshold, we should skip ICP at
1023       // this callsite.
1024       const double TopNMispredictFrequency =
1025           (100.0 * TotalMispredictsTopN) / NumCalls;
1026 
1027       if (TopNMispredictFrequency < opts::ICPMispredictThreshold) {
1028         if (opts::Verbosity >= 1) {
1029           const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
1030           outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx
1031                  << " in " << BB.getName() << ", calls = " << NumCalls
1032                  << ", top N mispredict frequency "
1033                  << format("%.1f", TopNMispredictFrequency) << "% < "
1034                  << opts::ICPMispredictThreshold << "%\n";
1035         }
1036         return 0;
1037       }
1038     }
1039   }
1040 
1041   // Filter functions that can have ICP applied (for debugging)
1042   if (!opts::ICPFuncsList.empty()) {
1043     for (std::string &Name : opts::ICPFuncsList)
1044       if (BF->hasName(Name))
1045         return N;
1046     return 0;
1047   }
1048 
1049   return N;
1050 }
1051 
1052 void IndirectCallPromotion::printCallsiteInfo(
1053     const BinaryBasicBlock &BB, const MCInst &Inst,
1054     const std::vector<Callsite> &Targets, const size_t N,
1055     uint64_t NumCalls) const {
1056   BinaryContext &BC = BB.getFunction()->getBinaryContext();
1057   const bool IsTailCall = BC.MIB->isTailCall(Inst);
1058   const bool IsJumpTable = BB.getFunction()->getJumpTable(Inst);
1059   const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
1060 
1061   outs() << "BOLT-INFO: ICP candidate branch info: " << *BB.getFunction()
1062          << " @ " << InstIdx << " in " << BB.getName()
1063          << " -> calls = " << NumCalls
1064          << (IsTailCall ? " (tail)" : (IsJumpTable ? " (jump table)" : ""))
1065          << "\n";
1066   for (size_t I = 0; I < N; I++) {
1067     const double Frequency = 100.0 * Targets[I].Branches / NumCalls;
1068     const double MisFrequency = 100.0 * Targets[I].Mispreds / NumCalls;
1069     outs() << "BOLT-INFO:   ";
1070     if (Targets[I].To.Sym)
1071       outs() << Targets[I].To.Sym->getName();
1072     else
1073       outs() << Targets[I].To.Addr;
1074     outs() << ", calls = " << Targets[I].Branches
1075            << ", mispreds = " << Targets[I].Mispreds
1076            << ", taken freq = " << format("%.1f", Frequency) << "%"
1077            << ", mis. freq = " << format("%.1f", MisFrequency) << "%";
1078     bool First = true;
1079     for (uint64_t JTIndex : Targets[I].JTIndices) {
1080       outs() << (First ? ", indices = " : ", ") << JTIndex;
1081       First = false;
1082     }
1083     outs() << "\n";
1084   }
1085 
1086   LLVM_DEBUG({
1087     dbgs() << "BOLT-INFO: ICP original call instruction:";
1088     BC.printInstruction(dbgs(), Inst, Targets[0].From.Addr, nullptr, true);
1089   });
1090 }
1091 
1092 void IndirectCallPromotion::runOnFunctions(BinaryContext &BC) {
1093   if (opts::ICP == ICP_NONE)
1094     return;
1095 
1096   auto &BFs = BC.getBinaryFunctions();
1097 
1098   const bool OptimizeCalls = (opts::ICP == ICP_CALLS || opts::ICP == ICP_ALL);
1099   const bool OptimizeJumpTables =
1100       (opts::ICP == ICP_JUMP_TABLES || opts::ICP == ICP_ALL);
1101 
1102   std::unique_ptr<RegAnalysis> RA;
1103   std::unique_ptr<BinaryFunctionCallGraph> CG;
1104   if (OptimizeJumpTables) {
1105     CG.reset(new BinaryFunctionCallGraph(buildCallGraph(BC)));
1106     RA.reset(new RegAnalysis(BC, &BFs, &*CG));
1107   }
1108 
1109   // If icp-top-callsites is enabled, compute the total number of indirect
1110   // calls and then optimize the hottest callsites that contribute to that
1111   // total.
1112   SetVector<BinaryFunction *> Functions;
1113   if (opts::ICPTopCallsites == 0) {
1114     for (auto &KV : BFs)
1115       Functions.insert(&KV.second);
1116   } else {
1117     using IndirectCallsite = std::tuple<uint64_t, MCInst *, BinaryFunction *>;
1118     std::vector<IndirectCallsite> IndirectCalls;
1119     size_t TotalIndirectCalls = 0;
1120 
1121     // Find all the indirect callsites.
1122     for (auto &BFIt : BFs) {
1123       BinaryFunction &Function = BFIt.second;
1124 
1125       if (!Function.isSimple() || Function.isIgnored() ||
1126           !Function.hasProfile())
1127         continue;
1128 
1129       const bool HasLayout = !Function.layout_empty();
1130 
1131       for (BinaryBasicBlock &BB : Function) {
1132         if (HasLayout && Function.isSplit() && BB.isCold())
1133           continue;
1134 
1135         for (MCInst &Inst : BB) {
1136           const bool IsJumpTable = Function.getJumpTable(Inst);
1137           const bool HasIndirectCallProfile =
1138               BC.MIB->hasAnnotation(Inst, "CallProfile");
1139           const bool IsDirectCall =
1140               (BC.MIB->isCall(Inst) && BC.MIB->getTargetSymbol(Inst, 0));
1141 
1142           if (!IsDirectCall &&
1143               ((HasIndirectCallProfile && !IsJumpTable && OptimizeCalls) ||
1144                (IsJumpTable && OptimizeJumpTables))) {
1145             uint64_t NumCalls = 0;
1146             for (const Callsite &BInfo : getCallTargets(BB, Inst))
1147               NumCalls += BInfo.Branches;
1148             IndirectCalls.push_back(
1149                 std::make_tuple(NumCalls, &Inst, &Function));
1150             TotalIndirectCalls += NumCalls;
1151           }
1152         }
1153       }
1154     }
1155 
1156     // Sort callsites by execution count.
1157     std::sort(IndirectCalls.rbegin(), IndirectCalls.rend());
1158 
1159     // Find callsites that contribute to the top "opts::ICPTopCallsites"%
1160     // number of calls.
1161     const float TopPerc = opts::ICPTopCallsites / 100.0f;
1162     int64_t MaxCalls = TotalIndirectCalls * TopPerc;
1163     uint64_t LastFreq = std::numeric_limits<uint64_t>::max();
1164     size_t Num = 0;
1165     for (const IndirectCallsite &IC : IndirectCalls) {
1166       const uint64_t CurFreq = std::get<0>(IC);
1167       // Once we decide to stop, include at least all branches that share the
1168       // same frequency of the last one to avoid non-deterministic behavior
1169       // (e.g. turning on/off ICP depending on the order of functions)
1170       if (MaxCalls <= 0 && CurFreq != LastFreq)
1171         break;
1172       MaxCalls -= CurFreq;
1173       LastFreq = CurFreq;
1174       BC.MIB->addAnnotation(*std::get<1>(IC), "DoICP", true);
1175       Functions.insert(std::get<2>(IC));
1176       ++Num;
1177     }
1178     outs() << "BOLT-INFO: ICP Total indirect calls = " << TotalIndirectCalls
1179            << ", " << Num << " callsites cover " << opts::ICPTopCallsites
1180            << "% of all indirect calls\n";
1181   }
1182 
1183   for (BinaryFunction *FuncPtr : Functions) {
1184     BinaryFunction &Function = *FuncPtr;
1185 
1186     if (!Function.isSimple() || Function.isIgnored() || !Function.hasProfile())
1187       continue;
1188 
1189     const bool HasLayout = !Function.layout_empty();
1190 
1191     // Total number of indirect calls issued from the current Function.
1192     // (a fraction of TotalIndirectCalls)
1193     uint64_t FuncTotalIndirectCalls = 0;
1194     uint64_t FuncTotalIndirectJmps = 0;
1195 
1196     std::vector<BinaryBasicBlock *> BBs;
1197     for (BinaryBasicBlock &BB : Function) {
1198       // Skip indirect calls in cold blocks.
1199       if (!HasLayout || !Function.isSplit() || !BB.isCold())
1200         BBs.push_back(&BB);
1201     }
1202     if (BBs.empty())
1203       continue;
1204 
1205     DataflowInfoManager Info(Function, RA.get(), nullptr);
1206     while (!BBs.empty()) {
1207       BinaryBasicBlock *BB = BBs.back();
1208       BBs.pop_back();
1209 
1210       for (unsigned Idx = 0; Idx < BB->size(); ++Idx) {
1211         MCInst &Inst = BB->getInstructionAtIndex(Idx);
1212         const ptrdiff_t InstIdx = &Inst - &(*BB->begin());
1213         const bool IsTailCall = BC.MIB->isTailCall(Inst);
1214         const bool HasIndirectCallProfile =
1215             BC.MIB->hasAnnotation(Inst, "CallProfile");
1216         const bool IsJumpTable = Function.getJumpTable(Inst);
1217 
1218         if (BC.MIB->isCall(Inst))
1219           TotalCalls += BB->getKnownExecutionCount();
1220 
1221         if (IsJumpTable && !OptimizeJumpTables)
1222           continue;
1223 
1224         if (!IsJumpTable && (!HasIndirectCallProfile || !OptimizeCalls))
1225           continue;
1226 
1227         // Ignore direct calls.
1228         if (BC.MIB->isCall(Inst) && BC.MIB->getTargetSymbol(Inst, 0))
1229           continue;
1230 
1231         assert((BC.MIB->isCall(Inst) || BC.MIB->isIndirectBranch(Inst)) &&
1232                "expected a call or an indirect jump instruction");
1233 
1234         if (IsJumpTable)
1235           ++TotalJumpTableCallsites;
1236         else
1237           ++TotalIndirectCallsites;
1238 
1239         std::vector<Callsite> Targets = getCallTargets(*BB, Inst);
1240 
1241         // Compute the total number of calls from this particular callsite.
1242         uint64_t NumCalls = 0;
1243         for (const Callsite &BInfo : Targets)
1244           NumCalls += BInfo.Branches;
1245         if (!IsJumpTable)
1246           FuncTotalIndirectCalls += NumCalls;
1247         else
1248           FuncTotalIndirectJmps += NumCalls;
1249 
1250         // If FLAGS regs is alive after this jmp site, do not try
1251         // promoting because we will clobber FLAGS.
1252         if (IsJumpTable) {
1253           ErrorOr<const BitVector &> State =
1254               Info.getLivenessAnalysis().getStateBefore(Inst);
1255           if (!State || (State && (*State)[BC.MIB->getFlagsReg()])) {
1256             if (opts::Verbosity >= 1)
1257               outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1258                      << InstIdx << " in " << BB->getName()
1259                      << ", calls = " << NumCalls
1260                      << (State ? ", cannot clobber flags reg.\n"
1261                                : ", no liveness data available.\n");
1262             continue;
1263           }
1264         }
1265 
1266         // Should this callsite be optimized?  Return the number of targets
1267         // to use when promoting this call.  A value of zero means to skip
1268         // this callsite.
1269         size_t N = canPromoteCallsite(*BB, Inst, Targets, NumCalls);
1270 
1271         // If it is a jump table and it failed to meet our initial threshold,
1272         // proceed to findCallTargetSymbols -- it may reevaluate N if
1273         // memory profile is present
1274         if (!N && !IsJumpTable)
1275           continue;
1276 
1277         if (opts::Verbosity >= 1)
1278           printCallsiteInfo(*BB, Inst, Targets, N, NumCalls);
1279 
1280         // Find MCSymbols or absolute addresses for each call target.
1281         MCInst *TargetFetchInst = nullptr;
1282         const SymTargetsType SymTargets =
1283             findCallTargetSymbols(Targets, N, *BB, Inst, TargetFetchInst);
1284 
1285         // findCallTargetSymbols may have changed N if mem profile is available
1286         // for jump tables
1287         if (!N)
1288           continue;
1289 
1290         LLVM_DEBUG(printDecision(dbgs(), Targets, N));
1291 
1292         // If we can't resolve any of the target symbols, punt on this callsite.
1293         // TODO: can this ever happen?
1294         if (SymTargets.size() < N) {
1295           const size_t LastTarget = SymTargets.size();
1296           if (opts::Verbosity >= 1)
1297             outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1298                    << InstIdx << " in " << BB->getName()
1299                    << ", calls = " << NumCalls
1300                    << ", ICP failed to find target symbol for "
1301                    << Targets[LastTarget].To.Sym->getName() << "\n";
1302           continue;
1303         }
1304 
1305         MethodInfoType MethodInfo;
1306 
1307         if (!IsJumpTable) {
1308           MethodInfo = maybeGetVtableSyms(*BB, Inst, SymTargets);
1309           TotalMethodLoadsEliminated += MethodInfo.first.empty() ? 0 : 1;
1310           LLVM_DEBUG(dbgs()
1311                      << "BOLT-INFO: ICP "
1312                      << (!MethodInfo.first.empty() ? "found" : "did not find")
1313                      << " vtables for all methods.\n");
1314         } else if (TargetFetchInst) {
1315           ++TotalIndexBasedJumps;
1316           MethodInfo.second.push_back(TargetFetchInst);
1317         }
1318 
1319         // Generate new promoted call code for this callsite.
1320         MCPlusBuilder::BlocksVectorTy ICPcode =
1321             (IsJumpTable && !opts::ICPJumpTablesByTarget)
1322                 ? BC.MIB->jumpTablePromotion(Inst, SymTargets,
1323                                              MethodInfo.second, BC.Ctx.get())
1324                 : BC.MIB->indirectCallPromotion(
1325                       Inst, SymTargets, MethodInfo.first, MethodInfo.second,
1326                       opts::ICPOldCodeSequence, BC.Ctx.get());
1327 
1328         if (ICPcode.empty()) {
1329           if (opts::Verbosity >= 1)
1330             outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1331                    << InstIdx << " in " << BB->getName()
1332                    << ", calls = " << NumCalls
1333                    << ", unable to generate promoted call code.\n";
1334           continue;
1335         }
1336 
1337         LLVM_DEBUG({
1338           uint64_t Offset = Targets[0].From.Addr;
1339           dbgs() << "BOLT-INFO: ICP indirect call code:\n";
1340           for (const auto &entry : ICPcode) {
1341             const MCSymbol *const &Sym = entry.first;
1342             const InstructionListType &Insts = entry.second;
1343             if (Sym)
1344               dbgs() << Sym->getName() << ":\n";
1345             Offset = BC.printInstructions(dbgs(), Insts.begin(), Insts.end(),
1346                                           Offset);
1347           }
1348           dbgs() << "---------------------------------------------------\n";
1349         });
1350 
1351         // Rewrite the CFG with the newly generated ICP code.
1352         std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs =
1353             rewriteCall(*BB, Inst, std::move(ICPcode), MethodInfo.second);
1354 
1355         // Fix the CFG after inserting the new basic blocks.
1356         BinaryBasicBlock *MergeBlock =
1357             fixCFG(*BB, IsTailCall, IsJumpTable, std::move(NewBBs), Targets);
1358 
1359         // Since the tail of the original block was split off and it may contain
1360         // additional indirect calls, we must add the merge block to the set of
1361         // blocks to process.
1362         if (MergeBlock)
1363           BBs.push_back(MergeBlock);
1364 
1365         if (opts::Verbosity >= 1)
1366           outs() << "BOLT-INFO: ICP succeeded in " << Function << " @ "
1367                  << InstIdx << " in " << BB->getName()
1368                  << " -> calls = " << NumCalls << "\n";
1369 
1370         if (IsJumpTable)
1371           ++TotalOptimizedJumpTableCallsites;
1372         else
1373           ++TotalOptimizedIndirectCallsites;
1374 
1375         Modified.insert(&Function);
1376       }
1377     }
1378     TotalIndirectCalls += FuncTotalIndirectCalls;
1379     TotalIndirectJmps += FuncTotalIndirectJmps;
1380   }
1381 
1382   outs() << "BOLT-INFO: ICP total indirect callsites with profile = "
1383          << TotalIndirectCallsites << "\n"
1384          << "BOLT-INFO: ICP total jump table callsites = "
1385          << TotalJumpTableCallsites << "\n"
1386          << "BOLT-INFO: ICP total number of calls = " << TotalCalls << "\n"
1387          << "BOLT-INFO: ICP percentage of calls that are indirect = "
1388          << format("%.1f", (100.0 * TotalIndirectCalls) / TotalCalls) << "%\n"
1389          << "BOLT-INFO: ICP percentage of indirect calls that can be "
1390             "optimized = "
1391          << format("%.1f", (100.0 * TotalNumFrequentCalls) /
1392                                std::max<size_t>(TotalIndirectCalls, 1))
1393          << "%\n"
1394          << "BOLT-INFO: ICP percentage of indirect callsites that are "
1395             "optimized = "
1396          << format("%.1f", (100.0 * TotalOptimizedIndirectCallsites) /
1397                                std::max<uint64_t>(TotalIndirectCallsites, 1))
1398          << "%\n"
1399          << "BOLT-INFO: ICP number of method load elimination candidates = "
1400          << TotalMethodLoadEliminationCandidates << "\n"
1401          << "BOLT-INFO: ICP percentage of method calls candidates that have "
1402             "loads eliminated = "
1403          << format("%.1f", (100.0 * TotalMethodLoadsEliminated) /
1404                                std::max<uint64_t>(
1405                                    TotalMethodLoadEliminationCandidates, 1))
1406          << "%\n"
1407          << "BOLT-INFO: ICP percentage of indirect branches that are "
1408             "optimized = "
1409          << format("%.1f", (100.0 * TotalNumFrequentJmps) /
1410                                std::max<uint64_t>(TotalIndirectJmps, 1))
1411          << "%\n"
1412          << "BOLT-INFO: ICP percentage of jump table callsites that are "
1413          << "optimized = "
1414          << format("%.1f", (100.0 * TotalOptimizedJumpTableCallsites) /
1415                                std::max<uint64_t>(TotalJumpTableCallsites, 1))
1416          << "%\n"
1417          << "BOLT-INFO: ICP number of jump table callsites that can use hot "
1418          << "indices = " << TotalIndexBasedCandidates << "\n"
1419          << "BOLT-INFO: ICP percentage of jump table callsites that use hot "
1420             "indices = "
1421          << format("%.1f", (100.0 * TotalIndexBasedJumps) /
1422                                std::max<uint64_t>(TotalIndexBasedCandidates, 1))
1423          << "%\n";
1424 
1425   (void)verifyProfile;
1426 #ifndef NDEBUG
1427   verifyProfile(BFs);
1428 #endif
1429 }
1430 
1431 } // namespace bolt
1432 } // namespace llvm
1433