1 //===- TestMatchers.cpp - Pass to test matchers ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/IR/BuiltinOps.h" 11 #include "mlir/IR/FunctionInterfaces.h" 12 #include "mlir/IR/Matchers.h" 13 #include "mlir/Pass/Pass.h" 14 15 using namespace mlir; 16 17 namespace { 18 /// This is a test pass for verifying matchers. 19 struct TestMatchers 20 : public PassWrapper<TestMatchers, InterfacePass<FunctionOpInterface>> { 21 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMatchers) 22 23 void runOnOperation() override; 24 StringRef getArgument() const final { return "test-matchers"; } 25 StringRef getDescription() const final { 26 return "Test C++ pattern matchers."; 27 } 28 }; 29 } // namespace 30 31 // This could be done better but is not worth the variadic template trouble. 32 template <typename Matcher> 33 static unsigned countMatches(FunctionOpInterface f, Matcher &matcher) { 34 unsigned count = 0; 35 f.walk([&count, &matcher](Operation *op) { 36 if (matcher.match(op)) 37 ++count; 38 }); 39 return count; 40 } 41 42 using mlir::matchers::m_Any; 43 using mlir::matchers::m_Val; 44 static void test1(FunctionOpInterface f) { 45 assert(f.getNumArguments() == 3 && "matcher test funcs must have 3 args"); 46 47 auto a = m_Val(f.getArgument(0)); 48 auto b = m_Val(f.getArgument(1)); 49 auto c = m_Val(f.getArgument(2)); 50 51 auto p0 = m_Op<arith::AddFOp>(); // using 0-arity matcher 52 llvm::outs() << "Pattern add(*) matched " << countMatches(f, p0) 53 << " times\n"; 54 55 auto p1 = m_Op<arith::MulFOp>(); // using 0-arity matcher 56 llvm::outs() << "Pattern mul(*) matched " << countMatches(f, p1) 57 << " times\n"; 58 59 auto p2 = m_Op<arith::AddFOp>(m_Op<arith::AddFOp>(), m_Any()); 60 llvm::outs() << "Pattern add(add(*), *) matched " << countMatches(f, p2) 61 << " times\n"; 62 63 auto p3 = m_Op<arith::AddFOp>(m_Any(), m_Op<arith::AddFOp>()); 64 llvm::outs() << "Pattern add(*, add(*)) matched " << countMatches(f, p3) 65 << " times\n"; 66 67 auto p4 = m_Op<arith::MulFOp>(m_Op<arith::AddFOp>(), m_Any()); 68 llvm::outs() << "Pattern mul(add(*), *) matched " << countMatches(f, p4) 69 << " times\n"; 70 71 auto p5 = m_Op<arith::MulFOp>(m_Any(), m_Op<arith::AddFOp>()); 72 llvm::outs() << "Pattern mul(*, add(*)) matched " << countMatches(f, p5) 73 << " times\n"; 74 75 auto p6 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Any()); 76 llvm::outs() << "Pattern mul(mul(*), *) matched " << countMatches(f, p6) 77 << " times\n"; 78 79 auto p7 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::MulFOp>()); 80 llvm::outs() << "Pattern mul(mul(*), mul(*)) matched " << countMatches(f, p7) 81 << " times\n"; 82 83 auto mulOfMulmul = 84 m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::MulFOp>()); 85 auto p8 = m_Op<arith::MulFOp>(mulOfMulmul, mulOfMulmul); 86 llvm::outs() 87 << "Pattern mul(mul(mul(*), mul(*)), mul(mul(*), mul(*))) matched " 88 << countMatches(f, p8) << " times\n"; 89 90 // clang-format off 91 auto mulOfMuladd = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::AddFOp>()); 92 auto mulOfAnyadd = m_Op<arith::MulFOp>(m_Any(), m_Op<arith::AddFOp>()); 93 auto p9 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>( 94 mulOfMuladd, m_Op<arith::MulFOp>()), 95 m_Op<arith::MulFOp>(mulOfAnyadd, mulOfAnyadd)); 96 // clang-format on 97 llvm::outs() << "Pattern mul(mul(mul(mul(*), add(*)), mul(*)), mul(mul(*, " 98 "add(*)), mul(*, add(*)))) matched " 99 << countMatches(f, p9) << " times\n"; 100 101 auto p10 = m_Op<arith::AddFOp>(a, b); 102 llvm::outs() << "Pattern add(a, b) matched " << countMatches(f, p10) 103 << " times\n"; 104 105 auto p11 = m_Op<arith::AddFOp>(a, c); 106 llvm::outs() << "Pattern add(a, c) matched " << countMatches(f, p11) 107 << " times\n"; 108 109 auto p12 = m_Op<arith::AddFOp>(b, a); 110 llvm::outs() << "Pattern add(b, a) matched " << countMatches(f, p12) 111 << " times\n"; 112 113 auto p13 = m_Op<arith::AddFOp>(c, a); 114 llvm::outs() << "Pattern add(c, a) matched " << countMatches(f, p13) 115 << " times\n"; 116 117 auto p14 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(c, b)); 118 llvm::outs() << "Pattern mul(a, add(c, b)) matched " << countMatches(f, p14) 119 << " times\n"; 120 121 auto p15 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(b, c)); 122 llvm::outs() << "Pattern mul(a, add(b, c)) matched " << countMatches(f, p15) 123 << " times\n"; 124 125 auto mulOfAany = m_Op<arith::MulFOp>(a, m_Any()); 126 auto p16 = m_Op<arith::MulFOp>(mulOfAany, m_Op<arith::AddFOp>(a, c)); 127 llvm::outs() << "Pattern mul(mul(a, *), add(a, c)) matched " 128 << countMatches(f, p16) << " times\n"; 129 130 auto p17 = m_Op<arith::MulFOp>(mulOfAany, m_Op<arith::AddFOp>(c, b)); 131 llvm::outs() << "Pattern mul(mul(a, *), add(c, b)) matched " 132 << countMatches(f, p17) << " times\n"; 133 } 134 135 void test2(FunctionOpInterface f) { 136 auto a = m_Val(f.getArgument(0)); 137 FloatAttr floatAttr; 138 auto p = 139 m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(a, m_Constant(&floatAttr))); 140 auto p1 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(a, m_Constant())); 141 // Last operation that is not the terminator. 142 Operation *lastOp = f.getBody().front().back().getPrevNode(); 143 if (p.match(lastOp)) 144 llvm::outs() 145 << "Pattern add(add(a, constant), a) matched and bound constant to: " 146 << floatAttr.getValueAsDouble() << "\n"; 147 if (p1.match(lastOp)) 148 llvm::outs() << "Pattern add(add(a, constant), a) matched\n"; 149 } 150 151 void TestMatchers::runOnOperation() { 152 auto f = getOperation(); 153 llvm::outs() << f.getName() << "\n"; 154 if (f.getName() == "test1") 155 test1(f); 156 if (f.getName() == "test2") 157 test2(f); 158 } 159 160 namespace mlir { 161 void registerTestMatchers() { PassRegistration<TestMatchers>(); } 162 } // namespace mlir 163