xref: /llvm-project/mlir/test/lib/IR/TestVisitorsGeneric.cpp (revision e95e94adc6bb748de015ac3053e7f0786b65f351)
1 //===- TestIRVisitorsGeneric.cpp - Pass to test the Generic IR visitors ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestOps.h"
10 #include "mlir/Pass/Pass.h"
11 
12 using namespace mlir;
13 
getStageDescription(const WalkStage & stage)14 static std::string getStageDescription(const WalkStage &stage) {
15   if (stage.isBeforeAllRegions())
16     return "before all regions";
17   if (stage.isAfterAllRegions())
18     return "after all regions";
19   return "before region #" + std::to_string(stage.getNextRegion());
20 }
21 
22 namespace {
23 /// This pass exercises generic visitor with void callbacks and prints the order
24 /// and stage in which operations are visited.
25 struct TestGenericIRVisitorPass
26     : public PassWrapper<TestGenericIRVisitorPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon059fcff50111::TestGenericIRVisitorPass27   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGenericIRVisitorPass)
28 
29   StringRef getArgument() const final { return "test-generic-ir-visitors"; }
getDescription__anon059fcff50111::TestGenericIRVisitorPass30   StringRef getDescription() const final { return "Test generic IR visitors."; }
runOnOperation__anon059fcff50111::TestGenericIRVisitorPass31   void runOnOperation() override {
32     Operation *outerOp = getOperation();
33     int stepNo = 0;
34     outerOp->walk([&](Operation *op, const WalkStage &stage) {
35       llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
36                    << getStageDescription(stage) << "\n";
37     });
38 
39     // Exercise static inference of operation type.
40     outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
41       llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
42                    << getStageDescription(stage) << "\n";
43     });
44   }
45 };
46 
47 /// This pass exercises the generic visitor with non-void callbacks and prints
48 /// the order and stage in which operations are visited. It will interrupt the
49 /// walk based on attributes peesent in the IR.
50 struct TestGenericIRVisitorInterruptPass
51     : public PassWrapper<TestGenericIRVisitorInterruptPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon059fcff50111::TestGenericIRVisitorInterruptPass52   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
53       TestGenericIRVisitorInterruptPass)
54 
55   StringRef getArgument() const final {
56     return "test-generic-ir-visitors-interrupt";
57   }
getDescription__anon059fcff50111::TestGenericIRVisitorInterruptPass58   StringRef getDescription() const final {
59     return "Test generic IR visitors with interrupts.";
60   }
runOnOperation__anon059fcff50111::TestGenericIRVisitorInterruptPass61   void runOnOperation() override {
62     Operation *outerOp = getOperation();
63     int stepNo = 0;
64 
65     auto walker = [&](Operation *op, const WalkStage &stage) {
66       if (auto interruptBeforeAall =
67               op->getAttrOfType<BoolAttr>("interrupt_before_all"))
68         if (interruptBeforeAall.getValue() && stage.isBeforeAllRegions())
69           return WalkResult::interrupt();
70 
71       if (auto interruptAfterAll =
72               op->getAttrOfType<BoolAttr>("interrupt_after_all"))
73         if (interruptAfterAll.getValue() && stage.isAfterAllRegions())
74           return WalkResult::interrupt();
75 
76       if (auto interruptAfterRegion =
77               op->getAttrOfType<IntegerAttr>("interrupt_after_region"))
78         if (stage.isAfterRegion(
79                 static_cast<int>(interruptAfterRegion.getInt())))
80           return WalkResult::interrupt();
81 
82       if (auto skipBeforeAall = op->getAttrOfType<BoolAttr>("skip_before_all"))
83         if (skipBeforeAall.getValue() && stage.isBeforeAllRegions())
84           return WalkResult::skip();
85 
86       if (auto skipAfterAll = op->getAttrOfType<BoolAttr>("skip_after_all"))
87         if (skipAfterAll.getValue() && stage.isAfterAllRegions())
88           return WalkResult::skip();
89 
90       if (auto skipAfterRegion =
91               op->getAttrOfType<IntegerAttr>("skip_after_region"))
92         if (stage.isAfterRegion(static_cast<int>(skipAfterRegion.getInt())))
93           return WalkResult::skip();
94 
95       llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
96                    << getStageDescription(stage) << "\n";
97       return WalkResult::advance();
98     };
99 
100     // Interrupt the walk based on attributes on the operation.
101     auto result = outerOp->walk(walker);
102 
103     if (result.wasInterrupted())
104       llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
105 
106     // Exercise static inference of operation type.
107     result = outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
108       return walker(op, stage);
109     });
110 
111     if (result.wasInterrupted())
112       llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
113   }
114 };
115 
116 struct TestGenericIRBlockVisitorInterruptPass
117     : public PassWrapper<TestGenericIRBlockVisitorInterruptPass,
118                          OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon059fcff50111::TestGenericIRBlockVisitorInterruptPass119   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
120       TestGenericIRBlockVisitorInterruptPass)
121 
122   StringRef getArgument() const final {
123     return "test-generic-ir-block-visitors-interrupt";
124   }
getDescription__anon059fcff50111::TestGenericIRBlockVisitorInterruptPass125   StringRef getDescription() const final {
126     return "Test generic IR visitors with interrupts, starting with Blocks.";
127   }
128 
runOnOperation__anon059fcff50111::TestGenericIRBlockVisitorInterruptPass129   void runOnOperation() override {
130     int stepNo = 0;
131 
132     auto walker = [&](Block *block) {
133       for (Operation &op : *block)
134         if (op.getAttrOfType<BoolAttr>("interrupt"))
135           return WalkResult::interrupt();
136 
137       llvm::outs() << "step " << stepNo++ << "\n";
138       return WalkResult::advance();
139     };
140 
141     auto result = getOperation()->walk(walker);
142     if (result.wasInterrupted())
143       llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
144   }
145 };
146 
147 struct TestGenericIRRegionVisitorInterruptPass
148     : public PassWrapper<TestGenericIRRegionVisitorInterruptPass,
149                          OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon059fcff50111::TestGenericIRRegionVisitorInterruptPass150   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
151       TestGenericIRRegionVisitorInterruptPass)
152 
153   StringRef getArgument() const final {
154     return "test-generic-ir-region-visitors-interrupt";
155   }
getDescription__anon059fcff50111::TestGenericIRRegionVisitorInterruptPass156   StringRef getDescription() const final {
157     return "Test generic IR visitors with interrupts, starting with Regions.";
158   }
159 
runOnOperation__anon059fcff50111::TestGenericIRRegionVisitorInterruptPass160   void runOnOperation() override {
161     int stepNo = 0;
162 
163     auto walker = [&](Region *region) {
164       for (Operation &op : region->getOps())
165         if (op.getAttrOfType<BoolAttr>("interrupt"))
166           return WalkResult::interrupt();
167 
168       llvm::outs() << "step " << stepNo++ << "\n";
169       return WalkResult::advance();
170     };
171 
172     auto result = getOperation()->walk(walker);
173     if (result.wasInterrupted())
174       llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
175   }
176 };
177 
178 } // namespace
179 
180 namespace mlir {
181 namespace test {
registerTestGenericIRVisitorsPass()182 void registerTestGenericIRVisitorsPass() {
183   PassRegistration<TestGenericIRVisitorPass>();
184   PassRegistration<TestGenericIRVisitorInterruptPass>();
185   PassRegistration<TestGenericIRBlockVisitorInterruptPass>();
186   PassRegistration<TestGenericIRRegionVisitorInterruptPass>();
187 }
188 
189 } // namespace test
190 } // namespace mlir
191