xref: /llvm-project/mlir/examples/transform-opt/mlir-transform-opt.cpp (revision 4f4e2abb1a5ff1225d32410fd02b732d077aa056)
1 //===- mlir-transform-opt.cpp -----------------------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
10 #include "mlir/Dialect/Transform/IR/Utils.h"
11 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
12 #include "mlir/IR/AsmState.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/Diagnostics.h"
15 #include "mlir/IR/DialectRegistry.h"
16 #include "mlir/IR/MLIRContext.h"
17 #include "mlir/InitAllDialects.h"
18 #include "mlir/InitAllExtensions.h"
19 #include "mlir/InitAllPasses.h"
20 #include "mlir/Parser/Parser.h"
21 #include "mlir/Support/FileUtilities.h"
22 #include "mlir/Tools/mlir-opt/MlirOptMain.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/InitLLVM.h"
25 #include "llvm/Support/SourceMgr.h"
26 #include "llvm/Support/ToolOutputFile.h"
27 #include <cstdlib>
28 
29 namespace {
30 
31 using namespace llvm;
32 
33 /// Structure containing command line options for the tool, these will get
34 /// initialized when an instance is created.
35 struct MlirTransformOptCLOptions {
36   cl::opt<bool> allowUnregisteredDialects{
37       "allow-unregistered-dialect",
38       cl::desc("Allow operations coming from an unregistered dialect"),
39       cl::init(false)};
40 
41   cl::opt<bool> verifyDiagnostics{
42       "verify-diagnostics",
43       cl::desc("Check that emitted diagnostics match expected-* lines "
44                "on the corresponding line"),
45       cl::init(false)};
46 
47   cl::opt<std::string> payloadFilename{cl::Positional, cl::desc("<input file>"),
48                                        cl::init("-")};
49 
50   cl::opt<std::string> outputFilename{"o", cl::desc("Output filename"),
51                                       cl::value_desc("filename"),
52                                       cl::init("-")};
53 
54   cl::opt<std::string> transformMainFilename{
55       "transform",
56       cl::desc("File containing entry point of the transform script, if "
57                "different from the input file"),
58       cl::value_desc("filename"), cl::init("")};
59 
60   cl::list<std::string> transformLibraryFilenames{
61       "transform-library", cl::desc("File(s) containing definitions of "
62                                     "additional transform script symbols")};
63 
64   cl::opt<std::string> transformEntryPoint{
65       "transform-entry-point",
66       cl::desc("Name of the entry point transform symbol"),
67       cl::init(mlir::transform::TransformDialect::kTransformEntryPointSymbolName
68                    .str())};
69 
70   cl::opt<bool> disableExpensiveChecks{
71       "disable-expensive-checks",
72       cl::desc("Disables potentially expensive checks in the transform "
73                "interpreter, providing more speed at the expense of "
74                "potential memory problems and silent corruptions"),
75       cl::init(false)};
76 
77   cl::opt<bool> dumpLibraryModule{
78       "dump-library-module",
79       cl::desc("Prints the combined library module before the output"),
80       cl::init(false)};
81 };
82 } // namespace
83 
84 /// "Managed" static instance of the command-line options structure. This makes
85 /// them locally-scoped and explicitly initialized/deinitialized. While this is
86 /// not strictly necessary in the tool source file that is not being used as a
87 /// library (where the options would pollute the global list of options), it is
88 /// good practice to follow this.
89 static llvm::ManagedStatic<MlirTransformOptCLOptions> clOptions;
90 
91 /// Explicitly registers command-line options.
92 static void registerCLOptions() { *clOptions; }
93 
94 namespace {
95 /// A wrapper class for source managers diagnostic. This provides both unique
96 /// ownership and virtual function-like overload for a pair of
97 /// inheritance-related classes that do not use virtual functions.
98 class DiagnosticHandlerWrapper {
99 public:
100   /// Kind of the diagnostic handler to use.
101   enum class Kind { EmitDiagnostics, VerifyDiagnostics };
102 
103   /// Constructs the diagnostic handler of the specified kind of the given
104   /// source manager and context.
105   DiagnosticHandlerWrapper(Kind kind, llvm::SourceMgr &mgr,
106                            mlir::MLIRContext *context) {
107     if (kind == Kind::EmitDiagnostics)
108       handler = new mlir::SourceMgrDiagnosticHandler(mgr, context);
109     else
110       handler = new mlir::SourceMgrDiagnosticVerifierHandler(mgr, context);
111   }
112 
113   /// This object is non-copyable but movable.
114   DiagnosticHandlerWrapper(const DiagnosticHandlerWrapper &) = delete;
115   DiagnosticHandlerWrapper(DiagnosticHandlerWrapper &&other) = default;
116   DiagnosticHandlerWrapper &
117   operator=(const DiagnosticHandlerWrapper &) = delete;
118   DiagnosticHandlerWrapper &operator=(DiagnosticHandlerWrapper &&) = default;
119 
120   /// Verifies the captured "expected-*" diagnostics if required.
121   llvm::LogicalResult verify() const {
122     if (auto *ptr =
123             handler.dyn_cast<mlir::SourceMgrDiagnosticVerifierHandler *>()) {
124       return ptr->verify();
125     }
126     return mlir::success();
127   }
128 
129   /// Destructs the object of the same type as allocated.
130   ~DiagnosticHandlerWrapper() {
131     if (auto *ptr = handler.dyn_cast<mlir::SourceMgrDiagnosticHandler *>()) {
132       delete ptr;
133     } else {
134       delete cast<mlir::SourceMgrDiagnosticVerifierHandler *>(handler);
135     }
136   }
137 
138 private:
139   /// Internal storage is a type-safe union.
140   llvm::PointerUnion<mlir::SourceMgrDiagnosticHandler *,
141                      mlir::SourceMgrDiagnosticVerifierHandler *>
142       handler;
143 };
144 
145 /// MLIR has deeply rooted expectations that the LLVM source manager contains
146 /// exactly one buffer, until at least the lexer level. This class wraps
147 /// multiple LLVM source managers each managing a buffer to match MLIR's
148 /// expectations while still providing a centralized handling mechanism.
149 class TransformSourceMgr {
150 public:
151   /// Constructs the source manager indicating whether diagnostic messages will
152   /// be verified later on.
153   explicit TransformSourceMgr(bool verifyDiagnostics)
154       : verifyDiagnostics(verifyDiagnostics) {}
155 
156   /// Deconstructs the source manager. Note that `checkResults` must have been
157   /// called on this instance before deconstructing it.
158   ~TransformSourceMgr() {
159     assert(resultChecked && "must check the result of diagnostic handlers by "
160                             "running TransformSourceMgr::checkResult");
161   }
162 
163   /// Parses the given buffer and creates the top-level operation of the kind
164   /// specified as template argument in the given context. Additional parsing
165   /// options may be provided.
166   template <typename OpTy = mlir::Operation *>
167   mlir::OwningOpRef<OpTy> parseBuffer(std::unique_ptr<MemoryBuffer> buffer,
168                                       mlir::MLIRContext &context,
169                                       const mlir::ParserConfig &config) {
170     // Create a single-buffer LLVM source manager. Note that `unique_ptr` allows
171     // the code below to capture a reference to the source manager in such a way
172     // that it is not invalidated when the vector contents is eventually
173     // reallocated.
174     llvm::SourceMgr &mgr =
175         *sourceMgrs.emplace_back(std::make_unique<llvm::SourceMgr>());
176     mgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
177 
178     // Choose the type of diagnostic handler depending on whether diagnostic
179     // verification needs to happen and store it.
180     if (verifyDiagnostics) {
181       diagHandlers.emplace_back(
182           DiagnosticHandlerWrapper::Kind::VerifyDiagnostics, mgr, &context);
183     } else {
184       diagHandlers.emplace_back(DiagnosticHandlerWrapper::Kind::EmitDiagnostics,
185                                 mgr, &context);
186     }
187 
188     // Defer to MLIR's parser.
189     return mlir::parseSourceFile<OpTy>(mgr, config);
190   }
191 
192   /// If diagnostic message verification has been requested upon construction of
193   /// this source manager, performs the verification, reports errors and returns
194   /// the result of the verification. Otherwise passes through the given value.
195   llvm::LogicalResult checkResult(llvm::LogicalResult result) {
196     resultChecked = true;
197     if (!verifyDiagnostics)
198       return result;
199 
200     return mlir::failure(llvm::any_of(diagHandlers, [](const auto &handler) {
201       return mlir::failed(handler.verify());
202     }));
203   }
204 
205 private:
206   /// Indicates whether diagnostic message verification is requested.
207   const bool verifyDiagnostics;
208 
209   /// Indicates that diagnostic message verification has taken place, and the
210   /// deconstruction is therefore safe.
211   bool resultChecked = false;
212 
213   /// Storage for per-buffer source managers and diagnostic handlers. These are
214   /// wrapped into unique pointers in order to make it safe to capture
215   /// references to these objects: if the vector is reallocated, the unique
216   /// pointer objects are moved by the pointer addresses won't change. Also, for
217   /// handlers, this allows to store the pointer to the base class.
218   SmallVector<std::unique_ptr<llvm::SourceMgr>> sourceMgrs;
219   SmallVector<DiagnosticHandlerWrapper> diagHandlers;
220 };
221 } // namespace
222 
223 /// Trivial wrapper around `applyTransforms` that doesn't support extra mapping
224 /// and doesn't enforce the entry point transform ops being top-level.
225 static llvm::LogicalResult
226 applyTransforms(mlir::Operation *payloadRoot,
227                 mlir::transform::TransformOpInterface transformRoot,
228                 const mlir::transform::TransformOptions &options) {
229   return applyTransforms(payloadRoot, transformRoot, {}, options,
230                          /*enforceToplevelTransformOp=*/false);
231 }
232 
233 /// Applies transforms indicated in the transform dialect script to the input
234 /// buffer. The transform script may be embedded in the input buffer or as a
235 /// separate buffer. The transform script may have external symbols, the
236 /// definitions of which must be provided in transform library buffers. If the
237 /// application is successful, prints the transformed input buffer into the
238 /// given output stream. Additional configuration options are derived from
239 /// command-line options.
240 static llvm::LogicalResult processPayloadBuffer(
241     raw_ostream &os, std::unique_ptr<MemoryBuffer> inputBuffer,
242     std::unique_ptr<llvm::MemoryBuffer> transformBuffer,
243     MutableArrayRef<std::unique_ptr<MemoryBuffer>> transformLibraries,
244     mlir::DialectRegistry &registry) {
245 
246   // Initialize the MLIR context, and various configurations.
247   mlir::MLIRContext context(registry, mlir::MLIRContext::Threading::DISABLED);
248   context.allowUnregisteredDialects(clOptions->allowUnregisteredDialects);
249   mlir::ParserConfig config(&context);
250   TransformSourceMgr sourceMgr(
251       /*verifyDiagnostics=*/clOptions->verifyDiagnostics);
252 
253   // Parse the input buffer that will be used as transform payload.
254   mlir::OwningOpRef<mlir::Operation *> payloadRoot =
255       sourceMgr.parseBuffer(std::move(inputBuffer), context, config);
256   if (!payloadRoot)
257     return sourceMgr.checkResult(mlir::failure());
258 
259   // Identify the module containing the transform script entry point. This may
260   // be the same module as the input or a separate module. In the former case,
261   // make a copy of the module so it can be modified freely. Modification may
262   // happen in the script itself (at which point it could be rewriting itself
263   // during interpretation, leading to tricky memory errors) or by embedding
264   // library modules in the script.
265   mlir::OwningOpRef<mlir::ModuleOp> transformRoot;
266   if (transformBuffer) {
267     transformRoot = sourceMgr.parseBuffer<mlir::ModuleOp>(
268         std::move(transformBuffer), context, config);
269     if (!transformRoot)
270       return sourceMgr.checkResult(mlir::failure());
271   } else {
272     transformRoot = cast<mlir::ModuleOp>(payloadRoot->clone());
273   }
274 
275   // Parse and merge the libraries into the main transform module.
276   for (auto &&transformLibrary : transformLibraries) {
277     mlir::OwningOpRef<mlir::ModuleOp> libraryModule =
278         sourceMgr.parseBuffer<mlir::ModuleOp>(std::move(transformLibrary),
279                                               context, config);
280 
281     if (!libraryModule ||
282         mlir::failed(mlir::transform::detail::mergeSymbolsInto(
283             *transformRoot, std::move(libraryModule))))
284       return sourceMgr.checkResult(mlir::failure());
285   }
286 
287   // If requested, dump the combined transform module.
288   if (clOptions->dumpLibraryModule)
289     transformRoot->dump();
290 
291   // Find the entry point symbol. Even if it had originally been in the payload
292   // module, it was cloned into the transform module so only look there.
293   mlir::transform::TransformOpInterface entryPoint =
294       mlir::transform::detail::findTransformEntryPoint(
295           *transformRoot, mlir::ModuleOp(), clOptions->transformEntryPoint);
296   if (!entryPoint)
297     return sourceMgr.checkResult(mlir::failure());
298 
299   // Apply the requested transformations.
300   mlir::transform::TransformOptions transformOptions;
301   transformOptions.enableExpensiveChecks(!clOptions->disableExpensiveChecks);
302   if (mlir::failed(applyTransforms(*payloadRoot, entryPoint, transformOptions)))
303     return sourceMgr.checkResult(mlir::failure());
304 
305   // Print the transformed result and check the captured diagnostics if
306   // requested.
307   payloadRoot->print(os);
308   return sourceMgr.checkResult(mlir::success());
309 }
310 
311 /// Tool entry point.
312 static llvm::LogicalResult runMain(int argc, char **argv) {
313   // Register all upstream dialects and extensions. Specific uses are advised
314   // not to register all dialects indiscriminately but rather hand-pick what is
315   // necessary for their use case.
316   mlir::DialectRegistry registry;
317   mlir::registerAllDialects(registry);
318   mlir::registerAllExtensions(registry);
319   mlir::registerAllPasses();
320 
321   // Explicitly register the transform dialect. This is not strictly necessary
322   // since it has been already registered as part of the upstream dialect list,
323   // but useful for example purposes for cases when dialects to register are
324   // hand-picked. The transform dialect must be registered.
325   registry.insert<mlir::transform::TransformDialect>();
326 
327   // Register various command-line options. Note that the LLVM initializer
328   // object is a RAII that ensures correct deconstruction of command-line option
329   // objects inside ManagedStatic.
330   llvm::InitLLVM y(argc, argv);
331   mlir::registerAsmPrinterCLOptions();
332   mlir::registerMLIRContextCLOptions();
333   registerCLOptions();
334   llvm::cl::ParseCommandLineOptions(argc, argv,
335                                     "Minimal Transform dialect driver\n");
336 
337   // Try opening the main input file.
338   std::string errorMessage;
339   std::unique_ptr<llvm::MemoryBuffer> payloadFile =
340       mlir::openInputFile(clOptions->payloadFilename, &errorMessage);
341   if (!payloadFile) {
342     llvm::errs() << errorMessage << "\n";
343     return mlir::failure();
344   }
345 
346   // Try opening the output file.
347   std::unique_ptr<llvm::ToolOutputFile> outputFile =
348       mlir::openOutputFile(clOptions->outputFilename, &errorMessage);
349   if (!outputFile) {
350     llvm::errs() << errorMessage << "\n";
351     return mlir::failure();
352   }
353 
354   // Try opening the main transform file if provided.
355   std::unique_ptr<llvm::MemoryBuffer> transformRootFile;
356   if (!clOptions->transformMainFilename.empty()) {
357     if (clOptions->transformMainFilename == clOptions->payloadFilename) {
358       llvm::errs() << "warning: " << clOptions->payloadFilename
359                    << " is provided as both payload and transform file\n";
360     } else {
361       transformRootFile =
362           mlir::openInputFile(clOptions->transformMainFilename, &errorMessage);
363       if (!transformRootFile) {
364         llvm::errs() << errorMessage << "\n";
365         return mlir::failure();
366       }
367     }
368   }
369 
370   // Try opening transform library files if provided.
371   SmallVector<std::unique_ptr<llvm::MemoryBuffer>> transformLibraries;
372   transformLibraries.reserve(clOptions->transformLibraryFilenames.size());
373   for (llvm::StringRef filename : clOptions->transformLibraryFilenames) {
374     transformLibraries.emplace_back(
375         mlir::openInputFile(filename, &errorMessage));
376     if (!transformLibraries.back()) {
377       llvm::errs() << errorMessage << "\n";
378       return mlir::failure();
379     }
380   }
381 
382   return processPayloadBuffer(outputFile->os(), std::move(payloadFile),
383                               std::move(transformRootFile), transformLibraries,
384                               registry);
385 }
386 
387 int main(int argc, char **argv) {
388   return mlir::asMainReturnCode(runMain(argc, argv));
389 }
390