xref: /llvm-project/bolt/lib/Passes/LongJmp.cpp (revision a73b50ad0649d635433547ff51cd73d2ce9f085b)
1 //===- bolt/Passes/LongJmp.cpp --------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the LongJmpPass class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "bolt/Passes/LongJmp.h"
14 
15 #define DEBUG_TYPE "longjmp"
16 
17 using namespace llvm;
18 
19 namespace opts {
20 extern cl::OptionCategory BoltOptCategory;
21 extern llvm::cl::opt<unsigned> AlignText;
22 extern cl::opt<unsigned> AlignFunctions;
23 extern cl::opt<bool> UseOldText;
24 extern cl::opt<bool> HotFunctionsAtEnd;
25 
26 static cl::opt<bool>
27 GroupStubs("group-stubs",
28   cl::desc("share stubs across functions"),
29   cl::init(true),
30   cl::ZeroOrMore,
31   cl::cat(BoltOptCategory));
32 }
33 
34 namespace llvm {
35 namespace bolt {
36 
37 namespace {
38 constexpr unsigned ColdFragAlign = 16;
39 
40 void relaxStubToShortJmp(BinaryBasicBlock &StubBB, const MCSymbol *Tgt) {
41   const BinaryContext &BC = StubBB.getFunction()->getBinaryContext();
42   InstructionListType Seq;
43   BC.MIB->createShortJmp(Seq, Tgt, BC.Ctx.get());
44   StubBB.clear();
45   StubBB.addInstructions(Seq.begin(), Seq.end());
46 }
47 
48 void relaxStubToLongJmp(BinaryBasicBlock &StubBB, const MCSymbol *Tgt) {
49   const BinaryContext &BC = StubBB.getFunction()->getBinaryContext();
50   InstructionListType Seq;
51   BC.MIB->createLongJmp(Seq, Tgt, BC.Ctx.get());
52   StubBB.clear();
53   StubBB.addInstructions(Seq.begin(), Seq.end());
54 }
55 
56 BinaryBasicBlock *getBBAtHotColdSplitPoint(BinaryFunction &Func) {
57   if (!Func.isSplit() || Func.empty())
58     return nullptr;
59 
60   assert(!(*Func.begin()).isCold() && "Entry cannot be cold");
61   for (auto I = Func.layout_begin(), E = Func.layout_end(); I != E; ++I) {
62     auto Next = std::next(I);
63     if (Next != E && (*Next)->isCold())
64       return *I;
65   }
66   llvm_unreachable("No hot-colt split point found");
67 }
68 
69 bool shouldInsertStub(const BinaryContext &BC, const MCInst &Inst) {
70   return (BC.MIB->isBranch(Inst) || BC.MIB->isCall(Inst)) &&
71          !BC.MIB->isIndirectBranch(Inst) && !BC.MIB->isIndirectCall(Inst);
72 }
73 
74 } // end anonymous namespace
75 
76 std::pair<std::unique_ptr<BinaryBasicBlock>, MCSymbol *>
77 LongJmpPass::createNewStub(BinaryBasicBlock &SourceBB, const MCSymbol *TgtSym,
78                            bool TgtIsFunc, uint64_t AtAddress) {
79   BinaryFunction &Func = *SourceBB.getFunction();
80   const BinaryContext &BC = Func.getBinaryContext();
81   const bool IsCold = SourceBB.isCold();
82   MCSymbol *StubSym = BC.Ctx->createNamedTempSymbol("Stub");
83   std::unique_ptr<BinaryBasicBlock> StubBB = Func.createBasicBlock(0, StubSym);
84   MCInst Inst;
85   BC.MIB->createUncondBranch(Inst, TgtSym, BC.Ctx.get());
86   if (TgtIsFunc)
87     BC.MIB->convertJmpToTailCall(Inst);
88   StubBB->addInstruction(Inst);
89   StubBB->setExecutionCount(0);
90 
91   // Register this in stubs maps
92   auto registerInMap = [&](StubGroupsTy &Map) {
93     StubGroupTy &StubGroup = Map[TgtSym];
94     StubGroup.insert(
95         std::lower_bound(
96             StubGroup.begin(), StubGroup.end(),
97             std::make_pair(AtAddress, nullptr),
98             [&](const std::pair<uint64_t, BinaryBasicBlock *> &LHS,
99                 const std::pair<uint64_t, BinaryBasicBlock *> &RHS) {
100               return LHS.first < RHS.first;
101             }),
102         std::make_pair(AtAddress, StubBB.get()));
103   };
104 
105   Stubs[&Func].insert(StubBB.get());
106   StubBits[StubBB.get()] = BC.MIB->getUncondBranchEncodingSize();
107   if (IsCold) {
108     registerInMap(ColdLocalStubs[&Func]);
109     if (opts::GroupStubs && TgtIsFunc)
110       registerInMap(ColdStubGroups);
111     ++NumColdStubs;
112   } else {
113     registerInMap(HotLocalStubs[&Func]);
114     if (opts::GroupStubs && TgtIsFunc)
115       registerInMap(HotStubGroups);
116     ++NumHotStubs;
117   }
118 
119   return std::make_pair(std::move(StubBB), StubSym);
120 }
121 
122 BinaryBasicBlock *LongJmpPass::lookupStubFromGroup(
123     const StubGroupsTy &StubGroups, const BinaryFunction &Func,
124     const MCInst &Inst, const MCSymbol *TgtSym, uint64_t DotAddress) const {
125   const BinaryContext &BC = Func.getBinaryContext();
126   auto CandidatesIter = StubGroups.find(TgtSym);
127   if (CandidatesIter == StubGroups.end())
128     return nullptr;
129   const StubGroupTy &Candidates = CandidatesIter->second;
130   if (Candidates.empty())
131     return nullptr;
132   auto Cand = std::lower_bound(
133       Candidates.begin(), Candidates.end(), std::make_pair(DotAddress, nullptr),
134       [&](const std::pair<uint64_t, BinaryBasicBlock *> &LHS,
135           const std::pair<uint64_t, BinaryBasicBlock *> &RHS) {
136         return LHS.first < RHS.first;
137       });
138   if (Cand == Candidates.end())
139     return nullptr;
140   if (Cand != Candidates.begin()) {
141     const StubTy *LeftCand = std::prev(Cand);
142     if (Cand->first - DotAddress > DotAddress - LeftCand->first)
143       Cand = LeftCand;
144   }
145   int BitsAvail = BC.MIB->getPCRelEncodingSize(Inst) - 1;
146   uint64_t Mask = ~((1ULL << BitsAvail) - 1);
147   uint64_t PCRelTgtAddress = Cand->first;
148   PCRelTgtAddress = DotAddress > PCRelTgtAddress ? DotAddress - PCRelTgtAddress
149                                                  : PCRelTgtAddress - DotAddress;
150   LLVM_DEBUG({
151     if (Candidates.size() > 1)
152       dbgs() << "Considering stub group with " << Candidates.size()
153              << " candidates. DotAddress is " << Twine::utohexstr(DotAddress)
154              << ", chosen candidate address is "
155              << Twine::utohexstr(Cand->first) << "\n";
156   });
157   return PCRelTgtAddress & Mask ? nullptr : Cand->second;
158 }
159 
160 BinaryBasicBlock *
161 LongJmpPass::lookupGlobalStub(const BinaryBasicBlock &SourceBB,
162                               const MCInst &Inst, const MCSymbol *TgtSym,
163                               uint64_t DotAddress) const {
164   const BinaryFunction &Func = *SourceBB.getFunction();
165   const StubGroupsTy &StubGroups =
166       SourceBB.isCold() ? ColdStubGroups : HotStubGroups;
167   return lookupStubFromGroup(StubGroups, Func, Inst, TgtSym, DotAddress);
168 }
169 
170 BinaryBasicBlock *LongJmpPass::lookupLocalStub(const BinaryBasicBlock &SourceBB,
171                                                const MCInst &Inst,
172                                                const MCSymbol *TgtSym,
173                                                uint64_t DotAddress) const {
174   const BinaryFunction &Func = *SourceBB.getFunction();
175   const DenseMap<const BinaryFunction *, StubGroupsTy> &StubGroups =
176       SourceBB.isCold() ? ColdLocalStubs : HotLocalStubs;
177   const auto Iter = StubGroups.find(&Func);
178   if (Iter == StubGroups.end())
179     return nullptr;
180   return lookupStubFromGroup(Iter->second, Func, Inst, TgtSym, DotAddress);
181 }
182 
183 std::unique_ptr<BinaryBasicBlock>
184 LongJmpPass::replaceTargetWithStub(BinaryBasicBlock &BB, MCInst &Inst,
185                                    uint64_t DotAddress,
186                                    uint64_t StubCreationAddress) {
187   const BinaryFunction &Func = *BB.getFunction();
188   const BinaryContext &BC = Func.getBinaryContext();
189   std::unique_ptr<BinaryBasicBlock> NewBB;
190   const MCSymbol *TgtSym = BC.MIB->getTargetSymbol(Inst);
191   assert(TgtSym && "getTargetSymbol failed");
192 
193   BinaryBasicBlock::BinaryBranchInfo BI{0, 0};
194   BinaryBasicBlock *TgtBB = BB.getSuccessor(TgtSym, BI);
195   auto LocalStubsIter = Stubs.find(&Func);
196 
197   // If already using stub and the stub is from another function, create a local
198   // stub, since the foreign stub is now out of range
199   if (!TgtBB) {
200     auto SSIter = SharedStubs.find(TgtSym);
201     if (SSIter != SharedStubs.end()) {
202       TgtSym = BC.MIB->getTargetSymbol(*SSIter->second->begin());
203       --NumSharedStubs;
204     }
205   } else if (LocalStubsIter != Stubs.end() &&
206              LocalStubsIter->second.count(TgtBB)) {
207     // If we are replacing a local stub (because it is now out of range),
208     // use its target instead of creating a stub to jump to another stub
209     TgtSym = BC.MIB->getTargetSymbol(*TgtBB->begin());
210     TgtBB = BB.getSuccessor(TgtSym, BI);
211   }
212 
213   BinaryBasicBlock *StubBB = lookupLocalStub(BB, Inst, TgtSym, DotAddress);
214   // If not found, look it up in globally shared stub maps if it is a function
215   // call (TgtBB is not set)
216   if (!StubBB && !TgtBB) {
217     StubBB = lookupGlobalStub(BB, Inst, TgtSym, DotAddress);
218     if (StubBB) {
219       SharedStubs[StubBB->getLabel()] = StubBB;
220       ++NumSharedStubs;
221     }
222   }
223   MCSymbol *StubSymbol = StubBB ? StubBB->getLabel() : nullptr;
224 
225   if (!StubBB) {
226     std::tie(NewBB, StubSymbol) =
227         createNewStub(BB, TgtSym, /*is func?*/ !TgtBB, StubCreationAddress);
228     StubBB = NewBB.get();
229   }
230 
231   // Local branch
232   if (TgtBB) {
233     uint64_t OrigCount = BI.Count;
234     uint64_t OrigMispreds = BI.MispredictedCount;
235     BB.replaceSuccessor(TgtBB, StubBB, OrigCount, OrigMispreds);
236     StubBB->setExecutionCount(StubBB->getExecutionCount() + OrigCount);
237     if (NewBB) {
238       StubBB->addSuccessor(TgtBB, OrigCount, OrigMispreds);
239       StubBB->setIsCold(BB.isCold());
240     }
241     // Call / tail call
242   } else {
243     StubBB->setExecutionCount(StubBB->getExecutionCount() +
244                               BB.getExecutionCount());
245     if (NewBB) {
246       assert(TgtBB == nullptr);
247       StubBB->setIsCold(BB.isCold());
248       // Set as entry point because this block is valid but we have no preds
249       StubBB->getFunction()->addEntryPoint(*StubBB);
250     }
251   }
252   BC.MIB->replaceBranchTarget(Inst, StubSymbol, BC.Ctx.get());
253 
254   return NewBB;
255 }
256 
257 void LongJmpPass::updateStubGroups() {
258   auto update = [&](StubGroupsTy &StubGroups) {
259     for (auto &KeyVal : StubGroups) {
260       for (StubTy &Elem : KeyVal.second)
261         Elem.first = BBAddresses[Elem.second];
262       std::sort(KeyVal.second.begin(), KeyVal.second.end(),
263                 [&](const std::pair<uint64_t, BinaryBasicBlock *> &LHS,
264                     const std::pair<uint64_t, BinaryBasicBlock *> &RHS) {
265                   return LHS.first < RHS.first;
266                 });
267     }
268   };
269 
270   for (auto &KeyVal : HotLocalStubs)
271     update(KeyVal.second);
272   for (auto &KeyVal : ColdLocalStubs)
273     update(KeyVal.second);
274   update(HotStubGroups);
275   update(ColdStubGroups);
276 }
277 
278 void LongJmpPass::tentativeBBLayout(const BinaryFunction &Func) {
279   const BinaryContext &BC = Func.getBinaryContext();
280   uint64_t HotDot = HotAddresses[&Func];
281   uint64_t ColdDot = ColdAddresses[&Func];
282   bool Cold = false;
283   for (BinaryBasicBlock *BB : Func.layout()) {
284     if (Cold || BB->isCold()) {
285       Cold = true;
286       BBAddresses[BB] = ColdDot;
287       ColdDot += BC.computeCodeSize(BB->begin(), BB->end());
288     } else {
289       BBAddresses[BB] = HotDot;
290       HotDot += BC.computeCodeSize(BB->begin(), BB->end());
291     }
292   }
293 }
294 
295 uint64_t LongJmpPass::tentativeLayoutRelocColdPart(
296     const BinaryContext &BC, std::vector<BinaryFunction *> &SortedFunctions,
297     uint64_t DotAddress) {
298   DotAddress = alignTo(DotAddress, llvm::Align(opts::AlignFunctions));
299   for (BinaryFunction *Func : SortedFunctions) {
300     if (!Func->isSplit())
301       continue;
302     DotAddress = alignTo(DotAddress, BinaryFunction::MinAlign);
303     uint64_t Pad =
304         offsetToAlignment(DotAddress, llvm::Align(Func->getAlignment()));
305     if (Pad <= Func->getMaxColdAlignmentBytes())
306       DotAddress += Pad;
307     ColdAddresses[Func] = DotAddress;
308     LLVM_DEBUG(dbgs() << Func->getPrintName() << " cold tentative: "
309                       << Twine::utohexstr(DotAddress) << "\n");
310     DotAddress += Func->estimateColdSize();
311     DotAddress = alignTo(DotAddress, Func->getConstantIslandAlignment());
312     DotAddress += Func->estimateConstantIslandSize();
313   }
314   return DotAddress;
315 }
316 
317 uint64_t LongJmpPass::tentativeLayoutRelocMode(
318     const BinaryContext &BC, std::vector<BinaryFunction *> &SortedFunctions,
319     uint64_t DotAddress) {
320 
321   // Compute hot cold frontier
322   uint32_t LastHotIndex = -1u;
323   uint32_t CurrentIndex = 0;
324   if (opts::HotFunctionsAtEnd) {
325     for (BinaryFunction *BF : SortedFunctions) {
326       if (BF->hasValidIndex()) {
327         LastHotIndex = CurrentIndex;
328         break;
329       }
330 
331       ++CurrentIndex;
332     }
333   } else {
334     for (BinaryFunction *BF : SortedFunctions) {
335       if (!BF->hasValidIndex()) {
336         LastHotIndex = CurrentIndex;
337         break;
338       }
339 
340       ++CurrentIndex;
341     }
342   }
343 
344   // Hot
345   CurrentIndex = 0;
346   bool ColdLayoutDone = false;
347   for (BinaryFunction *Func : SortedFunctions) {
348     if (!BC.shouldEmit(*Func)) {
349       HotAddresses[Func] = Func->getAddress();
350       continue;
351     }
352 
353     if (!ColdLayoutDone && CurrentIndex >= LastHotIndex) {
354       DotAddress =
355           tentativeLayoutRelocColdPart(BC, SortedFunctions, DotAddress);
356       ColdLayoutDone = true;
357       if (opts::HotFunctionsAtEnd)
358         DotAddress = alignTo(DotAddress, opts::AlignText);
359     }
360 
361     DotAddress = alignTo(DotAddress, BinaryFunction::MinAlign);
362     uint64_t Pad =
363         offsetToAlignment(DotAddress, llvm::Align(Func->getAlignment()));
364     if (Pad <= Func->getMaxAlignmentBytes())
365       DotAddress += Pad;
366     HotAddresses[Func] = DotAddress;
367     LLVM_DEBUG(dbgs() << Func->getPrintName() << " tentative: "
368                       << Twine::utohexstr(DotAddress) << "\n");
369     if (!Func->isSplit())
370       DotAddress += Func->estimateSize();
371     else
372       DotAddress += Func->estimateHotSize();
373 
374     DotAddress = alignTo(DotAddress, Func->getConstantIslandAlignment());
375     DotAddress += Func->estimateConstantIslandSize();
376     ++CurrentIndex;
377   }
378   // BBs
379   for (BinaryFunction *Func : SortedFunctions)
380     tentativeBBLayout(*Func);
381 
382   return DotAddress;
383 }
384 
385 void LongJmpPass::tentativeLayout(
386     const BinaryContext &BC, std::vector<BinaryFunction *> &SortedFunctions) {
387   uint64_t DotAddress = BC.LayoutStartAddress;
388 
389   if (!BC.HasRelocations) {
390     for (BinaryFunction *Func : SortedFunctions) {
391       HotAddresses[Func] = Func->getAddress();
392       DotAddress = alignTo(DotAddress, ColdFragAlign);
393       ColdAddresses[Func] = DotAddress;
394       if (Func->isSplit())
395         DotAddress += Func->estimateColdSize();
396       tentativeBBLayout(*Func);
397     }
398 
399     return;
400   }
401 
402   // Relocation mode
403   uint64_t EstimatedTextSize = 0;
404   if (opts::UseOldText) {
405     EstimatedTextSize = tentativeLayoutRelocMode(BC, SortedFunctions, 0);
406 
407     // Initial padding
408     if (EstimatedTextSize <= BC.OldTextSectionSize) {
409       DotAddress = BC.OldTextSectionAddress;
410       uint64_t Pad =
411           offsetToAlignment(DotAddress, llvm::Align(opts::AlignText));
412       if (Pad + EstimatedTextSize <= BC.OldTextSectionSize) {
413         DotAddress += Pad;
414       }
415     }
416   }
417 
418   if (!EstimatedTextSize || EstimatedTextSize > BC.OldTextSectionSize)
419     DotAddress = alignTo(BC.LayoutStartAddress, opts::AlignText);
420 
421   tentativeLayoutRelocMode(BC, SortedFunctions, DotAddress);
422 }
423 
424 bool LongJmpPass::usesStub(const BinaryFunction &Func,
425                            const MCInst &Inst) const {
426   const MCSymbol *TgtSym = Func.getBinaryContext().MIB->getTargetSymbol(Inst);
427   const BinaryBasicBlock *TgtBB = Func.getBasicBlockForLabel(TgtSym);
428   auto Iter = Stubs.find(&Func);
429   if (Iter != Stubs.end())
430     return Iter->second.count(TgtBB);
431   return false;
432 }
433 
434 uint64_t LongJmpPass::getSymbolAddress(const BinaryContext &BC,
435                                        const MCSymbol *Target,
436                                        const BinaryBasicBlock *TgtBB) const {
437   if (TgtBB) {
438     auto Iter = BBAddresses.find(TgtBB);
439     assert(Iter != BBAddresses.end() && "Unrecognized BB");
440     return Iter->second;
441   }
442   uint64_t EntryID = 0;
443   const BinaryFunction *TargetFunc = BC.getFunctionForSymbol(Target, &EntryID);
444   auto Iter = HotAddresses.find(TargetFunc);
445   if (Iter == HotAddresses.end() || (TargetFunc && EntryID)) {
446     // Look at BinaryContext's resolution for this symbol - this is a symbol not
447     // mapped to a BinaryFunction
448     ErrorOr<uint64_t> ValueOrError = BC.getSymbolValue(*Target);
449     assert(ValueOrError && "Unrecognized symbol");
450     return *ValueOrError;
451   }
452   return Iter->second;
453 }
454 
455 bool LongJmpPass::relaxStub(BinaryBasicBlock &StubBB) {
456   const BinaryFunction &Func = *StubBB.getFunction();
457   const BinaryContext &BC = Func.getBinaryContext();
458   const int Bits = StubBits[&StubBB];
459   // Already working with the largest range?
460   if (Bits == static_cast<int>(BC.AsmInfo->getCodePointerSize() * 8))
461     return false;
462 
463   const static int RangeShortJmp = BC.MIB->getShortJmpEncodingSize();
464   const static int RangeSingleInstr = BC.MIB->getUncondBranchEncodingSize();
465   const static uint64_t ShortJmpMask = ~((1ULL << RangeShortJmp) - 1);
466   const static uint64_t SingleInstrMask =
467       ~((1ULL << (RangeSingleInstr - 1)) - 1);
468 
469   const MCSymbol *RealTargetSym = BC.MIB->getTargetSymbol(*StubBB.begin());
470   const BinaryBasicBlock *TgtBB = Func.getBasicBlockForLabel(RealTargetSym);
471   uint64_t TgtAddress = getSymbolAddress(BC, RealTargetSym, TgtBB);
472   uint64_t DotAddress = BBAddresses[&StubBB];
473   uint64_t PCRelTgtAddress = DotAddress > TgtAddress ? DotAddress - TgtAddress
474                                                      : TgtAddress - DotAddress;
475   // If it fits in one instruction, do not relax
476   if (!(PCRelTgtAddress & SingleInstrMask))
477     return false;
478 
479   // Fits short jmp
480   if (!(PCRelTgtAddress & ShortJmpMask)) {
481     if (Bits >= RangeShortJmp)
482       return false;
483 
484     LLVM_DEBUG(dbgs() << "Relaxing stub to short jump. PCRelTgtAddress = "
485                       << Twine::utohexstr(PCRelTgtAddress)
486                       << " RealTargetSym = " << RealTargetSym->getName()
487                       << "\n");
488     relaxStubToShortJmp(StubBB, RealTargetSym);
489     StubBits[&StubBB] = RangeShortJmp;
490     return true;
491   }
492 
493   // The long jmp uses absolute address on AArch64
494   // So we could not use it for PIC binaries
495   if (BC.isAArch64() && !BC.HasFixedLoadAddress) {
496     errs() << "BOLT-ERROR: Unable to relax stub for PIC binary\n";
497     exit(1);
498   }
499 
500   LLVM_DEBUG(dbgs() << "Relaxing stub to long jump. PCRelTgtAddress = "
501                     << Twine::utohexstr(PCRelTgtAddress)
502                     << " RealTargetSym = " << RealTargetSym->getName() << "\n");
503   relaxStubToLongJmp(StubBB, RealTargetSym);
504   StubBits[&StubBB] = static_cast<int>(BC.AsmInfo->getCodePointerSize() * 8);
505   return true;
506 }
507 
508 bool LongJmpPass::needsStub(const BinaryBasicBlock &BB, const MCInst &Inst,
509                             uint64_t DotAddress) const {
510   const BinaryFunction &Func = *BB.getFunction();
511   const BinaryContext &BC = Func.getBinaryContext();
512   const MCSymbol *TgtSym = BC.MIB->getTargetSymbol(Inst);
513   assert(TgtSym && "getTargetSymbol failed");
514 
515   const BinaryBasicBlock *TgtBB = Func.getBasicBlockForLabel(TgtSym);
516   // Check for shared stubs from foreign functions
517   if (!TgtBB) {
518     auto SSIter = SharedStubs.find(TgtSym);
519     if (SSIter != SharedStubs.end())
520       TgtBB = SSIter->second;
521   }
522 
523   int BitsAvail = BC.MIB->getPCRelEncodingSize(Inst) - 1;
524   uint64_t Mask = ~((1ULL << BitsAvail) - 1);
525 
526   uint64_t PCRelTgtAddress = getSymbolAddress(BC, TgtSym, TgtBB);
527   PCRelTgtAddress = DotAddress > PCRelTgtAddress ? DotAddress - PCRelTgtAddress
528                                                  : PCRelTgtAddress - DotAddress;
529 
530   return PCRelTgtAddress & Mask;
531 }
532 
533 bool LongJmpPass::relax(BinaryFunction &Func) {
534   const BinaryContext &BC = Func.getBinaryContext();
535   bool Modified = false;
536 
537   assert(BC.isAArch64() && "Unsupported arch");
538   constexpr int InsnSize = 4; // AArch64
539   std::vector<std::pair<BinaryBasicBlock *, std::unique_ptr<BinaryBasicBlock>>>
540       Insertions;
541 
542   BinaryBasicBlock *Frontier = getBBAtHotColdSplitPoint(Func);
543   uint64_t FrontierAddress = Frontier ? BBAddresses[Frontier] : 0;
544   if (FrontierAddress)
545     FrontierAddress += Frontier->getNumNonPseudos() * InsnSize;
546 
547   // Add necessary stubs for branch targets we know we can't fit in the
548   // instruction
549   for (BinaryBasicBlock &BB : Func) {
550     uint64_t DotAddress = BBAddresses[&BB];
551     // Stubs themselves are relaxed on the next loop
552     if (Stubs[&Func].count(&BB))
553       continue;
554 
555     for (MCInst &Inst : BB) {
556       if (BC.MIB->isPseudo(Inst))
557         continue;
558 
559       if (!shouldInsertStub(BC, Inst)) {
560         DotAddress += InsnSize;
561         continue;
562       }
563 
564       // Check and relax direct branch or call
565       if (!needsStub(BB, Inst, DotAddress)) {
566         DotAddress += InsnSize;
567         continue;
568       }
569       Modified = true;
570 
571       // Insert stubs close to the patched BB if call, but far away from the
572       // hot path if a branch, since this branch target is the cold region
573       // (but first check that the far away stub will be in range).
574       BinaryBasicBlock *InsertionPoint = &BB;
575       if (Func.isSimple() && !BC.MIB->isCall(Inst) && FrontierAddress &&
576           !BB.isCold()) {
577         int BitsAvail = BC.MIB->getPCRelEncodingSize(Inst) - 1;
578         uint64_t Mask = ~((1ULL << BitsAvail) - 1);
579         assert(FrontierAddress > DotAddress &&
580                "Hot code should be before the frontier");
581         uint64_t PCRelTgt = FrontierAddress - DotAddress;
582         if (!(PCRelTgt & Mask))
583           InsertionPoint = Frontier;
584       }
585       // Always put stubs at the end of the function if non-simple. We can't
586       // change the layout of non-simple functions because it has jump tables
587       // that we do not control.
588       if (!Func.isSimple())
589         InsertionPoint = &*std::prev(Func.end());
590 
591       // Create a stub to handle a far-away target
592       Insertions.emplace_back(InsertionPoint,
593                               replaceTargetWithStub(BB, Inst, DotAddress,
594                                                     InsertionPoint == Frontier
595                                                         ? FrontierAddress
596                                                         : DotAddress));
597 
598       DotAddress += InsnSize;
599     }
600   }
601 
602   // Relax stubs if necessary
603   for (BinaryBasicBlock &BB : Func) {
604     if (!Stubs[&Func].count(&BB) || !BB.isValid())
605       continue;
606 
607     Modified |= relaxStub(BB);
608   }
609 
610   for (std::pair<BinaryBasicBlock *, std::unique_ptr<BinaryBasicBlock>> &Elmt :
611        Insertions) {
612     if (!Elmt.second)
613       continue;
614     std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs;
615     NewBBs.emplace_back(std::move(Elmt.second));
616     Func.insertBasicBlocks(Elmt.first, std::move(NewBBs), true);
617   }
618 
619   return Modified;
620 }
621 
622 void LongJmpPass::runOnFunctions(BinaryContext &BC) {
623   outs() << "BOLT-INFO: Starting stub-insertion pass\n";
624   std::vector<BinaryFunction *> Sorted = BC.getSortedFunctions();
625   bool Modified;
626   uint32_t Iterations = 0;
627   do {
628     ++Iterations;
629     Modified = false;
630     tentativeLayout(BC, Sorted);
631     updateStubGroups();
632     for (BinaryFunction *Func : Sorted) {
633       if (relax(*Func)) {
634         // Don't ruin non-simple functions, they can't afford to have the layout
635         // changed.
636         if (Func->isSimple())
637           Func->fixBranches();
638         Modified = true;
639       }
640     }
641   } while (Modified);
642   outs() << "BOLT-INFO: Inserted " << NumHotStubs
643          << " stubs in the hot area and " << NumColdStubs
644          << " stubs in the cold area. Shared " << NumSharedStubs
645          << " times, iterated " << Iterations << " times.\n";
646 }
647 } // namespace bolt
648 } // namespace llvm
649