1 //===- FormatGen.cpp - Utilities for custom assembly formats ----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "FormatGen.h" 10 #include "llvm/ADT/StringSwitch.h" 11 #include "llvm/Support/SourceMgr.h" 12 #include "llvm/TableGen/Error.h" 13 14 using namespace mlir; 15 using namespace mlir::tblgen; 16 using llvm::SourceMgr; 17 18 //===----------------------------------------------------------------------===// 19 // FormatToken 20 //===----------------------------------------------------------------------===// 21 22 SMLoc FormatToken::getLoc() const { 23 return SMLoc::getFromPointer(spelling.data()); 24 } 25 26 //===----------------------------------------------------------------------===// 27 // FormatLexer 28 //===----------------------------------------------------------------------===// 29 30 FormatLexer::FormatLexer(SourceMgr &mgr, SMLoc loc) 31 : mgr(mgr), loc(loc), 32 curBuffer(mgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer()), 33 curPtr(curBuffer.begin()) {} 34 35 FormatToken FormatLexer::emitError(SMLoc loc, const Twine &msg) { 36 mgr.PrintMessage(loc, SourceMgr::DK_Error, msg); 37 llvm::SrcMgr.PrintMessage(this->loc, SourceMgr::DK_Note, 38 "in custom assembly format for this operation"); 39 return formToken(FormatToken::error, loc.getPointer()); 40 } 41 42 FormatToken FormatLexer::emitError(const char *loc, const Twine &msg) { 43 return emitError(SMLoc::getFromPointer(loc), msg); 44 } 45 46 FormatToken FormatLexer::emitErrorAndNote(SMLoc loc, const Twine &msg, 47 const Twine ¬e) { 48 mgr.PrintMessage(loc, SourceMgr::DK_Error, msg); 49 llvm::SrcMgr.PrintMessage(this->loc, SourceMgr::DK_Note, 50 "in custom assembly format for this operation"); 51 mgr.PrintMessage(loc, SourceMgr::DK_Note, note); 52 return formToken(FormatToken::error, loc.getPointer()); 53 } 54 55 int FormatLexer::getNextChar() { 56 char curChar = *curPtr++; 57 switch (curChar) { 58 default: 59 return (unsigned char)curChar; 60 case 0: { 61 // A nul character in the stream is either the end of the current buffer or 62 // a random nul in the file. Disambiguate that here. 63 if (curPtr - 1 != curBuffer.end()) 64 return 0; 65 66 // Otherwise, return end of file. 67 --curPtr; 68 return EOF; 69 } 70 case '\n': 71 case '\r': 72 // Handle the newline character by ignoring it and incrementing the line 73 // count. However, be careful about 'dos style' files with \n\r in them. 74 // Only treat a \n\r or \r\n as a single line. 75 if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar) 76 ++curPtr; 77 return '\n'; 78 } 79 } 80 81 FormatToken FormatLexer::lexToken() { 82 const char *tokStart = curPtr; 83 84 // This always consumes at least one character. 85 int curChar = getNextChar(); 86 switch (curChar) { 87 default: 88 // Handle identifiers: [a-zA-Z_] 89 if (isalpha(curChar) || curChar == '_') 90 return lexIdentifier(tokStart); 91 92 // Unknown character, emit an error. 93 return emitError(tokStart, "unexpected character"); 94 case EOF: 95 // Return EOF denoting the end of lexing. 96 return formToken(FormatToken::eof, tokStart); 97 98 // Lex punctuation. 99 case '^': 100 return formToken(FormatToken::caret, tokStart); 101 case ':': 102 return formToken(FormatToken::colon, tokStart); 103 case ',': 104 return formToken(FormatToken::comma, tokStart); 105 case '=': 106 return formToken(FormatToken::equal, tokStart); 107 case '<': 108 return formToken(FormatToken::less, tokStart); 109 case '>': 110 return formToken(FormatToken::greater, tokStart); 111 case '?': 112 return formToken(FormatToken::question, tokStart); 113 case '(': 114 return formToken(FormatToken::l_paren, tokStart); 115 case ')': 116 return formToken(FormatToken::r_paren, tokStart); 117 case '*': 118 return formToken(FormatToken::star, tokStart); 119 case '|': 120 return formToken(FormatToken::pipe, tokStart); 121 122 // Ignore whitespace characters. 123 case 0: 124 case ' ': 125 case '\t': 126 case '\n': 127 return lexToken(); 128 129 case '`': 130 return lexLiteral(tokStart); 131 case '$': 132 return lexVariable(tokStart); 133 case '"': 134 return lexString(tokStart); 135 } 136 } 137 138 FormatToken FormatLexer::lexLiteral(const char *tokStart) { 139 assert(curPtr[-1] == '`'); 140 141 // Lex a literal surrounded by ``. 142 while (const char curChar = *curPtr++) { 143 if (curChar == '`') 144 return formToken(FormatToken::literal, tokStart); 145 } 146 return emitError(curPtr - 1, "unexpected end of file in literal"); 147 } 148 149 FormatToken FormatLexer::lexVariable(const char *tokStart) { 150 if (!isalpha(curPtr[0]) && curPtr[0] != '_') 151 return emitError(curPtr - 1, "expected variable name"); 152 153 // Otherwise, consume the rest of the characters. 154 while (isalnum(*curPtr) || *curPtr == '_') 155 ++curPtr; 156 return formToken(FormatToken::variable, tokStart); 157 } 158 159 FormatToken FormatLexer::lexString(const char *tokStart) { 160 // Lex until another quote, respecting escapes. 161 bool escape = false; 162 while (const char curChar = *curPtr++) { 163 if (!escape && curChar == '"') 164 return formToken(FormatToken::string, tokStart); 165 escape = curChar == '\\'; 166 } 167 return emitError(curPtr - 1, "unexpected end of file in string"); 168 } 169 170 FormatToken FormatLexer::lexIdentifier(const char *tokStart) { 171 // Match the rest of the identifier regex: [0-9a-zA-Z_\-]* 172 while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-') 173 ++curPtr; 174 175 // Check to see if this identifier is a keyword. 176 StringRef str(tokStart, curPtr - tokStart); 177 auto kind = 178 StringSwitch<FormatToken::Kind>(str) 179 .Case("attr-dict", FormatToken::kw_attr_dict) 180 .Case("attr-dict-with-keyword", FormatToken::kw_attr_dict_w_keyword) 181 .Case("prop-dict", FormatToken::kw_prop_dict) 182 .Case("custom", FormatToken::kw_custom) 183 .Case("functional-type", FormatToken::kw_functional_type) 184 .Case("oilist", FormatToken::kw_oilist) 185 .Case("operands", FormatToken::kw_operands) 186 .Case("params", FormatToken::kw_params) 187 .Case("ref", FormatToken::kw_ref) 188 .Case("regions", FormatToken::kw_regions) 189 .Case("results", FormatToken::kw_results) 190 .Case("struct", FormatToken::kw_struct) 191 .Case("successors", FormatToken::kw_successors) 192 .Case("type", FormatToken::kw_type) 193 .Case("qualified", FormatToken::kw_qualified) 194 .Default(FormatToken::identifier); 195 return FormatToken(kind, str); 196 } 197 198 //===----------------------------------------------------------------------===// 199 // FormatParser 200 //===----------------------------------------------------------------------===// 201 202 FormatElement::~FormatElement() = default; 203 204 FormatParser::~FormatParser() = default; 205 206 FailureOr<std::vector<FormatElement *>> FormatParser::parse() { 207 SMLoc loc = curToken.getLoc(); 208 209 // Parse each of the format elements into the main format. 210 std::vector<FormatElement *> elements; 211 while (curToken.getKind() != FormatToken::eof) { 212 FailureOr<FormatElement *> element = parseElement(TopLevelContext); 213 if (failed(element)) 214 return failure(); 215 elements.push_back(*element); 216 } 217 218 // Verify the format. 219 if (failed(verify(loc, elements))) 220 return failure(); 221 return elements; 222 } 223 224 //===----------------------------------------------------------------------===// 225 // Element Parsing 226 227 FailureOr<FormatElement *> FormatParser::parseElement(Context ctx) { 228 if (curToken.is(FormatToken::literal)) 229 return parseLiteral(ctx); 230 if (curToken.is(FormatToken::string)) 231 return parseString(ctx); 232 if (curToken.is(FormatToken::variable)) 233 return parseVariable(ctx); 234 if (curToken.isKeyword()) 235 return parseDirective(ctx); 236 if (curToken.is(FormatToken::l_paren)) 237 return parseOptionalGroup(ctx); 238 return emitError(curToken.getLoc(), 239 "expected literal, variable, directive, or optional group"); 240 } 241 242 FailureOr<FormatElement *> FormatParser::parseLiteral(Context ctx) { 243 FormatToken tok = curToken; 244 SMLoc loc = tok.getLoc(); 245 consumeToken(); 246 247 if (ctx != TopLevelContext) { 248 return emitError( 249 loc, 250 "literals may only be used in the top-level section of the format"); 251 } 252 // Get the spelling without the surrounding backticks. 253 StringRef value = tok.getSpelling(); 254 // Prevents things like `$arg0` or empty literals (when a literal is expected 255 // but not found) from getting segmentation faults. 256 if (value.size() < 2 || value[0] != '`' || value[value.size() - 1] != '`') 257 return emitError(tok.getLoc(), "expected literal, but got '" + value + "'"); 258 value = value.drop_front().drop_back(); 259 260 // The parsed literal is a space element (`` or ` `) or a newline. 261 if (value.empty() || value == " " || value == "\\n") 262 return create<WhitespaceElement>(value); 263 264 // Check that the parsed literal is valid. 265 if (!isValidLiteral(value, [&](Twine msg) { 266 (void)emitError(loc, "expected valid literal but got '" + value + 267 "': " + msg); 268 })) 269 return failure(); 270 return create<LiteralElement>(value); 271 } 272 273 FailureOr<FormatElement *> FormatParser::parseString(Context ctx) { 274 FormatToken tok = curToken; 275 SMLoc loc = tok.getLoc(); 276 consumeToken(); 277 278 if (ctx != CustomDirectiveContext) { 279 return emitError( 280 loc, "strings may only be used as 'custom' directive arguments"); 281 } 282 // Escape the string. 283 std::string value; 284 StringRef contents = tok.getSpelling().drop_front().drop_back(); 285 value.reserve(contents.size()); 286 bool escape = false; 287 for (char c : contents) { 288 escape = c == '\\'; 289 if (!escape) 290 value.push_back(c); 291 } 292 return create<StringElement>(std::move(value)); 293 } 294 295 FailureOr<FormatElement *> FormatParser::parseVariable(Context ctx) { 296 FormatToken tok = curToken; 297 SMLoc loc = tok.getLoc(); 298 consumeToken(); 299 300 // Get the name of the variable without the leading `$`. 301 StringRef name = tok.getSpelling().drop_front(); 302 return parseVariableImpl(loc, name, ctx); 303 } 304 305 FailureOr<FormatElement *> FormatParser::parseDirective(Context ctx) { 306 FormatToken tok = curToken; 307 SMLoc loc = tok.getLoc(); 308 consumeToken(); 309 310 if (tok.is(FormatToken::kw_custom)) 311 return parseCustomDirective(loc, ctx); 312 if (tok.is(FormatToken::kw_ref)) 313 return parseRefDirective(loc, ctx); 314 if (tok.is(FormatToken::kw_qualified)) 315 return parseQualifiedDirective(loc, ctx); 316 return parseDirectiveImpl(loc, tok.getKind(), ctx); 317 } 318 319 FailureOr<FormatElement *> FormatParser::parseOptionalGroup(Context ctx) { 320 SMLoc loc = curToken.getLoc(); 321 consumeToken(); 322 if (ctx != TopLevelContext) { 323 return emitError(loc, 324 "optional groups can only be used as top-level elements"); 325 } 326 327 // Parse the child elements for this optional group. 328 std::vector<FormatElement *> thenElements, elseElements; 329 FormatElement *anchor = nullptr; 330 auto parseChildElements = 331 [this, &anchor](std::vector<FormatElement *> &elements) -> LogicalResult { 332 do { 333 FailureOr<FormatElement *> element = parseElement(TopLevelContext); 334 if (failed(element)) 335 return failure(); 336 // Check for an anchor. 337 if (curToken.is(FormatToken::caret)) { 338 if (anchor) { 339 return emitError(curToken.getLoc(), 340 "only one element can be marked as the anchor of an " 341 "optional group"); 342 } 343 anchor = *element; 344 consumeToken(); 345 } 346 elements.push_back(*element); 347 } while (!curToken.is(FormatToken::r_paren)); 348 return success(); 349 }; 350 351 // Parse the 'then' elements. If the anchor was found in this group, then the 352 // optional is not inverted. 353 if (failed(parseChildElements(thenElements))) 354 return failure(); 355 consumeToken(); 356 bool inverted = !anchor; 357 358 // Parse the `else` elements of this optional group. 359 if (curToken.is(FormatToken::colon)) { 360 consumeToken(); 361 if (failed(parseToken( 362 FormatToken::l_paren, 363 "expected '(' to start else branch of optional group")) || 364 failed(parseChildElements(elseElements))) 365 return failure(); 366 consumeToken(); 367 } 368 if (failed(parseToken(FormatToken::question, 369 "expected '?' after optional group"))) 370 return failure(); 371 372 // The optional group is required to have an anchor. 373 if (!anchor) 374 return emitError(loc, "optional group has no anchor element"); 375 376 // Verify the child elements. 377 if (failed(verifyOptionalGroupElements(loc, thenElements, anchor)) || 378 failed(verifyOptionalGroupElements(loc, elseElements, nullptr))) 379 return failure(); 380 381 // Get the first parsable element. It must be an element that can be 382 // optionally-parsed. 383 auto isWhitespace = [](FormatElement *element) { 384 return isa<WhitespaceElement>(element); 385 }; 386 auto thenParseBegin = llvm::find_if_not(thenElements, isWhitespace); 387 auto elseParseBegin = llvm::find_if_not(elseElements, isWhitespace); 388 unsigned thenParseStart = std::distance(thenElements.begin(), thenParseBegin); 389 unsigned elseParseStart = std::distance(elseElements.begin(), elseParseBegin); 390 391 if (!isa<LiteralElement, VariableElement, CustomDirective>(*thenParseBegin)) { 392 return emitError(loc, "first parsable element of an optional group must be " 393 "a literal, variable, or custom directive"); 394 } 395 return create<OptionalElement>(std::move(thenElements), 396 std::move(elseElements), thenParseStart, 397 elseParseStart, anchor, inverted); 398 } 399 400 FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc, 401 Context ctx) { 402 if (ctx != TopLevelContext) 403 return emitError(loc, "'custom' is only valid as a top-level directive"); 404 405 FailureOr<FormatToken> nameTok; 406 if (failed(parseToken(FormatToken::less, 407 "expected '<' before custom directive name")) || 408 failed(nameTok = 409 parseToken(FormatToken::identifier, 410 "expected custom directive name identifier")) || 411 failed(parseToken(FormatToken::greater, 412 "expected '>' after custom directive name")) || 413 failed(parseToken(FormatToken::l_paren, 414 "expected '(' before custom directive parameters"))) 415 return failure(); 416 417 // Parse the arguments. 418 std::vector<FormatElement *> arguments; 419 while (true) { 420 FailureOr<FormatElement *> argument = parseElement(CustomDirectiveContext); 421 if (failed(argument)) 422 return failure(); 423 arguments.push_back(*argument); 424 if (!curToken.is(FormatToken::comma)) 425 break; 426 consumeToken(); 427 } 428 429 if (failed(parseToken(FormatToken::r_paren, 430 "expected ')' after custom directive parameters"))) 431 return failure(); 432 433 if (failed(verifyCustomDirectiveArguments(loc, arguments))) 434 return failure(); 435 return create<CustomDirective>(nameTok->getSpelling(), std::move(arguments)); 436 } 437 438 FailureOr<FormatElement *> FormatParser::parseRefDirective(SMLoc loc, 439 Context context) { 440 if (context != CustomDirectiveContext) 441 return emitError(loc, "'ref' is only valid within a `custom` directive"); 442 443 FailureOr<FormatElement *> arg; 444 if (failed(parseToken(FormatToken::l_paren, 445 "expected '(' before argument list")) || 446 failed(arg = parseElement(RefDirectiveContext)) || 447 failed( 448 parseToken(FormatToken::r_paren, "expected ')' after argument list"))) 449 return failure(); 450 451 return create<RefDirective>(*arg); 452 } 453 454 FailureOr<FormatElement *> FormatParser::parseQualifiedDirective(SMLoc loc, 455 Context ctx) { 456 if (failed(parseToken(FormatToken::l_paren, 457 "expected '(' before argument list"))) 458 return failure(); 459 FailureOr<FormatElement *> var = parseElement(ctx); 460 if (failed(var)) 461 return var; 462 if (failed(markQualified(loc, *var))) 463 return failure(); 464 if (failed( 465 parseToken(FormatToken::r_paren, "expected ')' after argument list"))) 466 return failure(); 467 return var; 468 } 469 470 //===----------------------------------------------------------------------===// 471 // Utility Functions 472 //===----------------------------------------------------------------------===// 473 474 bool mlir::tblgen::shouldEmitSpaceBefore(StringRef value, 475 bool lastWasPunctuation) { 476 if (value.size() != 1 && value != "->") 477 return true; 478 if (lastWasPunctuation) 479 return !StringRef(">)}],").contains(value.front()); 480 return !StringRef("<>(){}[],").contains(value.front()); 481 } 482 483 bool mlir::tblgen::canFormatStringAsKeyword( 484 StringRef value, function_ref<void(Twine)> emitError) { 485 if (value.empty()) { 486 if (emitError) 487 emitError("keywords cannot be empty"); 488 return false; 489 } 490 if (!isalpha(value.front()) && value.front() != '_') { 491 if (emitError) 492 emitError("valid keyword starts with a letter or '_'"); 493 return false; 494 } 495 if (!llvm::all_of(value.drop_front(), [](char c) { 496 return isalnum(c) || c == '_' || c == '$' || c == '.'; 497 })) { 498 if (emitError) 499 emitError( 500 "keywords should contain only alphanum, '_', '$', or '.' characters"); 501 return false; 502 } 503 return true; 504 } 505 506 bool mlir::tblgen::isValidLiteral(StringRef value, 507 function_ref<void(Twine)> emitError) { 508 if (value.empty()) { 509 if (emitError) 510 emitError("literal can't be empty"); 511 return false; 512 } 513 char front = value.front(); 514 515 // If there is only one character, this must either be punctuation or a 516 // single character bare identifier. 517 if (value.size() == 1) { 518 StringRef bare = "_:,=<>()[]{}?+*"; 519 if (isalpha(front) || bare.contains(front)) 520 return true; 521 if (emitError) 522 emitError("single character literal must be a letter or one of '" + bare + 523 "'"); 524 return false; 525 } 526 // Check the punctuation that are larger than a single character. 527 if (value == "->") 528 return true; 529 if (value == "...") 530 return true; 531 532 // Otherwise, this must be an identifier. 533 return canFormatStringAsKeyword(value, emitError); 534 } 535 536 //===----------------------------------------------------------------------===// 537 // Commandline Options 538 //===----------------------------------------------------------------------===// 539 540 llvm::cl::opt<bool> mlir::tblgen::formatErrorIsFatal( 541 "asmformat-error-is-fatal", 542 llvm::cl::desc("Emit a fatal error if format parsing fails"), 543 llvm::cl::init(true)); 544