//===- MlirOptMain.cpp - MLIR Optimizer Driver ----------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This is a utility that runs an optimization pass and prints the result back // out. It is designed to support unit testing. // //===----------------------------------------------------------------------===// #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/Debug/CLOptionsSetup.h" #include "mlir/Debug/Counter.h" #include "mlir/Debug/DebuggerExecutionContextHook.h" #include "mlir/Debug/ExecutionContext.h" #include "mlir/Debug/Observers/ActionLogging.h" #include "mlir/Dialect/IRDL/IR/IRDL.h" #include "mlir/Dialect/IRDL/IRDLLoading.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/Timing.h" #include "mlir/Support/ToolUtilities.h" #include "mlir/Tools/ParseUtilities.h" #include "mlir/Tools/Plugins/DialectPlugin.h" #include "mlir/Tools/Plugins/PassPlugin.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileUtilities.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Process.h" #include "llvm/Support/Regex.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/StringSaver.h" #include "llvm/Support/ThreadPool.h" #include "llvm/Support/ToolOutputFile.h" using namespace mlir; using namespace llvm; namespace { class BytecodeVersionParser : public cl::parser> { public: BytecodeVersionParser(cl::Option &o) : cl::parser>(o) {} bool parse(cl::Option &o, StringRef /*argName*/, StringRef arg, std::optional &v) { long long w; if (getAsSignedInteger(arg, 10, w)) return o.error("Invalid argument '" + arg + "', only integer is supported."); v = w; return false; } }; /// This class is intended to manage the handling of command line options for /// creating a *-opt config. This is a singleton. struct MlirOptMainConfigCLOptions : public MlirOptMainConfig { MlirOptMainConfigCLOptions() { // These options are static but all uses ExternalStorage to initialize the // members of the parent class. This is unusual but since this class is a // singleton it basically attaches command line option to the singleton // members. static cl::opt allowUnregisteredDialects( "allow-unregistered-dialect", cl::desc("Allow operation with no registered dialects"), cl::location(allowUnregisteredDialectsFlag), cl::init(false)); static cl::opt dumpPassPipeline( "dump-pass-pipeline", cl::desc("Print the pipeline that will be run"), cl::location(dumpPassPipelineFlag), cl::init(false)); static cl::opt emitBytecode( "emit-bytecode", cl::desc("Emit bytecode when generating output"), cl::location(emitBytecodeFlag), cl::init(false)); static cl::opt elideResourcesFromBytecode( "elide-resource-data-from-bytecode", cl::desc("Elide resources when generating bytecode"), cl::location(elideResourceDataFromBytecodeFlag), cl::init(false)); static cl::opt, /*ExternalStorage=*/true, BytecodeVersionParser> bytecodeVersion( "emit-bytecode-version", cl::desc("Use specified bytecode when generating output"), cl::location(emitBytecodeVersion), cl::init(std::nullopt)); static cl::opt irdlFile( "irdl-file", cl::desc("IRDL file to register before processing the input"), cl::location(irdlFileFlag), cl::init(""), cl::value_desc("filename")); static cl::opt diagnosticVerbosityLevel( "mlir-diagnostic-verbosity-level", cl::desc("Choose level of diagnostic information"), cl::location(diagnosticVerbosityLevelFlag), cl::init(VerbosityLevel::ErrorsWarningsAndRemarks), cl::values( clEnumValN(VerbosityLevel::ErrorsOnly, "errors", "Errors only"), clEnumValN(VerbosityLevel::ErrorsAndWarnings, "warnings", "Errors and warnings"), clEnumValN(VerbosityLevel::ErrorsWarningsAndRemarks, "remarks", "Errors, warnings and remarks"))); static cl::opt disableDiagnosticNotes( "mlir-disable-diagnostic-notes", cl::desc("Disable diagnostic notes."), cl::location(disableDiagnosticNotesFlag), cl::init(false)); static cl::opt enableDebuggerHook( "mlir-enable-debugger-hook", cl::desc("Enable Debugger hook for debugging MLIR Actions"), cl::location(enableDebuggerActionHookFlag), cl::init(false)); static cl::opt explicitModule( "no-implicit-module", cl::desc("Disable implicit addition of a top-level module op during " "parsing"), cl::location(useExplicitModuleFlag), cl::init(false)); static cl::opt listPasses( "list-passes", cl::desc("Print the list of registered passes and exit"), cl::location(listPassesFlag), cl::init(false)); static cl::opt runReproducer( "run-reproducer", cl::desc("Run the pipeline stored in the reproducer"), cl::location(runReproducerFlag), cl::init(false)); static cl::opt showDialects( "show-dialects", cl::desc("Print the list of registered dialects and exit"), cl::location(showDialectsFlag), cl::init(false)); static cl::opt splitInputFile{ "split-input-file", llvm::cl::ValueOptional, cl::callback([&](const std::string &str) { // Implicit value: use default marker if flag was used without value. if (str.empty()) splitInputFile.setValue(kDefaultSplitMarker); }), cl::desc("Split the input file into chunks using the given or " "default marker and process each chunk independently"), cl::location(splitInputFileFlag), cl::init("")}; static cl::opt outputSplitMarker( "output-split-marker", cl::desc("Split marker to use for merging the ouput"), cl::location(outputSplitMarkerFlag), cl::init(kDefaultSplitMarker)); static cl::opt verifyDiagnostics( "verify-diagnostics", cl::desc("Check that emitted diagnostics match " "expected-* lines on the corresponding line"), cl::location(verifyDiagnosticsFlag), cl::init(false)); static cl::opt verifyPasses( "verify-each", cl::desc("Run the verifier after each transformation pass"), cl::location(verifyPassesFlag), cl::init(true)); static cl::opt disableVerifyOnParsing( "mlir-very-unsafe-disable-verifier-on-parsing", cl::desc("Disable the verifier on parsing (very unsafe)"), cl::location(disableVerifierOnParsingFlag), cl::init(false)); static cl::opt verifyRoundtrip( "verify-roundtrip", cl::desc("Round-trip the IR after parsing and ensure it succeeds"), cl::location(verifyRoundtripFlag), cl::init(false)); static cl::list passPlugins( "load-pass-plugin", cl::desc("Load passes from plugin library")); static cl::opt generateReproducerFile( "mlir-generate-reproducer", llvm::cl::desc( "Generate an mlir reproducer at the provided filename" " (no crash required)"), cl::location(generateReproducerFileFlag), cl::init(""), cl::value_desc("filename")); /// Set the callback to load a pass plugin. passPlugins.setCallback([&](const std::string &pluginPath) { auto plugin = PassPlugin::load(pluginPath); if (!plugin) { errs() << "Failed to load passes from '" << pluginPath << "'. Request ignored.\n"; return; } plugin.get().registerPassRegistryCallbacks(); }); static cl::list dialectPlugins( "load-dialect-plugin", cl::desc("Load dialects from plugin library")); this->dialectPlugins = std::addressof(dialectPlugins); static PassPipelineCLParser passPipeline("", "Compiler passes to run", "p"); setPassPipelineParser(passPipeline); } /// Set the callback to load a dialect plugin. void setDialectPluginsCallback(DialectRegistry ®istry); /// Pointer to static dialectPlugins variable in constructor, needed by /// setDialectPluginsCallback(DialectRegistry&). cl::list *dialectPlugins = nullptr; }; /// A scoped diagnostic handler that suppresses certain diagnostics based on /// the verbosity level and whether the diagnostic is a note. class DiagnosticFilter : public ScopedDiagnosticHandler { public: DiagnosticFilter(MLIRContext *ctx, VerbosityLevel verbosityLevel, bool showNotes = true) : ScopedDiagnosticHandler(ctx) { setHandler([verbosityLevel, showNotes](Diagnostic &diag) { auto severity = diag.getSeverity(); switch (severity) { case DiagnosticSeverity::Error: // failure indicates that the error is not handled by the filter and // goes through to the default handler. Therefore, the error can be // successfully printed. return failure(); case DiagnosticSeverity::Warning: if (verbosityLevel == VerbosityLevel::ErrorsOnly) return success(); else return failure(); case DiagnosticSeverity::Remark: if (verbosityLevel == VerbosityLevel::ErrorsOnly || verbosityLevel == VerbosityLevel::ErrorsAndWarnings) return success(); else return failure(); case DiagnosticSeverity::Note: if (showNotes) return failure(); else return success(); } llvm_unreachable("Unknown diagnostic severity"); }); } }; } // namespace ManagedStatic clOptionsConfig; void MlirOptMainConfig::registerCLOptions(DialectRegistry ®istry) { clOptionsConfig->setDialectPluginsCallback(registry); tracing::DebugConfig::registerCLOptions(); } MlirOptMainConfig MlirOptMainConfig::createFromCLOptions() { clOptionsConfig->setDebugConfig(tracing::DebugConfig::createFromCLOptions()); return *clOptionsConfig; } MlirOptMainConfig &MlirOptMainConfig::setPassPipelineParser( const PassPipelineCLParser &passPipeline) { passPipelineCallback = [&](PassManager &pm) { auto errorHandler = [&](const Twine &msg) { emitError(UnknownLoc::get(pm.getContext())) << msg; return failure(); }; if (failed(passPipeline.addToPipeline(pm, errorHandler))) return failure(); if (this->shouldDumpPassPipeline()) { pm.dump(); llvm::errs() << "\n"; } return success(); }; return *this; } void MlirOptMainConfigCLOptions::setDialectPluginsCallback( DialectRegistry ®istry) { dialectPlugins->setCallback([&](const std::string &pluginPath) { auto plugin = DialectPlugin::load(pluginPath); if (!plugin) { errs() << "Failed to load dialect plugin from '" << pluginPath << "'. Request ignored.\n"; return; }; plugin.get().registerDialectRegistryCallbacks(registry); }); } LogicalResult loadIRDLDialects(StringRef irdlFile, MLIRContext &ctx) { DialectRegistry registry; registry.insert(); ctx.appendDialectRegistry(registry); // Set up the input file. std::string errorMessage; std::unique_ptr file = openInputFile(irdlFile, &errorMessage); if (!file) { emitError(UnknownLoc::get(&ctx)) << errorMessage; return failure(); } // Give the buffer to the source manager. // This will be picked up by the parser. SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &ctx); // Parse the input file. OwningOpRef module(parseSourceFile(sourceMgr, &ctx)); if (!module) return failure(); // Load IRDL dialects. return irdl::loadDialects(module.get()); } // Return success if the module can correctly round-trip. This intended to test // that the custom printers/parsers are complete. static LogicalResult doVerifyRoundTrip(Operation *op, const MlirOptMainConfig &config, bool useBytecode) { // We use a new context to avoid resource handle renaming issue in the diff. MLIRContext roundtripContext; OwningOpRef roundtripModule; roundtripContext.appendDialectRegistry( op->getContext()->getDialectRegistry()); if (op->getContext()->allowsUnregisteredDialects()) roundtripContext.allowUnregisteredDialects(); StringRef irdlFile = config.getIrdlFile(); if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, roundtripContext))) return failure(); std::string testType = (useBytecode) ? "bytecode" : "textual"; // Print a first time with custom format (or bytecode) and parse it back to // the roundtripModule. { std::string buffer; llvm::raw_string_ostream ostream(buffer); if (useBytecode) { if (failed(writeBytecodeToFile(op, ostream))) { op->emitOpError() << "failed to write bytecode, cannot verify round-trip.\n"; return failure(); } } else { op->print(ostream, OpPrintingFlags().printGenericOpForm().enableDebugInfo()); } FallbackAsmResourceMap fallbackResourceMap; ParserConfig parseConfig(&roundtripContext, config.shouldVerifyOnParsing(), &fallbackResourceMap); roundtripModule = parseSourceString(buffer, parseConfig); if (!roundtripModule) { op->emitOpError() << "failed to parse " << testType << " content back, cannot verify round-trip.\n"; return failure(); } } // Print in the generic form for the reference module and the round-tripped // one and compare the outputs. std::string reference, roundtrip; { llvm::raw_string_ostream ostreamref(reference); op->print(ostreamref, OpPrintingFlags().printGenericOpForm().enableDebugInfo()); llvm::raw_string_ostream ostreamrndtrip(roundtrip); roundtripModule.get()->print( ostreamrndtrip, OpPrintingFlags().printGenericOpForm().enableDebugInfo()); } if (reference != roundtrip) { // TODO implement a diff. return op->emitOpError() << testType << " roundTrip testing roundtripped module differs " "from reference:\n<<<<<>>>>roundtripped\n"; } return success(); } static LogicalResult doVerifyRoundTrip(Operation *op, const MlirOptMainConfig &config) { auto txtStatus = doVerifyRoundTrip(op, config, /*useBytecode=*/false); auto bcStatus = doVerifyRoundTrip(op, config, /*useBytecode=*/true); return success(succeeded(txtStatus) && succeeded(bcStatus)); } /// Perform the actions on the input file indicated by the command line flags /// within the specified context. /// /// This typically parses the main source file, runs zero or more optimization /// passes, then prints the output. /// static LogicalResult performActions(raw_ostream &os, const std::shared_ptr &sourceMgr, MLIRContext *context, const MlirOptMainConfig &config) { DefaultTimingManager tm; applyDefaultTimingManagerCLOptions(tm); TimingScope timing = tm.getRootScope(); // Disable multi-threading when parsing the input file. This removes the // unnecessary/costly context synchronization when parsing. bool wasThreadingEnabled = context->isMultithreadingEnabled(); context->disableMultithreading(); // Prepare the parser config, and attach any useful/necessary resource // handlers. Unhandled external resources are treated as passthrough, i.e. // they are not processed and will be emitted directly to the output // untouched. PassReproducerOptions reproOptions; FallbackAsmResourceMap fallbackResourceMap; ParserConfig parseConfig(context, config.shouldVerifyOnParsing(), &fallbackResourceMap); if (config.shouldRunReproducer()) reproOptions.attachResourceParser(parseConfig); // Parse the input file and reset the context threading state. TimingScope parserTiming = timing.nest("Parser"); OwningOpRef op = parseSourceFileForTool( sourceMgr, parseConfig, !config.shouldUseExplicitModule()); parserTiming.stop(); if (!op) return failure(); // Perform round-trip verification if requested if (config.shouldVerifyRoundtrip() && failed(doVerifyRoundTrip(op.get(), config))) return failure(); context->enableMultithreading(wasThreadingEnabled); // Prepare the pass manager, applying command-line and reproducer options. PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit); pm.enableVerifier(config.shouldVerifyPasses()); if (failed(applyPassManagerCLOptions(pm))) return failure(); pm.enableTiming(timing); if (config.shouldRunReproducer() && failed(reproOptions.apply(pm))) return failure(); if (failed(config.setupPassPipeline(pm))) return failure(); // Run the pipeline. if (failed(pm.run(*op))) return failure(); // Generate reproducers if requested if (!config.getReproducerFilename().empty()) { StringRef anchorName = pm.getAnyOpAnchorName(); const auto &passes = pm.getPasses(); makeReproducer(anchorName, passes, op.get(), config.getReproducerFilename()); } // Print the output. TimingScope outputTiming = timing.nest("Output"); if (config.shouldEmitBytecode()) { BytecodeWriterConfig writerConfig(fallbackResourceMap); if (auto v = config.bytecodeVersionToEmit()) writerConfig.setDesiredBytecodeVersion(*v); if (config.shouldElideResourceDataFromBytecode()) writerConfig.setElideResourceDataFlag(); return writeBytecodeToFile(op.get(), os, writerConfig); } if (config.bytecodeVersionToEmit().has_value()) return emitError(UnknownLoc::get(pm.getContext())) << "bytecode version while not emitting bytecode"; AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr, &fallbackResourceMap); op.get()->print(os, asmState); os << '\n'; return success(); } /// Parses the memory buffer. If successfully, run a series of passes against /// it and print the result. static LogicalResult processBuffer(raw_ostream &os, std::unique_ptr ownedBuffer, const MlirOptMainConfig &config, DialectRegistry ®istry, llvm::ThreadPoolInterface *threadPool) { // Tell sourceMgr about this buffer, which is what the parser will pick up. auto sourceMgr = std::make_shared(); sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); // Create a context just for the current buffer. Disable threading on creation // since we'll inject the thread-pool separately. MLIRContext context(registry, MLIRContext::Threading::DISABLED); if (threadPool) context.setThreadPool(*threadPool); StringRef irdlFile = config.getIrdlFile(); if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, context))) return failure(); // Parse the input file. context.allowUnregisteredDialects(config.shouldAllowUnregisteredDialects()); if (config.shouldVerifyDiagnostics()) context.printOpOnDiagnostic(false); tracing::InstallDebugHandler installDebugHandler(context, config.getDebugConfig()); // If we are in verify diagnostics mode then we have a lot of work to do, // otherwise just perform the actions without worrying about it. if (!config.shouldVerifyDiagnostics()) { SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context); DiagnosticFilter diagnosticFilter(&context, config.getDiagnosticVerbosityLevel(), config.shouldShowNotes()); return performActions(os, sourceMgr, &context, config); } SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context); // Do any processing requested by command line flags. We don't care whether // these actions succeed or fail, we only care what diagnostics they produce // and whether they match our expectations. (void)performActions(os, sourceMgr, &context, config); // Verify the diagnostic handler to make sure that each of the diagnostics // matched. return sourceMgrHandler.verify(); } std::pair mlir::registerAndParseCLIOptions(int argc, char **argv, llvm::StringRef toolName, DialectRegistry ®istry) { static cl::opt inputFilename( cl::Positional, cl::desc(""), cl::init("-")); static cl::opt outputFilename("o", cl::desc("Output filename"), cl::value_desc("filename"), cl::init("-")); // Register any command line options. MlirOptMainConfig::registerCLOptions(registry); registerAsmPrinterCLOptions(); registerMLIRContextCLOptions(); registerPassManagerCLOptions(); registerDefaultTimingManagerCLOptions(); tracing::DebugCounter::registerCLOptions(); // Build the list of dialects as a header for the --help message. std::string helpHeader = (toolName + "\nAvailable Dialects: ").str(); { llvm::raw_string_ostream os(helpHeader); interleaveComma(registry.getDialectNames(), os, [&](auto name) { os << name; }); } // Parse pass names in main to ensure static initialization completed. cl::ParseCommandLineOptions(argc, argv, helpHeader); return std::make_pair(inputFilename.getValue(), outputFilename.getValue()); } static LogicalResult printRegisteredDialects(DialectRegistry ®istry) { llvm::outs() << "Available Dialects: "; interleave(registry.getDialectNames(), llvm::outs(), ","); llvm::outs() << "\n"; return success(); } static LogicalResult printRegisteredPassesAndReturn() { mlir::printRegisteredPasses(); return success(); } LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream, std::unique_ptr buffer, DialectRegistry ®istry, const MlirOptMainConfig &config) { if (config.shouldShowDialects()) return printRegisteredDialects(registry); if (config.shouldListPasses()) return printRegisteredPassesAndReturn(); // The split-input-file mode is a very specific mode that slices the file // up into small pieces and checks each independently. // We use an explicit threadpool to avoid creating and joining/destroying // threads for each of the split. ThreadPoolInterface *threadPool = nullptr; // Create a temporary context for the sake of checking if // --mlir-disable-threading was passed on the command line. // We use the thread-pool this context is creating, and avoid // creating any thread when disabled. MLIRContext threadPoolCtx; if (threadPoolCtx.isMultithreadingEnabled()) threadPool = &threadPoolCtx.getThreadPool(); auto chunkFn = [&](std::unique_ptr chunkBuffer, raw_ostream &os) { return processBuffer(os, std::move(chunkBuffer), config, registry, threadPool); }; return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream, config.inputSplitMarker(), config.outputSplitMarker()); } LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef inputFilename, llvm::StringRef outputFilename, DialectRegistry ®istry) { InitLLVM y(argc, argv); MlirOptMainConfig config = MlirOptMainConfig::createFromCLOptions(); if (config.shouldShowDialects()) return printRegisteredDialects(registry); if (config.shouldListPasses()) return printRegisteredPassesAndReturn(); // When reading from stdin and the input is a tty, it is often a user mistake // and the process "appears to be stuck". Print a message to let the user know // about it! if (inputFilename == "-" && sys::Process::FileDescriptorIsDisplayed(fileno(stdin))) llvm::errs() << "(processing input from stdin now, hit ctrl-c/ctrl-d to " "interrupt)\n"; // Set up the input file. std::string errorMessage; auto file = openInputFile(inputFilename, &errorMessage); if (!file) { llvm::errs() << errorMessage << "\n"; return failure(); } auto output = openOutputFile(outputFilename, &errorMessage); if (!output) { llvm::errs() << errorMessage << "\n"; return failure(); } if (failed(MlirOptMain(output->os(), std::move(file), registry, config))) return failure(); // Keep the output file if the invocation of MlirOptMain was successful. output->keep(); return success(); } LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName, DialectRegistry ®istry) { // Register and parse command line options. std::string inputFilename, outputFilename; std::tie(inputFilename, outputFilename) = registerAndParseCLIOptions(argc, argv, toolName, registry); return MlirOptMain(argc, argv, inputFilename, outputFilename, registry); }