xref: /llvm-project/mlir/lib/Pass/IRPrinting.cpp (revision c5ea7b8338e6947b5219f95a60702fac1da633ee)
1 //===- IRPrinting.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/IR/SymbolTable.h"
11 #include "mlir/Pass/PassManager.h"
12 #include "mlir/Support/FileUtilities.h"
13 #include "llvm/ADT/StringExtras.h"
14 #include "llvm/Support/FileSystem.h"
15 #include "llvm/Support/FormatVariadic.h"
16 #include "llvm/Support/Path.h"
17 #include "llvm/Support/ToolOutputFile.h"
18 
19 using namespace mlir;
20 using namespace mlir::detail;
21 
22 namespace {
23 //===----------------------------------------------------------------------===//
24 // IRPrinter
25 //===----------------------------------------------------------------------===//
26 
27 class IRPrinterInstrumentation : public PassInstrumentation {
28 public:
29   IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config)
30       : config(std::move(config)) {}
31 
32 private:
33   /// Instrumentation hooks.
34   void runBeforePass(Pass *pass, Operation *op) override;
35   void runAfterPass(Pass *pass, Operation *op) override;
36   void runAfterPassFailed(Pass *pass, Operation *op) override;
37 
38   /// Configuration to use.
39   std::unique_ptr<PassManager::IRPrinterConfig> config;
40 
41   /// The following is a set of fingerprints for operations that are currently
42   /// being operated on in a pass. This field is only used when the
43   /// configuration asked for change detection.
44   DenseMap<Pass *, OperationFingerPrint> beforePassFingerPrints;
45 };
46 } // namespace
47 
48 static void printIR(Operation *op, bool printModuleScope, raw_ostream &out,
49                     OpPrintingFlags flags) {
50   // Otherwise, check to see if we are not printing at module scope.
51   if (!printModuleScope)
52     return op->print(out << " //----- //\n",
53                      op->getBlock() ? flags.useLocalScope() : flags);
54 
55   // Otherwise, we are printing at module scope.
56   out << " ('" << op->getName() << "' operation";
57   if (auto symbolName =
58           op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()))
59     out << ": @" << symbolName.getValue();
60   out << ") //----- //\n";
61 
62   // Find the top-level operation.
63   auto *topLevelOp = op;
64   while (auto *parentOp = topLevelOp->getParentOp())
65     topLevelOp = parentOp;
66   topLevelOp->print(out, flags);
67 }
68 
69 /// Instrumentation hooks.
70 void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) {
71   if (isa<OpToOpPassAdaptor>(pass))
72     return;
73   // If the config asked to detect changes, record the current fingerprint.
74   if (config->shouldPrintAfterOnlyOnChange())
75     beforePassFingerPrints.try_emplace(pass, op);
76 
77   config->printBeforeIfEnabled(pass, op, [&](raw_ostream &out) {
78     out << "// -----// IR Dump Before " << pass->getName() << " ("
79         << pass->getArgument() << ")";
80     printIR(op, config->shouldPrintAtModuleScope(), out,
81             config->getOpPrintingFlags());
82     out << "\n\n";
83   });
84 }
85 
86 void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) {
87   if (isa<OpToOpPassAdaptor>(pass))
88     return;
89 
90   // Check to see if we are only printing on failure.
91   if (config->shouldPrintAfterOnlyOnFailure())
92     return;
93 
94   // If the config asked to detect changes, compare the current fingerprint with
95   // the previous.
96   if (config->shouldPrintAfterOnlyOnChange()) {
97     auto fingerPrintIt = beforePassFingerPrints.find(pass);
98     assert(fingerPrintIt != beforePassFingerPrints.end() &&
99            "expected valid fingerprint");
100     // If the fingerprints are the same, we don't print the IR.
101     if (fingerPrintIt->second == OperationFingerPrint(op)) {
102       beforePassFingerPrints.erase(fingerPrintIt);
103       return;
104     }
105     beforePassFingerPrints.erase(fingerPrintIt);
106   }
107 
108   config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) {
109     out << "// -----// IR Dump After " << pass->getName() << " ("
110         << pass->getArgument() << ")";
111     printIR(op, config->shouldPrintAtModuleScope(), out,
112             config->getOpPrintingFlags());
113     out << "\n\n";
114   });
115 }
116 
117 void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
118   if (isa<OpToOpPassAdaptor>(pass))
119     return;
120   if (config->shouldPrintAfterOnlyOnChange())
121     beforePassFingerPrints.erase(pass);
122 
123   config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) {
124     out << formatv("// -----// IR Dump After {0} Failed ({1})", pass->getName(),
125                    pass->getArgument());
126     printIR(op, config->shouldPrintAtModuleScope(), out,
127             config->getOpPrintingFlags());
128     out << "\n\n";
129   });
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // IRPrinterConfig
134 //===----------------------------------------------------------------------===//
135 
136 /// Initialize the configuration.
137 PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope,
138                                               bool printAfterOnlyOnChange,
139                                               bool printAfterOnlyOnFailure,
140                                               OpPrintingFlags opPrintingFlags)
141     : printModuleScope(printModuleScope),
142       printAfterOnlyOnChange(printAfterOnlyOnChange),
143       printAfterOnlyOnFailure(printAfterOnlyOnFailure),
144       opPrintingFlags(opPrintingFlags) {}
145 PassManager::IRPrinterConfig::~IRPrinterConfig() = default;
146 
147 /// A hook that may be overridden by a derived config that checks if the IR
148 /// of 'operation' should be dumped *before* the pass 'pass' has been
149 /// executed. If the IR should be dumped, 'printCallback' should be invoked
150 /// with the stream to dump into.
151 void PassManager::IRPrinterConfig::printBeforeIfEnabled(
152     Pass *pass, Operation *operation, PrintCallbackFn printCallback) {
153   // By default, never print.
154 }
155 
156 /// A hook that may be overridden by a derived config that checks if the IR
157 /// of 'operation' should be dumped *after* the pass 'pass' has been
158 /// executed. If the IR should be dumped, 'printCallback' should be invoked
159 /// with the stream to dump into.
160 void PassManager::IRPrinterConfig::printAfterIfEnabled(
161     Pass *pass, Operation *operation, PrintCallbackFn printCallback) {
162   // By default, never print.
163 }
164 
165 //===----------------------------------------------------------------------===//
166 // PassManager
167 //===----------------------------------------------------------------------===//
168 
169 namespace {
170 /// Simple wrapper config that allows for the simpler interface defined above.
171 struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig {
172   BasicIRPrinterConfig(
173       std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
174       std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
175       bool printModuleScope, bool printAfterOnlyOnChange,
176       bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags,
177       raw_ostream &out)
178       : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange,
179                         printAfterOnlyOnFailure, opPrintingFlags),
180         shouldPrintBeforePass(std::move(shouldPrintBeforePass)),
181         shouldPrintAfterPass(std::move(shouldPrintAfterPass)), out(out) {
182     assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) &&
183            "expected at least one valid filter function");
184   }
185 
186   void printBeforeIfEnabled(Pass *pass, Operation *operation,
187                             PrintCallbackFn printCallback) final {
188     if (shouldPrintBeforePass && shouldPrintBeforePass(pass, operation))
189       printCallback(out);
190   }
191 
192   void printAfterIfEnabled(Pass *pass, Operation *operation,
193                            PrintCallbackFn printCallback) final {
194     if (shouldPrintAfterPass && shouldPrintAfterPass(pass, operation))
195       printCallback(out);
196   }
197 
198   /// Filter functions for before and after pass execution.
199   std::function<bool(Pass *, Operation *)> shouldPrintBeforePass;
200   std::function<bool(Pass *, Operation *)> shouldPrintAfterPass;
201 
202   /// The stream to output to.
203   raw_ostream &out;
204 };
205 } // namespace
206 
207 /// Return pairs of (sanitized op name, symbol name) for `op` and all parent
208 /// operations. Op names are sanitized by replacing periods with underscores.
209 /// The pairs are returned in order of outer-most to inner-most (ancestors of
210 /// `op` first, `op` last). This information is used to construct the directory
211 /// tree for the `FileTreeIRPrinterConfig` below.
212 /// The counter for `op` will be incremented by this call.
213 static std::pair<SmallVector<std::pair<std::string, StringRef>>, std::string>
214 getOpAndSymbolNames(Operation *op, StringRef passName,
215                     llvm::DenseMap<Operation *, unsigned> &counters) {
216   SmallVector<std::pair<std::string, StringRef>> pathElements;
217   SmallVector<unsigned> countPrefix;
218 
219   Operation *iter = op;
220   ++counters.try_emplace(op, -1).first->second;
221   while (iter) {
222     countPrefix.push_back(counters[iter]);
223     StringAttr symbolName =
224         iter->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
225     std::string opName =
226         llvm::join(llvm::split(iter->getName().getStringRef().str(), '.'), "_");
227     pathElements.emplace_back(opName, symbolName ? symbolName.strref()
228                                                  : "no-symbol-name");
229     iter = iter->getParentOp();
230   }
231   // Return in the order of top level (module) down to `op`.
232   std::reverse(countPrefix.begin(), countPrefix.end());
233   std::reverse(pathElements.begin(), pathElements.end());
234 
235   std::string passFileName = llvm::formatv(
236       "{0:$[_]}_{1}.mlir",
237       llvm::make_range(countPrefix.begin(), countPrefix.end()), passName);
238 
239   return {pathElements, passFileName};
240 }
241 
242 static LogicalResult createDirectoryOrPrintErr(llvm::StringRef dirPath) {
243   if (std::error_code ec =
244           llvm::sys::fs::create_directory(dirPath, /*IgnoreExisting=*/true)) {
245     llvm::errs() << "Error while creating directory " << dirPath << ": "
246                  << ec.message() << "\n";
247     return failure();
248   }
249   return success();
250 }
251 
252 /// Creates  directories (if required) and opens an output file for the
253 /// FileTreeIRPrinterConfig.
254 static std::unique_ptr<llvm::ToolOutputFile>
255 createTreePrinterOutputPath(Operation *op, llvm::StringRef passArgument,
256                             llvm::StringRef rootDir,
257                             llvm::DenseMap<Operation *, unsigned> &counters) {
258   // Create the path. We will create a tree rooted at the given 'rootDir'
259   // directory. The root directory will contain folders with the names of
260   // modules. Sub-directories within those folders mirror the nesting
261   // structure of the pass manager, using symbol names for directory names.
262   auto [opAndSymbolNames, fileName] =
263       getOpAndSymbolNames(op, passArgument, counters);
264 
265   // Create all the directories, starting at the root. Abort early if we fail to
266   // create any directory.
267   llvm::SmallString<128> path(rootDir);
268   if (failed(createDirectoryOrPrintErr(path)))
269     return nullptr;
270 
271   for (const auto &[opName, symbolName] : opAndSymbolNames) {
272     llvm::sys::path::append(path, opName + "_" + symbolName);
273     if (failed(createDirectoryOrPrintErr(path)))
274       return nullptr;
275   }
276 
277   // Open output file.
278   llvm::sys::path::append(path, fileName);
279   std::string error;
280   std::unique_ptr<llvm::ToolOutputFile> file = openOutputFile(path, &error);
281   if (!file) {
282     llvm::errs() << "Error opening output file " << path << ": " << error
283                  << "\n";
284     return nullptr;
285   }
286   return file;
287 }
288 
289 namespace {
290 /// A configuration that prints the IR before/after each pass to a set of files
291 /// in the specified directory. The files are organized into subdirectories that
292 /// mirror the nesting structure of the IR.
293 struct FileTreeIRPrinterConfig : public PassManager::IRPrinterConfig {
294   FileTreeIRPrinterConfig(
295       std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
296       std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
297       bool printModuleScope, bool printAfterOnlyOnChange,
298       bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags,
299       llvm::StringRef treeDir)
300       : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange,
301                         printAfterOnlyOnFailure, opPrintingFlags),
302         shouldPrintBeforePass(std::move(shouldPrintBeforePass)),
303         shouldPrintAfterPass(std::move(shouldPrintAfterPass)),
304         treeDir(treeDir) {
305     assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) &&
306            "expected at least one valid filter function");
307   }
308 
309   void printBeforeIfEnabled(Pass *pass, Operation *operation,
310                             PrintCallbackFn printCallback) final {
311     if (!shouldPrintBeforePass || !shouldPrintBeforePass(pass, operation))
312       return;
313     std::unique_ptr<llvm::ToolOutputFile> file = createTreePrinterOutputPath(
314         operation, pass->getArgument(), treeDir, counters);
315     if (!file)
316       return;
317     printCallback(file->os());
318     file->keep();
319   }
320 
321   void printAfterIfEnabled(Pass *pass, Operation *operation,
322                            PrintCallbackFn printCallback) final {
323     if (!shouldPrintAfterPass || !shouldPrintAfterPass(pass, operation))
324       return;
325     std::unique_ptr<llvm::ToolOutputFile> file = createTreePrinterOutputPath(
326         operation, pass->getArgument(), treeDir, counters);
327     if (!file)
328       return;
329     printCallback(file->os());
330     file->keep();
331   }
332 
333   /// Filter functions for before and after pass execution.
334   std::function<bool(Pass *, Operation *)> shouldPrintBeforePass;
335   std::function<bool(Pass *, Operation *)> shouldPrintAfterPass;
336 
337   /// Directory that should be used as the root of the file tree.
338   std::string treeDir;
339 
340   /// Counters used for labeling the prefix. Every op which could be targeted by
341   /// a pass gets its own counter.
342   llvm::DenseMap<Operation *, unsigned> counters;
343 };
344 
345 } // namespace
346 
347 /// Add an instrumentation to print the IR before and after pass execution,
348 /// using the provided configuration.
349 void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) {
350   if (config->shouldPrintAtModuleScope() &&
351       getContext()->isMultithreadingEnabled())
352     llvm::report_fatal_error("IR printing can't be setup on a pass-manager "
353                              "without disabling multi-threading first.");
354   addInstrumentation(
355       std::make_unique<IRPrinterInstrumentation>(std::move(config)));
356 }
357 
358 /// Add an instrumentation to print the IR before and after pass execution.
359 void PassManager::enableIRPrinting(
360     std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
361     std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
362     bool printModuleScope, bool printAfterOnlyOnChange,
363     bool printAfterOnlyOnFailure, raw_ostream &out,
364     OpPrintingFlags opPrintingFlags) {
365   enableIRPrinting(std::make_unique<BasicIRPrinterConfig>(
366       std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass),
367       printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure,
368       opPrintingFlags, out));
369 }
370 
371 /// Add an instrumentation to print the IR before and after pass execution.
372 void PassManager::enableIRPrintingToFileTree(
373     std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
374     std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
375     bool printModuleScope, bool printAfterOnlyOnChange,
376     bool printAfterOnlyOnFailure, StringRef printTreeDir,
377     OpPrintingFlags opPrintingFlags) {
378   enableIRPrinting(std::make_unique<FileTreeIRPrinterConfig>(
379       std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass),
380       printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure,
381       opPrintingFlags, printTreeDir));
382 }
383