11bf08709SNicolas Vasilache //===- InterpreterPass.cpp - Transform dialect interpreter pass -----------===//
21bf08709SNicolas Vasilache //
31bf08709SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
41bf08709SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information.
51bf08709SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
61bf08709SNicolas Vasilache //
71bf08709SNicolas Vasilache //===----------------------------------------------------------------------===//
81bf08709SNicolas Vasilache
91bf08709SNicolas Vasilache #include "mlir/Dialect/Transform/IR/TransformDialect.h"
10*5a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
111bf08709SNicolas Vasilache #include "mlir/Dialect/Transform/Transforms/Passes.h"
121bf08709SNicolas Vasilache #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
131bf08709SNicolas Vasilache
141bf08709SNicolas Vasilache using namespace mlir;
151bf08709SNicolas Vasilache
161bf08709SNicolas Vasilache namespace mlir {
171bf08709SNicolas Vasilache namespace transform {
181bf08709SNicolas Vasilache #define GEN_PASS_DEF_INTERPRETERPASS
191bf08709SNicolas Vasilache #include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
201bf08709SNicolas Vasilache } // namespace transform
211bf08709SNicolas Vasilache } // namespace mlir
221bf08709SNicolas Vasilache
23e4384149SOleksandr "Alex" Zinenko /// Returns the payload operation to be used as payload root:
24e4384149SOleksandr "Alex" Zinenko /// - the operation nested under `passRoot` that has the given tag attribute,
25e4384149SOleksandr "Alex" Zinenko /// must be unique;
26e4384149SOleksandr "Alex" Zinenko /// - the `passRoot` itself if the tag is empty.
findPayloadRoot(Operation * passRoot,StringRef tag)27e4384149SOleksandr "Alex" Zinenko static Operation *findPayloadRoot(Operation *passRoot, StringRef tag) {
28e4384149SOleksandr "Alex" Zinenko // Fast return.
29e4384149SOleksandr "Alex" Zinenko if (tag.empty())
30e4384149SOleksandr "Alex" Zinenko return passRoot;
31e4384149SOleksandr "Alex" Zinenko
32e4384149SOleksandr "Alex" Zinenko // Walk to do a lookup.
33e4384149SOleksandr "Alex" Zinenko Operation *target = nullptr;
34e4384149SOleksandr "Alex" Zinenko auto tagAttrName = StringAttr::get(
35e4384149SOleksandr "Alex" Zinenko passRoot->getContext(), transform::TransformDialect::kTargetTagAttrName);
36e4384149SOleksandr "Alex" Zinenko WalkResult walkResult = passRoot->walk([&](Operation *op) {
37e4384149SOleksandr "Alex" Zinenko auto attr = op->getAttrOfType<StringAttr>(tagAttrName);
38e4384149SOleksandr "Alex" Zinenko if (!attr || attr.getValue() != tag)
39e4384149SOleksandr "Alex" Zinenko return WalkResult::advance();
40e4384149SOleksandr "Alex" Zinenko
41e4384149SOleksandr "Alex" Zinenko if (!target) {
42e4384149SOleksandr "Alex" Zinenko target = op;
43e4384149SOleksandr "Alex" Zinenko return WalkResult::advance();
44e4384149SOleksandr "Alex" Zinenko }
45e4384149SOleksandr "Alex" Zinenko
46e4384149SOleksandr "Alex" Zinenko InFlightDiagnostic diag = op->emitError()
47e4384149SOleksandr "Alex" Zinenko << "repeated operation with the target tag '"
48e4384149SOleksandr "Alex" Zinenko << tag << "'";
49e4384149SOleksandr "Alex" Zinenko diag.attachNote(target->getLoc()) << "previously seen operation";
50e4384149SOleksandr "Alex" Zinenko return WalkResult::interrupt();
51e4384149SOleksandr "Alex" Zinenko });
52e4384149SOleksandr "Alex" Zinenko
535468f884SOleksandr "Alex" Zinenko if (!target) {
545468f884SOleksandr "Alex" Zinenko passRoot->emitError()
555468f884SOleksandr "Alex" Zinenko << "could not find the operation with transform.target_tag=\"" << tag
565468f884SOleksandr "Alex" Zinenko << "\" attribute";
575468f884SOleksandr "Alex" Zinenko return nullptr;
585468f884SOleksandr "Alex" Zinenko }
595468f884SOleksandr "Alex" Zinenko
60e4384149SOleksandr "Alex" Zinenko return walkResult.wasInterrupted() ? nullptr : target;
61e4384149SOleksandr "Alex" Zinenko }
62e4384149SOleksandr "Alex" Zinenko
631bf08709SNicolas Vasilache namespace {
641bf08709SNicolas Vasilache class InterpreterPass
651bf08709SNicolas Vasilache : public transform::impl::InterpreterPassBase<InterpreterPass> {
665468f884SOleksandr "Alex" Zinenko // Parses the pass arguments to bind trailing arguments of the entry point.
675468f884SOleksandr "Alex" Zinenko std::optional<RaggedArray<transform::MappedValue>>
parseArguments(Operation * payloadRoot)685468f884SOleksandr "Alex" Zinenko parseArguments(Operation *payloadRoot) {
695468f884SOleksandr "Alex" Zinenko MLIRContext *context = payloadRoot->getContext();
705468f884SOleksandr "Alex" Zinenko
715468f884SOleksandr "Alex" Zinenko SmallVector<SmallVector<transform::MappedValue>, 2> trailingBindings;
725468f884SOleksandr "Alex" Zinenko trailingBindings.resize(debugBindTrailingArgs.size());
735468f884SOleksandr "Alex" Zinenko
745468f884SOleksandr "Alex" Zinenko // Construct lists of op names to match.
755468f884SOleksandr "Alex" Zinenko SmallVector<std::optional<OperationName>> debugBindNames;
765468f884SOleksandr "Alex" Zinenko debugBindNames.reserve(debugBindTrailingArgs.size());
775468f884SOleksandr "Alex" Zinenko for (auto &&[position, nameString] :
785468f884SOleksandr "Alex" Zinenko llvm::enumerate(debugBindTrailingArgs)) {
795468f884SOleksandr "Alex" Zinenko StringRef name = nameString;
805468f884SOleksandr "Alex" Zinenko
815468f884SOleksandr "Alex" Zinenko // Parse the integer literals.
825468f884SOleksandr "Alex" Zinenko if (name.starts_with("#")) {
835468f884SOleksandr "Alex" Zinenko debugBindNames.push_back(std::nullopt);
845468f884SOleksandr "Alex" Zinenko StringRef lhs = "";
855468f884SOleksandr "Alex" Zinenko StringRef rhs = name.drop_front();
865468f884SOleksandr "Alex" Zinenko do {
875468f884SOleksandr "Alex" Zinenko std::tie(lhs, rhs) = rhs.split(';');
885468f884SOleksandr "Alex" Zinenko int64_t value;
895468f884SOleksandr "Alex" Zinenko if (lhs.getAsInteger(10, value)) {
905468f884SOleksandr "Alex" Zinenko emitError(UnknownLoc::get(context))
915468f884SOleksandr "Alex" Zinenko << "couldn't parse integer pass argument " << name;
925468f884SOleksandr "Alex" Zinenko return std::nullopt;
935468f884SOleksandr "Alex" Zinenko }
945468f884SOleksandr "Alex" Zinenko trailingBindings[position].push_back(
955468f884SOleksandr "Alex" Zinenko Builder(context).getI64IntegerAttr(value));
965468f884SOleksandr "Alex" Zinenko } while (!rhs.empty());
975468f884SOleksandr "Alex" Zinenko } else if (name.starts_with("^")) {
985468f884SOleksandr "Alex" Zinenko debugBindNames.emplace_back(OperationName(name.drop_front(), context));
995468f884SOleksandr "Alex" Zinenko } else {
1005468f884SOleksandr "Alex" Zinenko debugBindNames.emplace_back(OperationName(name, context));
1015468f884SOleksandr "Alex" Zinenko }
1025468f884SOleksandr "Alex" Zinenko }
1035468f884SOleksandr "Alex" Zinenko
1045468f884SOleksandr "Alex" Zinenko // Collect operations or results for extra bindings.
1055468f884SOleksandr "Alex" Zinenko payloadRoot->walk([&](Operation *payload) {
1065468f884SOleksandr "Alex" Zinenko for (auto &&[position, name] : llvm::enumerate(debugBindNames)) {
1075468f884SOleksandr "Alex" Zinenko if (!name || payload->getName() != *name)
1085468f884SOleksandr "Alex" Zinenko continue;
1095468f884SOleksandr "Alex" Zinenko
1105468f884SOleksandr "Alex" Zinenko if (StringRef(*std::next(debugBindTrailingArgs.begin(), position))
1115468f884SOleksandr "Alex" Zinenko .starts_with("^")) {
1125468f884SOleksandr "Alex" Zinenko llvm::append_range(trailingBindings[position], payload->getResults());
1135468f884SOleksandr "Alex" Zinenko } else {
1145468f884SOleksandr "Alex" Zinenko trailingBindings[position].push_back(payload);
1155468f884SOleksandr "Alex" Zinenko }
1165468f884SOleksandr "Alex" Zinenko }
1175468f884SOleksandr "Alex" Zinenko });
1185468f884SOleksandr "Alex" Zinenko
1195468f884SOleksandr "Alex" Zinenko RaggedArray<transform::MappedValue> bindings;
1205468f884SOleksandr "Alex" Zinenko bindings.push_back(ArrayRef<Operation *>{payloadRoot});
1215468f884SOleksandr "Alex" Zinenko for (SmallVector<transform::MappedValue> &trailing : trailingBindings)
1225468f884SOleksandr "Alex" Zinenko bindings.push_back(std::move(trailing));
1235468f884SOleksandr "Alex" Zinenko return bindings;
1245468f884SOleksandr "Alex" Zinenko }
1255468f884SOleksandr "Alex" Zinenko
1261bf08709SNicolas Vasilache public:
1271bf08709SNicolas Vasilache using Base::Base;
1281bf08709SNicolas Vasilache
runOnOperation()1291bf08709SNicolas Vasilache void runOnOperation() override {
13022e3bf4eSIngo Müller MLIRContext *context = &getContext();
13122e3bf4eSIngo Müller ModuleOp transformModule =
13222e3bf4eSIngo Müller transform::detail::getPreloadedTransformModule(context);
133e4384149SOleksandr "Alex" Zinenko Operation *payloadRoot =
134e4384149SOleksandr "Alex" Zinenko findPayloadRoot(getOperation(), debugPayloadRootTag);
135b33b91a2SOleksandr "Alex" Zinenko if (!payloadRoot)
136b33b91a2SOleksandr "Alex" Zinenko return signalPassFailure();
137b33b91a2SOleksandr "Alex" Zinenko
138e4384149SOleksandr "Alex" Zinenko Operation *transformEntryPoint = transform::detail::findTransformEntryPoint(
139e4384149SOleksandr "Alex" Zinenko getOperation(), transformModule, entryPoint);
1405468f884SOleksandr "Alex" Zinenko if (!transformEntryPoint)
1411bf08709SNicolas Vasilache return signalPassFailure();
1421bf08709SNicolas Vasilache
1435468f884SOleksandr "Alex" Zinenko std::optional<RaggedArray<transform::MappedValue>> bindings =
1445468f884SOleksandr "Alex" Zinenko parseArguments(payloadRoot);
1455468f884SOleksandr "Alex" Zinenko if (!bindings)
1465468f884SOleksandr "Alex" Zinenko return signalPassFailure();
147e4384149SOleksandr "Alex" Zinenko if (failed(transform::applyTransformNamedSequence(
1485468f884SOleksandr "Alex" Zinenko *bindings,
149b33b91a2SOleksandr "Alex" Zinenko cast<transform::TransformOpInterface>(transformEntryPoint),
150b33b91a2SOleksandr "Alex" Zinenko transformModule,
151e4384149SOleksandr "Alex" Zinenko options.enableExpensiveChecks(!disableExpensiveChecks)))) {
152e4384149SOleksandr "Alex" Zinenko return signalPassFailure();
153e4384149SOleksandr "Alex" Zinenko }
154e4384149SOleksandr "Alex" Zinenko }
155e4384149SOleksandr "Alex" Zinenko
1561bf08709SNicolas Vasilache private:
1571bf08709SNicolas Vasilache /// Transform interpreter options.
1581bf08709SNicolas Vasilache transform::TransformOptions options;
1591bf08709SNicolas Vasilache };
1601bf08709SNicolas Vasilache } // namespace
161