xref: /llvm-project/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp (revision 5a9bdd85ee4d8527e2cedf44f3ce26ff414f9b6a)
1 //===- TestTransformDialectInterpreter.cpp --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines a test pass that interprets Transform dialect operations in
10 // the module.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TestTransformDialectExtension.h"
15 #include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.h"
16 #include "mlir/Dialect/Transform/IR/TransformOps.h"
17 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/Pass/Pass.h"
22 
23 using namespace mlir;
24 
25 namespace {
26 /// Simple pass that applies transform dialect ops directly contained in a
27 /// module.
28 
29 template <typename Derived>
30 class OpPassWrapper : public PassWrapper<Derived, OperationPass<>> {};
31 
32 class TestTransformDialectInterpreterPass
33     : public transform::TransformInterpreterPassBase<
34           TestTransformDialectInterpreterPass, OpPassWrapper> {
35 public:
36   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
37       TestTransformDialectInterpreterPass)
38 
39   TestTransformDialectInterpreterPass() = default;
40   TestTransformDialectInterpreterPass(
41       const TestTransformDialectInterpreterPass &pass)
42       : TransformInterpreterPassBase(pass) {}
43 
44   StringRef getArgument() const override {
45     return "test-transform-dialect-interpreter";
46   }
47 
48   StringRef getDescription() const override {
49     return "apply transform dialect operations one by one";
50   }
51 
52   void getDependentDialects(DialectRegistry &registry) const override {
53     registry.insert<transform::TransformDialect>();
54   }
55 
56   void findOperationsByName(Operation *root, StringRef name,
57                             SmallVectorImpl<Operation *> &operations) {
58     root->walk([&](Operation *op) {
59       if (op->getName().getStringRef() == name) {
60         operations.push_back(op);
61       }
62     });
63   }
64 
65   void createParameterMapping(MLIRContext &context, ArrayRef<int> values,
66                               RaggedArray<transform::MappedValue> &result) {
67     SmallVector<transform::MappedValue> storage =
68         llvm::to_vector(llvm::map_range(values, [&](int v) {
69           Builder b(&context);
70           return transform::MappedValue(b.getI64IntegerAttr(v));
71         }));
72     result.push_back(std::move(storage));
73   }
74 
75   void
76   createOpResultMapping(Operation *root, StringRef name,
77                         RaggedArray<transform::MappedValue> &extraMapping) {
78     SmallVector<Operation *> operations;
79     findOperationsByName(root, name, operations);
80     SmallVector<Value> results;
81     for (Operation *op : operations)
82       llvm::append_range(results, op->getResults());
83     extraMapping.push_back(results);
84   }
85 
86   unsigned numberOfSetOptions(const Option<std::string> &ops,
87                               const ListOption<int> &params,
88                               const Option<std::string> &values) {
89     unsigned numSetValues = 0;
90     numSetValues += !ops.empty();
91     numSetValues += !params.empty();
92     numSetValues += !values.empty();
93     return numSetValues;
94   }
95 
96   std::optional<LogicalResult> constructTransformModule(OpBuilder &builder,
97                                                         Location loc) {
98     if (!testModuleGeneration)
99       return std::nullopt;
100 
101     builder.create<transform::SequenceOp>(
102         loc, TypeRange(), transform::FailurePropagationMode::Propagate,
103         builder.getType<transform::AnyOpType>(),
104         [](OpBuilder &b, Location nested, Value rootH) {
105           b.create<transform::DebugEmitRemarkAtOp>(nested, rootH,
106                                                    "remark from generated");
107           b.create<transform::YieldOp>(nested, ValueRange());
108         });
109     return success();
110   }
111 
112   void runOnOperation() override {
113     unsigned firstSetOptions =
114         numberOfSetOptions(bindFirstExtraToOps, bindFirstExtraToParams,
115                            bindFirstExtraToResultsOfOps);
116     unsigned secondSetOptions =
117         numberOfSetOptions(bindSecondExtraToOps, bindSecondExtraToParams,
118                            bindSecondExtraToResultsOfOps);
119     auto loc = UnknownLoc::get(&getContext());
120     if (firstSetOptions > 1) {
121       emitError(loc) << "cannot bind the first extra top-level argument to "
122                         "multiple entities";
123       return signalPassFailure();
124     }
125     if (secondSetOptions > 1) {
126       emitError(loc) << "cannot bind the second extra top-level argument to "
127                         "multiple entities";
128       return signalPassFailure();
129     }
130     if (firstSetOptions == 0 && secondSetOptions != 0) {
131       emitError(loc) << "cannot bind the second extra top-level argument "
132                         "without bindings the first";
133     }
134 
135     RaggedArray<transform::MappedValue> extraMapping;
136     if (!bindFirstExtraToOps.empty()) {
137       SmallVector<Operation *> operations;
138       findOperationsByName(getOperation(), bindFirstExtraToOps.getValue(),
139                            operations);
140       extraMapping.push_back(operations);
141     } else if (!bindFirstExtraToParams.empty()) {
142       createParameterMapping(getContext(), bindFirstExtraToParams,
143                              extraMapping);
144     } else if (!bindFirstExtraToResultsOfOps.empty()) {
145       createOpResultMapping(getOperation(), bindFirstExtraToResultsOfOps,
146                             extraMapping);
147     }
148 
149     if (!bindSecondExtraToOps.empty()) {
150       SmallVector<Operation *> operations;
151       findOperationsByName(getOperation(), bindSecondExtraToOps, operations);
152       extraMapping.push_back(operations);
153     } else if (!bindSecondExtraToParams.empty()) {
154       createParameterMapping(getContext(), bindSecondExtraToParams,
155                              extraMapping);
156     } else if (!bindSecondExtraToResultsOfOps.empty()) {
157       createOpResultMapping(getOperation(), bindSecondExtraToResultsOfOps,
158                             extraMapping);
159     }
160 
161     options = options.enableExpensiveChecks(enableExpensiveChecks);
162     options = options.enableEnforceSingleToplevelTransformOp(
163         enforceSingleToplevelTransformOp);
164     if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
165             getOperation(), getArgument(), getSharedTransformModule(),
166             getTransformLibraryModule(), extraMapping, options,
167             transformFileName, transformLibraryPaths, debugPayloadRootTag,
168             debugTransformRootTag, getBinaryName())))
169       return signalPassFailure();
170   }
171 
172   Option<bool> enableExpensiveChecks{
173       *this, "enable-expensive-checks", llvm::cl::init(false),
174       llvm::cl::desc("perform expensive checks to better report errors in the "
175                      "transform IR")};
176   Option<bool> enforceSingleToplevelTransformOp{
177       *this, "enforce-single-top-level-transform-op", llvm::cl::init(true),
178       llvm::cl::desc("Ensure that only a single top-level transform op is "
179                      "present in the IR.")};
180 
181   Option<std::string> bindFirstExtraToOps{
182       *this, "bind-first-extra-to-ops",
183       llvm::cl::desc("bind the first extra argument of the top-level op to "
184                      "payload operations of the given kind")};
185   ListOption<int> bindFirstExtraToParams{
186       *this, "bind-first-extra-to-params",
187       llvm::cl::desc("bind the first extra argument of the top-level op to "
188                      "the given integer parameters")};
189   Option<std::string> bindFirstExtraToResultsOfOps{
190       *this, "bind-first-extra-to-results-of-ops",
191       llvm::cl::desc("bind the first extra argument of the top-level op to "
192                      "results of payload operations of the given kind")};
193 
194   Option<std::string> bindSecondExtraToOps{
195       *this, "bind-second-extra-to-ops",
196       llvm::cl::desc("bind the second extra argument of the top-level op to "
197                      "payload operations of the given kind")};
198   ListOption<int> bindSecondExtraToParams{
199       *this, "bind-second-extra-to-params",
200       llvm::cl::desc("bind the second extra argument of the top-level op to "
201                      "the given integer parameters")};
202   Option<std::string> bindSecondExtraToResultsOfOps{
203       *this, "bind-second-extra-to-results-of-ops",
204       llvm::cl::desc("bind the second extra argument of the top-level op to "
205                      "results of payload operations of the given kind")};
206 
207   Option<std::string> transformFileName{
208       *this, "transform-file-name", llvm::cl::init(""),
209       llvm::cl::desc(
210           "Optional filename containing a transform dialect specification to "
211           "apply. If left empty, the IR is assumed to contain one top-level "
212           "transform dialect operation somewhere in the module.")};
213   Option<std::string> debugPayloadRootTag{
214       *this, "debug-payload-root-tag", llvm::cl::init(""),
215       llvm::cl::desc(
216           "Select the operation with 'transform.target_tag' attribute having "
217           "the given value as payload IR root. If empty select the pass anchor "
218           "operation as the payload IR root.")};
219   Option<std::string> debugTransformRootTag{
220       *this, "debug-transform-root-tag", llvm::cl::init(""),
221       llvm::cl::desc(
222           "Select the operation with 'transform.target_tag' attribute having "
223           "the given value as container IR for top-level transform ops. This "
224           "allows user control on what transformation to apply. If empty, "
225           "select the container of the top-level transform op.")};
226   ListOption<std::string> transformLibraryPaths{
227       *this, "transform-library-paths", llvm::cl::ZeroOrMore,
228       llvm::cl::desc("Optional paths to files with modules that should be "
229                      "merged into the transform module to provide the "
230                      "definitions of external named sequences.")};
231 
232   Option<bool> testModuleGeneration{
233       *this, "test-module-generation", llvm::cl::init(false),
234       llvm::cl::desc("test the generation of the transform module during pass "
235                      "initialization, overridden by parsing")};
236 };
237 
238 struct TestTransformDialectEraseSchedulePass
239     : public PassWrapper<TestTransformDialectEraseSchedulePass,
240                          OperationPass<ModuleOp>> {
241   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
242       TestTransformDialectEraseSchedulePass)
243 
244   StringRef getArgument() const final {
245     return "test-transform-dialect-erase-schedule";
246   }
247 
248   StringRef getDescription() const final {
249     return "erase transform dialect schedule from the IR";
250   }
251 
252   void runOnOperation() override {
253     getOperation()->walk<WalkOrder::PreOrder>([&](Operation *nestedOp) {
254       if (isa<transform::TransformOpInterface>(nestedOp)) {
255         nestedOp->erase();
256         return WalkResult::skip();
257       }
258       return WalkResult::advance();
259     });
260   }
261 };
262 } // namespace
263 
264 namespace mlir {
265 namespace test {
266 /// Registers the test pass for erasing transform dialect ops.
267 void registerTestTransformDialectEraseSchedulePass() {
268   PassRegistration<TestTransformDialectEraseSchedulePass> reg;
269 }
270 /// Registers the test pass for applying transform dialect ops.
271 void registerTestTransformDialectInterpreterPass() {
272   PassRegistration<TestTransformDialectInterpreterPass> reg;
273 }
274 } // namespace test
275 } // namespace mlir
276