xref: /llvm-project/mlir/test/lib/IR/TestMatchers.cpp (revision 34a35a8b244243f5a4ad5d531007bccfeaa0b02e)
1 //===- TestMatchers.cpp - Pass to test matchers ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/IR/BuiltinOps.h"
11 #include "mlir/IR/Matchers.h"
12 #include "mlir/Interfaces/FunctionInterfaces.h"
13 #include "mlir/Pass/Pass.h"
14 
15 using namespace mlir;
16 
17 namespace {
18 /// This is a test pass for verifying matchers.
19 struct TestMatchers
20     : public PassWrapper<TestMatchers, InterfacePass<FunctionOpInterface>> {
21   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMatchers)
22 
23   void runOnOperation() override;
getArgument__anon5c1ca04c0111::TestMatchers24   StringRef getArgument() const final { return "test-matchers"; }
getDescription__anon5c1ca04c0111::TestMatchers25   StringRef getDescription() const final {
26     return "Test C++ pattern matchers.";
27   }
28 };
29 } // namespace
30 
31 // This could be done better but is not worth the variadic template trouble.
32 template <typename Matcher>
countMatches(FunctionOpInterface f,Matcher & matcher)33 static unsigned countMatches(FunctionOpInterface f, Matcher &matcher) {
34   unsigned count = 0;
35   f.walk([&count, &matcher](Operation *op) {
36     if (matcher.match(op))
37       ++count;
38   });
39   return count;
40 }
41 
42 using mlir::matchers::m_Any;
43 using mlir::matchers::m_Val;
test1(FunctionOpInterface f)44 static void test1(FunctionOpInterface f) {
45   assert(f.getNumArguments() == 3 && "matcher test funcs must have 3 args");
46 
47   auto a = m_Val(f.getArgument(0));
48   auto b = m_Val(f.getArgument(1));
49   auto c = m_Val(f.getArgument(2));
50 
51   auto p0 = m_Op<arith::AddFOp>(); // using 0-arity matcher
52   llvm::outs() << "Pattern add(*) matched " << countMatches(f, p0)
53                << " times\n";
54 
55   auto p1 = m_Op<arith::MulFOp>(); // using 0-arity matcher
56   llvm::outs() << "Pattern mul(*) matched " << countMatches(f, p1)
57                << " times\n";
58 
59   auto p2 = m_Op<arith::AddFOp>(m_Op<arith::AddFOp>(), m_Any());
60   llvm::outs() << "Pattern add(add(*), *) matched " << countMatches(f, p2)
61                << " times\n";
62 
63   auto p3 = m_Op<arith::AddFOp>(m_Any(), m_Op<arith::AddFOp>());
64   llvm::outs() << "Pattern add(*, add(*)) matched " << countMatches(f, p3)
65                << " times\n";
66 
67   auto p4 = m_Op<arith::MulFOp>(m_Op<arith::AddFOp>(), m_Any());
68   llvm::outs() << "Pattern mul(add(*), *) matched " << countMatches(f, p4)
69                << " times\n";
70 
71   auto p5 = m_Op<arith::MulFOp>(m_Any(), m_Op<arith::AddFOp>());
72   llvm::outs() << "Pattern mul(*, add(*)) matched " << countMatches(f, p5)
73                << " times\n";
74 
75   auto p6 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Any());
76   llvm::outs() << "Pattern mul(mul(*), *) matched " << countMatches(f, p6)
77                << " times\n";
78 
79   auto p7 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::MulFOp>());
80   llvm::outs() << "Pattern mul(mul(*), mul(*)) matched " << countMatches(f, p7)
81                << " times\n";
82 
83   auto mulOfMulmul =
84       m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::MulFOp>());
85   auto p8 = m_Op<arith::MulFOp>(mulOfMulmul, mulOfMulmul);
86   llvm::outs()
87       << "Pattern mul(mul(mul(*), mul(*)), mul(mul(*), mul(*))) matched "
88       << countMatches(f, p8) << " times\n";
89 
90   // clang-format off
91   auto mulOfMuladd = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::AddFOp>());
92   auto mulOfAnyadd = m_Op<arith::MulFOp>(m_Any(), m_Op<arith::AddFOp>());
93   auto p9 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(
94                      mulOfMuladd, m_Op<arith::MulFOp>()),
95                    m_Op<arith::MulFOp>(mulOfAnyadd, mulOfAnyadd));
96   // clang-format on
97   llvm::outs() << "Pattern mul(mul(mul(mul(*), add(*)), mul(*)), mul(mul(*, "
98                   "add(*)), mul(*, add(*)))) matched "
99                << countMatches(f, p9) << " times\n";
100 
101   auto p10 = m_Op<arith::AddFOp>(a, b);
102   llvm::outs() << "Pattern add(a, b) matched " << countMatches(f, p10)
103                << " times\n";
104 
105   auto p11 = m_Op<arith::AddFOp>(a, c);
106   llvm::outs() << "Pattern add(a, c) matched " << countMatches(f, p11)
107                << " times\n";
108 
109   auto p12 = m_Op<arith::AddFOp>(b, a);
110   llvm::outs() << "Pattern add(b, a) matched " << countMatches(f, p12)
111                << " times\n";
112 
113   auto p13 = m_Op<arith::AddFOp>(c, a);
114   llvm::outs() << "Pattern add(c, a) matched " << countMatches(f, p13)
115                << " times\n";
116 
117   auto p14 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(c, b));
118   llvm::outs() << "Pattern mul(a, add(c, b)) matched " << countMatches(f, p14)
119                << " times\n";
120 
121   auto p15 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(b, c));
122   llvm::outs() << "Pattern mul(a, add(b, c)) matched " << countMatches(f, p15)
123                << " times\n";
124 
125   auto mulOfAany = m_Op<arith::MulFOp>(a, m_Any());
126   auto p16 = m_Op<arith::MulFOp>(mulOfAany, m_Op<arith::AddFOp>(a, c));
127   llvm::outs() << "Pattern mul(mul(a, *), add(a, c)) matched "
128                << countMatches(f, p16) << " times\n";
129 
130   auto p17 = m_Op<arith::MulFOp>(mulOfAany, m_Op<arith::AddFOp>(c, b));
131   llvm::outs() << "Pattern mul(mul(a, *), add(c, b)) matched "
132                << countMatches(f, p17) << " times\n";
133 }
134 
test2(FunctionOpInterface f)135 void test2(FunctionOpInterface f) {
136   auto a = m_Val(f.getArgument(0));
137   FloatAttr floatAttr;
138   auto p =
139       m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(a, m_Constant(&floatAttr)));
140   auto p1 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(a, m_Constant()));
141   // Last operation that is not the terminator.
142   Operation *lastOp = f.getFunctionBody().front().back().getPrevNode();
143   if (p.match(lastOp))
144     llvm::outs()
145         << "Pattern add(add(a, constant), a) matched and bound constant to: "
146         << floatAttr.getValueAsDouble() << "\n";
147   if (p1.match(lastOp))
148     llvm::outs() << "Pattern add(add(a, constant), a) matched\n";
149 }
150 
test3(FunctionOpInterface f)151 void test3(FunctionOpInterface f) {
152   arith::FastMathFlagsAttr fastMathAttr;
153   auto p = m_Op<arith::MulFOp>(m_Any(),
154                                m_Op<arith::AddFOp>(m_Any(), m_Op("test.name")));
155   auto p1 = m_Attr("fastmath", &fastMathAttr);
156 
157   // Last operation that is not the terminator.
158   Operation *lastOp = f.getFunctionBody().front().back().getPrevNode();
159   if (p.match(lastOp))
160     llvm::outs() << "Pattern mul(*, add(*, m_Op(\"test.name\"))) matched\n";
161   if (p1.match(lastOp))
162     llvm::outs() << "Pattern m_Attr(\"fastmath\") matched and bound value to: "
163                  << fastMathAttr.getValue() << "\n";
164 }
165 
runOnOperation()166 void TestMatchers::runOnOperation() {
167   auto f = getOperation();
168   llvm::outs() << f.getName() << "\n";
169   if (f.getName() == "test1")
170     test1(f);
171   if (f.getName() == "test2")
172     test2(f);
173   if (f.getName() == "test3")
174     test3(f);
175 }
176 
177 namespace mlir {
registerTestMatchers()178 void registerTestMatchers() { PassRegistration<TestMatchers>(); }
179 } // namespace mlir
180