1 //===- IRPrinting.cpp -----------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/IR/SymbolTable.h" 11 #include "mlir/Pass/PassManager.h" 12 #include "mlir/Support/FileUtilities.h" 13 #include "llvm/ADT/StringExtras.h" 14 #include "llvm/Support/FileSystem.h" 15 #include "llvm/Support/FormatVariadic.h" 16 #include "llvm/Support/Path.h" 17 #include "llvm/Support/ToolOutputFile.h" 18 19 using namespace mlir; 20 using namespace mlir::detail; 21 22 namespace { 23 //===----------------------------------------------------------------------===// 24 // IRPrinter 25 //===----------------------------------------------------------------------===// 26 27 class IRPrinterInstrumentation : public PassInstrumentation { 28 public: 29 IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config) 30 : config(std::move(config)) {} 31 32 private: 33 /// Instrumentation hooks. 34 void runBeforePass(Pass *pass, Operation *op) override; 35 void runAfterPass(Pass *pass, Operation *op) override; 36 void runAfterPassFailed(Pass *pass, Operation *op) override; 37 38 /// Configuration to use. 39 std::unique_ptr<PassManager::IRPrinterConfig> config; 40 41 /// The following is a set of fingerprints for operations that are currently 42 /// being operated on in a pass. This field is only used when the 43 /// configuration asked for change detection. 44 DenseMap<Pass *, OperationFingerPrint> beforePassFingerPrints; 45 }; 46 } // namespace 47 48 static void printIR(Operation *op, bool printModuleScope, raw_ostream &out, 49 OpPrintingFlags flags) { 50 // Otherwise, check to see if we are not printing at module scope. 51 if (!printModuleScope) 52 return op->print(out << " //----- //\n", 53 op->getBlock() ? flags.useLocalScope() : flags); 54 55 // Otherwise, we are printing at module scope. 56 out << " ('" << op->getName() << "' operation"; 57 if (auto symbolName = 58 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())) 59 out << ": @" << symbolName.getValue(); 60 out << ") //----- //\n"; 61 62 // Find the top-level operation. 63 auto *topLevelOp = op; 64 while (auto *parentOp = topLevelOp->getParentOp()) 65 topLevelOp = parentOp; 66 topLevelOp->print(out, flags); 67 } 68 69 /// Instrumentation hooks. 70 void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) { 71 if (isa<OpToOpPassAdaptor>(pass)) 72 return; 73 // If the config asked to detect changes, record the current fingerprint. 74 if (config->shouldPrintAfterOnlyOnChange()) 75 beforePassFingerPrints.try_emplace(pass, op); 76 77 config->printBeforeIfEnabled(pass, op, [&](raw_ostream &out) { 78 out << "// -----// IR Dump Before " << pass->getName() << " (" 79 << pass->getArgument() << ")"; 80 printIR(op, config->shouldPrintAtModuleScope(), out, 81 config->getOpPrintingFlags()); 82 out << "\n\n"; 83 }); 84 } 85 86 void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) { 87 if (isa<OpToOpPassAdaptor>(pass)) 88 return; 89 90 // Check to see if we are only printing on failure. 91 if (config->shouldPrintAfterOnlyOnFailure()) 92 return; 93 94 // If the config asked to detect changes, compare the current fingerprint with 95 // the previous. 96 if (config->shouldPrintAfterOnlyOnChange()) { 97 auto fingerPrintIt = beforePassFingerPrints.find(pass); 98 assert(fingerPrintIt != beforePassFingerPrints.end() && 99 "expected valid fingerprint"); 100 // If the fingerprints are the same, we don't print the IR. 101 if (fingerPrintIt->second == OperationFingerPrint(op)) { 102 beforePassFingerPrints.erase(fingerPrintIt); 103 return; 104 } 105 beforePassFingerPrints.erase(fingerPrintIt); 106 } 107 108 config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) { 109 out << "// -----// IR Dump After " << pass->getName() << " (" 110 << pass->getArgument() << ")"; 111 printIR(op, config->shouldPrintAtModuleScope(), out, 112 config->getOpPrintingFlags()); 113 out << "\n\n"; 114 }); 115 } 116 117 void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) { 118 if (isa<OpToOpPassAdaptor>(pass)) 119 return; 120 if (config->shouldPrintAfterOnlyOnChange()) 121 beforePassFingerPrints.erase(pass); 122 123 config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) { 124 out << formatv("// -----// IR Dump After {0} Failed ({1})", pass->getName(), 125 pass->getArgument()); 126 printIR(op, config->shouldPrintAtModuleScope(), out, 127 config->getOpPrintingFlags()); 128 out << "\n\n"; 129 }); 130 } 131 132 //===----------------------------------------------------------------------===// 133 // IRPrinterConfig 134 //===----------------------------------------------------------------------===// 135 136 /// Initialize the configuration. 137 PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope, 138 bool printAfterOnlyOnChange, 139 bool printAfterOnlyOnFailure, 140 OpPrintingFlags opPrintingFlags) 141 : printModuleScope(printModuleScope), 142 printAfterOnlyOnChange(printAfterOnlyOnChange), 143 printAfterOnlyOnFailure(printAfterOnlyOnFailure), 144 opPrintingFlags(opPrintingFlags) {} 145 PassManager::IRPrinterConfig::~IRPrinterConfig() = default; 146 147 /// A hook that may be overridden by a derived config that checks if the IR 148 /// of 'operation' should be dumped *before* the pass 'pass' has been 149 /// executed. If the IR should be dumped, 'printCallback' should be invoked 150 /// with the stream to dump into. 151 void PassManager::IRPrinterConfig::printBeforeIfEnabled( 152 Pass *pass, Operation *operation, PrintCallbackFn printCallback) { 153 // By default, never print. 154 } 155 156 /// A hook that may be overridden by a derived config that checks if the IR 157 /// of 'operation' should be dumped *after* the pass 'pass' has been 158 /// executed. If the IR should be dumped, 'printCallback' should be invoked 159 /// with the stream to dump into. 160 void PassManager::IRPrinterConfig::printAfterIfEnabled( 161 Pass *pass, Operation *operation, PrintCallbackFn printCallback) { 162 // By default, never print. 163 } 164 165 //===----------------------------------------------------------------------===// 166 // PassManager 167 //===----------------------------------------------------------------------===// 168 169 namespace { 170 /// Simple wrapper config that allows for the simpler interface defined above. 171 struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig { 172 BasicIRPrinterConfig( 173 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, 174 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, 175 bool printModuleScope, bool printAfterOnlyOnChange, 176 bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags, 177 raw_ostream &out) 178 : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange, 179 printAfterOnlyOnFailure, opPrintingFlags), 180 shouldPrintBeforePass(std::move(shouldPrintBeforePass)), 181 shouldPrintAfterPass(std::move(shouldPrintAfterPass)), out(out) { 182 assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) && 183 "expected at least one valid filter function"); 184 } 185 186 void printBeforeIfEnabled(Pass *pass, Operation *operation, 187 PrintCallbackFn printCallback) final { 188 if (shouldPrintBeforePass && shouldPrintBeforePass(pass, operation)) 189 printCallback(out); 190 } 191 192 void printAfterIfEnabled(Pass *pass, Operation *operation, 193 PrintCallbackFn printCallback) final { 194 if (shouldPrintAfterPass && shouldPrintAfterPass(pass, operation)) 195 printCallback(out); 196 } 197 198 /// Filter functions for before and after pass execution. 199 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass; 200 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass; 201 202 /// The stream to output to. 203 raw_ostream &out; 204 }; 205 } // namespace 206 207 /// Return pairs of (sanitized op name, symbol name) for `op` and all parent 208 /// operations. Op names are sanitized by replacing periods with underscores. 209 /// The pairs are returned in order of outer-most to inner-most (ancestors of 210 /// `op` first, `op` last). This information is used to construct the directory 211 /// tree for the `FileTreeIRPrinterConfig` below. 212 /// The counter for `op` will be incremented by this call. 213 static std::pair<SmallVector<std::pair<std::string, StringRef>>, std::string> 214 getOpAndSymbolNames(Operation *op, StringRef passName, 215 llvm::DenseMap<Operation *, unsigned> &counters) { 216 SmallVector<std::pair<std::string, StringRef>> pathElements; 217 SmallVector<unsigned> countPrefix; 218 219 Operation *iter = op; 220 ++counters.try_emplace(op, -1).first->second; 221 while (iter) { 222 countPrefix.push_back(counters[iter]); 223 StringAttr symbolName = 224 iter->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); 225 std::string opName = 226 llvm::join(llvm::split(iter->getName().getStringRef().str(), '.'), "_"); 227 pathElements.emplace_back(opName, symbolName ? symbolName.strref() 228 : "no-symbol-name"); 229 iter = iter->getParentOp(); 230 } 231 // Return in the order of top level (module) down to `op`. 232 std::reverse(countPrefix.begin(), countPrefix.end()); 233 std::reverse(pathElements.begin(), pathElements.end()); 234 235 std::string passFileName = llvm::formatv( 236 "{0:$[_]}_{1}.mlir", 237 llvm::make_range(countPrefix.begin(), countPrefix.end()), passName); 238 239 return {pathElements, passFileName}; 240 } 241 242 static LogicalResult createDirectoryOrPrintErr(llvm::StringRef dirPath) { 243 if (std::error_code ec = 244 llvm::sys::fs::create_directory(dirPath, /*IgnoreExisting=*/true)) { 245 llvm::errs() << "Error while creating directory " << dirPath << ": " 246 << ec.message() << "\n"; 247 return failure(); 248 } 249 return success(); 250 } 251 252 /// Creates directories (if required) and opens an output file for the 253 /// FileTreeIRPrinterConfig. 254 static std::unique_ptr<llvm::ToolOutputFile> 255 createTreePrinterOutputPath(Operation *op, llvm::StringRef passArgument, 256 llvm::StringRef rootDir, 257 llvm::DenseMap<Operation *, unsigned> &counters) { 258 // Create the path. We will create a tree rooted at the given 'rootDir' 259 // directory. The root directory will contain folders with the names of 260 // modules. Sub-directories within those folders mirror the nesting 261 // structure of the pass manager, using symbol names for directory names. 262 auto [opAndSymbolNames, fileName] = 263 getOpAndSymbolNames(op, passArgument, counters); 264 265 // Create all the directories, starting at the root. Abort early if we fail to 266 // create any directory. 267 llvm::SmallString<128> path(rootDir); 268 if (failed(createDirectoryOrPrintErr(path))) 269 return nullptr; 270 271 for (const auto &[opName, symbolName] : opAndSymbolNames) { 272 llvm::sys::path::append(path, opName + "_" + symbolName); 273 if (failed(createDirectoryOrPrintErr(path))) 274 return nullptr; 275 } 276 277 // Open output file. 278 llvm::sys::path::append(path, fileName); 279 std::string error; 280 std::unique_ptr<llvm::ToolOutputFile> file = openOutputFile(path, &error); 281 if (!file) { 282 llvm::errs() << "Error opening output file " << path << ": " << error 283 << "\n"; 284 return nullptr; 285 } 286 return file; 287 } 288 289 namespace { 290 /// A configuration that prints the IR before/after each pass to a set of files 291 /// in the specified directory. The files are organized into subdirectories that 292 /// mirror the nesting structure of the IR. 293 struct FileTreeIRPrinterConfig : public PassManager::IRPrinterConfig { 294 FileTreeIRPrinterConfig( 295 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, 296 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, 297 bool printModuleScope, bool printAfterOnlyOnChange, 298 bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags, 299 llvm::StringRef treeDir) 300 : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange, 301 printAfterOnlyOnFailure, opPrintingFlags), 302 shouldPrintBeforePass(std::move(shouldPrintBeforePass)), 303 shouldPrintAfterPass(std::move(shouldPrintAfterPass)), 304 treeDir(treeDir) { 305 assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) && 306 "expected at least one valid filter function"); 307 } 308 309 void printBeforeIfEnabled(Pass *pass, Operation *operation, 310 PrintCallbackFn printCallback) final { 311 if (!shouldPrintBeforePass || !shouldPrintBeforePass(pass, operation)) 312 return; 313 std::unique_ptr<llvm::ToolOutputFile> file = createTreePrinterOutputPath( 314 operation, pass->getArgument(), treeDir, counters); 315 if (!file) 316 return; 317 printCallback(file->os()); 318 file->keep(); 319 } 320 321 void printAfterIfEnabled(Pass *pass, Operation *operation, 322 PrintCallbackFn printCallback) final { 323 if (!shouldPrintAfterPass || !shouldPrintAfterPass(pass, operation)) 324 return; 325 std::unique_ptr<llvm::ToolOutputFile> file = createTreePrinterOutputPath( 326 operation, pass->getArgument(), treeDir, counters); 327 if (!file) 328 return; 329 printCallback(file->os()); 330 file->keep(); 331 } 332 333 /// Filter functions for before and after pass execution. 334 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass; 335 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass; 336 337 /// Directory that should be used as the root of the file tree. 338 std::string treeDir; 339 340 /// Counters used for labeling the prefix. Every op which could be targeted by 341 /// a pass gets its own counter. 342 llvm::DenseMap<Operation *, unsigned> counters; 343 }; 344 345 } // namespace 346 347 /// Add an instrumentation to print the IR before and after pass execution, 348 /// using the provided configuration. 349 void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) { 350 if (config->shouldPrintAtModuleScope() && 351 getContext()->isMultithreadingEnabled()) 352 llvm::report_fatal_error("IR printing can't be setup on a pass-manager " 353 "without disabling multi-threading first."); 354 addInstrumentation( 355 std::make_unique<IRPrinterInstrumentation>(std::move(config))); 356 } 357 358 /// Add an instrumentation to print the IR before and after pass execution. 359 void PassManager::enableIRPrinting( 360 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, 361 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, 362 bool printModuleScope, bool printAfterOnlyOnChange, 363 bool printAfterOnlyOnFailure, raw_ostream &out, 364 OpPrintingFlags opPrintingFlags) { 365 enableIRPrinting(std::make_unique<BasicIRPrinterConfig>( 366 std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass), 367 printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure, 368 opPrintingFlags, out)); 369 } 370 371 /// Add an instrumentation to print the IR before and after pass execution. 372 void PassManager::enableIRPrintingToFileTree( 373 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, 374 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, 375 bool printModuleScope, bool printAfterOnlyOnChange, 376 bool printAfterOnlyOnFailure, StringRef printTreeDir, 377 OpPrintingFlags opPrintingFlags) { 378 enableIRPrinting(std::make_unique<FileTreeIRPrinterConfig>( 379 std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass), 380 printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure, 381 opPrintingFlags, printTreeDir)); 382 } 383