xref: /llvm-project/mlir/lib/Query/Matcher/RegistryManager.cpp (revision 884221eddb9d395830704fac79fd04008e02e368)
1 //===- RegistryManager.cpp - Matcher registry -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Registry map populated at static initialization time.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "RegistryManager.h"
14 #include "mlir/Query/Matcher/Registry.h"
15 
16 #include <set>
17 #include <utility>
18 
19 namespace mlir::query::matcher {
20 namespace {
21 
22 // This is needed because these matchers are defined as overloaded functions.
23 using IsConstantOp = detail::constant_op_matcher();
24 using HasOpAttrName = detail::AttrOpMatcher(llvm::StringRef);
25 using HasOpName = detail::NameOpMatcher(llvm::StringRef);
26 
27 // Enum to string for autocomplete.
28 static std::string asArgString(ArgKind kind) {
29   switch (kind) {
30   case ArgKind::Matcher:
31     return "Matcher";
32   case ArgKind::String:
33     return "String";
34   }
35   llvm_unreachable("Unhandled ArgKind");
36 }
37 
38 } // namespace
39 
40 void Registry::registerMatcherDescriptor(
41     llvm::StringRef matcherName,
42     std::unique_ptr<internal::MatcherDescriptor> callback) {
43   assert(!constructorMap.contains(matcherName));
44   constructorMap[matcherName] = std::move(callback);
45 }
46 
47 std::optional<MatcherCtor>
48 RegistryManager::lookupMatcherCtor(llvm::StringRef matcherName,
49                                    const Registry &matcherRegistry) {
50   auto it = matcherRegistry.constructors().find(matcherName);
51   return it == matcherRegistry.constructors().end()
52              ? std::optional<MatcherCtor>()
53              : it->second.get();
54 }
55 
56 std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes(
57     llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) {
58   // Starting with the above seed of acceptable top-level matcher types, compute
59   // the acceptable type set for the argument indicated by each context element.
60   std::set<ArgKind> typeSet;
61   typeSet.insert(ArgKind::Matcher);
62 
63   for (const auto &ctxEntry : context) {
64     MatcherCtor ctor = ctxEntry.first;
65     unsigned argNumber = ctxEntry.second;
66     std::vector<ArgKind> nextTypeSet;
67 
68     if (argNumber < ctor->getNumArgs())
69       ctor->getArgKinds(argNumber, nextTypeSet);
70 
71     typeSet.insert(nextTypeSet.begin(), nextTypeSet.end());
72   }
73 
74   return std::vector<ArgKind>(typeSet.begin(), typeSet.end());
75 }
76 
77 std::vector<MatcherCompletion>
78 RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
79                                        const Registry &matcherRegistry) {
80   std::vector<MatcherCompletion> completions;
81 
82   // Search the registry for acceptable matchers.
83   for (const auto &m : matcherRegistry.constructors()) {
84     const internal::MatcherDescriptor &matcher = *m.getValue();
85     llvm::StringRef name = m.getKey();
86 
87     unsigned numArgs = matcher.getNumArgs();
88     std::vector<std::vector<ArgKind>> argKinds(numArgs);
89 
90     for (const ArgKind &kind : acceptedTypes) {
91       if (kind != ArgKind::Matcher)
92         continue;
93 
94       for (unsigned arg = 0; arg != numArgs; ++arg)
95         matcher.getArgKinds(arg, argKinds[arg]);
96     }
97 
98     std::string decl;
99     llvm::raw_string_ostream os(decl);
100 
101     std::string typedText = std::string(name);
102     os << "Matcher: " << name << "(";
103 
104     for (const std::vector<ArgKind> &arg : argKinds) {
105       if (&arg != &argKinds[0])
106         os << ", ";
107 
108       bool firstArgKind = true;
109       // Two steps. First all non-matchers, then matchers only.
110       for (const ArgKind &argKind : arg) {
111         if (!firstArgKind)
112           os << "|";
113 
114         firstArgKind = false;
115         os << asArgString(argKind);
116       }
117     }
118 
119     os << ")";
120     typedText += "(";
121 
122     if (argKinds.empty())
123       typedText += ")";
124     else if (argKinds[0][0] == ArgKind::String)
125       typedText += "\"";
126 
127     completions.emplace_back(typedText, decl);
128   }
129 
130   return completions;
131 }
132 
133 VariantMatcher RegistryManager::constructMatcher(
134     MatcherCtor ctor, internal::SourceRange nameRange,
135     llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args,
136     internal::Diagnostics *error) {
137   VariantMatcher out = ctor->create(nameRange, args, error);
138   if (functionName.empty() || out.isNull())
139     return out;
140 
141   if (std::optional<DynMatcher> result = out.getDynMatcher()) {
142     result->setFunctionName(functionName);
143     return VariantMatcher::SingleMatcher(*result);
144   }
145 
146   error->addError(nameRange, internal::ErrorType::RegistryNotBindable);
147   return {};
148 }
149 
150 } // namespace mlir::query::matcher
151