1 //===- Parser.cpp - Matcher expression parser -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Recursive parser implementation for the matcher expression grammar.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "Parser.h"
14
15 #include <vector>
16
17 namespace mlir::query::matcher::internal {
18
19 // Simple structure to hold information for one token from the parser.
20 struct Parser::TokenInfo {
21 TokenInfo() = default;
22
23 // Method to set the kind and text of the token
setmlir::query::matcher::internal::Parser::TokenInfo24 void set(TokenKind newKind, llvm::StringRef newText) {
25 kind = newKind;
26 text = newText;
27 }
28
29 // Known identifiers.
30 static const char *const ID_Extract;
31
32 llvm::StringRef text;
33 TokenKind kind = TokenKind::Eof;
34 SourceRange range;
35 VariantValue value;
36 };
37
38 const char *const Parser::TokenInfo::ID_Extract = "extract";
39
40 class Parser::CodeTokenizer {
41 public:
42 // Constructor with matcherCode and error
CodeTokenizer(llvm::StringRef matcherCode,Diagnostics * error)43 explicit CodeTokenizer(llvm::StringRef matcherCode, Diagnostics *error)
44 : code(matcherCode), startOfLine(matcherCode), error(error) {
45 nextToken = getNextToken();
46 }
47
48 // Constructor with matcherCode, error, and codeCompletionOffset
CodeTokenizer(llvm::StringRef matcherCode,Diagnostics * error,unsigned codeCompletionOffset)49 CodeTokenizer(llvm::StringRef matcherCode, Diagnostics *error,
50 unsigned codeCompletionOffset)
51 : code(matcherCode), startOfLine(matcherCode), error(error),
52 codeCompletionLocation(matcherCode.data() + codeCompletionOffset) {
53 nextToken = getNextToken();
54 }
55
56 // Peek at next token without consuming it
peekNextToken() const57 const TokenInfo &peekNextToken() const { return nextToken; }
58
59 // Consume and return the next token
consumeNextToken()60 TokenInfo consumeNextToken() {
61 TokenInfo thisToken = nextToken;
62 nextToken = getNextToken();
63 return thisToken;
64 }
65
66 // Skip any newline tokens
skipNewlines()67 TokenInfo skipNewlines() {
68 while (nextToken.kind == TokenKind::NewLine)
69 nextToken = getNextToken();
70 return nextToken;
71 }
72
73 // Consume and return next token, ignoring newlines
consumeNextTokenIgnoreNewlines()74 TokenInfo consumeNextTokenIgnoreNewlines() {
75 skipNewlines();
76 return nextToken.kind == TokenKind::Eof ? nextToken : consumeNextToken();
77 }
78
79 // Return kind of next token
nextTokenKind() const80 TokenKind nextTokenKind() const { return nextToken.kind; }
81
82 private:
83 // Helper function to get the first character as a new StringRef and drop it
84 // from the original string
firstCharacterAndDrop(llvm::StringRef & str)85 llvm::StringRef firstCharacterAndDrop(llvm::StringRef &str) {
86 assert(!str.empty());
87 llvm::StringRef firstChar = str.substr(0, 1);
88 str = str.drop_front();
89 return firstChar;
90 }
91
92 // Get next token, consuming whitespaces and handling different token types
getNextToken()93 TokenInfo getNextToken() {
94 consumeWhitespace();
95 TokenInfo result;
96 result.range.start = currentLocation();
97
98 // Code completion case
99 if (codeCompletionLocation && codeCompletionLocation <= code.data()) {
100 result.set(TokenKind::CodeCompletion,
101 llvm::StringRef(codeCompletionLocation, 0));
102 codeCompletionLocation = nullptr;
103 return result;
104 }
105
106 // End of file case
107 if (code.empty()) {
108 result.set(TokenKind::Eof, "");
109 return result;
110 }
111
112 // Switch to handle specific characters
113 switch (code[0]) {
114 case '#':
115 code = code.drop_until([](char c) { return c == '\n'; });
116 return getNextToken();
117 case ',':
118 result.set(TokenKind::Comma, firstCharacterAndDrop(code));
119 break;
120 case '.':
121 result.set(TokenKind::Period, firstCharacterAndDrop(code));
122 break;
123 case '\n':
124 ++line;
125 startOfLine = code.drop_front();
126 result.set(TokenKind::NewLine, firstCharacterAndDrop(code));
127 break;
128 case '(':
129 result.set(TokenKind::OpenParen, firstCharacterAndDrop(code));
130 break;
131 case ')':
132 result.set(TokenKind::CloseParen, firstCharacterAndDrop(code));
133 break;
134 case '"':
135 case '\'':
136 consumeStringLiteral(&result);
137 break;
138 default:
139 parseIdentifierOrInvalid(&result);
140 break;
141 }
142
143 result.range.end = currentLocation();
144 return result;
145 }
146
147 // Consume a string literal, handle escape sequences and missing closing
148 // quote.
consumeStringLiteral(TokenInfo * result)149 void consumeStringLiteral(TokenInfo *result) {
150 bool inEscape = false;
151 const char marker = code[0];
152 for (size_t length = 1; length < code.size(); ++length) {
153 if (inEscape) {
154 inEscape = false;
155 continue;
156 }
157 if (code[length] == '\\') {
158 inEscape = true;
159 continue;
160 }
161 if (code[length] == marker) {
162 result->kind = TokenKind::Literal;
163 result->text = code.substr(0, length + 1);
164 result->value = code.substr(1, length - 1);
165 code = code.drop_front(length + 1);
166 return;
167 }
168 }
169 llvm::StringRef errorText = code;
170 code = code.drop_front(code.size());
171 SourceRange range;
172 range.start = result->range.start;
173 range.end = currentLocation();
174 error->addError(range, ErrorType::ParserStringError) << errorText;
175 result->kind = TokenKind::Error;
176 }
177
parseIdentifierOrInvalid(TokenInfo * result)178 void parseIdentifierOrInvalid(TokenInfo *result) {
179 if (isalnum(code[0])) {
180 // Parse an identifier
181 size_t tokenLength = 1;
182
183 while (true) {
184 // A code completion location in/immediately after an identifier will
185 // cause the portion of the identifier before the code completion
186 // location to become a code completion token.
187 if (codeCompletionLocation == code.data() + tokenLength) {
188 codeCompletionLocation = nullptr;
189 result->kind = TokenKind::CodeCompletion;
190 result->text = code.substr(0, tokenLength);
191 code = code.drop_front(tokenLength);
192 return;
193 }
194 if (tokenLength == code.size() || !(isalnum(code[tokenLength])))
195 break;
196 ++tokenLength;
197 }
198 result->kind = TokenKind::Ident;
199 result->text = code.substr(0, tokenLength);
200 code = code.drop_front(tokenLength);
201 } else {
202 result->kind = TokenKind::InvalidChar;
203 result->text = code.substr(0, 1);
204 code = code.drop_front(1);
205 }
206 }
207
208 // Consume all leading whitespace from code, except newlines
consumeWhitespace()209 void consumeWhitespace() { code = code.ltrim(" \t\v\f\r"); }
210
211 // Returns the current location in the source code
currentLocation()212 SourceLocation currentLocation() {
213 SourceLocation location;
214 location.line = line;
215 location.column = code.data() - startOfLine.data() + 1;
216 return location;
217 }
218
219 llvm::StringRef code;
220 llvm::StringRef startOfLine;
221 unsigned line = 1;
222 Diagnostics *error;
223 TokenInfo nextToken;
224 const char *codeCompletionLocation = nullptr;
225 };
226
227 Parser::Sema::~Sema() = default;
228
getAcceptedCompletionTypes(llvm::ArrayRef<std::pair<MatcherCtor,unsigned>> context)229 std::vector<ArgKind> Parser::Sema::getAcceptedCompletionTypes(
230 llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) {
231 return {};
232 }
233
234 std::vector<MatcherCompletion>
getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes)235 Parser::Sema::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes) {
236 return {};
237 }
238
239 // Entry for the scope of a parser
240 struct Parser::ScopedContextEntry {
241 Parser *parser;
242
ScopedContextEntrymlir::query::matcher::internal::Parser::ScopedContextEntry243 ScopedContextEntry(Parser *parser, MatcherCtor c) : parser(parser) {
244 parser->contextStack.emplace_back(c, 0u);
245 }
246
~ScopedContextEntrymlir::query::matcher::internal::Parser::ScopedContextEntry247 ~ScopedContextEntry() { parser->contextStack.pop_back(); }
248
nextArgmlir::query::matcher::internal::Parser::ScopedContextEntry249 void nextArg() { ++parser->contextStack.back().second; }
250 };
251
252 // Parse and validate expressions starting with an identifier.
253 // This function can parse named values and matchers. In case of failure, it
254 // will try to determine the user's intent to give an appropriate error message.
parseIdentifierPrefixImpl(VariantValue * value)255 bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
256 const TokenInfo nameToken = tokenizer->consumeNextToken();
257
258 if (tokenizer->nextTokenKind() != TokenKind::OpenParen) {
259 // Parse as a named value.
260 auto namedValue =
261 namedValues ? namedValues->lookup(nameToken.text) : VariantValue();
262
263 if (!namedValue.isMatcher()) {
264 error->addError(tokenizer->peekNextToken().range,
265 ErrorType::ParserNotAMatcher);
266 return false;
267 }
268
269 if (tokenizer->nextTokenKind() == TokenKind::NewLine) {
270 error->addError(tokenizer->peekNextToken().range,
271 ErrorType::ParserNoOpenParen)
272 << "NewLine";
273 return false;
274 }
275
276 // If the syntax is correct and the name is not a matcher either, report
277 // an unknown named value.
278 if ((tokenizer->nextTokenKind() == TokenKind::Comma ||
279 tokenizer->nextTokenKind() == TokenKind::CloseParen ||
280 tokenizer->nextTokenKind() == TokenKind::NewLine ||
281 tokenizer->nextTokenKind() == TokenKind::Eof) &&
282 !sema->lookupMatcherCtor(nameToken.text)) {
283 error->addError(nameToken.range, ErrorType::RegistryValueNotFound)
284 << nameToken.text;
285 return false;
286 }
287 // Otherwise, fallback to the matcher parser.
288 }
289
290 tokenizer->skipNewlines();
291
292 assert(nameToken.kind == TokenKind::Ident);
293 TokenInfo openToken = tokenizer->consumeNextToken();
294 if (openToken.kind != TokenKind::OpenParen) {
295 error->addError(openToken.range, ErrorType::ParserNoOpenParen)
296 << openToken.text;
297 return false;
298 }
299
300 std::optional<MatcherCtor> ctor = sema->lookupMatcherCtor(nameToken.text);
301
302 // Parse as a matcher expression.
303 return parseMatcherExpressionImpl(nameToken, openToken, ctor, value);
304 }
305
parseChainedExpression(std::string & argument)306 bool Parser::parseChainedExpression(std::string &argument) {
307 // Parse the parenthesized argument to .extract("foo")
308 // Note: EOF is handled inside the consume functions and would fail below when
309 // checking token kind.
310 const TokenInfo openToken = tokenizer->consumeNextToken();
311 const TokenInfo argumentToken = tokenizer->consumeNextTokenIgnoreNewlines();
312 const TokenInfo closeToken = tokenizer->consumeNextTokenIgnoreNewlines();
313
314 if (openToken.kind != TokenKind::OpenParen) {
315 error->addError(openToken.range, ErrorType::ParserChainedExprNoOpenParen);
316 return false;
317 }
318
319 if (argumentToken.kind != TokenKind::Literal ||
320 !argumentToken.value.isString()) {
321 error->addError(argumentToken.range,
322 ErrorType::ParserChainedExprInvalidArg);
323 return false;
324 }
325
326 if (closeToken.kind != TokenKind::CloseParen) {
327 error->addError(closeToken.range, ErrorType::ParserChainedExprNoCloseParen);
328 return false;
329 }
330
331 // If all checks passed, extract the argument and return true.
332 argument = argumentToken.value.getString();
333 return true;
334 }
335
336 // Parse the arguments of a matcher
parseMatcherArgs(std::vector<ParserValue> & args,MatcherCtor ctor,const TokenInfo & nameToken,TokenInfo & endToken)337 bool Parser::parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
338 const TokenInfo &nameToken, TokenInfo &endToken) {
339 ScopedContextEntry sce(this, ctor);
340
341 while (tokenizer->nextTokenKind() != TokenKind::Eof) {
342 if (tokenizer->nextTokenKind() == TokenKind::CloseParen) {
343 // end of args.
344 endToken = tokenizer->consumeNextToken();
345 break;
346 }
347
348 if (!args.empty()) {
349 // We must find a , token to continue.
350 TokenInfo commaToken = tokenizer->consumeNextToken();
351 if (commaToken.kind != TokenKind::Comma) {
352 error->addError(commaToken.range, ErrorType::ParserNoComma)
353 << commaToken.text;
354 return false;
355 }
356 }
357
358 ParserValue argValue;
359 tokenizer->skipNewlines();
360
361 argValue.text = tokenizer->peekNextToken().text;
362 argValue.range = tokenizer->peekNextToken().range;
363 if (!parseExpressionImpl(&argValue.value)) {
364 return false;
365 }
366
367 tokenizer->skipNewlines();
368 args.push_back(argValue);
369 sce.nextArg();
370 }
371
372 return true;
373 }
374
375 // Parse and validate a matcher expression.
parseMatcherExpressionImpl(const TokenInfo & nameToken,const TokenInfo & openToken,std::optional<MatcherCtor> ctor,VariantValue * value)376 bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken,
377 const TokenInfo &openToken,
378 std::optional<MatcherCtor> ctor,
379 VariantValue *value) {
380 if (!ctor) {
381 error->addError(nameToken.range, ErrorType::RegistryMatcherNotFound)
382 << nameToken.text;
383 // Do not return here. We need to continue to give completion suggestions.
384 }
385
386 std::vector<ParserValue> args;
387 TokenInfo endToken;
388
389 tokenizer->skipNewlines();
390
391 if (!parseMatcherArgs(args, ctor.value_or(nullptr), nameToken, endToken)) {
392 return false;
393 }
394
395 // Check for the missing closing parenthesis
396 if (endToken.kind != TokenKind::CloseParen) {
397 error->addError(openToken.range, ErrorType::ParserNoCloseParen)
398 << nameToken.text;
399 return false;
400 }
401
402 std::string functionName;
403 if (tokenizer->peekNextToken().kind == TokenKind::Period) {
404 tokenizer->consumeNextToken();
405 TokenInfo chainCallToken = tokenizer->consumeNextToken();
406 if (chainCallToken.kind == TokenKind::CodeCompletion) {
407 addCompletion(chainCallToken, MatcherCompletion("extract(\"", "extract"));
408 return false;
409 }
410
411 if (chainCallToken.kind != TokenKind::Ident ||
412 chainCallToken.text != TokenInfo::ID_Extract) {
413 error->addError(chainCallToken.range,
414 ErrorType::ParserMalformedChainedExpr);
415 return false;
416 }
417
418 if (chainCallToken.text == TokenInfo::ID_Extract &&
419 !parseChainedExpression(functionName))
420 return false;
421 }
422
423 if (!ctor)
424 return false;
425 // Merge the start and end infos.
426 SourceRange matcherRange = nameToken.range;
427 matcherRange.end = endToken.range.end;
428 VariantMatcher result = sema->actOnMatcherExpression(
429 *ctor, matcherRange, functionName, args, error);
430 if (result.isNull())
431 return false;
432 *value = result;
433 return true;
434 }
435
436 // If the prefix of this completion matches the completion token, add it to
437 // completions minus the prefix.
addCompletion(const TokenInfo & compToken,const MatcherCompletion & completion)438 void Parser::addCompletion(const TokenInfo &compToken,
439 const MatcherCompletion &completion) {
440 if (llvm::StringRef(completion.typedText).starts_with(compToken.text)) {
441 completions.emplace_back(completion.typedText.substr(compToken.text.size()),
442 completion.matcherDecl);
443 }
444 }
445
446 std::vector<MatcherCompletion>
getNamedValueCompletions(llvm::ArrayRef<ArgKind> acceptedTypes)447 Parser::getNamedValueCompletions(llvm::ArrayRef<ArgKind> acceptedTypes) {
448 if (!namedValues)
449 return {};
450
451 std::vector<MatcherCompletion> result;
452 for (const auto &entry : *namedValues) {
453 std::string decl =
454 (entry.getValue().getTypeAsString() + " " + entry.getKey()).str();
455 result.emplace_back(entry.getKey(), decl);
456 }
457 return result;
458 }
459
addExpressionCompletions()460 void Parser::addExpressionCompletions() {
461 const TokenInfo compToken = tokenizer->consumeNextTokenIgnoreNewlines();
462 assert(compToken.kind == TokenKind::CodeCompletion);
463
464 // We cannot complete code if there is an invalid element on the context
465 // stack.
466 for (const auto &entry : contextStack) {
467 if (!entry.first)
468 return;
469 }
470
471 auto acceptedTypes = sema->getAcceptedCompletionTypes(contextStack);
472 for (const auto &completion : sema->getMatcherCompletions(acceptedTypes)) {
473 addCompletion(compToken, completion);
474 }
475
476 for (const auto &completion : getNamedValueCompletions(acceptedTypes)) {
477 addCompletion(compToken, completion);
478 }
479 }
480
481 // Parse an <Expresssion>
parseExpressionImpl(VariantValue * value)482 bool Parser::parseExpressionImpl(VariantValue *value) {
483 switch (tokenizer->nextTokenKind()) {
484 case TokenKind::Literal:
485 *value = tokenizer->consumeNextToken().value;
486 return true;
487 case TokenKind::Ident:
488 return parseIdentifierPrefixImpl(value);
489 case TokenKind::CodeCompletion:
490 addExpressionCompletions();
491 return false;
492 case TokenKind::Eof:
493 error->addError(tokenizer->consumeNextToken().range,
494 ErrorType::ParserNoCode);
495 return false;
496
497 case TokenKind::Error:
498 // This error was already reported by the tokenizer.
499 return false;
500 case TokenKind::NewLine:
501 case TokenKind::OpenParen:
502 case TokenKind::CloseParen:
503 case TokenKind::Comma:
504 case TokenKind::Period:
505 case TokenKind::InvalidChar:
506 const TokenInfo token = tokenizer->consumeNextToken();
507 error->addError(token.range, ErrorType::ParserInvalidToken)
508 << (token.kind == TokenKind::NewLine ? "NewLine" : token.text);
509 return false;
510 }
511
512 llvm_unreachable("Unknown token kind.");
513 }
514
Parser(CodeTokenizer * tokenizer,const Registry & matcherRegistry,const NamedValueMap * namedValues,Diagnostics * error)515 Parser::Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry,
516 const NamedValueMap *namedValues, Diagnostics *error)
517 : tokenizer(tokenizer),
518 sema(std::make_unique<RegistrySema>(matcherRegistry)),
519 namedValues(namedValues), error(error) {}
520
521 Parser::RegistrySema::~RegistrySema() = default;
522
523 std::optional<MatcherCtor>
lookupMatcherCtor(llvm::StringRef matcherName)524 Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) {
525 return RegistryManager::lookupMatcherCtor(matcherName, matcherRegistry);
526 }
527
actOnMatcherExpression(MatcherCtor ctor,SourceRange nameRange,llvm::StringRef functionName,llvm::ArrayRef<ParserValue> args,Diagnostics * error)528 VariantMatcher Parser::RegistrySema::actOnMatcherExpression(
529 MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
530 llvm::ArrayRef<ParserValue> args, Diagnostics *error) {
531 return RegistryManager::constructMatcher(ctor, nameRange, functionName, args,
532 error);
533 }
534
getAcceptedCompletionTypes(llvm::ArrayRef<std::pair<MatcherCtor,unsigned>> context)535 std::vector<ArgKind> Parser::RegistrySema::getAcceptedCompletionTypes(
536 llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) {
537 return RegistryManager::getAcceptedCompletionTypes(context);
538 }
539
getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes)540 std::vector<MatcherCompletion> Parser::RegistrySema::getMatcherCompletions(
541 llvm::ArrayRef<ArgKind> acceptedTypes) {
542 return RegistryManager::getMatcherCompletions(acceptedTypes, matcherRegistry);
543 }
544
parseExpression(llvm::StringRef & code,const Registry & matcherRegistry,const NamedValueMap * namedValues,VariantValue * value,Diagnostics * error)545 bool Parser::parseExpression(llvm::StringRef &code,
546 const Registry &matcherRegistry,
547 const NamedValueMap *namedValues,
548 VariantValue *value, Diagnostics *error) {
549 CodeTokenizer tokenizer(code, error);
550 Parser parser(&tokenizer, matcherRegistry, namedValues, error);
551 if (!parser.parseExpressionImpl(value))
552 return false;
553 auto nextToken = tokenizer.peekNextToken();
554 if (nextToken.kind != TokenKind::Eof &&
555 nextToken.kind != TokenKind::NewLine) {
556 error->addError(tokenizer.peekNextToken().range,
557 ErrorType::ParserTrailingCode);
558 return false;
559 }
560 return true;
561 }
562
563 std::vector<MatcherCompletion>
completeExpression(llvm::StringRef & code,unsigned completionOffset,const Registry & matcherRegistry,const NamedValueMap * namedValues)564 Parser::completeExpression(llvm::StringRef &code, unsigned completionOffset,
565 const Registry &matcherRegistry,
566 const NamedValueMap *namedValues) {
567 Diagnostics error;
568 CodeTokenizer tokenizer(code, &error, completionOffset);
569 Parser parser(&tokenizer, matcherRegistry, namedValues, &error);
570 VariantValue dummy;
571 parser.parseExpressionImpl(&dummy);
572
573 return parser.completions;
574 }
575
parseMatcherExpression(llvm::StringRef & code,const Registry & matcherRegistry,const NamedValueMap * namedValues,Diagnostics * error)576 std::optional<DynMatcher> Parser::parseMatcherExpression(
577 llvm::StringRef &code, const Registry &matcherRegistry,
578 const NamedValueMap *namedValues, Diagnostics *error) {
579 VariantValue value;
580 if (!parseExpression(code, matcherRegistry, namedValues, &value, error))
581 return std::nullopt;
582 if (!value.isMatcher()) {
583 error->addError(SourceRange(), ErrorType::ParserNotAMatcher);
584 return std::nullopt;
585 }
586 std::optional<DynMatcher> result = value.getMatcher().getDynMatcher();
587 if (!result) {
588 error->addError(SourceRange(), ErrorType::ParserOverloadedType)
589 << value.getTypeAsString();
590 }
591 return result;
592 }
593
594 } // namespace mlir::query::matcher::internal
595