1 //===- RegistryManager.cpp - Matcher registry -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Registry map populated at static initialization time. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "RegistryManager.h" 14 #include "mlir/Query/Matcher/Registry.h" 15 16 #include <set> 17 #include <utility> 18 19 namespace mlir::query::matcher { 20 namespace { 21 22 // This is needed because these matchers are defined as overloaded functions. 23 using IsConstantOp = detail::constant_op_matcher(); 24 using HasOpAttrName = detail::AttrOpMatcher(llvm::StringRef); 25 using HasOpName = detail::NameOpMatcher(llvm::StringRef); 26 27 // Enum to string for autocomplete. 28 static std::string asArgString(ArgKind kind) { 29 switch (kind) { 30 case ArgKind::Matcher: 31 return "Matcher"; 32 case ArgKind::String: 33 return "String"; 34 } 35 llvm_unreachable("Unhandled ArgKind"); 36 } 37 38 } // namespace 39 40 void Registry::registerMatcherDescriptor( 41 llvm::StringRef matcherName, 42 std::unique_ptr<internal::MatcherDescriptor> callback) { 43 assert(!constructorMap.contains(matcherName)); 44 constructorMap[matcherName] = std::move(callback); 45 } 46 47 std::optional<MatcherCtor> 48 RegistryManager::lookupMatcherCtor(llvm::StringRef matcherName, 49 const Registry &matcherRegistry) { 50 auto it = matcherRegistry.constructors().find(matcherName); 51 return it == matcherRegistry.constructors().end() 52 ? std::optional<MatcherCtor>() 53 : it->second.get(); 54 } 55 56 std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes( 57 llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) { 58 // Starting with the above seed of acceptable top-level matcher types, compute 59 // the acceptable type set for the argument indicated by each context element. 60 std::set<ArgKind> typeSet; 61 typeSet.insert(ArgKind::Matcher); 62 63 for (const auto &ctxEntry : context) { 64 MatcherCtor ctor = ctxEntry.first; 65 unsigned argNumber = ctxEntry.second; 66 std::vector<ArgKind> nextTypeSet; 67 68 if (argNumber < ctor->getNumArgs()) 69 ctor->getArgKinds(argNumber, nextTypeSet); 70 71 typeSet.insert(nextTypeSet.begin(), nextTypeSet.end()); 72 } 73 74 return std::vector<ArgKind>(typeSet.begin(), typeSet.end()); 75 } 76 77 std::vector<MatcherCompletion> 78 RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes, 79 const Registry &matcherRegistry) { 80 std::vector<MatcherCompletion> completions; 81 82 // Search the registry for acceptable matchers. 83 for (const auto &m : matcherRegistry.constructors()) { 84 const internal::MatcherDescriptor &matcher = *m.getValue(); 85 llvm::StringRef name = m.getKey(); 86 87 unsigned numArgs = matcher.getNumArgs(); 88 std::vector<std::vector<ArgKind>> argKinds(numArgs); 89 90 for (const ArgKind &kind : acceptedTypes) { 91 if (kind != ArgKind::Matcher) 92 continue; 93 94 for (unsigned arg = 0; arg != numArgs; ++arg) 95 matcher.getArgKinds(arg, argKinds[arg]); 96 } 97 98 std::string decl; 99 llvm::raw_string_ostream os(decl); 100 101 std::string typedText = std::string(name); 102 os << "Matcher: " << name << "("; 103 104 for (const std::vector<ArgKind> &arg : argKinds) { 105 if (&arg != &argKinds[0]) 106 os << ", "; 107 108 bool firstArgKind = true; 109 // Two steps. First all non-matchers, then matchers only. 110 for (const ArgKind &argKind : arg) { 111 if (!firstArgKind) 112 os << "|"; 113 114 firstArgKind = false; 115 os << asArgString(argKind); 116 } 117 } 118 119 os << ")"; 120 typedText += "("; 121 122 if (argKinds.empty()) 123 typedText += ")"; 124 else if (argKinds[0][0] == ArgKind::String) 125 typedText += "\""; 126 127 completions.emplace_back(typedText, decl); 128 } 129 130 return completions; 131 } 132 133 VariantMatcher RegistryManager::constructMatcher( 134 MatcherCtor ctor, internal::SourceRange nameRange, 135 llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args, 136 internal::Diagnostics *error) { 137 VariantMatcher out = ctor->create(nameRange, args, error); 138 if (functionName.empty() || out.isNull()) 139 return out; 140 141 if (std::optional<DynMatcher> result = out.getDynMatcher()) { 142 result->setFunctionName(functionName); 143 return VariantMatcher::SingleMatcher(*result); 144 } 145 146 error->addError(nameRange, internal::ErrorType::RegistryNotBindable); 147 return {}; 148 } 149 150 } // namespace mlir::query::matcher 151