xref: /llvm-project/mlir/test/lib/IR/TestMatchers.cpp (revision c64770506b89a2376fe13080bc3b72789e6c752d)
1 //===- TestMatchers.cpp - Pass to test matchers ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/StandardOps/Ops.h"
10 #include "mlir/IR/Function.h"
11 #include "mlir/IR/Matchers.h"
12 #include "mlir/Pass/Pass.h"
13 
14 using namespace mlir;
15 
16 namespace {
17 /// This is a test pass for verifying matchers.
18 struct TestMatchers : public FunctionPass<TestMatchers> {
19   void runOnFunction() override;
20 };
21 } // end anonymous namespace
22 
23 // This could be done better but is not worth the variadic template trouble.
24 template <typename Matcher> unsigned countMatches(FuncOp f, Matcher &matcher) {
25   unsigned count = 0;
26   f.walk([&count, &matcher](Operation *op) {
27     if (matcher.match(op))
28       ++count;
29   });
30   return count;
31 }
32 
33 using mlir::matchers::m_Any;
34 using mlir::matchers::m_Val;
35 static void test1(FuncOp f) {
36   assert(f.getNumArguments() == 3 && "matcher test funcs must have 3 args");
37 
38   auto a = m_Val(f.getArgument(0));
39   auto b = m_Val(f.getArgument(1));
40   auto c = m_Val(f.getArgument(2));
41 
42   auto p0 = m_Op<AddFOp>(); // using 0-arity matcher
43   llvm::outs() << "Pattern add(*) matched " << countMatches(f, p0)
44                << " times\n";
45 
46   auto p1 = m_Op<MulFOp>(); // using 0-arity matcher
47   llvm::outs() << "Pattern mul(*) matched " << countMatches(f, p1)
48                << " times\n";
49 
50   auto p2 = m_Op<AddFOp>(m_Op<AddFOp>(), m_Any());
51   llvm::outs() << "Pattern add(add(*), *) matched " << countMatches(f, p2)
52                << " times\n";
53 
54   auto p3 = m_Op<AddFOp>(m_Any(), m_Op<AddFOp>());
55   llvm::outs() << "Pattern add(*, add(*)) matched " << countMatches(f, p3)
56                << " times\n";
57 
58   auto p4 = m_Op<MulFOp>(m_Op<AddFOp>(), m_Any());
59   llvm::outs() << "Pattern mul(add(*), *) matched " << countMatches(f, p4)
60                << " times\n";
61 
62   auto p5 = m_Op<MulFOp>(m_Any(), m_Op<AddFOp>());
63   llvm::outs() << "Pattern mul(*, add(*)) matched " << countMatches(f, p5)
64                << " times\n";
65 
66   auto p6 = m_Op<MulFOp>(m_Op<MulFOp>(), m_Any());
67   llvm::outs() << "Pattern mul(mul(*), *) matched " << countMatches(f, p6)
68                << " times\n";
69 
70   auto p7 = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<MulFOp>());
71   llvm::outs() << "Pattern mul(mul(*), mul(*)) matched " << countMatches(f, p7)
72                << " times\n";
73 
74   auto mul_of_mulmul = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<MulFOp>());
75   auto p8 = m_Op<MulFOp>(mul_of_mulmul, mul_of_mulmul);
76   llvm::outs()
77       << "Pattern mul(mul(mul(*), mul(*)), mul(mul(*), mul(*))) matched "
78       << countMatches(f, p8) << " times\n";
79 
80   // clang-format off
81   auto mul_of_muladd = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<AddFOp>());
82   auto mul_of_anyadd = m_Op<MulFOp>(m_Any(), m_Op<AddFOp>());
83   auto p9 = m_Op<MulFOp>(m_Op<MulFOp>(
84                      mul_of_muladd, m_Op<MulFOp>()),
85                    m_Op<MulFOp>(mul_of_anyadd, mul_of_anyadd));
86   // clang-format on
87   llvm::outs() << "Pattern mul(mul(mul(mul(*), add(*)), mul(*)), mul(mul(*, "
88                   "add(*)), mul(*, add(*)))) matched "
89                << countMatches(f, p9) << " times\n";
90 
91   auto p10 = m_Op<AddFOp>(a, b);
92   llvm::outs() << "Pattern add(a, b) matched " << countMatches(f, p10)
93                << " times\n";
94 
95   auto p11 = m_Op<AddFOp>(a, c);
96   llvm::outs() << "Pattern add(a, c) matched " << countMatches(f, p11)
97                << " times\n";
98 
99   auto p12 = m_Op<AddFOp>(b, a);
100   llvm::outs() << "Pattern add(b, a) matched " << countMatches(f, p12)
101                << " times\n";
102 
103   auto p13 = m_Op<AddFOp>(c, a);
104   llvm::outs() << "Pattern add(c, a) matched " << countMatches(f, p13)
105                << " times\n";
106 
107   auto p14 = m_Op<MulFOp>(a, m_Op<AddFOp>(c, b));
108   llvm::outs() << "Pattern mul(a, add(c, b)) matched " << countMatches(f, p14)
109                << " times\n";
110 
111   auto p15 = m_Op<MulFOp>(a, m_Op<AddFOp>(b, c));
112   llvm::outs() << "Pattern mul(a, add(b, c)) matched " << countMatches(f, p15)
113                << " times\n";
114 
115   auto mul_of_aany = m_Op<MulFOp>(a, m_Any());
116   auto p16 = m_Op<MulFOp>(mul_of_aany, m_Op<AddFOp>(a, c));
117   llvm::outs() << "Pattern mul(mul(a, *), add(a, c)) matched "
118                << countMatches(f, p16) << " times\n";
119 
120   auto p17 = m_Op<MulFOp>(mul_of_aany, m_Op<AddFOp>(c, b));
121   llvm::outs() << "Pattern mul(mul(a, *), add(c, b)) matched "
122                << countMatches(f, p17) << " times\n";
123 }
124 
125 void test2(FuncOp f) {
126   auto a = m_Val(f.getArgument(0));
127   FloatAttr floatAttr;
128   auto p = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant(&floatAttr)));
129   auto p1 = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant()));
130   // Last operation that is not the terminator.
131   Operation *lastOp = f.getBody().front().back().getPrevNode();
132   if (p.match(lastOp))
133     llvm::outs()
134         << "Pattern add(add(a, constant), a) matched and bound constant to: "
135         << floatAttr.getValueAsDouble() << "\n";
136   if (p1.match(lastOp))
137     llvm::outs() << "Pattern add(add(a, constant), a) matched\n";
138 }
139 
140 void TestMatchers::runOnFunction() {
141   auto f = getFunction();
142   llvm::outs() << f.getName() << "\n";
143   if (f.getName() == "test1")
144     test1(f);
145   if (f.getName() == "test2")
146     test2(f);
147 }
148 
149 namespace mlir {
150 void registerTestMatchers() {
151   PassRegistration<TestMatchers>("test-matchers", "Test C++ pattern matchers.");
152 }
153 } // namespace mlir
154