xref: /llvm-project/mlir/lib/Query/Matcher/RegistryManager.cpp (revision 884221eddb9d395830704fac79fd04008e02e368)
102d9f4d1SDevajith //===- RegistryManager.cpp - Matcher registry -----------------------------===//
202d9f4d1SDevajith //
302d9f4d1SDevajith // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
402d9f4d1SDevajith // See https://llvm.org/LICENSE.txt for license information.
502d9f4d1SDevajith // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
602d9f4d1SDevajith //
702d9f4d1SDevajith //===----------------------------------------------------------------------===//
802d9f4d1SDevajith //
902d9f4d1SDevajith // Registry map populated at static initialization time.
1002d9f4d1SDevajith //
1102d9f4d1SDevajith //===----------------------------------------------------------------------===//
1202d9f4d1SDevajith 
1302d9f4d1SDevajith #include "RegistryManager.h"
1402d9f4d1SDevajith #include "mlir/Query/Matcher/Registry.h"
1502d9f4d1SDevajith 
1602d9f4d1SDevajith #include <set>
1702d9f4d1SDevajith #include <utility>
1802d9f4d1SDevajith 
1902d9f4d1SDevajith namespace mlir::query::matcher {
2002d9f4d1SDevajith namespace {
2102d9f4d1SDevajith 
2202d9f4d1SDevajith // This is needed because these matchers are defined as overloaded functions.
2302d9f4d1SDevajith using IsConstantOp = detail::constant_op_matcher();
2402d9f4d1SDevajith using HasOpAttrName = detail::AttrOpMatcher(llvm::StringRef);
2502d9f4d1SDevajith using HasOpName = detail::NameOpMatcher(llvm::StringRef);
2602d9f4d1SDevajith 
2702d9f4d1SDevajith // Enum to string for autocomplete.
2802d9f4d1SDevajith static std::string asArgString(ArgKind kind) {
2902d9f4d1SDevajith   switch (kind) {
3002d9f4d1SDevajith   case ArgKind::Matcher:
3102d9f4d1SDevajith     return "Matcher";
3202d9f4d1SDevajith   case ArgKind::String:
3302d9f4d1SDevajith     return "String";
3402d9f4d1SDevajith   }
3502d9f4d1SDevajith   llvm_unreachable("Unhandled ArgKind");
3602d9f4d1SDevajith }
3702d9f4d1SDevajith 
3802d9f4d1SDevajith } // namespace
3902d9f4d1SDevajith 
4002d9f4d1SDevajith void Registry::registerMatcherDescriptor(
4102d9f4d1SDevajith     llvm::StringRef matcherName,
4202d9f4d1SDevajith     std::unique_ptr<internal::MatcherDescriptor> callback) {
4302d9f4d1SDevajith   assert(!constructorMap.contains(matcherName));
4402d9f4d1SDevajith   constructorMap[matcherName] = std::move(callback);
4502d9f4d1SDevajith }
4602d9f4d1SDevajith 
4702d9f4d1SDevajith std::optional<MatcherCtor>
4802d9f4d1SDevajith RegistryManager::lookupMatcherCtor(llvm::StringRef matcherName,
4902d9f4d1SDevajith                                    const Registry &matcherRegistry) {
5002d9f4d1SDevajith   auto it = matcherRegistry.constructors().find(matcherName);
5102d9f4d1SDevajith   return it == matcherRegistry.constructors().end()
5202d9f4d1SDevajith              ? std::optional<MatcherCtor>()
5302d9f4d1SDevajith              : it->second.get();
5402d9f4d1SDevajith }
5502d9f4d1SDevajith 
5602d9f4d1SDevajith std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes(
5702d9f4d1SDevajith     llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) {
5802d9f4d1SDevajith   // Starting with the above seed of acceptable top-level matcher types, compute
5902d9f4d1SDevajith   // the acceptable type set for the argument indicated by each context element.
6002d9f4d1SDevajith   std::set<ArgKind> typeSet;
6102d9f4d1SDevajith   typeSet.insert(ArgKind::Matcher);
6202d9f4d1SDevajith 
6302d9f4d1SDevajith   for (const auto &ctxEntry : context) {
6402d9f4d1SDevajith     MatcherCtor ctor = ctxEntry.first;
6502d9f4d1SDevajith     unsigned argNumber = ctxEntry.second;
6602d9f4d1SDevajith     std::vector<ArgKind> nextTypeSet;
6702d9f4d1SDevajith 
6802d9f4d1SDevajith     if (argNumber < ctor->getNumArgs())
6902d9f4d1SDevajith       ctor->getArgKinds(argNumber, nextTypeSet);
7002d9f4d1SDevajith 
7102d9f4d1SDevajith     typeSet.insert(nextTypeSet.begin(), nextTypeSet.end());
7202d9f4d1SDevajith   }
7302d9f4d1SDevajith 
7402d9f4d1SDevajith   return std::vector<ArgKind>(typeSet.begin(), typeSet.end());
7502d9f4d1SDevajith }
7602d9f4d1SDevajith 
7702d9f4d1SDevajith std::vector<MatcherCompletion>
7802d9f4d1SDevajith RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
7902d9f4d1SDevajith                                        const Registry &matcherRegistry) {
8002d9f4d1SDevajith   std::vector<MatcherCompletion> completions;
8102d9f4d1SDevajith 
8202d9f4d1SDevajith   // Search the registry for acceptable matchers.
8302d9f4d1SDevajith   for (const auto &m : matcherRegistry.constructors()) {
8402d9f4d1SDevajith     const internal::MatcherDescriptor &matcher = *m.getValue();
8502d9f4d1SDevajith     llvm::StringRef name = m.getKey();
8602d9f4d1SDevajith 
8702d9f4d1SDevajith     unsigned numArgs = matcher.getNumArgs();
8802d9f4d1SDevajith     std::vector<std::vector<ArgKind>> argKinds(numArgs);
8902d9f4d1SDevajith 
9002d9f4d1SDevajith     for (const ArgKind &kind : acceptedTypes) {
9102d9f4d1SDevajith       if (kind != ArgKind::Matcher)
9202d9f4d1SDevajith         continue;
9302d9f4d1SDevajith 
9402d9f4d1SDevajith       for (unsigned arg = 0; arg != numArgs; ++arg)
9502d9f4d1SDevajith         matcher.getArgKinds(arg, argKinds[arg]);
9602d9f4d1SDevajith     }
9702d9f4d1SDevajith 
9802d9f4d1SDevajith     std::string decl;
9902d9f4d1SDevajith     llvm::raw_string_ostream os(decl);
10002d9f4d1SDevajith 
10102d9f4d1SDevajith     std::string typedText = std::string(name);
10202d9f4d1SDevajith     os << "Matcher: " << name << "(";
10302d9f4d1SDevajith 
10402d9f4d1SDevajith     for (const std::vector<ArgKind> &arg : argKinds) {
10502d9f4d1SDevajith       if (&arg != &argKinds[0])
10602d9f4d1SDevajith         os << ", ";
10702d9f4d1SDevajith 
10802d9f4d1SDevajith       bool firstArgKind = true;
10902d9f4d1SDevajith       // Two steps. First all non-matchers, then matchers only.
11002d9f4d1SDevajith       for (const ArgKind &argKind : arg) {
11102d9f4d1SDevajith         if (!firstArgKind)
11202d9f4d1SDevajith           os << "|";
11302d9f4d1SDevajith 
11402d9f4d1SDevajith         firstArgKind = false;
11502d9f4d1SDevajith         os << asArgString(argKind);
11602d9f4d1SDevajith       }
11702d9f4d1SDevajith     }
11802d9f4d1SDevajith 
11902d9f4d1SDevajith     os << ")";
12002d9f4d1SDevajith     typedText += "(";
12102d9f4d1SDevajith 
12202d9f4d1SDevajith     if (argKinds.empty())
12302d9f4d1SDevajith       typedText += ")";
12402d9f4d1SDevajith     else if (argKinds[0][0] == ArgKind::String)
12502d9f4d1SDevajith       typedText += "\"";
12602d9f4d1SDevajith 
127*884221edSJOE1994     completions.emplace_back(typedText, decl);
12802d9f4d1SDevajith   }
12902d9f4d1SDevajith 
13002d9f4d1SDevajith   return completions;
13102d9f4d1SDevajith }
13202d9f4d1SDevajith 
13302d9f4d1SDevajith VariantMatcher RegistryManager::constructMatcher(
13402d9f4d1SDevajith     MatcherCtor ctor, internal::SourceRange nameRange,
13558b44c81SJacques Pienaar     llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args,
13658b44c81SJacques Pienaar     internal::Diagnostics *error) {
13758b44c81SJacques Pienaar   VariantMatcher out = ctor->create(nameRange, args, error);
13858b44c81SJacques Pienaar   if (functionName.empty() || out.isNull())
13958b44c81SJacques Pienaar     return out;
14058b44c81SJacques Pienaar 
14158b44c81SJacques Pienaar   if (std::optional<DynMatcher> result = out.getDynMatcher()) {
14258b44c81SJacques Pienaar     result->setFunctionName(functionName);
14358b44c81SJacques Pienaar     return VariantMatcher::SingleMatcher(*result);
14458b44c81SJacques Pienaar   }
14558b44c81SJacques Pienaar 
14658b44c81SJacques Pienaar   error->addError(nameRange, internal::ErrorType::RegistryNotBindable);
14758b44c81SJacques Pienaar   return {};
14802d9f4d1SDevajith }
14902d9f4d1SDevajith 
15002d9f4d1SDevajith } // namespace mlir::query::matcher
151