xref: /llvm-project/mlir/lib/Query/Matcher/Parser.cpp (revision 58b44c8102afb0e76d1cb70d4a5d089f70d2f657)
1 //===- Parser.cpp - Matcher expression parser -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Recursive parser implementation for the matcher expression grammar.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Parser.h"
14 
15 #include <vector>
16 
17 namespace mlir::query::matcher::internal {
18 
19 // Simple structure to hold information for one token from the parser.
20 struct Parser::TokenInfo {
21   TokenInfo() = default;
22 
23   // Method to set the kind and text of the token
setmlir::query::matcher::internal::Parser::TokenInfo24   void set(TokenKind newKind, llvm::StringRef newText) {
25     kind = newKind;
26     text = newText;
27   }
28 
29   // Known identifiers.
30   static const char *const ID_Extract;
31 
32   llvm::StringRef text;
33   TokenKind kind = TokenKind::Eof;
34   SourceRange range;
35   VariantValue value;
36 };
37 
38 const char *const Parser::TokenInfo::ID_Extract = "extract";
39 
40 class Parser::CodeTokenizer {
41 public:
42   // Constructor with matcherCode and error
CodeTokenizer(llvm::StringRef matcherCode,Diagnostics * error)43   explicit CodeTokenizer(llvm::StringRef matcherCode, Diagnostics *error)
44       : code(matcherCode), startOfLine(matcherCode), error(error) {
45     nextToken = getNextToken();
46   }
47 
48   // Constructor with matcherCode, error, and codeCompletionOffset
CodeTokenizer(llvm::StringRef matcherCode,Diagnostics * error,unsigned codeCompletionOffset)49   CodeTokenizer(llvm::StringRef matcherCode, Diagnostics *error,
50                 unsigned codeCompletionOffset)
51       : code(matcherCode), startOfLine(matcherCode), error(error),
52         codeCompletionLocation(matcherCode.data() + codeCompletionOffset) {
53     nextToken = getNextToken();
54   }
55 
56   // Peek at next token without consuming it
peekNextToken() const57   const TokenInfo &peekNextToken() const { return nextToken; }
58 
59   // Consume and return the next token
consumeNextToken()60   TokenInfo consumeNextToken() {
61     TokenInfo thisToken = nextToken;
62     nextToken = getNextToken();
63     return thisToken;
64   }
65 
66   // Skip any newline tokens
skipNewlines()67   TokenInfo skipNewlines() {
68     while (nextToken.kind == TokenKind::NewLine)
69       nextToken = getNextToken();
70     return nextToken;
71   }
72 
73   // Consume and return next token, ignoring newlines
consumeNextTokenIgnoreNewlines()74   TokenInfo consumeNextTokenIgnoreNewlines() {
75     skipNewlines();
76     return nextToken.kind == TokenKind::Eof ? nextToken : consumeNextToken();
77   }
78 
79   // Return kind of next token
nextTokenKind() const80   TokenKind nextTokenKind() const { return nextToken.kind; }
81 
82 private:
83   // Helper function to get the first character as a new StringRef and drop it
84   // from the original string
firstCharacterAndDrop(llvm::StringRef & str)85   llvm::StringRef firstCharacterAndDrop(llvm::StringRef &str) {
86     assert(!str.empty());
87     llvm::StringRef firstChar = str.substr(0, 1);
88     str = str.drop_front();
89     return firstChar;
90   }
91 
92   // Get next token, consuming whitespaces and handling different token types
getNextToken()93   TokenInfo getNextToken() {
94     consumeWhitespace();
95     TokenInfo result;
96     result.range.start = currentLocation();
97 
98     // Code completion case
99     if (codeCompletionLocation && codeCompletionLocation <= code.data()) {
100       result.set(TokenKind::CodeCompletion,
101                  llvm::StringRef(codeCompletionLocation, 0));
102       codeCompletionLocation = nullptr;
103       return result;
104     }
105 
106     // End of file case
107     if (code.empty()) {
108       result.set(TokenKind::Eof, "");
109       return result;
110     }
111 
112     // Switch to handle specific characters
113     switch (code[0]) {
114     case '#':
115       code = code.drop_until([](char c) { return c == '\n'; });
116       return getNextToken();
117     case ',':
118       result.set(TokenKind::Comma, firstCharacterAndDrop(code));
119       break;
120     case '.':
121       result.set(TokenKind::Period, firstCharacterAndDrop(code));
122       break;
123     case '\n':
124       ++line;
125       startOfLine = code.drop_front();
126       result.set(TokenKind::NewLine, firstCharacterAndDrop(code));
127       break;
128     case '(':
129       result.set(TokenKind::OpenParen, firstCharacterAndDrop(code));
130       break;
131     case ')':
132       result.set(TokenKind::CloseParen, firstCharacterAndDrop(code));
133       break;
134     case '"':
135     case '\'':
136       consumeStringLiteral(&result);
137       break;
138     default:
139       parseIdentifierOrInvalid(&result);
140       break;
141     }
142 
143     result.range.end = currentLocation();
144     return result;
145   }
146 
147   // Consume a string literal, handle escape sequences and missing closing
148   // quote.
consumeStringLiteral(TokenInfo * result)149   void consumeStringLiteral(TokenInfo *result) {
150     bool inEscape = false;
151     const char marker = code[0];
152     for (size_t length = 1; length < code.size(); ++length) {
153       if (inEscape) {
154         inEscape = false;
155         continue;
156       }
157       if (code[length] == '\\') {
158         inEscape = true;
159         continue;
160       }
161       if (code[length] == marker) {
162         result->kind = TokenKind::Literal;
163         result->text = code.substr(0, length + 1);
164         result->value = code.substr(1, length - 1);
165         code = code.drop_front(length + 1);
166         return;
167       }
168     }
169     llvm::StringRef errorText = code;
170     code = code.drop_front(code.size());
171     SourceRange range;
172     range.start = result->range.start;
173     range.end = currentLocation();
174     error->addError(range, ErrorType::ParserStringError) << errorText;
175     result->kind = TokenKind::Error;
176   }
177 
parseIdentifierOrInvalid(TokenInfo * result)178   void parseIdentifierOrInvalid(TokenInfo *result) {
179     if (isalnum(code[0])) {
180       // Parse an identifier
181       size_t tokenLength = 1;
182 
183       while (true) {
184         // A code completion location in/immediately after an identifier will
185         // cause the portion of the identifier before the code completion
186         // location to become a code completion token.
187         if (codeCompletionLocation == code.data() + tokenLength) {
188           codeCompletionLocation = nullptr;
189           result->kind = TokenKind::CodeCompletion;
190           result->text = code.substr(0, tokenLength);
191           code = code.drop_front(tokenLength);
192           return;
193         }
194         if (tokenLength == code.size() || !(isalnum(code[tokenLength])))
195           break;
196         ++tokenLength;
197       }
198       result->kind = TokenKind::Ident;
199       result->text = code.substr(0, tokenLength);
200       code = code.drop_front(tokenLength);
201     } else {
202       result->kind = TokenKind::InvalidChar;
203       result->text = code.substr(0, 1);
204       code = code.drop_front(1);
205     }
206   }
207 
208   // Consume all leading whitespace from code, except newlines
consumeWhitespace()209   void consumeWhitespace() { code = code.ltrim(" \t\v\f\r"); }
210 
211   // Returns the current location in the source code
currentLocation()212   SourceLocation currentLocation() {
213     SourceLocation location;
214     location.line = line;
215     location.column = code.data() - startOfLine.data() + 1;
216     return location;
217   }
218 
219   llvm::StringRef code;
220   llvm::StringRef startOfLine;
221   unsigned line = 1;
222   Diagnostics *error;
223   TokenInfo nextToken;
224   const char *codeCompletionLocation = nullptr;
225 };
226 
227 Parser::Sema::~Sema() = default;
228 
getAcceptedCompletionTypes(llvm::ArrayRef<std::pair<MatcherCtor,unsigned>> context)229 std::vector<ArgKind> Parser::Sema::getAcceptedCompletionTypes(
230     llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) {
231   return {};
232 }
233 
234 std::vector<MatcherCompletion>
getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes)235 Parser::Sema::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes) {
236   return {};
237 }
238 
239 // Entry for the scope of a parser
240 struct Parser::ScopedContextEntry {
241   Parser *parser;
242 
ScopedContextEntrymlir::query::matcher::internal::Parser::ScopedContextEntry243   ScopedContextEntry(Parser *parser, MatcherCtor c) : parser(parser) {
244     parser->contextStack.emplace_back(c, 0u);
245   }
246 
~ScopedContextEntrymlir::query::matcher::internal::Parser::ScopedContextEntry247   ~ScopedContextEntry() { parser->contextStack.pop_back(); }
248 
nextArgmlir::query::matcher::internal::Parser::ScopedContextEntry249   void nextArg() { ++parser->contextStack.back().second; }
250 };
251 
252 // Parse and validate expressions starting with an identifier.
253 // This function can parse named values and matchers. In case of failure, it
254 // will try to determine the user's intent to give an appropriate error message.
parseIdentifierPrefixImpl(VariantValue * value)255 bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
256   const TokenInfo nameToken = tokenizer->consumeNextToken();
257 
258   if (tokenizer->nextTokenKind() != TokenKind::OpenParen) {
259     // Parse as a named value.
260     auto namedValue =
261         namedValues ? namedValues->lookup(nameToken.text) : VariantValue();
262 
263     if (!namedValue.isMatcher()) {
264       error->addError(tokenizer->peekNextToken().range,
265                       ErrorType::ParserNotAMatcher);
266       return false;
267     }
268 
269     if (tokenizer->nextTokenKind() == TokenKind::NewLine) {
270       error->addError(tokenizer->peekNextToken().range,
271                       ErrorType::ParserNoOpenParen)
272           << "NewLine";
273       return false;
274     }
275 
276     // If the syntax is correct and the name is not a matcher either, report
277     // an unknown named value.
278     if ((tokenizer->nextTokenKind() == TokenKind::Comma ||
279          tokenizer->nextTokenKind() == TokenKind::CloseParen ||
280          tokenizer->nextTokenKind() == TokenKind::NewLine ||
281          tokenizer->nextTokenKind() == TokenKind::Eof) &&
282         !sema->lookupMatcherCtor(nameToken.text)) {
283       error->addError(nameToken.range, ErrorType::RegistryValueNotFound)
284           << nameToken.text;
285       return false;
286     }
287     // Otherwise, fallback to the matcher parser.
288   }
289 
290   tokenizer->skipNewlines();
291 
292   assert(nameToken.kind == TokenKind::Ident);
293   TokenInfo openToken = tokenizer->consumeNextToken();
294   if (openToken.kind != TokenKind::OpenParen) {
295     error->addError(openToken.range, ErrorType::ParserNoOpenParen)
296         << openToken.text;
297     return false;
298   }
299 
300   std::optional<MatcherCtor> ctor = sema->lookupMatcherCtor(nameToken.text);
301 
302   // Parse as a matcher expression.
303   return parseMatcherExpressionImpl(nameToken, openToken, ctor, value);
304 }
305 
parseChainedExpression(std::string & argument)306 bool Parser::parseChainedExpression(std::string &argument) {
307   // Parse the parenthesized argument to .extract("foo")
308   // Note: EOF is handled inside the consume functions and would fail below when
309   // checking token kind.
310   const TokenInfo openToken = tokenizer->consumeNextToken();
311   const TokenInfo argumentToken = tokenizer->consumeNextTokenIgnoreNewlines();
312   const TokenInfo closeToken = tokenizer->consumeNextTokenIgnoreNewlines();
313 
314   if (openToken.kind != TokenKind::OpenParen) {
315     error->addError(openToken.range, ErrorType::ParserChainedExprNoOpenParen);
316     return false;
317   }
318 
319   if (argumentToken.kind != TokenKind::Literal ||
320       !argumentToken.value.isString()) {
321     error->addError(argumentToken.range,
322                     ErrorType::ParserChainedExprInvalidArg);
323     return false;
324   }
325 
326   if (closeToken.kind != TokenKind::CloseParen) {
327     error->addError(closeToken.range, ErrorType::ParserChainedExprNoCloseParen);
328     return false;
329   }
330 
331   // If all checks passed, extract the argument and return true.
332   argument = argumentToken.value.getString();
333   return true;
334 }
335 
336 // Parse the arguments of a matcher
parseMatcherArgs(std::vector<ParserValue> & args,MatcherCtor ctor,const TokenInfo & nameToken,TokenInfo & endToken)337 bool Parser::parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
338                               const TokenInfo &nameToken, TokenInfo &endToken) {
339   ScopedContextEntry sce(this, ctor);
340 
341   while (tokenizer->nextTokenKind() != TokenKind::Eof) {
342     if (tokenizer->nextTokenKind() == TokenKind::CloseParen) {
343       // end of args.
344       endToken = tokenizer->consumeNextToken();
345       break;
346     }
347 
348     if (!args.empty()) {
349       // We must find a , token to continue.
350       TokenInfo commaToken = tokenizer->consumeNextToken();
351       if (commaToken.kind != TokenKind::Comma) {
352         error->addError(commaToken.range, ErrorType::ParserNoComma)
353             << commaToken.text;
354         return false;
355       }
356     }
357 
358     ParserValue argValue;
359     tokenizer->skipNewlines();
360 
361     argValue.text = tokenizer->peekNextToken().text;
362     argValue.range = tokenizer->peekNextToken().range;
363     if (!parseExpressionImpl(&argValue.value)) {
364       return false;
365     }
366 
367     tokenizer->skipNewlines();
368     args.push_back(argValue);
369     sce.nextArg();
370   }
371 
372   return true;
373 }
374 
375 // Parse and validate a matcher expression.
parseMatcherExpressionImpl(const TokenInfo & nameToken,const TokenInfo & openToken,std::optional<MatcherCtor> ctor,VariantValue * value)376 bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken,
377                                         const TokenInfo &openToken,
378                                         std::optional<MatcherCtor> ctor,
379                                         VariantValue *value) {
380   if (!ctor) {
381     error->addError(nameToken.range, ErrorType::RegistryMatcherNotFound)
382         << nameToken.text;
383     // Do not return here. We need to continue to give completion suggestions.
384   }
385 
386   std::vector<ParserValue> args;
387   TokenInfo endToken;
388 
389   tokenizer->skipNewlines();
390 
391   if (!parseMatcherArgs(args, ctor.value_or(nullptr), nameToken, endToken)) {
392     return false;
393   }
394 
395   // Check for the missing closing parenthesis
396   if (endToken.kind != TokenKind::CloseParen) {
397     error->addError(openToken.range, ErrorType::ParserNoCloseParen)
398         << nameToken.text;
399     return false;
400   }
401 
402   std::string functionName;
403   if (tokenizer->peekNextToken().kind == TokenKind::Period) {
404     tokenizer->consumeNextToken();
405     TokenInfo chainCallToken = tokenizer->consumeNextToken();
406     if (chainCallToken.kind == TokenKind::CodeCompletion) {
407       addCompletion(chainCallToken, MatcherCompletion("extract(\"", "extract"));
408       return false;
409     }
410 
411     if (chainCallToken.kind != TokenKind::Ident ||
412         chainCallToken.text != TokenInfo::ID_Extract) {
413       error->addError(chainCallToken.range,
414                       ErrorType::ParserMalformedChainedExpr);
415       return false;
416     }
417 
418     if (chainCallToken.text == TokenInfo::ID_Extract &&
419         !parseChainedExpression(functionName))
420       return false;
421   }
422 
423   if (!ctor)
424     return false;
425   // Merge the start and end infos.
426   SourceRange matcherRange = nameToken.range;
427   matcherRange.end = endToken.range.end;
428   VariantMatcher result = sema->actOnMatcherExpression(
429       *ctor, matcherRange, functionName, args, error);
430   if (result.isNull())
431     return false;
432   *value = result;
433   return true;
434 }
435 
436 // If the prefix of this completion matches the completion token, add it to
437 // completions minus the prefix.
addCompletion(const TokenInfo & compToken,const MatcherCompletion & completion)438 void Parser::addCompletion(const TokenInfo &compToken,
439                            const MatcherCompletion &completion) {
440   if (llvm::StringRef(completion.typedText).starts_with(compToken.text)) {
441     completions.emplace_back(completion.typedText.substr(compToken.text.size()),
442                              completion.matcherDecl);
443   }
444 }
445 
446 std::vector<MatcherCompletion>
getNamedValueCompletions(llvm::ArrayRef<ArgKind> acceptedTypes)447 Parser::getNamedValueCompletions(llvm::ArrayRef<ArgKind> acceptedTypes) {
448   if (!namedValues)
449     return {};
450 
451   std::vector<MatcherCompletion> result;
452   for (const auto &entry : *namedValues) {
453     std::string decl =
454         (entry.getValue().getTypeAsString() + " " + entry.getKey()).str();
455     result.emplace_back(entry.getKey(), decl);
456   }
457   return result;
458 }
459 
addExpressionCompletions()460 void Parser::addExpressionCompletions() {
461   const TokenInfo compToken = tokenizer->consumeNextTokenIgnoreNewlines();
462   assert(compToken.kind == TokenKind::CodeCompletion);
463 
464   // We cannot complete code if there is an invalid element on the context
465   // stack.
466   for (const auto &entry : contextStack) {
467     if (!entry.first)
468       return;
469   }
470 
471   auto acceptedTypes = sema->getAcceptedCompletionTypes(contextStack);
472   for (const auto &completion : sema->getMatcherCompletions(acceptedTypes)) {
473     addCompletion(compToken, completion);
474   }
475 
476   for (const auto &completion : getNamedValueCompletions(acceptedTypes)) {
477     addCompletion(compToken, completion);
478   }
479 }
480 
481 // Parse an <Expresssion>
parseExpressionImpl(VariantValue * value)482 bool Parser::parseExpressionImpl(VariantValue *value) {
483   switch (tokenizer->nextTokenKind()) {
484   case TokenKind::Literal:
485     *value = tokenizer->consumeNextToken().value;
486     return true;
487   case TokenKind::Ident:
488     return parseIdentifierPrefixImpl(value);
489   case TokenKind::CodeCompletion:
490     addExpressionCompletions();
491     return false;
492   case TokenKind::Eof:
493     error->addError(tokenizer->consumeNextToken().range,
494                     ErrorType::ParserNoCode);
495     return false;
496 
497   case TokenKind::Error:
498     // This error was already reported by the tokenizer.
499     return false;
500   case TokenKind::NewLine:
501   case TokenKind::OpenParen:
502   case TokenKind::CloseParen:
503   case TokenKind::Comma:
504   case TokenKind::Period:
505   case TokenKind::InvalidChar:
506     const TokenInfo token = tokenizer->consumeNextToken();
507     error->addError(token.range, ErrorType::ParserInvalidToken)
508         << (token.kind == TokenKind::NewLine ? "NewLine" : token.text);
509     return false;
510   }
511 
512   llvm_unreachable("Unknown token kind.");
513 }
514 
Parser(CodeTokenizer * tokenizer,const Registry & matcherRegistry,const NamedValueMap * namedValues,Diagnostics * error)515 Parser::Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry,
516                const NamedValueMap *namedValues, Diagnostics *error)
517     : tokenizer(tokenizer),
518       sema(std::make_unique<RegistrySema>(matcherRegistry)),
519       namedValues(namedValues), error(error) {}
520 
521 Parser::RegistrySema::~RegistrySema() = default;
522 
523 std::optional<MatcherCtor>
lookupMatcherCtor(llvm::StringRef matcherName)524 Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) {
525   return RegistryManager::lookupMatcherCtor(matcherName, matcherRegistry);
526 }
527 
actOnMatcherExpression(MatcherCtor ctor,SourceRange nameRange,llvm::StringRef functionName,llvm::ArrayRef<ParserValue> args,Diagnostics * error)528 VariantMatcher Parser::RegistrySema::actOnMatcherExpression(
529     MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
530     llvm::ArrayRef<ParserValue> args, Diagnostics *error) {
531   return RegistryManager::constructMatcher(ctor, nameRange, functionName, args,
532                                            error);
533 }
534 
getAcceptedCompletionTypes(llvm::ArrayRef<std::pair<MatcherCtor,unsigned>> context)535 std::vector<ArgKind> Parser::RegistrySema::getAcceptedCompletionTypes(
536     llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) {
537   return RegistryManager::getAcceptedCompletionTypes(context);
538 }
539 
getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes)540 std::vector<MatcherCompletion> Parser::RegistrySema::getMatcherCompletions(
541     llvm::ArrayRef<ArgKind> acceptedTypes) {
542   return RegistryManager::getMatcherCompletions(acceptedTypes, matcherRegistry);
543 }
544 
parseExpression(llvm::StringRef & code,const Registry & matcherRegistry,const NamedValueMap * namedValues,VariantValue * value,Diagnostics * error)545 bool Parser::parseExpression(llvm::StringRef &code,
546                              const Registry &matcherRegistry,
547                              const NamedValueMap *namedValues,
548                              VariantValue *value, Diagnostics *error) {
549   CodeTokenizer tokenizer(code, error);
550   Parser parser(&tokenizer, matcherRegistry, namedValues, error);
551   if (!parser.parseExpressionImpl(value))
552     return false;
553   auto nextToken = tokenizer.peekNextToken();
554   if (nextToken.kind != TokenKind::Eof &&
555       nextToken.kind != TokenKind::NewLine) {
556     error->addError(tokenizer.peekNextToken().range,
557                     ErrorType::ParserTrailingCode);
558     return false;
559   }
560   return true;
561 }
562 
563 std::vector<MatcherCompletion>
completeExpression(llvm::StringRef & code,unsigned completionOffset,const Registry & matcherRegistry,const NamedValueMap * namedValues)564 Parser::completeExpression(llvm::StringRef &code, unsigned completionOffset,
565                            const Registry &matcherRegistry,
566                            const NamedValueMap *namedValues) {
567   Diagnostics error;
568   CodeTokenizer tokenizer(code, &error, completionOffset);
569   Parser parser(&tokenizer, matcherRegistry, namedValues, &error);
570   VariantValue dummy;
571   parser.parseExpressionImpl(&dummy);
572 
573   return parser.completions;
574 }
575 
parseMatcherExpression(llvm::StringRef & code,const Registry & matcherRegistry,const NamedValueMap * namedValues,Diagnostics * error)576 std::optional<DynMatcher> Parser::parseMatcherExpression(
577     llvm::StringRef &code, const Registry &matcherRegistry,
578     const NamedValueMap *namedValues, Diagnostics *error) {
579   VariantValue value;
580   if (!parseExpression(code, matcherRegistry, namedValues, &value, error))
581     return std::nullopt;
582   if (!value.isMatcher()) {
583     error->addError(SourceRange(), ErrorType::ParserNotAMatcher);
584     return std::nullopt;
585   }
586   std::optional<DynMatcher> result = value.getMatcher().getDynMatcher();
587   if (!result) {
588     error->addError(SourceRange(), ErrorType::ParserOverloadedType)
589         << value.getTypeAsString();
590   }
591   return result;
592 }
593 
594 } // namespace mlir::query::matcher::internal
595