1 //===- TestMatchers.cpp - Pass to test matchers ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/StandardOps/Ops.h" 10 #include "mlir/IR/Function.h" 11 #include "mlir/IR/Matchers.h" 12 #include "mlir/Pass/Pass.h" 13 14 using namespace mlir; 15 16 namespace { 17 /// This is a test pass for verifying matchers. 18 struct TestMatchers : public FunctionPass<TestMatchers> { 19 void runOnFunction() override; 20 }; 21 } // end anonymous namespace 22 23 // This could be done better but is not worth the variadic template trouble. 24 template <typename Matcher> unsigned countMatches(FuncOp f, Matcher &matcher) { 25 unsigned count = 0; 26 f.walk([&count, &matcher](Operation *op) { 27 if (matcher.match(op)) 28 ++count; 29 }); 30 return count; 31 } 32 33 using mlir::matchers::m_Any; 34 using mlir::matchers::m_Val; 35 static void test1(FuncOp f) { 36 assert(f.getNumArguments() == 3 && "matcher test funcs must have 3 args"); 37 38 auto a = m_Val(f.getArgument(0)); 39 auto b = m_Val(f.getArgument(1)); 40 auto c = m_Val(f.getArgument(2)); 41 42 auto p0 = m_Op<AddFOp>(); // using 0-arity matcher 43 llvm::outs() << "Pattern add(*) matched " << countMatches(f, p0) 44 << " times\n"; 45 46 auto p1 = m_Op<MulFOp>(); // using 0-arity matcher 47 llvm::outs() << "Pattern mul(*) matched " << countMatches(f, p1) 48 << " times\n"; 49 50 auto p2 = m_Op<AddFOp>(m_Op<AddFOp>(), m_Any()); 51 llvm::outs() << "Pattern add(add(*), *) matched " << countMatches(f, p2) 52 << " times\n"; 53 54 auto p3 = m_Op<AddFOp>(m_Any(), m_Op<AddFOp>()); 55 llvm::outs() << "Pattern add(*, add(*)) matched " << countMatches(f, p3) 56 << " times\n"; 57 58 auto p4 = m_Op<MulFOp>(m_Op<AddFOp>(), m_Any()); 59 llvm::outs() << "Pattern mul(add(*), *) matched " << countMatches(f, p4) 60 << " times\n"; 61 62 auto p5 = m_Op<MulFOp>(m_Any(), m_Op<AddFOp>()); 63 llvm::outs() << "Pattern mul(*, add(*)) matched " << countMatches(f, p5) 64 << " times\n"; 65 66 auto p6 = m_Op<MulFOp>(m_Op<MulFOp>(), m_Any()); 67 llvm::outs() << "Pattern mul(mul(*), *) matched " << countMatches(f, p6) 68 << " times\n"; 69 70 auto p7 = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<MulFOp>()); 71 llvm::outs() << "Pattern mul(mul(*), mul(*)) matched " << countMatches(f, p7) 72 << " times\n"; 73 74 auto mul_of_mulmul = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<MulFOp>()); 75 auto p8 = m_Op<MulFOp>(mul_of_mulmul, mul_of_mulmul); 76 llvm::outs() 77 << "Pattern mul(mul(mul(*), mul(*)), mul(mul(*), mul(*))) matched " 78 << countMatches(f, p8) << " times\n"; 79 80 // clang-format off 81 auto mul_of_muladd = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<AddFOp>()); 82 auto mul_of_anyadd = m_Op<MulFOp>(m_Any(), m_Op<AddFOp>()); 83 auto p9 = m_Op<MulFOp>(m_Op<MulFOp>( 84 mul_of_muladd, m_Op<MulFOp>()), 85 m_Op<MulFOp>(mul_of_anyadd, mul_of_anyadd)); 86 // clang-format on 87 llvm::outs() << "Pattern mul(mul(mul(mul(*), add(*)), mul(*)), mul(mul(*, " 88 "add(*)), mul(*, add(*)))) matched " 89 << countMatches(f, p9) << " times\n"; 90 91 auto p10 = m_Op<AddFOp>(a, b); 92 llvm::outs() << "Pattern add(a, b) matched " << countMatches(f, p10) 93 << " times\n"; 94 95 auto p11 = m_Op<AddFOp>(a, c); 96 llvm::outs() << "Pattern add(a, c) matched " << countMatches(f, p11) 97 << " times\n"; 98 99 auto p12 = m_Op<AddFOp>(b, a); 100 llvm::outs() << "Pattern add(b, a) matched " << countMatches(f, p12) 101 << " times\n"; 102 103 auto p13 = m_Op<AddFOp>(c, a); 104 llvm::outs() << "Pattern add(c, a) matched " << countMatches(f, p13) 105 << " times\n"; 106 107 auto p14 = m_Op<MulFOp>(a, m_Op<AddFOp>(c, b)); 108 llvm::outs() << "Pattern mul(a, add(c, b)) matched " << countMatches(f, p14) 109 << " times\n"; 110 111 auto p15 = m_Op<MulFOp>(a, m_Op<AddFOp>(b, c)); 112 llvm::outs() << "Pattern mul(a, add(b, c)) matched " << countMatches(f, p15) 113 << " times\n"; 114 115 auto mul_of_aany = m_Op<MulFOp>(a, m_Any()); 116 auto p16 = m_Op<MulFOp>(mul_of_aany, m_Op<AddFOp>(a, c)); 117 llvm::outs() << "Pattern mul(mul(a, *), add(a, c)) matched " 118 << countMatches(f, p16) << " times\n"; 119 120 auto p17 = m_Op<MulFOp>(mul_of_aany, m_Op<AddFOp>(c, b)); 121 llvm::outs() << "Pattern mul(mul(a, *), add(c, b)) matched " 122 << countMatches(f, p17) << " times\n"; 123 } 124 125 void test2(FuncOp f) { 126 auto a = m_Val(f.getArgument(0)); 127 FloatAttr floatAttr; 128 auto p = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant(&floatAttr))); 129 auto p1 = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant())); 130 // Last operation that is not the terminator. 131 Operation *lastOp = f.getBody().front().back().getPrevNode(); 132 if (p.match(lastOp)) 133 llvm::outs() 134 << "Pattern add(add(a, constant), a) matched and bound constant to: " 135 << floatAttr.getValueAsDouble() << "\n"; 136 if (p1.match(lastOp)) 137 llvm::outs() << "Pattern add(add(a, constant), a) matched\n"; 138 } 139 140 void TestMatchers::runOnFunction() { 141 auto f = getFunction(); 142 llvm::outs() << f.getName() << "\n"; 143 if (f.getName() == "test1") 144 test1(f); 145 if (f.getName() == "test2") 146 test2(f); 147 } 148 149 namespace mlir { 150 void registerTestMatchers() { 151 PassRegistration<TestMatchers>("test-matchers", "Test C++ pattern matchers."); 152 } 153 } // namespace mlir 154