19a4b30cfSRahul Joshi //===- PassManagerTest.cpp - PassManager unit tests -----------------------===//
29a4b30cfSRahul Joshi //
39a4b30cfSRahul Joshi // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49a4b30cfSRahul Joshi // See https://llvm.org/LICENSE.txt for license information.
59a4b30cfSRahul Joshi // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69a4b30cfSRahul Joshi //
79a4b30cfSRahul Joshi //===----------------------------------------------------------------------===//
89a4b30cfSRahul Joshi
99a4b30cfSRahul Joshi #include "mlir/Pass/PassManager.h"
10*46708a5bSAman LaChapelle #include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h"
11*46708a5bSAman LaChapelle #include "mlir/Debug/ExecutionContext.h"
1236550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
139a4b30cfSRahul Joshi #include "mlir/IR/Builders.h"
1465fcddffSRiver Riddle #include "mlir/IR/BuiltinOps.h"
15809b4403SMehdi Amini #include "mlir/IR/Diagnostics.h"
169a4b30cfSRahul Joshi #include "mlir/Pass/Pass.h"
179a4b30cfSRahul Joshi #include "gtest/gtest.h"
189a4b30cfSRahul Joshi
19e5639b3fSMehdi Amini #include <memory>
20e5639b3fSMehdi Amini
219a4b30cfSRahul Joshi using namespace mlir;
229a4b30cfSRahul Joshi using namespace mlir::detail;
239a4b30cfSRahul Joshi
249a4b30cfSRahul Joshi namespace {
259a4b30cfSRahul Joshi /// Analysis that operates on any operation.
269a4b30cfSRahul Joshi struct GenericAnalysis {
275e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GenericAnalysis)
285e50dd04SRiver Riddle
GenericAnalysis__anon06cb5f700111::GenericAnalysis2958ceae95SRiver Riddle GenericAnalysis(Operation *op) : isFunc(isa<func::FuncOp>(op)) {}
309a4b30cfSRahul Joshi const bool isFunc;
319a4b30cfSRahul Joshi };
329a4b30cfSRahul Joshi
339a4b30cfSRahul Joshi /// Analysis that operates on a specific operation.
349a4b30cfSRahul Joshi struct OpSpecificAnalysis {
355e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpSpecificAnalysis)
365e50dd04SRiver Riddle
OpSpecificAnalysis__anon06cb5f700111::OpSpecificAnalysis3758ceae95SRiver Riddle OpSpecificAnalysis(func::FuncOp op) : isSecret(op.getName() == "secret") {}
389a4b30cfSRahul Joshi const bool isSecret;
399a4b30cfSRahul Joshi };
409a4b30cfSRahul Joshi
4158ceae95SRiver Riddle /// Simple pass to annotate a func::FuncOp with the results of analysis.
429a4b30cfSRahul Joshi struct AnnotateFunctionPass
4358ceae95SRiver Riddle : public PassWrapper<AnnotateFunctionPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon06cb5f700111::AnnotateFunctionPass445e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnnotateFunctionPass)
455e50dd04SRiver Riddle
469a4b30cfSRahul Joshi void runOnOperation() override {
4758ceae95SRiver Riddle func::FuncOp op = getOperation();
48a1eb1544SChristian Sigg Builder builder(op->getParentOfType<ModuleOp>());
499a4b30cfSRahul Joshi
509a4b30cfSRahul Joshi auto &ga = getAnalysis<GenericAnalysis>();
519a4b30cfSRahul Joshi auto &sa = getAnalysis<OpSpecificAnalysis>();
529a4b30cfSRahul Joshi
53a1eb1544SChristian Sigg op->setAttr("isFunc", builder.getBoolAttr(ga.isFunc));
54a1eb1544SChristian Sigg op->setAttr("isSecret", builder.getBoolAttr(sa.isSecret));
559a4b30cfSRahul Joshi }
569a4b30cfSRahul Joshi };
579a4b30cfSRahul Joshi
TEST(PassManagerTest,OpSpecificAnalysis)589a4b30cfSRahul Joshi TEST(PassManagerTest, OpSpecificAnalysis) {
599a4b30cfSRahul Joshi MLIRContext context;
6036550692SRiver Riddle context.loadDialect<func::FuncDialect>();
619a4b30cfSRahul Joshi Builder builder(&context);
629a4b30cfSRahul Joshi
639a4b30cfSRahul Joshi // Create a module with 2 functions.
648f66ab1cSSanjoy Das OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
659a4b30cfSRahul Joshi for (StringRef name : {"secret", "not_secret"}) {
663a77eb66SKazu Hirata auto func = func::FuncOp::create(
673a77eb66SKazu Hirata builder.getUnknownLoc(), name,
683a77eb66SKazu Hirata builder.getFunctionType(std::nullopt, std::nullopt));
69b7382ed3SRahul Joshi func.setPrivate();
709a4b30cfSRahul Joshi module->push_back(func);
719a4b30cfSRahul Joshi }
729a4b30cfSRahul Joshi
739a4b30cfSRahul Joshi // Instantiate and run our pass.
7494a30928Srkayaith auto pm = PassManager::on<ModuleOp>(&context);
7558ceae95SRiver Riddle pm.addNestedPass<func::FuncOp>(std::make_unique<AnnotateFunctionPass>());
769a4b30cfSRahul Joshi LogicalResult result = pm.run(module.get());
779a4b30cfSRahul Joshi EXPECT_TRUE(succeeded(result));
789a4b30cfSRahul Joshi
799a4b30cfSRahul Joshi // Verify that each function got annotated with expected attributes.
8058ceae95SRiver Riddle for (func::FuncOp func : module->getOps<func::FuncOp>()) {
81830b9b07SMehdi Amini ASSERT_TRUE(isa<BoolAttr>(func->getDiscardableAttr("isFunc")));
82830b9b07SMehdi Amini EXPECT_TRUE(cast<BoolAttr>(func->getDiscardableAttr("isFunc")).getValue());
839a4b30cfSRahul Joshi
849a4b30cfSRahul Joshi bool isSecret = func.getName() == "secret";
85830b9b07SMehdi Amini ASSERT_TRUE(isa<BoolAttr>(func->getDiscardableAttr("isSecret")));
86830b9b07SMehdi Amini EXPECT_EQ(cast<BoolAttr>(func->getDiscardableAttr("isSecret")).getValue(),
87830b9b07SMehdi Amini isSecret);
889a4b30cfSRahul Joshi }
899a4b30cfSRahul Joshi }
909a4b30cfSRahul Joshi
91*46708a5bSAman LaChapelle /// Simple pass to annotate a func::FuncOp with a single attribute `didProcess`.
92*46708a5bSAman LaChapelle struct AddAttrFunctionPass
93*46708a5bSAman LaChapelle : public PassWrapper<AddAttrFunctionPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon06cb5f700111::AddAttrFunctionPass94*46708a5bSAman LaChapelle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddAttrFunctionPass)
95*46708a5bSAman LaChapelle
96*46708a5bSAman LaChapelle void runOnOperation() override {
97*46708a5bSAman LaChapelle func::FuncOp op = getOperation();
98*46708a5bSAman LaChapelle Builder builder(op->getParentOfType<ModuleOp>());
99*46708a5bSAman LaChapelle if (op->hasAttr("didProcess"))
100*46708a5bSAman LaChapelle op->setAttr("didProcessAgain", builder.getUnitAttr());
101*46708a5bSAman LaChapelle
102*46708a5bSAman LaChapelle // We always want to set this one.
103*46708a5bSAman LaChapelle op->setAttr("didProcess", builder.getUnitAttr());
104*46708a5bSAman LaChapelle }
105*46708a5bSAman LaChapelle };
106*46708a5bSAman LaChapelle
107*46708a5bSAman LaChapelle /// Simple pass to annotate a func::FuncOp with a single attribute
108*46708a5bSAman LaChapelle /// `didProcess2`.
109*46708a5bSAman LaChapelle struct AddSecondAttrFunctionPass
110*46708a5bSAman LaChapelle : public PassWrapper<AddSecondAttrFunctionPass,
111*46708a5bSAman LaChapelle OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon06cb5f700111::AddSecondAttrFunctionPass112*46708a5bSAman LaChapelle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddSecondAttrFunctionPass)
113*46708a5bSAman LaChapelle
114*46708a5bSAman LaChapelle void runOnOperation() override {
115*46708a5bSAman LaChapelle func::FuncOp op = getOperation();
116*46708a5bSAman LaChapelle Builder builder(op->getParentOfType<ModuleOp>());
117*46708a5bSAman LaChapelle op->setAttr("didProcess2", builder.getUnitAttr());
118*46708a5bSAman LaChapelle }
119*46708a5bSAman LaChapelle };
120*46708a5bSAman LaChapelle
TEST(PassManagerTest,ExecutionAction)121*46708a5bSAman LaChapelle TEST(PassManagerTest, ExecutionAction) {
122*46708a5bSAman LaChapelle MLIRContext context;
123*46708a5bSAman LaChapelle context.loadDialect<func::FuncDialect>();
124*46708a5bSAman LaChapelle Builder builder(&context);
125*46708a5bSAman LaChapelle
126*46708a5bSAman LaChapelle // Create a module with 2 functions.
127*46708a5bSAman LaChapelle OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
128*46708a5bSAman LaChapelle auto f =
129*46708a5bSAman LaChapelle func::FuncOp::create(builder.getUnknownLoc(), "process_me_once",
130*46708a5bSAman LaChapelle builder.getFunctionType(std::nullopt, std::nullopt));
131*46708a5bSAman LaChapelle f.setPrivate();
132*46708a5bSAman LaChapelle module->push_back(f);
133*46708a5bSAman LaChapelle
134*46708a5bSAman LaChapelle // Instantiate our passes.
135*46708a5bSAman LaChapelle auto pm = PassManager::on<ModuleOp>(&context);
136*46708a5bSAman LaChapelle auto pass = std::make_unique<AddAttrFunctionPass>();
137*46708a5bSAman LaChapelle auto *passPtr = pass.get();
138*46708a5bSAman LaChapelle pm.addNestedPass<func::FuncOp>(std::move(pass));
139*46708a5bSAman LaChapelle pm.addNestedPass<func::FuncOp>(std::make_unique<AddSecondAttrFunctionPass>());
140*46708a5bSAman LaChapelle // Duplicate the first pass to ensure that we *only* run the *first* pass, not
141*46708a5bSAman LaChapelle // all instances of this pass kind. Notice that this pass (and the test as a
142*46708a5bSAman LaChapelle // whole) are built to ensure that we can run just a single pass out of a
143*46708a5bSAman LaChapelle // pipeline that may contain duplicates.
144*46708a5bSAman LaChapelle pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
145*46708a5bSAman LaChapelle
146*46708a5bSAman LaChapelle // Use the action manager to only hit the first pass, not the second one.
147*46708a5bSAman LaChapelle auto onBreakpoint = [&](const tracing::ActionActiveStack *backtrace)
148*46708a5bSAman LaChapelle -> tracing::ExecutionContext::Control {
149*46708a5bSAman LaChapelle // Not a PassExecutionAction, apply the action.
150*46708a5bSAman LaChapelle auto *passExec = dyn_cast<PassExecutionAction>(&backtrace->getAction());
151*46708a5bSAman LaChapelle if (!passExec)
152*46708a5bSAman LaChapelle return tracing::ExecutionContext::Next;
153*46708a5bSAman LaChapelle
154*46708a5bSAman LaChapelle // If this isn't a function, apply the action.
155*46708a5bSAman LaChapelle if (!isa<func::FuncOp>(passExec->getOp()))
156*46708a5bSAman LaChapelle return tracing::ExecutionContext::Next;
157*46708a5bSAman LaChapelle
158*46708a5bSAman LaChapelle // Only apply the first function pass. Not all instances of the first pass,
159*46708a5bSAman LaChapelle // only the first pass.
160*46708a5bSAman LaChapelle if (passExec->getPass().getThreadingSiblingOrThis() == passPtr)
161*46708a5bSAman LaChapelle return tracing::ExecutionContext::Next;
162*46708a5bSAman LaChapelle
163*46708a5bSAman LaChapelle // Do not apply any other passes in the pass manager.
164*46708a5bSAman LaChapelle return tracing::ExecutionContext::Skip;
165*46708a5bSAman LaChapelle };
166*46708a5bSAman LaChapelle
167*46708a5bSAman LaChapelle // Set up our breakpoint manager.
168*46708a5bSAman LaChapelle tracing::TagBreakpointManager simpleManager;
169*46708a5bSAman LaChapelle tracing::ExecutionContext executionCtx(onBreakpoint);
170*46708a5bSAman LaChapelle executionCtx.addBreakpointManager(&simpleManager);
171*46708a5bSAman LaChapelle simpleManager.addBreakpoint(PassExecutionAction::tag);
172*46708a5bSAman LaChapelle
173*46708a5bSAman LaChapelle // Register the execution context in the MLIRContext.
174*46708a5bSAman LaChapelle context.registerActionHandler(executionCtx);
175*46708a5bSAman LaChapelle
176*46708a5bSAman LaChapelle // Run the pass manager, expecting our handler to be called.
177*46708a5bSAman LaChapelle LogicalResult result = pm.run(module.get());
178*46708a5bSAman LaChapelle EXPECT_TRUE(succeeded(result));
179*46708a5bSAman LaChapelle
180*46708a5bSAman LaChapelle // Verify that each function got annotated with `didProcess` and *not*
181*46708a5bSAman LaChapelle // `didProcess2`.
182*46708a5bSAman LaChapelle for (func::FuncOp func : module->getOps<func::FuncOp>()) {
183*46708a5bSAman LaChapelle ASSERT_TRUE(func->getDiscardableAttr("didProcess"));
184*46708a5bSAman LaChapelle ASSERT_FALSE(func->getDiscardableAttr("didProcess2"));
185*46708a5bSAman LaChapelle ASSERT_FALSE(func->getDiscardableAttr("didProcessAgain"));
186*46708a5bSAman LaChapelle }
187*46708a5bSAman LaChapelle }
188*46708a5bSAman LaChapelle
1891284dc34SMehdi Amini namespace {
1901284dc34SMehdi Amini struct InvalidPass : Pass {
1915e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InvalidPass)
1925e50dd04SRiver Riddle
InvalidPass__anon06cb5f700111::__anon06cb5f700311::InvalidPass1931284dc34SMehdi Amini InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {}
getName__anon06cb5f700111::__anon06cb5f700311::InvalidPass1941284dc34SMehdi Amini StringRef getName() const override { return "Invalid Pass"; }
runOnOperation__anon06cb5f700111::__anon06cb5f700311::InvalidPass1951284dc34SMehdi Amini void runOnOperation() override {}
canScheduleOn__anon06cb5f700111::__anon06cb5f700311::InvalidPass1969c9a4317SRiver Riddle bool canScheduleOn(RegisteredOperationName opName) const override {
1979c9a4317SRiver Riddle return true;
1989c9a4317SRiver Riddle }
1991284dc34SMehdi Amini
2001284dc34SMehdi Amini /// A clone method to create a copy of this pass.
clonePass__anon06cb5f700111::__anon06cb5f700311::InvalidPass2011284dc34SMehdi Amini std::unique_ptr<Pass> clonePass() const override {
2021284dc34SMehdi Amini return std::make_unique<InvalidPass>(
2031284dc34SMehdi Amini *static_cast<const InvalidPass *>(this));
2041284dc34SMehdi Amini }
2051284dc34SMehdi Amini };
206be0a7e9fSMehdi Amini } // namespace
2071284dc34SMehdi Amini
TEST(PassManagerTest,InvalidPass)2081284dc34SMehdi Amini TEST(PassManagerTest, InvalidPass) {
2091284dc34SMehdi Amini MLIRContext context;
2100f9e6451SMehdi Amini context.allowUnregisteredDialects();
2111284dc34SMehdi Amini
2121284dc34SMehdi Amini // Create a module
2138f66ab1cSSanjoy Das OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
2141284dc34SMehdi Amini
2151284dc34SMehdi Amini // Add a single "invalid_op" operation
2161284dc34SMehdi Amini OpBuilder builder(&module->getBodyRegion());
2171284dc34SMehdi Amini OperationState state(UnknownLoc::get(&context), "invalid_op");
2181284dc34SMehdi Amini builder.insert(Operation::create(state));
2191284dc34SMehdi Amini
2201284dc34SMehdi Amini // Register a diagnostic handler to capture the diagnostic so that we can
2211284dc34SMehdi Amini // check it later.
2221284dc34SMehdi Amini std::unique_ptr<Diagnostic> diagnostic;
2231284dc34SMehdi Amini context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
224e5639b3fSMehdi Amini diagnostic = std::make_unique<Diagnostic>(std::move(diag));
2251284dc34SMehdi Amini });
2261284dc34SMehdi Amini
2271284dc34SMehdi Amini // Instantiate and run our pass.
22894a30928Srkayaith auto pm = PassManager::on<ModuleOp>(&context);
229008b9d97SMehdi Amini pm.nest("invalid_op").addPass(std::make_unique<InvalidPass>());
2301284dc34SMehdi Amini LogicalResult result = pm.run(module.get());
2311284dc34SMehdi Amini EXPECT_TRUE(failed(result));
2321284dc34SMehdi Amini ASSERT_TRUE(diagnostic.get() != nullptr);
2331284dc34SMehdi Amini EXPECT_EQ(
2341284dc34SMehdi Amini diagnostic->str(),
2351284dc34SMehdi Amini "'invalid_op' op trying to schedule a pass on an unregistered operation");
236008b9d97SMehdi Amini
237a8c1d9d6SMehdi Amini // Check that clearing the pass manager effectively removed the pass.
238a8c1d9d6SMehdi Amini pm.clear();
239a8c1d9d6SMehdi Amini result = pm.run(module.get());
240a8c1d9d6SMehdi Amini EXPECT_TRUE(succeeded(result));
241a8c1d9d6SMehdi Amini
242008b9d97SMehdi Amini // Check that adding the pass at the top-level triggers a fatal error.
24394a30928Srkayaith ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()),
24494a30928Srkayaith "Can't add pass 'Invalid Pass' restricted to 'invalid_op' on a "
24594a30928Srkayaith "PassManager intended to run on 'builtin.module', did you "
24694a30928Srkayaith "intend to nest?");
2471284dc34SMehdi Amini }
2481284dc34SMehdi Amini
249809b4403SMehdi Amini /// Simple pass to annotate a func::FuncOp with the results of analysis.
250809b4403SMehdi Amini struct InitializeCheckingPass
251809b4403SMehdi Amini : public PassWrapper<InitializeCheckingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon06cb5f700111::InitializeCheckingPass252809b4403SMehdi Amini MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InitializeCheckingPass)
253809b4403SMehdi Amini LogicalResult initialize(MLIRContext *ctx) final {
254809b4403SMehdi Amini initialized = true;
255809b4403SMehdi Amini return success();
256809b4403SMehdi Amini }
257809b4403SMehdi Amini bool initialized = false;
258809b4403SMehdi Amini
runOnOperation__anon06cb5f700111::InitializeCheckingPass259809b4403SMehdi Amini void runOnOperation() override {
260809b4403SMehdi Amini if (!initialized) {
261809b4403SMehdi Amini getOperation()->emitError() << "Pass isn't initialized!";
262809b4403SMehdi Amini signalPassFailure();
263809b4403SMehdi Amini }
264809b4403SMehdi Amini }
265809b4403SMehdi Amini };
266809b4403SMehdi Amini
TEST(PassManagerTest,PassInitialization)267809b4403SMehdi Amini TEST(PassManagerTest, PassInitialization) {
268809b4403SMehdi Amini MLIRContext context;
269809b4403SMehdi Amini context.allowUnregisteredDialects();
270809b4403SMehdi Amini
271809b4403SMehdi Amini // Create a module
272809b4403SMehdi Amini OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
273809b4403SMehdi Amini
274809b4403SMehdi Amini // Instantiate and run our pass.
275809b4403SMehdi Amini auto pm = PassManager::on<ModuleOp>(&context);
276809b4403SMehdi Amini pm.addPass(std::make_unique<InitializeCheckingPass>());
277809b4403SMehdi Amini EXPECT_TRUE(succeeded(pm.run(module.get())));
278809b4403SMehdi Amini
279809b4403SMehdi Amini // Adding a second copy of the pass, we should also initialize it!
280809b4403SMehdi Amini pm.addPass(std::make_unique<InitializeCheckingPass>());
281809b4403SMehdi Amini EXPECT_TRUE(succeeded(pm.run(module.get())));
282809b4403SMehdi Amini }
283809b4403SMehdi Amini
284be0a7e9fSMehdi Amini } // namespace
285