1 //===- TestMatchers.cpp - Pass to test matchers ---------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/Dialect/StandardOps/Ops.h" 19 #include "mlir/IR/Function.h" 20 #include "mlir/IR/Matchers.h" 21 #include "mlir/Pass/Pass.h" 22 23 using namespace mlir; 24 25 namespace { 26 /// This is a test pass for verifying matchers. 27 struct TestMatchers : public FunctionPass<TestMatchers> { 28 void runOnFunction() override; 29 }; 30 } // end anonymous namespace 31 32 // This could be done better but is not worth the variadic template trouble. 33 template <typename Matcher> unsigned countMatches(FuncOp f, Matcher &matcher) { 34 unsigned count = 0; 35 f.walk([&count, &matcher](Operation *op) { 36 if (matcher.match(op)) 37 ++count; 38 }); 39 return count; 40 } 41 42 using mlir::matchers::m_Any; 43 using mlir::matchers::m_Val; 44 static void test1(FuncOp f) { 45 assert(f.getNumArguments() == 3 && "matcher test funcs must have 3 args"); 46 47 auto a = m_Val(f.getArgument(0)); 48 auto b = m_Val(f.getArgument(1)); 49 auto c = m_Val(f.getArgument(2)); 50 51 auto p0 = m_Op<AddFOp>(); // using 0-arity matcher 52 llvm::outs() << "Pattern add(*) matched " << countMatches(f, p0) 53 << " times\n"; 54 55 auto p1 = m_Op<MulFOp>(); // using 0-arity matcher 56 llvm::outs() << "Pattern mul(*) matched " << countMatches(f, p1) 57 << " times\n"; 58 59 auto p2 = m_Op<AddFOp>(m_Op<AddFOp>(), m_Any()); 60 llvm::outs() << "Pattern add(add(*), *) matched " << countMatches(f, p2) 61 << " times\n"; 62 63 auto p3 = m_Op<AddFOp>(m_Any(), m_Op<AddFOp>()); 64 llvm::outs() << "Pattern add(*, add(*)) matched " << countMatches(f, p3) 65 << " times\n"; 66 67 auto p4 = m_Op<MulFOp>(m_Op<AddFOp>(), m_Any()); 68 llvm::outs() << "Pattern mul(add(*), *) matched " << countMatches(f, p4) 69 << " times\n"; 70 71 auto p5 = m_Op<MulFOp>(m_Any(), m_Op<AddFOp>()); 72 llvm::outs() << "Pattern mul(*, add(*)) matched " << countMatches(f, p5) 73 << " times\n"; 74 75 auto p6 = m_Op<MulFOp>(m_Op<MulFOp>(), m_Any()); 76 llvm::outs() << "Pattern mul(mul(*), *) matched " << countMatches(f, p6) 77 << " times\n"; 78 79 auto p7 = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<MulFOp>()); 80 llvm::outs() << "Pattern mul(mul(*), mul(*)) matched " << countMatches(f, p7) 81 << " times\n"; 82 83 auto mul_of_mulmul = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<MulFOp>()); 84 auto p8 = m_Op<MulFOp>(mul_of_mulmul, mul_of_mulmul); 85 llvm::outs() 86 << "Pattern mul(mul(mul(*), mul(*)), mul(mul(*), mul(*))) matched " 87 << countMatches(f, p8) << " times\n"; 88 89 // clang-format off 90 auto mul_of_muladd = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<AddFOp>()); 91 auto mul_of_anyadd = m_Op<MulFOp>(m_Any(), m_Op<AddFOp>()); 92 auto p9 = m_Op<MulFOp>(m_Op<MulFOp>( 93 mul_of_muladd, m_Op<MulFOp>()), 94 m_Op<MulFOp>(mul_of_anyadd, mul_of_anyadd)); 95 // clang-format on 96 llvm::outs() << "Pattern mul(mul(mul(mul(*), add(*)), mul(*)), mul(mul(*, " 97 "add(*)), mul(*, add(*)))) matched " 98 << countMatches(f, p9) << " times\n"; 99 100 auto p10 = m_Op<AddFOp>(a, b); 101 llvm::outs() << "Pattern add(a, b) matched " << countMatches(f, p10) 102 << " times\n"; 103 104 auto p11 = m_Op<AddFOp>(a, c); 105 llvm::outs() << "Pattern add(a, c) matched " << countMatches(f, p11) 106 << " times\n"; 107 108 auto p12 = m_Op<AddFOp>(b, a); 109 llvm::outs() << "Pattern add(b, a) matched " << countMatches(f, p12) 110 << " times\n"; 111 112 auto p13 = m_Op<AddFOp>(c, a); 113 llvm::outs() << "Pattern add(c, a) matched " << countMatches(f, p13) 114 << " times\n"; 115 116 auto p14 = m_Op<MulFOp>(a, m_Op<AddFOp>(c, b)); 117 llvm::outs() << "Pattern mul(a, add(c, b)) matched " << countMatches(f, p14) 118 << " times\n"; 119 120 auto p15 = m_Op<MulFOp>(a, m_Op<AddFOp>(b, c)); 121 llvm::outs() << "Pattern mul(a, add(b, c)) matched " << countMatches(f, p15) 122 << " times\n"; 123 124 auto mul_of_aany = m_Op<MulFOp>(a, m_Any()); 125 auto p16 = m_Op<MulFOp>(mul_of_aany, m_Op<AddFOp>(a, c)); 126 llvm::outs() << "Pattern mul(mul(a, *), add(a, c)) matched " 127 << countMatches(f, p16) << " times\n"; 128 129 auto p17 = m_Op<MulFOp>(mul_of_aany, m_Op<AddFOp>(c, b)); 130 llvm::outs() << "Pattern mul(mul(a, *), add(c, b)) matched " 131 << countMatches(f, p17) << " times\n"; 132 } 133 134 void test2(FuncOp f) { 135 auto a = m_Val(f.getArgument(0)); 136 FloatAttr floatAttr; 137 auto p = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant(&floatAttr))); 138 // Last operation that is not the terminator. 139 Operation *lastOp = f.getBody().front().back().getPrevNode(); 140 if (p.match(lastOp)) 141 llvm::outs() 142 << "Pattern add(add(a, constant), a) matched and bound constant to: " 143 << floatAttr.getValueAsDouble() << "\n"; 144 } 145 146 void TestMatchers::runOnFunction() { 147 auto f = getFunction(); 148 llvm::outs() << f.getName() << "\n"; 149 if (f.getName() == "test1") 150 test1(f); 151 if (f.getName() == "test2") 152 test2(f); 153 } 154 155 static PassRegistration<TestMatchers> pass("test-matchers", 156 "Test C++ pattern matchers."); 157