102d9f4d1SDevajith //===- RegistryManager.cpp - Matcher registry -----------------------------===// 202d9f4d1SDevajith // 302d9f4d1SDevajith // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 402d9f4d1SDevajith // See https://llvm.org/LICENSE.txt for license information. 502d9f4d1SDevajith // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 602d9f4d1SDevajith // 702d9f4d1SDevajith //===----------------------------------------------------------------------===// 802d9f4d1SDevajith // 902d9f4d1SDevajith // Registry map populated at static initialization time. 1002d9f4d1SDevajith // 1102d9f4d1SDevajith //===----------------------------------------------------------------------===// 1202d9f4d1SDevajith 1302d9f4d1SDevajith #include "RegistryManager.h" 1402d9f4d1SDevajith #include "mlir/Query/Matcher/Registry.h" 1502d9f4d1SDevajith 1602d9f4d1SDevajith #include <set> 1702d9f4d1SDevajith #include <utility> 1802d9f4d1SDevajith 1902d9f4d1SDevajith namespace mlir::query::matcher { 2002d9f4d1SDevajith namespace { 2102d9f4d1SDevajith 2202d9f4d1SDevajith // This is needed because these matchers are defined as overloaded functions. 2302d9f4d1SDevajith using IsConstantOp = detail::constant_op_matcher(); 2402d9f4d1SDevajith using HasOpAttrName = detail::AttrOpMatcher(llvm::StringRef); 2502d9f4d1SDevajith using HasOpName = detail::NameOpMatcher(llvm::StringRef); 2602d9f4d1SDevajith 2702d9f4d1SDevajith // Enum to string for autocomplete. 2802d9f4d1SDevajith static std::string asArgString(ArgKind kind) { 2902d9f4d1SDevajith switch (kind) { 3002d9f4d1SDevajith case ArgKind::Matcher: 3102d9f4d1SDevajith return "Matcher"; 3202d9f4d1SDevajith case ArgKind::String: 3302d9f4d1SDevajith return "String"; 3402d9f4d1SDevajith } 3502d9f4d1SDevajith llvm_unreachable("Unhandled ArgKind"); 3602d9f4d1SDevajith } 3702d9f4d1SDevajith 3802d9f4d1SDevajith } // namespace 3902d9f4d1SDevajith 4002d9f4d1SDevajith void Registry::registerMatcherDescriptor( 4102d9f4d1SDevajith llvm::StringRef matcherName, 4202d9f4d1SDevajith std::unique_ptr<internal::MatcherDescriptor> callback) { 4302d9f4d1SDevajith assert(!constructorMap.contains(matcherName)); 4402d9f4d1SDevajith constructorMap[matcherName] = std::move(callback); 4502d9f4d1SDevajith } 4602d9f4d1SDevajith 4702d9f4d1SDevajith std::optional<MatcherCtor> 4802d9f4d1SDevajith RegistryManager::lookupMatcherCtor(llvm::StringRef matcherName, 4902d9f4d1SDevajith const Registry &matcherRegistry) { 5002d9f4d1SDevajith auto it = matcherRegistry.constructors().find(matcherName); 5102d9f4d1SDevajith return it == matcherRegistry.constructors().end() 5202d9f4d1SDevajith ? std::optional<MatcherCtor>() 5302d9f4d1SDevajith : it->second.get(); 5402d9f4d1SDevajith } 5502d9f4d1SDevajith 5602d9f4d1SDevajith std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes( 5702d9f4d1SDevajith llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) { 5802d9f4d1SDevajith // Starting with the above seed of acceptable top-level matcher types, compute 5902d9f4d1SDevajith // the acceptable type set for the argument indicated by each context element. 6002d9f4d1SDevajith std::set<ArgKind> typeSet; 6102d9f4d1SDevajith typeSet.insert(ArgKind::Matcher); 6202d9f4d1SDevajith 6302d9f4d1SDevajith for (const auto &ctxEntry : context) { 6402d9f4d1SDevajith MatcherCtor ctor = ctxEntry.first; 6502d9f4d1SDevajith unsigned argNumber = ctxEntry.second; 6602d9f4d1SDevajith std::vector<ArgKind> nextTypeSet; 6702d9f4d1SDevajith 6802d9f4d1SDevajith if (argNumber < ctor->getNumArgs()) 6902d9f4d1SDevajith ctor->getArgKinds(argNumber, nextTypeSet); 7002d9f4d1SDevajith 7102d9f4d1SDevajith typeSet.insert(nextTypeSet.begin(), nextTypeSet.end()); 7202d9f4d1SDevajith } 7302d9f4d1SDevajith 7402d9f4d1SDevajith return std::vector<ArgKind>(typeSet.begin(), typeSet.end()); 7502d9f4d1SDevajith } 7602d9f4d1SDevajith 7702d9f4d1SDevajith std::vector<MatcherCompletion> 7802d9f4d1SDevajith RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes, 7902d9f4d1SDevajith const Registry &matcherRegistry) { 8002d9f4d1SDevajith std::vector<MatcherCompletion> completions; 8102d9f4d1SDevajith 8202d9f4d1SDevajith // Search the registry for acceptable matchers. 8302d9f4d1SDevajith for (const auto &m : matcherRegistry.constructors()) { 8402d9f4d1SDevajith const internal::MatcherDescriptor &matcher = *m.getValue(); 8502d9f4d1SDevajith llvm::StringRef name = m.getKey(); 8602d9f4d1SDevajith 8702d9f4d1SDevajith unsigned numArgs = matcher.getNumArgs(); 8802d9f4d1SDevajith std::vector<std::vector<ArgKind>> argKinds(numArgs); 8902d9f4d1SDevajith 9002d9f4d1SDevajith for (const ArgKind &kind : acceptedTypes) { 9102d9f4d1SDevajith if (kind != ArgKind::Matcher) 9202d9f4d1SDevajith continue; 9302d9f4d1SDevajith 9402d9f4d1SDevajith for (unsigned arg = 0; arg != numArgs; ++arg) 9502d9f4d1SDevajith matcher.getArgKinds(arg, argKinds[arg]); 9602d9f4d1SDevajith } 9702d9f4d1SDevajith 9802d9f4d1SDevajith std::string decl; 9902d9f4d1SDevajith llvm::raw_string_ostream os(decl); 10002d9f4d1SDevajith 10102d9f4d1SDevajith std::string typedText = std::string(name); 10202d9f4d1SDevajith os << "Matcher: " << name << "("; 10302d9f4d1SDevajith 10402d9f4d1SDevajith for (const std::vector<ArgKind> &arg : argKinds) { 10502d9f4d1SDevajith if (&arg != &argKinds[0]) 10602d9f4d1SDevajith os << ", "; 10702d9f4d1SDevajith 10802d9f4d1SDevajith bool firstArgKind = true; 10902d9f4d1SDevajith // Two steps. First all non-matchers, then matchers only. 11002d9f4d1SDevajith for (const ArgKind &argKind : arg) { 11102d9f4d1SDevajith if (!firstArgKind) 11202d9f4d1SDevajith os << "|"; 11302d9f4d1SDevajith 11402d9f4d1SDevajith firstArgKind = false; 11502d9f4d1SDevajith os << asArgString(argKind); 11602d9f4d1SDevajith } 11702d9f4d1SDevajith } 11802d9f4d1SDevajith 11902d9f4d1SDevajith os << ")"; 12002d9f4d1SDevajith typedText += "("; 12102d9f4d1SDevajith 12202d9f4d1SDevajith if (argKinds.empty()) 12302d9f4d1SDevajith typedText += ")"; 12402d9f4d1SDevajith else if (argKinds[0][0] == ArgKind::String) 12502d9f4d1SDevajith typedText += "\""; 12602d9f4d1SDevajith 127*884221edSJOE1994 completions.emplace_back(typedText, decl); 12802d9f4d1SDevajith } 12902d9f4d1SDevajith 13002d9f4d1SDevajith return completions; 13102d9f4d1SDevajith } 13202d9f4d1SDevajith 13302d9f4d1SDevajith VariantMatcher RegistryManager::constructMatcher( 13402d9f4d1SDevajith MatcherCtor ctor, internal::SourceRange nameRange, 13558b44c81SJacques Pienaar llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args, 13658b44c81SJacques Pienaar internal::Diagnostics *error) { 13758b44c81SJacques Pienaar VariantMatcher out = ctor->create(nameRange, args, error); 13858b44c81SJacques Pienaar if (functionName.empty() || out.isNull()) 13958b44c81SJacques Pienaar return out; 14058b44c81SJacques Pienaar 14158b44c81SJacques Pienaar if (std::optional<DynMatcher> result = out.getDynMatcher()) { 14258b44c81SJacques Pienaar result->setFunctionName(functionName); 14358b44c81SJacques Pienaar return VariantMatcher::SingleMatcher(*result); 14458b44c81SJacques Pienaar } 14558b44c81SJacques Pienaar 14658b44c81SJacques Pienaar error->addError(nameRange, internal::ErrorType::RegistryNotBindable); 14758b44c81SJacques Pienaar return {}; 14802d9f4d1SDevajith } 14902d9f4d1SDevajith 15002d9f4d1SDevajith } // namespace mlir::query::matcher 151