1 //===- MlirOptMain.cpp - MLIR Optimizer Driver ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This is a utility that runs an optimization pass and prints the result back 10 // out. It is designed to support unit testing. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Tools/mlir-opt/MlirOptMain.h" 15 #include "mlir/Bytecode/BytecodeWriter.h" 16 #include "mlir/Debug/CLOptionsSetup.h" 17 #include "mlir/Debug/Counter.h" 18 #include "mlir/Debug/DebuggerExecutionContextHook.h" 19 #include "mlir/Debug/ExecutionContext.h" 20 #include "mlir/Debug/Observers/ActionLogging.h" 21 #include "mlir/Dialect/IRDL/IR/IRDL.h" 22 #include "mlir/Dialect/IRDL/IRDLLoading.h" 23 #include "mlir/IR/AsmState.h" 24 #include "mlir/IR/Attributes.h" 25 #include "mlir/IR/BuiltinOps.h" 26 #include "mlir/IR/Diagnostics.h" 27 #include "mlir/IR/Dialect.h" 28 #include "mlir/IR/Location.h" 29 #include "mlir/IR/MLIRContext.h" 30 #include "mlir/Parser/Parser.h" 31 #include "mlir/Pass/Pass.h" 32 #include "mlir/Pass/PassManager.h" 33 #include "mlir/Pass/PassRegistry.h" 34 #include "mlir/Support/FileUtilities.h" 35 #include "mlir/Support/Timing.h" 36 #include "mlir/Support/ToolUtilities.h" 37 #include "mlir/Tools/ParseUtilities.h" 38 #include "mlir/Tools/Plugins/DialectPlugin.h" 39 #include "mlir/Tools/Plugins/PassPlugin.h" 40 #include "llvm/ADT/StringRef.h" 41 #include "llvm/Support/CommandLine.h" 42 #include "llvm/Support/FileUtilities.h" 43 #include "llvm/Support/InitLLVM.h" 44 #include "llvm/Support/LogicalResult.h" 45 #include "llvm/Support/ManagedStatic.h" 46 #include "llvm/Support/Process.h" 47 #include "llvm/Support/Regex.h" 48 #include "llvm/Support/SourceMgr.h" 49 #include "llvm/Support/StringSaver.h" 50 #include "llvm/Support/ThreadPool.h" 51 #include "llvm/Support/ToolOutputFile.h" 52 53 using namespace mlir; 54 using namespace llvm; 55 56 namespace { 57 class BytecodeVersionParser : public cl::parser<std::optional<int64_t>> { 58 public: 59 BytecodeVersionParser(cl::Option &o) 60 : cl::parser<std::optional<int64_t>>(o) {} 61 62 bool parse(cl::Option &o, StringRef /*argName*/, StringRef arg, 63 std::optional<int64_t> &v) { 64 long long w; 65 if (getAsSignedInteger(arg, 10, w)) 66 return o.error("Invalid argument '" + arg + 67 "', only integer is supported."); 68 v = w; 69 return false; 70 } 71 }; 72 73 /// This class is intended to manage the handling of command line options for 74 /// creating a *-opt config. This is a singleton. 75 struct MlirOptMainConfigCLOptions : public MlirOptMainConfig { 76 MlirOptMainConfigCLOptions() { 77 // These options are static but all uses ExternalStorage to initialize the 78 // members of the parent class. This is unusual but since this class is a 79 // singleton it basically attaches command line option to the singleton 80 // members. 81 82 static cl::opt<bool, /*ExternalStorage=*/true> allowUnregisteredDialects( 83 "allow-unregistered-dialect", 84 cl::desc("Allow operation with no registered dialects"), 85 cl::location(allowUnregisteredDialectsFlag), cl::init(false)); 86 87 static cl::opt<bool, /*ExternalStorage=*/true> dumpPassPipeline( 88 "dump-pass-pipeline", cl::desc("Print the pipeline that will be run"), 89 cl::location(dumpPassPipelineFlag), cl::init(false)); 90 91 static cl::opt<bool, /*ExternalStorage=*/true> emitBytecode( 92 "emit-bytecode", cl::desc("Emit bytecode when generating output"), 93 cl::location(emitBytecodeFlag), cl::init(false)); 94 95 static cl::opt<bool, /*ExternalStorage=*/true> elideResourcesFromBytecode( 96 "elide-resource-data-from-bytecode", 97 cl::desc("Elide resources when generating bytecode"), 98 cl::location(elideResourceDataFromBytecodeFlag), cl::init(false)); 99 100 static cl::opt<std::optional<int64_t>, /*ExternalStorage=*/true, 101 BytecodeVersionParser> 102 bytecodeVersion( 103 "emit-bytecode-version", 104 cl::desc("Use specified bytecode when generating output"), 105 cl::location(emitBytecodeVersion), cl::init(std::nullopt)); 106 107 static cl::opt<std::string, /*ExternalStorage=*/true> irdlFile( 108 "irdl-file", 109 cl::desc("IRDL file to register before processing the input"), 110 cl::location(irdlFileFlag), cl::init(""), cl::value_desc("filename")); 111 112 static cl::opt<VerbosityLevel, /*ExternalStorage=*/true> 113 diagnosticVerbosityLevel( 114 "mlir-diagnostic-verbosity-level", 115 cl::desc("Choose level of diagnostic information"), 116 cl::location(diagnosticVerbosityLevelFlag), 117 cl::init(VerbosityLevel::ErrorsWarningsAndRemarks), 118 cl::values( 119 clEnumValN(VerbosityLevel::ErrorsOnly, "errors", "Errors only"), 120 clEnumValN(VerbosityLevel::ErrorsAndWarnings, "warnings", 121 "Errors and warnings"), 122 clEnumValN(VerbosityLevel::ErrorsWarningsAndRemarks, "remarks", 123 "Errors, warnings and remarks"))); 124 125 static cl::opt<bool, /*ExternalStorage=*/true> disableDiagnosticNotes( 126 "mlir-disable-diagnostic-notes", cl::desc("Disable diagnostic notes."), 127 cl::location(disableDiagnosticNotesFlag), cl::init(false)); 128 129 static cl::opt<bool, /*ExternalStorage=*/true> enableDebuggerHook( 130 "mlir-enable-debugger-hook", 131 cl::desc("Enable Debugger hook for debugging MLIR Actions"), 132 cl::location(enableDebuggerActionHookFlag), cl::init(false)); 133 134 static cl::opt<bool, /*ExternalStorage=*/true> explicitModule( 135 "no-implicit-module", 136 cl::desc("Disable implicit addition of a top-level module op during " 137 "parsing"), 138 cl::location(useExplicitModuleFlag), cl::init(false)); 139 140 static cl::opt<bool, /*ExternalStorage=*/true> listPasses( 141 "list-passes", cl::desc("Print the list of registered passes and exit"), 142 cl::location(listPassesFlag), cl::init(false)); 143 144 static cl::opt<bool, /*ExternalStorage=*/true> runReproducer( 145 "run-reproducer", cl::desc("Run the pipeline stored in the reproducer"), 146 cl::location(runReproducerFlag), cl::init(false)); 147 148 static cl::opt<bool, /*ExternalStorage=*/true> showDialects( 149 "show-dialects", 150 cl::desc("Print the list of registered dialects and exit"), 151 cl::location(showDialectsFlag), cl::init(false)); 152 153 static cl::opt<std::string, /*ExternalStorage=*/true> splitInputFile{ 154 "split-input-file", 155 llvm::cl::ValueOptional, 156 cl::callback([&](const std::string &str) { 157 // Implicit value: use default marker if flag was used without value. 158 if (str.empty()) 159 splitInputFile.setValue(kDefaultSplitMarker); 160 }), 161 cl::desc("Split the input file into chunks using the given or " 162 "default marker and process each chunk independently"), 163 cl::location(splitInputFileFlag), 164 cl::init("")}; 165 166 static cl::opt<std::string, /*ExternalStorage=*/true> outputSplitMarker( 167 "output-split-marker", 168 cl::desc("Split marker to use for merging the ouput"), 169 cl::location(outputSplitMarkerFlag), cl::init(kDefaultSplitMarker)); 170 171 static cl::opt<bool, /*ExternalStorage=*/true> verifyDiagnostics( 172 "verify-diagnostics", 173 cl::desc("Check that emitted diagnostics match " 174 "expected-* lines on the corresponding line"), 175 cl::location(verifyDiagnosticsFlag), cl::init(false)); 176 177 static cl::opt<bool, /*ExternalStorage=*/true> verifyPasses( 178 "verify-each", 179 cl::desc("Run the verifier after each transformation pass"), 180 cl::location(verifyPassesFlag), cl::init(true)); 181 182 static cl::opt<bool, /*ExternalStorage=*/true> disableVerifyOnParsing( 183 "mlir-very-unsafe-disable-verifier-on-parsing", 184 cl::desc("Disable the verifier on parsing (very unsafe)"), 185 cl::location(disableVerifierOnParsingFlag), cl::init(false)); 186 187 static cl::opt<bool, /*ExternalStorage=*/true> verifyRoundtrip( 188 "verify-roundtrip", 189 cl::desc("Round-trip the IR after parsing and ensure it succeeds"), 190 cl::location(verifyRoundtripFlag), cl::init(false)); 191 192 static cl::list<std::string> passPlugins( 193 "load-pass-plugin", cl::desc("Load passes from plugin library")); 194 195 static cl::opt<std::string, /*ExternalStorage=*/true> 196 generateReproducerFile( 197 "mlir-generate-reproducer", 198 llvm::cl::desc( 199 "Generate an mlir reproducer at the provided filename" 200 " (no crash required)"), 201 cl::location(generateReproducerFileFlag), cl::init(""), 202 cl::value_desc("filename")); 203 204 /// Set the callback to load a pass plugin. 205 passPlugins.setCallback([&](const std::string &pluginPath) { 206 auto plugin = PassPlugin::load(pluginPath); 207 if (!plugin) { 208 errs() << "Failed to load passes from '" << pluginPath 209 << "'. Request ignored.\n"; 210 return; 211 } 212 plugin.get().registerPassRegistryCallbacks(); 213 }); 214 215 static cl::list<std::string> dialectPlugins( 216 "load-dialect-plugin", cl::desc("Load dialects from plugin library")); 217 this->dialectPlugins = std::addressof(dialectPlugins); 218 219 static PassPipelineCLParser passPipeline("", "Compiler passes to run", "p"); 220 setPassPipelineParser(passPipeline); 221 } 222 223 /// Set the callback to load a dialect plugin. 224 void setDialectPluginsCallback(DialectRegistry ®istry); 225 226 /// Pointer to static dialectPlugins variable in constructor, needed by 227 /// setDialectPluginsCallback(DialectRegistry&). 228 cl::list<std::string> *dialectPlugins = nullptr; 229 }; 230 231 /// A scoped diagnostic handler that suppresses certain diagnostics based on 232 /// the verbosity level and whether the diagnostic is a note. 233 class DiagnosticFilter : public ScopedDiagnosticHandler { 234 public: 235 DiagnosticFilter(MLIRContext *ctx, VerbosityLevel verbosityLevel, 236 bool showNotes = true) 237 : ScopedDiagnosticHandler(ctx) { 238 setHandler([verbosityLevel, showNotes](Diagnostic &diag) { 239 auto severity = diag.getSeverity(); 240 switch (severity) { 241 case DiagnosticSeverity::Error: 242 // failure indicates that the error is not handled by the filter and 243 // goes through to the default handler. Therefore, the error can be 244 // successfully printed. 245 return failure(); 246 case DiagnosticSeverity::Warning: 247 if (verbosityLevel == VerbosityLevel::ErrorsOnly) 248 return success(); 249 else 250 return failure(); 251 case DiagnosticSeverity::Remark: 252 if (verbosityLevel == VerbosityLevel::ErrorsOnly || 253 verbosityLevel == VerbosityLevel::ErrorsAndWarnings) 254 return success(); 255 else 256 return failure(); 257 case DiagnosticSeverity::Note: 258 if (showNotes) 259 return failure(); 260 else 261 return success(); 262 } 263 llvm_unreachable("Unknown diagnostic severity"); 264 }); 265 } 266 }; 267 } // namespace 268 269 ManagedStatic<MlirOptMainConfigCLOptions> clOptionsConfig; 270 271 void MlirOptMainConfig::registerCLOptions(DialectRegistry ®istry) { 272 clOptionsConfig->setDialectPluginsCallback(registry); 273 tracing::DebugConfig::registerCLOptions(); 274 } 275 276 MlirOptMainConfig MlirOptMainConfig::createFromCLOptions() { 277 clOptionsConfig->setDebugConfig(tracing::DebugConfig::createFromCLOptions()); 278 return *clOptionsConfig; 279 } 280 281 MlirOptMainConfig &MlirOptMainConfig::setPassPipelineParser( 282 const PassPipelineCLParser &passPipeline) { 283 passPipelineCallback = [&](PassManager &pm) { 284 auto errorHandler = [&](const Twine &msg) { 285 emitError(UnknownLoc::get(pm.getContext())) << msg; 286 return failure(); 287 }; 288 if (failed(passPipeline.addToPipeline(pm, errorHandler))) 289 return failure(); 290 if (this->shouldDumpPassPipeline()) { 291 292 pm.dump(); 293 llvm::errs() << "\n"; 294 } 295 return success(); 296 }; 297 return *this; 298 } 299 300 void MlirOptMainConfigCLOptions::setDialectPluginsCallback( 301 DialectRegistry ®istry) { 302 dialectPlugins->setCallback([&](const std::string &pluginPath) { 303 auto plugin = DialectPlugin::load(pluginPath); 304 if (!plugin) { 305 errs() << "Failed to load dialect plugin from '" << pluginPath 306 << "'. Request ignored.\n"; 307 return; 308 }; 309 plugin.get().registerDialectRegistryCallbacks(registry); 310 }); 311 } 312 313 LogicalResult loadIRDLDialects(StringRef irdlFile, MLIRContext &ctx) { 314 DialectRegistry registry; 315 registry.insert<irdl::IRDLDialect>(); 316 ctx.appendDialectRegistry(registry); 317 318 // Set up the input file. 319 std::string errorMessage; 320 std::unique_ptr<MemoryBuffer> file = openInputFile(irdlFile, &errorMessage); 321 if (!file) { 322 emitError(UnknownLoc::get(&ctx)) << errorMessage; 323 return failure(); 324 } 325 326 // Give the buffer to the source manager. 327 // This will be picked up by the parser. 328 SourceMgr sourceMgr; 329 sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); 330 331 SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &ctx); 332 333 // Parse the input file. 334 OwningOpRef<ModuleOp> module(parseSourceFile<ModuleOp>(sourceMgr, &ctx)); 335 if (!module) 336 return failure(); 337 338 // Load IRDL dialects. 339 return irdl::loadDialects(module.get()); 340 } 341 342 // Return success if the module can correctly round-trip. This intended to test 343 // that the custom printers/parsers are complete. 344 static LogicalResult doVerifyRoundTrip(Operation *op, 345 const MlirOptMainConfig &config, 346 bool useBytecode) { 347 // We use a new context to avoid resource handle renaming issue in the diff. 348 MLIRContext roundtripContext; 349 OwningOpRef<Operation *> roundtripModule; 350 roundtripContext.appendDialectRegistry( 351 op->getContext()->getDialectRegistry()); 352 if (op->getContext()->allowsUnregisteredDialects()) 353 roundtripContext.allowUnregisteredDialects(); 354 StringRef irdlFile = config.getIrdlFile(); 355 if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, roundtripContext))) 356 return failure(); 357 358 std::string testType = (useBytecode) ? "bytecode" : "textual"; 359 // Print a first time with custom format (or bytecode) and parse it back to 360 // the roundtripModule. 361 { 362 std::string buffer; 363 llvm::raw_string_ostream ostream(buffer); 364 if (useBytecode) { 365 if (failed(writeBytecodeToFile(op, ostream))) { 366 op->emitOpError() 367 << "failed to write bytecode, cannot verify round-trip.\n"; 368 return failure(); 369 } 370 } else { 371 op->print(ostream, 372 OpPrintingFlags().printGenericOpForm().enableDebugInfo()); 373 } 374 FallbackAsmResourceMap fallbackResourceMap; 375 ParserConfig parseConfig(&roundtripContext, config.shouldVerifyOnParsing(), 376 &fallbackResourceMap); 377 roundtripModule = parseSourceString<Operation *>(buffer, parseConfig); 378 if (!roundtripModule) { 379 op->emitOpError() << "failed to parse " << testType 380 << " content back, cannot verify round-trip.\n"; 381 return failure(); 382 } 383 } 384 385 // Print in the generic form for the reference module and the round-tripped 386 // one and compare the outputs. 387 std::string reference, roundtrip; 388 { 389 llvm::raw_string_ostream ostreamref(reference); 390 op->print(ostreamref, 391 OpPrintingFlags().printGenericOpForm().enableDebugInfo()); 392 llvm::raw_string_ostream ostreamrndtrip(roundtrip); 393 roundtripModule.get()->print( 394 ostreamrndtrip, 395 OpPrintingFlags().printGenericOpForm().enableDebugInfo()); 396 } 397 if (reference != roundtrip) { 398 // TODO implement a diff. 399 return op->emitOpError() 400 << testType 401 << " roundTrip testing roundtripped module differs " 402 "from reference:\n<<<<<<Reference\n" 403 << reference << "\n=====\n" 404 << roundtrip << "\n>>>>>roundtripped\n"; 405 } 406 407 return success(); 408 } 409 410 static LogicalResult doVerifyRoundTrip(Operation *op, 411 const MlirOptMainConfig &config) { 412 auto txtStatus = doVerifyRoundTrip(op, config, /*useBytecode=*/false); 413 auto bcStatus = doVerifyRoundTrip(op, config, /*useBytecode=*/true); 414 return success(succeeded(txtStatus) && succeeded(bcStatus)); 415 } 416 417 /// Perform the actions on the input file indicated by the command line flags 418 /// within the specified context. 419 /// 420 /// This typically parses the main source file, runs zero or more optimization 421 /// passes, then prints the output. 422 /// 423 static LogicalResult 424 performActions(raw_ostream &os, 425 const std::shared_ptr<llvm::SourceMgr> &sourceMgr, 426 MLIRContext *context, const MlirOptMainConfig &config) { 427 DefaultTimingManager tm; 428 applyDefaultTimingManagerCLOptions(tm); 429 TimingScope timing = tm.getRootScope(); 430 431 // Disable multi-threading when parsing the input file. This removes the 432 // unnecessary/costly context synchronization when parsing. 433 bool wasThreadingEnabled = context->isMultithreadingEnabled(); 434 context->disableMultithreading(); 435 436 // Prepare the parser config, and attach any useful/necessary resource 437 // handlers. Unhandled external resources are treated as passthrough, i.e. 438 // they are not processed and will be emitted directly to the output 439 // untouched. 440 PassReproducerOptions reproOptions; 441 FallbackAsmResourceMap fallbackResourceMap; 442 ParserConfig parseConfig(context, config.shouldVerifyOnParsing(), 443 &fallbackResourceMap); 444 if (config.shouldRunReproducer()) 445 reproOptions.attachResourceParser(parseConfig); 446 447 // Parse the input file and reset the context threading state. 448 TimingScope parserTiming = timing.nest("Parser"); 449 OwningOpRef<Operation *> op = parseSourceFileForTool( 450 sourceMgr, parseConfig, !config.shouldUseExplicitModule()); 451 parserTiming.stop(); 452 if (!op) 453 return failure(); 454 455 // Perform round-trip verification if requested 456 if (config.shouldVerifyRoundtrip() && 457 failed(doVerifyRoundTrip(op.get(), config))) 458 return failure(); 459 460 context->enableMultithreading(wasThreadingEnabled); 461 462 // Prepare the pass manager, applying command-line and reproducer options. 463 PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit); 464 pm.enableVerifier(config.shouldVerifyPasses()); 465 if (failed(applyPassManagerCLOptions(pm))) 466 return failure(); 467 pm.enableTiming(timing); 468 if (config.shouldRunReproducer() && failed(reproOptions.apply(pm))) 469 return failure(); 470 if (failed(config.setupPassPipeline(pm))) 471 return failure(); 472 473 // Run the pipeline. 474 if (failed(pm.run(*op))) 475 return failure(); 476 477 // Generate reproducers if requested 478 if (!config.getReproducerFilename().empty()) { 479 StringRef anchorName = pm.getAnyOpAnchorName(); 480 const auto &passes = pm.getPasses(); 481 makeReproducer(anchorName, passes, op.get(), 482 config.getReproducerFilename()); 483 } 484 485 // Print the output. 486 TimingScope outputTiming = timing.nest("Output"); 487 if (config.shouldEmitBytecode()) { 488 BytecodeWriterConfig writerConfig(fallbackResourceMap); 489 if (auto v = config.bytecodeVersionToEmit()) 490 writerConfig.setDesiredBytecodeVersion(*v); 491 if (config.shouldElideResourceDataFromBytecode()) 492 writerConfig.setElideResourceDataFlag(); 493 return writeBytecodeToFile(op.get(), os, writerConfig); 494 } 495 496 if (config.bytecodeVersionToEmit().has_value()) 497 return emitError(UnknownLoc::get(pm.getContext())) 498 << "bytecode version while not emitting bytecode"; 499 AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr, 500 &fallbackResourceMap); 501 op.get()->print(os, asmState); 502 os << '\n'; 503 return success(); 504 } 505 506 /// Parses the memory buffer. If successfully, run a series of passes against 507 /// it and print the result. 508 static LogicalResult processBuffer(raw_ostream &os, 509 std::unique_ptr<MemoryBuffer> ownedBuffer, 510 const MlirOptMainConfig &config, 511 DialectRegistry ®istry, 512 llvm::ThreadPoolInterface *threadPool) { 513 // Tell sourceMgr about this buffer, which is what the parser will pick up. 514 auto sourceMgr = std::make_shared<SourceMgr>(); 515 sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); 516 517 // Create a context just for the current buffer. Disable threading on creation 518 // since we'll inject the thread-pool separately. 519 MLIRContext context(registry, MLIRContext::Threading::DISABLED); 520 if (threadPool) 521 context.setThreadPool(*threadPool); 522 523 StringRef irdlFile = config.getIrdlFile(); 524 if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, context))) 525 return failure(); 526 527 // Parse the input file. 528 context.allowUnregisteredDialects(config.shouldAllowUnregisteredDialects()); 529 if (config.shouldVerifyDiagnostics()) 530 context.printOpOnDiagnostic(false); 531 532 tracing::InstallDebugHandler installDebugHandler(context, 533 config.getDebugConfig()); 534 535 // If we are in verify diagnostics mode then we have a lot of work to do, 536 // otherwise just perform the actions without worrying about it. 537 if (!config.shouldVerifyDiagnostics()) { 538 SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context); 539 DiagnosticFilter diagnosticFilter(&context, 540 config.getDiagnosticVerbosityLevel(), 541 config.shouldShowNotes()); 542 return performActions(os, sourceMgr, &context, config); 543 } 544 545 SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context); 546 547 // Do any processing requested by command line flags. We don't care whether 548 // these actions succeed or fail, we only care what diagnostics they produce 549 // and whether they match our expectations. 550 (void)performActions(os, sourceMgr, &context, config); 551 552 // Verify the diagnostic handler to make sure that each of the diagnostics 553 // matched. 554 return sourceMgrHandler.verify(); 555 } 556 557 std::pair<std::string, std::string> 558 mlir::registerAndParseCLIOptions(int argc, char **argv, 559 llvm::StringRef toolName, 560 DialectRegistry ®istry) { 561 static cl::opt<std::string> inputFilename( 562 cl::Positional, cl::desc("<input file>"), cl::init("-")); 563 564 static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"), 565 cl::value_desc("filename"), 566 cl::init("-")); 567 // Register any command line options. 568 MlirOptMainConfig::registerCLOptions(registry); 569 registerAsmPrinterCLOptions(); 570 registerMLIRContextCLOptions(); 571 registerPassManagerCLOptions(); 572 registerDefaultTimingManagerCLOptions(); 573 tracing::DebugCounter::registerCLOptions(); 574 575 // Build the list of dialects as a header for the --help message. 576 std::string helpHeader = (toolName + "\nAvailable Dialects: ").str(); 577 { 578 llvm::raw_string_ostream os(helpHeader); 579 interleaveComma(registry.getDialectNames(), os, 580 [&](auto name) { os << name; }); 581 } 582 // Parse pass names in main to ensure static initialization completed. 583 cl::ParseCommandLineOptions(argc, argv, helpHeader); 584 return std::make_pair(inputFilename.getValue(), outputFilename.getValue()); 585 } 586 587 static LogicalResult printRegisteredDialects(DialectRegistry ®istry) { 588 llvm::outs() << "Available Dialects: "; 589 interleave(registry.getDialectNames(), llvm::outs(), ","); 590 llvm::outs() << "\n"; 591 return success(); 592 } 593 594 static LogicalResult printRegisteredPassesAndReturn() { 595 mlir::printRegisteredPasses(); 596 return success(); 597 } 598 599 LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream, 600 std::unique_ptr<llvm::MemoryBuffer> buffer, 601 DialectRegistry ®istry, 602 const MlirOptMainConfig &config) { 603 if (config.shouldShowDialects()) 604 return printRegisteredDialects(registry); 605 606 if (config.shouldListPasses()) 607 return printRegisteredPassesAndReturn(); 608 609 // The split-input-file mode is a very specific mode that slices the file 610 // up into small pieces and checks each independently. 611 // We use an explicit threadpool to avoid creating and joining/destroying 612 // threads for each of the split. 613 ThreadPoolInterface *threadPool = nullptr; 614 615 // Create a temporary context for the sake of checking if 616 // --mlir-disable-threading was passed on the command line. 617 // We use the thread-pool this context is creating, and avoid 618 // creating any thread when disabled. 619 MLIRContext threadPoolCtx; 620 if (threadPoolCtx.isMultithreadingEnabled()) 621 threadPool = &threadPoolCtx.getThreadPool(); 622 623 auto chunkFn = [&](std::unique_ptr<MemoryBuffer> chunkBuffer, 624 raw_ostream &os) { 625 return processBuffer(os, std::move(chunkBuffer), config, registry, 626 threadPool); 627 }; 628 return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream, 629 config.inputSplitMarker(), 630 config.outputSplitMarker()); 631 } 632 633 LogicalResult mlir::MlirOptMain(int argc, char **argv, 634 llvm::StringRef inputFilename, 635 llvm::StringRef outputFilename, 636 DialectRegistry ®istry) { 637 638 InitLLVM y(argc, argv); 639 640 MlirOptMainConfig config = MlirOptMainConfig::createFromCLOptions(); 641 642 if (config.shouldShowDialects()) 643 return printRegisteredDialects(registry); 644 645 if (config.shouldListPasses()) 646 return printRegisteredPassesAndReturn(); 647 648 // When reading from stdin and the input is a tty, it is often a user mistake 649 // and the process "appears to be stuck". Print a message to let the user know 650 // about it! 651 if (inputFilename == "-" && 652 sys::Process::FileDescriptorIsDisplayed(fileno(stdin))) 653 llvm::errs() << "(processing input from stdin now, hit ctrl-c/ctrl-d to " 654 "interrupt)\n"; 655 656 // Set up the input file. 657 std::string errorMessage; 658 auto file = openInputFile(inputFilename, &errorMessage); 659 if (!file) { 660 llvm::errs() << errorMessage << "\n"; 661 return failure(); 662 } 663 664 auto output = openOutputFile(outputFilename, &errorMessage); 665 if (!output) { 666 llvm::errs() << errorMessage << "\n"; 667 return failure(); 668 } 669 if (failed(MlirOptMain(output->os(), std::move(file), registry, config))) 670 return failure(); 671 672 // Keep the output file if the invocation of MlirOptMain was successful. 673 output->keep(); 674 return success(); 675 } 676 677 LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName, 678 DialectRegistry ®istry) { 679 680 // Register and parse command line options. 681 std::string inputFilename, outputFilename; 682 std::tie(inputFilename, outputFilename) = 683 registerAndParseCLIOptions(argc, argv, toolName, registry); 684 685 return MlirOptMain(argc, argv, inputFilename, outputFilename, registry); 686 } 687