15ffd83dbSDimitry Andric //===- MLInlineAdvisor.cpp - machine learned InlineAdvisor ----------------===// 25ffd83dbSDimitry Andric // 35ffd83dbSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 45ffd83dbSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 55ffd83dbSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65ffd83dbSDimitry Andric // 75ffd83dbSDimitry Andric //===----------------------------------------------------------------------===// 85ffd83dbSDimitry Andric // 95ffd83dbSDimitry Andric // This file implements the interface between the inliner and a learned model. 105ffd83dbSDimitry Andric // It delegates model evaluation to either the AOT compiled model (the 115ffd83dbSDimitry Andric // 'release' mode) or a runtime-loaded model (the 'development' case). 125ffd83dbSDimitry Andric // 135ffd83dbSDimitry Andric //===----------------------------------------------------------------------===// 1404eeddc0SDimitry Andric #include "llvm/Analysis/MLInlineAdvisor.h" 155ffd83dbSDimitry Andric #include "llvm/ADT/SCCIterator.h" 1681ad6265SDimitry Andric #include "llvm/Analysis/AssumptionCache.h" 17*0fca6ea1SDimitry Andric #include "llvm/Analysis/BlockFrequencyInfo.h" 185ffd83dbSDimitry Andric #include "llvm/Analysis/CallGraph.h" 19e8d8bef9SDimitry Andric #include "llvm/Analysis/FunctionPropertiesAnalysis.h" 205ffd83dbSDimitry Andric #include "llvm/Analysis/InlineCost.h" 2104eeddc0SDimitry Andric #include "llvm/Analysis/InlineModelFeatureMaps.h" 2206c3fb27SDimitry Andric #include "llvm/Analysis/InteractiveModelRunner.h" 2304eeddc0SDimitry Andric #include "llvm/Analysis/LazyCallGraph.h" 2481ad6265SDimitry Andric #include "llvm/Analysis/LoopInfo.h" 255ffd83dbSDimitry Andric #include "llvm/Analysis/MLModelRunner.h" 265ffd83dbSDimitry Andric #include "llvm/Analysis/OptimizationRemarkEmitter.h" 27*0fca6ea1SDimitry Andric #include "llvm/Analysis/ProfileSummaryInfo.h" 2806c3fb27SDimitry Andric #include "llvm/Analysis/ReleaseModeModelRunner.h" 295ffd83dbSDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h" 3081ad6265SDimitry Andric #include "llvm/IR/Dominators.h" 315ffd83dbSDimitry Andric #include "llvm/IR/InstIterator.h" 32*0fca6ea1SDimitry Andric #include "llvm/IR/Module.h" 335ffd83dbSDimitry Andric #include "llvm/IR/PassManager.h" 345ffd83dbSDimitry Andric #include "llvm/Support/CommandLine.h" 3504eeddc0SDimitry Andric 365ffd83dbSDimitry Andric using namespace llvm; 375ffd83dbSDimitry Andric 3806c3fb27SDimitry Andric static cl::opt<std::string> InteractiveChannelBaseName( 3906c3fb27SDimitry Andric "inliner-interactive-channel-base", cl::Hidden, 4006c3fb27SDimitry Andric cl::desc( 4106c3fb27SDimitry Andric "Base file path for the interactive mode. The incoming filename should " 4206c3fb27SDimitry Andric "have the name <inliner-interactive-channel-base>.in, while the " 4306c3fb27SDimitry Andric "outgoing name should be <inliner-interactive-channel-base>.out")); 4406c3fb27SDimitry Andric static const std::string InclDefaultMsg = 4506c3fb27SDimitry Andric (Twine("In interactive mode, also send the default policy decision: ") + 4606c3fb27SDimitry Andric DefaultDecisionName + ".") 4706c3fb27SDimitry Andric .str(); 4806c3fb27SDimitry Andric static cl::opt<bool> 4906c3fb27SDimitry Andric InteractiveIncludeDefault("inliner-interactive-include-default", cl::Hidden, 5006c3fb27SDimitry Andric cl::desc(InclDefaultMsg)); 5106c3fb27SDimitry Andric 52*0fca6ea1SDimitry Andric enum class SkipMLPolicyCriteria { Never, IfCallerIsNotCold }; 53*0fca6ea1SDimitry Andric 54*0fca6ea1SDimitry Andric static cl::opt<SkipMLPolicyCriteria> SkipPolicy( 55*0fca6ea1SDimitry Andric "ml-inliner-skip-policy", cl::Hidden, cl::init(SkipMLPolicyCriteria::Never), 56*0fca6ea1SDimitry Andric cl::values(clEnumValN(SkipMLPolicyCriteria::Never, "never", "never"), 57*0fca6ea1SDimitry Andric clEnumValN(SkipMLPolicyCriteria::IfCallerIsNotCold, 58*0fca6ea1SDimitry Andric "if-caller-not-cold", "if the caller is not cold"))); 59*0fca6ea1SDimitry Andric 60*0fca6ea1SDimitry Andric static cl::opt<std::string> ModelSelector("ml-inliner-model-selector", 61*0fca6ea1SDimitry Andric cl::Hidden, cl::init("")); 62*0fca6ea1SDimitry Andric 6304eeddc0SDimitry Andric #if defined(LLVM_HAVE_TF_AOT_INLINERSIZEMODEL) 640eae32dcSDimitry Andric // codegen-ed file 650eae32dcSDimitry Andric #include "InlinerSizeModel.h" // NOLINT 6606c3fb27SDimitry Andric using CompiledModelType = llvm::InlinerSizeModel; 6706c3fb27SDimitry Andric #else 6806c3fb27SDimitry Andric using CompiledModelType = NoopSavedModelImpl; 6906c3fb27SDimitry Andric #endif 700eae32dcSDimitry Andric 710eae32dcSDimitry Andric std::unique_ptr<InlineAdvisor> 7206c3fb27SDimitry Andric llvm::getReleaseModeAdvisor(Module &M, ModuleAnalysisManager &MAM, 7306c3fb27SDimitry Andric std::function<bool(CallBase &)> GetDefaultAdvice) { 7406c3fb27SDimitry Andric if (!llvm::isEmbeddedModelEvaluatorValid<CompiledModelType>() && 7506c3fb27SDimitry Andric InteractiveChannelBaseName.empty()) 7606c3fb27SDimitry Andric return nullptr; 7706c3fb27SDimitry Andric std::unique_ptr<MLModelRunner> AOTRunner; 7806c3fb27SDimitry Andric if (InteractiveChannelBaseName.empty()) 7906c3fb27SDimitry Andric AOTRunner = std::make_unique<ReleaseModeModelRunner<CompiledModelType>>( 80*0fca6ea1SDimitry Andric M.getContext(), FeatureMap, DecisionName, 81*0fca6ea1SDimitry Andric EmbeddedModelRunnerOptions().setModelSelector(ModelSelector)); 8206c3fb27SDimitry Andric else { 8306c3fb27SDimitry Andric auto Features = FeatureMap; 8406c3fb27SDimitry Andric if (InteractiveIncludeDefault) 8506c3fb27SDimitry Andric Features.push_back(DefaultDecisionSpec); 8606c3fb27SDimitry Andric AOTRunner = std::make_unique<InteractiveModelRunner>( 8706c3fb27SDimitry Andric M.getContext(), Features, InlineDecisionSpec, 8806c3fb27SDimitry Andric InteractiveChannelBaseName + ".out", 8906c3fb27SDimitry Andric InteractiveChannelBaseName + ".in"); 900eae32dcSDimitry Andric } 9106c3fb27SDimitry Andric return std::make_unique<MLInlineAdvisor>(M, MAM, std::move(AOTRunner), 9206c3fb27SDimitry Andric GetDefaultAdvice); 9306c3fb27SDimitry Andric } 940eae32dcSDimitry Andric 955ffd83dbSDimitry Andric #define DEBUG_TYPE "inline-ml" 965ffd83dbSDimitry Andric 975ffd83dbSDimitry Andric static cl::opt<float> SizeIncreaseThreshold( 985ffd83dbSDimitry Andric "ml-advisor-size-increase-threshold", cl::Hidden, 995ffd83dbSDimitry Andric cl::desc("Maximum factor by which expected native size may increase before " 1005ffd83dbSDimitry Andric "blocking any further inlining."), 1015ffd83dbSDimitry Andric cl::init(2.0)); 1025ffd83dbSDimitry Andric 10381ad6265SDimitry Andric static cl::opt<bool> KeepFPICache( 10481ad6265SDimitry Andric "ml-advisor-keep-fpi-cache", cl::Hidden, 10581ad6265SDimitry Andric cl::desc( 10681ad6265SDimitry Andric "For test - keep the ML Inline advisor's FunctionPropertiesInfo cache"), 10781ad6265SDimitry Andric cl::init(false)); 10881ad6265SDimitry Andric 109fe6060f1SDimitry Andric // clang-format off 11006c3fb27SDimitry Andric const std::vector<TensorSpec> llvm::FeatureMap{ 11106c3fb27SDimitry Andric #define POPULATE_NAMES(DTYPE, SHAPE, NAME, __) TensorSpec::createSpec<DTYPE>(#NAME, SHAPE), 112fe6060f1SDimitry Andric // InlineCost features - these must come first 113fe6060f1SDimitry Andric INLINE_COST_FEATURE_ITERATOR(POPULATE_NAMES) 114fe6060f1SDimitry Andric 115fe6060f1SDimitry Andric // Non-cost features 1165ffd83dbSDimitry Andric INLINE_FEATURE_ITERATOR(POPULATE_NAMES) 1175ffd83dbSDimitry Andric #undef POPULATE_NAMES 1185ffd83dbSDimitry Andric }; 119fe6060f1SDimitry Andric // clang-format on 1205ffd83dbSDimitry Andric 1215ffd83dbSDimitry Andric const char *const llvm::DecisionName = "inlining_decision"; 12206c3fb27SDimitry Andric const TensorSpec llvm::InlineDecisionSpec = 12306c3fb27SDimitry Andric TensorSpec::createSpec<int64_t>(DecisionName, {1}); 1245ffd83dbSDimitry Andric const char *const llvm::DefaultDecisionName = "inlining_default"; 12506c3fb27SDimitry Andric const TensorSpec llvm::DefaultDecisionSpec = 12606c3fb27SDimitry Andric TensorSpec::createSpec<int64_t>(DefaultDecisionName, {1}); 1275ffd83dbSDimitry Andric const char *const llvm::RewardName = "delta_size"; 1285ffd83dbSDimitry Andric 1295ffd83dbSDimitry Andric CallBase *getInlinableCS(Instruction &I) { 1305ffd83dbSDimitry Andric if (auto *CS = dyn_cast<CallBase>(&I)) 1315ffd83dbSDimitry Andric if (Function *Callee = CS->getCalledFunction()) { 1325ffd83dbSDimitry Andric if (!Callee->isDeclaration()) { 1335ffd83dbSDimitry Andric return CS; 1345ffd83dbSDimitry Andric } 1355ffd83dbSDimitry Andric } 1365ffd83dbSDimitry Andric return nullptr; 1375ffd83dbSDimitry Andric } 1385ffd83dbSDimitry Andric 13906c3fb27SDimitry Andric MLInlineAdvisor::MLInlineAdvisor( 14006c3fb27SDimitry Andric Module &M, ModuleAnalysisManager &MAM, 14106c3fb27SDimitry Andric std::unique_ptr<MLModelRunner> Runner, 14206c3fb27SDimitry Andric std::function<bool(CallBase &)> GetDefaultAdvice) 1435ffd83dbSDimitry Andric : InlineAdvisor( 144e8d8bef9SDimitry Andric M, MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager()), 14506c3fb27SDimitry Andric ModelRunner(std::move(Runner)), GetDefaultAdvice(GetDefaultAdvice), 14604eeddc0SDimitry Andric CG(MAM.getResult<LazyCallGraphAnalysis>(M)), 147*0fca6ea1SDimitry Andric InitialIRSize(getModuleIRSize()), CurrentIRSize(InitialIRSize), 148*0fca6ea1SDimitry Andric PSI(MAM.getResult<ProfileSummaryAnalysis>(M)) { 1495ffd83dbSDimitry Andric assert(ModelRunner); 15006c3fb27SDimitry Andric ModelRunner->switchContext(""); 1515ffd83dbSDimitry Andric // Extract the 'call site height' feature - the position of a call site 1525ffd83dbSDimitry Andric // relative to the farthest statically reachable SCC node. We don't mutate 1535ffd83dbSDimitry Andric // this value while inlining happens. Empirically, this feature proved 1545ffd83dbSDimitry Andric // critical in behavioral cloning - i.e. training a model to mimic the manual 1555ffd83dbSDimitry Andric // heuristic's decisions - and, thus, equally important for training for 1565ffd83dbSDimitry Andric // improvement. 15704eeddc0SDimitry Andric CallGraph CGraph(M); 15804eeddc0SDimitry Andric for (auto I = scc_begin(&CGraph); !I.isAtEnd(); ++I) { 1595ffd83dbSDimitry Andric const std::vector<CallGraphNode *> &CGNodes = *I; 1605ffd83dbSDimitry Andric unsigned Level = 0; 1615ffd83dbSDimitry Andric for (auto *CGNode : CGNodes) { 1625ffd83dbSDimitry Andric Function *F = CGNode->getFunction(); 1635ffd83dbSDimitry Andric if (!F || F->isDeclaration()) 1645ffd83dbSDimitry Andric continue; 1655ffd83dbSDimitry Andric for (auto &I : instructions(F)) { 1665ffd83dbSDimitry Andric if (auto *CS = getInlinableCS(I)) { 1675ffd83dbSDimitry Andric auto *Called = CS->getCalledFunction(); 16804eeddc0SDimitry Andric auto Pos = FunctionLevels.find(&CG.get(*Called)); 1695ffd83dbSDimitry Andric // In bottom up traversal, an inlinable callee is either in the 1705ffd83dbSDimitry Andric // same SCC, or to a function in a visited SCC. So not finding its 1715ffd83dbSDimitry Andric // level means we haven't visited it yet, meaning it's in this SCC. 1725ffd83dbSDimitry Andric if (Pos == FunctionLevels.end()) 1735ffd83dbSDimitry Andric continue; 1745ffd83dbSDimitry Andric Level = std::max(Level, Pos->second + 1); 1755ffd83dbSDimitry Andric } 1765ffd83dbSDimitry Andric } 1775ffd83dbSDimitry Andric } 1785ffd83dbSDimitry Andric for (auto *CGNode : CGNodes) { 1795ffd83dbSDimitry Andric Function *F = CGNode->getFunction(); 1805ffd83dbSDimitry Andric if (F && !F->isDeclaration()) 18104eeddc0SDimitry Andric FunctionLevels[&CG.get(*F)] = Level; 1825ffd83dbSDimitry Andric } 1835ffd83dbSDimitry Andric } 18404eeddc0SDimitry Andric for (auto KVP : FunctionLevels) { 18504eeddc0SDimitry Andric AllNodes.insert(KVP.first); 18604eeddc0SDimitry Andric EdgeCount += getLocalCalls(KVP.first->getFunction()); 18704eeddc0SDimitry Andric } 18804eeddc0SDimitry Andric NodeCount = AllNodes.size(); 18904eeddc0SDimitry Andric } 19004eeddc0SDimitry Andric 19104eeddc0SDimitry Andric unsigned MLInlineAdvisor::getInitialFunctionLevel(const Function &F) const { 19204eeddc0SDimitry Andric return CG.lookup(F) ? FunctionLevels.at(CG.lookup(F)) : 0; 1935ffd83dbSDimitry Andric } 1945ffd83dbSDimitry Andric 195*0fca6ea1SDimitry Andric void MLInlineAdvisor::onPassEntry(LazyCallGraph::SCC *CurSCC) { 196*0fca6ea1SDimitry Andric if (!CurSCC || ForceStop) 19781ad6265SDimitry Andric return; 19881ad6265SDimitry Andric FPICache.clear(); 1995ffd83dbSDimitry Andric // Function passes executed between InlinerPass runs may have changed the 2005ffd83dbSDimitry Andric // module-wide features. 20104eeddc0SDimitry Andric // The cgscc pass manager rules are such that: 20204eeddc0SDimitry Andric // - if a pass leads to merging SCCs, then the pipeline is restarted on the 20304eeddc0SDimitry Andric // merged SCC 20404eeddc0SDimitry Andric // - if a pass leads to splitting the SCC, then we continue with one of the 20504eeddc0SDimitry Andric // splits 20604eeddc0SDimitry Andric // This means that the NodesInLastSCC is a superset (not strict) of the nodes 20704eeddc0SDimitry Andric // that subsequent passes would have processed 20804eeddc0SDimitry Andric // - in addition, if new Nodes were created by a pass (e.g. CoroSplit), 20904eeddc0SDimitry Andric // they'd be adjacent to Nodes in the last SCC. So we just need to check the 21004eeddc0SDimitry Andric // boundary of Nodes in NodesInLastSCC for Nodes we haven't seen. We don't 2115f757f3fSDimitry Andric // care about the nature of the Edge (call or ref). `FunctionLevels`-wise, we 2125f757f3fSDimitry Andric // record them at the same level as the original node (this is a choice, may 2135f757f3fSDimitry Andric // need revisiting). 214*0fca6ea1SDimitry Andric // - nodes are only deleted at the end of a call graph walk where they are 215*0fca6ea1SDimitry Andric // batch deleted, so we shouldn't see any dead nodes here. 21604eeddc0SDimitry Andric while (!NodesInLastSCC.empty()) { 21781ad6265SDimitry Andric const auto *N = *NodesInLastSCC.begin(); 218*0fca6ea1SDimitry Andric assert(!N->isDead()); 21981ad6265SDimitry Andric NodesInLastSCC.erase(N); 22004eeddc0SDimitry Andric EdgeCount += getLocalCalls(N->getFunction()); 2215f757f3fSDimitry Andric const auto NLevel = FunctionLevels.at(N); 22204eeddc0SDimitry Andric for (const auto &E : *(*N)) { 22304eeddc0SDimitry Andric const auto *AdjNode = &E.getNode(); 22404eeddc0SDimitry Andric assert(!AdjNode->isDead() && !AdjNode->getFunction().isDeclaration()); 22504eeddc0SDimitry Andric auto I = AllNodes.insert(AdjNode); 226*0fca6ea1SDimitry Andric // We've discovered a new function. 2275f757f3fSDimitry Andric if (I.second) { 228*0fca6ea1SDimitry Andric ++NodeCount; 22981ad6265SDimitry Andric NodesInLastSCC.insert(AdjNode); 2305f757f3fSDimitry Andric FunctionLevels[AdjNode] = NLevel; 2315f757f3fSDimitry Andric } 23204eeddc0SDimitry Andric } 23304eeddc0SDimitry Andric } 23404eeddc0SDimitry Andric 23504eeddc0SDimitry Andric EdgeCount -= EdgesOfLastSeenNodes; 23604eeddc0SDimitry Andric EdgesOfLastSeenNodes = 0; 23781ad6265SDimitry Andric 23881ad6265SDimitry Andric // (Re)use NodesInLastSCC to remember the nodes in the SCC right now, 23981ad6265SDimitry Andric // in case the SCC is split before onPassExit and some nodes are split out 24081ad6265SDimitry Andric assert(NodesInLastSCC.empty()); 241*0fca6ea1SDimitry Andric for (const auto &N : *CurSCC) 24281ad6265SDimitry Andric NodesInLastSCC.insert(&N); 24304eeddc0SDimitry Andric } 24404eeddc0SDimitry Andric 245*0fca6ea1SDimitry Andric void MLInlineAdvisor::onPassExit(LazyCallGraph::SCC *CurSCC) { 24681ad6265SDimitry Andric // No need to keep this around - function passes will invalidate it. 24781ad6265SDimitry Andric if (!KeepFPICache) 24881ad6265SDimitry Andric FPICache.clear(); 249*0fca6ea1SDimitry Andric if (!CurSCC || ForceStop) 25004eeddc0SDimitry Andric return; 25104eeddc0SDimitry Andric // Keep track of the nodes and edges we last saw. Then, in onPassEntry, 25204eeddc0SDimitry Andric // we update the node count and edge count from the subset of these nodes that 25304eeddc0SDimitry Andric // survived. 25404eeddc0SDimitry Andric EdgesOfLastSeenNodes = 0; 25581ad6265SDimitry Andric 25681ad6265SDimitry Andric // Check on nodes that were in SCC onPassEntry 257*0fca6ea1SDimitry Andric for (const LazyCallGraph::Node *N : NodesInLastSCC) { 258*0fca6ea1SDimitry Andric assert(!N->isDead()); 259*0fca6ea1SDimitry Andric EdgesOfLastSeenNodes += getLocalCalls(N->getFunction()); 26081ad6265SDimitry Andric } 26181ad6265SDimitry Andric 26281ad6265SDimitry Andric // Check on nodes that may have got added to SCC 263*0fca6ea1SDimitry Andric for (const auto &N : *CurSCC) { 26404eeddc0SDimitry Andric assert(!N.isDead()); 26581ad6265SDimitry Andric auto I = NodesInLastSCC.insert(&N); 26681ad6265SDimitry Andric if (I.second) 26704eeddc0SDimitry Andric EdgesOfLastSeenNodes += getLocalCalls(N.getFunction()); 26804eeddc0SDimitry Andric } 26981ad6265SDimitry Andric assert(NodeCount >= NodesInLastSCC.size()); 27004eeddc0SDimitry Andric assert(EdgeCount >= EdgesOfLastSeenNodes); 2715ffd83dbSDimitry Andric } 2725ffd83dbSDimitry Andric 2735ffd83dbSDimitry Andric int64_t MLInlineAdvisor::getLocalCalls(Function &F) { 27481ad6265SDimitry Andric return getCachedFPI(F).DirectCallsToDefinedFunctions; 2755ffd83dbSDimitry Andric } 2765ffd83dbSDimitry Andric 2775ffd83dbSDimitry Andric // Update the internal state of the advisor, and force invalidate feature 2785ffd83dbSDimitry Andric // analysis. Currently, we maintain minimal (and very simple) global state - the 2795ffd83dbSDimitry Andric // number of functions and the number of static calls. We also keep track of the 2805ffd83dbSDimitry Andric // total IR size in this module, to stop misbehaving policies at a certain bloat 2815ffd83dbSDimitry Andric // factor (SizeIncreaseThreshold) 2825ffd83dbSDimitry Andric void MLInlineAdvisor::onSuccessfulInlining(const MLInlineAdvice &Advice, 2835ffd83dbSDimitry Andric bool CalleeWasDeleted) { 2845ffd83dbSDimitry Andric assert(!ForceStop); 2855ffd83dbSDimitry Andric Function *Caller = Advice.getCaller(); 2865ffd83dbSDimitry Andric Function *Callee = Advice.getCallee(); 2875ffd83dbSDimitry Andric // The caller features aren't valid anymore. 288fe6060f1SDimitry Andric { 289fe6060f1SDimitry Andric PreservedAnalyses PA = PreservedAnalyses::all(); 290fe6060f1SDimitry Andric PA.abandon<FunctionPropertiesAnalysis>(); 29181ad6265SDimitry Andric PA.abandon<DominatorTreeAnalysis>(); 29281ad6265SDimitry Andric PA.abandon<LoopAnalysis>(); 293fe6060f1SDimitry Andric FAM.invalidate(*Caller, PA); 294fe6060f1SDimitry Andric } 29581ad6265SDimitry Andric Advice.updateCachedCallerFPI(FAM); 2965ffd83dbSDimitry Andric int64_t IRSizeAfter = 2975ffd83dbSDimitry Andric getIRSize(*Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize); 2985ffd83dbSDimitry Andric CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize); 2995ffd83dbSDimitry Andric if (CurrentIRSize > SizeIncreaseThreshold * InitialIRSize) 3005ffd83dbSDimitry Andric ForceStop = true; 3015ffd83dbSDimitry Andric 3025ffd83dbSDimitry Andric // We can delta-update module-wide features. We know the inlining only changed 3035ffd83dbSDimitry Andric // the caller, and maybe the callee (by deleting the latter). 3045ffd83dbSDimitry Andric // Nodes are simple to update. 3055ffd83dbSDimitry Andric // For edges, we 'forget' the edges that the caller and callee used to have 3065ffd83dbSDimitry Andric // before inlining, and add back what they currently have together. 3075ffd83dbSDimitry Andric int64_t NewCallerAndCalleeEdges = 30881ad6265SDimitry Andric getCachedFPI(*Caller).DirectCallsToDefinedFunctions; 3095ffd83dbSDimitry Andric 310*0fca6ea1SDimitry Andric // A dead function's node is not actually removed from the call graph until 311*0fca6ea1SDimitry Andric // the end of the call graph walk, but the node no longer belongs to any valid 312*0fca6ea1SDimitry Andric // SCC. 313*0fca6ea1SDimitry Andric if (CalleeWasDeleted) { 3145ffd83dbSDimitry Andric --NodeCount; 315*0fca6ea1SDimitry Andric NodesInLastSCC.erase(CG.lookup(*Callee)); 316*0fca6ea1SDimitry Andric DeadFunctions.insert(Callee); 317*0fca6ea1SDimitry Andric } else { 318e8d8bef9SDimitry Andric NewCallerAndCalleeEdges += 31981ad6265SDimitry Andric getCachedFPI(*Callee).DirectCallsToDefinedFunctions; 320*0fca6ea1SDimitry Andric } 3215ffd83dbSDimitry Andric EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges); 3225ffd83dbSDimitry Andric assert(CurrentIRSize >= 0 && EdgeCount >= 0 && NodeCount >= 0); 3235ffd83dbSDimitry Andric } 3245ffd83dbSDimitry Andric 3255ffd83dbSDimitry Andric int64_t MLInlineAdvisor::getModuleIRSize() const { 3265ffd83dbSDimitry Andric int64_t Ret = 0; 32704eeddc0SDimitry Andric for (auto &F : M) 3285ffd83dbSDimitry Andric if (!F.isDeclaration()) 3295ffd83dbSDimitry Andric Ret += getIRSize(F); 3305ffd83dbSDimitry Andric return Ret; 3315ffd83dbSDimitry Andric } 3325ffd83dbSDimitry Andric 33381ad6265SDimitry Andric FunctionPropertiesInfo &MLInlineAdvisor::getCachedFPI(Function &F) const { 33481ad6265SDimitry Andric auto InsertPair = 33581ad6265SDimitry Andric FPICache.insert(std::make_pair(&F, FunctionPropertiesInfo())); 33681ad6265SDimitry Andric if (!InsertPair.second) 33781ad6265SDimitry Andric return InsertPair.first->second; 33881ad6265SDimitry Andric InsertPair.first->second = FAM.getResult<FunctionPropertiesAnalysis>(F); 33981ad6265SDimitry Andric return InsertPair.first->second; 34081ad6265SDimitry Andric } 34181ad6265SDimitry Andric 342e8d8bef9SDimitry Andric std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) { 34381ad6265SDimitry Andric if (auto Skip = getSkipAdviceIfUnreachableCallsite(CB)) 34481ad6265SDimitry Andric return Skip; 34581ad6265SDimitry Andric 3465ffd83dbSDimitry Andric auto &Caller = *CB.getCaller(); 3475ffd83dbSDimitry Andric auto &Callee = *CB.getCalledFunction(); 3485ffd83dbSDimitry Andric 3495ffd83dbSDimitry Andric auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { 3505ffd83dbSDimitry Andric return FAM.getResult<AssumptionAnalysis>(F); 3515ffd83dbSDimitry Andric }; 3525ffd83dbSDimitry Andric auto &TIR = FAM.getResult<TargetIRAnalysis>(Callee); 3535ffd83dbSDimitry Andric auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(Caller); 3545ffd83dbSDimitry Andric 355*0fca6ea1SDimitry Andric if (SkipPolicy == SkipMLPolicyCriteria::IfCallerIsNotCold) { 356*0fca6ea1SDimitry Andric if (!PSI.isFunctionEntryCold(&Caller)) 357*0fca6ea1SDimitry Andric return std::make_unique<InlineAdvice>(this, CB, ORE, 358*0fca6ea1SDimitry Andric GetDefaultAdvice(CB)); 359*0fca6ea1SDimitry Andric } 360e8d8bef9SDimitry Andric auto MandatoryKind = InlineAdvisor::getMandatoryKind(CB, FAM, ORE); 3615ffd83dbSDimitry Andric // If this is a "never inline" case, there won't be any changes to internal 3625ffd83dbSDimitry Andric // state we need to track, so we can just return the base InlineAdvice, which 3635ffd83dbSDimitry Andric // will do nothing interesting. 3645ffd83dbSDimitry Andric // Same thing if this is a recursive case. 365e8d8bef9SDimitry Andric if (MandatoryKind == InlineAdvisor::MandatoryInliningKind::Never || 3665ffd83dbSDimitry Andric &Caller == &Callee) 367e8d8bef9SDimitry Andric return getMandatoryAdvice(CB, false); 3685ffd83dbSDimitry Andric 369e8d8bef9SDimitry Andric bool Mandatory = 370e8d8bef9SDimitry Andric MandatoryKind == InlineAdvisor::MandatoryInliningKind::Always; 3715ffd83dbSDimitry Andric 3725ffd83dbSDimitry Andric // If we need to stop, we won't want to track anymore any state changes, so 3735ffd83dbSDimitry Andric // we just return the base InlineAdvice, which acts as a noop. 3745ffd83dbSDimitry Andric if (ForceStop) { 3755ffd83dbSDimitry Andric ORE.emit([&] { 3765ffd83dbSDimitry Andric return OptimizationRemarkMissed(DEBUG_TYPE, "ForceStop", &CB) 3775ffd83dbSDimitry Andric << "Won't attempt inlining because module size grew too much."; 3785ffd83dbSDimitry Andric }); 3795ffd83dbSDimitry Andric return std::make_unique<InlineAdvice>(this, CB, ORE, Mandatory); 3805ffd83dbSDimitry Andric } 3815ffd83dbSDimitry Andric 3825ffd83dbSDimitry Andric int CostEstimate = 0; 3835ffd83dbSDimitry Andric if (!Mandatory) { 3845ffd83dbSDimitry Andric auto IsCallSiteInlinable = 3855ffd83dbSDimitry Andric llvm::getInliningCostEstimate(CB, TIR, GetAssumptionCache); 3865ffd83dbSDimitry Andric if (!IsCallSiteInlinable) { 3875ffd83dbSDimitry Andric // We can't inline this for correctness reasons, so return the base 3885ffd83dbSDimitry Andric // InlineAdvice, as we don't care about tracking any state changes (which 3895ffd83dbSDimitry Andric // won't happen). 3905ffd83dbSDimitry Andric return std::make_unique<InlineAdvice>(this, CB, ORE, false); 3915ffd83dbSDimitry Andric } 3925ffd83dbSDimitry Andric CostEstimate = *IsCallSiteInlinable; 3935ffd83dbSDimitry Andric } 3945ffd83dbSDimitry Andric 395fe6060f1SDimitry Andric const auto CostFeatures = 396fe6060f1SDimitry Andric llvm::getInliningCostFeatures(CB, TIR, GetAssumptionCache); 397fe6060f1SDimitry Andric if (!CostFeatures) { 398fe6060f1SDimitry Andric return std::make_unique<InlineAdvice>(this, CB, ORE, false); 399fe6060f1SDimitry Andric } 400fe6060f1SDimitry Andric 4015ffd83dbSDimitry Andric if (Mandatory) 402e8d8bef9SDimitry Andric return getMandatoryAdvice(CB, true); 4035ffd83dbSDimitry Andric 4045ffd83dbSDimitry Andric auto NrCtantParams = 0; 4055ffd83dbSDimitry Andric for (auto I = CB.arg_begin(), E = CB.arg_end(); I != E; ++I) { 4065ffd83dbSDimitry Andric NrCtantParams += (isa<Constant>(*I)); 4075ffd83dbSDimitry Andric } 4085ffd83dbSDimitry Andric 40981ad6265SDimitry Andric auto &CallerBefore = getCachedFPI(Caller); 41081ad6265SDimitry Andric auto &CalleeBefore = getCachedFPI(Callee); 4115ffd83dbSDimitry Andric 41206c3fb27SDimitry Andric *ModelRunner->getTensor<int64_t>(FeatureIndex::callee_basic_block_count) = 4130eae32dcSDimitry Andric CalleeBefore.BasicBlockCount; 41406c3fb27SDimitry Andric *ModelRunner->getTensor<int64_t>(FeatureIndex::callsite_height) = 41504eeddc0SDimitry Andric getInitialFunctionLevel(Caller); 41606c3fb27SDimitry Andric *ModelRunner->getTensor<int64_t>(FeatureIndex::node_count) = NodeCount; 41706c3fb27SDimitry Andric *ModelRunner->getTensor<int64_t>(FeatureIndex::nr_ctant_params) = 41806c3fb27SDimitry Andric NrCtantParams; 41906c3fb27SDimitry Andric *ModelRunner->getTensor<int64_t>(FeatureIndex::edge_count) = EdgeCount; 42006c3fb27SDimitry Andric *ModelRunner->getTensor<int64_t>(FeatureIndex::caller_users) = 4210eae32dcSDimitry Andric CallerBefore.Uses; 4220eae32dcSDimitry Andric *ModelRunner->getTensor<int64_t>( 42306c3fb27SDimitry Andric FeatureIndex::caller_conditionally_executed_blocks) = 4240eae32dcSDimitry Andric CallerBefore.BlocksReachedFromConditionalInstruction; 42506c3fb27SDimitry Andric *ModelRunner->getTensor<int64_t>(FeatureIndex::caller_basic_block_count) = 4260eae32dcSDimitry Andric CallerBefore.BasicBlockCount; 4270eae32dcSDimitry Andric *ModelRunner->getTensor<int64_t>( 42806c3fb27SDimitry Andric FeatureIndex::callee_conditionally_executed_blocks) = 4290eae32dcSDimitry Andric CalleeBefore.BlocksReachedFromConditionalInstruction; 43006c3fb27SDimitry Andric *ModelRunner->getTensor<int64_t>(FeatureIndex::callee_users) = 4310eae32dcSDimitry Andric CalleeBefore.Uses; 43206c3fb27SDimitry Andric *ModelRunner->getTensor<int64_t>(FeatureIndex::cost_estimate) = CostEstimate; 433*0fca6ea1SDimitry Andric *ModelRunner->getTensor<int64_t>(FeatureIndex::is_callee_avail_external) = 434*0fca6ea1SDimitry Andric Callee.hasAvailableExternallyLinkage(); 435*0fca6ea1SDimitry Andric *ModelRunner->getTensor<int64_t>(FeatureIndex::is_caller_avail_external) = 436*0fca6ea1SDimitry Andric Caller.hasAvailableExternallyLinkage(); 437fe6060f1SDimitry Andric 438fe6060f1SDimitry Andric // Add the cost features 439fe6060f1SDimitry Andric for (size_t I = 0; 440fe6060f1SDimitry Andric I < static_cast<size_t>(InlineCostFeatureIndex::NumberOfFeatures); ++I) { 4410eae32dcSDimitry Andric *ModelRunner->getTensor<int64_t>(inlineCostFeatureToMlFeature( 4420eae32dcSDimitry Andric static_cast<InlineCostFeatureIndex>(I))) = CostFeatures->at(I); 443fe6060f1SDimitry Andric } 44406c3fb27SDimitry Andric // This one would have been set up to be right at the end. 44506c3fb27SDimitry Andric if (!InteractiveChannelBaseName.empty() && InteractiveIncludeDefault) 44606c3fb27SDimitry Andric *ModelRunner->getTensor<int64_t>(InlineCostFeatureIndex::NumberOfFeatures) = 44706c3fb27SDimitry Andric GetDefaultAdvice(CB); 4485ffd83dbSDimitry Andric return getAdviceFromModel(CB, ORE); 4495ffd83dbSDimitry Andric } 4505ffd83dbSDimitry Andric 4515ffd83dbSDimitry Andric std::unique_ptr<MLInlineAdvice> 4525ffd83dbSDimitry Andric MLInlineAdvisor::getAdviceFromModel(CallBase &CB, 4535ffd83dbSDimitry Andric OptimizationRemarkEmitter &ORE) { 4540eae32dcSDimitry Andric return std::make_unique<MLInlineAdvice>( 4550eae32dcSDimitry Andric this, CB, ORE, static_cast<bool>(ModelRunner->evaluate<int64_t>())); 4565ffd83dbSDimitry Andric } 4575ffd83dbSDimitry Andric 45881ad6265SDimitry Andric std::unique_ptr<InlineAdvice> 45981ad6265SDimitry Andric MLInlineAdvisor::getSkipAdviceIfUnreachableCallsite(CallBase &CB) { 46081ad6265SDimitry Andric if (!FAM.getResult<DominatorTreeAnalysis>(*CB.getCaller()) 46181ad6265SDimitry Andric .isReachableFromEntry(CB.getParent())) 46281ad6265SDimitry Andric return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), false); 46381ad6265SDimitry Andric return nullptr; 46481ad6265SDimitry Andric } 46581ad6265SDimitry Andric 466e8d8bef9SDimitry Andric std::unique_ptr<InlineAdvice> MLInlineAdvisor::getMandatoryAdvice(CallBase &CB, 467e8d8bef9SDimitry Andric bool Advice) { 468e8d8bef9SDimitry Andric // Make sure we track inlinings in all cases - mandatory or not. 46981ad6265SDimitry Andric if (auto Skip = getSkipAdviceIfUnreachableCallsite(CB)) 47081ad6265SDimitry Andric return Skip; 471e8d8bef9SDimitry Andric if (Advice && !ForceStop) 472e8d8bef9SDimitry Andric return getMandatoryAdviceImpl(CB); 473e8d8bef9SDimitry Andric 474e8d8bef9SDimitry Andric // If this is a "never inline" case, there won't be any changes to internal 475e8d8bef9SDimitry Andric // state we need to track, so we can just return the base InlineAdvice, which 476e8d8bef9SDimitry Andric // will do nothing interesting. 477e8d8bef9SDimitry Andric // Same if we are forced to stop - we don't track anymore. 478e8d8bef9SDimitry Andric return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), Advice); 479e8d8bef9SDimitry Andric } 480e8d8bef9SDimitry Andric 4815ffd83dbSDimitry Andric std::unique_ptr<MLInlineAdvice> 482e8d8bef9SDimitry Andric MLInlineAdvisor::getMandatoryAdviceImpl(CallBase &CB) { 483e8d8bef9SDimitry Andric return std::make_unique<MLInlineAdvice>(this, CB, getCallerORE(CB), true); 4845ffd83dbSDimitry Andric } 4855ffd83dbSDimitry Andric 48681ad6265SDimitry Andric void MLInlineAdvisor::print(raw_ostream &OS) const { 48781ad6265SDimitry Andric OS << "[MLInlineAdvisor] Nodes: " << NodeCount << " Edges: " << EdgeCount 48881ad6265SDimitry Andric << " EdgesOfLastSeenNodes: " << EdgesOfLastSeenNodes << "\n"; 48981ad6265SDimitry Andric OS << "[MLInlineAdvisor] FPI:\n"; 49081ad6265SDimitry Andric for (auto I : FPICache) { 491bdd1243dSDimitry Andric OS << I.first->getName() << ":\n"; 492bdd1243dSDimitry Andric I.second.print(OS); 49381ad6265SDimitry Andric OS << "\n"; 49481ad6265SDimitry Andric } 49581ad6265SDimitry Andric OS << "\n"; 4965f757f3fSDimitry Andric OS << "[MLInlineAdvisor] FuncLevels:\n"; 4975f757f3fSDimitry Andric for (auto I : FunctionLevels) 498*0fca6ea1SDimitry Andric OS << (DeadFunctions.contains(&I.first->getFunction()) 499*0fca6ea1SDimitry Andric ? "<deleted>" 500*0fca6ea1SDimitry Andric : I.first->getFunction().getName()) 5015f757f3fSDimitry Andric << " : " << I.second << "\n"; 5025f757f3fSDimitry Andric 5035f757f3fSDimitry Andric OS << "\n"; 50481ad6265SDimitry Andric } 50581ad6265SDimitry Andric 50681ad6265SDimitry Andric MLInlineAdvice::MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB, 50781ad6265SDimitry Andric OptimizationRemarkEmitter &ORE, 50881ad6265SDimitry Andric bool Recommendation) 50981ad6265SDimitry Andric : InlineAdvice(Advisor, CB, ORE, Recommendation), 51081ad6265SDimitry Andric CallerIRSize(Advisor->isForcedToStop() ? 0 : Advisor->getIRSize(*Caller)), 51181ad6265SDimitry Andric CalleeIRSize(Advisor->isForcedToStop() ? 0 : Advisor->getIRSize(*Callee)), 51281ad6265SDimitry Andric CallerAndCalleeEdges(Advisor->isForcedToStop() 51381ad6265SDimitry Andric ? 0 51481ad6265SDimitry Andric : (Advisor->getLocalCalls(*Caller) + 51581ad6265SDimitry Andric Advisor->getLocalCalls(*Callee))), 51681ad6265SDimitry Andric PreInlineCallerFPI(Advisor->getCachedFPI(*Caller)) { 51781ad6265SDimitry Andric if (Recommendation) 51881ad6265SDimitry Andric FPU.emplace(Advisor->getCachedFPI(*getCaller()), CB); 51981ad6265SDimitry Andric } 52081ad6265SDimitry Andric 5215ffd83dbSDimitry Andric void MLInlineAdvice::reportContextForRemark( 5225ffd83dbSDimitry Andric DiagnosticInfoOptimizationBase &OR) { 5235ffd83dbSDimitry Andric using namespace ore; 5245ffd83dbSDimitry Andric OR << NV("Callee", Callee->getName()); 5255ffd83dbSDimitry Andric for (size_t I = 0; I < NumberOfFeatures; ++I) 52681ad6265SDimitry Andric OR << NV(FeatureMap[I].name(), 5270eae32dcSDimitry Andric *getAdvisor()->getModelRunner().getTensor<int64_t>(I)); 5285ffd83dbSDimitry Andric OR << NV("ShouldInline", isInliningRecommended()); 5295ffd83dbSDimitry Andric } 5305ffd83dbSDimitry Andric 53181ad6265SDimitry Andric void MLInlineAdvice::updateCachedCallerFPI(FunctionAnalysisManager &FAM) const { 53281ad6265SDimitry Andric FPU->finish(FAM); 53381ad6265SDimitry Andric } 53481ad6265SDimitry Andric 5355ffd83dbSDimitry Andric void MLInlineAdvice::recordInliningImpl() { 5365ffd83dbSDimitry Andric ORE.emit([&]() { 5375ffd83dbSDimitry Andric OptimizationRemark R(DEBUG_TYPE, "InliningSuccess", DLoc, Block); 5385ffd83dbSDimitry Andric reportContextForRemark(R); 5395ffd83dbSDimitry Andric return R; 5405ffd83dbSDimitry Andric }); 5415ffd83dbSDimitry Andric getAdvisor()->onSuccessfulInlining(*this, /*CalleeWasDeleted*/ false); 5425ffd83dbSDimitry Andric } 5435ffd83dbSDimitry Andric 5445ffd83dbSDimitry Andric void MLInlineAdvice::recordInliningWithCalleeDeletedImpl() { 5455ffd83dbSDimitry Andric ORE.emit([&]() { 5465ffd83dbSDimitry Andric OptimizationRemark R(DEBUG_TYPE, "InliningSuccessWithCalleeDeleted", DLoc, 5475ffd83dbSDimitry Andric Block); 5485ffd83dbSDimitry Andric reportContextForRemark(R); 5495ffd83dbSDimitry Andric return R; 5505ffd83dbSDimitry Andric }); 5515ffd83dbSDimitry Andric getAdvisor()->onSuccessfulInlining(*this, /*CalleeWasDeleted*/ true); 5525ffd83dbSDimitry Andric } 5535ffd83dbSDimitry Andric 5545ffd83dbSDimitry Andric void MLInlineAdvice::recordUnsuccessfulInliningImpl( 5555ffd83dbSDimitry Andric const InlineResult &Result) { 55681ad6265SDimitry Andric getAdvisor()->getCachedFPI(*Caller) = PreInlineCallerFPI; 5575ffd83dbSDimitry Andric ORE.emit([&]() { 5585ffd83dbSDimitry Andric OptimizationRemarkMissed R(DEBUG_TYPE, "InliningAttemptedAndUnsuccessful", 5595ffd83dbSDimitry Andric DLoc, Block); 5605ffd83dbSDimitry Andric reportContextForRemark(R); 5615ffd83dbSDimitry Andric return R; 5625ffd83dbSDimitry Andric }); 5635ffd83dbSDimitry Andric } 5645ffd83dbSDimitry Andric void MLInlineAdvice::recordUnattemptedInliningImpl() { 56581ad6265SDimitry Andric assert(!FPU); 5665ffd83dbSDimitry Andric ORE.emit([&]() { 5675ffd83dbSDimitry Andric OptimizationRemarkMissed R(DEBUG_TYPE, "IniningNotAttempted", DLoc, Block); 5685ffd83dbSDimitry Andric reportContextForRemark(R); 5695ffd83dbSDimitry Andric return R; 5705ffd83dbSDimitry Andric }); 5715ffd83dbSDimitry Andric } 572