xref: /llvm-project/mlir/tools/mlir-tblgen/FormatGen.cpp (revision bccd37f69fdc7b5cd00d9231cabbe74bfe38f598)
1 //===- FormatGen.cpp - Utilities for custom assembly formats ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "FormatGen.h"
10 #include "llvm/ADT/StringSwitch.h"
11 #include "llvm/Support/SourceMgr.h"
12 #include "llvm/TableGen/Error.h"
13 
14 using namespace mlir;
15 using namespace mlir::tblgen;
16 using llvm::SourceMgr;
17 
18 //===----------------------------------------------------------------------===//
19 // FormatToken
20 //===----------------------------------------------------------------------===//
21 
22 SMLoc FormatToken::getLoc() const {
23   return SMLoc::getFromPointer(spelling.data());
24 }
25 
26 //===----------------------------------------------------------------------===//
27 // FormatLexer
28 //===----------------------------------------------------------------------===//
29 
30 FormatLexer::FormatLexer(SourceMgr &mgr, SMLoc loc)
31     : mgr(mgr), loc(loc),
32       curBuffer(mgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer()),
33       curPtr(curBuffer.begin()) {}
34 
35 FormatToken FormatLexer::emitError(SMLoc loc, const Twine &msg) {
36   mgr.PrintMessage(loc, SourceMgr::DK_Error, msg);
37   llvm::SrcMgr.PrintMessage(this->loc, SourceMgr::DK_Note,
38                             "in custom assembly format for this operation");
39   return formToken(FormatToken::error, loc.getPointer());
40 }
41 
42 FormatToken FormatLexer::emitError(const char *loc, const Twine &msg) {
43   return emitError(SMLoc::getFromPointer(loc), msg);
44 }
45 
46 FormatToken FormatLexer::emitErrorAndNote(SMLoc loc, const Twine &msg,
47                                           const Twine &note) {
48   mgr.PrintMessage(loc, SourceMgr::DK_Error, msg);
49   llvm::SrcMgr.PrintMessage(this->loc, SourceMgr::DK_Note,
50                             "in custom assembly format for this operation");
51   mgr.PrintMessage(loc, SourceMgr::DK_Note, note);
52   return formToken(FormatToken::error, loc.getPointer());
53 }
54 
55 int FormatLexer::getNextChar() {
56   char curChar = *curPtr++;
57   switch (curChar) {
58   default:
59     return (unsigned char)curChar;
60   case 0: {
61     // A nul character in the stream is either the end of the current buffer or
62     // a random nul in the file. Disambiguate that here.
63     if (curPtr - 1 != curBuffer.end())
64       return 0;
65 
66     // Otherwise, return end of file.
67     --curPtr;
68     return EOF;
69   }
70   case '\n':
71   case '\r':
72     // Handle the newline character by ignoring it and incrementing the line
73     // count. However, be careful about 'dos style' files with \n\r in them.
74     // Only treat a \n\r or \r\n as a single line.
75     if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
76       ++curPtr;
77     return '\n';
78   }
79 }
80 
81 FormatToken FormatLexer::lexToken() {
82   const char *tokStart = curPtr;
83 
84   // This always consumes at least one character.
85   int curChar = getNextChar();
86   switch (curChar) {
87   default:
88     // Handle identifiers: [a-zA-Z_]
89     if (isalpha(curChar) || curChar == '_')
90       return lexIdentifier(tokStart);
91 
92     // Unknown character, emit an error.
93     return emitError(tokStart, "unexpected character");
94   case EOF:
95     // Return EOF denoting the end of lexing.
96     return formToken(FormatToken::eof, tokStart);
97 
98   // Lex punctuation.
99   case '^':
100     return formToken(FormatToken::caret, tokStart);
101   case ':':
102     return formToken(FormatToken::colon, tokStart);
103   case ',':
104     return formToken(FormatToken::comma, tokStart);
105   case '=':
106     return formToken(FormatToken::equal, tokStart);
107   case '<':
108     return formToken(FormatToken::less, tokStart);
109   case '>':
110     return formToken(FormatToken::greater, tokStart);
111   case '?':
112     return formToken(FormatToken::question, tokStart);
113   case '(':
114     return formToken(FormatToken::l_paren, tokStart);
115   case ')':
116     return formToken(FormatToken::r_paren, tokStart);
117   case '*':
118     return formToken(FormatToken::star, tokStart);
119   case '|':
120     return formToken(FormatToken::pipe, tokStart);
121 
122   // Ignore whitespace characters.
123   case 0:
124   case ' ':
125   case '\t':
126   case '\n':
127     return lexToken();
128 
129   case '`':
130     return lexLiteral(tokStart);
131   case '$':
132     return lexVariable(tokStart);
133   case '"':
134     return lexString(tokStart);
135   }
136 }
137 
138 FormatToken FormatLexer::lexLiteral(const char *tokStart) {
139   assert(curPtr[-1] == '`');
140 
141   // Lex a literal surrounded by ``.
142   while (const char curChar = *curPtr++) {
143     if (curChar == '`')
144       return formToken(FormatToken::literal, tokStart);
145   }
146   return emitError(curPtr - 1, "unexpected end of file in literal");
147 }
148 
149 FormatToken FormatLexer::lexVariable(const char *tokStart) {
150   if (!isalpha(curPtr[0]) && curPtr[0] != '_')
151     return emitError(curPtr - 1, "expected variable name");
152 
153   // Otherwise, consume the rest of the characters.
154   while (isalnum(*curPtr) || *curPtr == '_')
155     ++curPtr;
156   return formToken(FormatToken::variable, tokStart);
157 }
158 
159 FormatToken FormatLexer::lexString(const char *tokStart) {
160   // Lex until another quote, respecting escapes.
161   bool escape = false;
162   while (const char curChar = *curPtr++) {
163     if (!escape && curChar == '"')
164       return formToken(FormatToken::string, tokStart);
165     escape = curChar == '\\';
166   }
167   return emitError(curPtr - 1, "unexpected end of file in string");
168 }
169 
170 FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
171   // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
172   while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
173     ++curPtr;
174 
175   // Check to see if this identifier is a keyword.
176   StringRef str(tokStart, curPtr - tokStart);
177   auto kind =
178       StringSwitch<FormatToken::Kind>(str)
179           .Case("attr-dict", FormatToken::kw_attr_dict)
180           .Case("attr-dict-with-keyword", FormatToken::kw_attr_dict_w_keyword)
181           .Case("prop-dict", FormatToken::kw_prop_dict)
182           .Case("custom", FormatToken::kw_custom)
183           .Case("functional-type", FormatToken::kw_functional_type)
184           .Case("oilist", FormatToken::kw_oilist)
185           .Case("operands", FormatToken::kw_operands)
186           .Case("params", FormatToken::kw_params)
187           .Case("ref", FormatToken::kw_ref)
188           .Case("regions", FormatToken::kw_regions)
189           .Case("results", FormatToken::kw_results)
190           .Case("struct", FormatToken::kw_struct)
191           .Case("successors", FormatToken::kw_successors)
192           .Case("type", FormatToken::kw_type)
193           .Case("qualified", FormatToken::kw_qualified)
194           .Default(FormatToken::identifier);
195   return FormatToken(kind, str);
196 }
197 
198 //===----------------------------------------------------------------------===//
199 // FormatParser
200 //===----------------------------------------------------------------------===//
201 
202 FormatElement::~FormatElement() = default;
203 
204 FormatParser::~FormatParser() = default;
205 
206 FailureOr<std::vector<FormatElement *>> FormatParser::parse() {
207   SMLoc loc = curToken.getLoc();
208 
209   // Parse each of the format elements into the main format.
210   std::vector<FormatElement *> elements;
211   while (curToken.getKind() != FormatToken::eof) {
212     FailureOr<FormatElement *> element = parseElement(TopLevelContext);
213     if (failed(element))
214       return failure();
215     elements.push_back(*element);
216   }
217 
218   // Verify the format.
219   if (failed(verify(loc, elements)))
220     return failure();
221   return elements;
222 }
223 
224 //===----------------------------------------------------------------------===//
225 // Element Parsing
226 
227 FailureOr<FormatElement *> FormatParser::parseElement(Context ctx) {
228   if (curToken.is(FormatToken::literal))
229     return parseLiteral(ctx);
230   if (curToken.is(FormatToken::string))
231     return parseString(ctx);
232   if (curToken.is(FormatToken::variable))
233     return parseVariable(ctx);
234   if (curToken.isKeyword())
235     return parseDirective(ctx);
236   if (curToken.is(FormatToken::l_paren))
237     return parseOptionalGroup(ctx);
238   return emitError(curToken.getLoc(),
239                    "expected literal, variable, directive, or optional group");
240 }
241 
242 FailureOr<FormatElement *> FormatParser::parseLiteral(Context ctx) {
243   FormatToken tok = curToken;
244   SMLoc loc = tok.getLoc();
245   consumeToken();
246 
247   if (ctx != TopLevelContext) {
248     return emitError(
249         loc,
250         "literals may only be used in the top-level section of the format");
251   }
252   // Get the spelling without the surrounding backticks.
253   StringRef value = tok.getSpelling();
254   // Prevents things like `$arg0` or empty literals (when a literal is expected
255   // but not found) from getting segmentation faults.
256   if (value.size() < 2 || value[0] != '`' || value[value.size() - 1] != '`')
257     return emitError(tok.getLoc(), "expected literal, but got '" + value + "'");
258   value = value.drop_front().drop_back();
259 
260   // The parsed literal is a space element (`` or ` `) or a newline.
261   if (value.empty() || value == " " || value == "\\n")
262     return create<WhitespaceElement>(value);
263 
264   // Check that the parsed literal is valid.
265   if (!isValidLiteral(value, [&](Twine msg) {
266         (void)emitError(loc, "expected valid literal but got '" + value +
267                                  "': " + msg);
268       }))
269     return failure();
270   return create<LiteralElement>(value);
271 }
272 
273 FailureOr<FormatElement *> FormatParser::parseString(Context ctx) {
274   FormatToken tok = curToken;
275   SMLoc loc = tok.getLoc();
276   consumeToken();
277 
278   if (ctx != CustomDirectiveContext) {
279     return emitError(
280         loc, "strings may only be used as 'custom' directive arguments");
281   }
282   // Escape the string.
283   std::string value;
284   StringRef contents = tok.getSpelling().drop_front().drop_back();
285   value.reserve(contents.size());
286   bool escape = false;
287   for (char c : contents) {
288     escape = c == '\\';
289     if (!escape)
290       value.push_back(c);
291   }
292   return create<StringElement>(std::move(value));
293 }
294 
295 FailureOr<FormatElement *> FormatParser::parseVariable(Context ctx) {
296   FormatToken tok = curToken;
297   SMLoc loc = tok.getLoc();
298   consumeToken();
299 
300   // Get the name of the variable without the leading `$`.
301   StringRef name = tok.getSpelling().drop_front();
302   return parseVariableImpl(loc, name, ctx);
303 }
304 
305 FailureOr<FormatElement *> FormatParser::parseDirective(Context ctx) {
306   FormatToken tok = curToken;
307   SMLoc loc = tok.getLoc();
308   consumeToken();
309 
310   if (tok.is(FormatToken::kw_custom))
311     return parseCustomDirective(loc, ctx);
312   if (tok.is(FormatToken::kw_ref))
313     return parseRefDirective(loc, ctx);
314   if (tok.is(FormatToken::kw_qualified))
315     return parseQualifiedDirective(loc, ctx);
316   return parseDirectiveImpl(loc, tok.getKind(), ctx);
317 }
318 
319 FailureOr<FormatElement *> FormatParser::parseOptionalGroup(Context ctx) {
320   SMLoc loc = curToken.getLoc();
321   consumeToken();
322   if (ctx != TopLevelContext) {
323     return emitError(loc,
324                      "optional groups can only be used as top-level elements");
325   }
326 
327   // Parse the child elements for this optional group.
328   std::vector<FormatElement *> thenElements, elseElements;
329   FormatElement *anchor = nullptr;
330   auto parseChildElements =
331       [this, &anchor](std::vector<FormatElement *> &elements) -> LogicalResult {
332     do {
333       FailureOr<FormatElement *> element = parseElement(TopLevelContext);
334       if (failed(element))
335         return failure();
336       // Check for an anchor.
337       if (curToken.is(FormatToken::caret)) {
338         if (anchor) {
339           return emitError(curToken.getLoc(),
340                            "only one element can be marked as the anchor of an "
341                            "optional group");
342         }
343         anchor = *element;
344         consumeToken();
345       }
346       elements.push_back(*element);
347     } while (!curToken.is(FormatToken::r_paren));
348     return success();
349   };
350 
351   // Parse the 'then' elements. If the anchor was found in this group, then the
352   // optional is not inverted.
353   if (failed(parseChildElements(thenElements)))
354     return failure();
355   consumeToken();
356   bool inverted = !anchor;
357 
358   // Parse the `else` elements of this optional group.
359   if (curToken.is(FormatToken::colon)) {
360     consumeToken();
361     if (failed(parseToken(
362             FormatToken::l_paren,
363             "expected '(' to start else branch of optional group")) ||
364         failed(parseChildElements(elseElements)))
365       return failure();
366     consumeToken();
367   }
368   if (failed(parseToken(FormatToken::question,
369                         "expected '?' after optional group")))
370     return failure();
371 
372   // The optional group is required to have an anchor.
373   if (!anchor)
374     return emitError(loc, "optional group has no anchor element");
375 
376   // Verify the child elements.
377   if (failed(verifyOptionalGroupElements(loc, thenElements, anchor)) ||
378       failed(verifyOptionalGroupElements(loc, elseElements, nullptr)))
379     return failure();
380 
381   // Get the first parsable element. It must be an element that can be
382   // optionally-parsed.
383   auto isWhitespace = [](FormatElement *element) {
384     return isa<WhitespaceElement>(element);
385   };
386   auto thenParseBegin = llvm::find_if_not(thenElements, isWhitespace);
387   auto elseParseBegin = llvm::find_if_not(elseElements, isWhitespace);
388   unsigned thenParseStart = std::distance(thenElements.begin(), thenParseBegin);
389   unsigned elseParseStart = std::distance(elseElements.begin(), elseParseBegin);
390 
391   if (!isa<LiteralElement, VariableElement, CustomDirective>(*thenParseBegin)) {
392     return emitError(loc, "first parsable element of an optional group must be "
393                           "a literal, variable, or custom directive");
394   }
395   return create<OptionalElement>(std::move(thenElements),
396                                  std::move(elseElements), thenParseStart,
397                                  elseParseStart, anchor, inverted);
398 }
399 
400 FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc,
401                                                               Context ctx) {
402   if (ctx != TopLevelContext)
403     return emitError(loc, "'custom' is only valid as a top-level directive");
404 
405   FailureOr<FormatToken> nameTok;
406   if (failed(parseToken(FormatToken::less,
407                         "expected '<' before custom directive name")) ||
408       failed(nameTok =
409                  parseToken(FormatToken::identifier,
410                             "expected custom directive name identifier")) ||
411       failed(parseToken(FormatToken::greater,
412                         "expected '>' after custom directive name")) ||
413       failed(parseToken(FormatToken::l_paren,
414                         "expected '(' before custom directive parameters")))
415     return failure();
416 
417   // Parse the arguments.
418   std::vector<FormatElement *> arguments;
419   while (true) {
420     FailureOr<FormatElement *> argument = parseElement(CustomDirectiveContext);
421     if (failed(argument))
422       return failure();
423     arguments.push_back(*argument);
424     if (!curToken.is(FormatToken::comma))
425       break;
426     consumeToken();
427   }
428 
429   if (failed(parseToken(FormatToken::r_paren,
430                         "expected ')' after custom directive parameters")))
431     return failure();
432 
433   if (failed(verifyCustomDirectiveArguments(loc, arguments)))
434     return failure();
435   return create<CustomDirective>(nameTok->getSpelling(), std::move(arguments));
436 }
437 
438 FailureOr<FormatElement *> FormatParser::parseRefDirective(SMLoc loc,
439                                                            Context context) {
440   if (context != CustomDirectiveContext)
441     return emitError(loc, "'ref' is only valid within a `custom` directive");
442 
443   FailureOr<FormatElement *> arg;
444   if (failed(parseToken(FormatToken::l_paren,
445                         "expected '(' before argument list")) ||
446       failed(arg = parseElement(RefDirectiveContext)) ||
447       failed(
448           parseToken(FormatToken::r_paren, "expected ')' after argument list")))
449     return failure();
450 
451   return create<RefDirective>(*arg);
452 }
453 
454 FailureOr<FormatElement *> FormatParser::parseQualifiedDirective(SMLoc loc,
455                                                                  Context ctx) {
456   if (failed(parseToken(FormatToken::l_paren,
457                         "expected '(' before argument list")))
458     return failure();
459   FailureOr<FormatElement *> var = parseElement(ctx);
460   if (failed(var))
461     return var;
462   if (failed(markQualified(loc, *var)))
463     return failure();
464   if (failed(
465           parseToken(FormatToken::r_paren, "expected ')' after argument list")))
466     return failure();
467   return var;
468 }
469 
470 //===----------------------------------------------------------------------===//
471 // Utility Functions
472 //===----------------------------------------------------------------------===//
473 
474 bool mlir::tblgen::shouldEmitSpaceBefore(StringRef value,
475                                          bool lastWasPunctuation) {
476   if (value.size() != 1 && value != "->")
477     return true;
478   if (lastWasPunctuation)
479     return !StringRef(">)}],").contains(value.front());
480   return !StringRef("<>(){}[],").contains(value.front());
481 }
482 
483 bool mlir::tblgen::canFormatStringAsKeyword(
484     StringRef value, function_ref<void(Twine)> emitError) {
485   if (value.empty()) {
486     if (emitError)
487       emitError("keywords cannot be empty");
488     return false;
489   }
490   if (!isalpha(value.front()) && value.front() != '_') {
491     if (emitError)
492       emitError("valid keyword starts with a letter or '_'");
493     return false;
494   }
495   if (!llvm::all_of(value.drop_front(), [](char c) {
496         return isalnum(c) || c == '_' || c == '$' || c == '.';
497       })) {
498     if (emitError)
499       emitError(
500           "keywords should contain only alphanum, '_', '$', or '.' characters");
501     return false;
502   }
503   return true;
504 }
505 
506 bool mlir::tblgen::isValidLiteral(StringRef value,
507                                   function_ref<void(Twine)> emitError) {
508   if (value.empty()) {
509     if (emitError)
510       emitError("literal can't be empty");
511     return false;
512   }
513   char front = value.front();
514 
515   // If there is only one character, this must either be punctuation or a
516   // single character bare identifier.
517   if (value.size() == 1) {
518     StringRef bare = "_:,=<>()[]{}?+*";
519     if (isalpha(front) || bare.contains(front))
520       return true;
521     if (emitError)
522       emitError("single character literal must be a letter or one of '" + bare +
523                 "'");
524     return false;
525   }
526   // Check the punctuation that are larger than a single character.
527   if (value == "->")
528     return true;
529   if (value == "...")
530     return true;
531 
532   // Otherwise, this must be an identifier.
533   return canFormatStringAsKeyword(value, emitError);
534 }
535 
536 //===----------------------------------------------------------------------===//
537 // Commandline Options
538 //===----------------------------------------------------------------------===//
539 
540 llvm::cl::opt<bool> mlir::tblgen::formatErrorIsFatal(
541     "asmformat-error-is-fatal",
542     llvm::cl::desc("Emit a fatal error if format parsing fails"),
543     llvm::cl::init(true));
544