xref: /llvm-project/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp (revision 5a9bdd85ee4d8527e2cedf44f3ce26ff414f9b6a)
11bf08709SNicolas Vasilache //===- InterpreterPass.cpp - Transform dialect interpreter pass -----------===//
21bf08709SNicolas Vasilache //
31bf08709SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
41bf08709SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information.
51bf08709SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
61bf08709SNicolas Vasilache //
71bf08709SNicolas Vasilache //===----------------------------------------------------------------------===//
81bf08709SNicolas Vasilache 
91bf08709SNicolas Vasilache #include "mlir/Dialect/Transform/IR/TransformDialect.h"
10*5a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
111bf08709SNicolas Vasilache #include "mlir/Dialect/Transform/Transforms/Passes.h"
121bf08709SNicolas Vasilache #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
131bf08709SNicolas Vasilache 
141bf08709SNicolas Vasilache using namespace mlir;
151bf08709SNicolas Vasilache 
161bf08709SNicolas Vasilache namespace mlir {
171bf08709SNicolas Vasilache namespace transform {
181bf08709SNicolas Vasilache #define GEN_PASS_DEF_INTERPRETERPASS
191bf08709SNicolas Vasilache #include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
201bf08709SNicolas Vasilache } // namespace transform
211bf08709SNicolas Vasilache } // namespace mlir
221bf08709SNicolas Vasilache 
23e4384149SOleksandr "Alex" Zinenko /// Returns the payload operation to be used as payload root:
24e4384149SOleksandr "Alex" Zinenko ///   - the operation nested under `passRoot` that has the given tag attribute,
25e4384149SOleksandr "Alex" Zinenko ///     must be unique;
26e4384149SOleksandr "Alex" Zinenko ///   - the `passRoot` itself if the tag is empty.
findPayloadRoot(Operation * passRoot,StringRef tag)27e4384149SOleksandr "Alex" Zinenko static Operation *findPayloadRoot(Operation *passRoot, StringRef tag) {
28e4384149SOleksandr "Alex" Zinenko   // Fast return.
29e4384149SOleksandr "Alex" Zinenko   if (tag.empty())
30e4384149SOleksandr "Alex" Zinenko     return passRoot;
31e4384149SOleksandr "Alex" Zinenko 
32e4384149SOleksandr "Alex" Zinenko   // Walk to do a lookup.
33e4384149SOleksandr "Alex" Zinenko   Operation *target = nullptr;
34e4384149SOleksandr "Alex" Zinenko   auto tagAttrName = StringAttr::get(
35e4384149SOleksandr "Alex" Zinenko       passRoot->getContext(), transform::TransformDialect::kTargetTagAttrName);
36e4384149SOleksandr "Alex" Zinenko   WalkResult walkResult = passRoot->walk([&](Operation *op) {
37e4384149SOleksandr "Alex" Zinenko     auto attr = op->getAttrOfType<StringAttr>(tagAttrName);
38e4384149SOleksandr "Alex" Zinenko     if (!attr || attr.getValue() != tag)
39e4384149SOleksandr "Alex" Zinenko       return WalkResult::advance();
40e4384149SOleksandr "Alex" Zinenko 
41e4384149SOleksandr "Alex" Zinenko     if (!target) {
42e4384149SOleksandr "Alex" Zinenko       target = op;
43e4384149SOleksandr "Alex" Zinenko       return WalkResult::advance();
44e4384149SOleksandr "Alex" Zinenko     }
45e4384149SOleksandr "Alex" Zinenko 
46e4384149SOleksandr "Alex" Zinenko     InFlightDiagnostic diag = op->emitError()
47e4384149SOleksandr "Alex" Zinenko                               << "repeated operation with the target tag '"
48e4384149SOleksandr "Alex" Zinenko                               << tag << "'";
49e4384149SOleksandr "Alex" Zinenko     diag.attachNote(target->getLoc()) << "previously seen operation";
50e4384149SOleksandr "Alex" Zinenko     return WalkResult::interrupt();
51e4384149SOleksandr "Alex" Zinenko   });
52e4384149SOleksandr "Alex" Zinenko 
535468f884SOleksandr "Alex" Zinenko   if (!target) {
545468f884SOleksandr "Alex" Zinenko     passRoot->emitError()
555468f884SOleksandr "Alex" Zinenko         << "could not find the operation with transform.target_tag=\"" << tag
565468f884SOleksandr "Alex" Zinenko         << "\" attribute";
575468f884SOleksandr "Alex" Zinenko     return nullptr;
585468f884SOleksandr "Alex" Zinenko   }
595468f884SOleksandr "Alex" Zinenko 
60e4384149SOleksandr "Alex" Zinenko   return walkResult.wasInterrupted() ? nullptr : target;
61e4384149SOleksandr "Alex" Zinenko }
62e4384149SOleksandr "Alex" Zinenko 
631bf08709SNicolas Vasilache namespace {
641bf08709SNicolas Vasilache class InterpreterPass
651bf08709SNicolas Vasilache     : public transform::impl::InterpreterPassBase<InterpreterPass> {
665468f884SOleksandr "Alex" Zinenko   // Parses the pass arguments to bind trailing arguments of the entry point.
675468f884SOleksandr "Alex" Zinenko   std::optional<RaggedArray<transform::MappedValue>>
parseArguments(Operation * payloadRoot)685468f884SOleksandr "Alex" Zinenko   parseArguments(Operation *payloadRoot) {
695468f884SOleksandr "Alex" Zinenko     MLIRContext *context = payloadRoot->getContext();
705468f884SOleksandr "Alex" Zinenko 
715468f884SOleksandr "Alex" Zinenko     SmallVector<SmallVector<transform::MappedValue>, 2> trailingBindings;
725468f884SOleksandr "Alex" Zinenko     trailingBindings.resize(debugBindTrailingArgs.size());
735468f884SOleksandr "Alex" Zinenko 
745468f884SOleksandr "Alex" Zinenko     // Construct lists of op names to match.
755468f884SOleksandr "Alex" Zinenko     SmallVector<std::optional<OperationName>> debugBindNames;
765468f884SOleksandr "Alex" Zinenko     debugBindNames.reserve(debugBindTrailingArgs.size());
775468f884SOleksandr "Alex" Zinenko     for (auto &&[position, nameString] :
785468f884SOleksandr "Alex" Zinenko          llvm::enumerate(debugBindTrailingArgs)) {
795468f884SOleksandr "Alex" Zinenko       StringRef name = nameString;
805468f884SOleksandr "Alex" Zinenko 
815468f884SOleksandr "Alex" Zinenko       // Parse the integer literals.
825468f884SOleksandr "Alex" Zinenko       if (name.starts_with("#")) {
835468f884SOleksandr "Alex" Zinenko         debugBindNames.push_back(std::nullopt);
845468f884SOleksandr "Alex" Zinenko         StringRef lhs = "";
855468f884SOleksandr "Alex" Zinenko         StringRef rhs = name.drop_front();
865468f884SOleksandr "Alex" Zinenko         do {
875468f884SOleksandr "Alex" Zinenko           std::tie(lhs, rhs) = rhs.split(';');
885468f884SOleksandr "Alex" Zinenko           int64_t value;
895468f884SOleksandr "Alex" Zinenko           if (lhs.getAsInteger(10, value)) {
905468f884SOleksandr "Alex" Zinenko             emitError(UnknownLoc::get(context))
915468f884SOleksandr "Alex" Zinenko                 << "couldn't parse integer pass argument " << name;
925468f884SOleksandr "Alex" Zinenko             return std::nullopt;
935468f884SOleksandr "Alex" Zinenko           }
945468f884SOleksandr "Alex" Zinenko           trailingBindings[position].push_back(
955468f884SOleksandr "Alex" Zinenko               Builder(context).getI64IntegerAttr(value));
965468f884SOleksandr "Alex" Zinenko         } while (!rhs.empty());
975468f884SOleksandr "Alex" Zinenko       } else if (name.starts_with("^")) {
985468f884SOleksandr "Alex" Zinenko         debugBindNames.emplace_back(OperationName(name.drop_front(), context));
995468f884SOleksandr "Alex" Zinenko       } else {
1005468f884SOleksandr "Alex" Zinenko         debugBindNames.emplace_back(OperationName(name, context));
1015468f884SOleksandr "Alex" Zinenko       }
1025468f884SOleksandr "Alex" Zinenko     }
1035468f884SOleksandr "Alex" Zinenko 
1045468f884SOleksandr "Alex" Zinenko     // Collect operations or results for extra bindings.
1055468f884SOleksandr "Alex" Zinenko     payloadRoot->walk([&](Operation *payload) {
1065468f884SOleksandr "Alex" Zinenko       for (auto &&[position, name] : llvm::enumerate(debugBindNames)) {
1075468f884SOleksandr "Alex" Zinenko         if (!name || payload->getName() != *name)
1085468f884SOleksandr "Alex" Zinenko           continue;
1095468f884SOleksandr "Alex" Zinenko 
1105468f884SOleksandr "Alex" Zinenko         if (StringRef(*std::next(debugBindTrailingArgs.begin(), position))
1115468f884SOleksandr "Alex" Zinenko                 .starts_with("^")) {
1125468f884SOleksandr "Alex" Zinenko           llvm::append_range(trailingBindings[position], payload->getResults());
1135468f884SOleksandr "Alex" Zinenko         } else {
1145468f884SOleksandr "Alex" Zinenko           trailingBindings[position].push_back(payload);
1155468f884SOleksandr "Alex" Zinenko         }
1165468f884SOleksandr "Alex" Zinenko       }
1175468f884SOleksandr "Alex" Zinenko     });
1185468f884SOleksandr "Alex" Zinenko 
1195468f884SOleksandr "Alex" Zinenko     RaggedArray<transform::MappedValue> bindings;
1205468f884SOleksandr "Alex" Zinenko     bindings.push_back(ArrayRef<Operation *>{payloadRoot});
1215468f884SOleksandr "Alex" Zinenko     for (SmallVector<transform::MappedValue> &trailing : trailingBindings)
1225468f884SOleksandr "Alex" Zinenko       bindings.push_back(std::move(trailing));
1235468f884SOleksandr "Alex" Zinenko     return bindings;
1245468f884SOleksandr "Alex" Zinenko   }
1255468f884SOleksandr "Alex" Zinenko 
1261bf08709SNicolas Vasilache public:
1271bf08709SNicolas Vasilache   using Base::Base;
1281bf08709SNicolas Vasilache 
runOnOperation()1291bf08709SNicolas Vasilache   void runOnOperation() override {
13022e3bf4eSIngo Müller     MLIRContext *context = &getContext();
13122e3bf4eSIngo Müller     ModuleOp transformModule =
13222e3bf4eSIngo Müller         transform::detail::getPreloadedTransformModule(context);
133e4384149SOleksandr "Alex" Zinenko     Operation *payloadRoot =
134e4384149SOleksandr "Alex" Zinenko         findPayloadRoot(getOperation(), debugPayloadRootTag);
135b33b91a2SOleksandr "Alex" Zinenko     if (!payloadRoot)
136b33b91a2SOleksandr "Alex" Zinenko       return signalPassFailure();
137b33b91a2SOleksandr "Alex" Zinenko 
138e4384149SOleksandr "Alex" Zinenko     Operation *transformEntryPoint = transform::detail::findTransformEntryPoint(
139e4384149SOleksandr "Alex" Zinenko         getOperation(), transformModule, entryPoint);
1405468f884SOleksandr "Alex" Zinenko     if (!transformEntryPoint)
1411bf08709SNicolas Vasilache       return signalPassFailure();
1421bf08709SNicolas Vasilache 
1435468f884SOleksandr "Alex" Zinenko     std::optional<RaggedArray<transform::MappedValue>> bindings =
1445468f884SOleksandr "Alex" Zinenko         parseArguments(payloadRoot);
1455468f884SOleksandr "Alex" Zinenko     if (!bindings)
1465468f884SOleksandr "Alex" Zinenko       return signalPassFailure();
147e4384149SOleksandr "Alex" Zinenko     if (failed(transform::applyTransformNamedSequence(
1485468f884SOleksandr "Alex" Zinenko             *bindings,
149b33b91a2SOleksandr "Alex" Zinenko             cast<transform::TransformOpInterface>(transformEntryPoint),
150b33b91a2SOleksandr "Alex" Zinenko             transformModule,
151e4384149SOleksandr "Alex" Zinenko             options.enableExpensiveChecks(!disableExpensiveChecks)))) {
152e4384149SOleksandr "Alex" Zinenko       return signalPassFailure();
153e4384149SOleksandr "Alex" Zinenko     }
154e4384149SOleksandr "Alex" Zinenko   }
155e4384149SOleksandr "Alex" Zinenko 
1561bf08709SNicolas Vasilache private:
1571bf08709SNicolas Vasilache   /// Transform interpreter options.
1581bf08709SNicolas Vasilache   transform::TransformOptions options;
1591bf08709SNicolas Vasilache };
1601bf08709SNicolas Vasilache } // namespace
161