xref: /llvm-project/mlir/test/lib/IR/TestMatchers.cpp (revision be0a7e9f27083ada6072fcc0711ffa5630daa5ec)
1 //===- TestMatchers.cpp - Pass to test matchers ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/StandardOps/IR/Ops.h"
10 #include "mlir/IR/BuiltinOps.h"
11 #include "mlir/IR/Matchers.h"
12 #include "mlir/Pass/Pass.h"
13 
14 using namespace mlir;
15 
16 namespace {
17 /// This is a test pass for verifying matchers.
18 struct TestMatchers : public PassWrapper<TestMatchers, FunctionPass> {
19   void runOnFunction() override;
20   StringRef getArgument() const final { return "test-matchers"; }
21   StringRef getDescription() const final {
22     return "Test C++ pattern matchers.";
23   }
24 };
25 } // namespace
26 
27 // This could be done better but is not worth the variadic template trouble.
28 template <typename Matcher>
29 static unsigned countMatches(FuncOp f, Matcher &matcher) {
30   unsigned count = 0;
31   f.walk([&count, &matcher](Operation *op) {
32     if (matcher.match(op))
33       ++count;
34   });
35   return count;
36 }
37 
38 using mlir::matchers::m_Any;
39 using mlir::matchers::m_Val;
40 static void test1(FuncOp f) {
41   assert(f.getNumArguments() == 3 && "matcher test funcs must have 3 args");
42 
43   auto a = m_Val(f.getArgument(0));
44   auto b = m_Val(f.getArgument(1));
45   auto c = m_Val(f.getArgument(2));
46 
47   auto p0 = m_Op<arith::AddFOp>(); // using 0-arity matcher
48   llvm::outs() << "Pattern add(*) matched " << countMatches(f, p0)
49                << " times\n";
50 
51   auto p1 = m_Op<arith::MulFOp>(); // using 0-arity matcher
52   llvm::outs() << "Pattern mul(*) matched " << countMatches(f, p1)
53                << " times\n";
54 
55   auto p2 = m_Op<arith::AddFOp>(m_Op<arith::AddFOp>(), m_Any());
56   llvm::outs() << "Pattern add(add(*), *) matched " << countMatches(f, p2)
57                << " times\n";
58 
59   auto p3 = m_Op<arith::AddFOp>(m_Any(), m_Op<arith::AddFOp>());
60   llvm::outs() << "Pattern add(*, add(*)) matched " << countMatches(f, p3)
61                << " times\n";
62 
63   auto p4 = m_Op<arith::MulFOp>(m_Op<arith::AddFOp>(), m_Any());
64   llvm::outs() << "Pattern mul(add(*), *) matched " << countMatches(f, p4)
65                << " times\n";
66 
67   auto p5 = m_Op<arith::MulFOp>(m_Any(), m_Op<arith::AddFOp>());
68   llvm::outs() << "Pattern mul(*, add(*)) matched " << countMatches(f, p5)
69                << " times\n";
70 
71   auto p6 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Any());
72   llvm::outs() << "Pattern mul(mul(*), *) matched " << countMatches(f, p6)
73                << " times\n";
74 
75   auto p7 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::MulFOp>());
76   llvm::outs() << "Pattern mul(mul(*), mul(*)) matched " << countMatches(f, p7)
77                << " times\n";
78 
79   auto mul_of_mulmul =
80       m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::MulFOp>());
81   auto p8 = m_Op<arith::MulFOp>(mul_of_mulmul, mul_of_mulmul);
82   llvm::outs()
83       << "Pattern mul(mul(mul(*), mul(*)), mul(mul(*), mul(*))) matched "
84       << countMatches(f, p8) << " times\n";
85 
86   // clang-format off
87   auto mul_of_muladd = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::AddFOp>());
88   auto mul_of_anyadd = m_Op<arith::MulFOp>(m_Any(), m_Op<arith::AddFOp>());
89   auto p9 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(
90                      mul_of_muladd, m_Op<arith::MulFOp>()),
91                    m_Op<arith::MulFOp>(mul_of_anyadd, mul_of_anyadd));
92   // clang-format on
93   llvm::outs() << "Pattern mul(mul(mul(mul(*), add(*)), mul(*)), mul(mul(*, "
94                   "add(*)), mul(*, add(*)))) matched "
95                << countMatches(f, p9) << " times\n";
96 
97   auto p10 = m_Op<arith::AddFOp>(a, b);
98   llvm::outs() << "Pattern add(a, b) matched " << countMatches(f, p10)
99                << " times\n";
100 
101   auto p11 = m_Op<arith::AddFOp>(a, c);
102   llvm::outs() << "Pattern add(a, c) matched " << countMatches(f, p11)
103                << " times\n";
104 
105   auto p12 = m_Op<arith::AddFOp>(b, a);
106   llvm::outs() << "Pattern add(b, a) matched " << countMatches(f, p12)
107                << " times\n";
108 
109   auto p13 = m_Op<arith::AddFOp>(c, a);
110   llvm::outs() << "Pattern add(c, a) matched " << countMatches(f, p13)
111                << " times\n";
112 
113   auto p14 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(c, b));
114   llvm::outs() << "Pattern mul(a, add(c, b)) matched " << countMatches(f, p14)
115                << " times\n";
116 
117   auto p15 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(b, c));
118   llvm::outs() << "Pattern mul(a, add(b, c)) matched " << countMatches(f, p15)
119                << " times\n";
120 
121   auto mul_of_aany = m_Op<arith::MulFOp>(a, m_Any());
122   auto p16 = m_Op<arith::MulFOp>(mul_of_aany, m_Op<arith::AddFOp>(a, c));
123   llvm::outs() << "Pattern mul(mul(a, *), add(a, c)) matched "
124                << countMatches(f, p16) << " times\n";
125 
126   auto p17 = m_Op<arith::MulFOp>(mul_of_aany, m_Op<arith::AddFOp>(c, b));
127   llvm::outs() << "Pattern mul(mul(a, *), add(c, b)) matched "
128                << countMatches(f, p17) << " times\n";
129 }
130 
131 void test2(FuncOp f) {
132   auto a = m_Val(f.getArgument(0));
133   FloatAttr floatAttr;
134   auto p =
135       m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(a, m_Constant(&floatAttr)));
136   auto p1 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(a, m_Constant()));
137   // Last operation that is not the terminator.
138   Operation *lastOp = f.getBody().front().back().getPrevNode();
139   if (p.match(lastOp))
140     llvm::outs()
141         << "Pattern add(add(a, constant), a) matched and bound constant to: "
142         << floatAttr.getValueAsDouble() << "\n";
143   if (p1.match(lastOp))
144     llvm::outs() << "Pattern add(add(a, constant), a) matched\n";
145 }
146 
147 void TestMatchers::runOnFunction() {
148   auto f = getFunction();
149   llvm::outs() << f.getName() << "\n";
150   if (f.getName() == "test1")
151     test1(f);
152   if (f.getName() == "test2")
153     test2(f);
154 }
155 
156 namespace mlir {
157 void registerTestMatchers() { PassRegistration<TestMatchers>(); }
158 } // namespace mlir
159