1 //===- mlir-transform-opt.cpp -----------------------------------*- C++ -*-===// 2 // 3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 10 #include "mlir/Dialect/Transform/IR/Utils.h" 11 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" 12 #include "mlir/IR/AsmState.h" 13 #include "mlir/IR/BuiltinOps.h" 14 #include "mlir/IR/Diagnostics.h" 15 #include "mlir/IR/DialectRegistry.h" 16 #include "mlir/IR/MLIRContext.h" 17 #include "mlir/InitAllDialects.h" 18 #include "mlir/InitAllExtensions.h" 19 #include "mlir/InitAllPasses.h" 20 #include "mlir/Parser/Parser.h" 21 #include "mlir/Support/FileUtilities.h" 22 #include "mlir/Tools/mlir-opt/MlirOptMain.h" 23 #include "llvm/Support/CommandLine.h" 24 #include "llvm/Support/InitLLVM.h" 25 #include "llvm/Support/SourceMgr.h" 26 #include "llvm/Support/ToolOutputFile.h" 27 #include <cstdlib> 28 29 namespace { 30 31 using namespace llvm; 32 33 /// Structure containing command line options for the tool, these will get 34 /// initialized when an instance is created. 35 struct MlirTransformOptCLOptions { 36 cl::opt<bool> allowUnregisteredDialects{ 37 "allow-unregistered-dialect", 38 cl::desc("Allow operations coming from an unregistered dialect"), 39 cl::init(false)}; 40 41 cl::opt<bool> verifyDiagnostics{ 42 "verify-diagnostics", 43 cl::desc("Check that emitted diagnostics match expected-* lines " 44 "on the corresponding line"), 45 cl::init(false)}; 46 47 cl::opt<std::string> payloadFilename{cl::Positional, cl::desc("<input file>"), 48 cl::init("-")}; 49 50 cl::opt<std::string> outputFilename{"o", cl::desc("Output filename"), 51 cl::value_desc("filename"), 52 cl::init("-")}; 53 54 cl::opt<std::string> transformMainFilename{ 55 "transform", 56 cl::desc("File containing entry point of the transform script, if " 57 "different from the input file"), 58 cl::value_desc("filename"), cl::init("")}; 59 60 cl::list<std::string> transformLibraryFilenames{ 61 "transform-library", cl::desc("File(s) containing definitions of " 62 "additional transform script symbols")}; 63 64 cl::opt<std::string> transformEntryPoint{ 65 "transform-entry-point", 66 cl::desc("Name of the entry point transform symbol"), 67 cl::init(mlir::transform::TransformDialect::kTransformEntryPointSymbolName 68 .str())}; 69 70 cl::opt<bool> disableExpensiveChecks{ 71 "disable-expensive-checks", 72 cl::desc("Disables potentially expensive checks in the transform " 73 "interpreter, providing more speed at the expense of " 74 "potential memory problems and silent corruptions"), 75 cl::init(false)}; 76 77 cl::opt<bool> dumpLibraryModule{ 78 "dump-library-module", 79 cl::desc("Prints the combined library module before the output"), 80 cl::init(false)}; 81 }; 82 } // namespace 83 84 /// "Managed" static instance of the command-line options structure. This makes 85 /// them locally-scoped and explicitly initialized/deinitialized. While this is 86 /// not strictly necessary in the tool source file that is not being used as a 87 /// library (where the options would pollute the global list of options), it is 88 /// good practice to follow this. 89 static llvm::ManagedStatic<MlirTransformOptCLOptions> clOptions; 90 91 /// Explicitly registers command-line options. 92 static void registerCLOptions() { *clOptions; } 93 94 namespace { 95 /// A wrapper class for source managers diagnostic. This provides both unique 96 /// ownership and virtual function-like overload for a pair of 97 /// inheritance-related classes that do not use virtual functions. 98 class DiagnosticHandlerWrapper { 99 public: 100 /// Kind of the diagnostic handler to use. 101 enum class Kind { EmitDiagnostics, VerifyDiagnostics }; 102 103 /// Constructs the diagnostic handler of the specified kind of the given 104 /// source manager and context. 105 DiagnosticHandlerWrapper(Kind kind, llvm::SourceMgr &mgr, 106 mlir::MLIRContext *context) { 107 if (kind == Kind::EmitDiagnostics) 108 handler = new mlir::SourceMgrDiagnosticHandler(mgr, context); 109 else 110 handler = new mlir::SourceMgrDiagnosticVerifierHandler(mgr, context); 111 } 112 113 /// This object is non-copyable but movable. 114 DiagnosticHandlerWrapper(const DiagnosticHandlerWrapper &) = delete; 115 DiagnosticHandlerWrapper(DiagnosticHandlerWrapper &&other) = default; 116 DiagnosticHandlerWrapper & 117 operator=(const DiagnosticHandlerWrapper &) = delete; 118 DiagnosticHandlerWrapper &operator=(DiagnosticHandlerWrapper &&) = default; 119 120 /// Verifies the captured "expected-*" diagnostics if required. 121 llvm::LogicalResult verify() const { 122 if (auto *ptr = 123 handler.dyn_cast<mlir::SourceMgrDiagnosticVerifierHandler *>()) { 124 return ptr->verify(); 125 } 126 return mlir::success(); 127 } 128 129 /// Destructs the object of the same type as allocated. 130 ~DiagnosticHandlerWrapper() { 131 if (auto *ptr = handler.dyn_cast<mlir::SourceMgrDiagnosticHandler *>()) { 132 delete ptr; 133 } else { 134 delete cast<mlir::SourceMgrDiagnosticVerifierHandler *>(handler); 135 } 136 } 137 138 private: 139 /// Internal storage is a type-safe union. 140 llvm::PointerUnion<mlir::SourceMgrDiagnosticHandler *, 141 mlir::SourceMgrDiagnosticVerifierHandler *> 142 handler; 143 }; 144 145 /// MLIR has deeply rooted expectations that the LLVM source manager contains 146 /// exactly one buffer, until at least the lexer level. This class wraps 147 /// multiple LLVM source managers each managing a buffer to match MLIR's 148 /// expectations while still providing a centralized handling mechanism. 149 class TransformSourceMgr { 150 public: 151 /// Constructs the source manager indicating whether diagnostic messages will 152 /// be verified later on. 153 explicit TransformSourceMgr(bool verifyDiagnostics) 154 : verifyDiagnostics(verifyDiagnostics) {} 155 156 /// Deconstructs the source manager. Note that `checkResults` must have been 157 /// called on this instance before deconstructing it. 158 ~TransformSourceMgr() { 159 assert(resultChecked && "must check the result of diagnostic handlers by " 160 "running TransformSourceMgr::checkResult"); 161 } 162 163 /// Parses the given buffer and creates the top-level operation of the kind 164 /// specified as template argument in the given context. Additional parsing 165 /// options may be provided. 166 template <typename OpTy = mlir::Operation *> 167 mlir::OwningOpRef<OpTy> parseBuffer(std::unique_ptr<MemoryBuffer> buffer, 168 mlir::MLIRContext &context, 169 const mlir::ParserConfig &config) { 170 // Create a single-buffer LLVM source manager. Note that `unique_ptr` allows 171 // the code below to capture a reference to the source manager in such a way 172 // that it is not invalidated when the vector contents is eventually 173 // reallocated. 174 llvm::SourceMgr &mgr = 175 *sourceMgrs.emplace_back(std::make_unique<llvm::SourceMgr>()); 176 mgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); 177 178 // Choose the type of diagnostic handler depending on whether diagnostic 179 // verification needs to happen and store it. 180 if (verifyDiagnostics) { 181 diagHandlers.emplace_back( 182 DiagnosticHandlerWrapper::Kind::VerifyDiagnostics, mgr, &context); 183 } else { 184 diagHandlers.emplace_back(DiagnosticHandlerWrapper::Kind::EmitDiagnostics, 185 mgr, &context); 186 } 187 188 // Defer to MLIR's parser. 189 return mlir::parseSourceFile<OpTy>(mgr, config); 190 } 191 192 /// If diagnostic message verification has been requested upon construction of 193 /// this source manager, performs the verification, reports errors and returns 194 /// the result of the verification. Otherwise passes through the given value. 195 llvm::LogicalResult checkResult(llvm::LogicalResult result) { 196 resultChecked = true; 197 if (!verifyDiagnostics) 198 return result; 199 200 return mlir::failure(llvm::any_of(diagHandlers, [](const auto &handler) { 201 return mlir::failed(handler.verify()); 202 })); 203 } 204 205 private: 206 /// Indicates whether diagnostic message verification is requested. 207 const bool verifyDiagnostics; 208 209 /// Indicates that diagnostic message verification has taken place, and the 210 /// deconstruction is therefore safe. 211 bool resultChecked = false; 212 213 /// Storage for per-buffer source managers and diagnostic handlers. These are 214 /// wrapped into unique pointers in order to make it safe to capture 215 /// references to these objects: if the vector is reallocated, the unique 216 /// pointer objects are moved by the pointer addresses won't change. Also, for 217 /// handlers, this allows to store the pointer to the base class. 218 SmallVector<std::unique_ptr<llvm::SourceMgr>> sourceMgrs; 219 SmallVector<DiagnosticHandlerWrapper> diagHandlers; 220 }; 221 } // namespace 222 223 /// Trivial wrapper around `applyTransforms` that doesn't support extra mapping 224 /// and doesn't enforce the entry point transform ops being top-level. 225 static llvm::LogicalResult 226 applyTransforms(mlir::Operation *payloadRoot, 227 mlir::transform::TransformOpInterface transformRoot, 228 const mlir::transform::TransformOptions &options) { 229 return applyTransforms(payloadRoot, transformRoot, {}, options, 230 /*enforceToplevelTransformOp=*/false); 231 } 232 233 /// Applies transforms indicated in the transform dialect script to the input 234 /// buffer. The transform script may be embedded in the input buffer or as a 235 /// separate buffer. The transform script may have external symbols, the 236 /// definitions of which must be provided in transform library buffers. If the 237 /// application is successful, prints the transformed input buffer into the 238 /// given output stream. Additional configuration options are derived from 239 /// command-line options. 240 static llvm::LogicalResult processPayloadBuffer( 241 raw_ostream &os, std::unique_ptr<MemoryBuffer> inputBuffer, 242 std::unique_ptr<llvm::MemoryBuffer> transformBuffer, 243 MutableArrayRef<std::unique_ptr<MemoryBuffer>> transformLibraries, 244 mlir::DialectRegistry ®istry) { 245 246 // Initialize the MLIR context, and various configurations. 247 mlir::MLIRContext context(registry, mlir::MLIRContext::Threading::DISABLED); 248 context.allowUnregisteredDialects(clOptions->allowUnregisteredDialects); 249 mlir::ParserConfig config(&context); 250 TransformSourceMgr sourceMgr( 251 /*verifyDiagnostics=*/clOptions->verifyDiagnostics); 252 253 // Parse the input buffer that will be used as transform payload. 254 mlir::OwningOpRef<mlir::Operation *> payloadRoot = 255 sourceMgr.parseBuffer(std::move(inputBuffer), context, config); 256 if (!payloadRoot) 257 return sourceMgr.checkResult(mlir::failure()); 258 259 // Identify the module containing the transform script entry point. This may 260 // be the same module as the input or a separate module. In the former case, 261 // make a copy of the module so it can be modified freely. Modification may 262 // happen in the script itself (at which point it could be rewriting itself 263 // during interpretation, leading to tricky memory errors) or by embedding 264 // library modules in the script. 265 mlir::OwningOpRef<mlir::ModuleOp> transformRoot; 266 if (transformBuffer) { 267 transformRoot = sourceMgr.parseBuffer<mlir::ModuleOp>( 268 std::move(transformBuffer), context, config); 269 if (!transformRoot) 270 return sourceMgr.checkResult(mlir::failure()); 271 } else { 272 transformRoot = cast<mlir::ModuleOp>(payloadRoot->clone()); 273 } 274 275 // Parse and merge the libraries into the main transform module. 276 for (auto &&transformLibrary : transformLibraries) { 277 mlir::OwningOpRef<mlir::ModuleOp> libraryModule = 278 sourceMgr.parseBuffer<mlir::ModuleOp>(std::move(transformLibrary), 279 context, config); 280 281 if (!libraryModule || 282 mlir::failed(mlir::transform::detail::mergeSymbolsInto( 283 *transformRoot, std::move(libraryModule)))) 284 return sourceMgr.checkResult(mlir::failure()); 285 } 286 287 // If requested, dump the combined transform module. 288 if (clOptions->dumpLibraryModule) 289 transformRoot->dump(); 290 291 // Find the entry point symbol. Even if it had originally been in the payload 292 // module, it was cloned into the transform module so only look there. 293 mlir::transform::TransformOpInterface entryPoint = 294 mlir::transform::detail::findTransformEntryPoint( 295 *transformRoot, mlir::ModuleOp(), clOptions->transformEntryPoint); 296 if (!entryPoint) 297 return sourceMgr.checkResult(mlir::failure()); 298 299 // Apply the requested transformations. 300 mlir::transform::TransformOptions transformOptions; 301 transformOptions.enableExpensiveChecks(!clOptions->disableExpensiveChecks); 302 if (mlir::failed(applyTransforms(*payloadRoot, entryPoint, transformOptions))) 303 return sourceMgr.checkResult(mlir::failure()); 304 305 // Print the transformed result and check the captured diagnostics if 306 // requested. 307 payloadRoot->print(os); 308 return sourceMgr.checkResult(mlir::success()); 309 } 310 311 /// Tool entry point. 312 static llvm::LogicalResult runMain(int argc, char **argv) { 313 // Register all upstream dialects and extensions. Specific uses are advised 314 // not to register all dialects indiscriminately but rather hand-pick what is 315 // necessary for their use case. 316 mlir::DialectRegistry registry; 317 mlir::registerAllDialects(registry); 318 mlir::registerAllExtensions(registry); 319 mlir::registerAllPasses(); 320 321 // Explicitly register the transform dialect. This is not strictly necessary 322 // since it has been already registered as part of the upstream dialect list, 323 // but useful for example purposes for cases when dialects to register are 324 // hand-picked. The transform dialect must be registered. 325 registry.insert<mlir::transform::TransformDialect>(); 326 327 // Register various command-line options. Note that the LLVM initializer 328 // object is a RAII that ensures correct deconstruction of command-line option 329 // objects inside ManagedStatic. 330 llvm::InitLLVM y(argc, argv); 331 mlir::registerAsmPrinterCLOptions(); 332 mlir::registerMLIRContextCLOptions(); 333 registerCLOptions(); 334 llvm::cl::ParseCommandLineOptions(argc, argv, 335 "Minimal Transform dialect driver\n"); 336 337 // Try opening the main input file. 338 std::string errorMessage; 339 std::unique_ptr<llvm::MemoryBuffer> payloadFile = 340 mlir::openInputFile(clOptions->payloadFilename, &errorMessage); 341 if (!payloadFile) { 342 llvm::errs() << errorMessage << "\n"; 343 return mlir::failure(); 344 } 345 346 // Try opening the output file. 347 std::unique_ptr<llvm::ToolOutputFile> outputFile = 348 mlir::openOutputFile(clOptions->outputFilename, &errorMessage); 349 if (!outputFile) { 350 llvm::errs() << errorMessage << "\n"; 351 return mlir::failure(); 352 } 353 354 // Try opening the main transform file if provided. 355 std::unique_ptr<llvm::MemoryBuffer> transformRootFile; 356 if (!clOptions->transformMainFilename.empty()) { 357 if (clOptions->transformMainFilename == clOptions->payloadFilename) { 358 llvm::errs() << "warning: " << clOptions->payloadFilename 359 << " is provided as both payload and transform file\n"; 360 } else { 361 transformRootFile = 362 mlir::openInputFile(clOptions->transformMainFilename, &errorMessage); 363 if (!transformRootFile) { 364 llvm::errs() << errorMessage << "\n"; 365 return mlir::failure(); 366 } 367 } 368 } 369 370 // Try opening transform library files if provided. 371 SmallVector<std::unique_ptr<llvm::MemoryBuffer>> transformLibraries; 372 transformLibraries.reserve(clOptions->transformLibraryFilenames.size()); 373 for (llvm::StringRef filename : clOptions->transformLibraryFilenames) { 374 transformLibraries.emplace_back( 375 mlir::openInputFile(filename, &errorMessage)); 376 if (!transformLibraries.back()) { 377 llvm::errs() << errorMessage << "\n"; 378 return mlir::failure(); 379 } 380 } 381 382 return processPayloadBuffer(outputFile->os(), std::move(payloadFile), 383 std::move(transformRootFile), transformLibraries, 384 registry); 385 } 386 387 int main(int argc, char **argv) { 388 return mlir::asMainReturnCode(runMain(argc, argv)); 389 } 390