1ade58a26SNicolas Vasilache //===- TestMatchers.cpp - Pass to test matchers ---------------------------===//
2ade58a26SNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ade58a26SNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
8ade58a26SNicolas Vasilache
9abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1065fcddffSRiver Riddle #include "mlir/IR/BuiltinOps.h"
11ade58a26SNicolas Vasilache #include "mlir/IR/Matchers.h"
12*34a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionInterfaces.h"
13ade58a26SNicolas Vasilache #include "mlir/Pass/Pass.h"
14ade58a26SNicolas Vasilache
15ade58a26SNicolas Vasilache using namespace mlir;
16ade58a26SNicolas Vasilache
17ade58a26SNicolas Vasilache namespace {
18ade58a26SNicolas Vasilache /// This is a test pass for verifying matchers.
1987d6bf37SRiver Riddle struct TestMatchers
2087d6bf37SRiver Riddle : public PassWrapper<TestMatchers, InterfacePass<FunctionOpInterface>> {
215e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMatchers)
225e50dd04SRiver Riddle
2341574554SRiver Riddle void runOnOperation() override;
getArgument__anon5c1ca04c0111::TestMatchers24b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-matchers"; }
getDescription__anon5c1ca04c0111::TestMatchers25b5e22e6dSMehdi Amini StringRef getDescription() const final {
26b5e22e6dSMehdi Amini return "Test C++ pattern matchers.";
27b5e22e6dSMehdi Amini }
28ade58a26SNicolas Vasilache };
29be0a7e9fSMehdi Amini } // namespace
30ade58a26SNicolas Vasilache
31ade58a26SNicolas Vasilache // This could be done better but is not worth the variadic template trouble.
32350dadaaSBenjamin Kramer template <typename Matcher>
countMatches(FunctionOpInterface f,Matcher & matcher)3387d6bf37SRiver Riddle static unsigned countMatches(FunctionOpInterface f, Matcher &matcher) {
34ade58a26SNicolas Vasilache unsigned count = 0;
35ade58a26SNicolas Vasilache f.walk([&count, &matcher](Operation *op) {
367b19bd54SNicolas Vasilache if (matcher.match(op))
37ade58a26SNicolas Vasilache ++count;
38ade58a26SNicolas Vasilache });
39ade58a26SNicolas Vasilache return count;
40ade58a26SNicolas Vasilache }
41ade58a26SNicolas Vasilache
427b19bd54SNicolas Vasilache using mlir::matchers::m_Any;
437b19bd54SNicolas Vasilache using mlir::matchers::m_Val;
test1(FunctionOpInterface f)4487d6bf37SRiver Riddle static void test1(FunctionOpInterface f) {
45ade58a26SNicolas Vasilache assert(f.getNumArguments() == 3 && "matcher test funcs must have 3 args");
46ade58a26SNicolas Vasilache
477b19bd54SNicolas Vasilache auto a = m_Val(f.getArgument(0));
487b19bd54SNicolas Vasilache auto b = m_Val(f.getArgument(1));
497b19bd54SNicolas Vasilache auto c = m_Val(f.getArgument(2));
50ade58a26SNicolas Vasilache
51a54f4eaeSMogball auto p0 = m_Op<arith::AddFOp>(); // using 0-arity matcher
52ade58a26SNicolas Vasilache llvm::outs() << "Pattern add(*) matched " << countMatches(f, p0)
53ade58a26SNicolas Vasilache << " times\n";
54ade58a26SNicolas Vasilache
55a54f4eaeSMogball auto p1 = m_Op<arith::MulFOp>(); // using 0-arity matcher
56ade58a26SNicolas Vasilache llvm::outs() << "Pattern mul(*) matched " << countMatches(f, p1)
57ade58a26SNicolas Vasilache << " times\n";
58ade58a26SNicolas Vasilache
59a54f4eaeSMogball auto p2 = m_Op<arith::AddFOp>(m_Op<arith::AddFOp>(), m_Any());
60ade58a26SNicolas Vasilache llvm::outs() << "Pattern add(add(*), *) matched " << countMatches(f, p2)
61ade58a26SNicolas Vasilache << " times\n";
62ade58a26SNicolas Vasilache
63a54f4eaeSMogball auto p3 = m_Op<arith::AddFOp>(m_Any(), m_Op<arith::AddFOp>());
64ade58a26SNicolas Vasilache llvm::outs() << "Pattern add(*, add(*)) matched " << countMatches(f, p3)
65ade58a26SNicolas Vasilache << " times\n";
66ade58a26SNicolas Vasilache
67a54f4eaeSMogball auto p4 = m_Op<arith::MulFOp>(m_Op<arith::AddFOp>(), m_Any());
68ade58a26SNicolas Vasilache llvm::outs() << "Pattern mul(add(*), *) matched " << countMatches(f, p4)
69ade58a26SNicolas Vasilache << " times\n";
70ade58a26SNicolas Vasilache
71a54f4eaeSMogball auto p5 = m_Op<arith::MulFOp>(m_Any(), m_Op<arith::AddFOp>());
72ade58a26SNicolas Vasilache llvm::outs() << "Pattern mul(*, add(*)) matched " << countMatches(f, p5)
73ade58a26SNicolas Vasilache << " times\n";
74ade58a26SNicolas Vasilache
75a54f4eaeSMogball auto p6 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Any());
76ade58a26SNicolas Vasilache llvm::outs() << "Pattern mul(mul(*), *) matched " << countMatches(f, p6)
77ade58a26SNicolas Vasilache << " times\n";
78ade58a26SNicolas Vasilache
79a54f4eaeSMogball auto p7 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::MulFOp>());
80ade58a26SNicolas Vasilache llvm::outs() << "Pattern mul(mul(*), mul(*)) matched " << countMatches(f, p7)
81ade58a26SNicolas Vasilache << " times\n";
82ade58a26SNicolas Vasilache
8302b6fb21SMehdi Amini auto mulOfMulmul =
84a54f4eaeSMogball m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::MulFOp>());
8502b6fb21SMehdi Amini auto p8 = m_Op<arith::MulFOp>(mulOfMulmul, mulOfMulmul);
86ade58a26SNicolas Vasilache llvm::outs()
87ade58a26SNicolas Vasilache << "Pattern mul(mul(mul(*), mul(*)), mul(mul(*), mul(*))) matched "
88ade58a26SNicolas Vasilache << countMatches(f, p8) << " times\n";
89ade58a26SNicolas Vasilache
90ade58a26SNicolas Vasilache // clang-format off
9102b6fb21SMehdi Amini auto mulOfMuladd = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::AddFOp>());
9202b6fb21SMehdi Amini auto mulOfAnyadd = m_Op<arith::MulFOp>(m_Any(), m_Op<arith::AddFOp>());
93a54f4eaeSMogball auto p9 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(
9402b6fb21SMehdi Amini mulOfMuladd, m_Op<arith::MulFOp>()),
9502b6fb21SMehdi Amini m_Op<arith::MulFOp>(mulOfAnyadd, mulOfAnyadd));
96ade58a26SNicolas Vasilache // clang-format on
97ade58a26SNicolas Vasilache llvm::outs() << "Pattern mul(mul(mul(mul(*), add(*)), mul(*)), mul(mul(*, "
98ade58a26SNicolas Vasilache "add(*)), mul(*, add(*)))) matched "
99ade58a26SNicolas Vasilache << countMatches(f, p9) << " times\n";
100ade58a26SNicolas Vasilache
101a54f4eaeSMogball auto p10 = m_Op<arith::AddFOp>(a, b);
102ade58a26SNicolas Vasilache llvm::outs() << "Pattern add(a, b) matched " << countMatches(f, p10)
103ade58a26SNicolas Vasilache << " times\n";
104ade58a26SNicolas Vasilache
105a54f4eaeSMogball auto p11 = m_Op<arith::AddFOp>(a, c);
106ade58a26SNicolas Vasilache llvm::outs() << "Pattern add(a, c) matched " << countMatches(f, p11)
107ade58a26SNicolas Vasilache << " times\n";
108ade58a26SNicolas Vasilache
109a54f4eaeSMogball auto p12 = m_Op<arith::AddFOp>(b, a);
110ade58a26SNicolas Vasilache llvm::outs() << "Pattern add(b, a) matched " << countMatches(f, p12)
111ade58a26SNicolas Vasilache << " times\n";
112ade58a26SNicolas Vasilache
113a54f4eaeSMogball auto p13 = m_Op<arith::AddFOp>(c, a);
114ade58a26SNicolas Vasilache llvm::outs() << "Pattern add(c, a) matched " << countMatches(f, p13)
115ade58a26SNicolas Vasilache << " times\n";
116ade58a26SNicolas Vasilache
117a54f4eaeSMogball auto p14 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(c, b));
118ade58a26SNicolas Vasilache llvm::outs() << "Pattern mul(a, add(c, b)) matched " << countMatches(f, p14)
119ade58a26SNicolas Vasilache << " times\n";
120ade58a26SNicolas Vasilache
121a54f4eaeSMogball auto p15 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(b, c));
122ade58a26SNicolas Vasilache llvm::outs() << "Pattern mul(a, add(b, c)) matched " << countMatches(f, p15)
123ade58a26SNicolas Vasilache << " times\n";
124ade58a26SNicolas Vasilache
12502b6fb21SMehdi Amini auto mulOfAany = m_Op<arith::MulFOp>(a, m_Any());
12602b6fb21SMehdi Amini auto p16 = m_Op<arith::MulFOp>(mulOfAany, m_Op<arith::AddFOp>(a, c));
127ade58a26SNicolas Vasilache llvm::outs() << "Pattern mul(mul(a, *), add(a, c)) matched "
128ade58a26SNicolas Vasilache << countMatches(f, p16) << " times\n";
129ade58a26SNicolas Vasilache
13002b6fb21SMehdi Amini auto p17 = m_Op<arith::MulFOp>(mulOfAany, m_Op<arith::AddFOp>(c, b));
131ade58a26SNicolas Vasilache llvm::outs() << "Pattern mul(mul(a, *), add(c, b)) matched "
132ade58a26SNicolas Vasilache << countMatches(f, p17) << " times\n";
133ade58a26SNicolas Vasilache }
134ade58a26SNicolas Vasilache
test2(FunctionOpInterface f)13587d6bf37SRiver Riddle void test2(FunctionOpInterface f) {
1367b19bd54SNicolas Vasilache auto a = m_Val(f.getArgument(0));
1377b19bd54SNicolas Vasilache FloatAttr floatAttr;
138a54f4eaeSMogball auto p =
139a54f4eaeSMogball m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(a, m_Constant(&floatAttr)));
140a54f4eaeSMogball auto p1 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(a, m_Constant()));
1417b19bd54SNicolas Vasilache // Last operation that is not the terminator.
142ecba7c58SRiver Riddle Operation *lastOp = f.getFunctionBody().front().back().getPrevNode();
1437b19bd54SNicolas Vasilache if (p.match(lastOp))
1447b19bd54SNicolas Vasilache llvm::outs()
1457b19bd54SNicolas Vasilache << "Pattern add(add(a, constant), a) matched and bound constant to: "
1467b19bd54SNicolas Vasilache << floatAttr.getValueAsDouble() << "\n";
14781e7922eSLorenzo Chelini if (p1.match(lastOp))
14881e7922eSLorenzo Chelini llvm::outs() << "Pattern add(add(a, constant), a) matched\n";
1497b19bd54SNicolas Vasilache }
1507b19bd54SNicolas Vasilache
test3(FunctionOpInterface f)1518b9f8db5SDevajith V S void test3(FunctionOpInterface f) {
1528b9f8db5SDevajith V S arith::FastMathFlagsAttr fastMathAttr;
1538b9f8db5SDevajith V S auto p = m_Op<arith::MulFOp>(m_Any(),
1548b9f8db5SDevajith V S m_Op<arith::AddFOp>(m_Any(), m_Op("test.name")));
1558b9f8db5SDevajith V S auto p1 = m_Attr("fastmath", &fastMathAttr);
1568b9f8db5SDevajith V S
1578b9f8db5SDevajith V S // Last operation that is not the terminator.
1588b9f8db5SDevajith V S Operation *lastOp = f.getFunctionBody().front().back().getPrevNode();
1598b9f8db5SDevajith V S if (p.match(lastOp))
1608b9f8db5SDevajith V S llvm::outs() << "Pattern mul(*, add(*, m_Op(\"test.name\"))) matched\n";
1618b9f8db5SDevajith V S if (p1.match(lastOp))
1628b9f8db5SDevajith V S llvm::outs() << "Pattern m_Attr(\"fastmath\") matched and bound value to: "
1638b9f8db5SDevajith V S << fastMathAttr.getValue() << "\n";
1648b9f8db5SDevajith V S }
1658b9f8db5SDevajith V S
runOnOperation()16641574554SRiver Riddle void TestMatchers::runOnOperation() {
16741574554SRiver Riddle auto f = getOperation();
1687b19bd54SNicolas Vasilache llvm::outs() << f.getName() << "\n";
169ade58a26SNicolas Vasilache if (f.getName() == "test1")
170ade58a26SNicolas Vasilache test1(f);
1717b19bd54SNicolas Vasilache if (f.getName() == "test2")
1727b19bd54SNicolas Vasilache test2(f);
1738b9f8db5SDevajith V S if (f.getName() == "test3")
1748b9f8db5SDevajith V S test3(f);
175ade58a26SNicolas Vasilache }
176ade58a26SNicolas Vasilache
177c6477050SMehdi Amini namespace mlir {
registerTestMatchers()178b5e22e6dSMehdi Amini void registerTestMatchers() { PassRegistration<TestMatchers>(); }
179c6477050SMehdi Amini } // namespace mlir
180