1 //===--- Marshallers.h - Generic matcher function marshallers ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains function templates and classes to wrap matcher construct 10 // functions. It provides a collection of template function and classes that 11 // present a generic marshalling layer on top of matcher construct functions. 12 // The registry uses these to export all marshaller constructors with a uniform 13 // interface. This mechanism takes inspiration from clang-query. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H 18 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H 19 20 #include "ErrorBuilder.h" 21 #include "VariantValue.h" 22 #include "llvm/ADT/ArrayRef.h" 23 #include "llvm/ADT/StringRef.h" 24 25 namespace mlir::query::matcher::internal { 26 27 // Helper template class for jumping from argument type to the correct is/get 28 // functions in VariantValue. This is used for verifying and extracting the 29 // matcher arguments. 30 template <class T> 31 struct ArgTypeTraits; 32 template <class T> 33 struct ArgTypeTraits<const T &> : public ArgTypeTraits<T> {}; 34 35 template <> 36 struct ArgTypeTraits<llvm::StringRef> { 37 38 static bool hasCorrectType(const VariantValue &value) { 39 return value.isString(); 40 } 41 42 static const llvm::StringRef &get(const VariantValue &value) { 43 return value.getString(); 44 } 45 46 static ArgKind getKind() { return ArgKind::String; } 47 48 static std::optional<std::string> getBestGuess(const VariantValue &) { 49 return std::nullopt; 50 } 51 }; 52 53 template <> 54 struct ArgTypeTraits<DynMatcher> { 55 56 static bool hasCorrectType(const VariantValue &value) { 57 return value.isMatcher(); 58 } 59 60 static DynMatcher get(const VariantValue &value) { 61 return *value.getMatcher().getDynMatcher(); 62 } 63 64 static ArgKind getKind() { return ArgKind::Matcher; } 65 66 static std::optional<std::string> getBestGuess(const VariantValue &) { 67 return std::nullopt; 68 } 69 }; 70 71 // Interface for generic matcher descriptor. 72 // Offers a create() method that constructs the matcher from the provided 73 // arguments. 74 class MatcherDescriptor { 75 public: 76 virtual ~MatcherDescriptor() = default; 77 virtual VariantMatcher create(SourceRange nameRange, 78 const llvm::ArrayRef<ParserValue> args, 79 Diagnostics *error) const = 0; 80 81 // Returns the number of arguments accepted by the matcher. 82 virtual unsigned getNumArgs() const = 0; 83 84 // Append the set of argument types accepted for argument 'argNo' to 85 // 'argKinds'. 86 virtual void getArgKinds(unsigned argNo, 87 std::vector<ArgKind> &argKinds) const = 0; 88 }; 89 90 class FixedArgCountMatcherDescriptor : public MatcherDescriptor { 91 public: 92 using MarshallerType = VariantMatcher (*)(void (*matcherFunc)(), 93 llvm::StringRef matcherName, 94 SourceRange nameRange, 95 llvm::ArrayRef<ParserValue> args, 96 Diagnostics *error); 97 98 // Marshaller Function to unpack the arguments and call Func. Func is the 99 // Matcher construct function. This is the function that the matcher 100 // expressions would use to create the matcher. 101 FixedArgCountMatcherDescriptor(MarshallerType marshaller, 102 void (*matcherFunc)(), 103 llvm::StringRef matcherName, 104 llvm::ArrayRef<ArgKind> argKinds) 105 : marshaller(marshaller), matcherFunc(matcherFunc), 106 matcherName(matcherName), argKinds(argKinds.begin(), argKinds.end()) {} 107 108 VariantMatcher create(SourceRange nameRange, llvm::ArrayRef<ParserValue> args, 109 Diagnostics *error) const override { 110 return marshaller(matcherFunc, matcherName, nameRange, args, error); 111 } 112 113 unsigned getNumArgs() const override { return argKinds.size(); } 114 115 void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override { 116 kinds.push_back(argKinds[argNo]); 117 } 118 119 private: 120 const MarshallerType marshaller; 121 void (*const matcherFunc)(); 122 const llvm::StringRef matcherName; 123 const std::vector<ArgKind> argKinds; 124 }; 125 126 // Helper function to check if argument count matches expected count 127 inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount, 128 llvm::ArrayRef<ParserValue> args, 129 Diagnostics *error) { 130 if (args.size() != expectedArgCount) { 131 addError(error, nameRange, ErrorType::RegistryWrongArgCount, 132 {llvm::Twine(expectedArgCount), llvm::Twine(args.size())}); 133 return false; 134 } 135 return true; 136 } 137 138 // Helper function for checking argument type 139 template <typename ArgType, size_t Index> 140 inline bool checkArgTypeAtIndex(llvm::StringRef matcherName, 141 llvm::ArrayRef<ParserValue> args, 142 Diagnostics *error) { 143 if (!ArgTypeTraits<ArgType>::hasCorrectType(args[Index].value)) { 144 addError(error, args[Index].range, ErrorType::RegistryWrongArgType, 145 {llvm::Twine(matcherName), llvm::Twine(Index + 1)}); 146 return false; 147 } 148 return true; 149 } 150 151 // Marshaller function for fixed number of arguments 152 template <typename ReturnType, typename... ArgTypes, size_t... Is> 153 static VariantMatcher 154 matcherMarshallFixedImpl(void (*matcherFunc)(), llvm::StringRef matcherName, 155 SourceRange nameRange, 156 llvm::ArrayRef<ParserValue> args, Diagnostics *error, 157 std::index_sequence<Is...>) { 158 using FuncType = ReturnType (*)(ArgTypes...); 159 160 // Check if the argument count matches the expected count 161 if (!checkArgCount(nameRange, sizeof...(ArgTypes), args, error)) 162 return VariantMatcher(); 163 164 // Check if each argument at the corresponding index has the correct type 165 if ((... && checkArgTypeAtIndex<ArgTypes, Is>(matcherName, args, error))) { 166 ReturnType fnPointer = reinterpret_cast<FuncType>(matcherFunc)( 167 ArgTypeTraits<ArgTypes>::get(args[Is].value)...); 168 return VariantMatcher::SingleMatcher( 169 *DynMatcher::constructDynMatcherFromMatcherFn(fnPointer)); 170 } 171 172 return VariantMatcher(); 173 } 174 175 template <typename ReturnType, typename... ArgTypes> 176 static VariantMatcher 177 matcherMarshallFixed(void (*matcherFunc)(), llvm::StringRef matcherName, 178 SourceRange nameRange, llvm::ArrayRef<ParserValue> args, 179 Diagnostics *error) { 180 return matcherMarshallFixedImpl<ReturnType, ArgTypes...>( 181 matcherFunc, matcherName, nameRange, args, error, 182 std::index_sequence_for<ArgTypes...>{}); 183 } 184 185 // Fixed number of arguments overload 186 template <typename ReturnType, typename... ArgTypes> 187 std::unique_ptr<MatcherDescriptor> 188 makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...), 189 llvm::StringRef matcherName) { 190 // Create a vector of argument kinds 191 std::vector<ArgKind> argKinds = {ArgTypeTraits<ArgTypes>::getKind()...}; 192 return std::make_unique<FixedArgCountMatcherDescriptor>( 193 matcherMarshallFixed<ReturnType, ArgTypes...>, 194 reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds); 195 } 196 197 } // namespace mlir::query::matcher::internal 198 199 #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H 200