xref: /llvm-project/mlir/include/mlir/Debug/BreakpointManagers/TagBreakpointManager.h (revision fa51c1753a274fbb7a71d8fe91fd4e5caf2fa4d3)
1 //===- TagBreakpointManager.h - Simple breakpoint Support -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DEBUG_BREAKPOINTMANAGERS_TAGBREAKPOINTMANAGER_H
10 #define MLIR_DEBUG_BREAKPOINTMANAGERS_TAGBREAKPOINTMANAGER_H
11 
12 #include "mlir/Debug/BreakpointManager.h"
13 #include "mlir/Debug/ExecutionContext.h"
14 #include "mlir/IR/Action.h"
15 #include "llvm/ADT/MapVector.h"
16 
17 namespace mlir {
18 namespace tracing {
19 
20 /// Simple breakpoint matching an action "tag".
21 class TagBreakpoint : public BreakpointBase<TagBreakpoint> {
22 public:
TagBreakpoint(StringRef tag)23   TagBreakpoint(StringRef tag) : tag(tag) {}
24 
print(raw_ostream & os)25   void print(raw_ostream &os) const override { os << "Tag: `" << tag << '`'; }
26 
27 private:
28   /// A tag to associate the TagBreakpoint with.
29   std::string tag;
30 
31   /// Allow access to `tag`.
32   friend class TagBreakpointManager;
33 };
34 
35 /// This is a manager to store a collection of breakpoints that trigger
36 /// on tags.
37 class TagBreakpointManager
38     : public BreakpointManagerBase<TagBreakpointManager> {
39 public:
match(const Action & action)40   Breakpoint *match(const Action &action) const override {
41     auto it = breakpoints.find(action.getTag());
42     if (it != breakpoints.end() && it->second->isEnabled())
43       return it->second.get();
44     return {};
45   }
46 
47   /// Add a breakpoint to the manager for the given tag and return it.
48   /// If a breakpoint already exists for the given tag, return the existing
49   /// instance.
addBreakpoint(StringRef tag)50   TagBreakpoint *addBreakpoint(StringRef tag) {
51     auto result = breakpoints.insert({tag, nullptr});
52     auto &it = result.first;
53     if (result.second)
54       it->second = std::make_unique<TagBreakpoint>(tag.str());
55     return it->second.get();
56   }
57 
58 private:
59   llvm::StringMap<std::unique_ptr<TagBreakpoint>> breakpoints;
60 };
61 
62 } // namespace tracing
63 } // namespace mlir
64 
65 #endif // MLIR_DEBUG_BREAKPOINTMANAGERS_TAGBREAKPOINTMANAGER_H
66