1 //===- TestIRVisitors.cpp - Pass to test the IR visitors ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/BuiltinOps.h"
10 #include "mlir/IR/Iterators.h"
11 #include "mlir/Interfaces/FunctionInterfaces.h"
12 #include "mlir/Pass/Pass.h"
13
14 using namespace mlir;
15
printRegion(Region * region)16 static void printRegion(Region *region) {
17 llvm::outs() << "region " << region->getRegionNumber() << " from operation '"
18 << region->getParentOp()->getName() << "'";
19 }
20
printBlock(Block * block)21 static void printBlock(Block *block) {
22 llvm::outs() << "block ";
23 block->printAsOperand(llvm::outs(), /*printType=*/false);
24 llvm::outs() << " from ";
25 printRegion(block->getParent());
26 }
27
printOperation(Operation * op)28 static void printOperation(Operation *op) {
29 llvm::outs() << "op '" << op->getName() << "'";
30 }
31
32 /// Tests pure callbacks.
testPureCallbacks(Operation * op)33 static void testPureCallbacks(Operation *op) {
34 auto opPure = [](Operation *op) {
35 llvm::outs() << "Visiting ";
36 printOperation(op);
37 llvm::outs() << "\n";
38 };
39 auto blockPure = [](Block *block) {
40 llvm::outs() << "Visiting ";
41 printBlock(block);
42 llvm::outs() << "\n";
43 };
44 auto regionPure = [](Region *region) {
45 llvm::outs() << "Visiting ";
46 printRegion(region);
47 llvm::outs() << "\n";
48 };
49
50 llvm::outs() << "Op pre-order visits"
51 << "\n";
52 op->walk<WalkOrder::PreOrder>(opPure);
53 llvm::outs() << "Block pre-order visits"
54 << "\n";
55 op->walk<WalkOrder::PreOrder>(blockPure);
56 llvm::outs() << "Region pre-order visits"
57 << "\n";
58 op->walk<WalkOrder::PreOrder>(regionPure);
59
60 llvm::outs() << "Op post-order visits"
61 << "\n";
62 op->walk<WalkOrder::PostOrder>(opPure);
63 llvm::outs() << "Block post-order visits"
64 << "\n";
65 op->walk<WalkOrder::PostOrder>(blockPure);
66 llvm::outs() << "Region post-order visits"
67 << "\n";
68 op->walk<WalkOrder::PostOrder>(regionPure);
69
70 llvm::outs() << "Op reverse post-order visits"
71 << "\n";
72 op->walk<WalkOrder::PostOrder, ReverseIterator>(opPure);
73 llvm::outs() << "Block reverse post-order visits"
74 << "\n";
75 op->walk<WalkOrder::PostOrder, ReverseIterator>(blockPure);
76 llvm::outs() << "Region reverse post-order visits"
77 << "\n";
78 op->walk<WalkOrder::PostOrder, ReverseIterator>(regionPure);
79
80 // This test case tests "NoGraphRegions = true", so start the walk with
81 // functions.
82 op->walk([&](FunctionOpInterface funcOp) {
83 llvm::outs() << "Op forward dominance post-order visits"
84 << "\n";
85 funcOp->walk<WalkOrder::PostOrder,
86 ForwardDominanceIterator</*NoGraphRegions=*/true>>(opPure);
87 llvm::outs() << "Block forward dominance post-order visits"
88 << "\n";
89 funcOp->walk<WalkOrder::PostOrder,
90 ForwardDominanceIterator</*NoGraphRegions=*/true>>(blockPure);
91 llvm::outs() << "Region forward dominance post-order visits"
92 << "\n";
93 funcOp->walk<WalkOrder::PostOrder,
94 ForwardDominanceIterator</*NoGraphRegions=*/true>>(regionPure);
95
96 llvm::outs() << "Op reverse dominance post-order visits"
97 << "\n";
98 funcOp->walk<WalkOrder::PostOrder,
99 ReverseDominanceIterator</*NoGraphRegions=*/true>>(opPure);
100 llvm::outs() << "Block reverse dominance post-order visits"
101 << "\n";
102 funcOp->walk<WalkOrder::PostOrder,
103 ReverseDominanceIterator</*NoGraphRegions=*/true>>(blockPure);
104 llvm::outs() << "Region reverse dominance post-order visits"
105 << "\n";
106 funcOp->walk<WalkOrder::PostOrder,
107 ReverseDominanceIterator</*NoGraphRegions=*/true>>(regionPure);
108 });
109 }
110
111 /// Tests erasure callbacks that skip the walk.
testSkipErasureCallbacks(Operation * op)112 static void testSkipErasureCallbacks(Operation *op) {
113 auto skipOpErasure = [](Operation *op) {
114 // Do not erase module and module children operations. Otherwise, there
115 // wouldn't be too much to test in pre-order.
116 if (isa<ModuleOp>(op) || isa<ModuleOp>(op->getParentOp()))
117 return WalkResult::advance();
118
119 llvm::outs() << "Erasing ";
120 printOperation(op);
121 llvm::outs() << "\n";
122 op->dropAllUses();
123 op->erase();
124 return WalkResult::skip();
125 };
126 auto skipBlockErasure = [](Block *block) {
127 // Do not erase module and module children blocks. Otherwise there wouldn't
128 // be too much to test in pre-order.
129 Operation *parentOp = block->getParentOp();
130 if (isa<ModuleOp>(parentOp) || isa<ModuleOp>(parentOp->getParentOp()))
131 return WalkResult::advance();
132
133 if (block->use_empty()) {
134 llvm::outs() << "Erasing ";
135 printBlock(block);
136 llvm::outs() << "\n";
137 block->erase();
138 return WalkResult::skip();
139 }
140 llvm::outs() << "Cannot erase ";
141 printBlock(block);
142 llvm::outs() << ", still has uses\n";
143 return WalkResult::advance();
144
145 };
146
147 llvm::outs() << "Op pre-order erasures (skip)"
148 << "\n";
149 Operation *cloned = op->clone();
150 cloned->walk<WalkOrder::PreOrder>(skipOpErasure);
151 cloned->erase();
152
153 llvm::outs() << "Block pre-order erasures (skip)"
154 << "\n";
155 cloned = op->clone();
156 cloned->walk<WalkOrder::PreOrder>(skipBlockErasure);
157 cloned->erase();
158
159 llvm::outs() << "Op post-order erasures (skip)"
160 << "\n";
161 cloned = op->clone();
162 cloned->walk<WalkOrder::PostOrder>(skipOpErasure);
163 cloned->erase();
164
165 llvm::outs() << "Block post-order erasures (skip)"
166 << "\n";
167 cloned = op->clone();
168 cloned->walk<WalkOrder::PostOrder>(skipBlockErasure);
169 cloned->erase();
170 }
171
172 /// Tests callbacks that erase the op or block but don't return 'Skip'. This
173 /// callbacks are only valid in post-order.
testNoSkipErasureCallbacks(Operation * op)174 static void testNoSkipErasureCallbacks(Operation *op) {
175 auto noSkipOpErasure = [](Operation *op) {
176 llvm::outs() << "Erasing ";
177 printOperation(op);
178 llvm::outs() << "\n";
179 op->dropAllUses();
180 op->erase();
181 };
182 auto noSkipBlockErasure = [](Block *block) {
183 if (block->use_empty()) {
184 llvm::outs() << "Erasing ";
185 printBlock(block);
186 llvm::outs() << "\n";
187 block->erase();
188 } else {
189 llvm::outs() << "Cannot erase ";
190 printBlock(block);
191 llvm::outs() << ", still has uses\n";
192 }
193 };
194
195 llvm::outs() << "Op post-order erasures (no skip)"
196 << "\n";
197 Operation *cloned = op->clone();
198 cloned->walk<WalkOrder::PostOrder>(noSkipOpErasure);
199
200 llvm::outs() << "Block post-order erasures (no skip)"
201 << "\n";
202 cloned = op->clone();
203 cloned->walk<WalkOrder::PostOrder>(noSkipBlockErasure);
204 cloned->erase();
205 }
206
207 /// Invoke region/block walks on regions/blocks.
testBlockAndRegionWalkers(Operation * op)208 static void testBlockAndRegionWalkers(Operation *op) {
209 auto blockPure = [](Block *block) {
210 llvm::outs() << "Visiting ";
211 printBlock(block);
212 llvm::outs() << "\n";
213 };
214 auto regionPure = [](Region *region) {
215 llvm::outs() << "Visiting ";
216 printRegion(region);
217 llvm::outs() << "\n";
218 };
219
220 llvm::outs() << "Invoke block pre-order visits on blocks\n";
221 op->walk([&](Operation *op) {
222 if (!op->hasAttr("walk_blocks"))
223 return;
224 for (Region ®ion : op->getRegions()) {
225 for (Block &block : region.getBlocks()) {
226 block.walk<WalkOrder::PreOrder>(blockPure);
227 }
228 }
229 });
230
231 llvm::outs() << "Invoke block post-order visits on blocks\n";
232 op->walk([&](Operation *op) {
233 if (!op->hasAttr("walk_blocks"))
234 return;
235 for (Region ®ion : op->getRegions()) {
236 for (Block &block : region.getBlocks()) {
237 block.walk<WalkOrder::PostOrder>(blockPure);
238 }
239 }
240 });
241
242 llvm::outs() << "Invoke region pre-order visits on region\n";
243 op->walk([&](Operation *op) {
244 if (!op->hasAttr("walk_regions"))
245 return;
246 for (Region ®ion : op->getRegions()) {
247 region.walk<WalkOrder::PreOrder>(regionPure);
248 }
249 });
250
251 llvm::outs() << "Invoke region post-order visits on region\n";
252 op->walk([&](Operation *op) {
253 if (!op->hasAttr("walk_regions"))
254 return;
255 for (Region ®ion : op->getRegions()) {
256 region.walk<WalkOrder::PostOrder>(regionPure);
257 }
258 });
259 }
260
261 namespace {
262 /// This pass exercises the different configurations of the IR visitors.
263 struct TestIRVisitorsPass
264 : public PassWrapper<TestIRVisitorsPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon0b4d01b80f11::TestIRVisitorsPass265 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIRVisitorsPass)
266
267 StringRef getArgument() const final { return "test-ir-visitors"; }
getDescription__anon0b4d01b80f11::TestIRVisitorsPass268 StringRef getDescription() const final { return "Test various visitors."; }
runOnOperation__anon0b4d01b80f11::TestIRVisitorsPass269 void runOnOperation() override {
270 Operation *op = getOperation();
271 testPureCallbacks(op);
272 testBlockAndRegionWalkers(op);
273 testSkipErasureCallbacks(op);
274 testNoSkipErasureCallbacks(op);
275 }
276 };
277 } // namespace
278
279 namespace mlir {
280 namespace test {
registerTestIRVisitorsPass()281 void registerTestIRVisitorsPass() { PassRegistration<TestIRVisitorsPass>(); }
282 } // namespace test
283 } // namespace mlir
284