xref: /llvm-project/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp (revision 5a9bdd85ee4d8527e2cedf44f3ce26ff414f9b6a)
1 //===- InterpreterPass.cpp - Transform dialect interpreter pass -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
10 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
11 #include "mlir/Dialect/Transform/Transforms/Passes.h"
12 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
13 
14 using namespace mlir;
15 
16 namespace mlir {
17 namespace transform {
18 #define GEN_PASS_DEF_INTERPRETERPASS
19 #include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
20 } // namespace transform
21 } // namespace mlir
22 
23 /// Returns the payload operation to be used as payload root:
24 ///   - the operation nested under `passRoot` that has the given tag attribute,
25 ///     must be unique;
26 ///   - the `passRoot` itself if the tag is empty.
findPayloadRoot(Operation * passRoot,StringRef tag)27 static Operation *findPayloadRoot(Operation *passRoot, StringRef tag) {
28   // Fast return.
29   if (tag.empty())
30     return passRoot;
31 
32   // Walk to do a lookup.
33   Operation *target = nullptr;
34   auto tagAttrName = StringAttr::get(
35       passRoot->getContext(), transform::TransformDialect::kTargetTagAttrName);
36   WalkResult walkResult = passRoot->walk([&](Operation *op) {
37     auto attr = op->getAttrOfType<StringAttr>(tagAttrName);
38     if (!attr || attr.getValue() != tag)
39       return WalkResult::advance();
40 
41     if (!target) {
42       target = op;
43       return WalkResult::advance();
44     }
45 
46     InFlightDiagnostic diag = op->emitError()
47                               << "repeated operation with the target tag '"
48                               << tag << "'";
49     diag.attachNote(target->getLoc()) << "previously seen operation";
50     return WalkResult::interrupt();
51   });
52 
53   if (!target) {
54     passRoot->emitError()
55         << "could not find the operation with transform.target_tag=\"" << tag
56         << "\" attribute";
57     return nullptr;
58   }
59 
60   return walkResult.wasInterrupted() ? nullptr : target;
61 }
62 
63 namespace {
64 class InterpreterPass
65     : public transform::impl::InterpreterPassBase<InterpreterPass> {
66   // Parses the pass arguments to bind trailing arguments of the entry point.
67   std::optional<RaggedArray<transform::MappedValue>>
parseArguments(Operation * payloadRoot)68   parseArguments(Operation *payloadRoot) {
69     MLIRContext *context = payloadRoot->getContext();
70 
71     SmallVector<SmallVector<transform::MappedValue>, 2> trailingBindings;
72     trailingBindings.resize(debugBindTrailingArgs.size());
73 
74     // Construct lists of op names to match.
75     SmallVector<std::optional<OperationName>> debugBindNames;
76     debugBindNames.reserve(debugBindTrailingArgs.size());
77     for (auto &&[position, nameString] :
78          llvm::enumerate(debugBindTrailingArgs)) {
79       StringRef name = nameString;
80 
81       // Parse the integer literals.
82       if (name.starts_with("#")) {
83         debugBindNames.push_back(std::nullopt);
84         StringRef lhs = "";
85         StringRef rhs = name.drop_front();
86         do {
87           std::tie(lhs, rhs) = rhs.split(';');
88           int64_t value;
89           if (lhs.getAsInteger(10, value)) {
90             emitError(UnknownLoc::get(context))
91                 << "couldn't parse integer pass argument " << name;
92             return std::nullopt;
93           }
94           trailingBindings[position].push_back(
95               Builder(context).getI64IntegerAttr(value));
96         } while (!rhs.empty());
97       } else if (name.starts_with("^")) {
98         debugBindNames.emplace_back(OperationName(name.drop_front(), context));
99       } else {
100         debugBindNames.emplace_back(OperationName(name, context));
101       }
102     }
103 
104     // Collect operations or results for extra bindings.
105     payloadRoot->walk([&](Operation *payload) {
106       for (auto &&[position, name] : llvm::enumerate(debugBindNames)) {
107         if (!name || payload->getName() != *name)
108           continue;
109 
110         if (StringRef(*std::next(debugBindTrailingArgs.begin(), position))
111                 .starts_with("^")) {
112           llvm::append_range(trailingBindings[position], payload->getResults());
113         } else {
114           trailingBindings[position].push_back(payload);
115         }
116       }
117     });
118 
119     RaggedArray<transform::MappedValue> bindings;
120     bindings.push_back(ArrayRef<Operation *>{payloadRoot});
121     for (SmallVector<transform::MappedValue> &trailing : trailingBindings)
122       bindings.push_back(std::move(trailing));
123     return bindings;
124   }
125 
126 public:
127   using Base::Base;
128 
runOnOperation()129   void runOnOperation() override {
130     MLIRContext *context = &getContext();
131     ModuleOp transformModule =
132         transform::detail::getPreloadedTransformModule(context);
133     Operation *payloadRoot =
134         findPayloadRoot(getOperation(), debugPayloadRootTag);
135     if (!payloadRoot)
136       return signalPassFailure();
137 
138     Operation *transformEntryPoint = transform::detail::findTransformEntryPoint(
139         getOperation(), transformModule, entryPoint);
140     if (!transformEntryPoint)
141       return signalPassFailure();
142 
143     std::optional<RaggedArray<transform::MappedValue>> bindings =
144         parseArguments(payloadRoot);
145     if (!bindings)
146       return signalPassFailure();
147     if (failed(transform::applyTransformNamedSequence(
148             *bindings,
149             cast<transform::TransformOpInterface>(transformEntryPoint),
150             transformModule,
151             options.enableExpensiveChecks(!disableExpensiveChecks)))) {
152       return signalPassFailure();
153     }
154   }
155 
156 private:
157   /// Transform interpreter options.
158   transform::TransformOptions options;
159 };
160 } // namespace
161