xref: /llvm-project/mlir/include/mlir/Pass/PassInstrumentation.h (revision 0a81ace0047a2de93e71c82cdf0977fc989660df)
1 //===- PassInstrumentation.h ------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_PASS_PASSINSTRUMENTATION_H_
10 #define MLIR_PASS_PASSINSTRUMENTATION_H_
11 
12 #include "mlir/Support/LLVM.h"
13 #include "mlir/Support/TypeID.h"
14 #include <optional>
15 
16 namespace mlir {
17 class OperationName;
18 class Operation;
19 class Pass;
20 
21 namespace detail {
22 struct PassInstrumentorImpl;
23 } // namespace detail
24 
25 /// PassInstrumentation provides several entry points into the pass manager
26 /// infrastructure. Instrumentations should be added directly to a PassManager
27 /// before running a pipeline.
28 class PassInstrumentation {
29 public:
30   /// This struct represents information related to the parent pass of pipeline.
31   /// It includes information that allows for effectively linking pipelines that
32   /// run on different threads.
33   struct PipelineParentInfo {
34     /// The thread of the parent pass that the current pipeline was spawned
35     /// from. Note: This is acquired from llvm::get_threadid().
36     uint64_t parentThreadID;
37 
38     /// The pass that spawned this pipeline.
39     Pass *parentPass;
40   };
41 
42   virtual ~PassInstrumentation() = 0;
43 
44   /// A callback to run before a pass pipeline is executed. This function takes
45   /// the name of the operation type being operated on, or std::nullopt if the
46   /// pipeline is op-agnostic, and information related to the parent that
47   /// spawned this pipeline.
48   virtual void runBeforePipeline(std::optional<OperationName> name,
49                                  const PipelineParentInfo &parentInfo);
50 
51   /// A callback to run after a pass pipeline has executed. This function takes
52   /// the name of the operation type being operated on, or std::nullopt if the
53   /// pipeline is op-agnostic, and information related to the parent that
54   /// spawned this pipeline.
55   virtual void runAfterPipeline(std::optional<OperationName> name,
56                                 const PipelineParentInfo &parentInfo);
57 
58   /// A callback to run before a pass is executed. This function takes a pointer
59   /// to the pass to be executed, as well as the current operation being
60   /// operated on.
runBeforePass(Pass * pass,Operation * op)61   virtual void runBeforePass(Pass *pass, Operation *op) {}
62 
63   /// A callback to run after a pass is successfully executed. This function
64   /// takes a pointer to the pass to be executed, as well as the current
65   /// operation being operated on.
runAfterPass(Pass * pass,Operation * op)66   virtual void runAfterPass(Pass *pass, Operation *op) {}
67 
68   /// A callback to run when a pass execution fails. This function takes a
69   /// pointer to the pass that was being executed, as well as the current
70   /// operation being operated on. Note that the operation may be in an invalid
71   /// state.
runAfterPassFailed(Pass * pass,Operation * op)72   virtual void runAfterPassFailed(Pass *pass, Operation *op) {}
73 
74   /// A callback to run before an analysis is computed. This function takes the
75   /// name of the analysis to be computed, its TypeID, as well as the
76   /// current operation being analyzed.
runBeforeAnalysis(StringRef name,TypeID id,Operation * op)77   virtual void runBeforeAnalysis(StringRef name, TypeID id, Operation *op) {}
78 
79   /// A callback to run before an analysis is computed. This function takes the
80   /// name of the analysis that was computed, its TypeID, as well as the
81   /// current operation being analyzed.
runAfterAnalysis(StringRef name,TypeID id,Operation * op)82   virtual void runAfterAnalysis(StringRef name, TypeID id, Operation *op) {}
83 };
84 
85 /// This class holds a collection of PassInstrumentation objects, and invokes
86 /// their respective call backs.
87 class PassInstrumentor {
88 public:
89   PassInstrumentor();
90   PassInstrumentor(PassInstrumentor &&) = delete;
91   PassInstrumentor(const PassInstrumentor &) = delete;
92   ~PassInstrumentor();
93 
94   /// See PassInstrumentation::runBeforePipeline for details.
95   void
96   runBeforePipeline(std::optional<OperationName> name,
97                     const PassInstrumentation::PipelineParentInfo &parentInfo);
98 
99   /// See PassInstrumentation::runAfterPipeline for details.
100   void
101   runAfterPipeline(std::optional<OperationName> name,
102                    const PassInstrumentation::PipelineParentInfo &parentInfo);
103 
104   /// See PassInstrumentation::runBeforePass for details.
105   void runBeforePass(Pass *pass, Operation *op);
106 
107   /// See PassInstrumentation::runAfterPass for details.
108   void runAfterPass(Pass *pass, Operation *op);
109 
110   /// See PassInstrumentation::runAfterPassFailed for details.
111   void runAfterPassFailed(Pass *pass, Operation *op);
112 
113   /// See PassInstrumentation::runBeforeAnalysis for details.
114   void runBeforeAnalysis(StringRef name, TypeID id, Operation *op);
115 
116   /// See PassInstrumentation::runAfterAnalysis for details.
117   void runAfterAnalysis(StringRef name, TypeID id, Operation *op);
118 
119   /// Add the given instrumentation to the collection.
120   void addInstrumentation(std::unique_ptr<PassInstrumentation> pi);
121 
122 private:
123   std::unique_ptr<detail::PassInstrumentorImpl> impl;
124 };
125 
126 } // namespace mlir
127 
128 namespace llvm {
129 template <>
130 struct DenseMapInfo<mlir::PassInstrumentation::PipelineParentInfo> {
131   using T = mlir::PassInstrumentation::PipelineParentInfo;
132   using PairInfo = DenseMapInfo<std::pair<uint64_t, void *>>;
133 
134   static T getEmptyKey() {
135     auto pair = PairInfo::getEmptyKey();
136     return {pair.first, reinterpret_cast<mlir::Pass *>(pair.second)};
137   }
138   static T getTombstoneKey() {
139     auto pair = PairInfo::getTombstoneKey();
140     return {pair.first, reinterpret_cast<mlir::Pass *>(pair.second)};
141   }
142   static unsigned getHashValue(T val) {
143     return PairInfo::getHashValue({val.parentThreadID, val.parentPass});
144   }
145   static bool isEqual(T lhs, T rhs) {
146     return lhs.parentThreadID == rhs.parentThreadID &&
147            lhs.parentPass == rhs.parentPass;
148   }
149 };
150 } // namespace llvm
151 
152 #endif // MLIR_PASS_PASSINSTRUMENTATION_H_
153