xref: /llvm-project/mlir/test/lib/IR/TestVisitors.cpp (revision 6b72c37958c3d8aa0cfa48c4ad6509dd6ea37749)
1 //===- TestIRVisitors.cpp - Pass to test the IR visitors ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinOps.h"
10 #include "mlir/IR/Iterators.h"
11 #include "mlir/Interfaces/FunctionInterfaces.h"
12 #include "mlir/Pass/Pass.h"
13 
14 using namespace mlir;
15 
printRegion(Region * region)16 static void printRegion(Region *region) {
17   llvm::outs() << "region " << region->getRegionNumber() << " from operation '"
18                << region->getParentOp()->getName() << "'";
19 }
20 
printBlock(Block * block)21 static void printBlock(Block *block) {
22   llvm::outs() << "block ";
23   block->printAsOperand(llvm::outs(), /*printType=*/false);
24   llvm::outs() << " from ";
25   printRegion(block->getParent());
26 }
27 
printOperation(Operation * op)28 static void printOperation(Operation *op) {
29   llvm::outs() << "op '" << op->getName() << "'";
30 }
31 
32 /// Tests pure callbacks.
testPureCallbacks(Operation * op)33 static void testPureCallbacks(Operation *op) {
34   auto opPure = [](Operation *op) {
35     llvm::outs() << "Visiting ";
36     printOperation(op);
37     llvm::outs() << "\n";
38   };
39   auto blockPure = [](Block *block) {
40     llvm::outs() << "Visiting ";
41     printBlock(block);
42     llvm::outs() << "\n";
43   };
44   auto regionPure = [](Region *region) {
45     llvm::outs() << "Visiting ";
46     printRegion(region);
47     llvm::outs() << "\n";
48   };
49 
50   llvm::outs() << "Op pre-order visits"
51                << "\n";
52   op->walk<WalkOrder::PreOrder>(opPure);
53   llvm::outs() << "Block pre-order visits"
54                << "\n";
55   op->walk<WalkOrder::PreOrder>(blockPure);
56   llvm::outs() << "Region pre-order visits"
57                << "\n";
58   op->walk<WalkOrder::PreOrder>(regionPure);
59 
60   llvm::outs() << "Op post-order visits"
61                << "\n";
62   op->walk<WalkOrder::PostOrder>(opPure);
63   llvm::outs() << "Block post-order visits"
64                << "\n";
65   op->walk<WalkOrder::PostOrder>(blockPure);
66   llvm::outs() << "Region post-order visits"
67                << "\n";
68   op->walk<WalkOrder::PostOrder>(regionPure);
69 
70   llvm::outs() << "Op reverse post-order visits"
71                << "\n";
72   op->walk<WalkOrder::PostOrder, ReverseIterator>(opPure);
73   llvm::outs() << "Block reverse post-order visits"
74                << "\n";
75   op->walk<WalkOrder::PostOrder, ReverseIterator>(blockPure);
76   llvm::outs() << "Region reverse post-order visits"
77                << "\n";
78   op->walk<WalkOrder::PostOrder, ReverseIterator>(regionPure);
79 
80   // This test case tests "NoGraphRegions = true", so start the walk with
81   // functions.
82   op->walk([&](FunctionOpInterface funcOp) {
83     llvm::outs() << "Op forward dominance post-order visits"
84                  << "\n";
85     funcOp->walk<WalkOrder::PostOrder,
86                  ForwardDominanceIterator</*NoGraphRegions=*/true>>(opPure);
87     llvm::outs() << "Block forward dominance post-order visits"
88                  << "\n";
89     funcOp->walk<WalkOrder::PostOrder,
90                  ForwardDominanceIterator</*NoGraphRegions=*/true>>(blockPure);
91     llvm::outs() << "Region forward dominance post-order visits"
92                  << "\n";
93     funcOp->walk<WalkOrder::PostOrder,
94                  ForwardDominanceIterator</*NoGraphRegions=*/true>>(regionPure);
95 
96     llvm::outs() << "Op reverse dominance post-order visits"
97                  << "\n";
98     funcOp->walk<WalkOrder::PostOrder,
99                  ReverseDominanceIterator</*NoGraphRegions=*/true>>(opPure);
100     llvm::outs() << "Block reverse dominance post-order visits"
101                  << "\n";
102     funcOp->walk<WalkOrder::PostOrder,
103                  ReverseDominanceIterator</*NoGraphRegions=*/true>>(blockPure);
104     llvm::outs() << "Region reverse dominance post-order visits"
105                  << "\n";
106     funcOp->walk<WalkOrder::PostOrder,
107                  ReverseDominanceIterator</*NoGraphRegions=*/true>>(regionPure);
108   });
109 }
110 
111 /// Tests erasure callbacks that skip the walk.
testSkipErasureCallbacks(Operation * op)112 static void testSkipErasureCallbacks(Operation *op) {
113   auto skipOpErasure = [](Operation *op) {
114     // Do not erase module and module children operations. Otherwise, there
115     // wouldn't be too much to test in pre-order.
116     if (isa<ModuleOp>(op) || isa<ModuleOp>(op->getParentOp()))
117       return WalkResult::advance();
118 
119     llvm::outs() << "Erasing ";
120     printOperation(op);
121     llvm::outs() << "\n";
122     op->dropAllUses();
123     op->erase();
124     return WalkResult::skip();
125   };
126   auto skipBlockErasure = [](Block *block) {
127     // Do not erase module and module children blocks. Otherwise there wouldn't
128     // be too much to test in pre-order.
129     Operation *parentOp = block->getParentOp();
130     if (isa<ModuleOp>(parentOp) || isa<ModuleOp>(parentOp->getParentOp()))
131       return WalkResult::advance();
132 
133     if (block->use_empty()) {
134       llvm::outs() << "Erasing ";
135       printBlock(block);
136       llvm::outs() << "\n";
137       block->erase();
138       return WalkResult::skip();
139     }
140     llvm::outs() << "Cannot erase ";
141     printBlock(block);
142     llvm::outs() << ", still has uses\n";
143     return WalkResult::advance();
144 
145   };
146 
147   llvm::outs() << "Op pre-order erasures (skip)"
148                << "\n";
149   Operation *cloned = op->clone();
150   cloned->walk<WalkOrder::PreOrder>(skipOpErasure);
151   cloned->erase();
152 
153   llvm::outs() << "Block pre-order erasures (skip)"
154                << "\n";
155   cloned = op->clone();
156   cloned->walk<WalkOrder::PreOrder>(skipBlockErasure);
157   cloned->erase();
158 
159   llvm::outs() << "Op post-order erasures (skip)"
160                << "\n";
161   cloned = op->clone();
162   cloned->walk<WalkOrder::PostOrder>(skipOpErasure);
163   cloned->erase();
164 
165   llvm::outs() << "Block post-order erasures (skip)"
166                << "\n";
167   cloned = op->clone();
168   cloned->walk<WalkOrder::PostOrder>(skipBlockErasure);
169   cloned->erase();
170 }
171 
172 /// Tests callbacks that erase the op or block but don't return 'Skip'. This
173 /// callbacks are only valid in post-order.
testNoSkipErasureCallbacks(Operation * op)174 static void testNoSkipErasureCallbacks(Operation *op) {
175   auto noSkipOpErasure = [](Operation *op) {
176     llvm::outs() << "Erasing ";
177     printOperation(op);
178     llvm::outs() << "\n";
179     op->dropAllUses();
180     op->erase();
181   };
182   auto noSkipBlockErasure = [](Block *block) {
183     if (block->use_empty()) {
184       llvm::outs() << "Erasing ";
185       printBlock(block);
186       llvm::outs() << "\n";
187       block->erase();
188     } else {
189       llvm::outs() << "Cannot erase ";
190       printBlock(block);
191       llvm::outs() << ", still has uses\n";
192     }
193   };
194 
195   llvm::outs() << "Op post-order erasures (no skip)"
196                << "\n";
197   Operation *cloned = op->clone();
198   cloned->walk<WalkOrder::PostOrder>(noSkipOpErasure);
199 
200   llvm::outs() << "Block post-order erasures (no skip)"
201                << "\n";
202   cloned = op->clone();
203   cloned->walk<WalkOrder::PostOrder>(noSkipBlockErasure);
204   cloned->erase();
205 }
206 
207 /// Invoke region/block walks on regions/blocks.
testBlockAndRegionWalkers(Operation * op)208 static void testBlockAndRegionWalkers(Operation *op) {
209   auto blockPure = [](Block *block) {
210     llvm::outs() << "Visiting ";
211     printBlock(block);
212     llvm::outs() << "\n";
213   };
214   auto regionPure = [](Region *region) {
215     llvm::outs() << "Visiting ";
216     printRegion(region);
217     llvm::outs() << "\n";
218   };
219 
220   llvm::outs() << "Invoke block pre-order visits on blocks\n";
221   op->walk([&](Operation *op) {
222     if (!op->hasAttr("walk_blocks"))
223       return;
224     for (Region &region : op->getRegions()) {
225       for (Block &block : region.getBlocks()) {
226         block.walk<WalkOrder::PreOrder>(blockPure);
227       }
228     }
229   });
230 
231   llvm::outs() << "Invoke block post-order visits on blocks\n";
232   op->walk([&](Operation *op) {
233     if (!op->hasAttr("walk_blocks"))
234       return;
235     for (Region &region : op->getRegions()) {
236       for (Block &block : region.getBlocks()) {
237         block.walk<WalkOrder::PostOrder>(blockPure);
238       }
239     }
240   });
241 
242   llvm::outs() << "Invoke region pre-order visits on region\n";
243   op->walk([&](Operation *op) {
244     if (!op->hasAttr("walk_regions"))
245       return;
246     for (Region &region : op->getRegions()) {
247       region.walk<WalkOrder::PreOrder>(regionPure);
248     }
249   });
250 
251   llvm::outs() << "Invoke region post-order visits on region\n";
252   op->walk([&](Operation *op) {
253     if (!op->hasAttr("walk_regions"))
254       return;
255     for (Region &region : op->getRegions()) {
256       region.walk<WalkOrder::PostOrder>(regionPure);
257     }
258   });
259 }
260 
261 namespace {
262 /// This pass exercises the different configurations of the IR visitors.
263 struct TestIRVisitorsPass
264     : public PassWrapper<TestIRVisitorsPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon0b4d01b80f11::TestIRVisitorsPass265   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIRVisitorsPass)
266 
267   StringRef getArgument() const final { return "test-ir-visitors"; }
getDescription__anon0b4d01b80f11::TestIRVisitorsPass268   StringRef getDescription() const final { return "Test various visitors."; }
runOnOperation__anon0b4d01b80f11::TestIRVisitorsPass269   void runOnOperation() override {
270     Operation *op = getOperation();
271     testPureCallbacks(op);
272     testBlockAndRegionWalkers(op);
273     testSkipErasureCallbacks(op);
274     testNoSkipErasureCallbacks(op);
275   }
276 };
277 } // namespace
278 
279 namespace mlir {
280 namespace test {
registerTestIRVisitorsPass()281 void registerTestIRVisitorsPass() { PassRegistration<TestIRVisitorsPass>(); }
282 } // namespace test
283 } // namespace mlir
284