xref: /llvm-project/mlir/test/lib/Pass/TestPassManager.cpp (revision 5b21fd298cb4fc2042a95ffb9284b778f8504e04)
1 //===- TestPassManager.cpp - Test pass manager functionality --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Pass/PassManager.h"
15 
16 using namespace mlir;
17 
18 namespace {
19 struct TestModulePass
20     : public PassWrapper<TestModulePass, OperationPass<ModuleOp>> {
21   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestModulePass)
22 
23   void runOnOperation() final {}
24   StringRef getArgument() const final { return "test-module-pass"; }
25   StringRef getDescription() const final {
26     return "Test a module pass in the pass manager";
27   }
28 };
29 struct TestFunctionPass
30     : public PassWrapper<TestFunctionPass, OperationPass<func::FuncOp>> {
31   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFunctionPass)
32 
33   void runOnOperation() final {}
34   StringRef getArgument() const final { return "test-function-pass"; }
35   StringRef getDescription() const final {
36     return "Test a function pass in the pass manager";
37   }
38 };
39 struct TestInterfacePass
40     : public PassWrapper<TestInterfacePass,
41                          InterfacePass<FunctionOpInterface>> {
42   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestInterfacePass)
43 
44   void runOnOperation() final {
45     getOperation()->emitRemark() << "Executing interface pass on operation";
46   }
47   StringRef getArgument() const final { return "test-interface-pass"; }
48   StringRef getDescription() const final {
49     return "Test an interface pass (running on FunctionOpInterface) in the "
50            "pass manager";
51   }
52 };
53 struct TestOptionsPass
54     : public PassWrapper<TestOptionsPass, OperationPass<func::FuncOp>> {
55   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOptionsPass)
56 
57   enum Enum { Zero, One, Two };
58 
59   struct Options : public PassPipelineOptions<Options> {
60     ListOption<int> listOption{*this, "list",
61                                llvm::cl::desc("Example list option")};
62     ListOption<std::string> stringListOption{
63         *this, "string-list", llvm::cl::desc("Example string list option")};
64     Option<std::string> stringOption{*this, "string",
65                                      llvm::cl::desc("Example string option")};
66     Option<Enum> enumOption{
67         *this, "enum", llvm::cl::desc("Example enum option"),
68         llvm::cl::values(clEnumValN(0, "zero", "Example zero value"),
69                          clEnumValN(1, "one", "Example one value"),
70                          clEnumValN(2, "two", "Example two value"))};
71 
72     Options() = default;
73     Options(const Options &rhs) { *this = rhs; }
74     Options &operator=(const Options &rhs) {
75       copyOptionValuesFrom(rhs);
76       return *this;
77     }
78   };
79   TestOptionsPass() = default;
80   TestOptionsPass(const TestOptionsPass &) : PassWrapper() {}
81   TestOptionsPass(const Options &options) {
82     listOption = options.listOption;
83     stringOption = options.stringOption;
84     stringListOption = options.stringListOption;
85     enumOption = options.enumOption;
86   }
87 
88   void runOnOperation() final {}
89   StringRef getArgument() const final { return "test-options-pass"; }
90   StringRef getDescription() const final {
91     return "Test options parsing capabilities";
92   }
93 
94   ListOption<int> listOption{*this, "list",
95                              llvm::cl::desc("Example list option")};
96   ListOption<std::string> stringListOption{
97       *this, "string-list", llvm::cl::desc("Example string list option")};
98   Option<std::string> stringOption{*this, "string",
99                                    llvm::cl::desc("Example string option")};
100   Option<Enum> enumOption{
101       *this, "enum", llvm::cl::desc("Example enum option"),
102       llvm::cl::values(clEnumValN(0, "zero", "Example zero value"),
103                        clEnumValN(1, "one", "Example one value"),
104                        clEnumValN(2, "two", "Example two value"))};
105 };
106 
107 struct TestOptionsSuperPass
108     : public PassWrapper<TestOptionsSuperPass, OperationPass<func::FuncOp>> {
109   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOptionsSuperPass)
110 
111   struct Options : public PassPipelineOptions<Options> {
112     ListOption<TestOptionsPass::Options> listOption{
113         *this, "super-list",
114         llvm::cl::desc("Example list of PassPipelineOptions option")};
115 
116     Options() = default;
117   };
118 
119   TestOptionsSuperPass() = default;
120   TestOptionsSuperPass(const TestOptionsSuperPass &) : PassWrapper() {}
121   TestOptionsSuperPass(const Options &options) {
122     listOption = options.listOption;
123   }
124 
125   void runOnOperation() final {}
126   StringRef getArgument() const final { return "test-options-super-pass"; }
127   StringRef getDescription() const final {
128     return "Test options of options parsing capabilities";
129   }
130 
131   ListOption<TestOptionsPass::Options> listOption{
132       *this, "list",
133       llvm::cl::desc("Example list of PassPipelineOptions option")};
134 };
135 
136 /// A test pass that always aborts to enable testing the crash recovery
137 /// mechanism of the pass manager.
138 struct TestCrashRecoveryPass
139     : public PassWrapper<TestCrashRecoveryPass, OperationPass<>> {
140   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCrashRecoveryPass)
141 
142   void runOnOperation() final { abort(); }
143   StringRef getArgument() const final { return "test-pass-crash"; }
144   StringRef getDescription() const final {
145     return "Test a pass in the pass manager that always crashes";
146   }
147 };
148 
149 /// A test pass that always fails to enable testing the failure recovery
150 /// mechanisms of the pass manager.
151 struct TestFailurePass : public PassWrapper<TestFailurePass, OperationPass<>> {
152   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFailurePass)
153 
154   TestFailurePass() = default;
155   TestFailurePass(const TestFailurePass &other) : PassWrapper(other) {}
156 
157   void runOnOperation() final {
158     signalPassFailure();
159     if (genDiagnostics)
160       mlir::emitError(getOperation()->getLoc(), "illegal operation");
161   }
162   StringRef getArgument() const final { return "test-pass-failure"; }
163   StringRef getDescription() const final {
164     return "Test a pass in the pass manager that always fails";
165   }
166 
167   Option<bool> genDiagnostics{*this, "gen-diagnostics",
168                               llvm::cl::desc("Generate a diagnostic message")};
169 };
170 
171 /// A test pass that creates an invalid operation in a function body.
172 struct TestInvalidIRPass
173     : public PassWrapper<TestInvalidIRPass,
174                          InterfacePass<FunctionOpInterface>> {
175   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestInvalidIRPass)
176 
177   TestInvalidIRPass() = default;
178   TestInvalidIRPass(const TestInvalidIRPass &other) : PassWrapper(other) {}
179 
180   StringRef getArgument() const final { return "test-pass-create-invalid-ir"; }
181   StringRef getDescription() const final {
182     return "Test pass that adds an invalid operation in a function body";
183   }
184   void getDependentDialects(DialectRegistry &registry) const final {
185     registry.insert<test::TestDialect>();
186   }
187   void runOnOperation() final {
188     if (signalFailure)
189       signalPassFailure();
190     if (!emitInvalidIR)
191       return;
192     OpBuilder b(getOperation().getFunctionBody());
193     OperationState state(b.getUnknownLoc(), "test.any_attr_of_i32_str");
194     b.create(state);
195   }
196   Option<bool> signalFailure{*this, "signal-pass-failure",
197                              llvm::cl::desc("Trigger a pass failure")};
198   Option<bool> emitInvalidIR{*this, "emit-invalid-ir", llvm::cl::init(true),
199                              llvm::cl::desc("Emit invalid IR")};
200 };
201 
202 /// A test pass that always fails to enable testing the failure recovery
203 /// mechanisms of the pass manager.
204 struct TestInvalidParentPass
205     : public PassWrapper<TestInvalidParentPass,
206                          InterfacePass<FunctionOpInterface>> {
207   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestInvalidParentPass)
208 
209   StringRef getArgument() const final { return "test-pass-invalid-parent"; }
210   StringRef getDescription() const final {
211     return "Test a pass in the pass manager that makes the parent operation "
212            "invalid";
213   }
214   void getDependentDialects(DialectRegistry &registry) const final {
215     registry.insert<test::TestDialect>();
216   }
217   void runOnOperation() final {
218     FunctionOpInterface op = getOperation();
219     OpBuilder b(op.getFunctionBody());
220     b.create<test::TestCallOp>(op.getLoc(), TypeRange(), "some_unknown_func",
221                                ValueRange());
222   }
223 };
224 
225 /// A test pass that contains a statistic.
226 struct TestStatisticPass
227     : public PassWrapper<TestStatisticPass, OperationPass<>> {
228   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStatisticPass)
229 
230   TestStatisticPass() = default;
231   TestStatisticPass(const TestStatisticPass &) : PassWrapper() {}
232   StringRef getArgument() const final { return "test-stats-pass"; }
233   StringRef getDescription() const final { return "Test pass statistics"; }
234 
235   // Use a couple of statistics to verify their ordering
236   // in the print out. The statistics are registered in the order
237   // of construction, so put "num-ops2" before "num-ops" and
238   // make sure that the order is reversed.
239   Statistic opCountDuplicate{this, "num-ops2",
240                              "Number of operations counted one more time"};
241   Statistic opCount{this, "num-ops", "Number of operations counted"};
242 
243   void runOnOperation() final {
244     getOperation()->walk([&](Operation *) { ++opCount; });
245     getOperation()->walk([&](Operation *) { ++opCountDuplicate; });
246   }
247 };
248 } // namespace
249 
250 static void testNestedPipeline(OpPassManager &pm) {
251   // Nest a module pipeline that contains:
252   /// A module pass.
253   auto &modulePM = pm.nest<ModuleOp>();
254   modulePM.addPass(std::make_unique<TestModulePass>());
255   /// A nested function pass.
256   auto &nestedFunctionPM = modulePM.nest<func::FuncOp>();
257   nestedFunctionPM.addPass(std::make_unique<TestFunctionPass>());
258 
259   // Nest a function pipeline that contains a single pass.
260   auto &functionPM = pm.nest<func::FuncOp>();
261   functionPM.addPass(std::make_unique<TestFunctionPass>());
262 }
263 
264 static void testNestedPipelineTextual(OpPassManager &pm) {
265   (void)parsePassPipeline("test-pm-nested-pipeline", pm);
266 }
267 
268 namespace mlir {
269 void registerPassManagerTestPass() {
270   PassRegistration<TestOptionsPass>();
271   PassRegistration<TestOptionsSuperPass>();
272 
273   PassRegistration<TestModulePass>();
274 
275   PassRegistration<TestFunctionPass>();
276 
277   PassRegistration<TestInterfacePass>();
278 
279   PassRegistration<TestCrashRecoveryPass>();
280   PassRegistration<TestFailurePass>();
281   PassRegistration<TestInvalidIRPass>();
282   PassRegistration<TestInvalidParentPass>();
283 
284   PassRegistration<TestStatisticPass>();
285 
286   PassPipelineRegistration<>("test-pm-nested-pipeline",
287                              "Test a nested pipeline in the pass manager",
288                              testNestedPipeline);
289   PassPipelineRegistration<>("test-textual-pm-nested-pipeline",
290                              "Test a nested pipeline in the pass manager",
291                              testNestedPipelineTextual);
292 
293   PassPipelineRegistration<TestOptionsPass::Options>
294       registerOptionsPassPipeline(
295           "test-options-pass-pipeline",
296           "Parses options using pass pipeline registration",
297           [](OpPassManager &pm, const TestOptionsPass::Options &options) {
298             pm.addPass(std::make_unique<TestOptionsPass>(options));
299           });
300 
301   PassPipelineRegistration<TestOptionsSuperPass::Options>
302       registerOptionsSuperPassPipeline(
303           "test-options-super-pass-pipeline",
304           "Parses options of PassPipelineOptions using pass pipeline "
305           "registration",
306           [](OpPassManager &pm, const TestOptionsSuperPass::Options &options) {
307             pm.addPass(std::make_unique<TestOptionsSuperPass>(options));
308           });
309 }
310 } // namespace mlir
311