xref: /llvm-project/mlir/test/lib/IR/TestAffineWalk.cpp (revision a5757c5b65f1894de16f549212b1c37793312703)
1 //===- TestAffineWalk.cpp - Pass to test affine walks
2 //----------------------===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #include "mlir/Pass/Pass.h"
11 
12 #include "mlir/IR/BuiltinOps.h"
13 
14 using namespace mlir;
15 
16 namespace {
17 /// A test pass for verifying walk interrupts.
18 struct TestAffineWalk
19     : public PassWrapper<TestAffineWalk, OperationPass<ModuleOp>> {
20   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAffineWalk)
21 
22   void runOnOperation() override;
getArgument__anon633d728d0111::TestAffineWalk23   StringRef getArgument() const final { return "test-affine-walk"; }
getDescription__anon633d728d0111::TestAffineWalk24   StringRef getDescription() const final { return "Test affine walk method."; }
25 };
26 } // namespace
27 
28 /// Emits a remark for the first `map`'s result expression that contains a
29 /// mod expression.
checkMod(AffineMap map,Location loc)30 static void checkMod(AffineMap map, Location loc) {
31   for (AffineExpr e : map.getResults()) {
32     e.walk([&](AffineExpr s) {
33       if (s.getKind() == mlir::AffineExprKind::Mod) {
34         emitRemark(loc, "mod expression: ");
35         return WalkResult::interrupt();
36       }
37       return WalkResult::advance();
38     });
39   }
40 }
41 
runOnOperation()42 void TestAffineWalk::runOnOperation() {
43   auto m = getOperation();
44   // Test whether the walk is being correctly interrupted.
45   m.walk([](Operation *op) {
46     for (NamedAttribute attr : op->getAttrs()) {
47       auto mapAttr = dyn_cast<AffineMapAttr>(attr.getValue());
48       if (!mapAttr)
49         return;
50       checkMod(mapAttr.getAffineMap(), op->getLoc());
51     }
52   });
53 }
54 
55 namespace mlir {
registerTestAffineWalk()56 void registerTestAffineWalk() { PassRegistration<TestAffineWalk>(); }
57 } // namespace mlir
58