xref: /llvm-project/mlir/unittests/Pass/PassManagerTest.cpp (revision 46708a5bcba28955b2ddeddf5c0e64398223642b)
1 //===- PassManagerTest.cpp - PassManager unit tests -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Pass/PassManager.h"
10 #include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h"
11 #include "mlir/Debug/ExecutionContext.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/Pass/Pass.h"
17 #include "gtest/gtest.h"
18 
19 #include <memory>
20 
21 using namespace mlir;
22 using namespace mlir::detail;
23 
24 namespace {
25 /// Analysis that operates on any operation.
26 struct GenericAnalysis {
27   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GenericAnalysis)
28 
GenericAnalysis__anon06cb5f700111::GenericAnalysis29   GenericAnalysis(Operation *op) : isFunc(isa<func::FuncOp>(op)) {}
30   const bool isFunc;
31 };
32 
33 /// Analysis that operates on a specific operation.
34 struct OpSpecificAnalysis {
35   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpSpecificAnalysis)
36 
OpSpecificAnalysis__anon06cb5f700111::OpSpecificAnalysis37   OpSpecificAnalysis(func::FuncOp op) : isSecret(op.getName() == "secret") {}
38   const bool isSecret;
39 };
40 
41 /// Simple pass to annotate a func::FuncOp with the results of analysis.
42 struct AnnotateFunctionPass
43     : public PassWrapper<AnnotateFunctionPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon06cb5f700111::AnnotateFunctionPass44   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnnotateFunctionPass)
45 
46   void runOnOperation() override {
47     func::FuncOp op = getOperation();
48     Builder builder(op->getParentOfType<ModuleOp>());
49 
50     auto &ga = getAnalysis<GenericAnalysis>();
51     auto &sa = getAnalysis<OpSpecificAnalysis>();
52 
53     op->setAttr("isFunc", builder.getBoolAttr(ga.isFunc));
54     op->setAttr("isSecret", builder.getBoolAttr(sa.isSecret));
55   }
56 };
57 
TEST(PassManagerTest,OpSpecificAnalysis)58 TEST(PassManagerTest, OpSpecificAnalysis) {
59   MLIRContext context;
60   context.loadDialect<func::FuncDialect>();
61   Builder builder(&context);
62 
63   // Create a module with 2 functions.
64   OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
65   for (StringRef name : {"secret", "not_secret"}) {
66     auto func = func::FuncOp::create(
67         builder.getUnknownLoc(), name,
68         builder.getFunctionType(std::nullopt, std::nullopt));
69     func.setPrivate();
70     module->push_back(func);
71   }
72 
73   // Instantiate and run our pass.
74   auto pm = PassManager::on<ModuleOp>(&context);
75   pm.addNestedPass<func::FuncOp>(std::make_unique<AnnotateFunctionPass>());
76   LogicalResult result = pm.run(module.get());
77   EXPECT_TRUE(succeeded(result));
78 
79   // Verify that each function got annotated with expected attributes.
80   for (func::FuncOp func : module->getOps<func::FuncOp>()) {
81     ASSERT_TRUE(isa<BoolAttr>(func->getDiscardableAttr("isFunc")));
82     EXPECT_TRUE(cast<BoolAttr>(func->getDiscardableAttr("isFunc")).getValue());
83 
84     bool isSecret = func.getName() == "secret";
85     ASSERT_TRUE(isa<BoolAttr>(func->getDiscardableAttr("isSecret")));
86     EXPECT_EQ(cast<BoolAttr>(func->getDiscardableAttr("isSecret")).getValue(),
87               isSecret);
88   }
89 }
90 
91 /// Simple pass to annotate a func::FuncOp with a single attribute `didProcess`.
92 struct AddAttrFunctionPass
93     : public PassWrapper<AddAttrFunctionPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon06cb5f700111::AddAttrFunctionPass94   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddAttrFunctionPass)
95 
96   void runOnOperation() override {
97     func::FuncOp op = getOperation();
98     Builder builder(op->getParentOfType<ModuleOp>());
99     if (op->hasAttr("didProcess"))
100       op->setAttr("didProcessAgain", builder.getUnitAttr());
101 
102     // We always want to set this one.
103     op->setAttr("didProcess", builder.getUnitAttr());
104   }
105 };
106 
107 /// Simple pass to annotate a func::FuncOp with a single attribute
108 /// `didProcess2`.
109 struct AddSecondAttrFunctionPass
110     : public PassWrapper<AddSecondAttrFunctionPass,
111                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon06cb5f700111::AddSecondAttrFunctionPass112   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddSecondAttrFunctionPass)
113 
114   void runOnOperation() override {
115     func::FuncOp op = getOperation();
116     Builder builder(op->getParentOfType<ModuleOp>());
117     op->setAttr("didProcess2", builder.getUnitAttr());
118   }
119 };
120 
TEST(PassManagerTest,ExecutionAction)121 TEST(PassManagerTest, ExecutionAction) {
122   MLIRContext context;
123   context.loadDialect<func::FuncDialect>();
124   Builder builder(&context);
125 
126   // Create a module with 2 functions.
127   OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
128   auto f =
129       func::FuncOp::create(builder.getUnknownLoc(), "process_me_once",
130                            builder.getFunctionType(std::nullopt, std::nullopt));
131   f.setPrivate();
132   module->push_back(f);
133 
134   // Instantiate our passes.
135   auto pm = PassManager::on<ModuleOp>(&context);
136   auto pass = std::make_unique<AddAttrFunctionPass>();
137   auto *passPtr = pass.get();
138   pm.addNestedPass<func::FuncOp>(std::move(pass));
139   pm.addNestedPass<func::FuncOp>(std::make_unique<AddSecondAttrFunctionPass>());
140   // Duplicate the first pass to ensure that we *only* run the *first* pass, not
141   // all instances of this pass kind. Notice that this pass (and the test as a
142   // whole) are built to ensure that we can run just a single pass out of a
143   // pipeline that may contain duplicates.
144   pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
145 
146   // Use the action manager to only hit the first pass, not the second one.
147   auto onBreakpoint = [&](const tracing::ActionActiveStack *backtrace)
148       -> tracing::ExecutionContext::Control {
149     // Not a PassExecutionAction, apply the action.
150     auto *passExec = dyn_cast<PassExecutionAction>(&backtrace->getAction());
151     if (!passExec)
152       return tracing::ExecutionContext::Next;
153 
154     // If this isn't a function, apply the action.
155     if (!isa<func::FuncOp>(passExec->getOp()))
156       return tracing::ExecutionContext::Next;
157 
158     // Only apply the first function pass. Not all instances of the first pass,
159     // only the first pass.
160     if (passExec->getPass().getThreadingSiblingOrThis() == passPtr)
161       return tracing::ExecutionContext::Next;
162 
163     // Do not apply any other passes in the pass manager.
164     return tracing::ExecutionContext::Skip;
165   };
166 
167   // Set up our breakpoint manager.
168   tracing::TagBreakpointManager simpleManager;
169   tracing::ExecutionContext executionCtx(onBreakpoint);
170   executionCtx.addBreakpointManager(&simpleManager);
171   simpleManager.addBreakpoint(PassExecutionAction::tag);
172 
173   // Register the execution context in the MLIRContext.
174   context.registerActionHandler(executionCtx);
175 
176   // Run the pass manager, expecting our handler to be called.
177   LogicalResult result = pm.run(module.get());
178   EXPECT_TRUE(succeeded(result));
179 
180   // Verify that each function got annotated with `didProcess` and *not*
181   // `didProcess2`.
182   for (func::FuncOp func : module->getOps<func::FuncOp>()) {
183     ASSERT_TRUE(func->getDiscardableAttr("didProcess"));
184     ASSERT_FALSE(func->getDiscardableAttr("didProcess2"));
185     ASSERT_FALSE(func->getDiscardableAttr("didProcessAgain"));
186   }
187 }
188 
189 namespace {
190 struct InvalidPass : Pass {
191   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InvalidPass)
192 
InvalidPass__anon06cb5f700111::__anon06cb5f700311::InvalidPass193   InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {}
getName__anon06cb5f700111::__anon06cb5f700311::InvalidPass194   StringRef getName() const override { return "Invalid Pass"; }
runOnOperation__anon06cb5f700111::__anon06cb5f700311::InvalidPass195   void runOnOperation() override {}
canScheduleOn__anon06cb5f700111::__anon06cb5f700311::InvalidPass196   bool canScheduleOn(RegisteredOperationName opName) const override {
197     return true;
198   }
199 
200   /// A clone method to create a copy of this pass.
clonePass__anon06cb5f700111::__anon06cb5f700311::InvalidPass201   std::unique_ptr<Pass> clonePass() const override {
202     return std::make_unique<InvalidPass>(
203         *static_cast<const InvalidPass *>(this));
204   }
205 };
206 } // namespace
207 
TEST(PassManagerTest,InvalidPass)208 TEST(PassManagerTest, InvalidPass) {
209   MLIRContext context;
210   context.allowUnregisteredDialects();
211 
212   // Create a module
213   OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
214 
215   // Add a single "invalid_op" operation
216   OpBuilder builder(&module->getBodyRegion());
217   OperationState state(UnknownLoc::get(&context), "invalid_op");
218   builder.insert(Operation::create(state));
219 
220   // Register a diagnostic handler to capture the diagnostic so that we can
221   // check it later.
222   std::unique_ptr<Diagnostic> diagnostic;
223   context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
224     diagnostic = std::make_unique<Diagnostic>(std::move(diag));
225   });
226 
227   // Instantiate and run our pass.
228   auto pm = PassManager::on<ModuleOp>(&context);
229   pm.nest("invalid_op").addPass(std::make_unique<InvalidPass>());
230   LogicalResult result = pm.run(module.get());
231   EXPECT_TRUE(failed(result));
232   ASSERT_TRUE(diagnostic.get() != nullptr);
233   EXPECT_EQ(
234       diagnostic->str(),
235       "'invalid_op' op trying to schedule a pass on an unregistered operation");
236 
237   // Check that clearing the pass manager effectively removed the pass.
238   pm.clear();
239   result = pm.run(module.get());
240   EXPECT_TRUE(succeeded(result));
241 
242   // Check that adding the pass at the top-level triggers a fatal error.
243   ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()),
244                "Can't add pass 'Invalid Pass' restricted to 'invalid_op' on a "
245                "PassManager intended to run on 'builtin.module', did you "
246                "intend to nest?");
247 }
248 
249 /// Simple pass to annotate a func::FuncOp with the results of analysis.
250 struct InitializeCheckingPass
251     : public PassWrapper<InitializeCheckingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon06cb5f700111::InitializeCheckingPass252   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InitializeCheckingPass)
253   LogicalResult initialize(MLIRContext *ctx) final {
254     initialized = true;
255     return success();
256   }
257   bool initialized = false;
258 
runOnOperation__anon06cb5f700111::InitializeCheckingPass259   void runOnOperation() override {
260     if (!initialized) {
261       getOperation()->emitError() << "Pass isn't initialized!";
262       signalPassFailure();
263     }
264   }
265 };
266 
TEST(PassManagerTest,PassInitialization)267 TEST(PassManagerTest, PassInitialization) {
268   MLIRContext context;
269   context.allowUnregisteredDialects();
270 
271   // Create a module
272   OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
273 
274   // Instantiate and run our pass.
275   auto pm = PassManager::on<ModuleOp>(&context);
276   pm.addPass(std::make_unique<InitializeCheckingPass>());
277   EXPECT_TRUE(succeeded(pm.run(module.get())));
278 
279   // Adding a second copy of the pass, we should also initialize it!
280   pm.addPass(std::make_unique<InitializeCheckingPass>());
281   EXPECT_TRUE(succeeded(pm.run(module.get())));
282 }
283 
284 } // namespace
285