xref: /llvm-project/mlir/include/mlir/Debug/ExecutionContext.h (revision 1020150e7a6f6d6f833c232125c5ab817c03c76b)
1 //===- ExecutionContext.h -  Execution Context Support *- C++ -*-=============//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_TRACING_EXECUTIONCONTEXT_H
10 #define MLIR_TRACING_EXECUTIONCONTEXT_H
11 
12 #include "mlir/Debug/BreakpointManager.h"
13 #include "mlir/IR/Action.h"
14 #include "llvm/ADT/SmallVector.h"
15 
16 namespace mlir {
17 namespace tracing {
18 
19 /// This class is used to keep track of the active actions in the stack.
20 /// It provides the current action but also access to the parent entry in the
21 /// stack. This allows to keep track of the nested nature in which actions may
22 /// be executed.
23 struct ActionActiveStack {
24 public:
ActionActiveStackActionActiveStack25   ActionActiveStack(const ActionActiveStack *parent, const Action &action,
26                     int depth)
27       : parent(parent), action(action), depth(depth) {}
getParentActionActiveStack28   const ActionActiveStack *getParent() const { return parent; }
getActionActionActiveStack29   const Action &getAction() const { return action; }
getDepthActionActiveStack30   int getDepth() const { return depth; }
31   void print(raw_ostream &os, bool withContext) const;
dumpActionActiveStack32   void dump() const {
33     print(llvm::errs(), /*withContext=*/true);
34     llvm::errs() << "\n";
35   }
getBreakpointActionActiveStack36   Breakpoint *getBreakpoint() const { return breakpoint; }
setBreakpointActionActiveStack37   void setBreakpoint(Breakpoint *breakpoint) { this->breakpoint = breakpoint; }
38 
39 private:
40   Breakpoint *breakpoint = nullptr;
41   const ActionActiveStack *parent;
42   const Action &action;
43   int depth;
44 };
45 
46 /// The ExecutionContext is the main orchestration of the infrastructure, it
47 /// acts as a handler in the MLIRContext for executing an Action. When an action
48 /// is dispatched, it'll query its set of Breakpoints managers for a breakpoint
49 /// matching this action. If a breakpoint is hit, it passes the action and the
50 /// breakpoint information to a callback. The callback is responsible for
51 /// controlling the execution of the action through an enum value it returns.
52 /// Optionally, observers can be registered to be notified before and after the
53 /// callback is executed.
54 class ExecutionContext {
55 public:
56   /// Enum that allows the client of the context to control the execution of the
57   /// action.
58   /// - Apply: The action is executed.
59   /// - Skip: The action is skipped.
60   /// - Step: The action is executed and the execution is paused before the next
61   ///         action, including for nested actions encountered before the
62   ///         current action finishes.
63   /// - Next: The action is executed and the execution is paused after the
64   ///         current action finishes before the next action.
65   /// - Finish: The action is executed and the execution is paused only when we
66   ///           reach the parent/enclosing operation. If there are no enclosing
67   ///           operation, the execution continues without stopping.
68   enum Control { Apply = 1, Skip = 2, Step = 3, Next = 4, Finish = 5 };
69 
70   /// The type of the callback that is used to control the execution.
71   /// The callback is passed the current action.
72   using CallbackTy = function_ref<Control(const ActionActiveStack *)>;
73 
74   /// Create an ExecutionContext with a callback that is used to control the
75   /// execution.
ExecutionContext(CallbackTy callback)76   ExecutionContext(CallbackTy callback) { setCallback(callback); }
77   ExecutionContext() = default;
78 
79   /// Set the callback that is used to control the execution.
setCallback(CallbackTy callback)80   void setCallback(CallbackTy callback) {
81     onBreakpointControlExecutionCallback = callback;
82   }
83 
84   /// This abstract class defines the interface used to observe an Action
85   /// execution. It allows to be notified before and after the callback is
86   /// processed, but can't affect the execution.
87   struct Observer {
88     virtual ~Observer() = default;
89     /// This method is called before the Action is executed
90     /// If a breakpoint was hit, it is passed as an argument to the callback.
91     /// The `willExecute` argument indicates whether the action will be executed
92     /// or not.
93     /// Note that this method will be called from multiple threads concurrently
94     /// when MLIR multi-threading is enabled.
beforeExecuteObserver95     virtual void beforeExecute(const ActionActiveStack *action,
96                                Breakpoint *breakpoint, bool willExecute) {}
97 
98     /// This method is called after the Action is executed, if it was executed.
99     /// It is not called if the action is skipped.
100     /// Note that this method will be called from multiple threads concurrently
101     /// when MLIR multi-threading is enabled.
afterExecuteObserver102     virtual void afterExecute(const ActionActiveStack *action) {}
103   };
104 
105   /// Register a new `Observer` on this context. It'll be notified before and
106   /// after executing an action. Note that this method is not thread-safe: it
107   /// isn't supported to add a new observer while actions may be executed.
108   void registerObserver(Observer *observer);
109 
110   /// Register a new `BreakpointManager` on this context. It'll have a chance to
111   /// match an action before it gets executed. Note that this method is not
112   /// thread-safe: it isn't supported to add a new manager while actions may be
113   /// executed.
addBreakpointManager(BreakpointManager * manager)114   void addBreakpointManager(BreakpointManager *manager) {
115     breakpoints.push_back(manager);
116   }
117 
118   /// Process the given action. This is the operator called by MLIRContext on
119   /// `executeAction()`.
120   void operator()(function_ref<void()> transform, const Action &action);
121 
122 private:
123   /// Callback that is executed when a breakpoint is hit and allows the client
124   /// to control the execution.
125   CallbackTy onBreakpointControlExecutionCallback;
126 
127   /// Next point to stop execution as describe by `Control` enum.
128   /// This is handle by indicating at which levels of depth the next
129   /// break should happen.
130   std::optional<int> depthToBreak;
131 
132   /// Observers that are notified before and after the callback is executed.
133   SmallVector<Observer *> observers;
134 
135   /// The list of managers that are queried for breakpoints.
136   SmallVector<BreakpointManager *> breakpoints;
137 };
138 
139 } // namespace tracing
140 } // namespace mlir
141 
142 #endif // MLIR_TRACING_EXECUTIONCONTEXT_H
143