xref: /llvm-project/bolt/lib/Passes/Inliner.cpp (revision ee4282259d5993dfa0b7b8937541dd6ccaadf3d5)
1 //===- bolt/Passes/Inliner.cpp - Inlining pass for low-level binary IR ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Inliner class used for inlining binary functions.
10 //
11 // The current inliner has a limited callee support
12 // (see Inliner::getInliningInfo() for the most up-to-date details):
13 //
14 //  * No exception handling
15 //  * No jump tables
16 //  * Single entry point
17 //  * CFI update not supported - breaks unwinding
18 //  * Regular Call Sites:
19 //    - only leaf functions (or callees with only tail calls)
20 //      * no invokes (they can't be tail calls)
21 //    - no direct use of %rsp
22 //  * Tail Call Sites:
23 //    - since the stack is unmodified, the regular call limitations are lifted
24 //
25 //===----------------------------------------------------------------------===//
26 
27 #include "bolt/Passes/Inliner.h"
28 #include "bolt/Core/MCPlus.h"
29 #include "llvm/Support/CommandLine.h"
30 
31 #define DEBUG_TYPE "bolt-inliner"
32 
33 using namespace llvm;
34 
35 namespace opts {
36 
37 extern cl::OptionCategory BoltOptCategory;
38 
39 static cl::opt<bool>
40     AdjustProfile("inline-ap",
41                   cl::desc("adjust function profile after inlining"),
42                   cl::cat(BoltOptCategory));
43 
44 static cl::list<std::string>
45 ForceInlineFunctions("force-inline",
46   cl::CommaSeparated,
47   cl::desc("list of functions to always consider for inlining"),
48   cl::value_desc("func1,func2,func3,..."),
49   cl::Hidden,
50   cl::cat(BoltOptCategory));
51 
52 static cl::opt<bool> InlineAll("inline-all", cl::desc("inline all functions"),
53                                cl::cat(BoltOptCategory));
54 
55 static cl::opt<bool> InlineIgnoreLeafCFI(
56     "inline-ignore-leaf-cfi",
57     cl::desc("inline leaf functions with CFI programs (can break unwinding)"),
58     cl::init(true), cl::ReallyHidden, cl::cat(BoltOptCategory));
59 
60 static cl::opt<bool> InlineIgnoreCFI(
61     "inline-ignore-cfi",
62     cl::desc(
63         "inline functions with CFI programs (can break exception handling)"),
64     cl::ReallyHidden, cl::cat(BoltOptCategory));
65 
66 static cl::opt<unsigned>
67     InlineLimit("inline-limit",
68                 cl::desc("maximum number of call sites to inline"), cl::init(0),
69                 cl::Hidden, cl::cat(BoltOptCategory));
70 
71 static cl::opt<unsigned>
72     InlineMaxIters("inline-max-iters",
73                    cl::desc("maximum number of inline iterations"), cl::init(3),
74                    cl::Hidden, cl::cat(BoltOptCategory));
75 
76 static cl::opt<bool> InlineSmallFunctions(
77     "inline-small-functions",
78     cl::desc("inline functions if increase in size is less than defined by "
79              "-inline-small-functions-bytes"),
80     cl::cat(BoltOptCategory));
81 
82 static cl::opt<unsigned> InlineSmallFunctionsBytes(
83     "inline-small-functions-bytes",
84     cl::desc("max number of bytes for the function to be considered small for "
85              "inlining purposes"),
86     cl::init(4), cl::Hidden, cl::cat(BoltOptCategory));
87 
88 static cl::opt<bool> NoInline(
89     "no-inline",
90     cl::desc("disable all inlining (overrides other inlining options)"),
91     cl::cat(BoltOptCategory));
92 
93 /// This function returns true if any of inlining options are specified and the
94 /// inlining pass should be executed. Whenever a new inlining option is added,
95 /// this function should reflect the change.
96 bool inliningEnabled() {
97   return !NoInline &&
98          (InlineAll || InlineSmallFunctions || !ForceInlineFunctions.empty());
99 }
100 
101 bool mustConsider(const llvm::bolt::BinaryFunction &Function) {
102   for (std::string &Name : opts::ForceInlineFunctions)
103     if (Function.hasName(Name))
104       return true;
105   return false;
106 }
107 
108 void syncOptions() {
109   if (opts::InlineIgnoreCFI)
110     opts::InlineIgnoreLeafCFI = true;
111 
112   if (opts::InlineAll)
113     opts::InlineSmallFunctions = true;
114 }
115 
116 } // namespace opts
117 
118 namespace llvm {
119 namespace bolt {
120 
121 uint64_t Inliner::SizeOfCallInst;
122 uint64_t Inliner::SizeOfTailCallInst;
123 
124 uint64_t Inliner::getSizeOfCallInst(const BinaryContext &BC) {
125   if (SizeOfCallInst)
126     return SizeOfCallInst;
127 
128   MCInst Inst;
129   BC.MIB->createCall(Inst, BC.Ctx->createNamedTempSymbol(), BC.Ctx.get());
130   SizeOfCallInst = BC.computeInstructionSize(Inst);
131 
132   return SizeOfCallInst;
133 }
134 
135 uint64_t Inliner::getSizeOfTailCallInst(const BinaryContext &BC) {
136   if (SizeOfTailCallInst)
137     return SizeOfTailCallInst;
138 
139   MCInst Inst;
140   BC.MIB->createTailCall(Inst, BC.Ctx->createNamedTempSymbol(), BC.Ctx.get());
141   SizeOfTailCallInst = BC.computeInstructionSize(Inst);
142 
143   return SizeOfTailCallInst;
144 }
145 
146 InliningInfo getInliningInfo(const BinaryFunction &BF) {
147   const BinaryContext &BC = BF.getBinaryContext();
148   bool DirectSP = false;
149   bool HasCFI = false;
150   bool IsLeaf = true;
151 
152   // Perform necessary checks unless the option overrides it.
153   if (!opts::mustConsider(BF)) {
154     if (BF.hasSDTMarker())
155       return INL_NONE;
156 
157     if (BF.hasEHRanges())
158       return INL_NONE;
159 
160     if (BF.isMultiEntry())
161       return INL_NONE;
162 
163     if (BF.hasJumpTables())
164       return INL_NONE;
165 
166     const MCPhysReg SPReg = BC.MIB->getStackPointer();
167     for (const BinaryBasicBlock &BB : BF) {
168       for (const MCInst &Inst : BB) {
169         // Tail calls are marked as implicitly using the stack pointer and they
170         // could be inlined.
171         if (BC.MIB->isTailCall(Inst))
172           break;
173 
174         if (BC.MIB->isCFI(Inst)) {
175           HasCFI = true;
176           continue;
177         }
178 
179         if (BC.MIB->isCall(Inst))
180           IsLeaf = false;
181 
182         // Push/pop instructions are straightforward to handle.
183         if (BC.MIB->isPush(Inst) || BC.MIB->isPop(Inst))
184           continue;
185 
186         DirectSP |= BC.MIB->hasDefOfPhysReg(Inst, SPReg) ||
187                     BC.MIB->hasUseOfPhysReg(Inst, SPReg);
188       }
189     }
190   }
191 
192   if (HasCFI) {
193     if (!opts::InlineIgnoreLeafCFI)
194       return INL_NONE;
195 
196     if (!IsLeaf && !opts::InlineIgnoreCFI)
197       return INL_NONE;
198   }
199 
200   InliningInfo Info(DirectSP ? INL_TAILCALL : INL_ANY);
201 
202   size_t Size = BF.estimateSize();
203 
204   Info.SizeAfterInlining = Size;
205   Info.SizeAfterTailCallInlining = Size;
206 
207   // Handle special case of the known size reduction.
208   if (BF.size() == 1) {
209     // For a regular call the last return instruction could be removed
210     // (or converted to a branch).
211     const MCInst *LastInst = BF.back().getLastNonPseudoInstr();
212     if (LastInst && BC.MIB->isReturn(*LastInst) &&
213         !BC.MIB->isTailCall(*LastInst)) {
214       const uint64_t RetInstSize = BC.computeInstructionSize(*LastInst);
215       assert(Size >= RetInstSize);
216       Info.SizeAfterInlining -= RetInstSize;
217     }
218   }
219 
220   return Info;
221 }
222 
223 void Inliner::findInliningCandidates(BinaryContext &BC) {
224   for (const auto &BFI : BC.getBinaryFunctions()) {
225     const BinaryFunction &Function = BFI.second;
226     if (!shouldOptimize(Function))
227       continue;
228     const InliningInfo InlInfo = getInliningInfo(Function);
229     if (InlInfo.Type != INL_NONE)
230       InliningCandidates[&Function] = InlInfo;
231   }
232 }
233 
234 std::pair<BinaryBasicBlock *, BinaryBasicBlock::iterator>
235 Inliner::inlineCall(BinaryBasicBlock &CallerBB,
236                     BinaryBasicBlock::iterator CallInst,
237                     const BinaryFunction &Callee) {
238   BinaryFunction &CallerFunction = *CallerBB.getFunction();
239   BinaryContext &BC = CallerFunction.getBinaryContext();
240   auto &MIB = *BC.MIB;
241 
242   assert(MIB.isCall(*CallInst) && "can only inline a call or a tail call");
243   assert(!Callee.isMultiEntry() &&
244          "cannot inline function with multiple entries");
245   assert(!Callee.hasJumpTables() &&
246          "cannot inline function with jump table(s)");
247 
248   // Get information about the call site.
249   const bool CSIsInvoke = BC.MIB->isInvoke(*CallInst);
250   const bool CSIsTailCall = BC.MIB->isTailCall(*CallInst);
251   const int64_t CSGNUArgsSize = BC.MIB->getGnuArgsSize(*CallInst);
252   const std::optional<MCPlus::MCLandingPad> CSEHInfo =
253       BC.MIB->getEHInfo(*CallInst);
254 
255   // Split basic block at the call site if there will be more incoming edges
256   // coming from the callee.
257   BinaryBasicBlock *FirstInlinedBB = &CallerBB;
258   if (Callee.front().pred_size() && CallInst != CallerBB.begin()) {
259     FirstInlinedBB = CallerBB.splitAt(CallInst);
260     CallInst = FirstInlinedBB->begin();
261   }
262 
263   // Split basic block after the call instruction unless the callee is trivial
264   // (i.e. consists of a single basic block). If necessary, obtain a basic block
265   // for return instructions in the callee to redirect to.
266   BinaryBasicBlock *NextBB = nullptr;
267   if (Callee.size() > 1) {
268     if (std::next(CallInst) != FirstInlinedBB->end())
269       NextBB = FirstInlinedBB->splitAt(std::next(CallInst));
270     else
271       NextBB = FirstInlinedBB->getSuccessor();
272   }
273   if (NextBB)
274     FirstInlinedBB->removeSuccessor(NextBB);
275 
276   // Remove the call instruction.
277   auto InsertII = FirstInlinedBB->eraseInstruction(CallInst);
278 
279   double ProfileRatio = 0;
280   if (uint64_t CalleeExecCount = Callee.getKnownExecutionCount())
281     ProfileRatio =
282         (double)FirstInlinedBB->getKnownExecutionCount() / CalleeExecCount;
283 
284   // Save execution count of the first block as we don't want it to change
285   // later due to profile adjustment rounding errors.
286   const uint64_t FirstInlinedBBCount = FirstInlinedBB->getKnownExecutionCount();
287 
288   // Copy basic blocks and maintain a map from their origin.
289   std::unordered_map<const BinaryBasicBlock *, BinaryBasicBlock *> InlinedBBMap;
290   InlinedBBMap[&Callee.front()] = FirstInlinedBB;
291   for (const BinaryBasicBlock &BB : llvm::drop_begin(Callee)) {
292     BinaryBasicBlock *InlinedBB = CallerFunction.addBasicBlock();
293     InlinedBBMap[&BB] = InlinedBB;
294     InlinedBB->setCFIState(FirstInlinedBB->getCFIState());
295     if (Callee.hasValidProfile())
296       InlinedBB->setExecutionCount(BB.getKnownExecutionCount());
297     else
298       InlinedBB->setExecutionCount(FirstInlinedBBCount);
299   }
300 
301   // Copy over instructions and edges.
302   for (const BinaryBasicBlock &BB : Callee) {
303     BinaryBasicBlock *InlinedBB = InlinedBBMap[&BB];
304 
305     if (InlinedBB != FirstInlinedBB)
306       InsertII = InlinedBB->begin();
307 
308     // Copy over instructions making any necessary mods.
309     for (MCInst Inst : BB) {
310       if (MIB.isPseudo(Inst))
311         continue;
312 
313       MIB.stripAnnotations(Inst, /*KeepTC=*/BC.isX86() || BC.isAArch64());
314 
315       // Fix branch target. Strictly speaking, we don't have to do this as
316       // targets of direct branches will be fixed later and don't matter
317       // in the CFG state. However, disassembly may look misleading, and
318       // hence we do the fixing.
319       if (MIB.isBranch(Inst) && !MIB.isTailCall(Inst)) {
320         assert(!MIB.isIndirectBranch(Inst) &&
321                "unexpected indirect branch in callee");
322         const BinaryBasicBlock *TargetBB =
323             Callee.getBasicBlockForLabel(MIB.getTargetSymbol(Inst));
324         assert(TargetBB && "cannot find target block in callee");
325         MIB.replaceBranchTarget(Inst, InlinedBBMap[TargetBB]->getLabel(),
326                                 BC.Ctx.get());
327       }
328 
329       if (CSIsTailCall || (!MIB.isCall(Inst) && !MIB.isReturn(Inst))) {
330         InsertII =
331             std::next(InlinedBB->insertInstruction(InsertII, std::move(Inst)));
332         continue;
333       }
334 
335       // Handle special instructions for a non-tail call site.
336       if (!MIB.isCall(Inst)) {
337         // Returns are removed.
338         break;
339       }
340 
341       MIB.convertTailCallToCall(Inst);
342 
343       // Propagate EH-related info to call instructions.
344       if (CSIsInvoke) {
345         MIB.addEHInfo(Inst, *CSEHInfo);
346         if (CSGNUArgsSize >= 0)
347           MIB.addGnuArgsSize(Inst, CSGNUArgsSize);
348       }
349 
350       InsertII =
351           std::next(InlinedBB->insertInstruction(InsertII, std::move(Inst)));
352     }
353 
354     // Add CFG edges to the basic blocks of the inlined instance.
355     std::vector<BinaryBasicBlock *> Successors(BB.succ_size());
356     llvm::transform(BB.successors(), Successors.begin(),
357                     [&InlinedBBMap](const BinaryBasicBlock *BB) {
358                       auto It = InlinedBBMap.find(BB);
359                       assert(It != InlinedBBMap.end());
360                       return It->second;
361                     });
362 
363     if (CallerFunction.hasValidProfile() && Callee.hasValidProfile())
364       InlinedBB->addSuccessors(Successors.begin(), Successors.end(),
365                                BB.branch_info_begin(), BB.branch_info_end());
366     else
367       InlinedBB->addSuccessors(Successors.begin(), Successors.end());
368 
369     if (!CSIsTailCall && BB.succ_size() == 0 && NextBB) {
370       // Either it's a return block or the last instruction never returns.
371       InlinedBB->addSuccessor(NextBB, InlinedBB->getExecutionCount());
372     }
373 
374     // Scale profiling info for blocks and edges after inlining.
375     if (CallerFunction.hasValidProfile() && Callee.size() > 1) {
376       if (opts::AdjustProfile)
377         InlinedBB->adjustExecutionCount(ProfileRatio);
378       else
379         InlinedBB->setExecutionCount(InlinedBB->getKnownExecutionCount() *
380                                      ProfileRatio);
381     }
382   }
383 
384   // Restore the original execution count of the first inlined basic block.
385   FirstInlinedBB->setExecutionCount(FirstInlinedBBCount);
386 
387   CallerFunction.recomputeLandingPads();
388 
389   if (NextBB)
390     return std::make_pair(NextBB, NextBB->begin());
391 
392   if (Callee.size() == 1)
393     return std::make_pair(FirstInlinedBB, InsertII);
394 
395   return std::make_pair(FirstInlinedBB, FirstInlinedBB->end());
396 }
397 
398 bool Inliner::inlineCallsInFunction(BinaryFunction &Function) {
399   BinaryContext &BC = Function.getBinaryContext();
400   std::vector<BinaryBasicBlock *> Blocks(Function.getLayout().block_begin(),
401                                          Function.getLayout().block_end());
402   llvm::sort(
403       Blocks, [](const BinaryBasicBlock *BB1, const BinaryBasicBlock *BB2) {
404         return BB1->getKnownExecutionCount() > BB2->getKnownExecutionCount();
405       });
406 
407   bool DidInlining = false;
408   for (BinaryBasicBlock *BB : Blocks) {
409     for (auto InstIt = BB->begin(); InstIt != BB->end();) {
410       MCInst &Inst = *InstIt;
411       if (!BC.MIB->isCall(Inst) || MCPlus::getNumPrimeOperands(Inst) != 1 ||
412           !Inst.getOperand(0).isExpr()) {
413         ++InstIt;
414         continue;
415       }
416 
417       const MCSymbol *TargetSymbol = BC.MIB->getTargetSymbol(Inst);
418       assert(TargetSymbol && "target symbol expected for direct call");
419 
420       // Don't inline calls to a secondary entry point in a target function.
421       uint64_t EntryID = 0;
422       BinaryFunction *TargetFunction =
423           BC.getFunctionForSymbol(TargetSymbol, &EntryID);
424       if (!TargetFunction || EntryID != 0) {
425         ++InstIt;
426         continue;
427       }
428 
429       // Don't do recursive inlining.
430       if (TargetFunction == &Function) {
431         ++InstIt;
432         continue;
433       }
434 
435       auto IInfo = InliningCandidates.find(TargetFunction);
436       if (IInfo == InliningCandidates.end()) {
437         ++InstIt;
438         continue;
439       }
440 
441       const bool IsTailCall = BC.MIB->isTailCall(Inst);
442       if (!IsTailCall && IInfo->second.Type == INL_TAILCALL) {
443         ++InstIt;
444         continue;
445       }
446 
447       int64_t SizeAfterInlining;
448       if (IsTailCall)
449         SizeAfterInlining =
450             IInfo->second.SizeAfterTailCallInlining - getSizeOfTailCallInst(BC);
451       else
452         SizeAfterInlining =
453             IInfo->second.SizeAfterInlining - getSizeOfCallInst(BC);
454 
455       if (!opts::InlineAll && !opts::mustConsider(*TargetFunction)) {
456         if (!opts::InlineSmallFunctions ||
457             SizeAfterInlining > opts::InlineSmallFunctionsBytes) {
458           ++InstIt;
459           continue;
460         }
461       }
462 
463       LLVM_DEBUG(dbgs() << "BOLT-DEBUG: inlining call to " << *TargetFunction
464                         << " in " << Function << " : " << BB->getName()
465                         << ". Count: " << BB->getKnownExecutionCount()
466                         << ". Size change: " << SizeAfterInlining
467                         << " bytes.\n");
468 
469       std::tie(BB, InstIt) = inlineCall(*BB, InstIt, *TargetFunction);
470 
471       DidInlining = true;
472       TotalInlinedBytes += SizeAfterInlining;
473 
474       ++NumInlinedCallSites;
475       NumInlinedDynamicCalls += BB->getExecutionCount();
476 
477       // Subtract basic block execution count from the callee execution count.
478       if (opts::AdjustProfile)
479         TargetFunction->adjustExecutionCount(BB->getKnownExecutionCount());
480 
481       // Check if the caller inlining status has to be adjusted.
482       if (IInfo->second.Type == INL_TAILCALL) {
483         auto CallerIInfo = InliningCandidates.find(&Function);
484         if (CallerIInfo != InliningCandidates.end() &&
485             CallerIInfo->second.Type == INL_ANY) {
486           LLVM_DEBUG(dbgs() << "adjusting inlining status for function "
487                             << Function << '\n');
488           CallerIInfo->second.Type = INL_TAILCALL;
489         }
490       }
491 
492       if (NumInlinedCallSites == opts::InlineLimit)
493         return true;
494     }
495   }
496 
497   return DidInlining;
498 }
499 
500 Error Inliner::runOnFunctions(BinaryContext &BC) {
501   opts::syncOptions();
502 
503   if (!opts::inliningEnabled())
504     return Error::success();
505 
506   bool InlinedOnce;
507   unsigned NumIters = 0;
508   do {
509     if (opts::InlineLimit && NumInlinedCallSites >= opts::InlineLimit)
510       break;
511 
512     InlinedOnce = false;
513 
514     InliningCandidates.clear();
515     findInliningCandidates(BC);
516 
517     std::vector<BinaryFunction *> ConsideredFunctions;
518     for (auto &BFI : BC.getBinaryFunctions()) {
519       BinaryFunction &Function = BFI.second;
520       if (!shouldOptimize(Function))
521         continue;
522       ConsideredFunctions.push_back(&Function);
523     }
524     llvm::sort(ConsideredFunctions, [](const BinaryFunction *A,
525                                        const BinaryFunction *B) {
526       return B->getKnownExecutionCount() < A->getKnownExecutionCount();
527     });
528     for (BinaryFunction *Function : ConsideredFunctions) {
529       if (opts::InlineLimit && NumInlinedCallSites >= opts::InlineLimit)
530         break;
531 
532       const bool DidInline = inlineCallsInFunction(*Function);
533 
534       if (DidInline)
535         Modified.insert(Function);
536 
537       InlinedOnce |= DidInline;
538     }
539 
540     ++NumIters;
541   } while (InlinedOnce && NumIters < opts::InlineMaxIters);
542 
543   if (NumInlinedCallSites)
544     BC.outs() << "BOLT-INFO: inlined " << NumInlinedDynamicCalls << " calls at "
545               << NumInlinedCallSites << " call sites in " << NumIters
546               << " iteration(s). Change in binary size: " << TotalInlinedBytes
547               << " bytes.\n";
548   return Error::success();
549 }
550 
551 } // namespace bolt
552 } // namespace llvm
553