xref: /llvm-project/mlir/unittests/Pass/PassManagerTest.cpp (revision 46708a5bcba28955b2ddeddf5c0e64398223642b)
19a4b30cfSRahul Joshi //===- PassManagerTest.cpp - PassManager unit tests -----------------------===//
29a4b30cfSRahul Joshi //
39a4b30cfSRahul Joshi // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49a4b30cfSRahul Joshi // See https://llvm.org/LICENSE.txt for license information.
59a4b30cfSRahul Joshi // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69a4b30cfSRahul Joshi //
79a4b30cfSRahul Joshi //===----------------------------------------------------------------------===//
89a4b30cfSRahul Joshi 
99a4b30cfSRahul Joshi #include "mlir/Pass/PassManager.h"
10*46708a5bSAman LaChapelle #include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h"
11*46708a5bSAman LaChapelle #include "mlir/Debug/ExecutionContext.h"
1236550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
139a4b30cfSRahul Joshi #include "mlir/IR/Builders.h"
1465fcddffSRiver Riddle #include "mlir/IR/BuiltinOps.h"
15809b4403SMehdi Amini #include "mlir/IR/Diagnostics.h"
169a4b30cfSRahul Joshi #include "mlir/Pass/Pass.h"
179a4b30cfSRahul Joshi #include "gtest/gtest.h"
189a4b30cfSRahul Joshi 
19e5639b3fSMehdi Amini #include <memory>
20e5639b3fSMehdi Amini 
219a4b30cfSRahul Joshi using namespace mlir;
229a4b30cfSRahul Joshi using namespace mlir::detail;
239a4b30cfSRahul Joshi 
249a4b30cfSRahul Joshi namespace {
259a4b30cfSRahul Joshi /// Analysis that operates on any operation.
269a4b30cfSRahul Joshi struct GenericAnalysis {
275e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GenericAnalysis)
285e50dd04SRiver Riddle 
GenericAnalysis__anon06cb5f700111::GenericAnalysis2958ceae95SRiver Riddle   GenericAnalysis(Operation *op) : isFunc(isa<func::FuncOp>(op)) {}
309a4b30cfSRahul Joshi   const bool isFunc;
319a4b30cfSRahul Joshi };
329a4b30cfSRahul Joshi 
339a4b30cfSRahul Joshi /// Analysis that operates on a specific operation.
349a4b30cfSRahul Joshi struct OpSpecificAnalysis {
355e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpSpecificAnalysis)
365e50dd04SRiver Riddle 
OpSpecificAnalysis__anon06cb5f700111::OpSpecificAnalysis3758ceae95SRiver Riddle   OpSpecificAnalysis(func::FuncOp op) : isSecret(op.getName() == "secret") {}
389a4b30cfSRahul Joshi   const bool isSecret;
399a4b30cfSRahul Joshi };
409a4b30cfSRahul Joshi 
4158ceae95SRiver Riddle /// Simple pass to annotate a func::FuncOp with the results of analysis.
429a4b30cfSRahul Joshi struct AnnotateFunctionPass
4358ceae95SRiver Riddle     : public PassWrapper<AnnotateFunctionPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon06cb5f700111::AnnotateFunctionPass445e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnnotateFunctionPass)
455e50dd04SRiver Riddle 
469a4b30cfSRahul Joshi   void runOnOperation() override {
4758ceae95SRiver Riddle     func::FuncOp op = getOperation();
48a1eb1544SChristian Sigg     Builder builder(op->getParentOfType<ModuleOp>());
499a4b30cfSRahul Joshi 
509a4b30cfSRahul Joshi     auto &ga = getAnalysis<GenericAnalysis>();
519a4b30cfSRahul Joshi     auto &sa = getAnalysis<OpSpecificAnalysis>();
529a4b30cfSRahul Joshi 
53a1eb1544SChristian Sigg     op->setAttr("isFunc", builder.getBoolAttr(ga.isFunc));
54a1eb1544SChristian Sigg     op->setAttr("isSecret", builder.getBoolAttr(sa.isSecret));
559a4b30cfSRahul Joshi   }
569a4b30cfSRahul Joshi };
579a4b30cfSRahul Joshi 
TEST(PassManagerTest,OpSpecificAnalysis)589a4b30cfSRahul Joshi TEST(PassManagerTest, OpSpecificAnalysis) {
599a4b30cfSRahul Joshi   MLIRContext context;
6036550692SRiver Riddle   context.loadDialect<func::FuncDialect>();
619a4b30cfSRahul Joshi   Builder builder(&context);
629a4b30cfSRahul Joshi 
639a4b30cfSRahul Joshi   // Create a module with 2 functions.
648f66ab1cSSanjoy Das   OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
659a4b30cfSRahul Joshi   for (StringRef name : {"secret", "not_secret"}) {
663a77eb66SKazu Hirata     auto func = func::FuncOp::create(
673a77eb66SKazu Hirata         builder.getUnknownLoc(), name,
683a77eb66SKazu Hirata         builder.getFunctionType(std::nullopt, std::nullopt));
69b7382ed3SRahul Joshi     func.setPrivate();
709a4b30cfSRahul Joshi     module->push_back(func);
719a4b30cfSRahul Joshi   }
729a4b30cfSRahul Joshi 
739a4b30cfSRahul Joshi   // Instantiate and run our pass.
7494a30928Srkayaith   auto pm = PassManager::on<ModuleOp>(&context);
7558ceae95SRiver Riddle   pm.addNestedPass<func::FuncOp>(std::make_unique<AnnotateFunctionPass>());
769a4b30cfSRahul Joshi   LogicalResult result = pm.run(module.get());
779a4b30cfSRahul Joshi   EXPECT_TRUE(succeeded(result));
789a4b30cfSRahul Joshi 
799a4b30cfSRahul Joshi   // Verify that each function got annotated with expected attributes.
8058ceae95SRiver Riddle   for (func::FuncOp func : module->getOps<func::FuncOp>()) {
81830b9b07SMehdi Amini     ASSERT_TRUE(isa<BoolAttr>(func->getDiscardableAttr("isFunc")));
82830b9b07SMehdi Amini     EXPECT_TRUE(cast<BoolAttr>(func->getDiscardableAttr("isFunc")).getValue());
839a4b30cfSRahul Joshi 
849a4b30cfSRahul Joshi     bool isSecret = func.getName() == "secret";
85830b9b07SMehdi Amini     ASSERT_TRUE(isa<BoolAttr>(func->getDiscardableAttr("isSecret")));
86830b9b07SMehdi Amini     EXPECT_EQ(cast<BoolAttr>(func->getDiscardableAttr("isSecret")).getValue(),
87830b9b07SMehdi Amini               isSecret);
889a4b30cfSRahul Joshi   }
899a4b30cfSRahul Joshi }
909a4b30cfSRahul Joshi 
91*46708a5bSAman LaChapelle /// Simple pass to annotate a func::FuncOp with a single attribute `didProcess`.
92*46708a5bSAman LaChapelle struct AddAttrFunctionPass
93*46708a5bSAman LaChapelle     : public PassWrapper<AddAttrFunctionPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon06cb5f700111::AddAttrFunctionPass94*46708a5bSAman LaChapelle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddAttrFunctionPass)
95*46708a5bSAman LaChapelle 
96*46708a5bSAman LaChapelle   void runOnOperation() override {
97*46708a5bSAman LaChapelle     func::FuncOp op = getOperation();
98*46708a5bSAman LaChapelle     Builder builder(op->getParentOfType<ModuleOp>());
99*46708a5bSAman LaChapelle     if (op->hasAttr("didProcess"))
100*46708a5bSAman LaChapelle       op->setAttr("didProcessAgain", builder.getUnitAttr());
101*46708a5bSAman LaChapelle 
102*46708a5bSAman LaChapelle     // We always want to set this one.
103*46708a5bSAman LaChapelle     op->setAttr("didProcess", builder.getUnitAttr());
104*46708a5bSAman LaChapelle   }
105*46708a5bSAman LaChapelle };
106*46708a5bSAman LaChapelle 
107*46708a5bSAman LaChapelle /// Simple pass to annotate a func::FuncOp with a single attribute
108*46708a5bSAman LaChapelle /// `didProcess2`.
109*46708a5bSAman LaChapelle struct AddSecondAttrFunctionPass
110*46708a5bSAman LaChapelle     : public PassWrapper<AddSecondAttrFunctionPass,
111*46708a5bSAman LaChapelle                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon06cb5f700111::AddSecondAttrFunctionPass112*46708a5bSAman LaChapelle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddSecondAttrFunctionPass)
113*46708a5bSAman LaChapelle 
114*46708a5bSAman LaChapelle   void runOnOperation() override {
115*46708a5bSAman LaChapelle     func::FuncOp op = getOperation();
116*46708a5bSAman LaChapelle     Builder builder(op->getParentOfType<ModuleOp>());
117*46708a5bSAman LaChapelle     op->setAttr("didProcess2", builder.getUnitAttr());
118*46708a5bSAman LaChapelle   }
119*46708a5bSAman LaChapelle };
120*46708a5bSAman LaChapelle 
TEST(PassManagerTest,ExecutionAction)121*46708a5bSAman LaChapelle TEST(PassManagerTest, ExecutionAction) {
122*46708a5bSAman LaChapelle   MLIRContext context;
123*46708a5bSAman LaChapelle   context.loadDialect<func::FuncDialect>();
124*46708a5bSAman LaChapelle   Builder builder(&context);
125*46708a5bSAman LaChapelle 
126*46708a5bSAman LaChapelle   // Create a module with 2 functions.
127*46708a5bSAman LaChapelle   OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
128*46708a5bSAman LaChapelle   auto f =
129*46708a5bSAman LaChapelle       func::FuncOp::create(builder.getUnknownLoc(), "process_me_once",
130*46708a5bSAman LaChapelle                            builder.getFunctionType(std::nullopt, std::nullopt));
131*46708a5bSAman LaChapelle   f.setPrivate();
132*46708a5bSAman LaChapelle   module->push_back(f);
133*46708a5bSAman LaChapelle 
134*46708a5bSAman LaChapelle   // Instantiate our passes.
135*46708a5bSAman LaChapelle   auto pm = PassManager::on<ModuleOp>(&context);
136*46708a5bSAman LaChapelle   auto pass = std::make_unique<AddAttrFunctionPass>();
137*46708a5bSAman LaChapelle   auto *passPtr = pass.get();
138*46708a5bSAman LaChapelle   pm.addNestedPass<func::FuncOp>(std::move(pass));
139*46708a5bSAman LaChapelle   pm.addNestedPass<func::FuncOp>(std::make_unique<AddSecondAttrFunctionPass>());
140*46708a5bSAman LaChapelle   // Duplicate the first pass to ensure that we *only* run the *first* pass, not
141*46708a5bSAman LaChapelle   // all instances of this pass kind. Notice that this pass (and the test as a
142*46708a5bSAman LaChapelle   // whole) are built to ensure that we can run just a single pass out of a
143*46708a5bSAman LaChapelle   // pipeline that may contain duplicates.
144*46708a5bSAman LaChapelle   pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
145*46708a5bSAman LaChapelle 
146*46708a5bSAman LaChapelle   // Use the action manager to only hit the first pass, not the second one.
147*46708a5bSAman LaChapelle   auto onBreakpoint = [&](const tracing::ActionActiveStack *backtrace)
148*46708a5bSAman LaChapelle       -> tracing::ExecutionContext::Control {
149*46708a5bSAman LaChapelle     // Not a PassExecutionAction, apply the action.
150*46708a5bSAman LaChapelle     auto *passExec = dyn_cast<PassExecutionAction>(&backtrace->getAction());
151*46708a5bSAman LaChapelle     if (!passExec)
152*46708a5bSAman LaChapelle       return tracing::ExecutionContext::Next;
153*46708a5bSAman LaChapelle 
154*46708a5bSAman LaChapelle     // If this isn't a function, apply the action.
155*46708a5bSAman LaChapelle     if (!isa<func::FuncOp>(passExec->getOp()))
156*46708a5bSAman LaChapelle       return tracing::ExecutionContext::Next;
157*46708a5bSAman LaChapelle 
158*46708a5bSAman LaChapelle     // Only apply the first function pass. Not all instances of the first pass,
159*46708a5bSAman LaChapelle     // only the first pass.
160*46708a5bSAman LaChapelle     if (passExec->getPass().getThreadingSiblingOrThis() == passPtr)
161*46708a5bSAman LaChapelle       return tracing::ExecutionContext::Next;
162*46708a5bSAman LaChapelle 
163*46708a5bSAman LaChapelle     // Do not apply any other passes in the pass manager.
164*46708a5bSAman LaChapelle     return tracing::ExecutionContext::Skip;
165*46708a5bSAman LaChapelle   };
166*46708a5bSAman LaChapelle 
167*46708a5bSAman LaChapelle   // Set up our breakpoint manager.
168*46708a5bSAman LaChapelle   tracing::TagBreakpointManager simpleManager;
169*46708a5bSAman LaChapelle   tracing::ExecutionContext executionCtx(onBreakpoint);
170*46708a5bSAman LaChapelle   executionCtx.addBreakpointManager(&simpleManager);
171*46708a5bSAman LaChapelle   simpleManager.addBreakpoint(PassExecutionAction::tag);
172*46708a5bSAman LaChapelle 
173*46708a5bSAman LaChapelle   // Register the execution context in the MLIRContext.
174*46708a5bSAman LaChapelle   context.registerActionHandler(executionCtx);
175*46708a5bSAman LaChapelle 
176*46708a5bSAman LaChapelle   // Run the pass manager, expecting our handler to be called.
177*46708a5bSAman LaChapelle   LogicalResult result = pm.run(module.get());
178*46708a5bSAman LaChapelle   EXPECT_TRUE(succeeded(result));
179*46708a5bSAman LaChapelle 
180*46708a5bSAman LaChapelle   // Verify that each function got annotated with `didProcess` and *not*
181*46708a5bSAman LaChapelle   // `didProcess2`.
182*46708a5bSAman LaChapelle   for (func::FuncOp func : module->getOps<func::FuncOp>()) {
183*46708a5bSAman LaChapelle     ASSERT_TRUE(func->getDiscardableAttr("didProcess"));
184*46708a5bSAman LaChapelle     ASSERT_FALSE(func->getDiscardableAttr("didProcess2"));
185*46708a5bSAman LaChapelle     ASSERT_FALSE(func->getDiscardableAttr("didProcessAgain"));
186*46708a5bSAman LaChapelle   }
187*46708a5bSAman LaChapelle }
188*46708a5bSAman LaChapelle 
1891284dc34SMehdi Amini namespace {
1901284dc34SMehdi Amini struct InvalidPass : Pass {
1915e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InvalidPass)
1925e50dd04SRiver Riddle 
InvalidPass__anon06cb5f700111::__anon06cb5f700311::InvalidPass1931284dc34SMehdi Amini   InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {}
getName__anon06cb5f700111::__anon06cb5f700311::InvalidPass1941284dc34SMehdi Amini   StringRef getName() const override { return "Invalid Pass"; }
runOnOperation__anon06cb5f700111::__anon06cb5f700311::InvalidPass1951284dc34SMehdi Amini   void runOnOperation() override {}
canScheduleOn__anon06cb5f700111::__anon06cb5f700311::InvalidPass1969c9a4317SRiver Riddle   bool canScheduleOn(RegisteredOperationName opName) const override {
1979c9a4317SRiver Riddle     return true;
1989c9a4317SRiver Riddle   }
1991284dc34SMehdi Amini 
2001284dc34SMehdi Amini   /// A clone method to create a copy of this pass.
clonePass__anon06cb5f700111::__anon06cb5f700311::InvalidPass2011284dc34SMehdi Amini   std::unique_ptr<Pass> clonePass() const override {
2021284dc34SMehdi Amini     return std::make_unique<InvalidPass>(
2031284dc34SMehdi Amini         *static_cast<const InvalidPass *>(this));
2041284dc34SMehdi Amini   }
2051284dc34SMehdi Amini };
206be0a7e9fSMehdi Amini } // namespace
2071284dc34SMehdi Amini 
TEST(PassManagerTest,InvalidPass)2081284dc34SMehdi Amini TEST(PassManagerTest, InvalidPass) {
2091284dc34SMehdi Amini   MLIRContext context;
2100f9e6451SMehdi Amini   context.allowUnregisteredDialects();
2111284dc34SMehdi Amini 
2121284dc34SMehdi Amini   // Create a module
2138f66ab1cSSanjoy Das   OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
2141284dc34SMehdi Amini 
2151284dc34SMehdi Amini   // Add a single "invalid_op" operation
2161284dc34SMehdi Amini   OpBuilder builder(&module->getBodyRegion());
2171284dc34SMehdi Amini   OperationState state(UnknownLoc::get(&context), "invalid_op");
2181284dc34SMehdi Amini   builder.insert(Operation::create(state));
2191284dc34SMehdi Amini 
2201284dc34SMehdi Amini   // Register a diagnostic handler to capture the diagnostic so that we can
2211284dc34SMehdi Amini   // check it later.
2221284dc34SMehdi Amini   std::unique_ptr<Diagnostic> diagnostic;
2231284dc34SMehdi Amini   context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
224e5639b3fSMehdi Amini     diagnostic = std::make_unique<Diagnostic>(std::move(diag));
2251284dc34SMehdi Amini   });
2261284dc34SMehdi Amini 
2271284dc34SMehdi Amini   // Instantiate and run our pass.
22894a30928Srkayaith   auto pm = PassManager::on<ModuleOp>(&context);
229008b9d97SMehdi Amini   pm.nest("invalid_op").addPass(std::make_unique<InvalidPass>());
2301284dc34SMehdi Amini   LogicalResult result = pm.run(module.get());
2311284dc34SMehdi Amini   EXPECT_TRUE(failed(result));
2321284dc34SMehdi Amini   ASSERT_TRUE(diagnostic.get() != nullptr);
2331284dc34SMehdi Amini   EXPECT_EQ(
2341284dc34SMehdi Amini       diagnostic->str(),
2351284dc34SMehdi Amini       "'invalid_op' op trying to schedule a pass on an unregistered operation");
236008b9d97SMehdi Amini 
237a8c1d9d6SMehdi Amini   // Check that clearing the pass manager effectively removed the pass.
238a8c1d9d6SMehdi Amini   pm.clear();
239a8c1d9d6SMehdi Amini   result = pm.run(module.get());
240a8c1d9d6SMehdi Amini   EXPECT_TRUE(succeeded(result));
241a8c1d9d6SMehdi Amini 
242008b9d97SMehdi Amini   // Check that adding the pass at the top-level triggers a fatal error.
24394a30928Srkayaith   ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()),
24494a30928Srkayaith                "Can't add pass 'Invalid Pass' restricted to 'invalid_op' on a "
24594a30928Srkayaith                "PassManager intended to run on 'builtin.module', did you "
24694a30928Srkayaith                "intend to nest?");
2471284dc34SMehdi Amini }
2481284dc34SMehdi Amini 
249809b4403SMehdi Amini /// Simple pass to annotate a func::FuncOp with the results of analysis.
250809b4403SMehdi Amini struct InitializeCheckingPass
251809b4403SMehdi Amini     : public PassWrapper<InitializeCheckingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon06cb5f700111::InitializeCheckingPass252809b4403SMehdi Amini   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InitializeCheckingPass)
253809b4403SMehdi Amini   LogicalResult initialize(MLIRContext *ctx) final {
254809b4403SMehdi Amini     initialized = true;
255809b4403SMehdi Amini     return success();
256809b4403SMehdi Amini   }
257809b4403SMehdi Amini   bool initialized = false;
258809b4403SMehdi Amini 
runOnOperation__anon06cb5f700111::InitializeCheckingPass259809b4403SMehdi Amini   void runOnOperation() override {
260809b4403SMehdi Amini     if (!initialized) {
261809b4403SMehdi Amini       getOperation()->emitError() << "Pass isn't initialized!";
262809b4403SMehdi Amini       signalPassFailure();
263809b4403SMehdi Amini     }
264809b4403SMehdi Amini   }
265809b4403SMehdi Amini };
266809b4403SMehdi Amini 
TEST(PassManagerTest,PassInitialization)267809b4403SMehdi Amini TEST(PassManagerTest, PassInitialization) {
268809b4403SMehdi Amini   MLIRContext context;
269809b4403SMehdi Amini   context.allowUnregisteredDialects();
270809b4403SMehdi Amini 
271809b4403SMehdi Amini   // Create a module
272809b4403SMehdi Amini   OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
273809b4403SMehdi Amini 
274809b4403SMehdi Amini   // Instantiate and run our pass.
275809b4403SMehdi Amini   auto pm = PassManager::on<ModuleOp>(&context);
276809b4403SMehdi Amini   pm.addPass(std::make_unique<InitializeCheckingPass>());
277809b4403SMehdi Amini   EXPECT_TRUE(succeeded(pm.run(module.get())));
278809b4403SMehdi Amini 
279809b4403SMehdi Amini   // Adding a second copy of the pass, we should also initialize it!
280809b4403SMehdi Amini   pm.addPass(std::make_unique<InitializeCheckingPass>());
281809b4403SMehdi Amini   EXPECT_TRUE(succeeded(pm.run(module.get())));
282809b4403SMehdi Amini }
283809b4403SMehdi Amini 
284be0a7e9fSMehdi Amini } // namespace
285