1 //===- Parser.cpp - Matcher expression parser -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Recursive parser implementation for the matcher expression grammar. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Parser.h" 14 15 #include <vector> 16 17 namespace mlir::query::matcher::internal { 18 19 // Simple structure to hold information for one token from the parser. 20 struct Parser::TokenInfo { 21 TokenInfo() = default; 22 23 // Method to set the kind and text of the token 24 void set(TokenKind newKind, llvm::StringRef newText) { 25 kind = newKind; 26 text = newText; 27 } 28 29 llvm::StringRef text; 30 TokenKind kind = TokenKind::Eof; 31 SourceRange range; 32 VariantValue value; 33 }; 34 35 class Parser::CodeTokenizer { 36 public: 37 // Constructor with matcherCode and error 38 explicit CodeTokenizer(llvm::StringRef matcherCode, Diagnostics *error) 39 : code(matcherCode), startOfLine(matcherCode), error(error) { 40 nextToken = getNextToken(); 41 } 42 43 // Constructor with matcherCode, error, and codeCompletionOffset 44 CodeTokenizer(llvm::StringRef matcherCode, Diagnostics *error, 45 unsigned codeCompletionOffset) 46 : code(matcherCode), startOfLine(matcherCode), error(error), 47 codeCompletionLocation(matcherCode.data() + codeCompletionOffset) { 48 nextToken = getNextToken(); 49 } 50 51 // Peek at next token without consuming it 52 const TokenInfo &peekNextToken() const { return nextToken; } 53 54 // Consume and return the next token 55 TokenInfo consumeNextToken() { 56 TokenInfo thisToken = nextToken; 57 nextToken = getNextToken(); 58 return thisToken; 59 } 60 61 // Skip any newline tokens 62 TokenInfo skipNewlines() { 63 while (nextToken.kind == TokenKind::NewLine) 64 nextToken = getNextToken(); 65 return nextToken; 66 } 67 68 // Consume and return next token, ignoring newlines 69 TokenInfo consumeNextTokenIgnoreNewlines() { 70 skipNewlines(); 71 return nextToken.kind == TokenKind::Eof ? nextToken : consumeNextToken(); 72 } 73 74 // Return kind of next token 75 TokenKind nextTokenKind() const { return nextToken.kind; } 76 77 private: 78 // Helper function to get the first character as a new StringRef and drop it 79 // from the original string 80 llvm::StringRef firstCharacterAndDrop(llvm::StringRef &str) { 81 assert(!str.empty()); 82 llvm::StringRef firstChar = str.substr(0, 1); 83 str = str.drop_front(); 84 return firstChar; 85 } 86 87 // Get next token, consuming whitespaces and handling different token types 88 TokenInfo getNextToken() { 89 consumeWhitespace(); 90 TokenInfo result; 91 result.range.start = currentLocation(); 92 93 // Code completion case 94 if (codeCompletionLocation && codeCompletionLocation <= code.data()) { 95 result.set(TokenKind::CodeCompletion, 96 llvm::StringRef(codeCompletionLocation, 0)); 97 codeCompletionLocation = nullptr; 98 return result; 99 } 100 101 // End of file case 102 if (code.empty()) { 103 result.set(TokenKind::Eof, ""); 104 return result; 105 } 106 107 // Switch to handle specific characters 108 switch (code[0]) { 109 case '#': 110 code = code.drop_until([](char c) { return c == '\n'; }); 111 return getNextToken(); 112 case ',': 113 result.set(TokenKind::Comma, firstCharacterAndDrop(code)); 114 break; 115 case '.': 116 result.set(TokenKind::Period, firstCharacterAndDrop(code)); 117 break; 118 case '\n': 119 ++line; 120 startOfLine = code.drop_front(); 121 result.set(TokenKind::NewLine, firstCharacterAndDrop(code)); 122 break; 123 case '(': 124 result.set(TokenKind::OpenParen, firstCharacterAndDrop(code)); 125 break; 126 case ')': 127 result.set(TokenKind::CloseParen, firstCharacterAndDrop(code)); 128 break; 129 case '"': 130 case '\'': 131 consumeStringLiteral(&result); 132 break; 133 default: 134 parseIdentifierOrInvalid(&result); 135 break; 136 } 137 138 result.range.end = currentLocation(); 139 return result; 140 } 141 142 // Consume a string literal, handle escape sequences and missing closing 143 // quote. 144 void consumeStringLiteral(TokenInfo *result) { 145 bool inEscape = false; 146 const char marker = code[0]; 147 for (size_t length = 1; length < code.size(); ++length) { 148 if (inEscape) { 149 inEscape = false; 150 continue; 151 } 152 if (code[length] == '\\') { 153 inEscape = true; 154 continue; 155 } 156 if (code[length] == marker) { 157 result->kind = TokenKind::Literal; 158 result->text = code.substr(0, length + 1); 159 result->value = code.substr(1, length - 1); 160 code = code.drop_front(length + 1); 161 return; 162 } 163 } 164 llvm::StringRef errorText = code; 165 code = code.drop_front(code.size()); 166 SourceRange range; 167 range.start = result->range.start; 168 range.end = currentLocation(); 169 error->addError(range, ErrorType::ParserStringError) << errorText; 170 result->kind = TokenKind::Error; 171 } 172 173 void parseIdentifierOrInvalid(TokenInfo *result) { 174 if (isalnum(code[0])) { 175 // Parse an identifier 176 size_t tokenLength = 1; 177 178 while (true) { 179 // A code completion location in/immediately after an identifier will 180 // cause the portion of the identifier before the code completion 181 // location to become a code completion token. 182 if (codeCompletionLocation == code.data() + tokenLength) { 183 codeCompletionLocation = nullptr; 184 result->kind = TokenKind::CodeCompletion; 185 result->text = code.substr(0, tokenLength); 186 code = code.drop_front(tokenLength); 187 return; 188 } 189 if (tokenLength == code.size() || !(isalnum(code[tokenLength]))) 190 break; 191 ++tokenLength; 192 } 193 result->kind = TokenKind::Ident; 194 result->text = code.substr(0, tokenLength); 195 code = code.drop_front(tokenLength); 196 } else { 197 result->kind = TokenKind::InvalidChar; 198 result->text = code.substr(0, 1); 199 code = code.drop_front(1); 200 } 201 } 202 203 // Consume all leading whitespace from code, except newlines 204 void consumeWhitespace() { code = code.ltrim(" \t\v\f\r"); } 205 206 // Returns the current location in the source code 207 SourceLocation currentLocation() { 208 SourceLocation location; 209 location.line = line; 210 location.column = code.data() - startOfLine.data() + 1; 211 return location; 212 } 213 214 llvm::StringRef code; 215 llvm::StringRef startOfLine; 216 unsigned line = 1; 217 Diagnostics *error; 218 TokenInfo nextToken; 219 const char *codeCompletionLocation = nullptr; 220 }; 221 222 Parser::Sema::~Sema() = default; 223 224 std::vector<ArgKind> Parser::Sema::getAcceptedCompletionTypes( 225 llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) { 226 return {}; 227 } 228 229 std::vector<MatcherCompletion> 230 Parser::Sema::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes) { 231 return {}; 232 } 233 234 // Entry for the scope of a parser 235 struct Parser::ScopedContextEntry { 236 Parser *parser; 237 238 ScopedContextEntry(Parser *parser, MatcherCtor c) : parser(parser) { 239 parser->contextStack.emplace_back(c, 0u); 240 } 241 242 ~ScopedContextEntry() { parser->contextStack.pop_back(); } 243 244 void nextArg() { ++parser->contextStack.back().second; } 245 }; 246 247 // Parse and validate expressions starting with an identifier. 248 // This function can parse named values and matchers. In case of failure, it 249 // will try to determine the user's intent to give an appropriate error message. 250 bool Parser::parseIdentifierPrefixImpl(VariantValue *value) { 251 const TokenInfo nameToken = tokenizer->consumeNextToken(); 252 253 if (tokenizer->nextTokenKind() != TokenKind::OpenParen) { 254 // Parse as a named value. 255 auto namedValue = 256 namedValues ? namedValues->lookup(nameToken.text) : VariantValue(); 257 258 if (!namedValue.isMatcher()) { 259 error->addError(tokenizer->peekNextToken().range, 260 ErrorType::ParserNotAMatcher); 261 return false; 262 } 263 264 if (tokenizer->nextTokenKind() == TokenKind::NewLine) { 265 error->addError(tokenizer->peekNextToken().range, 266 ErrorType::ParserNoOpenParen) 267 << "NewLine"; 268 return false; 269 } 270 271 // If the syntax is correct and the name is not a matcher either, report 272 // an unknown named value. 273 if ((tokenizer->nextTokenKind() == TokenKind::Comma || 274 tokenizer->nextTokenKind() == TokenKind::CloseParen || 275 tokenizer->nextTokenKind() == TokenKind::NewLine || 276 tokenizer->nextTokenKind() == TokenKind::Eof) && 277 !sema->lookupMatcherCtor(nameToken.text)) { 278 error->addError(nameToken.range, ErrorType::RegistryValueNotFound) 279 << nameToken.text; 280 return false; 281 } 282 // Otherwise, fallback to the matcher parser. 283 } 284 285 tokenizer->skipNewlines(); 286 287 assert(nameToken.kind == TokenKind::Ident); 288 TokenInfo openToken = tokenizer->consumeNextToken(); 289 if (openToken.kind != TokenKind::OpenParen) { 290 error->addError(openToken.range, ErrorType::ParserNoOpenParen) 291 << openToken.text; 292 return false; 293 } 294 295 std::optional<MatcherCtor> ctor = sema->lookupMatcherCtor(nameToken.text); 296 297 // Parse as a matcher expression. 298 return parseMatcherExpressionImpl(nameToken, openToken, ctor, value); 299 } 300 301 // Parse the arguments of a matcher 302 bool Parser::parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor, 303 const TokenInfo &nameToken, TokenInfo &endToken) { 304 ScopedContextEntry sce(this, ctor); 305 306 while (tokenizer->nextTokenKind() != TokenKind::Eof) { 307 if (tokenizer->nextTokenKind() == TokenKind::CloseParen) { 308 // end of args. 309 endToken = tokenizer->consumeNextToken(); 310 break; 311 } 312 313 if (!args.empty()) { 314 // We must find a , token to continue. 315 TokenInfo commaToken = tokenizer->consumeNextToken(); 316 if (commaToken.kind != TokenKind::Comma) { 317 error->addError(commaToken.range, ErrorType::ParserNoComma) 318 << commaToken.text; 319 return false; 320 } 321 } 322 323 ParserValue argValue; 324 tokenizer->skipNewlines(); 325 326 argValue.text = tokenizer->peekNextToken().text; 327 argValue.range = tokenizer->peekNextToken().range; 328 if (!parseExpressionImpl(&argValue.value)) { 329 return false; 330 } 331 332 tokenizer->skipNewlines(); 333 args.push_back(argValue); 334 sce.nextArg(); 335 } 336 337 return true; 338 } 339 340 // Parse and validate a matcher expression. 341 bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken, 342 const TokenInfo &openToken, 343 std::optional<MatcherCtor> ctor, 344 VariantValue *value) { 345 if (!ctor) { 346 error->addError(nameToken.range, ErrorType::RegistryMatcherNotFound) 347 << nameToken.text; 348 // Do not return here. We need to continue to give completion suggestions. 349 } 350 351 std::vector<ParserValue> args; 352 TokenInfo endToken; 353 354 tokenizer->skipNewlines(); 355 356 if (!parseMatcherArgs(args, ctor.value_or(nullptr), nameToken, endToken)) { 357 return false; 358 } 359 360 // Check for the missing closing parenthesis 361 if (endToken.kind != TokenKind::CloseParen) { 362 error->addError(openToken.range, ErrorType::ParserNoCloseParen) 363 << nameToken.text; 364 return false; 365 } 366 367 if (!ctor) 368 return false; 369 // Merge the start and end infos. 370 SourceRange matcherRange = nameToken.range; 371 matcherRange.end = endToken.range.end; 372 VariantMatcher result = 373 sema->actOnMatcherExpression(*ctor, matcherRange, args, error); 374 if (result.isNull()) 375 return false; 376 *value = result; 377 return true; 378 } 379 380 // If the prefix of this completion matches the completion token, add it to 381 // completions minus the prefix. 382 void Parser::addCompletion(const TokenInfo &compToken, 383 const MatcherCompletion &completion) { 384 if (llvm::StringRef(completion.typedText).starts_with(compToken.text)) { 385 completions.emplace_back(completion.typedText.substr(compToken.text.size()), 386 completion.matcherDecl); 387 } 388 } 389 390 std::vector<MatcherCompletion> 391 Parser::getNamedValueCompletions(llvm::ArrayRef<ArgKind> acceptedTypes) { 392 if (!namedValues) 393 return {}; 394 395 std::vector<MatcherCompletion> result; 396 for (const auto &entry : *namedValues) { 397 std::string decl = 398 (entry.getValue().getTypeAsString() + " " + entry.getKey()).str(); 399 result.emplace_back(entry.getKey(), decl); 400 } 401 return result; 402 } 403 404 void Parser::addExpressionCompletions() { 405 const TokenInfo compToken = tokenizer->consumeNextTokenIgnoreNewlines(); 406 assert(compToken.kind == TokenKind::CodeCompletion); 407 408 // We cannot complete code if there is an invalid element on the context 409 // stack. 410 for (const auto &entry : contextStack) { 411 if (!entry.first) 412 return; 413 } 414 415 auto acceptedTypes = sema->getAcceptedCompletionTypes(contextStack); 416 for (const auto &completion : sema->getMatcherCompletions(acceptedTypes)) { 417 addCompletion(compToken, completion); 418 } 419 420 for (const auto &completion : getNamedValueCompletions(acceptedTypes)) { 421 addCompletion(compToken, completion); 422 } 423 } 424 425 // Parse an <Expresssion> 426 bool Parser::parseExpressionImpl(VariantValue *value) { 427 switch (tokenizer->nextTokenKind()) { 428 case TokenKind::Literal: 429 *value = tokenizer->consumeNextToken().value; 430 return true; 431 case TokenKind::Ident: 432 return parseIdentifierPrefixImpl(value); 433 case TokenKind::CodeCompletion: 434 addExpressionCompletions(); 435 return false; 436 case TokenKind::Eof: 437 error->addError(tokenizer->consumeNextToken().range, 438 ErrorType::ParserNoCode); 439 return false; 440 441 case TokenKind::Error: 442 // This error was already reported by the tokenizer. 443 return false; 444 case TokenKind::NewLine: 445 case TokenKind::OpenParen: 446 case TokenKind::CloseParen: 447 case TokenKind::Comma: 448 case TokenKind::Period: 449 case TokenKind::InvalidChar: 450 const TokenInfo token = tokenizer->consumeNextToken(); 451 error->addError(token.range, ErrorType::ParserInvalidToken) 452 << (token.kind == TokenKind::NewLine ? "NewLine" : token.text); 453 return false; 454 } 455 456 llvm_unreachable("Unknown token kind."); 457 } 458 459 Parser::Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry, 460 const NamedValueMap *namedValues, Diagnostics *error) 461 : tokenizer(tokenizer), 462 sema(std::make_unique<RegistrySema>(matcherRegistry)), 463 namedValues(namedValues), error(error) {} 464 465 Parser::RegistrySema::~RegistrySema() = default; 466 467 std::optional<MatcherCtor> 468 Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) { 469 return RegistryManager::lookupMatcherCtor(matcherName, matcherRegistry); 470 } 471 472 VariantMatcher Parser::RegistrySema::actOnMatcherExpression( 473 MatcherCtor ctor, SourceRange nameRange, llvm::ArrayRef<ParserValue> args, 474 Diagnostics *error) { 475 return RegistryManager::constructMatcher(ctor, nameRange, args, error); 476 } 477 478 std::vector<ArgKind> Parser::RegistrySema::getAcceptedCompletionTypes( 479 llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) { 480 return RegistryManager::getAcceptedCompletionTypes(context); 481 } 482 483 std::vector<MatcherCompletion> Parser::RegistrySema::getMatcherCompletions( 484 llvm::ArrayRef<ArgKind> acceptedTypes) { 485 return RegistryManager::getMatcherCompletions(acceptedTypes, matcherRegistry); 486 } 487 488 bool Parser::parseExpression(llvm::StringRef &code, 489 const Registry &matcherRegistry, 490 const NamedValueMap *namedValues, 491 VariantValue *value, Diagnostics *error) { 492 CodeTokenizer tokenizer(code, error); 493 Parser parser(&tokenizer, matcherRegistry, namedValues, error); 494 if (!parser.parseExpressionImpl(value)) 495 return false; 496 auto nextToken = tokenizer.peekNextToken(); 497 if (nextToken.kind != TokenKind::Eof && 498 nextToken.kind != TokenKind::NewLine) { 499 error->addError(tokenizer.peekNextToken().range, 500 ErrorType::ParserTrailingCode); 501 return false; 502 } 503 return true; 504 } 505 506 std::vector<MatcherCompletion> 507 Parser::completeExpression(llvm::StringRef &code, unsigned completionOffset, 508 const Registry &matcherRegistry, 509 const NamedValueMap *namedValues) { 510 Diagnostics error; 511 CodeTokenizer tokenizer(code, &error, completionOffset); 512 Parser parser(&tokenizer, matcherRegistry, namedValues, &error); 513 VariantValue dummy; 514 parser.parseExpressionImpl(&dummy); 515 516 return parser.completions; 517 } 518 519 std::optional<DynMatcher> Parser::parseMatcherExpression( 520 llvm::StringRef &code, const Registry &matcherRegistry, 521 const NamedValueMap *namedValues, Diagnostics *error) { 522 VariantValue value; 523 if (!parseExpression(code, matcherRegistry, namedValues, &value, error)) 524 return std::nullopt; 525 if (!value.isMatcher()) { 526 error->addError(SourceRange(), ErrorType::ParserNotAMatcher); 527 return std::nullopt; 528 } 529 std::optional<DynMatcher> result = value.getMatcher().getDynMatcher(); 530 if (!result) { 531 error->addError(SourceRange(), ErrorType::ParserOverloadedType) 532 << value.getTypeAsString(); 533 } 534 return result; 535 } 536 537 } // namespace mlir::query::matcher::internal 538