xref: /llvm-project/mlir/test/lib/IR/TestMatchers.cpp (revision 7b19bd5411a68399db4bcf3c2804a67f1d0b3a62)
1 //===- TestMatchers.cpp - Pass to test matchers ---------------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 
18 #include "mlir/Dialect/StandardOps/Ops.h"
19 #include "mlir/IR/Function.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/Pass/Pass.h"
22 
23 using namespace mlir;
24 
25 namespace {
26 /// This is a test pass for verifying matchers.
27 struct TestMatchers : public FunctionPass<TestMatchers> {
28   void runOnFunction() override;
29 };
30 } // end anonymous namespace
31 
32 // This could be done better but is not worth the variadic template trouble.
33 template <typename Matcher> unsigned countMatches(FuncOp f, Matcher &matcher) {
34   unsigned count = 0;
35   f.walk([&count, &matcher](Operation *op) {
36     if (matcher.match(op))
37       ++count;
38   });
39   return count;
40 }
41 
42 using mlir::matchers::m_Any;
43 using mlir::matchers::m_Val;
44 static void test1(FuncOp f) {
45   assert(f.getNumArguments() == 3 && "matcher test funcs must have 3 args");
46 
47   auto a = m_Val(f.getArgument(0));
48   auto b = m_Val(f.getArgument(1));
49   auto c = m_Val(f.getArgument(2));
50 
51   auto p0 = m_Op<AddFOp>(); // using 0-arity matcher
52   llvm::outs() << "Pattern add(*) matched " << countMatches(f, p0)
53                << " times\n";
54 
55   auto p1 = m_Op<MulFOp>(); // using 0-arity matcher
56   llvm::outs() << "Pattern mul(*) matched " << countMatches(f, p1)
57                << " times\n";
58 
59   auto p2 = m_Op<AddFOp>(m_Op<AddFOp>(), m_Any());
60   llvm::outs() << "Pattern add(add(*), *) matched " << countMatches(f, p2)
61                << " times\n";
62 
63   auto p3 = m_Op<AddFOp>(m_Any(), m_Op<AddFOp>());
64   llvm::outs() << "Pattern add(*, add(*)) matched " << countMatches(f, p3)
65                << " times\n";
66 
67   auto p4 = m_Op<MulFOp>(m_Op<AddFOp>(), m_Any());
68   llvm::outs() << "Pattern mul(add(*), *) matched " << countMatches(f, p4)
69                << " times\n";
70 
71   auto p5 = m_Op<MulFOp>(m_Any(), m_Op<AddFOp>());
72   llvm::outs() << "Pattern mul(*, add(*)) matched " << countMatches(f, p5)
73                << " times\n";
74 
75   auto p6 = m_Op<MulFOp>(m_Op<MulFOp>(), m_Any());
76   llvm::outs() << "Pattern mul(mul(*), *) matched " << countMatches(f, p6)
77                << " times\n";
78 
79   auto p7 = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<MulFOp>());
80   llvm::outs() << "Pattern mul(mul(*), mul(*)) matched " << countMatches(f, p7)
81                << " times\n";
82 
83   auto mul_of_mulmul = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<MulFOp>());
84   auto p8 = m_Op<MulFOp>(mul_of_mulmul, mul_of_mulmul);
85   llvm::outs()
86       << "Pattern mul(mul(mul(*), mul(*)), mul(mul(*), mul(*))) matched "
87       << countMatches(f, p8) << " times\n";
88 
89   // clang-format off
90   auto mul_of_muladd = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<AddFOp>());
91   auto mul_of_anyadd = m_Op<MulFOp>(m_Any(), m_Op<AddFOp>());
92   auto p9 = m_Op<MulFOp>(m_Op<MulFOp>(
93                      mul_of_muladd, m_Op<MulFOp>()),
94                    m_Op<MulFOp>(mul_of_anyadd, mul_of_anyadd));
95   // clang-format on
96   llvm::outs() << "Pattern mul(mul(mul(mul(*), add(*)), mul(*)), mul(mul(*, "
97                   "add(*)), mul(*, add(*)))) matched "
98                << countMatches(f, p9) << " times\n";
99 
100   auto p10 = m_Op<AddFOp>(a, b);
101   llvm::outs() << "Pattern add(a, b) matched " << countMatches(f, p10)
102                << " times\n";
103 
104   auto p11 = m_Op<AddFOp>(a, c);
105   llvm::outs() << "Pattern add(a, c) matched " << countMatches(f, p11)
106                << " times\n";
107 
108   auto p12 = m_Op<AddFOp>(b, a);
109   llvm::outs() << "Pattern add(b, a) matched " << countMatches(f, p12)
110                << " times\n";
111 
112   auto p13 = m_Op<AddFOp>(c, a);
113   llvm::outs() << "Pattern add(c, a) matched " << countMatches(f, p13)
114                << " times\n";
115 
116   auto p14 = m_Op<MulFOp>(a, m_Op<AddFOp>(c, b));
117   llvm::outs() << "Pattern mul(a, add(c, b)) matched " << countMatches(f, p14)
118                << " times\n";
119 
120   auto p15 = m_Op<MulFOp>(a, m_Op<AddFOp>(b, c));
121   llvm::outs() << "Pattern mul(a, add(b, c)) matched " << countMatches(f, p15)
122                << " times\n";
123 
124   auto mul_of_aany = m_Op<MulFOp>(a, m_Any());
125   auto p16 = m_Op<MulFOp>(mul_of_aany, m_Op<AddFOp>(a, c));
126   llvm::outs() << "Pattern mul(mul(a, *), add(a, c)) matched "
127                << countMatches(f, p16) << " times\n";
128 
129   auto p17 = m_Op<MulFOp>(mul_of_aany, m_Op<AddFOp>(c, b));
130   llvm::outs() << "Pattern mul(mul(a, *), add(c, b)) matched "
131                << countMatches(f, p17) << " times\n";
132 }
133 
134 void test2(FuncOp f) {
135   auto a = m_Val(f.getArgument(0));
136   FloatAttr floatAttr;
137   auto p = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant(&floatAttr)));
138   // Last operation that is not the terminator.
139   Operation *lastOp = f.getBody().front().back().getPrevNode();
140   if (p.match(lastOp))
141     llvm::outs()
142         << "Pattern add(add(a, constant), a) matched and bound constant to: "
143         << floatAttr.getValueAsDouble() << "\n";
144 }
145 
146 void TestMatchers::runOnFunction() {
147   auto f = getFunction();
148   llvm::outs() << f.getName() << "\n";
149   if (f.getName() == "test1")
150     test1(f);
151   if (f.getName() == "test2")
152     test2(f);
153 }
154 
155 static PassRegistration<TestMatchers> pass("test-matchers",
156                                            "Test C++ pattern matchers.");
157