xref: /llvm-project/llvm/lib/Analysis/MLInlineAdvisor.cpp (revision f32e5bdcefcff80f4296f8f4abedc37dcda36d53)
1 //===- MLInlineAdvisor.cpp - machine learned InlineAdvisor ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the interface between the inliner and a learned model.
10 // It delegates model evaluation to either the AOT compiled model (the
11 // 'release' mode) or a runtime-loaded model (the 'development' case).
12 //
13 //===----------------------------------------------------------------------===//
14 #include "llvm/Analysis/MLInlineAdvisor.h"
15 #include "llvm/ADT/SCCIterator.h"
16 #include "llvm/Analysis/AssumptionCache.h"
17 #include "llvm/Analysis/BlockFrequencyInfo.h"
18 #include "llvm/Analysis/CallGraph.h"
19 #include "llvm/Analysis/FunctionPropertiesAnalysis.h"
20 #include "llvm/Analysis/InlineCost.h"
21 #include "llvm/Analysis/InlineModelFeatureMaps.h"
22 #include "llvm/Analysis/InteractiveModelRunner.h"
23 #include "llvm/Analysis/LazyCallGraph.h"
24 #include "llvm/Analysis/LoopInfo.h"
25 #include "llvm/Analysis/MLModelRunner.h"
26 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
27 #include "llvm/Analysis/ProfileSummaryInfo.h"
28 #include "llvm/Analysis/ReleaseModeModelRunner.h"
29 #include "llvm/Analysis/TargetTransformInfo.h"
30 #include "llvm/IR/Dominators.h"
31 #include "llvm/IR/InstIterator.h"
32 #include "llvm/IR/Module.h"
33 #include "llvm/IR/PassManager.h"
34 #include "llvm/Support/CommandLine.h"
35 
36 using namespace llvm;
37 
38 static cl::opt<std::string> InteractiveChannelBaseName(
39     "inliner-interactive-channel-base", cl::Hidden,
40     cl::desc(
41         "Base file path for the interactive mode. The incoming filename should "
42         "have the name <inliner-interactive-channel-base>.in, while the "
43         "outgoing name should be <inliner-interactive-channel-base>.out"));
44 static const std::string InclDefaultMsg =
45     (Twine("In interactive mode, also send the default policy decision: ") +
46      DefaultDecisionName + ".")
47         .str();
48 static cl::opt<bool>
49     InteractiveIncludeDefault("inliner-interactive-include-default", cl::Hidden,
50                               cl::desc(InclDefaultMsg));
51 
52 enum class SkipMLPolicyCriteria { Never, IfCallerIsNotCold };
53 
54 static cl::opt<SkipMLPolicyCriteria> SkipPolicy(
55     "ml-inliner-skip-policy", cl::Hidden, cl::init(SkipMLPolicyCriteria::Never),
56     cl::values(clEnumValN(SkipMLPolicyCriteria::Never, "never", "never"),
57                clEnumValN(SkipMLPolicyCriteria::IfCallerIsNotCold,
58                           "if-caller-not-cold", "if the caller is not cold")));
59 
60 static cl::opt<std::string> ModelSelector("ml-inliner-model-selector",
61                                           cl::Hidden, cl::init(""));
62 
63 #if defined(LLVM_HAVE_TF_AOT_INLINERSIZEMODEL)
64 // codegen-ed file
65 #include "InlinerSizeModel.h" // NOLINT
66 using CompiledModelType = llvm::InlinerSizeModel;
67 #else
68 using CompiledModelType = NoopSavedModelImpl;
69 #endif
70 
71 std::unique_ptr<InlineAdvisor>
72 llvm::getReleaseModeAdvisor(Module &M, ModuleAnalysisManager &MAM,
73                             std::function<bool(CallBase &)> GetDefaultAdvice) {
74   if (!llvm::isEmbeddedModelEvaluatorValid<CompiledModelType>() &&
75       InteractiveChannelBaseName.empty())
76     return nullptr;
77   std::unique_ptr<MLModelRunner> AOTRunner;
78   if (InteractiveChannelBaseName.empty())
79     AOTRunner = std::make_unique<ReleaseModeModelRunner<CompiledModelType>>(
80         M.getContext(), FeatureMap, DecisionName,
81         EmbeddedModelRunnerOptions().setModelSelector(ModelSelector));
82   else {
83     auto Features = FeatureMap;
84     if (InteractiveIncludeDefault)
85       Features.push_back(DefaultDecisionSpec);
86     AOTRunner = std::make_unique<InteractiveModelRunner>(
87         M.getContext(), Features, InlineDecisionSpec,
88         InteractiveChannelBaseName + ".out",
89         InteractiveChannelBaseName + ".in");
90   }
91   return std::make_unique<MLInlineAdvisor>(M, MAM, std::move(AOTRunner),
92                                            GetDefaultAdvice);
93 }
94 
95 #define DEBUG_TYPE "inline-ml"
96 
97 static cl::opt<float> SizeIncreaseThreshold(
98     "ml-advisor-size-increase-threshold", cl::Hidden,
99     cl::desc("Maximum factor by which expected native size may increase before "
100              "blocking any further inlining."),
101     cl::init(2.0));
102 
103 static cl::opt<bool> KeepFPICache(
104     "ml-advisor-keep-fpi-cache", cl::Hidden,
105     cl::desc(
106         "For test - keep the ML Inline advisor's FunctionPropertiesInfo cache"),
107     cl::init(false));
108 
109 // clang-format off
110 const std::vector<TensorSpec> llvm::FeatureMap{
111 #define POPULATE_NAMES(DTYPE, SHAPE, NAME, __) TensorSpec::createSpec<DTYPE>(#NAME, SHAPE),
112 // InlineCost features - these must come first
113   INLINE_COST_FEATURE_ITERATOR(POPULATE_NAMES)
114 
115 // Non-cost features
116   INLINE_FEATURE_ITERATOR(POPULATE_NAMES)
117 #undef POPULATE_NAMES
118 };
119 // clang-format on
120 
121 const char *const llvm::DecisionName = "inlining_decision";
122 const TensorSpec llvm::InlineDecisionSpec =
123     TensorSpec::createSpec<int64_t>(DecisionName, {1});
124 const char *const llvm::DefaultDecisionName = "inlining_default";
125 const TensorSpec llvm::DefaultDecisionSpec =
126     TensorSpec::createSpec<int64_t>(DefaultDecisionName, {1});
127 const char *const llvm::RewardName = "delta_size";
128 
129 CallBase *getInlinableCS(Instruction &I) {
130   if (auto *CS = dyn_cast<CallBase>(&I))
131     if (Function *Callee = CS->getCalledFunction()) {
132       if (!Callee->isDeclaration()) {
133         return CS;
134       }
135     }
136   return nullptr;
137 }
138 
139 MLInlineAdvisor::MLInlineAdvisor(
140     Module &M, ModuleAnalysisManager &MAM,
141     std::unique_ptr<MLModelRunner> Runner,
142     std::function<bool(CallBase &)> GetDefaultAdvice)
143     : InlineAdvisor(
144           M, MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager()),
145       ModelRunner(std::move(Runner)), GetDefaultAdvice(GetDefaultAdvice),
146       CG(MAM.getResult<LazyCallGraphAnalysis>(M)),
147       InitialIRSize(getModuleIRSize()), CurrentIRSize(InitialIRSize),
148       PSI(MAM.getResult<ProfileSummaryAnalysis>(M)) {
149   assert(ModelRunner);
150   ModelRunner->switchContext("");
151   // Extract the 'call site height' feature - the position of a call site
152   // relative to the farthest statically reachable SCC node. We don't mutate
153   // this value while inlining happens. Empirically, this feature proved
154   // critical in behavioral cloning - i.e. training a model to mimic the manual
155   // heuristic's decisions - and, thus, equally important for training for
156   // improvement.
157   CallGraph CGraph(M);
158   for (auto I = scc_begin(&CGraph); !I.isAtEnd(); ++I) {
159     const std::vector<CallGraphNode *> &CGNodes = *I;
160     unsigned Level = 0;
161     for (auto *CGNode : CGNodes) {
162       Function *F = CGNode->getFunction();
163       if (!F || F->isDeclaration())
164         continue;
165       for (auto &I : instructions(F)) {
166         if (auto *CS = getInlinableCS(I)) {
167           auto *Called = CS->getCalledFunction();
168           auto Pos = FunctionLevels.find(&CG.get(*Called));
169           // In bottom up traversal, an inlinable callee is either in the
170           // same SCC, or to a function in a visited SCC. So not finding its
171           // level means we haven't visited it yet, meaning it's in this SCC.
172           if (Pos == FunctionLevels.end())
173             continue;
174           Level = std::max(Level, Pos->second + 1);
175         }
176       }
177     }
178     for (auto *CGNode : CGNodes) {
179       Function *F = CGNode->getFunction();
180       if (F && !F->isDeclaration())
181         FunctionLevels[&CG.get(*F)] = Level;
182     }
183   }
184   for (auto KVP : FunctionLevels) {
185     AllNodes.insert(KVP.first);
186     EdgeCount += getLocalCalls(KVP.first->getFunction());
187   }
188   NodeCount = AllNodes.size();
189 }
190 
191 unsigned MLInlineAdvisor::getInitialFunctionLevel(const Function &F) const {
192   return CG.lookup(F) ? FunctionLevels.at(CG.lookup(F)) : 0;
193 }
194 
195 void MLInlineAdvisor::onPassEntry(LazyCallGraph::SCC *CurSCC) {
196   if (!CurSCC || ForceStop)
197     return;
198   FPICache.clear();
199   // Function passes executed between InlinerPass runs may have changed the
200   // module-wide features.
201   // The cgscc pass manager rules are such that:
202   // - if a pass leads to merging SCCs, then the pipeline is restarted on the
203   // merged SCC
204   // - if a pass leads to splitting the SCC, then we continue with one of the
205   // splits
206   // This means that the NodesInLastSCC is a superset (not strict) of the nodes
207   // that subsequent passes would have processed
208   // - in addition, if new Nodes were created by a pass (e.g. CoroSplit),
209   // they'd be adjacent to Nodes in the last SCC. So we just need to check the
210   // boundary of Nodes in NodesInLastSCC for Nodes we haven't seen. We don't
211   // care about the nature of the Edge (call or ref). `FunctionLevels`-wise, we
212   // record them at the same level as the original node (this is a choice, may
213   // need revisiting).
214   // - nodes are only deleted at the end of a call graph walk where they are
215   // batch deleted, so we shouldn't see any dead nodes here.
216   while (!NodesInLastSCC.empty()) {
217     const auto *N = *NodesInLastSCC.begin();
218     assert(!N->isDead());
219     NodesInLastSCC.erase(N);
220     EdgeCount += getLocalCalls(N->getFunction());
221     const auto NLevel = FunctionLevels.at(N);
222     for (const auto &E : *(*N)) {
223       const auto *AdjNode = &E.getNode();
224       assert(!AdjNode->isDead() && !AdjNode->getFunction().isDeclaration());
225       auto I = AllNodes.insert(AdjNode);
226       // We've discovered a new function.
227       if (I.second) {
228         ++NodeCount;
229         NodesInLastSCC.insert(AdjNode);
230         FunctionLevels[AdjNode] = NLevel;
231       }
232     }
233   }
234 
235   EdgeCount -= EdgesOfLastSeenNodes;
236   EdgesOfLastSeenNodes = 0;
237 
238   // (Re)use NodesInLastSCC to remember the nodes in the SCC right now,
239   // in case the SCC is split before onPassExit and some nodes are split out
240   assert(NodesInLastSCC.empty());
241   for (const auto &N : *CurSCC)
242     NodesInLastSCC.insert(&N);
243 }
244 
245 void MLInlineAdvisor::onPassExit(LazyCallGraph::SCC *CurSCC) {
246   // No need to keep this around - function passes will invalidate it.
247   if (!KeepFPICache)
248     FPICache.clear();
249   if (!CurSCC || ForceStop)
250     return;
251   // Keep track of the nodes and edges we last saw. Then, in onPassEntry,
252   // we update the node count and edge count from the subset of these nodes that
253   // survived.
254   EdgesOfLastSeenNodes = 0;
255 
256   // Check on nodes that were in SCC onPassEntry
257   for (const LazyCallGraph::Node *N : NodesInLastSCC) {
258     assert(!N->isDead());
259     EdgesOfLastSeenNodes += getLocalCalls(N->getFunction());
260   }
261 
262   // Check on nodes that may have got added to SCC
263   for (const auto &N : *CurSCC) {
264     assert(!N.isDead());
265     auto I = NodesInLastSCC.insert(&N);
266     if (I.second)
267       EdgesOfLastSeenNodes += getLocalCalls(N.getFunction());
268   }
269   assert(NodeCount >= NodesInLastSCC.size());
270   assert(EdgeCount >= EdgesOfLastSeenNodes);
271 }
272 
273 int64_t MLInlineAdvisor::getLocalCalls(Function &F) {
274   return getCachedFPI(F).DirectCallsToDefinedFunctions;
275 }
276 
277 // Update the internal state of the advisor, and force invalidate feature
278 // analysis. Currently, we maintain minimal (and very simple) global state - the
279 // number of functions and the number of static calls. We also keep track of the
280 // total IR size in this module, to stop misbehaving policies at a certain bloat
281 // factor (SizeIncreaseThreshold)
282 void MLInlineAdvisor::onSuccessfulInlining(const MLInlineAdvice &Advice,
283                                            bool CalleeWasDeleted) {
284   assert(!ForceStop);
285   Function *Caller = Advice.getCaller();
286   Function *Callee = Advice.getCallee();
287   // The caller features aren't valid anymore.
288   {
289     PreservedAnalyses PA = PreservedAnalyses::all();
290     PA.abandon<FunctionPropertiesAnalysis>();
291     PA.abandon<LoopAnalysis>();
292     FAM.invalidate(*Caller, PA);
293   }
294   Advice.updateCachedCallerFPI(FAM);
295   int64_t IRSizeAfter =
296       getIRSize(*Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize);
297   CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize);
298   if (CurrentIRSize > SizeIncreaseThreshold * InitialIRSize)
299     ForceStop = true;
300 
301   // We can delta-update module-wide features. We know the inlining only changed
302   // the caller, and maybe the callee (by deleting the latter).
303   // Nodes are simple to update.
304   // For edges, we 'forget' the edges that the caller and callee used to have
305   // before inlining, and add back what they currently have together.
306   int64_t NewCallerAndCalleeEdges =
307       getCachedFPI(*Caller).DirectCallsToDefinedFunctions;
308 
309   // A dead function's node is not actually removed from the call graph until
310   // the end of the call graph walk, but the node no longer belongs to any valid
311   // SCC.
312   if (CalleeWasDeleted) {
313     --NodeCount;
314     NodesInLastSCC.erase(CG.lookup(*Callee));
315     DeadFunctions.insert(Callee);
316   } else {
317     NewCallerAndCalleeEdges +=
318         getCachedFPI(*Callee).DirectCallsToDefinedFunctions;
319   }
320   EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges);
321   assert(CurrentIRSize >= 0 && EdgeCount >= 0 && NodeCount >= 0);
322 }
323 
324 int64_t MLInlineAdvisor::getModuleIRSize() const {
325   int64_t Ret = 0;
326   for (auto &F : M)
327     if (!F.isDeclaration())
328       Ret += getIRSize(F);
329   return Ret;
330 }
331 
332 FunctionPropertiesInfo &MLInlineAdvisor::getCachedFPI(Function &F) const {
333   auto InsertPair =
334       FPICache.insert(std::make_pair(&F, FunctionPropertiesInfo()));
335   if (!InsertPair.second)
336     return InsertPair.first->second;
337   InsertPair.first->second = FAM.getResult<FunctionPropertiesAnalysis>(F);
338   return InsertPair.first->second;
339 }
340 
341 std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) {
342   if (auto Skip = getSkipAdviceIfUnreachableCallsite(CB))
343     return Skip;
344 
345   auto &Caller = *CB.getCaller();
346   auto &Callee = *CB.getCalledFunction();
347 
348   auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {
349     return FAM.getResult<AssumptionAnalysis>(F);
350   };
351   auto &TIR = FAM.getResult<TargetIRAnalysis>(Callee);
352   auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(Caller);
353 
354   if (SkipPolicy == SkipMLPolicyCriteria::IfCallerIsNotCold) {
355     if (!PSI.isFunctionEntryCold(&Caller))
356       return std::make_unique<InlineAdvice>(this, CB, ORE,
357                                             GetDefaultAdvice(CB));
358   }
359   auto MandatoryKind = InlineAdvisor::getMandatoryKind(CB, FAM, ORE);
360   // If this is a "never inline" case, there won't be any changes to internal
361   // state we need to track, so we can just return the base InlineAdvice, which
362   // will do nothing interesting.
363   // Same thing if this is a recursive case.
364   if (MandatoryKind == InlineAdvisor::MandatoryInliningKind::Never ||
365       &Caller == &Callee)
366     return getMandatoryAdvice(CB, false);
367 
368   bool Mandatory =
369       MandatoryKind == InlineAdvisor::MandatoryInliningKind::Always;
370 
371   // If we need to stop, we won't want to track anymore any state changes, so
372   // we just return the base InlineAdvice, which acts as a noop.
373   if (ForceStop) {
374     ORE.emit([&] {
375       return OptimizationRemarkMissed(DEBUG_TYPE, "ForceStop", &CB)
376              << "Won't attempt inlining because module size grew too much.";
377     });
378     return std::make_unique<InlineAdvice>(this, CB, ORE, Mandatory);
379   }
380 
381   int CostEstimate = 0;
382   if (!Mandatory) {
383     auto IsCallSiteInlinable =
384         llvm::getInliningCostEstimate(CB, TIR, GetAssumptionCache);
385     if (!IsCallSiteInlinable) {
386       // We can't inline this for correctness reasons, so return the base
387       // InlineAdvice, as we don't care about tracking any state changes (which
388       // won't happen).
389       return std::make_unique<InlineAdvice>(this, CB, ORE, false);
390     }
391     CostEstimate = *IsCallSiteInlinable;
392   }
393 
394   const auto CostFeatures =
395       llvm::getInliningCostFeatures(CB, TIR, GetAssumptionCache);
396   if (!CostFeatures) {
397     return std::make_unique<InlineAdvice>(this, CB, ORE, false);
398   }
399 
400   if (Mandatory)
401     return getMandatoryAdvice(CB, true);
402 
403   auto NumCtantParams = 0;
404   for (auto I = CB.arg_begin(), E = CB.arg_end(); I != E; ++I) {
405     NumCtantParams += (isa<Constant>(*I));
406   }
407 
408   auto &CallerBefore = getCachedFPI(Caller);
409   auto &CalleeBefore = getCachedFPI(Callee);
410 
411   *ModelRunner->getTensor<int64_t>(FeatureIndex::callee_basic_block_count) =
412       CalleeBefore.BasicBlockCount;
413   *ModelRunner->getTensor<int64_t>(FeatureIndex::callsite_height) =
414       getInitialFunctionLevel(Caller);
415   *ModelRunner->getTensor<int64_t>(FeatureIndex::node_count) = NodeCount;
416   *ModelRunner->getTensor<int64_t>(FeatureIndex::nr_ctant_params) =
417       NumCtantParams;
418   *ModelRunner->getTensor<int64_t>(FeatureIndex::edge_count) = EdgeCount;
419   *ModelRunner->getTensor<int64_t>(FeatureIndex::caller_users) =
420       CallerBefore.Uses;
421   *ModelRunner->getTensor<int64_t>(
422       FeatureIndex::caller_conditionally_executed_blocks) =
423       CallerBefore.BlocksReachedFromConditionalInstruction;
424   *ModelRunner->getTensor<int64_t>(FeatureIndex::caller_basic_block_count) =
425       CallerBefore.BasicBlockCount;
426   *ModelRunner->getTensor<int64_t>(
427       FeatureIndex::callee_conditionally_executed_blocks) =
428       CalleeBefore.BlocksReachedFromConditionalInstruction;
429   *ModelRunner->getTensor<int64_t>(FeatureIndex::callee_users) =
430       CalleeBefore.Uses;
431   *ModelRunner->getTensor<int64_t>(FeatureIndex::cost_estimate) = CostEstimate;
432   *ModelRunner->getTensor<int64_t>(FeatureIndex::is_callee_avail_external) =
433       Callee.hasAvailableExternallyLinkage();
434   *ModelRunner->getTensor<int64_t>(FeatureIndex::is_caller_avail_external) =
435       Caller.hasAvailableExternallyLinkage();
436 
437   // Add the cost features
438   for (size_t I = 0;
439        I < static_cast<size_t>(InlineCostFeatureIndex::NumberOfFeatures); ++I) {
440     *ModelRunner->getTensor<int64_t>(inlineCostFeatureToMlFeature(
441         static_cast<InlineCostFeatureIndex>(I))) = CostFeatures->at(I);
442   }
443   // This one would have been set up to be right at the end.
444   if (!InteractiveChannelBaseName.empty() && InteractiveIncludeDefault)
445     *ModelRunner->getTensor<int64_t>(InlineCostFeatureIndex::NumberOfFeatures) =
446         GetDefaultAdvice(CB);
447   return getAdviceFromModel(CB, ORE);
448 }
449 
450 std::unique_ptr<MLInlineAdvice>
451 MLInlineAdvisor::getAdviceFromModel(CallBase &CB,
452                                     OptimizationRemarkEmitter &ORE) {
453   return std::make_unique<MLInlineAdvice>(
454       this, CB, ORE, static_cast<bool>(ModelRunner->evaluate<int64_t>()));
455 }
456 
457 std::unique_ptr<InlineAdvice>
458 MLInlineAdvisor::getSkipAdviceIfUnreachableCallsite(CallBase &CB) {
459   if (!FAM.getResult<DominatorTreeAnalysis>(*CB.getCaller())
460            .isReachableFromEntry(CB.getParent()))
461     return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), false);
462   return nullptr;
463 }
464 
465 std::unique_ptr<InlineAdvice> MLInlineAdvisor::getMandatoryAdvice(CallBase &CB,
466                                                                   bool Advice) {
467   // Make sure we track inlinings in all cases - mandatory or not.
468   if (auto Skip = getSkipAdviceIfUnreachableCallsite(CB))
469     return Skip;
470   if (Advice && !ForceStop)
471     return getMandatoryAdviceImpl(CB);
472 
473   // If this is a "never inline" case, there won't be any changes to internal
474   // state we need to track, so we can just return the base InlineAdvice, which
475   // will do nothing interesting.
476   // Same if we are forced to stop - we don't track anymore.
477   return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), Advice);
478 }
479 
480 std::unique_ptr<MLInlineAdvice>
481 MLInlineAdvisor::getMandatoryAdviceImpl(CallBase &CB) {
482   return std::make_unique<MLInlineAdvice>(this, CB, getCallerORE(CB), true);
483 }
484 
485 void MLInlineAdvisor::print(raw_ostream &OS) const {
486   OS << "[MLInlineAdvisor] Nodes: " << NodeCount << " Edges: " << EdgeCount
487      << " EdgesOfLastSeenNodes: " << EdgesOfLastSeenNodes << "\n";
488   OS << "[MLInlineAdvisor] FPI:\n";
489   for (auto I : FPICache) {
490     OS << I.first->getName() << ":\n";
491     I.second.print(OS);
492     OS << "\n";
493   }
494   OS << "\n";
495   OS << "[MLInlineAdvisor] FuncLevels:\n";
496   for (auto I : FunctionLevels)
497     OS << (DeadFunctions.contains(&I.first->getFunction())
498                ? "<deleted>"
499                : I.first->getFunction().getName())
500        << " : " << I.second << "\n";
501 
502   OS << "\n";
503 }
504 
505 MLInlineAdvice::MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB,
506                                OptimizationRemarkEmitter &ORE,
507                                bool Recommendation)
508     : InlineAdvice(Advisor, CB, ORE, Recommendation),
509       CallerIRSize(Advisor->isForcedToStop() ? 0 : Advisor->getIRSize(*Caller)),
510       CalleeIRSize(Advisor->isForcedToStop() ? 0 : Advisor->getIRSize(*Callee)),
511       CallerAndCalleeEdges(Advisor->isForcedToStop()
512                                ? 0
513                                : (Advisor->getLocalCalls(*Caller) +
514                                   Advisor->getLocalCalls(*Callee))),
515       PreInlineCallerFPI(Advisor->getCachedFPI(*Caller)) {
516   if (Recommendation)
517     FPU.emplace(Advisor->getCachedFPI(*getCaller()), CB);
518 }
519 
520 void MLInlineAdvice::reportContextForRemark(
521     DiagnosticInfoOptimizationBase &OR) {
522   using namespace ore;
523   OR << NV("Callee", Callee->getName());
524   for (size_t I = 0; I < NumberOfFeatures; ++I)
525     OR << NV(FeatureMap[I].name(),
526              *getAdvisor()->getModelRunner().getTensor<int64_t>(I));
527   OR << NV("ShouldInline", isInliningRecommended());
528 }
529 
530 void MLInlineAdvice::updateCachedCallerFPI(FunctionAnalysisManager &FAM) const {
531   FPU->finish(FAM);
532 }
533 
534 void MLInlineAdvice::recordInliningImpl() {
535   ORE.emit([&]() {
536     OptimizationRemark R(DEBUG_TYPE, "InliningSuccess", DLoc, Block);
537     reportContextForRemark(R);
538     return R;
539   });
540   getAdvisor()->onSuccessfulInlining(*this, /*CalleeWasDeleted*/ false);
541 }
542 
543 void MLInlineAdvice::recordInliningWithCalleeDeletedImpl() {
544   ORE.emit([&]() {
545     OptimizationRemark R(DEBUG_TYPE, "InliningSuccessWithCalleeDeleted", DLoc,
546                          Block);
547     reportContextForRemark(R);
548     return R;
549   });
550   getAdvisor()->onSuccessfulInlining(*this, /*CalleeWasDeleted*/ true);
551 }
552 
553 void MLInlineAdvice::recordUnsuccessfulInliningImpl(
554     const InlineResult &Result) {
555   getAdvisor()->getCachedFPI(*Caller) = PreInlineCallerFPI;
556   ORE.emit([&]() {
557     OptimizationRemarkMissed R(DEBUG_TYPE, "InliningAttemptedAndUnsuccessful",
558                                DLoc, Block);
559     reportContextForRemark(R);
560     return R;
561   });
562 }
563 void MLInlineAdvice::recordUnattemptedInliningImpl() {
564   assert(!FPU);
565   ORE.emit([&]() {
566     OptimizationRemarkMissed R(DEBUG_TYPE, "IniningNotAttempted", DLoc, Block);
567     reportContextForRemark(R);
568     return R;
569   });
570 }
571