xref: /llvm-project/llvm/lib/Analysis/MLInlineAdvisor.cpp (revision f32e5bdcefcff80f4296f8f4abedc37dcda36d53)
1bdceefe9SMircea Trofin //===- MLInlineAdvisor.cpp - machine learned InlineAdvisor ----------------===//
2bdceefe9SMircea Trofin //
3bdceefe9SMircea Trofin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4bdceefe9SMircea Trofin // See https://llvm.org/LICENSE.txt for license information.
5bdceefe9SMircea Trofin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6bdceefe9SMircea Trofin //
7bdceefe9SMircea Trofin //===----------------------------------------------------------------------===//
8bdceefe9SMircea Trofin //
9bdceefe9SMircea Trofin // This file implements the interface between the inliner and a learned model.
10bdceefe9SMircea Trofin // It delegates model evaluation to either the AOT compiled model (the
11bdceefe9SMircea Trofin // 'release' mode) or a runtime-loaded model (the 'development' case).
12bdceefe9SMircea Trofin //
13bdceefe9SMircea Trofin //===----------------------------------------------------------------------===//
14f29256a6SMircea Trofin #include "llvm/Analysis/MLInlineAdvisor.h"
15bdceefe9SMircea Trofin #include "llvm/ADT/SCCIterator.h"
1671c3a551Sserge-sans-paille #include "llvm/Analysis/AssumptionCache.h"
176037a698SMircea Trofin #include "llvm/Analysis/BlockFrequencyInfo.h"
18bdceefe9SMircea Trofin #include "llvm/Analysis/CallGraph.h"
19418121c3STarindu Jayatilaka #include "llvm/Analysis/FunctionPropertiesAnalysis.h"
20bdceefe9SMircea Trofin #include "llvm/Analysis/InlineCost.h"
21f29256a6SMircea Trofin #include "llvm/Analysis/InlineModelFeatureMaps.h"
225fd51fcbSMircea Trofin #include "llvm/Analysis/InteractiveModelRunner.h"
23248d55afSMircea Trofin #include "llvm/Analysis/LazyCallGraph.h"
24f46dd19bSMircea Trofin #include "llvm/Analysis/LoopInfo.h"
25bdceefe9SMircea Trofin #include "llvm/Analysis/MLModelRunner.h"
26bdceefe9SMircea Trofin #include "llvm/Analysis/OptimizationRemarkEmitter.h"
276037a698SMircea Trofin #include "llvm/Analysis/ProfileSummaryInfo.h"
285fd51fcbSMircea Trofin #include "llvm/Analysis/ReleaseModeModelRunner.h"
29bdceefe9SMircea Trofin #include "llvm/Analysis/TargetTransformInfo.h"
3022a1f998SMircea Trofin #include "llvm/IR/Dominators.h"
31bdceefe9SMircea Trofin #include "llvm/IR/InstIterator.h"
324169338eSNikita Popov #include "llvm/IR/Module.h"
33bdceefe9SMircea Trofin #include "llvm/IR/PassManager.h"
34bdceefe9SMircea Trofin #include "llvm/Support/CommandLine.h"
35f29256a6SMircea Trofin 
36bdceefe9SMircea Trofin using namespace llvm;
37bdceefe9SMircea Trofin 
385fd51fcbSMircea Trofin static cl::opt<std::string> InteractiveChannelBaseName(
395fd51fcbSMircea Trofin     "inliner-interactive-channel-base", cl::Hidden,
405fd51fcbSMircea Trofin     cl::desc(
415fd51fcbSMircea Trofin         "Base file path for the interactive mode. The incoming filename should "
425fd51fcbSMircea Trofin         "have the name <inliner-interactive-channel-base>.in, while the "
435fd51fcbSMircea Trofin         "outgoing name should be <inliner-interactive-channel-base>.out"));
44f3b5fca1SMircea Trofin static const std::string InclDefaultMsg =
45f3b5fca1SMircea Trofin     (Twine("In interactive mode, also send the default policy decision: ") +
46f3b5fca1SMircea Trofin      DefaultDecisionName + ".")
47f3b5fca1SMircea Trofin         .str();
48f3b5fca1SMircea Trofin static cl::opt<bool>
49f3b5fca1SMircea Trofin     InteractiveIncludeDefault("inliner-interactive-include-default", cl::Hidden,
50f3b5fca1SMircea Trofin                               cl::desc(InclDefaultMsg));
515fd51fcbSMircea Trofin 
526037a698SMircea Trofin enum class SkipMLPolicyCriteria { Never, IfCallerIsNotCold };
536037a698SMircea Trofin 
546037a698SMircea Trofin static cl::opt<SkipMLPolicyCriteria> SkipPolicy(
556037a698SMircea Trofin     "ml-inliner-skip-policy", cl::Hidden, cl::init(SkipMLPolicyCriteria::Never),
566037a698SMircea Trofin     cl::values(clEnumValN(SkipMLPolicyCriteria::Never, "never", "never"),
576037a698SMircea Trofin                clEnumValN(SkipMLPolicyCriteria::IfCallerIsNotCold,
586037a698SMircea Trofin                           "if-caller-not-cold", "if the caller is not cold")));
596037a698SMircea Trofin 
60313b1a82SMircea Trofin static cl::opt<std::string> ModelSelector("ml-inliner-model-selector",
61313b1a82SMircea Trofin                                           cl::Hidden, cl::init(""));
62313b1a82SMircea Trofin 
63b1af01feSMircea Trofin #if defined(LLVM_HAVE_TF_AOT_INLINERSIZEMODEL)
64db5aceb9SMircea Trofin // codegen-ed file
65db5aceb9SMircea Trofin #include "InlinerSizeModel.h" // NOLINT
665fd51fcbSMircea Trofin using CompiledModelType = llvm::InlinerSizeModel;
675fd51fcbSMircea Trofin #else
685fd51fcbSMircea Trofin using CompiledModelType = NoopSavedModelImpl;
695fd51fcbSMircea Trofin #endif
70db5aceb9SMircea Trofin 
71db5aceb9SMircea Trofin std::unique_ptr<InlineAdvisor>
72ab2e7666SMircea Trofin llvm::getReleaseModeAdvisor(Module &M, ModuleAnalysisManager &MAM,
73ab2e7666SMircea Trofin                             std::function<bool(CallBase &)> GetDefaultAdvice) {
745fd51fcbSMircea Trofin   if (!llvm::isEmbeddedModelEvaluatorValid<CompiledModelType>() &&
755fd51fcbSMircea Trofin       InteractiveChannelBaseName.empty())
765fd51fcbSMircea Trofin     return nullptr;
775fd51fcbSMircea Trofin   std::unique_ptr<MLModelRunner> AOTRunner;
785fd51fcbSMircea Trofin   if (InteractiveChannelBaseName.empty())
795fd51fcbSMircea Trofin     AOTRunner = std::make_unique<ReleaseModeModelRunner<CompiledModelType>>(
80313b1a82SMircea Trofin         M.getContext(), FeatureMap, DecisionName,
81313b1a82SMircea Trofin         EmbeddedModelRunnerOptions().setModelSelector(ModelSelector));
82ab2e7666SMircea Trofin   else {
83ab2e7666SMircea Trofin     auto Features = FeatureMap;
84ab2e7666SMircea Trofin     if (InteractiveIncludeDefault)
85ab2e7666SMircea Trofin       Features.push_back(DefaultDecisionSpec);
865fd51fcbSMircea Trofin     AOTRunner = std::make_unique<InteractiveModelRunner>(
87ab2e7666SMircea Trofin         M.getContext(), Features, InlineDecisionSpec,
885fd51fcbSMircea Trofin         InteractiveChannelBaseName + ".out",
895fd51fcbSMircea Trofin         InteractiveChannelBaseName + ".in");
90ab2e7666SMircea Trofin   }
91ab2e7666SMircea Trofin   return std::make_unique<MLInlineAdvisor>(M, MAM, std::move(AOTRunner),
92ab2e7666SMircea Trofin                                            GetDefaultAdvice);
93db5aceb9SMircea Trofin }
94db5aceb9SMircea Trofin 
95bdceefe9SMircea Trofin #define DEBUG_TYPE "inline-ml"
96bdceefe9SMircea Trofin 
97bdceefe9SMircea Trofin static cl::opt<float> SizeIncreaseThreshold(
98bdceefe9SMircea Trofin     "ml-advisor-size-increase-threshold", cl::Hidden,
99bdceefe9SMircea Trofin     cl::desc("Maximum factor by which expected native size may increase before "
100bdceefe9SMircea Trofin              "blocking any further inlining."),
101bdceefe9SMircea Trofin     cl::init(2.0));
102bdceefe9SMircea Trofin 
1037e7021caSMircea Trofin static cl::opt<bool> KeepFPICache(
1047e7021caSMircea Trofin     "ml-advisor-keep-fpi-cache", cl::Hidden,
1057e7021caSMircea Trofin     cl::desc(
1067e7021caSMircea Trofin         "For test - keep the ML Inline advisor's FunctionPropertiesInfo cache"),
1077e7021caSMircea Trofin     cl::init(false));
1087e7021caSMircea Trofin 
10999f00635SJacob Hegna // clang-format off
1105fd51fcbSMircea Trofin const std::vector<TensorSpec> llvm::FeatureMap{
111f9b3e341SJacob Hegna #define POPULATE_NAMES(DTYPE, SHAPE, NAME, __) TensorSpec::createSpec<DTYPE>(#NAME, SHAPE),
11299f00635SJacob Hegna // InlineCost features - these must come first
11399f00635SJacob Hegna   INLINE_COST_FEATURE_ITERATOR(POPULATE_NAMES)
11499f00635SJacob Hegna 
11599f00635SJacob Hegna // Non-cost features
116bdceefe9SMircea Trofin   INLINE_FEATURE_ITERATOR(POPULATE_NAMES)
117bdceefe9SMircea Trofin #undef POPULATE_NAMES
118bdceefe9SMircea Trofin };
11999f00635SJacob Hegna // clang-format on
120bdceefe9SMircea Trofin 
121bdceefe9SMircea Trofin const char *const llvm::DecisionName = "inlining_decision";
1225fd51fcbSMircea Trofin const TensorSpec llvm::InlineDecisionSpec =
1235fd51fcbSMircea Trofin     TensorSpec::createSpec<int64_t>(DecisionName, {1});
124bdceefe9SMircea Trofin const char *const llvm::DefaultDecisionName = "inlining_default";
125ab2e7666SMircea Trofin const TensorSpec llvm::DefaultDecisionSpec =
126ab2e7666SMircea Trofin     TensorSpec::createSpec<int64_t>(DefaultDecisionName, {1});
127bdceefe9SMircea Trofin const char *const llvm::RewardName = "delta_size";
128bdceefe9SMircea Trofin 
129bdceefe9SMircea Trofin CallBase *getInlinableCS(Instruction &I) {
130bdceefe9SMircea Trofin   if (auto *CS = dyn_cast<CallBase>(&I))
131bdceefe9SMircea Trofin     if (Function *Callee = CS->getCalledFunction()) {
132bdceefe9SMircea Trofin       if (!Callee->isDeclaration()) {
133bdceefe9SMircea Trofin         return CS;
134bdceefe9SMircea Trofin       }
135bdceefe9SMircea Trofin     }
136bdceefe9SMircea Trofin   return nullptr;
137bdceefe9SMircea Trofin }
138bdceefe9SMircea Trofin 
139ab2e7666SMircea Trofin MLInlineAdvisor::MLInlineAdvisor(
140ab2e7666SMircea Trofin     Module &M, ModuleAnalysisManager &MAM,
141ab2e7666SMircea Trofin     std::unique_ptr<MLModelRunner> Runner,
142ab2e7666SMircea Trofin     std::function<bool(CallBase &)> GetDefaultAdvice)
143bdceefe9SMircea Trofin     : InlineAdvisor(
144ccec2cf1SMircea Trofin           M, MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager()),
145ab2e7666SMircea Trofin       ModelRunner(std::move(Runner)), GetDefaultAdvice(GetDefaultAdvice),
146248d55afSMircea Trofin       CG(MAM.getResult<LazyCallGraphAnalysis>(M)),
1476037a698SMircea Trofin       InitialIRSize(getModuleIRSize()), CurrentIRSize(InitialIRSize),
1486037a698SMircea Trofin       PSI(MAM.getResult<ProfileSummaryAnalysis>(M)) {
149bdceefe9SMircea Trofin   assert(ModelRunner);
1505fd51fcbSMircea Trofin   ModelRunner->switchContext("");
151bdceefe9SMircea Trofin   // Extract the 'call site height' feature - the position of a call site
152bdceefe9SMircea Trofin   // relative to the farthest statically reachable SCC node. We don't mutate
153bdceefe9SMircea Trofin   // this value while inlining happens. Empirically, this feature proved
154bdceefe9SMircea Trofin   // critical in behavioral cloning - i.e. training a model to mimic the manual
155bdceefe9SMircea Trofin   // heuristic's decisions - and, thus, equally important for training for
156bdceefe9SMircea Trofin   // improvement.
157248d55afSMircea Trofin   CallGraph CGraph(M);
158248d55afSMircea Trofin   for (auto I = scc_begin(&CGraph); !I.isAtEnd(); ++I) {
159bdceefe9SMircea Trofin     const std::vector<CallGraphNode *> &CGNodes = *I;
160bdceefe9SMircea Trofin     unsigned Level = 0;
161bdceefe9SMircea Trofin     for (auto *CGNode : CGNodes) {
162bdceefe9SMircea Trofin       Function *F = CGNode->getFunction();
163bdceefe9SMircea Trofin       if (!F || F->isDeclaration())
164bdceefe9SMircea Trofin         continue;
165bdceefe9SMircea Trofin       for (auto &I : instructions(F)) {
166bdceefe9SMircea Trofin         if (auto *CS = getInlinableCS(I)) {
167bdceefe9SMircea Trofin           auto *Called = CS->getCalledFunction();
168248d55afSMircea Trofin           auto Pos = FunctionLevels.find(&CG.get(*Called));
169bdceefe9SMircea Trofin           // In bottom up traversal, an inlinable callee is either in the
170bdceefe9SMircea Trofin           // same SCC, or to a function in a visited SCC. So not finding its
171bdceefe9SMircea Trofin           // level means we haven't visited it yet, meaning it's in this SCC.
172bdceefe9SMircea Trofin           if (Pos == FunctionLevels.end())
173bdceefe9SMircea Trofin             continue;
174bdceefe9SMircea Trofin           Level = std::max(Level, Pos->second + 1);
175bdceefe9SMircea Trofin         }
176bdceefe9SMircea Trofin       }
177bdceefe9SMircea Trofin     }
178bdceefe9SMircea Trofin     for (auto *CGNode : CGNodes) {
179bdceefe9SMircea Trofin       Function *F = CGNode->getFunction();
180bdceefe9SMircea Trofin       if (F && !F->isDeclaration())
181248d55afSMircea Trofin         FunctionLevels[&CG.get(*F)] = Level;
182bdceefe9SMircea Trofin     }
183bdceefe9SMircea Trofin   }
1843e8553aaSMircea Trofin   for (auto KVP : FunctionLevels) {
1853e8553aaSMircea Trofin     AllNodes.insert(KVP.first);
1863e8553aaSMircea Trofin     EdgeCount += getLocalCalls(KVP.first->getFunction());
1873e8553aaSMircea Trofin   }
1883e8553aaSMircea Trofin   NodeCount = AllNodes.size();
189bdceefe9SMircea Trofin }
190bdceefe9SMircea Trofin 
191248d55afSMircea Trofin unsigned MLInlineAdvisor::getInitialFunctionLevel(const Function &F) const {
192248d55afSMircea Trofin   return CG.lookup(F) ? FunctionLevels.at(CG.lookup(F)) : 0;
193248d55afSMircea Trofin }
194248d55afSMircea Trofin 
1950555afd0SArthur Eubanks void MLInlineAdvisor::onPassEntry(LazyCallGraph::SCC *CurSCC) {
1960555afd0SArthur Eubanks   if (!CurSCC || ForceStop)
197a3a7826dSJin Xin Ng     return;
198f46dd19bSMircea Trofin   FPICache.clear();
199bdceefe9SMircea Trofin   // Function passes executed between InlinerPass runs may have changed the
200bdceefe9SMircea Trofin   // module-wide features.
2013e8553aaSMircea Trofin   // The cgscc pass manager rules are such that:
2023e8553aaSMircea Trofin   // - if a pass leads to merging SCCs, then the pipeline is restarted on the
2033e8553aaSMircea Trofin   // merged SCC
2043e8553aaSMircea Trofin   // - if a pass leads to splitting the SCC, then we continue with one of the
2053e8553aaSMircea Trofin   // splits
2063e8553aaSMircea Trofin   // This means that the NodesInLastSCC is a superset (not strict) of the nodes
2073e8553aaSMircea Trofin   // that subsequent passes would have processed
2083e8553aaSMircea Trofin   // - in addition, if new Nodes were created by a pass (e.g. CoroSplit),
2093e8553aaSMircea Trofin   // they'd be adjacent to Nodes in the last SCC. So we just need to check the
2103e8553aaSMircea Trofin   // boundary of Nodes in NodesInLastSCC for Nodes we haven't seen. We don't
2111b3fc405SMircea Trofin   // care about the nature of the Edge (call or ref). `FunctionLevels`-wise, we
2121b3fc405SMircea Trofin   // record them at the same level as the original node (this is a choice, may
2131b3fc405SMircea Trofin   // need revisiting).
21494471e6dSArthur Eubanks   // - nodes are only deleted at the end of a call graph walk where they are
21594471e6dSArthur Eubanks   // batch deleted, so we shouldn't see any dead nodes here.
2163e8553aaSMircea Trofin   while (!NodesInLastSCC.empty()) {
217aaff3fb6SJin Xin Ng     const auto *N = *NodesInLastSCC.begin();
21894471e6dSArthur Eubanks     assert(!N->isDead());
219aaff3fb6SJin Xin Ng     NodesInLastSCC.erase(N);
2203e8553aaSMircea Trofin     EdgeCount += getLocalCalls(N->getFunction());
2211b3fc405SMircea Trofin     const auto NLevel = FunctionLevels.at(N);
2223e8553aaSMircea Trofin     for (const auto &E : *(*N)) {
2233e8553aaSMircea Trofin       const auto *AdjNode = &E.getNode();
2243e8553aaSMircea Trofin       assert(!AdjNode->isDead() && !AdjNode->getFunction().isDeclaration());
2253e8553aaSMircea Trofin       auto I = AllNodes.insert(AdjNode);
22681f4fb65SArthur Eubanks       // We've discovered a new function.
2271b3fc405SMircea Trofin       if (I.second) {
22881f4fb65SArthur Eubanks         ++NodeCount;
229aaff3fb6SJin Xin Ng         NodesInLastSCC.insert(AdjNode);
2301b3fc405SMircea Trofin         FunctionLevels[AdjNode] = NLevel;
2311b3fc405SMircea Trofin       }
2323e8553aaSMircea Trofin     }
2333e8553aaSMircea Trofin   }
2343e8553aaSMircea Trofin 
2353e8553aaSMircea Trofin   EdgeCount -= EdgesOfLastSeenNodes;
2363e8553aaSMircea Trofin   EdgesOfLastSeenNodes = 0;
237aaff3fb6SJin Xin Ng 
238aaff3fb6SJin Xin Ng   // (Re)use NodesInLastSCC to remember the nodes in the SCC right now,
239aaff3fb6SJin Xin Ng   // in case the SCC is split before onPassExit and some nodes are split out
240aaff3fb6SJin Xin Ng   assert(NodesInLastSCC.empty());
2410555afd0SArthur Eubanks   for (const auto &N : *CurSCC)
242aaff3fb6SJin Xin Ng     NodesInLastSCC.insert(&N);
2433e8553aaSMircea Trofin }
2443e8553aaSMircea Trofin 
2450555afd0SArthur Eubanks void MLInlineAdvisor::onPassExit(LazyCallGraph::SCC *CurSCC) {
246f46dd19bSMircea Trofin   // No need to keep this around - function passes will invalidate it.
2477e7021caSMircea Trofin   if (!KeepFPICache)
248f46dd19bSMircea Trofin     FPICache.clear();
2490555afd0SArthur Eubanks   if (!CurSCC || ForceStop)
2503e8553aaSMircea Trofin     return;
2513e8553aaSMircea Trofin   // Keep track of the nodes and edges we last saw. Then, in onPassEntry,
2523e8553aaSMircea Trofin   // we update the node count and edge count from the subset of these nodes that
2533e8553aaSMircea Trofin   // survived.
2543e8553aaSMircea Trofin   EdgesOfLastSeenNodes = 0;
255aaff3fb6SJin Xin Ng 
256aaff3fb6SJin Xin Ng   // Check on nodes that were in SCC onPassEntry
25794471e6dSArthur Eubanks   for (const LazyCallGraph::Node *N : NodesInLastSCC) {
25894471e6dSArthur Eubanks     assert(!N->isDead());
25994471e6dSArthur Eubanks     EdgesOfLastSeenNodes += getLocalCalls(N->getFunction());
260aaff3fb6SJin Xin Ng   }
261aaff3fb6SJin Xin Ng 
262aaff3fb6SJin Xin Ng   // Check on nodes that may have got added to SCC
2630555afd0SArthur Eubanks   for (const auto &N : *CurSCC) {
2643e8553aaSMircea Trofin     assert(!N.isDead());
265aaff3fb6SJin Xin Ng     auto I = NodesInLastSCC.insert(&N);
266aaff3fb6SJin Xin Ng     if (I.second)
2673e8553aaSMircea Trofin       EdgesOfLastSeenNodes += getLocalCalls(N.getFunction());
2683e8553aaSMircea Trofin   }
269aaff3fb6SJin Xin Ng   assert(NodeCount >= NodesInLastSCC.size());
2703e8553aaSMircea Trofin   assert(EdgeCount >= EdgesOfLastSeenNodes);
271bdceefe9SMircea Trofin }
272bdceefe9SMircea Trofin 
273bdceefe9SMircea Trofin int64_t MLInlineAdvisor::getLocalCalls(Function &F) {
274f46dd19bSMircea Trofin   return getCachedFPI(F).DirectCallsToDefinedFunctions;
275bdceefe9SMircea Trofin }
276bdceefe9SMircea Trofin 
277bdceefe9SMircea Trofin // Update the internal state of the advisor, and force invalidate feature
278bdceefe9SMircea Trofin // analysis. Currently, we maintain minimal (and very simple) global state - the
279bdceefe9SMircea Trofin // number of functions and the number of static calls. We also keep track of the
280bdceefe9SMircea Trofin // total IR size in this module, to stop misbehaving policies at a certain bloat
281bdceefe9SMircea Trofin // factor (SizeIncreaseThreshold)
282bdceefe9SMircea Trofin void MLInlineAdvisor::onSuccessfulInlining(const MLInlineAdvice &Advice,
283bdceefe9SMircea Trofin                                            bool CalleeWasDeleted) {
284bdceefe9SMircea Trofin   assert(!ForceStop);
285bdceefe9SMircea Trofin   Function *Caller = Advice.getCaller();
286bdceefe9SMircea Trofin   Function *Callee = Advice.getCallee();
287bdceefe9SMircea Trofin   // The caller features aren't valid anymore.
2880d06b14fSMircea Trofin   {
2897ae92a69SMircea Trofin     PreservedAnalyses PA = PreservedAnalyses::all();
2907ae92a69SMircea Trofin     PA.abandon<FunctionPropertiesAnalysis>();
2917ae92a69SMircea Trofin     PA.abandon<LoopAnalysis>();
2920d06b14fSMircea Trofin     FAM.invalidate(*Caller, PA);
2930d06b14fSMircea Trofin   }
29422a1f998SMircea Trofin   Advice.updateCachedCallerFPI(FAM);
295bdceefe9SMircea Trofin   int64_t IRSizeAfter =
296bdceefe9SMircea Trofin       getIRSize(*Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize);
297bdceefe9SMircea Trofin   CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize);
298bdceefe9SMircea Trofin   if (CurrentIRSize > SizeIncreaseThreshold * InitialIRSize)
299bdceefe9SMircea Trofin     ForceStop = true;
300bdceefe9SMircea Trofin 
301bdceefe9SMircea Trofin   // We can delta-update module-wide features. We know the inlining only changed
302bdceefe9SMircea Trofin   // the caller, and maybe the callee (by deleting the latter).
303bdceefe9SMircea Trofin   // Nodes are simple to update.
304bdceefe9SMircea Trofin   // For edges, we 'forget' the edges that the caller and callee used to have
305bdceefe9SMircea Trofin   // before inlining, and add back what they currently have together.
306bdceefe9SMircea Trofin   int64_t NewCallerAndCalleeEdges =
307f46dd19bSMircea Trofin       getCachedFPI(*Caller).DirectCallsToDefinedFunctions;
308bdceefe9SMircea Trofin 
30994471e6dSArthur Eubanks   // A dead function's node is not actually removed from the call graph until
31094471e6dSArthur Eubanks   // the end of the call graph walk, but the node no longer belongs to any valid
31194471e6dSArthur Eubanks   // SCC.
312ebdb6f4eSArthur Eubanks   if (CalleeWasDeleted) {
313bdceefe9SMircea Trofin     --NodeCount;
31494471e6dSArthur Eubanks     NodesInLastSCC.erase(CG.lookup(*Callee));
315ebdb6f4eSArthur Eubanks     DeadFunctions.insert(Callee);
316ebdb6f4eSArthur Eubanks   } else {
317418121c3STarindu Jayatilaka     NewCallerAndCalleeEdges +=
318f46dd19bSMircea Trofin         getCachedFPI(*Callee).DirectCallsToDefinedFunctions;
319ebdb6f4eSArthur Eubanks   }
320bdceefe9SMircea Trofin   EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges);
321bdceefe9SMircea Trofin   assert(CurrentIRSize >= 0 && EdgeCount >= 0 && NodeCount >= 0);
322bdceefe9SMircea Trofin }
323bdceefe9SMircea Trofin 
324bdceefe9SMircea Trofin int64_t MLInlineAdvisor::getModuleIRSize() const {
325bdceefe9SMircea Trofin   int64_t Ret = 0;
326248d55afSMircea Trofin   for (auto &F : M)
327bdceefe9SMircea Trofin     if (!F.isDeclaration())
328bdceefe9SMircea Trofin       Ret += getIRSize(F);
329bdceefe9SMircea Trofin   return Ret;
330bdceefe9SMircea Trofin }
331bdceefe9SMircea Trofin 
332f46dd19bSMircea Trofin FunctionPropertiesInfo &MLInlineAdvisor::getCachedFPI(Function &F) const {
333f46dd19bSMircea Trofin   auto InsertPair =
334f46dd19bSMircea Trofin       FPICache.insert(std::make_pair(&F, FunctionPropertiesInfo()));
335f46dd19bSMircea Trofin   if (!InsertPair.second)
336f46dd19bSMircea Trofin     return InsertPair.first->second;
337f46dd19bSMircea Trofin   InsertPair.first->second = FAM.getResult<FunctionPropertiesAnalysis>(F);
338f46dd19bSMircea Trofin   return InsertPair.first->second;
339f46dd19bSMircea Trofin }
340f46dd19bSMircea Trofin 
341e8049dc3SMircea Trofin std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) {
3427f24e574SMircea Trofin   if (auto Skip = getSkipAdviceIfUnreachableCallsite(CB))
3437f24e574SMircea Trofin     return Skip;
3447f24e574SMircea Trofin 
345bdceefe9SMircea Trofin   auto &Caller = *CB.getCaller();
346bdceefe9SMircea Trofin   auto &Callee = *CB.getCalledFunction();
347bdceefe9SMircea Trofin 
348bdceefe9SMircea Trofin   auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {
349bdceefe9SMircea Trofin     return FAM.getResult<AssumptionAnalysis>(F);
350bdceefe9SMircea Trofin   };
351bdceefe9SMircea Trofin   auto &TIR = FAM.getResult<TargetIRAnalysis>(Callee);
352bdceefe9SMircea Trofin   auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(Caller);
353bdceefe9SMircea Trofin 
3546037a698SMircea Trofin   if (SkipPolicy == SkipMLPolicyCriteria::IfCallerIsNotCold) {
3556037a698SMircea Trofin     if (!PSI.isFunctionEntryCold(&Caller))
3566037a698SMircea Trofin       return std::make_unique<InlineAdvice>(this, CB, ORE,
3576037a698SMircea Trofin                                             GetDefaultAdvice(CB));
3586037a698SMircea Trofin   }
359e8049dc3SMircea Trofin   auto MandatoryKind = InlineAdvisor::getMandatoryKind(CB, FAM, ORE);
360bdceefe9SMircea Trofin   // If this is a "never inline" case, there won't be any changes to internal
361bdceefe9SMircea Trofin   // state we need to track, so we can just return the base InlineAdvice, which
362bdceefe9SMircea Trofin   // will do nothing interesting.
363bdceefe9SMircea Trofin   // Same thing if this is a recursive case.
364e8049dc3SMircea Trofin   if (MandatoryKind == InlineAdvisor::MandatoryInliningKind::Never ||
365bdceefe9SMircea Trofin       &Caller == &Callee)
366e8049dc3SMircea Trofin     return getMandatoryAdvice(CB, false);
367bdceefe9SMircea Trofin 
3685fe10263SMircea Trofin   bool Mandatory =
369e8049dc3SMircea Trofin       MandatoryKind == InlineAdvisor::MandatoryInliningKind::Always;
370bdceefe9SMircea Trofin 
371bdceefe9SMircea Trofin   // If we need to stop, we won't want to track anymore any state changes, so
372bdceefe9SMircea Trofin   // we just return the base InlineAdvice, which acts as a noop.
373bdceefe9SMircea Trofin   if (ForceStop) {
374bdceefe9SMircea Trofin     ORE.emit([&] {
375bdceefe9SMircea Trofin       return OptimizationRemarkMissed(DEBUG_TYPE, "ForceStop", &CB)
376bdceefe9SMircea Trofin              << "Won't attempt inlining because module size grew too much.";
377bdceefe9SMircea Trofin     });
378bdceefe9SMircea Trofin     return std::make_unique<InlineAdvice>(this, CB, ORE, Mandatory);
379bdceefe9SMircea Trofin   }
380bdceefe9SMircea Trofin 
381bdceefe9SMircea Trofin   int CostEstimate = 0;
382bdceefe9SMircea Trofin   if (!Mandatory) {
383bdceefe9SMircea Trofin     auto IsCallSiteInlinable =
384bdceefe9SMircea Trofin         llvm::getInliningCostEstimate(CB, TIR, GetAssumptionCache);
385bdceefe9SMircea Trofin     if (!IsCallSiteInlinable) {
386bdceefe9SMircea Trofin       // We can't inline this for correctness reasons, so return the base
387bdceefe9SMircea Trofin       // InlineAdvice, as we don't care about tracking any state changes (which
388bdceefe9SMircea Trofin       // won't happen).
389bdceefe9SMircea Trofin       return std::make_unique<InlineAdvice>(this, CB, ORE, false);
390bdceefe9SMircea Trofin     }
391bdceefe9SMircea Trofin     CostEstimate = *IsCallSiteInlinable;
392bdceefe9SMircea Trofin   }
393bdceefe9SMircea Trofin 
39499f00635SJacob Hegna   const auto CostFeatures =
39599f00635SJacob Hegna       llvm::getInliningCostFeatures(CB, TIR, GetAssumptionCache);
39699f00635SJacob Hegna   if (!CostFeatures) {
39799f00635SJacob Hegna     return std::make_unique<InlineAdvice>(this, CB, ORE, false);
39899f00635SJacob Hegna   }
39999f00635SJacob Hegna 
400bdceefe9SMircea Trofin   if (Mandatory)
401e8049dc3SMircea Trofin     return getMandatoryAdvice(CB, true);
402bdceefe9SMircea Trofin 
403*f32e5bdcSMircea Trofin   auto NumCtantParams = 0;
404bdceefe9SMircea Trofin   for (auto I = CB.arg_begin(), E = CB.arg_end(); I != E; ++I) {
405*f32e5bdcSMircea Trofin     NumCtantParams += (isa<Constant>(*I));
406bdceefe9SMircea Trofin   }
407bdceefe9SMircea Trofin 
408f46dd19bSMircea Trofin   auto &CallerBefore = getCachedFPI(Caller);
409f46dd19bSMircea Trofin   auto &CalleeBefore = getCachedFPI(Callee);
410bdceefe9SMircea Trofin 
411f9b3e341SJacob Hegna   *ModelRunner->getTensor<int64_t>(FeatureIndex::callee_basic_block_count) =
412059e0347SMircea Trofin       CalleeBefore.BasicBlockCount;
413f9b3e341SJacob Hegna   *ModelRunner->getTensor<int64_t>(FeatureIndex::callsite_height) =
414248d55afSMircea Trofin       getInitialFunctionLevel(Caller);
415f9b3e341SJacob Hegna   *ModelRunner->getTensor<int64_t>(FeatureIndex::node_count) = NodeCount;
416f9b3e341SJacob Hegna   *ModelRunner->getTensor<int64_t>(FeatureIndex::nr_ctant_params) =
417*f32e5bdcSMircea Trofin       NumCtantParams;
418f9b3e341SJacob Hegna   *ModelRunner->getTensor<int64_t>(FeatureIndex::edge_count) = EdgeCount;
419f9b3e341SJacob Hegna   *ModelRunner->getTensor<int64_t>(FeatureIndex::caller_users) =
420059e0347SMircea Trofin       CallerBefore.Uses;
421059e0347SMircea Trofin   *ModelRunner->getTensor<int64_t>(
422f9b3e341SJacob Hegna       FeatureIndex::caller_conditionally_executed_blocks) =
423059e0347SMircea Trofin       CallerBefore.BlocksReachedFromConditionalInstruction;
424f9b3e341SJacob Hegna   *ModelRunner->getTensor<int64_t>(FeatureIndex::caller_basic_block_count) =
425059e0347SMircea Trofin       CallerBefore.BasicBlockCount;
426059e0347SMircea Trofin   *ModelRunner->getTensor<int64_t>(
427f9b3e341SJacob Hegna       FeatureIndex::callee_conditionally_executed_blocks) =
428059e0347SMircea Trofin       CalleeBefore.BlocksReachedFromConditionalInstruction;
429f9b3e341SJacob Hegna   *ModelRunner->getTensor<int64_t>(FeatureIndex::callee_users) =
430059e0347SMircea Trofin       CalleeBefore.Uses;
431f9b3e341SJacob Hegna   *ModelRunner->getTensor<int64_t>(FeatureIndex::cost_estimate) = CostEstimate;
432600ff287SMircea Trofin   *ModelRunner->getTensor<int64_t>(FeatureIndex::is_callee_avail_external) =
433600ff287SMircea Trofin       Callee.hasAvailableExternallyLinkage();
434600ff287SMircea Trofin   *ModelRunner->getTensor<int64_t>(FeatureIndex::is_caller_avail_external) =
435600ff287SMircea Trofin       Caller.hasAvailableExternallyLinkage();
43699f00635SJacob Hegna 
43799f00635SJacob Hegna   // Add the cost features
43899f00635SJacob Hegna   for (size_t I = 0;
43999f00635SJacob Hegna        I < static_cast<size_t>(InlineCostFeatureIndex::NumberOfFeatures); ++I) {
440059e0347SMircea Trofin     *ModelRunner->getTensor<int64_t>(inlineCostFeatureToMlFeature(
441059e0347SMircea Trofin         static_cast<InlineCostFeatureIndex>(I))) = CostFeatures->at(I);
44299f00635SJacob Hegna   }
443ab2e7666SMircea Trofin   // This one would have been set up to be right at the end.
444ab2e7666SMircea Trofin   if (!InteractiveChannelBaseName.empty() && InteractiveIncludeDefault)
445ab2e7666SMircea Trofin     *ModelRunner->getTensor<int64_t>(InlineCostFeatureIndex::NumberOfFeatures) =
446ab2e7666SMircea Trofin         GetDefaultAdvice(CB);
447bdceefe9SMircea Trofin   return getAdviceFromModel(CB, ORE);
448bdceefe9SMircea Trofin }
449bdceefe9SMircea Trofin 
450bdceefe9SMircea Trofin std::unique_ptr<MLInlineAdvice>
451bdceefe9SMircea Trofin MLInlineAdvisor::getAdviceFromModel(CallBase &CB,
452bdceefe9SMircea Trofin                                     OptimizationRemarkEmitter &ORE) {
453059e0347SMircea Trofin   return std::make_unique<MLInlineAdvice>(
454059e0347SMircea Trofin       this, CB, ORE, static_cast<bool>(ModelRunner->evaluate<int64_t>()));
455bdceefe9SMircea Trofin }
456bdceefe9SMircea Trofin 
4577f24e574SMircea Trofin std::unique_ptr<InlineAdvice>
4587f24e574SMircea Trofin MLInlineAdvisor::getSkipAdviceIfUnreachableCallsite(CallBase &CB) {
4597f24e574SMircea Trofin   if (!FAM.getResult<DominatorTreeAnalysis>(*CB.getCaller())
4607f24e574SMircea Trofin            .isReachableFromEntry(CB.getParent()))
4617f24e574SMircea Trofin     return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), false);
4627f24e574SMircea Trofin   return nullptr;
4637f24e574SMircea Trofin }
4647f24e574SMircea Trofin 
465e8049dc3SMircea Trofin std::unique_ptr<InlineAdvice> MLInlineAdvisor::getMandatoryAdvice(CallBase &CB,
466e8049dc3SMircea Trofin                                                                   bool Advice) {
467e8049dc3SMircea Trofin   // Make sure we track inlinings in all cases - mandatory or not.
4687f24e574SMircea Trofin   if (auto Skip = getSkipAdviceIfUnreachableCallsite(CB))
4697f24e574SMircea Trofin     return Skip;
470e8049dc3SMircea Trofin   if (Advice && !ForceStop)
471e8049dc3SMircea Trofin     return getMandatoryAdviceImpl(CB);
472e8049dc3SMircea Trofin 
473e8049dc3SMircea Trofin   // If this is a "never inline" case, there won't be any changes to internal
474e8049dc3SMircea Trofin   // state we need to track, so we can just return the base InlineAdvice, which
475e8049dc3SMircea Trofin   // will do nothing interesting.
476e8049dc3SMircea Trofin   // Same if we are forced to stop - we don't track anymore.
477e8049dc3SMircea Trofin   return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), Advice);
478e8049dc3SMircea Trofin }
479e8049dc3SMircea Trofin 
480bdceefe9SMircea Trofin std::unique_ptr<MLInlineAdvice>
481e8049dc3SMircea Trofin MLInlineAdvisor::getMandatoryAdviceImpl(CallBase &CB) {
482e8049dc3SMircea Trofin   return std::make_unique<MLInlineAdvice>(this, CB, getCallerORE(CB), true);
483bdceefe9SMircea Trofin }
484bdceefe9SMircea Trofin 
4857e7021caSMircea Trofin void MLInlineAdvisor::print(raw_ostream &OS) const {
4867e7021caSMircea Trofin   OS << "[MLInlineAdvisor] Nodes: " << NodeCount << " Edges: " << EdgeCount
487aaff3fb6SJin Xin Ng      << " EdgesOfLastSeenNodes: " << EdgesOfLastSeenNodes << "\n";
4887e7021caSMircea Trofin   OS << "[MLInlineAdvisor] FPI:\n";
4897e7021caSMircea Trofin   for (auto I : FPICache) {
4905617fb14SMircea Trofin     OS << I.first->getName() << ":\n";
4915617fb14SMircea Trofin     I.second.print(OS);
4927e7021caSMircea Trofin     OS << "\n";
4937e7021caSMircea Trofin   }
4947e7021caSMircea Trofin   OS << "\n";
4951b3fc405SMircea Trofin   OS << "[MLInlineAdvisor] FuncLevels:\n";
4961b3fc405SMircea Trofin   for (auto I : FunctionLevels)
497ebdb6f4eSArthur Eubanks     OS << (DeadFunctions.contains(&I.first->getFunction())
498ebdb6f4eSArthur Eubanks                ? "<deleted>"
499ebdb6f4eSArthur Eubanks                : I.first->getFunction().getName())
5001b3fc405SMircea Trofin        << " : " << I.second << "\n";
5011b3fc405SMircea Trofin 
5021b3fc405SMircea Trofin   OS << "\n";
5037e7021caSMircea Trofin }
5047e7021caSMircea Trofin 
505f46dd19bSMircea Trofin MLInlineAdvice::MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB,
506f46dd19bSMircea Trofin                                OptimizationRemarkEmitter &ORE,
507f46dd19bSMircea Trofin                                bool Recommendation)
508f46dd19bSMircea Trofin     : InlineAdvice(Advisor, CB, ORE, Recommendation),
509f46dd19bSMircea Trofin       CallerIRSize(Advisor->isForcedToStop() ? 0 : Advisor->getIRSize(*Caller)),
510f46dd19bSMircea Trofin       CalleeIRSize(Advisor->isForcedToStop() ? 0 : Advisor->getIRSize(*Callee)),
511f46dd19bSMircea Trofin       CallerAndCalleeEdges(Advisor->isForcedToStop()
512f46dd19bSMircea Trofin                                ? 0
513f46dd19bSMircea Trofin                                : (Advisor->getLocalCalls(*Caller) +
514f46dd19bSMircea Trofin                                   Advisor->getLocalCalls(*Callee))),
515f46dd19bSMircea Trofin       PreInlineCallerFPI(Advisor->getCachedFPI(*Caller)) {
516f46dd19bSMircea Trofin   if (Recommendation)
517f46dd19bSMircea Trofin     FPU.emplace(Advisor->getCachedFPI(*getCaller()), CB);
518f46dd19bSMircea Trofin }
519f46dd19bSMircea Trofin 
520bdceefe9SMircea Trofin void MLInlineAdvice::reportContextForRemark(
521bdceefe9SMircea Trofin     DiagnosticInfoOptimizationBase &OR) {
522bdceefe9SMircea Trofin   using namespace ore;
523bdceefe9SMircea Trofin   OR << NV("Callee", Callee->getName());
524bdceefe9SMircea Trofin   for (size_t I = 0; I < NumberOfFeatures; ++I)
525c35ad9eeSMircea Trofin     OR << NV(FeatureMap[I].name(),
526059e0347SMircea Trofin              *getAdvisor()->getModelRunner().getTensor<int64_t>(I));
527bdceefe9SMircea Trofin   OR << NV("ShouldInline", isInliningRecommended());
528bdceefe9SMircea Trofin }
529bdceefe9SMircea Trofin 
53022a1f998SMircea Trofin void MLInlineAdvice::updateCachedCallerFPI(FunctionAnalysisManager &FAM) const {
53122a1f998SMircea Trofin   FPU->finish(FAM);
532f46dd19bSMircea Trofin }
533f46dd19bSMircea Trofin 
534bdceefe9SMircea Trofin void MLInlineAdvice::recordInliningImpl() {
535bdceefe9SMircea Trofin   ORE.emit([&]() {
536bdceefe9SMircea Trofin     OptimizationRemark R(DEBUG_TYPE, "InliningSuccess", DLoc, Block);
537bdceefe9SMircea Trofin     reportContextForRemark(R);
538bdceefe9SMircea Trofin     return R;
539bdceefe9SMircea Trofin   });
540bdceefe9SMircea Trofin   getAdvisor()->onSuccessfulInlining(*this, /*CalleeWasDeleted*/ false);
541bdceefe9SMircea Trofin }
542bdceefe9SMircea Trofin 
543bdceefe9SMircea Trofin void MLInlineAdvice::recordInliningWithCalleeDeletedImpl() {
544bdceefe9SMircea Trofin   ORE.emit([&]() {
545bdceefe9SMircea Trofin     OptimizationRemark R(DEBUG_TYPE, "InliningSuccessWithCalleeDeleted", DLoc,
546bdceefe9SMircea Trofin                          Block);
547bdceefe9SMircea Trofin     reportContextForRemark(R);
548bdceefe9SMircea Trofin     return R;
549bdceefe9SMircea Trofin   });
550bdceefe9SMircea Trofin   getAdvisor()->onSuccessfulInlining(*this, /*CalleeWasDeleted*/ true);
551bdceefe9SMircea Trofin }
552bdceefe9SMircea Trofin 
553bdceefe9SMircea Trofin void MLInlineAdvice::recordUnsuccessfulInliningImpl(
554bdceefe9SMircea Trofin     const InlineResult &Result) {
555f46dd19bSMircea Trofin   getAdvisor()->getCachedFPI(*Caller) = PreInlineCallerFPI;
556bdceefe9SMircea Trofin   ORE.emit([&]() {
557bdceefe9SMircea Trofin     OptimizationRemarkMissed R(DEBUG_TYPE, "InliningAttemptedAndUnsuccessful",
558bdceefe9SMircea Trofin                                DLoc, Block);
559bdceefe9SMircea Trofin     reportContextForRemark(R);
560bdceefe9SMircea Trofin     return R;
561bdceefe9SMircea Trofin   });
562bdceefe9SMircea Trofin }
563bdceefe9SMircea Trofin void MLInlineAdvice::recordUnattemptedInliningImpl() {
564f46dd19bSMircea Trofin   assert(!FPU);
565bdceefe9SMircea Trofin   ORE.emit([&]() {
566bdceefe9SMircea Trofin     OptimizationRemarkMissed R(DEBUG_TYPE, "IniningNotAttempted", DLoc, Block);
567bdceefe9SMircea Trofin     reportContextForRemark(R);
568bdceefe9SMircea Trofin     return R;
569bdceefe9SMircea Trofin   });
570bdceefe9SMircea Trofin }
571