xref: /llvm-project/mlir/lib/Debug/ExecutionContext.cpp (revision 1020150e7a6f6d6f833c232125c5ab817c03c76b)
1 //===- ExecutionContext.cpp - Debug Execution Context Support -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Debug/ExecutionContext.h"
10 
11 #include "llvm/ADT/ScopeExit.h"
12 #include "llvm/Support/FormatVariadic.h"
13 
14 #include <cstddef>
15 
16 using namespace mlir;
17 using namespace mlir::tracing;
18 
19 //===----------------------------------------------------------------------===//
20 // ActionActiveStack
21 //===----------------------------------------------------------------------===//
22 
print(raw_ostream & os,bool withContext) const23 void ActionActiveStack::print(raw_ostream &os, bool withContext) const {
24   os << "ActionActiveStack depth " << getDepth() << "\n";
25   const ActionActiveStack *current = this;
26   int count = 0;
27   while (current) {
28     llvm::errs() << llvm::formatv("#{0,3}: ", count++);
29     current->action.print(llvm::errs());
30     llvm::errs() << "\n";
31     ArrayRef<IRUnit> context = current->action.getContextIRUnits();
32     if (withContext && !context.empty()) {
33       llvm::errs() << "Context:\n";
34       llvm::interleave(
35           current->action.getContextIRUnits(),
36           [&](const IRUnit &unit) {
37             llvm::errs() << "  - ";
38             unit.print(llvm::errs());
39           },
40           [&]() { llvm::errs() << "\n"; });
41       llvm::errs() << "\n";
42     }
43     current = current->parent;
44   }
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // ExecutionContext
49 //===----------------------------------------------------------------------===//
50 
51 static const LLVM_THREAD_LOCAL ActionActiveStack *actionStack = nullptr;
52 
registerObserver(Observer * observer)53 void ExecutionContext::registerObserver(Observer *observer) {
54   observers.push_back(observer);
55 }
56 
operator ()(llvm::function_ref<void ()> transform,const Action & action)57 void ExecutionContext::operator()(llvm::function_ref<void()> transform,
58                                   const Action &action) {
59   // Update the top of the stack with the current action.
60   int depth = 0;
61   if (actionStack)
62     depth = actionStack->getDepth() + 1;
63   ActionActiveStack info{actionStack, action, depth};
64   actionStack = &info;
65   auto raii = llvm::make_scope_exit([&]() { actionStack = info.getParent(); });
66   Breakpoint *breakpoint = nullptr;
67 
68   // Invoke the callback here and handles control requests here.
69   auto handleUserInput = [&]() -> bool {
70     if (!onBreakpointControlExecutionCallback)
71       return true;
72     auto todoNext = onBreakpointControlExecutionCallback(actionStack);
73     switch (todoNext) {
74     case ExecutionContext::Apply:
75       depthToBreak = std::nullopt;
76       return true;
77     case ExecutionContext::Skip:
78       depthToBreak = std::nullopt;
79       return false;
80     case ExecutionContext::Step:
81       depthToBreak = depth + 1;
82       return true;
83     case ExecutionContext::Next:
84       depthToBreak = depth;
85       return true;
86     case ExecutionContext::Finish:
87       depthToBreak = depth - 1;
88       return true;
89     }
90     llvm::report_fatal_error("Unknown control request");
91   };
92 
93   // Try to find a breakpoint that would hit on this action.
94   // Right now there is no way to collect them all, we stop at the first one.
95   for (auto *breakpointManager : breakpoints) {
96     breakpoint = breakpointManager->match(action);
97     if (breakpoint)
98       break;
99   }
100   info.setBreakpoint(breakpoint);
101 
102   bool shouldExecuteAction = true;
103   // If we have a breakpoint, or if `depthToBreak` was previously set and the
104   // current depth matches, we invoke the user-provided callback.
105   if (breakpoint || (depthToBreak && depth <= depthToBreak))
106     shouldExecuteAction = handleUserInput();
107 
108   // Notify the observers about the current action.
109   for (auto *observer : observers)
110     observer->beforeExecute(actionStack, breakpoint, shouldExecuteAction);
111 
112   if (shouldExecuteAction) {
113     // Execute the action here.
114     transform();
115 
116     // Notify the observers about completion of the action.
117     for (auto *observer : observers)
118       observer->afterExecute(actionStack);
119   }
120 
121   if (depthToBreak && depth <= depthToBreak)
122     handleUserInput();
123 }
124