xref: /llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp (revision c8b15157d70c57489b9aba939065c01c3f697ddb)
1 //===- MlirOptMain.cpp - MLIR Optimizer Driver ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a utility that runs an optimization pass and prints the result back
10 // out. It is designed to support unit testing.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Tools/mlir-opt/MlirOptMain.h"
15 #include "mlir/Bytecode/BytecodeWriter.h"
16 #include "mlir/Debug/CLOptionsSetup.h"
17 #include "mlir/Debug/Counter.h"
18 #include "mlir/Debug/DebuggerExecutionContextHook.h"
19 #include "mlir/Debug/ExecutionContext.h"
20 #include "mlir/Debug/Observers/ActionLogging.h"
21 #include "mlir/Dialect/IRDL/IR/IRDL.h"
22 #include "mlir/Dialect/IRDL/IRDLLoading.h"
23 #include "mlir/IR/AsmState.h"
24 #include "mlir/IR/Attributes.h"
25 #include "mlir/IR/BuiltinOps.h"
26 #include "mlir/IR/Diagnostics.h"
27 #include "mlir/IR/Dialect.h"
28 #include "mlir/IR/Location.h"
29 #include "mlir/IR/MLIRContext.h"
30 #include "mlir/Parser/Parser.h"
31 #include "mlir/Pass/Pass.h"
32 #include "mlir/Pass/PassManager.h"
33 #include "mlir/Pass/PassRegistry.h"
34 #include "mlir/Support/FileUtilities.h"
35 #include "mlir/Support/Timing.h"
36 #include "mlir/Support/ToolUtilities.h"
37 #include "mlir/Tools/ParseUtilities.h"
38 #include "mlir/Tools/Plugins/DialectPlugin.h"
39 #include "mlir/Tools/Plugins/PassPlugin.h"
40 #include "llvm/ADT/StringRef.h"
41 #include "llvm/Support/CommandLine.h"
42 #include "llvm/Support/FileUtilities.h"
43 #include "llvm/Support/InitLLVM.h"
44 #include "llvm/Support/LogicalResult.h"
45 #include "llvm/Support/ManagedStatic.h"
46 #include "llvm/Support/Process.h"
47 #include "llvm/Support/Regex.h"
48 #include "llvm/Support/SourceMgr.h"
49 #include "llvm/Support/StringSaver.h"
50 #include "llvm/Support/ThreadPool.h"
51 #include "llvm/Support/ToolOutputFile.h"
52 
53 using namespace mlir;
54 using namespace llvm;
55 
56 namespace {
57 class BytecodeVersionParser : public cl::parser<std::optional<int64_t>> {
58 public:
59   BytecodeVersionParser(cl::Option &o)
60       : cl::parser<std::optional<int64_t>>(o) {}
61 
62   bool parse(cl::Option &o, StringRef /*argName*/, StringRef arg,
63              std::optional<int64_t> &v) {
64     long long w;
65     if (getAsSignedInteger(arg, 10, w))
66       return o.error("Invalid argument '" + arg +
67                      "', only integer is supported.");
68     v = w;
69     return false;
70   }
71 };
72 
73 /// This class is intended to manage the handling of command line options for
74 /// creating a *-opt config. This is a singleton.
75 struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
76   MlirOptMainConfigCLOptions() {
77     // These options are static but all uses ExternalStorage to initialize the
78     // members of the parent class. This is unusual but since this class is a
79     // singleton it basically attaches command line option to the singleton
80     // members.
81 
82     static cl::opt<bool, /*ExternalStorage=*/true> allowUnregisteredDialects(
83         "allow-unregistered-dialect",
84         cl::desc("Allow operation with no registered dialects"),
85         cl::location(allowUnregisteredDialectsFlag), cl::init(false));
86 
87     static cl::opt<bool, /*ExternalStorage=*/true> dumpPassPipeline(
88         "dump-pass-pipeline", cl::desc("Print the pipeline that will be run"),
89         cl::location(dumpPassPipelineFlag), cl::init(false));
90 
91     static cl::opt<bool, /*ExternalStorage=*/true> emitBytecode(
92         "emit-bytecode", cl::desc("Emit bytecode when generating output"),
93         cl::location(emitBytecodeFlag), cl::init(false));
94 
95     static cl::opt<bool, /*ExternalStorage=*/true> elideResourcesFromBytecode(
96         "elide-resource-data-from-bytecode",
97         cl::desc("Elide resources when generating bytecode"),
98         cl::location(elideResourceDataFromBytecodeFlag), cl::init(false));
99 
100     static cl::opt<std::optional<int64_t>, /*ExternalStorage=*/true,
101                    BytecodeVersionParser>
102         bytecodeVersion(
103             "emit-bytecode-version",
104             cl::desc("Use specified bytecode when generating output"),
105             cl::location(emitBytecodeVersion), cl::init(std::nullopt));
106 
107     static cl::opt<std::string, /*ExternalStorage=*/true> irdlFile(
108         "irdl-file",
109         cl::desc("IRDL file to register before processing the input"),
110         cl::location(irdlFileFlag), cl::init(""), cl::value_desc("filename"));
111 
112     static cl::opt<VerbosityLevel, /*ExternalStorage=*/true>
113         diagnosticVerbosityLevel(
114             "mlir-diagnostic-verbosity-level",
115             cl::desc("Choose level of diagnostic information"),
116             cl::location(diagnosticVerbosityLevelFlag),
117             cl::init(VerbosityLevel::ErrorsWarningsAndRemarks),
118             cl::values(
119                 clEnumValN(VerbosityLevel::ErrorsOnly, "errors", "Errors only"),
120                 clEnumValN(VerbosityLevel::ErrorsAndWarnings, "warnings",
121                            "Errors and warnings"),
122                 clEnumValN(VerbosityLevel::ErrorsWarningsAndRemarks, "remarks",
123                            "Errors, warnings and remarks")));
124 
125     static cl::opt<bool, /*ExternalStorage=*/true> disableDiagnosticNotes(
126         "mlir-disable-diagnostic-notes", cl::desc("Disable diagnostic notes."),
127         cl::location(disableDiagnosticNotesFlag), cl::init(false));
128 
129     static cl::opt<bool, /*ExternalStorage=*/true> enableDebuggerHook(
130         "mlir-enable-debugger-hook",
131         cl::desc("Enable Debugger hook for debugging MLIR Actions"),
132         cl::location(enableDebuggerActionHookFlag), cl::init(false));
133 
134     static cl::opt<bool, /*ExternalStorage=*/true> explicitModule(
135         "no-implicit-module",
136         cl::desc("Disable implicit addition of a top-level module op during "
137                  "parsing"),
138         cl::location(useExplicitModuleFlag), cl::init(false));
139 
140     static cl::opt<bool, /*ExternalStorage=*/true> listPasses(
141         "list-passes", cl::desc("Print the list of registered passes and exit"),
142         cl::location(listPassesFlag), cl::init(false));
143 
144     static cl::opt<bool, /*ExternalStorage=*/true> runReproducer(
145         "run-reproducer", cl::desc("Run the pipeline stored in the reproducer"),
146         cl::location(runReproducerFlag), cl::init(false));
147 
148     static cl::opt<bool, /*ExternalStorage=*/true> showDialects(
149         "show-dialects",
150         cl::desc("Print the list of registered dialects and exit"),
151         cl::location(showDialectsFlag), cl::init(false));
152 
153     static cl::opt<std::string, /*ExternalStorage=*/true> splitInputFile{
154         "split-input-file",
155         llvm::cl::ValueOptional,
156         cl::callback([&](const std::string &str) {
157           // Implicit value: use default marker if flag was used without value.
158           if (str.empty())
159             splitInputFile.setValue(kDefaultSplitMarker);
160         }),
161         cl::desc("Split the input file into chunks using the given or "
162                  "default marker and process each chunk independently"),
163         cl::location(splitInputFileFlag),
164         cl::init("")};
165 
166     static cl::opt<std::string, /*ExternalStorage=*/true> outputSplitMarker(
167         "output-split-marker",
168         cl::desc("Split marker to use for merging the ouput"),
169         cl::location(outputSplitMarkerFlag), cl::init(kDefaultSplitMarker));
170 
171     static cl::opt<bool, /*ExternalStorage=*/true> verifyDiagnostics(
172         "verify-diagnostics",
173         cl::desc("Check that emitted diagnostics match "
174                  "expected-* lines on the corresponding line"),
175         cl::location(verifyDiagnosticsFlag), cl::init(false));
176 
177     static cl::opt<bool, /*ExternalStorage=*/true> verifyPasses(
178         "verify-each",
179         cl::desc("Run the verifier after each transformation pass"),
180         cl::location(verifyPassesFlag), cl::init(true));
181 
182     static cl::opt<bool, /*ExternalStorage=*/true> disableVerifyOnParsing(
183         "mlir-very-unsafe-disable-verifier-on-parsing",
184         cl::desc("Disable the verifier on parsing (very unsafe)"),
185         cl::location(disableVerifierOnParsingFlag), cl::init(false));
186 
187     static cl::opt<bool, /*ExternalStorage=*/true> verifyRoundtrip(
188         "verify-roundtrip",
189         cl::desc("Round-trip the IR after parsing and ensure it succeeds"),
190         cl::location(verifyRoundtripFlag), cl::init(false));
191 
192     static cl::list<std::string> passPlugins(
193         "load-pass-plugin", cl::desc("Load passes from plugin library"));
194 
195     static cl::opt<std::string, /*ExternalStorage=*/true>
196         generateReproducerFile(
197             "mlir-generate-reproducer",
198             llvm::cl::desc(
199                 "Generate an mlir reproducer at the provided filename"
200                 " (no crash required)"),
201             cl::location(generateReproducerFileFlag), cl::init(""),
202             cl::value_desc("filename"));
203 
204     /// Set the callback to load a pass plugin.
205     passPlugins.setCallback([&](const std::string &pluginPath) {
206       auto plugin = PassPlugin::load(pluginPath);
207       if (!plugin) {
208         errs() << "Failed to load passes from '" << pluginPath
209                << "'. Request ignored.\n";
210         return;
211       }
212       plugin.get().registerPassRegistryCallbacks();
213     });
214 
215     static cl::list<std::string> dialectPlugins(
216         "load-dialect-plugin", cl::desc("Load dialects from plugin library"));
217     this->dialectPlugins = std::addressof(dialectPlugins);
218 
219     static PassPipelineCLParser passPipeline("", "Compiler passes to run", "p");
220     setPassPipelineParser(passPipeline);
221   }
222 
223   /// Set the callback to load a dialect plugin.
224   void setDialectPluginsCallback(DialectRegistry &registry);
225 
226   /// Pointer to static dialectPlugins variable in constructor, needed by
227   /// setDialectPluginsCallback(DialectRegistry&).
228   cl::list<std::string> *dialectPlugins = nullptr;
229 };
230 
231 /// A scoped diagnostic handler that suppresses certain diagnostics based on
232 /// the verbosity level and whether the diagnostic is a note.
233 class DiagnosticFilter : public ScopedDiagnosticHandler {
234 public:
235   DiagnosticFilter(MLIRContext *ctx, VerbosityLevel verbosityLevel,
236                    bool showNotes = true)
237       : ScopedDiagnosticHandler(ctx) {
238     setHandler([verbosityLevel, showNotes](Diagnostic &diag) {
239       auto severity = diag.getSeverity();
240       switch (severity) {
241       case DiagnosticSeverity::Error:
242         // failure indicates that the error is not handled by the filter and
243         // goes through to the default handler. Therefore, the error can be
244         // successfully printed.
245         return failure();
246       case DiagnosticSeverity::Warning:
247         if (verbosityLevel == VerbosityLevel::ErrorsOnly)
248           return success();
249         else
250           return failure();
251       case DiagnosticSeverity::Remark:
252         if (verbosityLevel == VerbosityLevel::ErrorsOnly ||
253             verbosityLevel == VerbosityLevel::ErrorsAndWarnings)
254           return success();
255         else
256           return failure();
257       case DiagnosticSeverity::Note:
258         if (showNotes)
259           return failure();
260         else
261           return success();
262       }
263       llvm_unreachable("Unknown diagnostic severity");
264     });
265   }
266 };
267 } // namespace
268 
269 ManagedStatic<MlirOptMainConfigCLOptions> clOptionsConfig;
270 
271 void MlirOptMainConfig::registerCLOptions(DialectRegistry &registry) {
272   clOptionsConfig->setDialectPluginsCallback(registry);
273   tracing::DebugConfig::registerCLOptions();
274 }
275 
276 MlirOptMainConfig MlirOptMainConfig::createFromCLOptions() {
277   clOptionsConfig->setDebugConfig(tracing::DebugConfig::createFromCLOptions());
278   return *clOptionsConfig;
279 }
280 
281 MlirOptMainConfig &MlirOptMainConfig::setPassPipelineParser(
282     const PassPipelineCLParser &passPipeline) {
283   passPipelineCallback = [&](PassManager &pm) {
284     auto errorHandler = [&](const Twine &msg) {
285       emitError(UnknownLoc::get(pm.getContext())) << msg;
286       return failure();
287     };
288     if (failed(passPipeline.addToPipeline(pm, errorHandler)))
289       return failure();
290     if (this->shouldDumpPassPipeline()) {
291 
292       pm.dump();
293       llvm::errs() << "\n";
294     }
295     return success();
296   };
297   return *this;
298 }
299 
300 void MlirOptMainConfigCLOptions::setDialectPluginsCallback(
301     DialectRegistry &registry) {
302   dialectPlugins->setCallback([&](const std::string &pluginPath) {
303     auto plugin = DialectPlugin::load(pluginPath);
304     if (!plugin) {
305       errs() << "Failed to load dialect plugin from '" << pluginPath
306              << "'. Request ignored.\n";
307       return;
308     };
309     plugin.get().registerDialectRegistryCallbacks(registry);
310   });
311 }
312 
313 LogicalResult loadIRDLDialects(StringRef irdlFile, MLIRContext &ctx) {
314   DialectRegistry registry;
315   registry.insert<irdl::IRDLDialect>();
316   ctx.appendDialectRegistry(registry);
317 
318   // Set up the input file.
319   std::string errorMessage;
320   std::unique_ptr<MemoryBuffer> file = openInputFile(irdlFile, &errorMessage);
321   if (!file) {
322     emitError(UnknownLoc::get(&ctx)) << errorMessage;
323     return failure();
324   }
325 
326   // Give the buffer to the source manager.
327   // This will be picked up by the parser.
328   SourceMgr sourceMgr;
329   sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
330 
331   SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &ctx);
332 
333   // Parse the input file.
334   OwningOpRef<ModuleOp> module(parseSourceFile<ModuleOp>(sourceMgr, &ctx));
335   if (!module)
336     return failure();
337 
338   // Load IRDL dialects.
339   return irdl::loadDialects(module.get());
340 }
341 
342 // Return success if the module can correctly round-trip. This intended to test
343 // that the custom printers/parsers are complete.
344 static LogicalResult doVerifyRoundTrip(Operation *op,
345                                        const MlirOptMainConfig &config,
346                                        bool useBytecode) {
347   // We use a new context to avoid resource handle renaming issue in the diff.
348   MLIRContext roundtripContext;
349   OwningOpRef<Operation *> roundtripModule;
350   roundtripContext.appendDialectRegistry(
351       op->getContext()->getDialectRegistry());
352   if (op->getContext()->allowsUnregisteredDialects())
353     roundtripContext.allowUnregisteredDialects();
354   StringRef irdlFile = config.getIrdlFile();
355   if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, roundtripContext)))
356     return failure();
357 
358   std::string testType = (useBytecode) ? "bytecode" : "textual";
359   // Print a first time with custom format (or bytecode) and parse it back to
360   // the roundtripModule.
361   {
362     std::string buffer;
363     llvm::raw_string_ostream ostream(buffer);
364     if (useBytecode) {
365       if (failed(writeBytecodeToFile(op, ostream))) {
366         op->emitOpError()
367             << "failed to write bytecode, cannot verify round-trip.\n";
368         return failure();
369       }
370     } else {
371       op->print(ostream,
372                 OpPrintingFlags().printGenericOpForm().enableDebugInfo());
373     }
374     FallbackAsmResourceMap fallbackResourceMap;
375     ParserConfig parseConfig(&roundtripContext, config.shouldVerifyOnParsing(),
376                              &fallbackResourceMap);
377     roundtripModule = parseSourceString<Operation *>(buffer, parseConfig);
378     if (!roundtripModule) {
379       op->emitOpError() << "failed to parse " << testType
380                         << " content back, cannot verify round-trip.\n";
381       return failure();
382     }
383   }
384 
385   // Print in the generic form for the reference module and the round-tripped
386   // one and compare the outputs.
387   std::string reference, roundtrip;
388   {
389     llvm::raw_string_ostream ostreamref(reference);
390     op->print(ostreamref,
391               OpPrintingFlags().printGenericOpForm().enableDebugInfo());
392     llvm::raw_string_ostream ostreamrndtrip(roundtrip);
393     roundtripModule.get()->print(
394         ostreamrndtrip,
395         OpPrintingFlags().printGenericOpForm().enableDebugInfo());
396   }
397   if (reference != roundtrip) {
398     // TODO implement a diff.
399     return op->emitOpError()
400            << testType
401            << " roundTrip testing roundtripped module differs "
402               "from reference:\n<<<<<<Reference\n"
403            << reference << "\n=====\n"
404            << roundtrip << "\n>>>>>roundtripped\n";
405   }
406 
407   return success();
408 }
409 
410 static LogicalResult doVerifyRoundTrip(Operation *op,
411                                        const MlirOptMainConfig &config) {
412   auto txtStatus = doVerifyRoundTrip(op, config, /*useBytecode=*/false);
413   auto bcStatus = doVerifyRoundTrip(op, config, /*useBytecode=*/true);
414   return success(succeeded(txtStatus) && succeeded(bcStatus));
415 }
416 
417 /// Perform the actions on the input file indicated by the command line flags
418 /// within the specified context.
419 ///
420 /// This typically parses the main source file, runs zero or more optimization
421 /// passes, then prints the output.
422 ///
423 static LogicalResult
424 performActions(raw_ostream &os,
425                const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
426                MLIRContext *context, const MlirOptMainConfig &config) {
427   DefaultTimingManager tm;
428   applyDefaultTimingManagerCLOptions(tm);
429   TimingScope timing = tm.getRootScope();
430 
431   // Disable multi-threading when parsing the input file. This removes the
432   // unnecessary/costly context synchronization when parsing.
433   bool wasThreadingEnabled = context->isMultithreadingEnabled();
434   context->disableMultithreading();
435 
436   // Prepare the parser config, and attach any useful/necessary resource
437   // handlers. Unhandled external resources are treated as passthrough, i.e.
438   // they are not processed and will be emitted directly to the output
439   // untouched.
440   PassReproducerOptions reproOptions;
441   FallbackAsmResourceMap fallbackResourceMap;
442   ParserConfig parseConfig(context, config.shouldVerifyOnParsing(),
443                            &fallbackResourceMap);
444   if (config.shouldRunReproducer())
445     reproOptions.attachResourceParser(parseConfig);
446 
447   // Parse the input file and reset the context threading state.
448   TimingScope parserTiming = timing.nest("Parser");
449   OwningOpRef<Operation *> op = parseSourceFileForTool(
450       sourceMgr, parseConfig, !config.shouldUseExplicitModule());
451   parserTiming.stop();
452   if (!op)
453     return failure();
454 
455   // Perform round-trip verification if requested
456   if (config.shouldVerifyRoundtrip() &&
457       failed(doVerifyRoundTrip(op.get(), config)))
458     return failure();
459 
460   context->enableMultithreading(wasThreadingEnabled);
461 
462   // Prepare the pass manager, applying command-line and reproducer options.
463   PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit);
464   pm.enableVerifier(config.shouldVerifyPasses());
465   if (failed(applyPassManagerCLOptions(pm)))
466     return failure();
467   pm.enableTiming(timing);
468   if (config.shouldRunReproducer() && failed(reproOptions.apply(pm)))
469     return failure();
470   if (failed(config.setupPassPipeline(pm)))
471     return failure();
472 
473   // Run the pipeline.
474   if (failed(pm.run(*op)))
475     return failure();
476 
477   // Generate reproducers if requested
478   if (!config.getReproducerFilename().empty()) {
479     StringRef anchorName = pm.getAnyOpAnchorName();
480     const auto &passes = pm.getPasses();
481     makeReproducer(anchorName, passes, op.get(),
482                    config.getReproducerFilename());
483   }
484 
485   // Print the output.
486   TimingScope outputTiming = timing.nest("Output");
487   if (config.shouldEmitBytecode()) {
488     BytecodeWriterConfig writerConfig(fallbackResourceMap);
489     if (auto v = config.bytecodeVersionToEmit())
490       writerConfig.setDesiredBytecodeVersion(*v);
491     if (config.shouldElideResourceDataFromBytecode())
492       writerConfig.setElideResourceDataFlag();
493     return writeBytecodeToFile(op.get(), os, writerConfig);
494   }
495 
496   if (config.bytecodeVersionToEmit().has_value())
497     return emitError(UnknownLoc::get(pm.getContext()))
498            << "bytecode version while not emitting bytecode";
499   AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr,
500                     &fallbackResourceMap);
501   op.get()->print(os, asmState);
502   os << '\n';
503   return success();
504 }
505 
506 /// Parses the memory buffer.  If successfully, run a series of passes against
507 /// it and print the result.
508 static LogicalResult processBuffer(raw_ostream &os,
509                                    std::unique_ptr<MemoryBuffer> ownedBuffer,
510                                    const MlirOptMainConfig &config,
511                                    DialectRegistry &registry,
512                                    llvm::ThreadPoolInterface *threadPool) {
513   // Tell sourceMgr about this buffer, which is what the parser will pick up.
514   auto sourceMgr = std::make_shared<SourceMgr>();
515   sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
516 
517   // Create a context just for the current buffer. Disable threading on creation
518   // since we'll inject the thread-pool separately.
519   MLIRContext context(registry, MLIRContext::Threading::DISABLED);
520   if (threadPool)
521     context.setThreadPool(*threadPool);
522 
523   StringRef irdlFile = config.getIrdlFile();
524   if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, context)))
525     return failure();
526 
527   // Parse the input file.
528   context.allowUnregisteredDialects(config.shouldAllowUnregisteredDialects());
529   if (config.shouldVerifyDiagnostics())
530     context.printOpOnDiagnostic(false);
531 
532   tracing::InstallDebugHandler installDebugHandler(context,
533                                                    config.getDebugConfig());
534 
535   // If we are in verify diagnostics mode then we have a lot of work to do,
536   // otherwise just perform the actions without worrying about it.
537   if (!config.shouldVerifyDiagnostics()) {
538     SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context);
539     DiagnosticFilter diagnosticFilter(&context,
540                                       config.getDiagnosticVerbosityLevel(),
541                                       config.shouldShowNotes());
542     return performActions(os, sourceMgr, &context, config);
543   }
544 
545   SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context);
546 
547   // Do any processing requested by command line flags.  We don't care whether
548   // these actions succeed or fail, we only care what diagnostics they produce
549   // and whether they match our expectations.
550   (void)performActions(os, sourceMgr, &context, config);
551 
552   // Verify the diagnostic handler to make sure that each of the diagnostics
553   // matched.
554   return sourceMgrHandler.verify();
555 }
556 
557 std::pair<std::string, std::string>
558 mlir::registerAndParseCLIOptions(int argc, char **argv,
559                                  llvm::StringRef toolName,
560                                  DialectRegistry &registry) {
561   static cl::opt<std::string> inputFilename(
562       cl::Positional, cl::desc("<input file>"), cl::init("-"));
563 
564   static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
565                                              cl::value_desc("filename"),
566                                              cl::init("-"));
567   // Register any command line options.
568   MlirOptMainConfig::registerCLOptions(registry);
569   registerAsmPrinterCLOptions();
570   registerMLIRContextCLOptions();
571   registerPassManagerCLOptions();
572   registerDefaultTimingManagerCLOptions();
573   tracing::DebugCounter::registerCLOptions();
574 
575   // Build the list of dialects as a header for the --help message.
576   std::string helpHeader = (toolName + "\nAvailable Dialects: ").str();
577   {
578     llvm::raw_string_ostream os(helpHeader);
579     interleaveComma(registry.getDialectNames(), os,
580                     [&](auto name) { os << name; });
581   }
582   // Parse pass names in main to ensure static initialization completed.
583   cl::ParseCommandLineOptions(argc, argv, helpHeader);
584   return std::make_pair(inputFilename.getValue(), outputFilename.getValue());
585 }
586 
587 static LogicalResult printRegisteredDialects(DialectRegistry &registry) {
588   llvm::outs() << "Available Dialects: ";
589   interleave(registry.getDialectNames(), llvm::outs(), ",");
590   llvm::outs() << "\n";
591   return success();
592 }
593 
594 static LogicalResult printRegisteredPassesAndReturn() {
595   mlir::printRegisteredPasses();
596   return success();
597 }
598 
599 LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream,
600                                 std::unique_ptr<llvm::MemoryBuffer> buffer,
601                                 DialectRegistry &registry,
602                                 const MlirOptMainConfig &config) {
603   if (config.shouldShowDialects())
604     return printRegisteredDialects(registry);
605 
606   if (config.shouldListPasses())
607     return printRegisteredPassesAndReturn();
608 
609   // The split-input-file mode is a very specific mode that slices the file
610   // up into small pieces and checks each independently.
611   // We use an explicit threadpool to avoid creating and joining/destroying
612   // threads for each of the split.
613   ThreadPoolInterface *threadPool = nullptr;
614 
615   // Create a temporary context for the sake of checking if
616   // --mlir-disable-threading was passed on the command line.
617   // We use the thread-pool this context is creating, and avoid
618   // creating any thread when disabled.
619   MLIRContext threadPoolCtx;
620   if (threadPoolCtx.isMultithreadingEnabled())
621     threadPool = &threadPoolCtx.getThreadPool();
622 
623   auto chunkFn = [&](std::unique_ptr<MemoryBuffer> chunkBuffer,
624                      raw_ostream &os) {
625     return processBuffer(os, std::move(chunkBuffer), config, registry,
626                          threadPool);
627   };
628   return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream,
629                                config.inputSplitMarker(),
630                                config.outputSplitMarker());
631 }
632 
633 LogicalResult mlir::MlirOptMain(int argc, char **argv,
634                                 llvm::StringRef inputFilename,
635                                 llvm::StringRef outputFilename,
636                                 DialectRegistry &registry) {
637 
638   InitLLVM y(argc, argv);
639 
640   MlirOptMainConfig config = MlirOptMainConfig::createFromCLOptions();
641 
642   if (config.shouldShowDialects())
643     return printRegisteredDialects(registry);
644 
645   if (config.shouldListPasses())
646     return printRegisteredPassesAndReturn();
647 
648   // When reading from stdin and the input is a tty, it is often a user mistake
649   // and the process "appears to be stuck". Print a message to let the user know
650   // about it!
651   if (inputFilename == "-" &&
652       sys::Process::FileDescriptorIsDisplayed(fileno(stdin)))
653     llvm::errs() << "(processing input from stdin now, hit ctrl-c/ctrl-d to "
654                     "interrupt)\n";
655 
656   // Set up the input file.
657   std::string errorMessage;
658   auto file = openInputFile(inputFilename, &errorMessage);
659   if (!file) {
660     llvm::errs() << errorMessage << "\n";
661     return failure();
662   }
663 
664   auto output = openOutputFile(outputFilename, &errorMessage);
665   if (!output) {
666     llvm::errs() << errorMessage << "\n";
667     return failure();
668   }
669   if (failed(MlirOptMain(output->os(), std::move(file), registry, config)))
670     return failure();
671 
672   // Keep the output file if the invocation of MlirOptMain was successful.
673   output->keep();
674   return success();
675 }
676 
677 LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
678                                 DialectRegistry &registry) {
679 
680   // Register and parse command line options.
681   std::string inputFilename, outputFilename;
682   std::tie(inputFilename, outputFilename) =
683       registerAndParseCLIOptions(argc, argv, toolName, registry);
684 
685   return MlirOptMain(argc, argv, inputFilename, outputFilename, registry);
686 }
687