xref: /llvm-project/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp (revision 5a9bdd85ee4d8527e2cedf44f3ce26ff414f9b6a)
1 //===- TransformInterpreterUtils.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Lightweight transform dialect interpreter utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
14 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
15 #include "mlir/Dialect/Transform/IR/TransformOps.h"
16 #include "mlir/Dialect/Transform/IR/Utils.h"
17 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/Verifier.h"
20 #include "mlir/IR/Visitors.h"
21 #include "mlir/Interfaces/FunctionInterfaces.h"
22 #include "mlir/Parser/Parser.h"
23 #include "mlir/Support/FileUtilities.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/FileSystem.h"
28 #include "llvm/Support/SourceMgr.h"
29 #include "llvm/Support/raw_ostream.h"
30 
31 using namespace mlir;
32 
33 #define DEBUG_TYPE "transform-dialect-interpreter-utils"
34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
35 
36 /// Expands the given list of `paths` to a list of `.mlir` files.
37 ///
38 /// Each entry in `paths` may either be a regular file, in which case it ends up
39 /// in the result list, or a directory, in which case all (regular) `.mlir`
40 /// files in that directory are added. Any other file types lead to a failure.
expandPathsToMLIRFiles(ArrayRef<std::string> paths,MLIRContext * context,SmallVectorImpl<std::string> & fileNames)41 LogicalResult transform::detail::expandPathsToMLIRFiles(
42     ArrayRef<std::string> paths, MLIRContext *context,
43     SmallVectorImpl<std::string> &fileNames) {
44   for (const std::string &path : paths) {
45     auto loc = FileLineColLoc::get(context, path, 0, 0);
46 
47     if (llvm::sys::fs::is_regular_file(path)) {
48       LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
49       fileNames.push_back(path);
50       continue;
51     }
52 
53     if (!llvm::sys::fs::is_directory(path)) {
54       return emitError(loc)
55              << "'" << path << "' is neither a file nor a directory";
56     }
57 
58     LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");
59 
60     std::error_code ec;
61     for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
62          it != itEnd && !ec; it.increment(ec)) {
63       const std::string &fileName = it->path();
64 
65       if (it->type() != llvm::sys::fs::file_type::regular_file &&
66           it->type() != llvm::sys::fs::file_type::symlink_file) {
67         LLVM_DEBUG(DBGS() << "  Skipping non-regular file '" << fileName
68                           << "'\n");
69         continue;
70       }
71 
72       if (!StringRef(fileName).ends_with(".mlir")) {
73         LLVM_DEBUG(DBGS() << "  Skipping '" << fileName
74                           << "' because it does not end with '.mlir'\n");
75         continue;
76       }
77 
78       LLVM_DEBUG(DBGS() << "  Adding '" << fileName << "' to list of files\n");
79       fileNames.push_back(fileName);
80     }
81 
82     if (ec)
83       return emitError(loc) << "error while opening files in '" << path
84                             << "': " << ec.message();
85   }
86 
87   return success();
88 }
89 
parseTransformModuleFromFile(MLIRContext * context,llvm::StringRef transformFileName,OwningOpRef<ModuleOp> & transformModule)90 LogicalResult transform::detail::parseTransformModuleFromFile(
91     MLIRContext *context, llvm::StringRef transformFileName,
92     OwningOpRef<ModuleOp> &transformModule) {
93   if (transformFileName.empty()) {
94     LLVM_DEBUG(
95         DBGS() << "no transform file name specified, assuming the transform "
96                   "module is embedded in the IR next to the top-level\n");
97     return success();
98   }
99   // Parse transformFileName content into a ModuleOp.
100   std::string errorMessage;
101   auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
102   if (!memoryBuffer) {
103     return emitError(FileLineColLoc::get(
104                StringAttr::get(context, transformFileName), 0, 0))
105            << "failed to open transform file: " << errorMessage;
106   }
107   // Tell sourceMgr about this buffer, the parser will pick it up.
108   llvm::SourceMgr sourceMgr;
109   sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
110   transformModule =
111       OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
112   if (!transformModule) {
113     // Failed to parse the transform module.
114     // Don't need to emit an error here as the parsing should have already done
115     // that.
116     return failure();
117   }
118   return mlir::verify(*transformModule);
119 }
120 
getPreloadedTransformModule(MLIRContext * context)121 ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) {
122   return context->getOrLoadDialect<transform::TransformDialect>()
123       ->getLibraryModule();
124 }
125 
126 transform::TransformOpInterface
findTransformEntryPoint(Operation * root,ModuleOp module,StringRef entryPoint)127 transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
128                                            StringRef entryPoint) {
129   SmallVector<Operation *, 2> l{root};
130   if (module)
131     l.push_back(module);
132   for (Operation *op : l) {
133     transform::TransformOpInterface transform = nullptr;
134     op->walk<WalkOrder::PreOrder>(
135         [&](transform::NamedSequenceOp namedSequenceOp) {
136           if (namedSequenceOp.getSymName() == entryPoint) {
137             transform = cast<transform::TransformOpInterface>(
138                 namedSequenceOp.getOperation());
139             return WalkResult::interrupt();
140           }
141           return WalkResult::advance();
142         });
143     if (transform)
144       return transform;
145   }
146   auto diag = root->emitError()
147               << "could not find a nested named sequence with name: "
148               << entryPoint;
149   return nullptr;
150 }
151 
assembleTransformLibraryFromPaths(MLIRContext * context,ArrayRef<std::string> transformLibraryPaths,OwningOpRef<ModuleOp> & transformModule)152 LogicalResult transform::detail::assembleTransformLibraryFromPaths(
153     MLIRContext *context, ArrayRef<std::string> transformLibraryPaths,
154     OwningOpRef<ModuleOp> &transformModule) {
155   // Assemble list of library files.
156   SmallVector<std::string> libraryFileNames;
157   if (failed(detail::expandPathsToMLIRFiles(transformLibraryPaths, context,
158                                             libraryFileNames)))
159     return failure();
160 
161   // Parse modules from library files.
162   SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
163   for (const std::string &libraryFileName : libraryFileNames) {
164     OwningOpRef<ModuleOp> parsedLibrary;
165     if (failed(transform::detail::parseTransformModuleFromFile(
166             context, libraryFileName, parsedLibrary)))
167       return failure();
168     parsedLibraries.push_back(std::move(parsedLibrary));
169   }
170 
171   // Merge parsed libraries into one module.
172   auto loc = FileLineColLoc::get(context, "<shared-library-module>", 0, 0);
173   OwningOpRef<ModuleOp> mergedParsedLibraries =
174       ModuleOp::create(loc, "__transform");
175   {
176     mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
177                                          UnitAttr::get(context));
178     // TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
179     for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
180       if (failed(transform::detail::mergeSymbolsInto(
181               mergedParsedLibraries.get(), std::move(parsedLibrary))))
182         return parsedLibrary->emitError()
183                << "failed to merge symbols into shared library module";
184     }
185   }
186 
187   transformModule = std::move(mergedParsedLibraries);
188   return success();
189 }
190 
applyTransformNamedSequence(Operation * payload,Operation * transformRoot,ModuleOp transformModule,const TransformOptions & options)191 LogicalResult transform::applyTransformNamedSequence(
192     Operation *payload, Operation *transformRoot, ModuleOp transformModule,
193     const TransformOptions &options) {
194   RaggedArray<MappedValue> bindings;
195   bindings.push_back(ArrayRef<Operation *>{payload});
196   return applyTransformNamedSequence(bindings,
197                                      cast<TransformOpInterface>(transformRoot),
198                                      transformModule, options);
199 }
200 
applyTransformNamedSequence(RaggedArray<MappedValue> bindings,TransformOpInterface transformRoot,ModuleOp transformModule,const TransformOptions & options)201 LogicalResult transform::applyTransformNamedSequence(
202     RaggedArray<MappedValue> bindings, TransformOpInterface transformRoot,
203     ModuleOp transformModule, const TransformOptions &options) {
204   if (bindings.empty()) {
205     return transformRoot.emitError()
206            << "expected at least one binding for the root";
207   }
208   if (bindings.at(0).size() != 1) {
209     return transformRoot.emitError()
210            << "expected one payload to be bound to the first argument, got "
211            << bindings.at(0).size();
212   }
213   auto *payloadRoot = bindings.at(0).front().dyn_cast<Operation *>();
214   if (!payloadRoot) {
215     return transformRoot->emitError() << "expected the object bound to the "
216                                          "first argument to be an operation";
217   }
218 
219   bindings.removeFront();
220 
221   // `transformModule` may not be modified.
222   if (transformModule && !transformModule->isAncestor(transformRoot)) {
223     OwningOpRef<Operation *> clonedTransformModule(transformModule->clone());
224     if (failed(detail::mergeSymbolsInto(
225             SymbolTable::getNearestSymbolTable(transformRoot),
226             std::move(clonedTransformModule)))) {
227       return payloadRoot->emitError() << "failed to merge symbols";
228     }
229   }
230 
231   LLVM_DEBUG(DBGS() << "Apply\n" << *transformRoot << "\n");
232   LLVM_DEBUG(DBGS() << "To\n" << *payloadRoot << "\n");
233 
234   return applyTransforms(payloadRoot, transformRoot, bindings, options,
235                          /*enforceToplevelTransformOp=*/false);
236 }
237