xref: /llvm-project/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp (revision 92c69468405283a9e084ea7bd18c451680d4777f)
1 //===- TestTransformDialectInterpreter.cpp --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines a test pass that interprets Transform dialect operations in
10 // the module.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
15 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/Pass/Pass.h"
19 
20 using namespace mlir;
21 
22 namespace {
23 /// Simple pass that applies transform dialect ops directly contained in a
24 /// module.
25 
26 template <typename Derived>
27 class OpPassWrapper : public PassWrapper<Derived, OperationPass<>> {};
28 
29 class TestTransformDialectInterpreterPass
30     : public transform::TransformInterpreterPassBase<
31           TestTransformDialectInterpreterPass, OpPassWrapper> {
32 public:
33   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
34       TestTransformDialectInterpreterPass)
35 
36   TestTransformDialectInterpreterPass() = default;
37   TestTransformDialectInterpreterPass(
38       const TestTransformDialectInterpreterPass &pass)
39       : TransformInterpreterPassBase(pass) {}
40 
41   StringRef getArgument() const override {
42     return "test-transform-dialect-interpreter";
43   }
44 
45   StringRef getDescription() const override {
46     return "apply transform dialect operations one by one";
47   }
48 
49   void findOperationsByName(Operation *root, StringRef name,
50                             SmallVectorImpl<Operation *> &operations) {
51     root->walk([&](Operation *op) {
52       if (op->getName().getStringRef() == name) {
53         operations.push_back(op);
54       }
55     });
56   }
57 
58   void createParameterMapping(MLIRContext &context, ArrayRef<int> values,
59                               RaggedArray<transform::MappedValue> &result) {
60     SmallVector<transform::MappedValue> storage =
61         llvm::to_vector(llvm::map_range(values, [&](int v) {
62           Builder b(&context);
63           return transform::MappedValue(b.getI64IntegerAttr(v));
64         }));
65     result.push_back(std::move(storage));
66   }
67 
68   void
69   createOpResultMapping(Operation *root, StringRef name,
70                         RaggedArray<transform::MappedValue> &extraMapping) {
71     SmallVector<Operation *> operations;
72     findOperationsByName(root, name, operations);
73     SmallVector<Value> results;
74     for (Operation *op : operations)
75       llvm::append_range(results, op->getResults());
76     extraMapping.push_back(results);
77   }
78 
79   unsigned numberOfSetOptions(const Option<std::string> &ops,
80                               const ListOption<int> &params,
81                               const Option<std::string> &values) {
82     unsigned numSetValues = 0;
83     numSetValues += !ops.empty();
84     numSetValues += !params.empty();
85     numSetValues += !values.empty();
86     return numSetValues;
87   }
88 
89   void runOnOperation() override {
90     unsigned firstSetOptions =
91         numberOfSetOptions(bindFirstExtraToOps, bindFirstExtraToParams,
92                            bindFirstExtraToResultsOfOps);
93     unsigned secondSetOptions =
94         numberOfSetOptions(bindSecondExtraToOps, bindSecondExtraToParams,
95                            bindSecondExtraToResultsOfOps);
96     auto loc = UnknownLoc::get(&getContext());
97     if (firstSetOptions > 1) {
98       emitError(loc) << "cannot bind the first extra top-level argument to "
99                         "multiple entities";
100       return signalPassFailure();
101     }
102     if (secondSetOptions > 1) {
103       emitError(loc) << "cannot bind the second extra top-level argument to "
104                         "multiple entities";
105       return signalPassFailure();
106     }
107     if (firstSetOptions == 0 && secondSetOptions != 0) {
108       emitError(loc) << "cannot bind the second extra top-level argument "
109                         "without bindings the first";
110     }
111 
112     RaggedArray<transform::MappedValue> extraMapping;
113     if (!bindFirstExtraToOps.empty()) {
114       SmallVector<Operation *> operations;
115       findOperationsByName(getOperation(), bindFirstExtraToOps.getValue(),
116                            operations);
117       extraMapping.push_back(operations);
118     } else if (!bindFirstExtraToParams.empty()) {
119       createParameterMapping(getContext(), bindFirstExtraToParams,
120                              extraMapping);
121     } else if (!bindFirstExtraToResultsOfOps.empty()) {
122       createOpResultMapping(getOperation(), bindFirstExtraToResultsOfOps,
123                             extraMapping);
124     }
125 
126     if (!bindSecondExtraToOps.empty()) {
127       SmallVector<Operation *> operations;
128       findOperationsByName(getOperation(), bindSecondExtraToOps, operations);
129       extraMapping.push_back(operations);
130     } else if (!bindSecondExtraToParams.empty()) {
131       createParameterMapping(getContext(), bindSecondExtraToParams,
132                              extraMapping);
133     } else if (!bindSecondExtraToResultsOfOps.empty()) {
134       createOpResultMapping(getOperation(), bindSecondExtraToResultsOfOps,
135                             extraMapping);
136     }
137 
138     options = options.enableExpensiveChecks(enableExpensiveChecks);
139     if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
140             getOperation(), getArgument(), getSharedTransformModule(),
141             getTransformLibraryModule(), extraMapping, options,
142             transformFileName, transformLibraryFileName, debugPayloadRootTag,
143             debugTransformRootTag, getBinaryName())))
144       return signalPassFailure();
145   }
146 
147   Option<bool> enableExpensiveChecks{
148       *this, "enable-expensive-checks", llvm::cl::init(false),
149       llvm::cl::desc("perform expensive checks to better report errors in the "
150                      "transform IR")};
151 
152   Option<std::string> bindFirstExtraToOps{
153       *this, "bind-first-extra-to-ops",
154       llvm::cl::desc("bind the first extra argument of the top-level op to "
155                      "payload operations of the given kind")};
156   ListOption<int> bindFirstExtraToParams{
157       *this, "bind-first-extra-to-params",
158       llvm::cl::desc("bind the first extra argument of the top-level op to "
159                      "the given integer parameters")};
160   Option<std::string> bindFirstExtraToResultsOfOps{
161       *this, "bind-first-extra-to-results-of-ops",
162       llvm::cl::desc("bind the first extra argument of the top-level op to "
163                      "results of payload operations of the given kind")};
164 
165   Option<std::string> bindSecondExtraToOps{
166       *this, "bind-second-extra-to-ops",
167       llvm::cl::desc("bind the second extra argument of the top-level op to "
168                      "payload operations of the given kind")};
169   ListOption<int> bindSecondExtraToParams{
170       *this, "bind-second-extra-to-params",
171       llvm::cl::desc("bind the second extra argument of the top-level op to "
172                      "the given integer parameters")};
173   Option<std::string> bindSecondExtraToResultsOfOps{
174       *this, "bind-second-extra-to-results-of-ops",
175       llvm::cl::desc("bind the second extra argument of the top-level op to "
176                      "results of payload operations of the given kind")};
177 
178   Option<std::string> transformFileName{
179       *this, "transform-file-name", llvm::cl::init(""),
180       llvm::cl::desc(
181           "Optional filename containing a transform dialect specification to "
182           "apply. If left empty, the IR is assumed to contain one top-level "
183           "transform dialect operation somewhere in the module.")};
184   Option<std::string> debugPayloadRootTag{
185       *this, "debug-payload-root-tag", llvm::cl::init(""),
186       llvm::cl::desc(
187           "Select the operation with 'transform.target_tag' attribute having "
188           "the given value as payload IR root. If empty select the pass anchor "
189           "operation as the payload IR root.")};
190   Option<std::string> debugTransformRootTag{
191       *this, "debug-transform-root-tag", llvm::cl::init(""),
192       llvm::cl::desc(
193           "Select the operation with 'transform.target_tag' attribute having "
194           "the given value as container IR for top-level transform ops. This "
195           "allows user control on what transformation to apply. If empty, "
196           "select the container of the top-level transform op.")};
197   Option<std::string> transformLibraryFileName{
198       *this, "transform-library-file-name", llvm::cl::init(""),
199       llvm::cl::desc(
200           "Optional name of the file containing transform dialect symbol "
201           "definitions to be injected into the transform module.")};
202 };
203 
204 struct TestTransformDialectEraseSchedulePass
205     : public PassWrapper<TestTransformDialectEraseSchedulePass,
206                          OperationPass<ModuleOp>> {
207   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
208       TestTransformDialectEraseSchedulePass)
209 
210   StringRef getArgument() const final {
211     return "test-transform-dialect-erase-schedule";
212   }
213 
214   StringRef getDescription() const final {
215     return "erase transform dialect schedule from the IR";
216   }
217 
218   void runOnOperation() override {
219     getOperation()->walk<WalkOrder::PreOrder>([&](Operation *nestedOp) {
220       if (isa<transform::TransformOpInterface>(nestedOp)) {
221         nestedOp->erase();
222         return WalkResult::skip();
223       }
224       return WalkResult::advance();
225     });
226   }
227 };
228 } // namespace
229 
230 namespace mlir {
231 namespace test {
232 /// Registers the test pass for erasing transform dialect ops.
233 void registerTestTransformDialectEraseSchedulePass() {
234   PassRegistration<TestTransformDialectEraseSchedulePass> reg;
235 }
236 /// Registers the test pass for applying transform dialect ops.
237 void registerTestTransformDialectInterpreterPass() {
238   PassRegistration<TestTransformDialectInterpreterPass> reg;
239 }
240 } // namespace test
241 } // namespace mlir
242