xref: /llvm-project/mlir/lib/Debug/DebuggerExecutionContextHook.cpp (revision d5746d73cedcf7a593dc4b4f2ce2465e2d45750b)
1 //===- DebuggerExecutionContextHook.cpp - Debugger Support ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Debug/DebuggerExecutionContextHook.h"
10 
11 #include "mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h"
12 #include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h"
13 
14 using namespace mlir;
15 using namespace mlir::tracing;
16 
17 namespace {
18 /// This structure tracks the state of the interactive debugger.
19 struct DebuggerState {
20   /// This variable keeps track of the current control option. This is set by
21   /// the debugger when control is handed over to it.
22   ExecutionContext::Control debuggerControl = ExecutionContext::Apply;
23 
24   /// The breakpoint manager that allows the debugger to set breakpoints on
25   /// action tags.
26   TagBreakpointManager tagBreakpointManager;
27 
28   /// The breakpoint manager that allows the debugger to set breakpoints on
29   /// FileLineColLoc locations.
30   FileLineColLocBreakpointManager fileLineColLocBreakpointManager;
31 
32   /// Map of breakpoint IDs to breakpoint objects.
33   DenseMap<unsigned, Breakpoint *> breakpointIdsMap;
34 
35   /// The current stack of actiive actions.
36   const tracing::ActionActiveStack *actionActiveStack;
37 
38   /// This is a "cursor" in the IR, it is used for the debugger to navigate the
39   /// IR associated to the actions.
40   IRUnit cursor;
41 };
42 } // namespace
43 
44 static DebuggerState &getGlobalDebuggerState() {
45   static LLVM_THREAD_LOCAL DebuggerState debuggerState;
46   return debuggerState;
47 }
48 
49 extern "C" {
50 void mlirDebuggerSetControl(int controlOption) {
51   getGlobalDebuggerState().debuggerControl =
52       static_cast<ExecutionContext::Control>(controlOption);
53 }
54 
55 void mlirDebuggerPrintContext() {
56   DebuggerState &state = getGlobalDebuggerState();
57   if (!state.actionActiveStack) {
58     llvm::outs() << "No active action.\n";
59     return;
60   }
61   const ArrayRef<IRUnit> &units =
62       state.actionActiveStack->getAction().getContextIRUnits();
63   llvm::outs() << units.size() << " available IRUnits:\n";
64   for (const IRUnit &unit : units) {
65     llvm::outs() << "  - ";
66     unit.print(
67         llvm::outs(),
68         OpPrintingFlags().useLocalScope().skipRegions().enableDebugInfo());
69     llvm::outs() << "\n";
70   }
71 }
72 
73 void mlirDebuggerPrintActionBacktrace(bool withContext) {
74   DebuggerState &state = getGlobalDebuggerState();
75   if (!state.actionActiveStack) {
76     llvm::outs() << "No active action.\n";
77     return;
78   }
79   state.actionActiveStack->print(llvm::outs(), withContext);
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // Cursor Management
84 //===----------------------------------------------------------------------===//
85 
86 void mlirDebuggerCursorPrint(bool withRegion) {
87   auto &state = getGlobalDebuggerState();
88   if (!state.cursor) {
89     llvm::outs() << "No active MLIR cursor, select from the context first\n";
90     return;
91   }
92   state.cursor.print(llvm::outs(), OpPrintingFlags()
93                                        .skipRegions(!withRegion)
94                                        .useLocalScope()
95                                        .enableDebugInfo());
96   llvm::outs() << "\n";
97 }
98 
99 void mlirDebuggerCursorSelectIRUnitFromContext(int index) {
100   auto &state = getGlobalDebuggerState();
101   if (!state.actionActiveStack) {
102     llvm::outs() << "No active MLIR Action stack\n";
103     return;
104   }
105   ArrayRef<IRUnit> units =
106       state.actionActiveStack->getAction().getContextIRUnits();
107   if (index < 0 || index >= static_cast<int>(units.size())) {
108     llvm::outs() << "Index invalid, bounds: [0, " << units.size()
109                  << "] but got " << index << "\n";
110     return;
111   }
112   state.cursor = units[index];
113   state.cursor.print(llvm::outs());
114   llvm::outs() << "\n";
115 }
116 
117 void mlirDebuggerCursorSelectParentIRUnit() {
118   auto &state = getGlobalDebuggerState();
119   if (!state.cursor) {
120     llvm::outs() << "No active MLIR cursor, select from the context first\n";
121     return;
122   }
123   IRUnit *unit = &state.cursor;
124   if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
125     state.cursor = op->getBlock();
126   } else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
127     state.cursor = region->getParentOp();
128   } else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
129     state.cursor = block->getParent();
130   } else {
131     llvm::outs() << "Current cursor is not a valid IRUnit";
132     return;
133   }
134   state.cursor.print(llvm::outs());
135   llvm::outs() << "\n";
136 }
137 
138 void mlirDebuggerCursorSelectChildIRUnit(int index) {
139   auto &state = getGlobalDebuggerState();
140   if (!state.cursor) {
141     llvm::outs() << "No active MLIR cursor, select from the context first\n";
142     return;
143   }
144   IRUnit *unit = &state.cursor;
145   if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
146     if (index < 0 || index >= static_cast<int>(op->getNumRegions())) {
147       llvm::outs() << "Index invalid, op has " << op->getNumRegions()
148                    << " but got " << index << "\n";
149       return;
150     }
151     state.cursor = &op->getRegion(index);
152   } else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
153     auto block = region->begin();
154     int count = 0;
155     while (block != region->end() && count != index) {
156       ++block;
157       ++count;
158     }
159 
160     if (block == region->end()) {
161       llvm::outs() << "Index invalid, region has " << count << " block but got "
162                    << index << "\n";
163       return;
164     }
165     state.cursor = &*block;
166   } else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
167     auto op = block->begin();
168     int count = 0;
169     while (op != block->end() && count != index) {
170       ++op;
171       ++count;
172     }
173 
174     if (op == block->end()) {
175       llvm::outs() << "Index invalid, block has " << count
176                    << "operations but got " << index << "\n";
177       return;
178     }
179     state.cursor = &*op;
180   } else {
181     llvm::outs() << "Current cursor is not a valid IRUnit";
182     return;
183   }
184   state.cursor.print(llvm::outs());
185   llvm::outs() << "\n";
186 }
187 
188 void mlirDebuggerCursorSelectPreviousIRUnit() {
189   auto &state = getGlobalDebuggerState();
190   if (!state.cursor) {
191     llvm::outs() << "No active MLIR cursor, select from the context first\n";
192     return;
193   }
194   IRUnit *unit = &state.cursor;
195   if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
196     Operation *previous = op->getPrevNode();
197     if (!previous) {
198       llvm::outs() << "No previous operation in the current block\n";
199       return;
200     }
201     state.cursor = previous;
202   } else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
203     llvm::outs() << "Has region\n";
204     Operation *parent = region->getParentOp();
205     if (!parent) {
206       llvm::outs() << "No parent operation for the current region\n";
207       return;
208     }
209     if (region->getRegionNumber() == 0) {
210       llvm::outs() << "No previous region in the current operation\n";
211       return;
212     }
213     state.cursor =
214         &region->getParentOp()->getRegion(region->getRegionNumber() - 1);
215   } else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
216     Block *previous = block->getPrevNode();
217     if (!previous) {
218       llvm::outs() << "No previous block in the current region\n";
219       return;
220     }
221     state.cursor = previous;
222   } else {
223     llvm::outs() << "Current cursor is not a valid IRUnit";
224     return;
225   }
226   state.cursor.print(llvm::outs());
227   llvm::outs() << "\n";
228 }
229 
230 void mlirDebuggerCursorSelectNextIRUnit() {
231   auto &state = getGlobalDebuggerState();
232   if (!state.cursor) {
233     llvm::outs() << "No active MLIR cursor, select from the context first\n";
234     return;
235   }
236   IRUnit *unit = &state.cursor;
237   if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
238     Operation *next = op->getNextNode();
239     if (!next) {
240       llvm::outs() << "No next operation in the current block\n";
241       return;
242     }
243     state.cursor = next;
244   } else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
245     Operation *parent = region->getParentOp();
246     if (!parent) {
247       llvm::outs() << "No parent operation for the current region\n";
248       return;
249     }
250     if (region->getRegionNumber() == parent->getNumRegions() - 1) {
251       llvm::outs() << "No next region in the current operation\n";
252       return;
253     }
254     state.cursor =
255         &region->getParentOp()->getRegion(region->getRegionNumber() + 1);
256   } else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
257     Block *next = block->getNextNode();
258     if (!next) {
259       llvm::outs() << "No next block in the current region\n";
260       return;
261     }
262     state.cursor = next;
263   } else {
264     llvm::outs() << "Current cursor is not a valid IRUnit";
265     return;
266   }
267   state.cursor.print(llvm::outs());
268   llvm::outs() << "\n";
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // Breakpoint Management
273 //===----------------------------------------------------------------------===//
274 
275 void mlirDebuggerEnableBreakpoint(BreakpointHandle breakpoint) {
276   reinterpret_cast<Breakpoint *>(breakpoint)->enable();
277 }
278 
279 void mlirDebuggerDisableBreakpoint(BreakpointHandle breakpoint) {
280   reinterpret_cast<Breakpoint *>(breakpoint)->disable();
281 }
282 
283 BreakpointHandle mlirDebuggerAddTagBreakpoint(const char *tag) {
284   DebuggerState &state = getGlobalDebuggerState();
285   Breakpoint *breakpoint =
286       state.tagBreakpointManager.addBreakpoint(StringRef(tag, strlen(tag)));
287   int breakpointId = state.breakpointIdsMap.size() + 1;
288   state.breakpointIdsMap[breakpointId] = breakpoint;
289   return reinterpret_cast<BreakpointHandle>(breakpoint);
290 }
291 
292 void mlirDebuggerAddRewritePatternBreakpoint(const char *patternNameInfo) {}
293 
294 void mlirDebuggerAddFileLineColLocBreakpoint(const char *file, int line,
295                                              int col) {
296   getGlobalDebuggerState().fileLineColLocBreakpointManager.addBreakpoint(
297       StringRef(file, strlen(file)), line, col);
298 }
299 
300 } // extern "C"
301 
302 LLVM_ATTRIBUTE_NOINLINE void mlirDebuggerBreakpointHook() {
303   static LLVM_THREAD_LOCAL void *volatile sink;
304   sink = static_cast<void *>(const_cast<void **>(&sink));
305 }
306 
307 static void preventLinkerDeadCodeElim() {
308   static void *volatile sink;
309   static bool initialized = [&]() {
310     sink = (void *)mlirDebuggerSetControl;
311     sink = (void *)mlirDebuggerEnableBreakpoint;
312     sink = (void *)mlirDebuggerDisableBreakpoint;
313     sink = (void *)mlirDebuggerPrintContext;
314     sink = (void *)mlirDebuggerPrintActionBacktrace;
315     sink = (void *)mlirDebuggerCursorPrint;
316     sink = (void *)mlirDebuggerCursorSelectIRUnitFromContext;
317     sink = (void *)mlirDebuggerCursorSelectParentIRUnit;
318     sink = (void *)mlirDebuggerCursorSelectChildIRUnit;
319     sink = (void *)mlirDebuggerCursorSelectPreviousIRUnit;
320     sink = (void *)mlirDebuggerCursorSelectNextIRUnit;
321     sink = (void *)mlirDebuggerAddTagBreakpoint;
322     sink = (void *)mlirDebuggerAddRewritePatternBreakpoint;
323     sink = (void *)mlirDebuggerAddFileLineColLocBreakpoint;
324     sink = static_cast<void *>(const_cast<void **>(&sink));
325     return true;
326   }();
327   (void)initialized;
328 }
329 
330 static tracing::ExecutionContext::Control
331 debuggerCallBackFunction(const tracing::ActionActiveStack *actionStack) {
332   preventLinkerDeadCodeElim();
333   // Invoke the breakpoint hook, the debugger is supposed to trap this.
334   // The debugger controls the execution from there by invoking
335   // `mlirDebuggerSetControl()`.
336   auto &state = getGlobalDebuggerState();
337   state.actionActiveStack = actionStack;
338   getGlobalDebuggerState().debuggerControl = ExecutionContext::Apply;
339   actionStack->getAction().print(llvm::outs());
340   llvm::outs() << "\n";
341   mlirDebuggerBreakpointHook();
342   return getGlobalDebuggerState().debuggerControl;
343 }
344 
345 namespace {
346 /// Manage the stack of actions that are currently active.
347 class DebuggerObserver : public ExecutionContext::Observer {
348   void beforeExecute(const ActionActiveStack *action, Breakpoint *breakpoint,
349                      bool willExecute) override {
350     auto &state = getGlobalDebuggerState();
351     state.actionActiveStack = action;
352   }
353   void afterExecute(const ActionActiveStack *action) override {
354     auto &state = getGlobalDebuggerState();
355     state.actionActiveStack = action->getParent();
356     state.cursor = nullptr;
357   }
358 };
359 } // namespace
360 
361 void mlir::setupDebuggerExecutionContextHook(
362     tracing::ExecutionContext &executionContext) {
363   executionContext.setCallback(debuggerCallBackFunction);
364   DebuggerState &state = getGlobalDebuggerState();
365   static DebuggerObserver observer;
366   executionContext.registerObserver(&observer);
367   executionContext.addBreakpointManager(&state.fileLineColLocBreakpointManager);
368   executionContext.addBreakpointManager(&state.tagBreakpointManager);
369 }
370