xref: /llvm-project/mlir/test/lib/Pass/TestDynamicPipeline.cpp (revision 36d3efea15e6202edd64b05de38d8379e2baddb2)
1 //===------ TestDynamicPipeline.cpp --- dynamic pipeline test pass --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test the dynamic pipeline feature.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Pass/PassManager.h"
16 
17 using namespace mlir;
18 
19 namespace {
20 
21 class TestDynamicPipelinePass
22     : public PassWrapper<TestDynamicPipelinePass, OperationPass<>> {
23 public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDynamicPipelinePass)24   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDynamicPipelinePass)
25 
26   StringRef getArgument() const final { return "test-dynamic-pipeline"; }
getDescription() const27   StringRef getDescription() const final {
28     return "Tests the dynamic pipeline feature by applying "
29            "a pipeline on a selected set of functions";
30   }
getDependentDialects(DialectRegistry & registry) const31   void getDependentDialects(DialectRegistry &registry) const override {
32     OpPassManager pm(ModuleOp::getOperationName(),
33                      OpPassManager::Nesting::Implicit);
34     (void)parsePassPipeline(pipeline, pm, llvm::errs());
35     pm.getDependentDialects(registry);
36   }
37 
38   TestDynamicPipelinePass() = default;
TestDynamicPipelinePass(const TestDynamicPipelinePass &)39   TestDynamicPipelinePass(const TestDynamicPipelinePass &) {}
40 
runOnOperation()41   void runOnOperation() override {
42     Operation *currentOp = getOperation();
43 
44     llvm::errs() << "Dynamic execute '" << pipeline << "' on "
45                  << currentOp->getName() << "\n";
46     if (pipeline.empty()) {
47       llvm::errs() << "Empty pipeline\n";
48       return;
49     }
50     auto symbolOp = dyn_cast<SymbolOpInterface>(currentOp);
51     if (!symbolOp) {
52       currentOp->emitWarning()
53           << "Ignoring because not implementing SymbolOpInterface\n";
54       return;
55     }
56 
57     auto opName = symbolOp.getName();
58     if (!opNames.empty() && !llvm::is_contained(opNames, opName)) {
59       llvm::errs() << "dynamic-pipeline skip op name: " << opName << "\n";
60       return;
61     }
62     OpPassManager pm(currentOp->getName().getIdentifier(),
63                      OpPassManager::Nesting::Implicit);
64     (void)parsePassPipeline(pipeline, pm, llvm::errs());
65 
66     // Check that running on the parent operation always immediately fails.
67     if (runOnParent) {
68       if (currentOp->getParentOp())
69         if (!failed(runPipeline(pm, currentOp->getParentOp())))
70           signalPassFailure();
71       return;
72     }
73 
74     if (runOnNestedOp) {
75       llvm::errs() << "Run on nested op\n";
76       currentOp->walk([&](Operation *op) {
77         if (op == currentOp || !op->hasTrait<OpTrait::IsIsolatedFromAbove>() ||
78             op->getName() != currentOp->getName())
79           return;
80         llvm::errs() << "Run on " << *op << "\n";
81         // Run on the current operation
82         if (failed(runPipeline(pm, op)))
83           signalPassFailure();
84       });
85     } else {
86       // Run on the current operation
87       if (failed(runPipeline(pm, currentOp)))
88         signalPassFailure();
89     }
90   }
91 
92   Option<bool> runOnNestedOp{
93       *this, "run-on-nested-operations",
94       llvm::cl::desc("This will apply the pipeline on nested operations under "
95                      "the visited operation.")};
96   Option<bool> runOnParent{
97       *this, "run-on-parent",
98       llvm::cl::desc("This will apply the pipeline on the parent operation if "
99                      "it exist, this is expected to fail.")};
100   Option<std::string> pipeline{
101       *this, "dynamic-pipeline",
102       llvm::cl::desc("The pipeline description that "
103                      "will run on the filtered function.")};
104   ListOption<std::string> opNames{
105       *this, "op-name",
106       llvm::cl::desc("List of function name to apply the pipeline to")};
107 };
108 } // namespace
109 
110 namespace mlir {
111 namespace test {
registerTestDynamicPipelinePass()112 void registerTestDynamicPipelinePass() {
113   PassRegistration<TestDynamicPipelinePass>();
114 }
115 } // namespace test
116 } // namespace mlir
117