xref: /llvm-project/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp (revision 9ccf01fbf711f4f9b0aa82d5732d778ea169fa88)
1 //===- TestTransformDialectInterpreter.cpp --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines a test pass that interprets Transform dialect operations in
10 // the module.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TestTransformDialectExtension.h"
15 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
16 #include "mlir/Dialect/Transform/IR/TransformOps.h"
17 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/Pass/Pass.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 /// Simple pass that applies transform dialect ops directly contained in a
26 /// module.
27 
28 template <typename Derived>
29 class OpPassWrapper : public PassWrapper<Derived, OperationPass<>> {};
30 
31 class TestTransformDialectInterpreterPass
32     : public transform::TransformInterpreterPassBase<
33           TestTransformDialectInterpreterPass, OpPassWrapper> {
34 public:
35   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
36       TestTransformDialectInterpreterPass)
37 
38   TestTransformDialectInterpreterPass() = default;
39   TestTransformDialectInterpreterPass(
40       const TestTransformDialectInterpreterPass &pass)
41       : TransformInterpreterPassBase(pass) {}
42 
43   StringRef getArgument() const override {
44     return "test-transform-dialect-interpreter";
45   }
46 
47   StringRef getDescription() const override {
48     return "apply transform dialect operations one by one";
49   }
50 
51   void getDependentDialects(DialectRegistry &registry) const override {
52     registry.insert<transform::TransformDialect>();
53   }
54 
55   void findOperationsByName(Operation *root, StringRef name,
56                             SmallVectorImpl<Operation *> &operations) {
57     root->walk([&](Operation *op) {
58       if (op->getName().getStringRef() == name) {
59         operations.push_back(op);
60       }
61     });
62   }
63 
64   void createParameterMapping(MLIRContext &context, ArrayRef<int> values,
65                               RaggedArray<transform::MappedValue> &result) {
66     SmallVector<transform::MappedValue> storage =
67         llvm::to_vector(llvm::map_range(values, [&](int v) {
68           Builder b(&context);
69           return transform::MappedValue(b.getI64IntegerAttr(v));
70         }));
71     result.push_back(std::move(storage));
72   }
73 
74   void
75   createOpResultMapping(Operation *root, StringRef name,
76                         RaggedArray<transform::MappedValue> &extraMapping) {
77     SmallVector<Operation *> operations;
78     findOperationsByName(root, name, operations);
79     SmallVector<Value> results;
80     for (Operation *op : operations)
81       llvm::append_range(results, op->getResults());
82     extraMapping.push_back(results);
83   }
84 
85   unsigned numberOfSetOptions(const Option<std::string> &ops,
86                               const ListOption<int> &params,
87                               const Option<std::string> &values) {
88     unsigned numSetValues = 0;
89     numSetValues += !ops.empty();
90     numSetValues += !params.empty();
91     numSetValues += !values.empty();
92     return numSetValues;
93   }
94 
95   std::optional<LogicalResult> constructTransformModule(OpBuilder &builder,
96                                                         Location loc) {
97     if (!testModuleGeneration)
98       return std::nullopt;
99 
100     builder.create<transform::SequenceOp>(
101         loc, TypeRange(), transform::FailurePropagationMode::Propagate,
102         builder.getType<transform::AnyOpType>(),
103         [](OpBuilder &b, Location nested, Value rootH) {
104           b.create<mlir::test::TestPrintRemarkAtOperandOp>(
105               nested, rootH, "remark from generated");
106           b.create<transform::YieldOp>(nested, ValueRange());
107         });
108     return success();
109   }
110 
111   void runOnOperation() override {
112     unsigned firstSetOptions =
113         numberOfSetOptions(bindFirstExtraToOps, bindFirstExtraToParams,
114                            bindFirstExtraToResultsOfOps);
115     unsigned secondSetOptions =
116         numberOfSetOptions(bindSecondExtraToOps, bindSecondExtraToParams,
117                            bindSecondExtraToResultsOfOps);
118     auto loc = UnknownLoc::get(&getContext());
119     if (firstSetOptions > 1) {
120       emitError(loc) << "cannot bind the first extra top-level argument to "
121                         "multiple entities";
122       return signalPassFailure();
123     }
124     if (secondSetOptions > 1) {
125       emitError(loc) << "cannot bind the second extra top-level argument to "
126                         "multiple entities";
127       return signalPassFailure();
128     }
129     if (firstSetOptions == 0 && secondSetOptions != 0) {
130       emitError(loc) << "cannot bind the second extra top-level argument "
131                         "without bindings the first";
132     }
133 
134     RaggedArray<transform::MappedValue> extraMapping;
135     if (!bindFirstExtraToOps.empty()) {
136       SmallVector<Operation *> operations;
137       findOperationsByName(getOperation(), bindFirstExtraToOps.getValue(),
138                            operations);
139       extraMapping.push_back(operations);
140     } else if (!bindFirstExtraToParams.empty()) {
141       createParameterMapping(getContext(), bindFirstExtraToParams,
142                              extraMapping);
143     } else if (!bindFirstExtraToResultsOfOps.empty()) {
144       createOpResultMapping(getOperation(), bindFirstExtraToResultsOfOps,
145                             extraMapping);
146     }
147 
148     if (!bindSecondExtraToOps.empty()) {
149       SmallVector<Operation *> operations;
150       findOperationsByName(getOperation(), bindSecondExtraToOps, operations);
151       extraMapping.push_back(operations);
152     } else if (!bindSecondExtraToParams.empty()) {
153       createParameterMapping(getContext(), bindSecondExtraToParams,
154                              extraMapping);
155     } else if (!bindSecondExtraToResultsOfOps.empty()) {
156       createOpResultMapping(getOperation(), bindSecondExtraToResultsOfOps,
157                             extraMapping);
158     }
159 
160     options = options.enableExpensiveChecks(enableExpensiveChecks);
161     options = options.enableEnforceSingleToplevelTransformOp(
162         enforceSingleToplevelTransformOp);
163     if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
164             getOperation(), getArgument(), getSharedTransformModule(),
165             getTransformLibraryModule(), extraMapping, options,
166             transformFileName, transformLibraryPaths, debugPayloadRootTag,
167             debugTransformRootTag, getBinaryName())))
168       return signalPassFailure();
169   }
170 
171   Option<bool> enableExpensiveChecks{
172       *this, "enable-expensive-checks", llvm::cl::init(false),
173       llvm::cl::desc("perform expensive checks to better report errors in the "
174                      "transform IR")};
175   Option<bool> enforceSingleToplevelTransformOp{
176       *this, "enforce-single-top-level-transform-op", llvm::cl::init(true),
177       llvm::cl::desc("Ensure that only a single top-level transform op is "
178                      "present in the IR.")};
179 
180   Option<std::string> bindFirstExtraToOps{
181       *this, "bind-first-extra-to-ops",
182       llvm::cl::desc("bind the first extra argument of the top-level op to "
183                      "payload operations of the given kind")};
184   ListOption<int> bindFirstExtraToParams{
185       *this, "bind-first-extra-to-params",
186       llvm::cl::desc("bind the first extra argument of the top-level op to "
187                      "the given integer parameters")};
188   Option<std::string> bindFirstExtraToResultsOfOps{
189       *this, "bind-first-extra-to-results-of-ops",
190       llvm::cl::desc("bind the first extra argument of the top-level op to "
191                      "results of payload operations of the given kind")};
192 
193   Option<std::string> bindSecondExtraToOps{
194       *this, "bind-second-extra-to-ops",
195       llvm::cl::desc("bind the second extra argument of the top-level op to "
196                      "payload operations of the given kind")};
197   ListOption<int> bindSecondExtraToParams{
198       *this, "bind-second-extra-to-params",
199       llvm::cl::desc("bind the second extra argument of the top-level op to "
200                      "the given integer parameters")};
201   Option<std::string> bindSecondExtraToResultsOfOps{
202       *this, "bind-second-extra-to-results-of-ops",
203       llvm::cl::desc("bind the second extra argument of the top-level op to "
204                      "results of payload operations of the given kind")};
205 
206   Option<std::string> transformFileName{
207       *this, "transform-file-name", llvm::cl::init(""),
208       llvm::cl::desc(
209           "Optional filename containing a transform dialect specification to "
210           "apply. If left empty, the IR is assumed to contain one top-level "
211           "transform dialect operation somewhere in the module.")};
212   Option<std::string> debugPayloadRootTag{
213       *this, "debug-payload-root-tag", llvm::cl::init(""),
214       llvm::cl::desc(
215           "Select the operation with 'transform.target_tag' attribute having "
216           "the given value as payload IR root. If empty select the pass anchor "
217           "operation as the payload IR root.")};
218   Option<std::string> debugTransformRootTag{
219       *this, "debug-transform-root-tag", llvm::cl::init(""),
220       llvm::cl::desc(
221           "Select the operation with 'transform.target_tag' attribute having "
222           "the given value as container IR for top-level transform ops. This "
223           "allows user control on what transformation to apply. If empty, "
224           "select the container of the top-level transform op.")};
225   ListOption<std::string> transformLibraryPaths{
226       *this, "transform-library-paths", llvm::cl::ZeroOrMore,
227       llvm::cl::desc("Optional paths to files with modules that should be "
228                      "merged into the transform module to provide the "
229                      "definitions of external named sequences.")};
230 
231   Option<bool> testModuleGeneration{
232       *this, "test-module-generation", llvm::cl::init(false),
233       llvm::cl::desc("test the generation of the transform module during pass "
234                      "initialization, overridden by parsing")};
235 };
236 
237 struct TestTransformDialectEraseSchedulePass
238     : public PassWrapper<TestTransformDialectEraseSchedulePass,
239                          OperationPass<ModuleOp>> {
240   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
241       TestTransformDialectEraseSchedulePass)
242 
243   StringRef getArgument() const final {
244     return "test-transform-dialect-erase-schedule";
245   }
246 
247   StringRef getDescription() const final {
248     return "erase transform dialect schedule from the IR";
249   }
250 
251   void runOnOperation() override {
252     getOperation()->walk<WalkOrder::PreOrder>([&](Operation *nestedOp) {
253       if (isa<transform::TransformOpInterface>(nestedOp)) {
254         nestedOp->erase();
255         return WalkResult::skip();
256       }
257       return WalkResult::advance();
258     });
259   }
260 };
261 } // namespace
262 
263 namespace mlir {
264 namespace test {
265 /// Registers the test pass for erasing transform dialect ops.
266 void registerTestTransformDialectEraseSchedulePass() {
267   PassRegistration<TestTransformDialectEraseSchedulePass> reg;
268 }
269 /// Registers the test pass for applying transform dialect ops.
270 void registerTestTransformDialectInterpreterPass() {
271   PassRegistration<TestTransformDialectInterpreterPass> reg;
272 }
273 } // namespace test
274 } // namespace mlir
275