1 //===- ExecutionContext.h - Execution Context Support *- C++ -*-=============// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_TRACING_EXECUTIONCONTEXT_H 10 #define MLIR_TRACING_EXECUTIONCONTEXT_H 11 12 #include "mlir/Debug/BreakpointManager.h" 13 #include "mlir/IR/Action.h" 14 #include "llvm/ADT/SmallVector.h" 15 16 namespace mlir { 17 namespace tracing { 18 19 /// This class is used to keep track of the active actions in the stack. 20 /// It provides the current action but also access to the parent entry in the 21 /// stack. This allows to keep track of the nested nature in which actions may 22 /// be executed. 23 struct ActionActiveStack { 24 public: ActionActiveStackActionActiveStack25 ActionActiveStack(const ActionActiveStack *parent, const Action &action, 26 int depth) 27 : parent(parent), action(action), depth(depth) {} getParentActionActiveStack28 const ActionActiveStack *getParent() const { return parent; } getActionActionActiveStack29 const Action &getAction() const { return action; } getDepthActionActiveStack30 int getDepth() const { return depth; } 31 void print(raw_ostream &os, bool withContext) const; dumpActionActiveStack32 void dump() const { 33 print(llvm::errs(), /*withContext=*/true); 34 llvm::errs() << "\n"; 35 } getBreakpointActionActiveStack36 Breakpoint *getBreakpoint() const { return breakpoint; } setBreakpointActionActiveStack37 void setBreakpoint(Breakpoint *breakpoint) { this->breakpoint = breakpoint; } 38 39 private: 40 Breakpoint *breakpoint = nullptr; 41 const ActionActiveStack *parent; 42 const Action &action; 43 int depth; 44 }; 45 46 /// The ExecutionContext is the main orchestration of the infrastructure, it 47 /// acts as a handler in the MLIRContext for executing an Action. When an action 48 /// is dispatched, it'll query its set of Breakpoints managers for a breakpoint 49 /// matching this action. If a breakpoint is hit, it passes the action and the 50 /// breakpoint information to a callback. The callback is responsible for 51 /// controlling the execution of the action through an enum value it returns. 52 /// Optionally, observers can be registered to be notified before and after the 53 /// callback is executed. 54 class ExecutionContext { 55 public: 56 /// Enum that allows the client of the context to control the execution of the 57 /// action. 58 /// - Apply: The action is executed. 59 /// - Skip: The action is skipped. 60 /// - Step: The action is executed and the execution is paused before the next 61 /// action, including for nested actions encountered before the 62 /// current action finishes. 63 /// - Next: The action is executed and the execution is paused after the 64 /// current action finishes before the next action. 65 /// - Finish: The action is executed and the execution is paused only when we 66 /// reach the parent/enclosing operation. If there are no enclosing 67 /// operation, the execution continues without stopping. 68 enum Control { Apply = 1, Skip = 2, Step = 3, Next = 4, Finish = 5 }; 69 70 /// The type of the callback that is used to control the execution. 71 /// The callback is passed the current action. 72 using CallbackTy = function_ref<Control(const ActionActiveStack *)>; 73 74 /// Create an ExecutionContext with a callback that is used to control the 75 /// execution. ExecutionContext(CallbackTy callback)76 ExecutionContext(CallbackTy callback) { setCallback(callback); } 77 ExecutionContext() = default; 78 79 /// Set the callback that is used to control the execution. setCallback(CallbackTy callback)80 void setCallback(CallbackTy callback) { 81 onBreakpointControlExecutionCallback = callback; 82 } 83 84 /// This abstract class defines the interface used to observe an Action 85 /// execution. It allows to be notified before and after the callback is 86 /// processed, but can't affect the execution. 87 struct Observer { 88 virtual ~Observer() = default; 89 /// This method is called before the Action is executed 90 /// If a breakpoint was hit, it is passed as an argument to the callback. 91 /// The `willExecute` argument indicates whether the action will be executed 92 /// or not. 93 /// Note that this method will be called from multiple threads concurrently 94 /// when MLIR multi-threading is enabled. beforeExecuteObserver95 virtual void beforeExecute(const ActionActiveStack *action, 96 Breakpoint *breakpoint, bool willExecute) {} 97 98 /// This method is called after the Action is executed, if it was executed. 99 /// It is not called if the action is skipped. 100 /// Note that this method will be called from multiple threads concurrently 101 /// when MLIR multi-threading is enabled. afterExecuteObserver102 virtual void afterExecute(const ActionActiveStack *action) {} 103 }; 104 105 /// Register a new `Observer` on this context. It'll be notified before and 106 /// after executing an action. Note that this method is not thread-safe: it 107 /// isn't supported to add a new observer while actions may be executed. 108 void registerObserver(Observer *observer); 109 110 /// Register a new `BreakpointManager` on this context. It'll have a chance to 111 /// match an action before it gets executed. Note that this method is not 112 /// thread-safe: it isn't supported to add a new manager while actions may be 113 /// executed. addBreakpointManager(BreakpointManager * manager)114 void addBreakpointManager(BreakpointManager *manager) { 115 breakpoints.push_back(manager); 116 } 117 118 /// Process the given action. This is the operator called by MLIRContext on 119 /// `executeAction()`. 120 void operator()(function_ref<void()> transform, const Action &action); 121 122 private: 123 /// Callback that is executed when a breakpoint is hit and allows the client 124 /// to control the execution. 125 CallbackTy onBreakpointControlExecutionCallback; 126 127 /// Next point to stop execution as describe by `Control` enum. 128 /// This is handle by indicating at which levels of depth the next 129 /// break should happen. 130 std::optional<int> depthToBreak; 131 132 /// Observers that are notified before and after the callback is executed. 133 SmallVector<Observer *> observers; 134 135 /// The list of managers that are queried for breakpoints. 136 SmallVector<BreakpointManager *> breakpoints; 137 }; 138 139 } // namespace tracing 140 } // namespace mlir 141 142 #endif // MLIR_TRACING_EXECUTIONCONTEXT_H 143