1 //===- TestIRVisitorsGeneric.cpp - Pass to test the Generic IR visitors ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "TestOps.h"
10 #include "mlir/Pass/Pass.h"
11
12 using namespace mlir;
13
getStageDescription(const WalkStage & stage)14 static std::string getStageDescription(const WalkStage &stage) {
15 if (stage.isBeforeAllRegions())
16 return "before all regions";
17 if (stage.isAfterAllRegions())
18 return "after all regions";
19 return "before region #" + std::to_string(stage.getNextRegion());
20 }
21
22 namespace {
23 /// This pass exercises generic visitor with void callbacks and prints the order
24 /// and stage in which operations are visited.
25 struct TestGenericIRVisitorPass
26 : public PassWrapper<TestGenericIRVisitorPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon059fcff50111::TestGenericIRVisitorPass27 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGenericIRVisitorPass)
28
29 StringRef getArgument() const final { return "test-generic-ir-visitors"; }
getDescription__anon059fcff50111::TestGenericIRVisitorPass30 StringRef getDescription() const final { return "Test generic IR visitors."; }
runOnOperation__anon059fcff50111::TestGenericIRVisitorPass31 void runOnOperation() override {
32 Operation *outerOp = getOperation();
33 int stepNo = 0;
34 outerOp->walk([&](Operation *op, const WalkStage &stage) {
35 llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
36 << getStageDescription(stage) << "\n";
37 });
38
39 // Exercise static inference of operation type.
40 outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
41 llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
42 << getStageDescription(stage) << "\n";
43 });
44 }
45 };
46
47 /// This pass exercises the generic visitor with non-void callbacks and prints
48 /// the order and stage in which operations are visited. It will interrupt the
49 /// walk based on attributes peesent in the IR.
50 struct TestGenericIRVisitorInterruptPass
51 : public PassWrapper<TestGenericIRVisitorInterruptPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon059fcff50111::TestGenericIRVisitorInterruptPass52 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
53 TestGenericIRVisitorInterruptPass)
54
55 StringRef getArgument() const final {
56 return "test-generic-ir-visitors-interrupt";
57 }
getDescription__anon059fcff50111::TestGenericIRVisitorInterruptPass58 StringRef getDescription() const final {
59 return "Test generic IR visitors with interrupts.";
60 }
runOnOperation__anon059fcff50111::TestGenericIRVisitorInterruptPass61 void runOnOperation() override {
62 Operation *outerOp = getOperation();
63 int stepNo = 0;
64
65 auto walker = [&](Operation *op, const WalkStage &stage) {
66 if (auto interruptBeforeAall =
67 op->getAttrOfType<BoolAttr>("interrupt_before_all"))
68 if (interruptBeforeAall.getValue() && stage.isBeforeAllRegions())
69 return WalkResult::interrupt();
70
71 if (auto interruptAfterAll =
72 op->getAttrOfType<BoolAttr>("interrupt_after_all"))
73 if (interruptAfterAll.getValue() && stage.isAfterAllRegions())
74 return WalkResult::interrupt();
75
76 if (auto interruptAfterRegion =
77 op->getAttrOfType<IntegerAttr>("interrupt_after_region"))
78 if (stage.isAfterRegion(
79 static_cast<int>(interruptAfterRegion.getInt())))
80 return WalkResult::interrupt();
81
82 if (auto skipBeforeAall = op->getAttrOfType<BoolAttr>("skip_before_all"))
83 if (skipBeforeAall.getValue() && stage.isBeforeAllRegions())
84 return WalkResult::skip();
85
86 if (auto skipAfterAll = op->getAttrOfType<BoolAttr>("skip_after_all"))
87 if (skipAfterAll.getValue() && stage.isAfterAllRegions())
88 return WalkResult::skip();
89
90 if (auto skipAfterRegion =
91 op->getAttrOfType<IntegerAttr>("skip_after_region"))
92 if (stage.isAfterRegion(static_cast<int>(skipAfterRegion.getInt())))
93 return WalkResult::skip();
94
95 llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
96 << getStageDescription(stage) << "\n";
97 return WalkResult::advance();
98 };
99
100 // Interrupt the walk based on attributes on the operation.
101 auto result = outerOp->walk(walker);
102
103 if (result.wasInterrupted())
104 llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
105
106 // Exercise static inference of operation type.
107 result = outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
108 return walker(op, stage);
109 });
110
111 if (result.wasInterrupted())
112 llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
113 }
114 };
115
116 struct TestGenericIRBlockVisitorInterruptPass
117 : public PassWrapper<TestGenericIRBlockVisitorInterruptPass,
118 OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon059fcff50111::TestGenericIRBlockVisitorInterruptPass119 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
120 TestGenericIRBlockVisitorInterruptPass)
121
122 StringRef getArgument() const final {
123 return "test-generic-ir-block-visitors-interrupt";
124 }
getDescription__anon059fcff50111::TestGenericIRBlockVisitorInterruptPass125 StringRef getDescription() const final {
126 return "Test generic IR visitors with interrupts, starting with Blocks.";
127 }
128
runOnOperation__anon059fcff50111::TestGenericIRBlockVisitorInterruptPass129 void runOnOperation() override {
130 int stepNo = 0;
131
132 auto walker = [&](Block *block) {
133 for (Operation &op : *block)
134 if (op.getAttrOfType<BoolAttr>("interrupt"))
135 return WalkResult::interrupt();
136
137 llvm::outs() << "step " << stepNo++ << "\n";
138 return WalkResult::advance();
139 };
140
141 auto result = getOperation()->walk(walker);
142 if (result.wasInterrupted())
143 llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
144 }
145 };
146
147 struct TestGenericIRRegionVisitorInterruptPass
148 : public PassWrapper<TestGenericIRRegionVisitorInterruptPass,
149 OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon059fcff50111::TestGenericIRRegionVisitorInterruptPass150 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
151 TestGenericIRRegionVisitorInterruptPass)
152
153 StringRef getArgument() const final {
154 return "test-generic-ir-region-visitors-interrupt";
155 }
getDescription__anon059fcff50111::TestGenericIRRegionVisitorInterruptPass156 StringRef getDescription() const final {
157 return "Test generic IR visitors with interrupts, starting with Regions.";
158 }
159
runOnOperation__anon059fcff50111::TestGenericIRRegionVisitorInterruptPass160 void runOnOperation() override {
161 int stepNo = 0;
162
163 auto walker = [&](Region *region) {
164 for (Operation &op : region->getOps())
165 if (op.getAttrOfType<BoolAttr>("interrupt"))
166 return WalkResult::interrupt();
167
168 llvm::outs() << "step " << stepNo++ << "\n";
169 return WalkResult::advance();
170 };
171
172 auto result = getOperation()->walk(walker);
173 if (result.wasInterrupted())
174 llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
175 }
176 };
177
178 } // namespace
179
180 namespace mlir {
181 namespace test {
registerTestGenericIRVisitorsPass()182 void registerTestGenericIRVisitorsPass() {
183 PassRegistration<TestGenericIRVisitorPass>();
184 PassRegistration<TestGenericIRVisitorInterruptPass>();
185 PassRegistration<TestGenericIRBlockVisitorInterruptPass>();
186 PassRegistration<TestGenericIRRegionVisitorInterruptPass>();
187 }
188
189 } // namespace test
190 } // namespace mlir
191