xref: /llvm-project/llvm/include/llvm/Analysis/MLInlineAdvisor.h (revision 0da2ba811ac8a01509bc533428941fb9519c0715)
1 //===- MLInlineAdvisor.h - ML - based InlineAdvisor factories ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H
10 #define LLVM_ANALYSIS_MLINLINEADVISOR_H
11 
12 #include "llvm/Analysis/FunctionPropertiesAnalysis.h"
13 #include "llvm/Analysis/InlineAdvisor.h"
14 #include "llvm/Analysis/LazyCallGraph.h"
15 #include "llvm/Analysis/MLModelRunner.h"
16 #include "llvm/IR/PassManager.h"
17 
18 #include <map>
19 #include <memory>
20 #include <optional>
21 
22 namespace llvm {
23 class DiagnosticInfoOptimizationBase;
24 class Module;
25 class MLInlineAdvice;
26 class ProfileSummaryInfo;
27 
28 class MLInlineAdvisor : public InlineAdvisor {
29 public:
30   MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM,
31                   std::unique_ptr<MLModelRunner> ModelRunner,
32                   std::function<bool(CallBase &)> GetDefaultAdvice);
33 
34   virtual ~MLInlineAdvisor() = default;
35 
36   void onPassEntry(LazyCallGraph::SCC *SCC) override;
37   void onPassExit(LazyCallGraph::SCC *SCC) override;
38 
39   int64_t getIRSize(Function &F) const {
40     return getCachedFPI(F).TotalInstructionCount;
41   }
42   void onSuccessfulInlining(const MLInlineAdvice &Advice,
43                             bool CalleeWasDeleted);
44 
45   bool isForcedToStop() const { return ForceStop; }
46   int64_t getLocalCalls(Function &F);
47   const MLModelRunner &getModelRunner() const { return *ModelRunner; }
48   FunctionPropertiesInfo &getCachedFPI(Function &) const;
49 
50 protected:
51   std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override;
52 
53   std::unique_ptr<InlineAdvice> getMandatoryAdvice(CallBase &CB,
54                                                    bool Advice) override;
55 
56   virtual std::unique_ptr<MLInlineAdvice> getMandatoryAdviceImpl(CallBase &CB);
57 
58   virtual std::unique_ptr<MLInlineAdvice>
59   getAdviceFromModel(CallBase &CB, OptimizationRemarkEmitter &ORE);
60 
61   // Get the initial 'level' of the function, or 0 if the function has been
62   // introduced afterwards.
63   // TODO: should we keep this updated?
64   unsigned getInitialFunctionLevel(const Function &F) const;
65 
66   std::unique_ptr<MLModelRunner> ModelRunner;
67   std::function<bool(CallBase &)> GetDefaultAdvice;
68 
69 private:
70   int64_t getModuleIRSize() const;
71   std::unique_ptr<InlineAdvice>
72   getSkipAdviceIfUnreachableCallsite(CallBase &CB);
73   void print(raw_ostream &OS) const override;
74 
75   // Using std::map to benefit from its iterator / reference non-invalidating
76   // semantics, which make it easy to use `getCachedFPI` results from multiple
77   // calls without needing to copy to avoid invalidation effects.
78   mutable std::map<const Function *, FunctionPropertiesInfo> FPICache;
79 
80   LazyCallGraph &CG;
81 
82   int64_t NodeCount = 0;
83   int64_t EdgeCount = 0;
84   int64_t EdgesOfLastSeenNodes = 0;
85 
86   std::map<const LazyCallGraph::Node *, unsigned> FunctionLevels;
87   const int32_t InitialIRSize = 0;
88   int32_t CurrentIRSize = 0;
89   llvm::SmallPtrSet<const LazyCallGraph::Node *, 1> NodesInLastSCC;
90   DenseSet<const LazyCallGraph::Node *> AllNodes;
91   DenseSet<Function *> DeadFunctions;
92   bool ForceStop = false;
93   ProfileSummaryInfo &PSI;
94 };
95 
96 /// InlineAdvice that tracks changes post inlining. For that reason, it only
97 /// overrides the "successful inlining" extension points.
98 class MLInlineAdvice : public InlineAdvice {
99 public:
100   MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB,
101                  OptimizationRemarkEmitter &ORE, bool Recommendation);
102   virtual ~MLInlineAdvice() = default;
103 
104   void recordInliningImpl() override;
105   void recordInliningWithCalleeDeletedImpl() override;
106   void recordUnsuccessfulInliningImpl(const InlineResult &Result) override;
107   void recordUnattemptedInliningImpl() override;
108 
109   Function *getCaller() const { return Caller; }
110   Function *getCallee() const { return Callee; }
111 
112   const int64_t CallerIRSize;
113   const int64_t CalleeIRSize;
114   const int64_t CallerAndCalleeEdges;
115   void updateCachedCallerFPI(FunctionAnalysisManager &FAM) const;
116 
117 private:
118   void reportContextForRemark(DiagnosticInfoOptimizationBase &OR);
119   MLInlineAdvisor *getAdvisor() const {
120     return static_cast<MLInlineAdvisor *>(Advisor);
121   };
122   // Make a copy of the FPI of the caller right before inlining. If inlining
123   // fails, we can just update the cache with that value.
124   const FunctionPropertiesInfo PreInlineCallerFPI;
125   std::optional<FunctionPropertiesUpdater> FPU;
126 };
127 
128 } // namespace llvm
129 
130 #endif // LLVM_ANALYSIS_MLINLINEADVISOR_H
131