xref: /llvm-project/llvm/include/llvm/IR/PassInstrumentation.h (revision 597ccb800829af69ebc18cd7c75d878c8d21de6e)
1 //===- llvm/IR/PassInstrumentation.h ----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 ///
10 /// This file defines the Pass Instrumentation classes that provide
11 /// instrumentation points into the pass execution by PassManager.
12 ///
13 /// There are two main classes:
14 ///   - PassInstrumentation provides a set of instrumentation points for
15 ///     pass managers to call on.
16 ///
17 ///   - PassInstrumentationCallbacks registers callbacks and provides access
18 ///     to them for PassInstrumentation.
19 ///
20 /// PassInstrumentation object is being used as a result of
21 /// PassInstrumentationAnalysis (so it is intended to be easily copyable).
22 ///
23 /// Intended scheme of use for Pass Instrumentation is as follows:
24 ///    - register instrumentation callbacks in PassInstrumentationCallbacks
25 ///      instance. PassBuilder provides helper for that.
26 ///
27 ///    - register PassInstrumentationAnalysis with all the PassManagers.
28 ///      PassBuilder handles that automatically when registering analyses.
29 ///
30 ///    - Pass Manager requests PassInstrumentationAnalysis from analysis manager
31 ///      and gets PassInstrumentation as its result.
32 ///
33 ///    - Pass Manager invokes PassInstrumentation entry points appropriately,
34 ///      passing StringRef identification ("name") of the pass currently being
35 ///      executed and IRUnit it works on. There can be different schemes of
36 ///      providing names in future, currently it is just a name() of the pass.
37 ///
38 ///    - PassInstrumentation wraps address of IRUnit into llvm::Any and passes
39 ///      control to all the registered callbacks. Note that we specifically wrap
40 ///      'const IRUnitT*' so as to avoid any accidental changes to IR in
41 ///      instrumenting callbacks.
42 ///
43 ///    - Some instrumentation points (BeforePass) allow to control execution
44 ///      of a pass. For those callbacks returning false means pass will not be
45 ///      executed.
46 ///
47 //===----------------------------------------------------------------------===//
48 
49 #ifndef LLVM_IR_PASSINSTRUMENTATION_H
50 #define LLVM_IR_PASSINSTRUMENTATION_H
51 
52 #include "llvm/ADT/Any.h"
53 #include "llvm/ADT/DenseMap.h"
54 #include "llvm/ADT/FunctionExtras.h"
55 #include "llvm/ADT/SmallVector.h"
56 #include "llvm/IR/PassManager.h"
57 #include "llvm/Support/Compiler.h"
58 #include <type_traits>
59 #include <vector>
60 
61 namespace llvm {
62 
63 class PreservedAnalyses;
64 class StringRef;
65 class Module;
66 class Loop;
67 class Function;
68 
69 extern template struct LLVM_TEMPLATE_ABI Any::TypeId<const Module *>;
70 extern template struct LLVM_TEMPLATE_ABI Any::TypeId<const Function *>;
71 extern template struct LLVM_TEMPLATE_ABI Any::TypeId<const Loop *>;
72 
73 /// This class manages callbacks registration, as well as provides a way for
74 /// PassInstrumentation to pass control to the registered callbacks.
75 class PassInstrumentationCallbacks {
76 public:
77   // Before/After callbacks accept IRUnits whenever appropriate, so they need
78   // to take them as constant pointers, wrapped with llvm::Any.
79   // For the case when IRUnit has been invalidated there is a different
80   // callback to use - AfterPassInvalidated.
81   // We call all BeforePassFuncs to determine if a pass should run or not.
82   // BeforeNonSkippedPassFuncs are called only if the pass should run.
83   // TODO: currently AfterPassInvalidated does not accept IRUnit, since passing
84   // already invalidated IRUnit is unsafe. There are ways to handle invalidated
85   // IRUnits in a safe way, and we might pursue that as soon as there is a
86   // useful instrumentation that needs it.
87   using BeforePassFunc = bool(StringRef, Any);
88   using BeforeSkippedPassFunc = void(StringRef, Any);
89   using BeforeNonSkippedPassFunc = void(StringRef, Any);
90   using AfterPassFunc = void(StringRef, Any, const PreservedAnalyses &);
91   using AfterPassInvalidatedFunc = void(StringRef, const PreservedAnalyses &);
92   using BeforeAnalysisFunc = void(StringRef, Any);
93   using AfterAnalysisFunc = void(StringRef, Any);
94   using AnalysisInvalidatedFunc = void(StringRef, Any);
95   using AnalysesClearedFunc = void(StringRef);
96 
97 public:
98   PassInstrumentationCallbacks() = default;
99 
100   /// Copying PassInstrumentationCallbacks is not intended.
101   PassInstrumentationCallbacks(const PassInstrumentationCallbacks &) = delete;
102   void operator=(const PassInstrumentationCallbacks &) = delete;
103 
104   template <typename CallableT>
105   void registerShouldRunOptionalPassCallback(CallableT C) {
106     ShouldRunOptionalPassCallbacks.emplace_back(std::move(C));
107   }
108 
109   template <typename CallableT>
110   void registerBeforeSkippedPassCallback(CallableT C) {
111     BeforeSkippedPassCallbacks.emplace_back(std::move(C));
112   }
113 
114   template <typename CallableT>
115   void registerBeforeNonSkippedPassCallback(CallableT C) {
116     BeforeNonSkippedPassCallbacks.emplace_back(std::move(C));
117   }
118 
119   template <typename CallableT>
120   void registerAfterPassCallback(CallableT C, bool ToFront = false) {
121     if (ToFront)
122       AfterPassCallbacks.insert(AfterPassCallbacks.begin(), std::move(C));
123     else
124       AfterPassCallbacks.emplace_back(std::move(C));
125   }
126 
127   template <typename CallableT>
128   void registerAfterPassInvalidatedCallback(CallableT C, bool ToFront = false) {
129     if (ToFront)
130       AfterPassInvalidatedCallbacks.insert(
131           AfterPassInvalidatedCallbacks.begin(), std::move(C));
132     else
133       AfterPassInvalidatedCallbacks.emplace_back(std::move(C));
134   }
135 
136   template <typename CallableT>
137   void registerBeforeAnalysisCallback(CallableT C) {
138     BeforeAnalysisCallbacks.emplace_back(std::move(C));
139   }
140 
141   template <typename CallableT>
142   void registerAfterAnalysisCallback(CallableT C, bool ToFront = false) {
143     if (ToFront)
144       AfterAnalysisCallbacks.insert(AfterAnalysisCallbacks.begin(),
145                                     std::move(C));
146     else
147       AfterAnalysisCallbacks.emplace_back(std::move(C));
148   }
149 
150   template <typename CallableT>
151   void registerAnalysisInvalidatedCallback(CallableT C) {
152     AnalysisInvalidatedCallbacks.emplace_back(std::move(C));
153   }
154 
155   template <typename CallableT>
156   void registerAnalysesClearedCallback(CallableT C) {
157     AnalysesClearedCallbacks.emplace_back(std::move(C));
158   }
159 
160   template <typename CallableT>
161   void registerClassToPassNameCallback(CallableT C) {
162     ClassToPassNameCallbacks.emplace_back(std::move(C));
163   }
164 
165   /// Add a class name to pass name mapping for use by pass instrumentation.
166   void addClassToPassName(StringRef ClassName, StringRef PassName);
167   /// Get the pass name for a given pass class name.
168   StringRef getPassNameForClassName(StringRef ClassName);
169 
170 private:
171   friend class PassInstrumentation;
172 
173   /// These are only run on passes that are not required. They return false when
174   /// an optional pass should be skipped.
175   SmallVector<llvm::unique_function<BeforePassFunc>, 4>
176       ShouldRunOptionalPassCallbacks;
177   /// These are run on passes that are skipped.
178   SmallVector<llvm::unique_function<BeforeSkippedPassFunc>, 4>
179       BeforeSkippedPassCallbacks;
180   /// These are run on passes that are about to be run.
181   SmallVector<llvm::unique_function<BeforeNonSkippedPassFunc>, 4>
182       BeforeNonSkippedPassCallbacks;
183   /// These are run on passes that have just run.
184   SmallVector<llvm::unique_function<AfterPassFunc>, 4> AfterPassCallbacks;
185   /// These are run on passes that have just run on invalidated IR.
186   SmallVector<llvm::unique_function<AfterPassInvalidatedFunc>, 4>
187       AfterPassInvalidatedCallbacks;
188   /// These are run on analyses that are about to be run.
189   SmallVector<llvm::unique_function<BeforeAnalysisFunc>, 4>
190       BeforeAnalysisCallbacks;
191   /// These are run on analyses that have been run.
192   SmallVector<llvm::unique_function<AfterAnalysisFunc>, 4>
193       AfterAnalysisCallbacks;
194   /// These are run on analyses that have been invalidated.
195   SmallVector<llvm::unique_function<AnalysisInvalidatedFunc>, 4>
196       AnalysisInvalidatedCallbacks;
197   /// These are run on analyses that have been cleared.
198   SmallVector<llvm::unique_function<AnalysesClearedFunc>, 4>
199       AnalysesClearedCallbacks;
200 
201   SmallVector<llvm::unique_function<void ()>, 4> ClassToPassNameCallbacks;
202   DenseMap<StringRef, std::string> ClassToPassName;
203 };
204 
205 /// This class provides instrumentation entry points for the Pass Manager,
206 /// doing calls to callbacks registered in PassInstrumentationCallbacks.
207 class PassInstrumentation {
208   PassInstrumentationCallbacks *Callbacks;
209 
210   // Template argument PassT of PassInstrumentation::runBeforePass could be two
211   // kinds: (1) a regular pass inherited from PassInfoMixin (happen when
212   // creating a adaptor pass for a regular pass); (2) a type-erased PassConcept
213   // created from (1). Here we want to make case (1) skippable unconditionally
214   // since they are regular passes. We call PassConcept::isRequired to decide
215   // for case (2).
216   template <typename PassT>
217   using has_required_t = decltype(std::declval<PassT &>().isRequired());
218 
219   template <typename PassT>
220   static std::enable_if_t<is_detected<has_required_t, PassT>::value, bool>
221   isRequired(const PassT &Pass) {
222     return Pass.isRequired();
223   }
224   template <typename PassT>
225   static std::enable_if_t<!is_detected<has_required_t, PassT>::value, bool>
226   isRequired(const PassT &Pass) {
227     return false;
228   }
229 
230 public:
231   /// Callbacks object is not owned by PassInstrumentation, its life-time
232   /// should at least match the life-time of corresponding
233   /// PassInstrumentationAnalysis (which usually is till the end of current
234   /// compilation).
235   PassInstrumentation(PassInstrumentationCallbacks *CB = nullptr)
236       : Callbacks(CB) {}
237 
238   /// BeforePass instrumentation point - takes \p Pass instance to be executed
239   /// and constant reference to IR it operates on. \Returns true if pass is
240   /// allowed to be executed. These are only run on optional pass since required
241   /// passes must always be run. This allows these callbacks to print info when
242   /// they want to skip a pass.
243   template <typename IRUnitT, typename PassT>
244   bool runBeforePass(const PassT &Pass, const IRUnitT &IR) const {
245     if (!Callbacks)
246       return true;
247 
248     bool ShouldRun = true;
249     if (!isRequired(Pass)) {
250       for (auto &C : Callbacks->ShouldRunOptionalPassCallbacks)
251         ShouldRun &= C(Pass.name(), llvm::Any(&IR));
252     }
253 
254     if (ShouldRun) {
255       for (auto &C : Callbacks->BeforeNonSkippedPassCallbacks)
256         C(Pass.name(), llvm::Any(&IR));
257     } else {
258       for (auto &C : Callbacks->BeforeSkippedPassCallbacks)
259         C(Pass.name(), llvm::Any(&IR));
260     }
261 
262     return ShouldRun;
263   }
264 
265   /// AfterPass instrumentation point - takes \p Pass instance that has
266   /// just been executed and constant reference to \p IR it operates on.
267   /// \p IR is guaranteed to be valid at this point.
268   template <typename IRUnitT, typename PassT>
269   void runAfterPass(const PassT &Pass, const IRUnitT &IR,
270                     const PreservedAnalyses &PA) const {
271     if (Callbacks)
272       for (auto &C : Callbacks->AfterPassCallbacks)
273         C(Pass.name(), llvm::Any(&IR), PA);
274   }
275 
276   /// AfterPassInvalidated instrumentation point - takes \p Pass instance
277   /// that has just been executed. For use when IR has been invalidated
278   /// by \p Pass execution.
279   template <typename IRUnitT, typename PassT>
280   void runAfterPassInvalidated(const PassT &Pass,
281                                const PreservedAnalyses &PA) const {
282     if (Callbacks)
283       for (auto &C : Callbacks->AfterPassInvalidatedCallbacks)
284         C(Pass.name(), PA);
285   }
286 
287   /// BeforeAnalysis instrumentation point - takes \p Analysis instance
288   /// to be executed and constant reference to IR it operates on.
289   template <typename IRUnitT, typename PassT>
290   void runBeforeAnalysis(const PassT &Analysis, const IRUnitT &IR) const {
291     if (Callbacks)
292       for (auto &C : Callbacks->BeforeAnalysisCallbacks)
293         C(Analysis.name(), llvm::Any(&IR));
294   }
295 
296   /// AfterAnalysis instrumentation point - takes \p Analysis instance
297   /// that has just been executed and constant reference to IR it operated on.
298   template <typename IRUnitT, typename PassT>
299   void runAfterAnalysis(const PassT &Analysis, const IRUnitT &IR) const {
300     if (Callbacks)
301       for (auto &C : Callbacks->AfterAnalysisCallbacks)
302         C(Analysis.name(), llvm::Any(&IR));
303   }
304 
305   /// AnalysisInvalidated instrumentation point - takes \p Analysis instance
306   /// that has just been invalidated and constant reference to IR it operated
307   /// on.
308   template <typename IRUnitT, typename PassT>
309   void runAnalysisInvalidated(const PassT &Analysis, const IRUnitT &IR) const {
310     if (Callbacks)
311       for (auto &C : Callbacks->AnalysisInvalidatedCallbacks)
312         C(Analysis.name(), llvm::Any(&IR));
313   }
314 
315   /// AnalysesCleared instrumentation point - takes name of IR that analyses
316   /// operated on.
317   void runAnalysesCleared(StringRef Name) const {
318     if (Callbacks)
319       for (auto &C : Callbacks->AnalysesClearedCallbacks)
320         C(Name);
321   }
322 
323   /// Handle invalidation from the pass manager when PassInstrumentation
324   /// is used as the result of PassInstrumentationAnalysis.
325   ///
326   /// On attempt to invalidate just return false. There is nothing to become
327   /// invalid here.
328   template <typename IRUnitT, typename... ExtraArgsT>
329   bool invalidate(IRUnitT &, const class llvm::PreservedAnalyses &,
330                   ExtraArgsT...) {
331     return false;
332   }
333 
334   template <typename CallableT>
335   void pushBeforeNonSkippedPassCallback(CallableT C) {
336     if (Callbacks)
337       Callbacks->BeforeNonSkippedPassCallbacks.emplace_back(std::move(C));
338   }
339   void popBeforeNonSkippedPassCallback() {
340     if (Callbacks)
341       Callbacks->BeforeNonSkippedPassCallbacks.pop_back();
342   }
343 
344   /// Get the pass name for a given pass class name.
345   StringRef getPassNameForClassName(StringRef ClassName) const {
346     if (Callbacks)
347       return Callbacks->getPassNameForClassName(ClassName);
348     return {};
349   }
350 };
351 
352 bool isSpecialPass(StringRef PassID, const std::vector<StringRef> &Specials);
353 
354 /// Pseudo-analysis pass that exposes the \c PassInstrumentation to pass
355 /// managers.
356 class PassInstrumentationAnalysis
357     : public AnalysisInfoMixin<PassInstrumentationAnalysis> {
358   friend AnalysisInfoMixin<PassInstrumentationAnalysis>;
359   static AnalysisKey Key;
360 
361   PassInstrumentationCallbacks *Callbacks;
362 
363 public:
364   /// PassInstrumentationCallbacks object is shared, owned by something else,
365   /// not this analysis.
366   PassInstrumentationAnalysis(PassInstrumentationCallbacks *Callbacks = nullptr)
367       : Callbacks(Callbacks) {}
368 
369   using Result = PassInstrumentation;
370 
371   template <typename IRUnitT, typename AnalysisManagerT, typename... ExtraArgTs>
372   Result run(IRUnitT &, AnalysisManagerT &, ExtraArgTs &&...) {
373     return PassInstrumentation(Callbacks);
374   }
375 };
376 
377 
378 } // namespace llvm
379 
380 #endif
381