xref: /llvm-project/mlir/lib/Debug/ExecutionContext.cpp (revision 1020150e7a6f6d6f833c232125c5ab817c03c76b)
1fa51c175SMehdi Amini //===- ExecutionContext.cpp - Debug Execution Context Support -------------===//
2fa51c175SMehdi Amini //
3fa51c175SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fa51c175SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
5fa51c175SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fa51c175SMehdi Amini //
7fa51c175SMehdi Amini //===----------------------------------------------------------------------===//
8fa51c175SMehdi Amini 
9fa51c175SMehdi Amini #include "mlir/Debug/ExecutionContext.h"
10fa51c175SMehdi Amini 
11fa51c175SMehdi Amini #include "llvm/ADT/ScopeExit.h"
12*1020150eSMehdi Amini #include "llvm/Support/FormatVariadic.h"
13fa51c175SMehdi Amini 
14fa51c175SMehdi Amini #include <cstddef>
15fa51c175SMehdi Amini 
16fa51c175SMehdi Amini using namespace mlir;
17fa51c175SMehdi Amini using namespace mlir::tracing;
18fa51c175SMehdi Amini 
19fa51c175SMehdi Amini //===----------------------------------------------------------------------===//
20*1020150eSMehdi Amini // ActionActiveStack
21*1020150eSMehdi Amini //===----------------------------------------------------------------------===//
22*1020150eSMehdi Amini 
print(raw_ostream & os,bool withContext) const23*1020150eSMehdi Amini void ActionActiveStack::print(raw_ostream &os, bool withContext) const {
24*1020150eSMehdi Amini   os << "ActionActiveStack depth " << getDepth() << "\n";
25*1020150eSMehdi Amini   const ActionActiveStack *current = this;
26*1020150eSMehdi Amini   int count = 0;
27*1020150eSMehdi Amini   while (current) {
28*1020150eSMehdi Amini     llvm::errs() << llvm::formatv("#{0,3}: ", count++);
29*1020150eSMehdi Amini     current->action.print(llvm::errs());
30*1020150eSMehdi Amini     llvm::errs() << "\n";
31*1020150eSMehdi Amini     ArrayRef<IRUnit> context = current->action.getContextIRUnits();
32*1020150eSMehdi Amini     if (withContext && !context.empty()) {
33*1020150eSMehdi Amini       llvm::errs() << "Context:\n";
34*1020150eSMehdi Amini       llvm::interleave(
35*1020150eSMehdi Amini           current->action.getContextIRUnits(),
36*1020150eSMehdi Amini           [&](const IRUnit &unit) {
37*1020150eSMehdi Amini             llvm::errs() << "  - ";
38*1020150eSMehdi Amini             unit.print(llvm::errs());
39*1020150eSMehdi Amini           },
40*1020150eSMehdi Amini           [&]() { llvm::errs() << "\n"; });
41*1020150eSMehdi Amini       llvm::errs() << "\n";
42*1020150eSMehdi Amini     }
43*1020150eSMehdi Amini     current = current->parent;
44*1020150eSMehdi Amini   }
45*1020150eSMehdi Amini }
46*1020150eSMehdi Amini 
47*1020150eSMehdi Amini //===----------------------------------------------------------------------===//
48fa51c175SMehdi Amini // ExecutionContext
49fa51c175SMehdi Amini //===----------------------------------------------------------------------===//
50fa51c175SMehdi Amini 
51*1020150eSMehdi Amini static const LLVM_THREAD_LOCAL ActionActiveStack *actionStack = nullptr;
52fa51c175SMehdi Amini 
registerObserver(Observer * observer)53fa51c175SMehdi Amini void ExecutionContext::registerObserver(Observer *observer) {
54fa51c175SMehdi Amini   observers.push_back(observer);
55fa51c175SMehdi Amini }
56fa51c175SMehdi Amini 
operator ()(llvm::function_ref<void ()> transform,const Action & action)57fa51c175SMehdi Amini void ExecutionContext::operator()(llvm::function_ref<void()> transform,
58fa51c175SMehdi Amini                                   const Action &action) {
59fa51c175SMehdi Amini   // Update the top of the stack with the current action.
60fa51c175SMehdi Amini   int depth = 0;
61fa51c175SMehdi Amini   if (actionStack)
62fa51c175SMehdi Amini     depth = actionStack->getDepth() + 1;
63fa51c175SMehdi Amini   ActionActiveStack info{actionStack, action, depth};
64fa51c175SMehdi Amini   actionStack = &info;
65fa51c175SMehdi Amini   auto raii = llvm::make_scope_exit([&]() { actionStack = info.getParent(); });
66fa51c175SMehdi Amini   Breakpoint *breakpoint = nullptr;
67fa51c175SMehdi Amini 
68fa51c175SMehdi Amini   // Invoke the callback here and handles control requests here.
69fa51c175SMehdi Amini   auto handleUserInput = [&]() -> bool {
70fa51c175SMehdi Amini     if (!onBreakpointControlExecutionCallback)
71fa51c175SMehdi Amini       return true;
72fa51c175SMehdi Amini     auto todoNext = onBreakpointControlExecutionCallback(actionStack);
73fa51c175SMehdi Amini     switch (todoNext) {
74fa51c175SMehdi Amini     case ExecutionContext::Apply:
75fa51c175SMehdi Amini       depthToBreak = std::nullopt;
76fa51c175SMehdi Amini       return true;
77fa51c175SMehdi Amini     case ExecutionContext::Skip:
78fa51c175SMehdi Amini       depthToBreak = std::nullopt;
79fa51c175SMehdi Amini       return false;
80fa51c175SMehdi Amini     case ExecutionContext::Step:
81fa51c175SMehdi Amini       depthToBreak = depth + 1;
82fa51c175SMehdi Amini       return true;
83fa51c175SMehdi Amini     case ExecutionContext::Next:
84fa51c175SMehdi Amini       depthToBreak = depth;
85fa51c175SMehdi Amini       return true;
86fa51c175SMehdi Amini     case ExecutionContext::Finish:
87fa51c175SMehdi Amini       depthToBreak = depth - 1;
88fa51c175SMehdi Amini       return true;
89fa51c175SMehdi Amini     }
90fa51c175SMehdi Amini     llvm::report_fatal_error("Unknown control request");
91fa51c175SMehdi Amini   };
92fa51c175SMehdi Amini 
93fa51c175SMehdi Amini   // Try to find a breakpoint that would hit on this action.
94fa51c175SMehdi Amini   // Right now there is no way to collect them all, we stop at the first one.
95fa51c175SMehdi Amini   for (auto *breakpointManager : breakpoints) {
96fa51c175SMehdi Amini     breakpoint = breakpointManager->match(action);
97fa51c175SMehdi Amini     if (breakpoint)
98fa51c175SMehdi Amini       break;
99fa51c175SMehdi Amini   }
100*1020150eSMehdi Amini   info.setBreakpoint(breakpoint);
101fa51c175SMehdi Amini 
102fa51c175SMehdi Amini   bool shouldExecuteAction = true;
103fa51c175SMehdi Amini   // If we have a breakpoint, or if `depthToBreak` was previously set and the
104fa51c175SMehdi Amini   // current depth matches, we invoke the user-provided callback.
105fa51c175SMehdi Amini   if (breakpoint || (depthToBreak && depth <= depthToBreak))
106fa51c175SMehdi Amini     shouldExecuteAction = handleUserInput();
107fa51c175SMehdi Amini 
108fa51c175SMehdi Amini   // Notify the observers about the current action.
109fa51c175SMehdi Amini   for (auto *observer : observers)
110fa51c175SMehdi Amini     observer->beforeExecute(actionStack, breakpoint, shouldExecuteAction);
111fa51c175SMehdi Amini 
112fa51c175SMehdi Amini   if (shouldExecuteAction) {
113fa51c175SMehdi Amini     // Execute the action here.
114fa51c175SMehdi Amini     transform();
115fa51c175SMehdi Amini 
116fa51c175SMehdi Amini     // Notify the observers about completion of the action.
117fa51c175SMehdi Amini     for (auto *observer : observers)
118fa51c175SMehdi Amini       observer->afterExecute(actionStack);
119fa51c175SMehdi Amini   }
120fa51c175SMehdi Amini 
121fa51c175SMehdi Amini   if (depthToBreak && depth <= depthToBreak)
122fa51c175SMehdi Amini     handleUserInput();
123fa51c175SMehdi Amini }
124