xref: /llvm-project/mlir/include/mlir/Query/Matcher/Marshallers.h (revision 02d9f4d1f128e17e04ab6e602d3c9b9942612428)
1 //===--- Marshallers.h - Generic matcher function marshallers ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains function templates and classes to wrap matcher construct
10 // functions. It provides a collection of template function and classes that
11 // present a generic marshalling layer on top of matcher construct functions.
12 // The registry uses these to export all marshaller constructors with a uniform
13 // interface. This mechanism takes inspiration from clang-query.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
18 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
19 
20 #include "ErrorBuilder.h"
21 #include "VariantValue.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/StringRef.h"
24 
25 namespace mlir::query::matcher::internal {
26 
27 // Helper template class for jumping from argument type to the correct is/get
28 // functions in VariantValue. This is used for verifying and extracting the
29 // matcher arguments.
30 template <class T>
31 struct ArgTypeTraits;
32 template <class T>
33 struct ArgTypeTraits<const T &> : public ArgTypeTraits<T> {};
34 
35 template <>
36 struct ArgTypeTraits<llvm::StringRef> {
37 
38   static bool hasCorrectType(const VariantValue &value) {
39     return value.isString();
40   }
41 
42   static const llvm::StringRef &get(const VariantValue &value) {
43     return value.getString();
44   }
45 
46   static ArgKind getKind() { return ArgKind::String; }
47 
48   static std::optional<std::string> getBestGuess(const VariantValue &) {
49     return std::nullopt;
50   }
51 };
52 
53 template <>
54 struct ArgTypeTraits<DynMatcher> {
55 
56   static bool hasCorrectType(const VariantValue &value) {
57     return value.isMatcher();
58   }
59 
60   static DynMatcher get(const VariantValue &value) {
61     return *value.getMatcher().getDynMatcher();
62   }
63 
64   static ArgKind getKind() { return ArgKind::Matcher; }
65 
66   static std::optional<std::string> getBestGuess(const VariantValue &) {
67     return std::nullopt;
68   }
69 };
70 
71 // Interface for generic matcher descriptor.
72 // Offers a create() method that constructs the matcher from the provided
73 // arguments.
74 class MatcherDescriptor {
75 public:
76   virtual ~MatcherDescriptor() = default;
77   virtual VariantMatcher create(SourceRange nameRange,
78                                 const llvm::ArrayRef<ParserValue> args,
79                                 Diagnostics *error) const = 0;
80 
81   // Returns the number of arguments accepted by the matcher.
82   virtual unsigned getNumArgs() const = 0;
83 
84   // Append the set of argument types accepted for argument 'argNo' to
85   // 'argKinds'.
86   virtual void getArgKinds(unsigned argNo,
87                            std::vector<ArgKind> &argKinds) const = 0;
88 };
89 
90 class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
91 public:
92   using MarshallerType = VariantMatcher (*)(void (*matcherFunc)(),
93                                             llvm::StringRef matcherName,
94                                             SourceRange nameRange,
95                                             llvm::ArrayRef<ParserValue> args,
96                                             Diagnostics *error);
97 
98   // Marshaller Function to unpack the arguments and call Func. Func is the
99   // Matcher construct function. This is the function that the matcher
100   // expressions would use to create the matcher.
101   FixedArgCountMatcherDescriptor(MarshallerType marshaller,
102                                  void (*matcherFunc)(),
103                                  llvm::StringRef matcherName,
104                                  llvm::ArrayRef<ArgKind> argKinds)
105       : marshaller(marshaller), matcherFunc(matcherFunc),
106         matcherName(matcherName), argKinds(argKinds.begin(), argKinds.end()) {}
107 
108   VariantMatcher create(SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
109                         Diagnostics *error) const override {
110     return marshaller(matcherFunc, matcherName, nameRange, args, error);
111   }
112 
113   unsigned getNumArgs() const override { return argKinds.size(); }
114 
115   void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
116     kinds.push_back(argKinds[argNo]);
117   }
118 
119 private:
120   const MarshallerType marshaller;
121   void (*const matcherFunc)();
122   const llvm::StringRef matcherName;
123   const std::vector<ArgKind> argKinds;
124 };
125 
126 // Helper function to check if argument count matches expected count
127 inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
128                           llvm::ArrayRef<ParserValue> args,
129                           Diagnostics *error) {
130   if (args.size() != expectedArgCount) {
131     addError(error, nameRange, ErrorType::RegistryWrongArgCount,
132              {llvm::Twine(expectedArgCount), llvm::Twine(args.size())});
133     return false;
134   }
135   return true;
136 }
137 
138 // Helper function for checking argument type
139 template <typename ArgType, size_t Index>
140 inline bool checkArgTypeAtIndex(llvm::StringRef matcherName,
141                                 llvm::ArrayRef<ParserValue> args,
142                                 Diagnostics *error) {
143   if (!ArgTypeTraits<ArgType>::hasCorrectType(args[Index].value)) {
144     addError(error, args[Index].range, ErrorType::RegistryWrongArgType,
145              {llvm::Twine(matcherName), llvm::Twine(Index + 1)});
146     return false;
147   }
148   return true;
149 }
150 
151 // Marshaller function for fixed number of arguments
152 template <typename ReturnType, typename... ArgTypes, size_t... Is>
153 static VariantMatcher
154 matcherMarshallFixedImpl(void (*matcherFunc)(), llvm::StringRef matcherName,
155                          SourceRange nameRange,
156                          llvm::ArrayRef<ParserValue> args, Diagnostics *error,
157                          std::index_sequence<Is...>) {
158   using FuncType = ReturnType (*)(ArgTypes...);
159 
160   // Check if the argument count matches the expected count
161   if (!checkArgCount(nameRange, sizeof...(ArgTypes), args, error))
162     return VariantMatcher();
163 
164   // Check if each argument at the corresponding index has the correct type
165   if ((... && checkArgTypeAtIndex<ArgTypes, Is>(matcherName, args, error))) {
166     ReturnType fnPointer = reinterpret_cast<FuncType>(matcherFunc)(
167         ArgTypeTraits<ArgTypes>::get(args[Is].value)...);
168     return VariantMatcher::SingleMatcher(
169         *DynMatcher::constructDynMatcherFromMatcherFn(fnPointer));
170   }
171 
172   return VariantMatcher();
173 }
174 
175 template <typename ReturnType, typename... ArgTypes>
176 static VariantMatcher
177 matcherMarshallFixed(void (*matcherFunc)(), llvm::StringRef matcherName,
178                      SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
179                      Diagnostics *error) {
180   return matcherMarshallFixedImpl<ReturnType, ArgTypes...>(
181       matcherFunc, matcherName, nameRange, args, error,
182       std::index_sequence_for<ArgTypes...>{});
183 }
184 
185 // Fixed number of arguments overload
186 template <typename ReturnType, typename... ArgTypes>
187 std::unique_ptr<MatcherDescriptor>
188 makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
189                         llvm::StringRef matcherName) {
190   // Create a vector of argument kinds
191   std::vector<ArgKind> argKinds = {ArgTypeTraits<ArgTypes>::getKind()...};
192   return std::make_unique<FixedArgCountMatcherDescriptor>(
193       matcherMarshallFixed<ReturnType, ArgTypes...>,
194       reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
195 }
196 
197 } // namespace mlir::query::matcher::internal
198 
199 #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
200