1 //===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Implements the base layer of the matcher framework. 10 // 11 // Matchers are methods that return a Matcher which provides a method 12 // match(Operation *op) 13 // 14 // The matcher functions are defined in include/mlir/IR/Matchers.h. 15 // This file contains the wrapper classes needed to construct matchers for 16 // mlir-query. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H 21 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H 22 23 #include "mlir/IR/Matchers.h" 24 #include "llvm/ADT/IntrusiveRefCntPtr.h" 25 26 namespace mlir::query::matcher { 27 28 // Generic interface for matchers on an MLIR operation. 29 class MatcherInterface 30 : public llvm::ThreadSafeRefCountedBase<MatcherInterface> { 31 public: 32 virtual ~MatcherInterface() = default; 33 34 virtual bool match(Operation *op) = 0; 35 }; 36 37 // MatcherFnImpl takes a matcher function object and implements 38 // MatcherInterface. 39 template <typename MatcherFn> 40 class MatcherFnImpl : public MatcherInterface { 41 public: MatcherFnImpl(MatcherFn & matcherFn)42 MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {} match(Operation * op)43 bool match(Operation *op) override { return matcherFn.match(op); } 44 45 private: 46 MatcherFn matcherFn; 47 }; 48 49 // Matcher wraps a MatcherInterface implementation and provides a match() 50 // method that redirects calls to the underlying implementation. 51 class DynMatcher { 52 public: 53 // Takes ownership of the provided implementation pointer. DynMatcher(MatcherInterface * implementation)54 DynMatcher(MatcherInterface *implementation) 55 : implementation(implementation) {} 56 57 template <typename MatcherFn> 58 static std::unique_ptr<DynMatcher> constructDynMatcherFromMatcherFn(MatcherFn & matcherFn)59 constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) { 60 auto impl = std::make_unique<MatcherFnImpl<MatcherFn>>(matcherFn); 61 return std::make_unique<DynMatcher>(impl.release()); 62 } 63 match(Operation * op)64 bool match(Operation *op) const { return implementation->match(op); } 65 setFunctionName(StringRef name)66 void setFunctionName(StringRef name) { functionName = name.str(); }; 67 hasFunctionName()68 bool hasFunctionName() const { return !functionName.empty(); }; 69 getFunctionName()70 StringRef getFunctionName() const { return functionName; }; 71 72 private: 73 llvm::IntrusiveRefCntPtr<MatcherInterface> implementation; 74 std::string functionName; 75 }; 76 77 } // namespace mlir::query::matcher 78 79 #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H 80