xref: /llvm-project/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp (revision 5a9bdd85ee4d8527e2cedf44f3ce26ff414f9b6a)
1ef8c26b7SNicolas Vasilache //===- TransformInterpreterUtils.cpp --------------------------------------===//
2ef8c26b7SNicolas Vasilache //
3ef8c26b7SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ef8c26b7SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information.
5ef8c26b7SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ef8c26b7SNicolas Vasilache //
7ef8c26b7SNicolas Vasilache //===----------------------------------------------------------------------===//
8ef8c26b7SNicolas Vasilache //
9ef8c26b7SNicolas Vasilache // Lightweight transform dialect interpreter utilities.
10ef8c26b7SNicolas Vasilache //
11ef8c26b7SNicolas Vasilache //===----------------------------------------------------------------------===//
12ef8c26b7SNicolas Vasilache 
13ef8c26b7SNicolas Vasilache #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
14ef8c26b7SNicolas Vasilache #include "mlir/Dialect/Transform/IR/TransformDialect.h"
15ef8c26b7SNicolas Vasilache #include "mlir/Dialect/Transform/IR/TransformOps.h"
1699c15eb4SIngo Müller #include "mlir/Dialect/Transform/IR/Utils.h"
17*5a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18ef8c26b7SNicolas Vasilache #include "mlir/IR/BuiltinOps.h"
19ef8c26b7SNicolas Vasilache #include "mlir/IR/Verifier.h"
20ef8c26b7SNicolas Vasilache #include "mlir/IR/Visitors.h"
21ef8c26b7SNicolas Vasilache #include "mlir/Interfaces/FunctionInterfaces.h"
22ef8c26b7SNicolas Vasilache #include "mlir/Parser/Parser.h"
23ef8c26b7SNicolas Vasilache #include "mlir/Support/FileUtilities.h"
24ef8c26b7SNicolas Vasilache #include "llvm/ADT/StringRef.h"
25ef8c26b7SNicolas Vasilache #include "llvm/Support/Casting.h"
26ef8c26b7SNicolas Vasilache #include "llvm/Support/Debug.h"
271bf08709SNicolas Vasilache #include "llvm/Support/FileSystem.h"
28ef8c26b7SNicolas Vasilache #include "llvm/Support/SourceMgr.h"
29ef8c26b7SNicolas Vasilache #include "llvm/Support/raw_ostream.h"
30ef8c26b7SNicolas Vasilache 
31ef8c26b7SNicolas Vasilache using namespace mlir;
32ef8c26b7SNicolas Vasilache 
33ef8c26b7SNicolas Vasilache #define DEBUG_TYPE "transform-dialect-interpreter-utils"
34ef8c26b7SNicolas Vasilache #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
35ef8c26b7SNicolas Vasilache 
361bf08709SNicolas Vasilache /// Expands the given list of `paths` to a list of `.mlir` files.
371bf08709SNicolas Vasilache ///
381bf08709SNicolas Vasilache /// Each entry in `paths` may either be a regular file, in which case it ends up
391bf08709SNicolas Vasilache /// in the result list, or a directory, in which case all (regular) `.mlir`
401bf08709SNicolas Vasilache /// files in that directory are added. Any other file types lead to a failure.
expandPathsToMLIRFiles(ArrayRef<std::string> paths,MLIRContext * context,SmallVectorImpl<std::string> & fileNames)411bf08709SNicolas Vasilache LogicalResult transform::detail::expandPathsToMLIRFiles(
4299c15eb4SIngo Müller     ArrayRef<std::string> paths, MLIRContext *context,
431bf08709SNicolas Vasilache     SmallVectorImpl<std::string> &fileNames) {
441bf08709SNicolas Vasilache   for (const std::string &path : paths) {
451bf08709SNicolas Vasilache     auto loc = FileLineColLoc::get(context, path, 0, 0);
461bf08709SNicolas Vasilache 
471bf08709SNicolas Vasilache     if (llvm::sys::fs::is_regular_file(path)) {
481bf08709SNicolas Vasilache       LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
491bf08709SNicolas Vasilache       fileNames.push_back(path);
501bf08709SNicolas Vasilache       continue;
511bf08709SNicolas Vasilache     }
521bf08709SNicolas Vasilache 
531bf08709SNicolas Vasilache     if (!llvm::sys::fs::is_directory(path)) {
541bf08709SNicolas Vasilache       return emitError(loc)
551bf08709SNicolas Vasilache              << "'" << path << "' is neither a file nor a directory";
561bf08709SNicolas Vasilache     }
571bf08709SNicolas Vasilache 
581bf08709SNicolas Vasilache     LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");
591bf08709SNicolas Vasilache 
601bf08709SNicolas Vasilache     std::error_code ec;
611bf08709SNicolas Vasilache     for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
621bf08709SNicolas Vasilache          it != itEnd && !ec; it.increment(ec)) {
631bf08709SNicolas Vasilache       const std::string &fileName = it->path();
641bf08709SNicolas Vasilache 
653517b67eSIngo Müller       if (it->type() != llvm::sys::fs::file_type::regular_file &&
663517b67eSIngo Müller           it->type() != llvm::sys::fs::file_type::symlink_file) {
671bf08709SNicolas Vasilache         LLVM_DEBUG(DBGS() << "  Skipping non-regular file '" << fileName
681bf08709SNicolas Vasilache                           << "'\n");
691bf08709SNicolas Vasilache         continue;
701bf08709SNicolas Vasilache       }
711bf08709SNicolas Vasilache 
7288d319a2SKazu Hirata       if (!StringRef(fileName).ends_with(".mlir")) {
731bf08709SNicolas Vasilache         LLVM_DEBUG(DBGS() << "  Skipping '" << fileName
741bf08709SNicolas Vasilache                           << "' because it does not end with '.mlir'\n");
751bf08709SNicolas Vasilache         continue;
761bf08709SNicolas Vasilache       }
771bf08709SNicolas Vasilache 
781bf08709SNicolas Vasilache       LLVM_DEBUG(DBGS() << "  Adding '" << fileName << "' to list of files\n");
791bf08709SNicolas Vasilache       fileNames.push_back(fileName);
801bf08709SNicolas Vasilache     }
811bf08709SNicolas Vasilache 
821bf08709SNicolas Vasilache     if (ec)
831bf08709SNicolas Vasilache       return emitError(loc) << "error while opening files in '" << path
841bf08709SNicolas Vasilache                             << "': " << ec.message();
851bf08709SNicolas Vasilache   }
861bf08709SNicolas Vasilache 
871bf08709SNicolas Vasilache   return success();
881bf08709SNicolas Vasilache }
891bf08709SNicolas Vasilache 
parseTransformModuleFromFile(MLIRContext * context,llvm::StringRef transformFileName,OwningOpRef<ModuleOp> & transformModule)90ef8c26b7SNicolas Vasilache LogicalResult transform::detail::parseTransformModuleFromFile(
91ef8c26b7SNicolas Vasilache     MLIRContext *context, llvm::StringRef transformFileName,
92ef8c26b7SNicolas Vasilache     OwningOpRef<ModuleOp> &transformModule) {
93ef8c26b7SNicolas Vasilache   if (transformFileName.empty()) {
94ef8c26b7SNicolas Vasilache     LLVM_DEBUG(
95ef8c26b7SNicolas Vasilache         DBGS() << "no transform file name specified, assuming the transform "
96ef8c26b7SNicolas Vasilache                   "module is embedded in the IR next to the top-level\n");
97ef8c26b7SNicolas Vasilache     return success();
98ef8c26b7SNicolas Vasilache   }
99ef8c26b7SNicolas Vasilache   // Parse transformFileName content into a ModuleOp.
100ef8c26b7SNicolas Vasilache   std::string errorMessage;
101ef8c26b7SNicolas Vasilache   auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
102ef8c26b7SNicolas Vasilache   if (!memoryBuffer) {
103ef8c26b7SNicolas Vasilache     return emitError(FileLineColLoc::get(
104ef8c26b7SNicolas Vasilache                StringAttr::get(context, transformFileName), 0, 0))
105ef8c26b7SNicolas Vasilache            << "failed to open transform file: " << errorMessage;
106ef8c26b7SNicolas Vasilache   }
107ef8c26b7SNicolas Vasilache   // Tell sourceMgr about this buffer, the parser will pick it up.
108ef8c26b7SNicolas Vasilache   llvm::SourceMgr sourceMgr;
109ef8c26b7SNicolas Vasilache   sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
110ef8c26b7SNicolas Vasilache   transformModule =
111ef8c26b7SNicolas Vasilache       OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
112282d5014SKunwar Grover   if (!transformModule) {
113282d5014SKunwar Grover     // Failed to parse the transform module.
114282d5014SKunwar Grover     // Don't need to emit an error here as the parsing should have already done
115282d5014SKunwar Grover     // that.
116282d5014SKunwar Grover     return failure();
117282d5014SKunwar Grover   }
118ef8c26b7SNicolas Vasilache   return mlir::verify(*transformModule);
119ef8c26b7SNicolas Vasilache }
120ef8c26b7SNicolas Vasilache 
getPreloadedTransformModule(MLIRContext * context)121ef8c26b7SNicolas Vasilache ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) {
12299c15eb4SIngo Müller   return context->getOrLoadDialect<transform::TransformDialect>()
12399c15eb4SIngo Müller       ->getLibraryModule();
124ef8c26b7SNicolas Vasilache }
125ef8c26b7SNicolas Vasilache 
126ef8c26b7SNicolas Vasilache transform::TransformOpInterface
findTransformEntryPoint(Operation * root,ModuleOp module,StringRef entryPoint)127ef8c26b7SNicolas Vasilache transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
128ef8c26b7SNicolas Vasilache                                            StringRef entryPoint) {
129ef8c26b7SNicolas Vasilache   SmallVector<Operation *, 2> l{root};
130ef8c26b7SNicolas Vasilache   if (module)
131ef8c26b7SNicolas Vasilache     l.push_back(module);
132ef8c26b7SNicolas Vasilache   for (Operation *op : l) {
133ef8c26b7SNicolas Vasilache     transform::TransformOpInterface transform = nullptr;
134ef8c26b7SNicolas Vasilache     op->walk<WalkOrder::PreOrder>(
135ef8c26b7SNicolas Vasilache         [&](transform::NamedSequenceOp namedSequenceOp) {
136ef8c26b7SNicolas Vasilache           if (namedSequenceOp.getSymName() == entryPoint) {
137ef8c26b7SNicolas Vasilache             transform = cast<transform::TransformOpInterface>(
138ef8c26b7SNicolas Vasilache                 namedSequenceOp.getOperation());
139ef8c26b7SNicolas Vasilache             return WalkResult::interrupt();
140ef8c26b7SNicolas Vasilache           }
141ef8c26b7SNicolas Vasilache           return WalkResult::advance();
142ef8c26b7SNicolas Vasilache         });
143ef8c26b7SNicolas Vasilache     if (transform)
144ef8c26b7SNicolas Vasilache       return transform;
145ef8c26b7SNicolas Vasilache   }
146ef8c26b7SNicolas Vasilache   auto diag = root->emitError()
147ef8c26b7SNicolas Vasilache               << "could not find a nested named sequence with name: "
148ef8c26b7SNicolas Vasilache               << entryPoint;
149ef8c26b7SNicolas Vasilache   return nullptr;
150ef8c26b7SNicolas Vasilache }
151ef8c26b7SNicolas Vasilache 
assembleTransformLibraryFromPaths(MLIRContext * context,ArrayRef<std::string> transformLibraryPaths,OwningOpRef<ModuleOp> & transformModule)1521bf08709SNicolas Vasilache LogicalResult transform::detail::assembleTransformLibraryFromPaths(
1531bf08709SNicolas Vasilache     MLIRContext *context, ArrayRef<std::string> transformLibraryPaths,
1541bf08709SNicolas Vasilache     OwningOpRef<ModuleOp> &transformModule) {
1551bf08709SNicolas Vasilache   // Assemble list of library files.
1561bf08709SNicolas Vasilache   SmallVector<std::string> libraryFileNames;
1571bf08709SNicolas Vasilache   if (failed(detail::expandPathsToMLIRFiles(transformLibraryPaths, context,
1581bf08709SNicolas Vasilache                                             libraryFileNames)))
1591bf08709SNicolas Vasilache     return failure();
1601bf08709SNicolas Vasilache 
1611bf08709SNicolas Vasilache   // Parse modules from library files.
1621bf08709SNicolas Vasilache   SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
1631bf08709SNicolas Vasilache   for (const std::string &libraryFileName : libraryFileNames) {
1641bf08709SNicolas Vasilache     OwningOpRef<ModuleOp> parsedLibrary;
1651bf08709SNicolas Vasilache     if (failed(transform::detail::parseTransformModuleFromFile(
1661bf08709SNicolas Vasilache             context, libraryFileName, parsedLibrary)))
1671bf08709SNicolas Vasilache       return failure();
1681bf08709SNicolas Vasilache     parsedLibraries.push_back(std::move(parsedLibrary));
1691bf08709SNicolas Vasilache   }
1701bf08709SNicolas Vasilache 
1711bf08709SNicolas Vasilache   // Merge parsed libraries into one module.
1721bf08709SNicolas Vasilache   auto loc = FileLineColLoc::get(context, "<shared-library-module>", 0, 0);
1731bf08709SNicolas Vasilache   OwningOpRef<ModuleOp> mergedParsedLibraries =
1741bf08709SNicolas Vasilache       ModuleOp::create(loc, "__transform");
1751bf08709SNicolas Vasilache   {
1761bf08709SNicolas Vasilache     mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
1771bf08709SNicolas Vasilache                                          UnitAttr::get(context));
1781bf08709SNicolas Vasilache     // TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
1791bf08709SNicolas Vasilache     for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
1801bf08709SNicolas Vasilache       if (failed(transform::detail::mergeSymbolsInto(
1811bf08709SNicolas Vasilache               mergedParsedLibraries.get(), std::move(parsedLibrary))))
182f07718b7SIngo Müller         return parsedLibrary->emitError()
183f07718b7SIngo Müller                << "failed to merge symbols into shared library module";
1841bf08709SNicolas Vasilache     }
1851bf08709SNicolas Vasilache   }
1861bf08709SNicolas Vasilache 
1871bf08709SNicolas Vasilache   transformModule = std::move(mergedParsedLibraries);
1881bf08709SNicolas Vasilache   return success();
1891bf08709SNicolas Vasilache }
1901bf08709SNicolas Vasilache 
applyTransformNamedSequence(Operation * payload,Operation * transformRoot,ModuleOp transformModule,const TransformOptions & options)191ef8c26b7SNicolas Vasilache LogicalResult transform::applyTransformNamedSequence(
192e4384149SOleksandr "Alex" Zinenko     Operation *payload, Operation *transformRoot, ModuleOp transformModule,
193e4384149SOleksandr "Alex" Zinenko     const TransformOptions &options) {
194b33b91a2SOleksandr "Alex" Zinenko   RaggedArray<MappedValue> bindings;
195b33b91a2SOleksandr "Alex" Zinenko   bindings.push_back(ArrayRef<Operation *>{payload});
196b33b91a2SOleksandr "Alex" Zinenko   return applyTransformNamedSequence(bindings,
197b33b91a2SOleksandr "Alex" Zinenko                                      cast<TransformOpInterface>(transformRoot),
198b33b91a2SOleksandr "Alex" Zinenko                                      transformModule, options);
199b33b91a2SOleksandr "Alex" Zinenko }
200b33b91a2SOleksandr "Alex" Zinenko 
applyTransformNamedSequence(RaggedArray<MappedValue> bindings,TransformOpInterface transformRoot,ModuleOp transformModule,const TransformOptions & options)201b33b91a2SOleksandr "Alex" Zinenko LogicalResult transform::applyTransformNamedSequence(
202b33b91a2SOleksandr "Alex" Zinenko     RaggedArray<MappedValue> bindings, TransformOpInterface transformRoot,
203b33b91a2SOleksandr "Alex" Zinenko     ModuleOp transformModule, const TransformOptions &options) {
204b33b91a2SOleksandr "Alex" Zinenko   if (bindings.empty()) {
205b33b91a2SOleksandr "Alex" Zinenko     return transformRoot.emitError()
206b33b91a2SOleksandr "Alex" Zinenko            << "expected at least one binding for the root";
207b33b91a2SOleksandr "Alex" Zinenko   }
208b33b91a2SOleksandr "Alex" Zinenko   if (bindings.at(0).size() != 1) {
209b33b91a2SOleksandr "Alex" Zinenko     return transformRoot.emitError()
210b33b91a2SOleksandr "Alex" Zinenko            << "expected one payload to be bound to the first argument, got "
211b33b91a2SOleksandr "Alex" Zinenko            << bindings.at(0).size();
212b33b91a2SOleksandr "Alex" Zinenko   }
213b33b91a2SOleksandr "Alex" Zinenko   auto *payloadRoot = bindings.at(0).front().dyn_cast<Operation *>();
214b33b91a2SOleksandr "Alex" Zinenko   if (!payloadRoot) {
215b33b91a2SOleksandr "Alex" Zinenko     return transformRoot->emitError() << "expected the object bound to the "
216b33b91a2SOleksandr "Alex" Zinenko                                          "first argument to be an operation";
217b33b91a2SOleksandr "Alex" Zinenko   }
218b33b91a2SOleksandr "Alex" Zinenko 
219b33b91a2SOleksandr "Alex" Zinenko   bindings.removeFront();
220b33b91a2SOleksandr "Alex" Zinenko 
221ef8c26b7SNicolas Vasilache   // `transformModule` may not be modified.
222ef8c26b7SNicolas Vasilache   if (transformModule && !transformModule->isAncestor(transformRoot)) {
2231bf08709SNicolas Vasilache     OwningOpRef<Operation *> clonedTransformModule(transformModule->clone());
224ef8c26b7SNicolas Vasilache     if (failed(detail::mergeSymbolsInto(
225ef8c26b7SNicolas Vasilache             SymbolTable::getNearestSymbolTable(transformRoot),
2261bf08709SNicolas Vasilache             std::move(clonedTransformModule)))) {
227b33b91a2SOleksandr "Alex" Zinenko       return payloadRoot->emitError() << "failed to merge symbols";
228ef8c26b7SNicolas Vasilache     }
2291bf08709SNicolas Vasilache   }
2301bf08709SNicolas Vasilache 
2311bf08709SNicolas Vasilache   LLVM_DEBUG(DBGS() << "Apply\n" << *transformRoot << "\n");
232b33b91a2SOleksandr "Alex" Zinenko   LLVM_DEBUG(DBGS() << "To\n" << *payloadRoot << "\n");
233ef8c26b7SNicolas Vasilache 
234b33b91a2SOleksandr "Alex" Zinenko   return applyTransforms(payloadRoot, transformRoot, bindings, options,
235ef8c26b7SNicolas Vasilache                          /*enforceToplevelTransformOp=*/false);
236ef8c26b7SNicolas Vasilache }
237