xref: /llvm-project/mlir/test/lib/IR/TestMatchers.cpp (revision 34a35a8b244243f5a4ad5d531007bccfeaa0b02e)
1ade58a26SNicolas Vasilache //===- TestMatchers.cpp - Pass to test matchers ---------------------------===//
2ade58a26SNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ade58a26SNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
8ade58a26SNicolas Vasilache 
9abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1065fcddffSRiver Riddle #include "mlir/IR/BuiltinOps.h"
11ade58a26SNicolas Vasilache #include "mlir/IR/Matchers.h"
12*34a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionInterfaces.h"
13ade58a26SNicolas Vasilache #include "mlir/Pass/Pass.h"
14ade58a26SNicolas Vasilache 
15ade58a26SNicolas Vasilache using namespace mlir;
16ade58a26SNicolas Vasilache 
17ade58a26SNicolas Vasilache namespace {
18ade58a26SNicolas Vasilache /// This is a test pass for verifying matchers.
1987d6bf37SRiver Riddle struct TestMatchers
2087d6bf37SRiver Riddle     : public PassWrapper<TestMatchers, InterfacePass<FunctionOpInterface>> {
215e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMatchers)
225e50dd04SRiver Riddle 
2341574554SRiver Riddle   void runOnOperation() override;
getArgument__anon5c1ca04c0111::TestMatchers24b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-matchers"; }
getDescription__anon5c1ca04c0111::TestMatchers25b5e22e6dSMehdi Amini   StringRef getDescription() const final {
26b5e22e6dSMehdi Amini     return "Test C++ pattern matchers.";
27b5e22e6dSMehdi Amini   }
28ade58a26SNicolas Vasilache };
29be0a7e9fSMehdi Amini } // namespace
30ade58a26SNicolas Vasilache 
31ade58a26SNicolas Vasilache // This could be done better but is not worth the variadic template trouble.
32350dadaaSBenjamin Kramer template <typename Matcher>
countMatches(FunctionOpInterface f,Matcher & matcher)3387d6bf37SRiver Riddle static unsigned countMatches(FunctionOpInterface f, Matcher &matcher) {
34ade58a26SNicolas Vasilache   unsigned count = 0;
35ade58a26SNicolas Vasilache   f.walk([&count, &matcher](Operation *op) {
367b19bd54SNicolas Vasilache     if (matcher.match(op))
37ade58a26SNicolas Vasilache       ++count;
38ade58a26SNicolas Vasilache   });
39ade58a26SNicolas Vasilache   return count;
40ade58a26SNicolas Vasilache }
41ade58a26SNicolas Vasilache 
427b19bd54SNicolas Vasilache using mlir::matchers::m_Any;
437b19bd54SNicolas Vasilache using mlir::matchers::m_Val;
test1(FunctionOpInterface f)4487d6bf37SRiver Riddle static void test1(FunctionOpInterface f) {
45ade58a26SNicolas Vasilache   assert(f.getNumArguments() == 3 && "matcher test funcs must have 3 args");
46ade58a26SNicolas Vasilache 
477b19bd54SNicolas Vasilache   auto a = m_Val(f.getArgument(0));
487b19bd54SNicolas Vasilache   auto b = m_Val(f.getArgument(1));
497b19bd54SNicolas Vasilache   auto c = m_Val(f.getArgument(2));
50ade58a26SNicolas Vasilache 
51a54f4eaeSMogball   auto p0 = m_Op<arith::AddFOp>(); // using 0-arity matcher
52ade58a26SNicolas Vasilache   llvm::outs() << "Pattern add(*) matched " << countMatches(f, p0)
53ade58a26SNicolas Vasilache                << " times\n";
54ade58a26SNicolas Vasilache 
55a54f4eaeSMogball   auto p1 = m_Op<arith::MulFOp>(); // using 0-arity matcher
56ade58a26SNicolas Vasilache   llvm::outs() << "Pattern mul(*) matched " << countMatches(f, p1)
57ade58a26SNicolas Vasilache                << " times\n";
58ade58a26SNicolas Vasilache 
59a54f4eaeSMogball   auto p2 = m_Op<arith::AddFOp>(m_Op<arith::AddFOp>(), m_Any());
60ade58a26SNicolas Vasilache   llvm::outs() << "Pattern add(add(*), *) matched " << countMatches(f, p2)
61ade58a26SNicolas Vasilache                << " times\n";
62ade58a26SNicolas Vasilache 
63a54f4eaeSMogball   auto p3 = m_Op<arith::AddFOp>(m_Any(), m_Op<arith::AddFOp>());
64ade58a26SNicolas Vasilache   llvm::outs() << "Pattern add(*, add(*)) matched " << countMatches(f, p3)
65ade58a26SNicolas Vasilache                << " times\n";
66ade58a26SNicolas Vasilache 
67a54f4eaeSMogball   auto p4 = m_Op<arith::MulFOp>(m_Op<arith::AddFOp>(), m_Any());
68ade58a26SNicolas Vasilache   llvm::outs() << "Pattern mul(add(*), *) matched " << countMatches(f, p4)
69ade58a26SNicolas Vasilache                << " times\n";
70ade58a26SNicolas Vasilache 
71a54f4eaeSMogball   auto p5 = m_Op<arith::MulFOp>(m_Any(), m_Op<arith::AddFOp>());
72ade58a26SNicolas Vasilache   llvm::outs() << "Pattern mul(*, add(*)) matched " << countMatches(f, p5)
73ade58a26SNicolas Vasilache                << " times\n";
74ade58a26SNicolas Vasilache 
75a54f4eaeSMogball   auto p6 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Any());
76ade58a26SNicolas Vasilache   llvm::outs() << "Pattern mul(mul(*), *) matched " << countMatches(f, p6)
77ade58a26SNicolas Vasilache                << " times\n";
78ade58a26SNicolas Vasilache 
79a54f4eaeSMogball   auto p7 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::MulFOp>());
80ade58a26SNicolas Vasilache   llvm::outs() << "Pattern mul(mul(*), mul(*)) matched " << countMatches(f, p7)
81ade58a26SNicolas Vasilache                << " times\n";
82ade58a26SNicolas Vasilache 
8302b6fb21SMehdi Amini   auto mulOfMulmul =
84a54f4eaeSMogball       m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::MulFOp>());
8502b6fb21SMehdi Amini   auto p8 = m_Op<arith::MulFOp>(mulOfMulmul, mulOfMulmul);
86ade58a26SNicolas Vasilache   llvm::outs()
87ade58a26SNicolas Vasilache       << "Pattern mul(mul(mul(*), mul(*)), mul(mul(*), mul(*))) matched "
88ade58a26SNicolas Vasilache       << countMatches(f, p8) << " times\n";
89ade58a26SNicolas Vasilache 
90ade58a26SNicolas Vasilache   // clang-format off
9102b6fb21SMehdi Amini   auto mulOfMuladd = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::AddFOp>());
9202b6fb21SMehdi Amini   auto mulOfAnyadd = m_Op<arith::MulFOp>(m_Any(), m_Op<arith::AddFOp>());
93a54f4eaeSMogball   auto p9 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(
9402b6fb21SMehdi Amini                      mulOfMuladd, m_Op<arith::MulFOp>()),
9502b6fb21SMehdi Amini                    m_Op<arith::MulFOp>(mulOfAnyadd, mulOfAnyadd));
96ade58a26SNicolas Vasilache   // clang-format on
97ade58a26SNicolas Vasilache   llvm::outs() << "Pattern mul(mul(mul(mul(*), add(*)), mul(*)), mul(mul(*, "
98ade58a26SNicolas Vasilache                   "add(*)), mul(*, add(*)))) matched "
99ade58a26SNicolas Vasilache                << countMatches(f, p9) << " times\n";
100ade58a26SNicolas Vasilache 
101a54f4eaeSMogball   auto p10 = m_Op<arith::AddFOp>(a, b);
102ade58a26SNicolas Vasilache   llvm::outs() << "Pattern add(a, b) matched " << countMatches(f, p10)
103ade58a26SNicolas Vasilache                << " times\n";
104ade58a26SNicolas Vasilache 
105a54f4eaeSMogball   auto p11 = m_Op<arith::AddFOp>(a, c);
106ade58a26SNicolas Vasilache   llvm::outs() << "Pattern add(a, c) matched " << countMatches(f, p11)
107ade58a26SNicolas Vasilache                << " times\n";
108ade58a26SNicolas Vasilache 
109a54f4eaeSMogball   auto p12 = m_Op<arith::AddFOp>(b, a);
110ade58a26SNicolas Vasilache   llvm::outs() << "Pattern add(b, a) matched " << countMatches(f, p12)
111ade58a26SNicolas Vasilache                << " times\n";
112ade58a26SNicolas Vasilache 
113a54f4eaeSMogball   auto p13 = m_Op<arith::AddFOp>(c, a);
114ade58a26SNicolas Vasilache   llvm::outs() << "Pattern add(c, a) matched " << countMatches(f, p13)
115ade58a26SNicolas Vasilache                << " times\n";
116ade58a26SNicolas Vasilache 
117a54f4eaeSMogball   auto p14 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(c, b));
118ade58a26SNicolas Vasilache   llvm::outs() << "Pattern mul(a, add(c, b)) matched " << countMatches(f, p14)
119ade58a26SNicolas Vasilache                << " times\n";
120ade58a26SNicolas Vasilache 
121a54f4eaeSMogball   auto p15 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(b, c));
122ade58a26SNicolas Vasilache   llvm::outs() << "Pattern mul(a, add(b, c)) matched " << countMatches(f, p15)
123ade58a26SNicolas Vasilache                << " times\n";
124ade58a26SNicolas Vasilache 
12502b6fb21SMehdi Amini   auto mulOfAany = m_Op<arith::MulFOp>(a, m_Any());
12602b6fb21SMehdi Amini   auto p16 = m_Op<arith::MulFOp>(mulOfAany, m_Op<arith::AddFOp>(a, c));
127ade58a26SNicolas Vasilache   llvm::outs() << "Pattern mul(mul(a, *), add(a, c)) matched "
128ade58a26SNicolas Vasilache                << countMatches(f, p16) << " times\n";
129ade58a26SNicolas Vasilache 
13002b6fb21SMehdi Amini   auto p17 = m_Op<arith::MulFOp>(mulOfAany, m_Op<arith::AddFOp>(c, b));
131ade58a26SNicolas Vasilache   llvm::outs() << "Pattern mul(mul(a, *), add(c, b)) matched "
132ade58a26SNicolas Vasilache                << countMatches(f, p17) << " times\n";
133ade58a26SNicolas Vasilache }
134ade58a26SNicolas Vasilache 
test2(FunctionOpInterface f)13587d6bf37SRiver Riddle void test2(FunctionOpInterface f) {
1367b19bd54SNicolas Vasilache   auto a = m_Val(f.getArgument(0));
1377b19bd54SNicolas Vasilache   FloatAttr floatAttr;
138a54f4eaeSMogball   auto p =
139a54f4eaeSMogball       m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(a, m_Constant(&floatAttr)));
140a54f4eaeSMogball   auto p1 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(a, m_Constant()));
1417b19bd54SNicolas Vasilache   // Last operation that is not the terminator.
142ecba7c58SRiver Riddle   Operation *lastOp = f.getFunctionBody().front().back().getPrevNode();
1437b19bd54SNicolas Vasilache   if (p.match(lastOp))
1447b19bd54SNicolas Vasilache     llvm::outs()
1457b19bd54SNicolas Vasilache         << "Pattern add(add(a, constant), a) matched and bound constant to: "
1467b19bd54SNicolas Vasilache         << floatAttr.getValueAsDouble() << "\n";
14781e7922eSLorenzo Chelini   if (p1.match(lastOp))
14881e7922eSLorenzo Chelini     llvm::outs() << "Pattern add(add(a, constant), a) matched\n";
1497b19bd54SNicolas Vasilache }
1507b19bd54SNicolas Vasilache 
test3(FunctionOpInterface f)1518b9f8db5SDevajith V S void test3(FunctionOpInterface f) {
1528b9f8db5SDevajith V S   arith::FastMathFlagsAttr fastMathAttr;
1538b9f8db5SDevajith V S   auto p = m_Op<arith::MulFOp>(m_Any(),
1548b9f8db5SDevajith V S                                m_Op<arith::AddFOp>(m_Any(), m_Op("test.name")));
1558b9f8db5SDevajith V S   auto p1 = m_Attr("fastmath", &fastMathAttr);
1568b9f8db5SDevajith V S 
1578b9f8db5SDevajith V S   // Last operation that is not the terminator.
1588b9f8db5SDevajith V S   Operation *lastOp = f.getFunctionBody().front().back().getPrevNode();
1598b9f8db5SDevajith V S   if (p.match(lastOp))
1608b9f8db5SDevajith V S     llvm::outs() << "Pattern mul(*, add(*, m_Op(\"test.name\"))) matched\n";
1618b9f8db5SDevajith V S   if (p1.match(lastOp))
1628b9f8db5SDevajith V S     llvm::outs() << "Pattern m_Attr(\"fastmath\") matched and bound value to: "
1638b9f8db5SDevajith V S                  << fastMathAttr.getValue() << "\n";
1648b9f8db5SDevajith V S }
1658b9f8db5SDevajith V S 
runOnOperation()16641574554SRiver Riddle void TestMatchers::runOnOperation() {
16741574554SRiver Riddle   auto f = getOperation();
1687b19bd54SNicolas Vasilache   llvm::outs() << f.getName() << "\n";
169ade58a26SNicolas Vasilache   if (f.getName() == "test1")
170ade58a26SNicolas Vasilache     test1(f);
1717b19bd54SNicolas Vasilache   if (f.getName() == "test2")
1727b19bd54SNicolas Vasilache     test2(f);
1738b9f8db5SDevajith V S   if (f.getName() == "test3")
1748b9f8db5SDevajith V S     test3(f);
175ade58a26SNicolas Vasilache }
176ade58a26SNicolas Vasilache 
177c6477050SMehdi Amini namespace mlir {
registerTestMatchers()178b5e22e6dSMehdi Amini void registerTestMatchers() { PassRegistration<TestMatchers>(); }
179c6477050SMehdi Amini } // namespace mlir
180