xref: /llvm-project/mlir/include/mlir/Query/Matcher/MatchersInternal.h (revision 58b44c8102afb0e76d1cb70d4a5d089f70d2f657)
1 //===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implements the base layer of the matcher framework.
10 //
11 // Matchers are methods that return a Matcher which provides a method
12 // match(Operation *op)
13 //
14 // The matcher functions are defined in include/mlir/IR/Matchers.h.
15 // This file contains the wrapper classes needed to construct matchers for
16 // mlir-query.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
21 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
22 
23 #include "mlir/IR/Matchers.h"
24 #include "llvm/ADT/IntrusiveRefCntPtr.h"
25 
26 namespace mlir::query::matcher {
27 
28 // Generic interface for matchers on an MLIR operation.
29 class MatcherInterface
30     : public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
31 public:
32   virtual ~MatcherInterface() = default;
33 
34   virtual bool match(Operation *op) = 0;
35 };
36 
37 // MatcherFnImpl takes a matcher function object and implements
38 // MatcherInterface.
39 template <typename MatcherFn>
40 class MatcherFnImpl : public MatcherInterface {
41 public:
MatcherFnImpl(MatcherFn & matcherFn)42   MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
match(Operation * op)43   bool match(Operation *op) override { return matcherFn.match(op); }
44 
45 private:
46   MatcherFn matcherFn;
47 };
48 
49 // Matcher wraps a MatcherInterface implementation and provides a match()
50 // method that redirects calls to the underlying implementation.
51 class DynMatcher {
52 public:
53   // Takes ownership of the provided implementation pointer.
DynMatcher(MatcherInterface * implementation)54   DynMatcher(MatcherInterface *implementation)
55       : implementation(implementation) {}
56 
57   template <typename MatcherFn>
58   static std::unique_ptr<DynMatcher>
constructDynMatcherFromMatcherFn(MatcherFn & matcherFn)59   constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
60     auto impl = std::make_unique<MatcherFnImpl<MatcherFn>>(matcherFn);
61     return std::make_unique<DynMatcher>(impl.release());
62   }
63 
match(Operation * op)64   bool match(Operation *op) const { return implementation->match(op); }
65 
setFunctionName(StringRef name)66   void setFunctionName(StringRef name) { functionName = name.str(); };
67 
hasFunctionName()68   bool hasFunctionName() const { return !functionName.empty(); };
69 
getFunctionName()70   StringRef getFunctionName() const { return functionName; };
71 
72 private:
73   llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
74   std::string functionName;
75 };
76 
77 } // namespace mlir::query::matcher
78 
79 #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
80