xref: /llvm-project/mlir/lib/Debug/Observers/ActionProfiler.cpp (revision 884221eddb9d395830704fac79fd04008e02e368)
1 //===- ActionProfiler.cpp -  Profiling Actions *- C++ -*-=====================//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Debug/Observers/ActionProfiler.h"
10 #include "mlir/Debug/BreakpointManager.h"
11 #include "mlir/IR/Action.h"
12 #include "mlir/Rewrite/PatternApplicator.h"
13 #include "llvm/Support/Casting.h"
14 #include "llvm/Support/Threading.h"
15 #include "llvm/Support/raw_ostream.h"
16 #include <chrono>
17 
18 using namespace mlir;
19 using namespace mlir::tracing;
20 
21 //===----------------------------------------------------------------------===//
22 // ActionProfiler
23 //===----------------------------------------------------------------------===//
24 void ActionProfiler::beforeExecute(const ActionActiveStack *action,
25                                    Breakpoint *breakpoint, bool willExecute) {
26   print(action, "B"); // begin event.
27 }
28 
29 void ActionProfiler::afterExecute(const ActionActiveStack *action) {
30   print(action, "E"); // end event.
31 }
32 
33 // Print an event in JSON format.
34 void ActionProfiler::print(const ActionActiveStack *action,
35                            llvm::StringRef phase) {
36   // Create the event.
37   std::string str;
38   llvm::raw_string_ostream event(str);
39   event << "{";
40   event << R"("name": ")" << action->getAction().getTag() << "\", ";
41   event << R"("cat": "PERF", )";
42   event << R"("ph": ")" << phase << "\", ";
43   event << R"("pid": 0, )";
44   event << R"("tid": )" << llvm::get_threadid() << ", ";
45   auto ts = std::chrono::steady_clock::now() - startTime;
46   event << R"("ts": )"
47         << std::chrono::duration_cast<std::chrono::microseconds>(ts).count();
48   if (phase == "B") {
49     event << R"(, "args": {)";
50     event << R"("desc": ")";
51     action->getAction().print(event);
52     event << "\"}";
53   }
54   event << "}";
55 
56   // Print the event.
57   std::lock_guard<std::mutex> guard(mutex);
58   if (printComma)
59     os << ",\n";
60   printComma = true;
61   os << str;
62   os.flush();
63 }
64