xref: /llvm-project/mlir/lib/Query/Matcher/Parser.cpp (revision 732a5cba8c739ed40a7280b5d74ca717910c2c4c)
1 //===- Parser.cpp - Matcher expression parser -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Recursive parser implementation for the matcher expression grammar.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Parser.h"
14 
15 #include <vector>
16 
17 namespace mlir::query::matcher::internal {
18 
19 // Simple structure to hold information for one token from the parser.
20 struct Parser::TokenInfo {
21   TokenInfo() = default;
22 
23   // Method to set the kind and text of the token
24   void set(TokenKind newKind, llvm::StringRef newText) {
25     kind = newKind;
26     text = newText;
27   }
28 
29   llvm::StringRef text;
30   TokenKind kind = TokenKind::Eof;
31   SourceRange range;
32   VariantValue value;
33 };
34 
35 class Parser::CodeTokenizer {
36 public:
37   // Constructor with matcherCode and error
38   explicit CodeTokenizer(llvm::StringRef matcherCode, Diagnostics *error)
39       : code(matcherCode), startOfLine(matcherCode), error(error) {
40     nextToken = getNextToken();
41   }
42 
43   // Constructor with matcherCode, error, and codeCompletionOffset
44   CodeTokenizer(llvm::StringRef matcherCode, Diagnostics *error,
45                 unsigned codeCompletionOffset)
46       : code(matcherCode), startOfLine(matcherCode), error(error),
47         codeCompletionLocation(matcherCode.data() + codeCompletionOffset) {
48     nextToken = getNextToken();
49   }
50 
51   // Peek at next token without consuming it
52   const TokenInfo &peekNextToken() const { return nextToken; }
53 
54   // Consume and return the next token
55   TokenInfo consumeNextToken() {
56     TokenInfo thisToken = nextToken;
57     nextToken = getNextToken();
58     return thisToken;
59   }
60 
61   // Skip any newline tokens
62   TokenInfo skipNewlines() {
63     while (nextToken.kind == TokenKind::NewLine)
64       nextToken = getNextToken();
65     return nextToken;
66   }
67 
68   // Consume and return next token, ignoring newlines
69   TokenInfo consumeNextTokenIgnoreNewlines() {
70     skipNewlines();
71     return nextToken.kind == TokenKind::Eof ? nextToken : consumeNextToken();
72   }
73 
74   // Return kind of next token
75   TokenKind nextTokenKind() const { return nextToken.kind; }
76 
77 private:
78   // Helper function to get the first character as a new StringRef and drop it
79   // from the original string
80   llvm::StringRef firstCharacterAndDrop(llvm::StringRef &str) {
81     assert(!str.empty());
82     llvm::StringRef firstChar = str.substr(0, 1);
83     str = str.drop_front();
84     return firstChar;
85   }
86 
87   // Get next token, consuming whitespaces and handling different token types
88   TokenInfo getNextToken() {
89     consumeWhitespace();
90     TokenInfo result;
91     result.range.start = currentLocation();
92 
93     // Code completion case
94     if (codeCompletionLocation && codeCompletionLocation <= code.data()) {
95       result.set(TokenKind::CodeCompletion,
96                  llvm::StringRef(codeCompletionLocation, 0));
97       codeCompletionLocation = nullptr;
98       return result;
99     }
100 
101     // End of file case
102     if (code.empty()) {
103       result.set(TokenKind::Eof, "");
104       return result;
105     }
106 
107     // Switch to handle specific characters
108     switch (code[0]) {
109     case '#':
110       code = code.drop_until([](char c) { return c == '\n'; });
111       return getNextToken();
112     case ',':
113       result.set(TokenKind::Comma, firstCharacterAndDrop(code));
114       break;
115     case '.':
116       result.set(TokenKind::Period, firstCharacterAndDrop(code));
117       break;
118     case '\n':
119       ++line;
120       startOfLine = code.drop_front();
121       result.set(TokenKind::NewLine, firstCharacterAndDrop(code));
122       break;
123     case '(':
124       result.set(TokenKind::OpenParen, firstCharacterAndDrop(code));
125       break;
126     case ')':
127       result.set(TokenKind::CloseParen, firstCharacterAndDrop(code));
128       break;
129     case '"':
130     case '\'':
131       consumeStringLiteral(&result);
132       break;
133     default:
134       parseIdentifierOrInvalid(&result);
135       break;
136     }
137 
138     result.range.end = currentLocation();
139     return result;
140   }
141 
142   // Consume a string literal, handle escape sequences and missing closing
143   // quote.
144   void consumeStringLiteral(TokenInfo *result) {
145     bool inEscape = false;
146     const char marker = code[0];
147     for (size_t length = 1; length < code.size(); ++length) {
148       if (inEscape) {
149         inEscape = false;
150         continue;
151       }
152       if (code[length] == '\\') {
153         inEscape = true;
154         continue;
155       }
156       if (code[length] == marker) {
157         result->kind = TokenKind::Literal;
158         result->text = code.substr(0, length + 1);
159         result->value = code.substr(1, length - 1);
160         code = code.drop_front(length + 1);
161         return;
162       }
163     }
164     llvm::StringRef errorText = code;
165     code = code.drop_front(code.size());
166     SourceRange range;
167     range.start = result->range.start;
168     range.end = currentLocation();
169     error->addError(range, ErrorType::ParserStringError) << errorText;
170     result->kind = TokenKind::Error;
171   }
172 
173   void parseIdentifierOrInvalid(TokenInfo *result) {
174     if (isalnum(code[0])) {
175       // Parse an identifier
176       size_t tokenLength = 1;
177 
178       while (true) {
179         // A code completion location in/immediately after an identifier will
180         // cause the portion of the identifier before the code completion
181         // location to become a code completion token.
182         if (codeCompletionLocation == code.data() + tokenLength) {
183           codeCompletionLocation = nullptr;
184           result->kind = TokenKind::CodeCompletion;
185           result->text = code.substr(0, tokenLength);
186           code = code.drop_front(tokenLength);
187           return;
188         }
189         if (tokenLength == code.size() || !(isalnum(code[tokenLength])))
190           break;
191         ++tokenLength;
192       }
193       result->kind = TokenKind::Ident;
194       result->text = code.substr(0, tokenLength);
195       code = code.drop_front(tokenLength);
196     } else {
197       result->kind = TokenKind::InvalidChar;
198       result->text = code.substr(0, 1);
199       code = code.drop_front(1);
200     }
201   }
202 
203   // Consume all leading whitespace from code, except newlines
204   void consumeWhitespace() { code = code.ltrim(" \t\v\f\r"); }
205 
206   // Returns the current location in the source code
207   SourceLocation currentLocation() {
208     SourceLocation location;
209     location.line = line;
210     location.column = code.data() - startOfLine.data() + 1;
211     return location;
212   }
213 
214   llvm::StringRef code;
215   llvm::StringRef startOfLine;
216   unsigned line = 1;
217   Diagnostics *error;
218   TokenInfo nextToken;
219   const char *codeCompletionLocation = nullptr;
220 };
221 
222 Parser::Sema::~Sema() = default;
223 
224 std::vector<ArgKind> Parser::Sema::getAcceptedCompletionTypes(
225     llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) {
226   return {};
227 }
228 
229 std::vector<MatcherCompletion>
230 Parser::Sema::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes) {
231   return {};
232 }
233 
234 // Entry for the scope of a parser
235 struct Parser::ScopedContextEntry {
236   Parser *parser;
237 
238   ScopedContextEntry(Parser *parser, MatcherCtor c) : parser(parser) {
239     parser->contextStack.emplace_back(c, 0u);
240   }
241 
242   ~ScopedContextEntry() { parser->contextStack.pop_back(); }
243 
244   void nextArg() { ++parser->contextStack.back().second; }
245 };
246 
247 // Parse and validate expressions starting with an identifier.
248 // This function can parse named values and matchers. In case of failure, it
249 // will try to determine the user's intent to give an appropriate error message.
250 bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
251   const TokenInfo nameToken = tokenizer->consumeNextToken();
252 
253   if (tokenizer->nextTokenKind() != TokenKind::OpenParen) {
254     // Parse as a named value.
255     auto namedValue =
256         namedValues ? namedValues->lookup(nameToken.text) : VariantValue();
257 
258     if (!namedValue.isMatcher()) {
259       error->addError(tokenizer->peekNextToken().range,
260                       ErrorType::ParserNotAMatcher);
261       return false;
262     }
263 
264     if (tokenizer->nextTokenKind() == TokenKind::NewLine) {
265       error->addError(tokenizer->peekNextToken().range,
266                       ErrorType::ParserNoOpenParen)
267           << "NewLine";
268       return false;
269     }
270 
271     // If the syntax is correct and the name is not a matcher either, report
272     // an unknown named value.
273     if ((tokenizer->nextTokenKind() == TokenKind::Comma ||
274          tokenizer->nextTokenKind() == TokenKind::CloseParen ||
275          tokenizer->nextTokenKind() == TokenKind::NewLine ||
276          tokenizer->nextTokenKind() == TokenKind::Eof) &&
277         !sema->lookupMatcherCtor(nameToken.text)) {
278       error->addError(nameToken.range, ErrorType::RegistryValueNotFound)
279           << nameToken.text;
280       return false;
281     }
282     // Otherwise, fallback to the matcher parser.
283   }
284 
285   tokenizer->skipNewlines();
286 
287   assert(nameToken.kind == TokenKind::Ident);
288   TokenInfo openToken = tokenizer->consumeNextToken();
289   if (openToken.kind != TokenKind::OpenParen) {
290     error->addError(openToken.range, ErrorType::ParserNoOpenParen)
291         << openToken.text;
292     return false;
293   }
294 
295   std::optional<MatcherCtor> ctor = sema->lookupMatcherCtor(nameToken.text);
296 
297   // Parse as a matcher expression.
298   return parseMatcherExpressionImpl(nameToken, openToken, ctor, value);
299 }
300 
301 // Parse the arguments of a matcher
302 bool Parser::parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
303                               const TokenInfo &nameToken, TokenInfo &endToken) {
304   ScopedContextEntry sce(this, ctor);
305 
306   while (tokenizer->nextTokenKind() != TokenKind::Eof) {
307     if (tokenizer->nextTokenKind() == TokenKind::CloseParen) {
308       // end of args.
309       endToken = tokenizer->consumeNextToken();
310       break;
311     }
312 
313     if (!args.empty()) {
314       // We must find a , token to continue.
315       TokenInfo commaToken = tokenizer->consumeNextToken();
316       if (commaToken.kind != TokenKind::Comma) {
317         error->addError(commaToken.range, ErrorType::ParserNoComma)
318             << commaToken.text;
319         return false;
320       }
321     }
322 
323     ParserValue argValue;
324     tokenizer->skipNewlines();
325 
326     argValue.text = tokenizer->peekNextToken().text;
327     argValue.range = tokenizer->peekNextToken().range;
328     if (!parseExpressionImpl(&argValue.value)) {
329       return false;
330     }
331 
332     tokenizer->skipNewlines();
333     args.push_back(argValue);
334     sce.nextArg();
335   }
336 
337   return true;
338 }
339 
340 // Parse and validate a matcher expression.
341 bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken,
342                                         const TokenInfo &openToken,
343                                         std::optional<MatcherCtor> ctor,
344                                         VariantValue *value) {
345   if (!ctor) {
346     error->addError(nameToken.range, ErrorType::RegistryMatcherNotFound)
347         << nameToken.text;
348     // Do not return here. We need to continue to give completion suggestions.
349   }
350 
351   std::vector<ParserValue> args;
352   TokenInfo endToken;
353 
354   tokenizer->skipNewlines();
355 
356   if (!parseMatcherArgs(args, ctor.value_or(nullptr), nameToken, endToken)) {
357     return false;
358   }
359 
360   // Check for the missing closing parenthesis
361   if (endToken.kind != TokenKind::CloseParen) {
362     error->addError(openToken.range, ErrorType::ParserNoCloseParen)
363         << nameToken.text;
364     return false;
365   }
366 
367   if (!ctor)
368     return false;
369   // Merge the start and end infos.
370   SourceRange matcherRange = nameToken.range;
371   matcherRange.end = endToken.range.end;
372   VariantMatcher result =
373       sema->actOnMatcherExpression(*ctor, matcherRange, args, error);
374   if (result.isNull())
375     return false;
376   *value = result;
377   return true;
378 }
379 
380 // If the prefix of this completion matches the completion token, add it to
381 // completions minus the prefix.
382 void Parser::addCompletion(const TokenInfo &compToken,
383                            const MatcherCompletion &completion) {
384   if (llvm::StringRef(completion.typedText).starts_with(compToken.text)) {
385     completions.emplace_back(completion.typedText.substr(compToken.text.size()),
386                              completion.matcherDecl);
387   }
388 }
389 
390 std::vector<MatcherCompletion>
391 Parser::getNamedValueCompletions(llvm::ArrayRef<ArgKind> acceptedTypes) {
392   if (!namedValues)
393     return {};
394 
395   std::vector<MatcherCompletion> result;
396   for (const auto &entry : *namedValues) {
397     std::string decl =
398         (entry.getValue().getTypeAsString() + " " + entry.getKey()).str();
399     result.emplace_back(entry.getKey(), decl);
400   }
401   return result;
402 }
403 
404 void Parser::addExpressionCompletions() {
405   const TokenInfo compToken = tokenizer->consumeNextTokenIgnoreNewlines();
406   assert(compToken.kind == TokenKind::CodeCompletion);
407 
408   // We cannot complete code if there is an invalid element on the context
409   // stack.
410   for (const auto &entry : contextStack) {
411     if (!entry.first)
412       return;
413   }
414 
415   auto acceptedTypes = sema->getAcceptedCompletionTypes(contextStack);
416   for (const auto &completion : sema->getMatcherCompletions(acceptedTypes)) {
417     addCompletion(compToken, completion);
418   }
419 
420   for (const auto &completion : getNamedValueCompletions(acceptedTypes)) {
421     addCompletion(compToken, completion);
422   }
423 }
424 
425 // Parse an <Expresssion>
426 bool Parser::parseExpressionImpl(VariantValue *value) {
427   switch (tokenizer->nextTokenKind()) {
428   case TokenKind::Literal:
429     *value = tokenizer->consumeNextToken().value;
430     return true;
431   case TokenKind::Ident:
432     return parseIdentifierPrefixImpl(value);
433   case TokenKind::CodeCompletion:
434     addExpressionCompletions();
435     return false;
436   case TokenKind::Eof:
437     error->addError(tokenizer->consumeNextToken().range,
438                     ErrorType::ParserNoCode);
439     return false;
440 
441   case TokenKind::Error:
442     // This error was already reported by the tokenizer.
443     return false;
444   case TokenKind::NewLine:
445   case TokenKind::OpenParen:
446   case TokenKind::CloseParen:
447   case TokenKind::Comma:
448   case TokenKind::Period:
449   case TokenKind::InvalidChar:
450     const TokenInfo token = tokenizer->consumeNextToken();
451     error->addError(token.range, ErrorType::ParserInvalidToken)
452         << (token.kind == TokenKind::NewLine ? "NewLine" : token.text);
453     return false;
454   }
455 
456   llvm_unreachable("Unknown token kind.");
457 }
458 
459 Parser::Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry,
460                const NamedValueMap *namedValues, Diagnostics *error)
461     : tokenizer(tokenizer),
462       sema(std::make_unique<RegistrySema>(matcherRegistry)),
463       namedValues(namedValues), error(error) {}
464 
465 Parser::RegistrySema::~RegistrySema() = default;
466 
467 std::optional<MatcherCtor>
468 Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) {
469   return RegistryManager::lookupMatcherCtor(matcherName, matcherRegistry);
470 }
471 
472 VariantMatcher Parser::RegistrySema::actOnMatcherExpression(
473     MatcherCtor ctor, SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
474     Diagnostics *error) {
475   return RegistryManager::constructMatcher(ctor, nameRange, args, error);
476 }
477 
478 std::vector<ArgKind> Parser::RegistrySema::getAcceptedCompletionTypes(
479     llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) {
480   return RegistryManager::getAcceptedCompletionTypes(context);
481 }
482 
483 std::vector<MatcherCompletion> Parser::RegistrySema::getMatcherCompletions(
484     llvm::ArrayRef<ArgKind> acceptedTypes) {
485   return RegistryManager::getMatcherCompletions(acceptedTypes, matcherRegistry);
486 }
487 
488 bool Parser::parseExpression(llvm::StringRef &code,
489                              const Registry &matcherRegistry,
490                              const NamedValueMap *namedValues,
491                              VariantValue *value, Diagnostics *error) {
492   CodeTokenizer tokenizer(code, error);
493   Parser parser(&tokenizer, matcherRegistry, namedValues, error);
494   if (!parser.parseExpressionImpl(value))
495     return false;
496   auto nextToken = tokenizer.peekNextToken();
497   if (nextToken.kind != TokenKind::Eof &&
498       nextToken.kind != TokenKind::NewLine) {
499     error->addError(tokenizer.peekNextToken().range,
500                     ErrorType::ParserTrailingCode);
501     return false;
502   }
503   return true;
504 }
505 
506 std::vector<MatcherCompletion>
507 Parser::completeExpression(llvm::StringRef &code, unsigned completionOffset,
508                            const Registry &matcherRegistry,
509                            const NamedValueMap *namedValues) {
510   Diagnostics error;
511   CodeTokenizer tokenizer(code, &error, completionOffset);
512   Parser parser(&tokenizer, matcherRegistry, namedValues, &error);
513   VariantValue dummy;
514   parser.parseExpressionImpl(&dummy);
515 
516   return parser.completions;
517 }
518 
519 std::optional<DynMatcher> Parser::parseMatcherExpression(
520     llvm::StringRef &code, const Registry &matcherRegistry,
521     const NamedValueMap *namedValues, Diagnostics *error) {
522   VariantValue value;
523   if (!parseExpression(code, matcherRegistry, namedValues, &value, error))
524     return std::nullopt;
525   if (!value.isMatcher()) {
526     error->addError(SourceRange(), ErrorType::ParserNotAMatcher);
527     return std::nullopt;
528   }
529   std::optional<DynMatcher> result = value.getMatcher().getDynMatcher();
530   if (!result) {
531     error->addError(SourceRange(), ErrorType::ParserOverloadedType)
532         << value.getTypeAsString();
533   }
534   return result;
535 }
536 
537 } // namespace mlir::query::matcher::internal
538