xref: /llvm-project/mlir/lib/AsmParser/AsmParserImpl.h (revision 4548bff0e8139d4f375f1078dd50a74116eae0a2)
1 //===- AsmParserImpl.h - MLIR AsmParserImpl Class ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_LIB_ASMPARSER_ASMPARSERIMPL_H
10 #define MLIR_LIB_ASMPARSER_ASMPARSERIMPL_H
11 
12 #include "Parser.h"
13 #include "mlir/AsmParser/AsmParserState.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "llvm/Support/Base64.h"
17 #include <optional>
18 
19 namespace mlir {
20 namespace detail {
21 //===----------------------------------------------------------------------===//
22 // AsmParserImpl
23 //===----------------------------------------------------------------------===//
24 
25 /// This class provides the implementation of the generic parser methods within
26 /// AsmParser.
27 template <typename BaseT>
28 class AsmParserImpl : public BaseT {
29 public:
30   AsmParserImpl(SMLoc nameLoc, Parser &parser)
31       : nameLoc(nameLoc), parser(parser) {}
32   ~AsmParserImpl() override = default;
33 
34   /// Return the location of the original name token.
35   SMLoc getNameLoc() const override { return nameLoc; }
36 
37   //===--------------------------------------------------------------------===//
38   // Utilities
39   //===--------------------------------------------------------------------===//
40 
41   /// Return if any errors were emitted during parsing.
42   bool didEmitError() const { return emittedError; }
43 
44   /// Emit a diagnostic at the specified location and return failure.
45   InFlightDiagnostic emitError(SMLoc loc, const Twine &message) override {
46     emittedError = true;
47     return parser.emitError(loc, message);
48   }
49 
50   /// Return a builder which provides useful access to MLIRContext, global
51   /// objects like types and attributes.
52   Builder &getBuilder() const override { return parser.builder; }
53 
54   /// Get the location of the next token and store it into the argument.  This
55   /// always succeeds.
56   SMLoc getCurrentLocation() override { return parser.getToken().getLoc(); }
57 
58   /// Re-encode the given source location as an MLIR location and return it.
59   Location getEncodedSourceLoc(SMLoc loc) override {
60     return parser.getEncodedSourceLocation(loc);
61   }
62 
63   //===--------------------------------------------------------------------===//
64   // Token Parsing
65   //===--------------------------------------------------------------------===//
66 
67   using Delimiter = AsmParser::Delimiter;
68 
69   /// Parse a `->` token.
70   ParseResult parseArrow() override {
71     return parser.parseToken(Token::arrow, "expected '->'");
72   }
73 
74   /// Parses a `->` if present.
75   ParseResult parseOptionalArrow() override {
76     return success(parser.consumeIf(Token::arrow));
77   }
78 
79   /// Parse a '{' token.
80   ParseResult parseLBrace() override {
81     return parser.parseToken(Token::l_brace, "expected '{'");
82   }
83 
84   /// Parse a '{' token if present
85   ParseResult parseOptionalLBrace() override {
86     return success(parser.consumeIf(Token::l_brace));
87   }
88 
89   /// Parse a `}` token.
90   ParseResult parseRBrace() override {
91     return parser.parseToken(Token::r_brace, "expected '}'");
92   }
93 
94   /// Parse a `}` token if present
95   ParseResult parseOptionalRBrace() override {
96     return success(parser.consumeIf(Token::r_brace));
97   }
98 
99   /// Parse a `:` token.
100   ParseResult parseColon() override {
101     return parser.parseToken(Token::colon, "expected ':'");
102   }
103 
104   /// Parse a `:` token if present.
105   ParseResult parseOptionalColon() override {
106     return success(parser.consumeIf(Token::colon));
107   }
108 
109   /// Parse a `,` token.
110   ParseResult parseComma() override {
111     return parser.parseToken(Token::comma, "expected ','");
112   }
113 
114   /// Parse a `,` token if present.
115   ParseResult parseOptionalComma() override {
116     return success(parser.consumeIf(Token::comma));
117   }
118 
119   /// Parses a `...`.
120   ParseResult parseEllipsis() override {
121     return parser.parseToken(Token::ellipsis, "expected '...'");
122   }
123 
124   /// Parses a `...` if present.
125   ParseResult parseOptionalEllipsis() override {
126     return success(parser.consumeIf(Token::ellipsis));
127   }
128 
129   /// Parse a `=` token.
130   ParseResult parseEqual() override {
131     return parser.parseToken(Token::equal, "expected '='");
132   }
133 
134   /// Parse a `=` token if present.
135   ParseResult parseOptionalEqual() override {
136     return success(parser.consumeIf(Token::equal));
137   }
138 
139   /// Parse a '<' token.
140   ParseResult parseLess() override {
141     return parser.parseToken(Token::less, "expected '<'");
142   }
143 
144   /// Parse a `<` token if present.
145   ParseResult parseOptionalLess() override {
146     return success(parser.consumeIf(Token::less));
147   }
148 
149   /// Parse a '>' token.
150   ParseResult parseGreater() override {
151     return parser.parseToken(Token::greater, "expected '>'");
152   }
153 
154   /// Parse a `>` token if present.
155   ParseResult parseOptionalGreater() override {
156     return success(parser.consumeIf(Token::greater));
157   }
158 
159   /// Parse a `(` token.
160   ParseResult parseLParen() override {
161     return parser.parseToken(Token::l_paren, "expected '('");
162   }
163 
164   /// Parses a '(' if present.
165   ParseResult parseOptionalLParen() override {
166     return success(parser.consumeIf(Token::l_paren));
167   }
168 
169   /// Parse a `)` token.
170   ParseResult parseRParen() override {
171     return parser.parseToken(Token::r_paren, "expected ')'");
172   }
173 
174   /// Parses a ')' if present.
175   ParseResult parseOptionalRParen() override {
176     return success(parser.consumeIf(Token::r_paren));
177   }
178 
179   /// Parse a `[` token.
180   ParseResult parseLSquare() override {
181     return parser.parseToken(Token::l_square, "expected '['");
182   }
183 
184   /// Parses a '[' if present.
185   ParseResult parseOptionalLSquare() override {
186     return success(parser.consumeIf(Token::l_square));
187   }
188 
189   /// Parse a `]` token.
190   ParseResult parseRSquare() override {
191     return parser.parseToken(Token::r_square, "expected ']'");
192   }
193 
194   /// Parses a ']' if present.
195   ParseResult parseOptionalRSquare() override {
196     return success(parser.consumeIf(Token::r_square));
197   }
198 
199   /// Parses a '?' token.
200   ParseResult parseQuestion() override {
201     return parser.parseToken(Token::question, "expected '?'");
202   }
203 
204   /// Parses a '?' if present.
205   ParseResult parseOptionalQuestion() override {
206     return success(parser.consumeIf(Token::question));
207   }
208 
209   /// Parses a '*' token.
210   ParseResult parseStar() override {
211     return parser.parseToken(Token::star, "expected '*'");
212   }
213 
214   /// Parses a '*' if present.
215   ParseResult parseOptionalStar() override {
216     return success(parser.consumeIf(Token::star));
217   }
218 
219   /// Parses a '+' token.
220   ParseResult parsePlus() override {
221     return parser.parseToken(Token::plus, "expected '+'");
222   }
223 
224   /// Parses a '+' token if present.
225   ParseResult parseOptionalPlus() override {
226     return success(parser.consumeIf(Token::plus));
227   }
228 
229   /// Parses a '-' token.
230   ParseResult parseMinus() override {
231     return parser.parseToken(Token::minus, "expected '-'");
232   }
233 
234   /// Parses a '-' token if present.
235   ParseResult parseOptionalMinus() override {
236     return success(parser.consumeIf(Token::minus));
237   }
238 
239   /// Parse a '|' token.
240   ParseResult parseVerticalBar() override {
241     return parser.parseToken(Token::vertical_bar, "expected '|'");
242   }
243 
244   /// Parse a '|' token if present.
245   ParseResult parseOptionalVerticalBar() override {
246     return success(parser.consumeIf(Token::vertical_bar));
247   }
248 
249   /// Parses a quoted string token if present.
250   ParseResult parseOptionalString(std::string *string) override {
251     if (!parser.getToken().is(Token::string))
252       return failure();
253 
254     if (string)
255       *string = parser.getToken().getStringValue();
256     parser.consumeToken();
257     return success();
258   }
259 
260   /// Parses a Base64 encoded string of bytes.
261   ParseResult parseBase64Bytes(std::vector<char> *bytes) override {
262     auto loc = getCurrentLocation();
263     if (!parser.getToken().is(Token::string))
264       return emitError(loc, "expected string");
265 
266     if (bytes) {
267       // decodeBase64 doesn't modify its input so we can use the token spelling
268       // and just slice off the quotes/whitespaces if there are any. Whitespace
269       // and quotes cannot appear as part of a (standard) base64 encoded string,
270       // so this is safe to do.
271       StringRef b64QuotedString = parser.getTokenSpelling();
272       StringRef b64String =
273           b64QuotedString.ltrim("\"  \t\n\v\f\r").rtrim("\" \t\n\v\f\r");
274       if (auto err = llvm::decodeBase64(b64String, *bytes))
275         return emitError(loc, toString(std::move(err)));
276     }
277 
278     parser.consumeToken();
279     return success();
280   }
281 
282   /// Parse a floating point value with given semantics from the stream. Since
283   /// this implementation parses the string as double precision and only
284   /// afterwards converts the value to the requested semantic, precision may be
285   /// lost.
286   ParseResult parseFloat(const llvm::fltSemantics &semantics,
287                          APFloat &result) override {
288     bool isNegative = parser.consumeIf(Token::minus);
289     Token curTok = parser.getToken();
290     std::optional<APFloat> apResult;
291     if (failed(parser.parseFloatFromLiteral(apResult, curTok, isNegative,
292                                             semantics)))
293       return failure();
294     parser.consumeToken();
295     result = *apResult;
296     return success();
297   }
298 
299   /// Parse a floating point value from the stream.
300   ParseResult parseFloat(double &result) override {
301     llvm::APFloat apResult(0.0);
302     if (parseFloat(APFloat::IEEEdouble(), apResult))
303       return failure();
304 
305     result = apResult.convertToDouble();
306     return success();
307   }
308 
309   /// Parse an optional integer value from the stream.
310   OptionalParseResult parseOptionalInteger(APInt &result) override {
311     return parser.parseOptionalInteger(result);
312   }
313 
314   /// Parse an optional integer value from the stream.
315   OptionalParseResult parseOptionalDecimalInteger(APInt &result) override {
316     return parser.parseOptionalDecimalInteger(result);
317   }
318 
319   /// Parse a list of comma-separated items with an optional delimiter.  If a
320   /// delimiter is provided, then an empty list is allowed.  If not, then at
321   /// least one element will be parsed.
322   ParseResult parseCommaSeparatedList(Delimiter delimiter,
323                                       function_ref<ParseResult()> parseElt,
324                                       StringRef contextMessage) override {
325     return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
326   }
327 
328   //===--------------------------------------------------------------------===//
329   // Keyword Parsing
330   //===--------------------------------------------------------------------===//
331 
332   ParseResult parseKeyword(StringRef keyword, const Twine &msg) override {
333     if (parser.getToken().isCodeCompletion())
334       return parser.codeCompleteExpectedTokens(keyword);
335 
336     auto loc = getCurrentLocation();
337     if (parseOptionalKeyword(keyword))
338       return emitError(loc, "expected '") << keyword << "'" << msg;
339     return success();
340   }
341   using AsmParser::parseKeyword;
342 
343   /// Parse the given keyword if present.
344   ParseResult parseOptionalKeyword(StringRef keyword) override {
345     if (parser.getToken().isCodeCompletion())
346       return parser.codeCompleteOptionalTokens(keyword);
347 
348     // Check that the current token has the same spelling.
349     if (!parser.isCurrentTokenAKeyword() ||
350         parser.getTokenSpelling() != keyword)
351       return failure();
352     parser.consumeToken();
353     return success();
354   }
355 
356   /// Parse a keyword, if present, into 'keyword'.
357   ParseResult parseOptionalKeyword(StringRef *keyword) override {
358     // Check that the current token is a keyword.
359     if (!parser.isCurrentTokenAKeyword())
360       return failure();
361 
362     *keyword = parser.getTokenSpelling();
363     parser.consumeToken();
364     return success();
365   }
366 
367   /// Parse a keyword if it is one of the 'allowedKeywords'.
368   ParseResult
369   parseOptionalKeyword(StringRef *keyword,
370                        ArrayRef<StringRef> allowedKeywords) override {
371     if (parser.getToken().isCodeCompletion())
372       return parser.codeCompleteOptionalTokens(allowedKeywords);
373 
374     // Check that the current token is a keyword.
375     if (!parser.isCurrentTokenAKeyword())
376       return failure();
377 
378     StringRef currentKeyword = parser.getTokenSpelling();
379     if (llvm::is_contained(allowedKeywords, currentKeyword)) {
380       *keyword = currentKeyword;
381       parser.consumeToken();
382       return success();
383     }
384 
385     return failure();
386   }
387 
388   /// Parse an optional keyword or string and set instance into 'result'.`
389   ParseResult parseOptionalKeywordOrString(std::string *result) override {
390     StringRef keyword;
391     if (succeeded(parseOptionalKeyword(&keyword))) {
392       *result = keyword.str();
393       return success();
394     }
395 
396     return parseOptionalString(result);
397   }
398 
399   //===--------------------------------------------------------------------===//
400   // Attribute Parsing
401   //===--------------------------------------------------------------------===//
402 
403   /// Parse an arbitrary attribute and return it in result.
404   ParseResult parseAttribute(Attribute &result, Type type) override {
405     result = parser.parseAttribute(type);
406     return success(static_cast<bool>(result));
407   }
408 
409   /// Parse a custom attribute with the provided callback, unless the next
410   /// token is `#`, in which case the generic parser is invoked.
411   ParseResult parseCustomAttributeWithFallback(
412       Attribute &result, Type type,
413       function_ref<ParseResult(Attribute &result, Type type)> parseAttribute)
414       override {
415     if (parser.getToken().isNot(Token::hash_identifier))
416       return parseAttribute(result, type);
417     result = parser.parseAttribute(type);
418     return success(static_cast<bool>(result));
419   }
420 
421   /// Parse a custom attribute with the provided callback, unless the next
422   /// token is `#`, in which case the generic parser is invoked.
423   ParseResult parseCustomTypeWithFallback(
424       Type &result,
425       function_ref<ParseResult(Type &result)> parseType) override {
426     if (parser.getToken().isNot(Token::exclamation_identifier))
427       return parseType(result);
428     result = parser.parseType();
429     return success(static_cast<bool>(result));
430   }
431 
432   OptionalParseResult parseOptionalAttribute(Attribute &result,
433                                              Type type) override {
434     return parser.parseOptionalAttribute(result, type);
435   }
436   OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
437                                              Type type) override {
438     return parser.parseOptionalAttribute(result, type);
439   }
440   OptionalParseResult parseOptionalAttribute(StringAttr &result,
441                                              Type type) override {
442     return parser.parseOptionalAttribute(result, type);
443   }
444   OptionalParseResult parseOptionalAttribute(SymbolRefAttr &result,
445                                              Type type) override {
446     return parser.parseOptionalAttribute(result, type);
447   }
448 
449   /// Parse a named dictionary into 'result' if it is present.
450   ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
451     if (parser.getToken().isNot(Token::l_brace))
452       return success();
453     return parser.parseAttributeDict(result);
454   }
455 
456   /// Parse a named dictionary into 'result' if the `attributes` keyword is
457   /// present.
458   ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result) override {
459     if (failed(parseOptionalKeyword("attributes")))
460       return success();
461     return parser.parseAttributeDict(result);
462   }
463 
464   /// Parse an affine map instance into 'map'.
465   ParseResult parseAffineMap(AffineMap &map) override {
466     return parser.parseAffineMapReference(map);
467   }
468 
469   /// Parse an affine expr instance into 'expr' using the already computed
470   /// mapping from symbols to affine expressions in 'symbolSet'.
471   ParseResult
472   parseAffineExpr(ArrayRef<std::pair<StringRef, AffineExpr>> symbolSet,
473                   AffineExpr &expr) override {
474     return parser.parseAffineExprReference(symbolSet, expr);
475   }
476 
477   /// Parse an integer set instance into 'set'.
478   ParseResult parseIntegerSet(IntegerSet &set) override {
479     return parser.parseIntegerSetReference(set);
480   }
481 
482   //===--------------------------------------------------------------------===//
483   // Identifier Parsing
484   //===--------------------------------------------------------------------===//
485 
486   /// Parse an optional @-identifier and store it (without the '@' symbol) in a
487   /// string attribute named 'attrName'.
488   ParseResult parseOptionalSymbolName(StringAttr &result) override {
489     Token atToken = parser.getToken();
490     if (atToken.isNot(Token::at_identifier))
491       return failure();
492 
493     result = getBuilder().getStringAttr(atToken.getSymbolReference());
494     parser.consumeToken();
495 
496     // If we are populating the assembly parser state, record this as a symbol
497     // reference.
498     if (parser.getState().asmState) {
499       parser.getState().asmState->addUses(SymbolRefAttr::get(result),
500                                           atToken.getLocRange());
501     }
502     return success();
503   }
504 
505   //===--------------------------------------------------------------------===//
506   // Resource Parsing
507   //===--------------------------------------------------------------------===//
508 
509   /// Parse a handle to a resource within the assembly format.
510   FailureOr<AsmDialectResourceHandle>
511   parseResourceHandle(Dialect *dialect) override {
512     const auto *interface = dyn_cast<OpAsmDialectInterface>(dialect);
513     if (!interface) {
514       return parser.emitError() << "dialect '" << dialect->getNamespace()
515                                 << "' does not expect resource handles";
516     }
517     StringRef resourceName;
518     return parser.parseResourceHandle(interface, resourceName);
519   }
520 
521   //===--------------------------------------------------------------------===//
522   // Type Parsing
523   //===--------------------------------------------------------------------===//
524 
525   /// Parse a type.
526   ParseResult parseType(Type &result) override {
527     return failure(!(result = parser.parseType()));
528   }
529 
530   /// Parse an optional type.
531   OptionalParseResult parseOptionalType(Type &result) override {
532     return parser.parseOptionalType(result);
533   }
534 
535   /// Parse an arrow followed by a type list.
536   ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) override {
537     if (parseArrow() || parser.parseFunctionResultTypes(result))
538       return failure();
539     return success();
540   }
541 
542   /// Parse an optional arrow followed by a type list.
543   ParseResult
544   parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override {
545     if (!parser.consumeIf(Token::arrow))
546       return success();
547     return parser.parseFunctionResultTypes(result);
548   }
549 
550   /// Parse a colon followed by a type.
551   ParseResult parseColonType(Type &result) override {
552     return failure(parser.parseToken(Token::colon, "expected ':'") ||
553                    !(result = parser.parseType()));
554   }
555 
556   /// Parse a colon followed by a type list, which must have at least one type.
557   ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override {
558     if (parser.parseToken(Token::colon, "expected ':'"))
559       return failure();
560     return parser.parseTypeListNoParens(result);
561   }
562 
563   /// Parse an optional colon followed by a type list, which if present must
564   /// have at least one type.
565   ParseResult
566   parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override {
567     if (!parser.consumeIf(Token::colon))
568       return success();
569     return parser.parseTypeListNoParens(result);
570   }
571 
572   ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
573                                  bool allowDynamic,
574                                  bool withTrailingX) override {
575     return parser.parseDimensionListRanked(dimensions, allowDynamic,
576                                            withTrailingX);
577   }
578 
579   ParseResult parseXInDimensionList() override {
580     return parser.parseXInDimensionList();
581   }
582 
583   LogicalResult pushCyclicParsing(const void *opaquePointer) override {
584     return success(parser.getState().cyclicParsingStack.insert(opaquePointer));
585   }
586 
587   void popCyclicParsing() override {
588     parser.getState().cyclicParsingStack.pop_back();
589   }
590 
591   //===--------------------------------------------------------------------===//
592   // Code Completion
593   //===--------------------------------------------------------------------===//
594 
595   /// Parse a keyword, or an empty string if the current location signals a code
596   /// completion.
597   ParseResult parseKeywordOrCompletion(StringRef *keyword) override {
598     Token tok = parser.getToken();
599     if (tok.isCodeCompletion() && tok.getSpelling().empty()) {
600       *keyword = "";
601       return success();
602     }
603     return parseKeyword(keyword);
604   }
605 
606   /// Signal the code completion of a set of expected tokens.
607   void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) override {
608     Token tok = parser.getToken();
609     if (tok.isCodeCompletion() && tok.getSpelling().empty())
610       (void)parser.codeCompleteExpectedTokens(tokens);
611   }
612 
613 protected:
614   /// The source location of the dialect symbol.
615   SMLoc nameLoc;
616 
617   /// The main parser.
618   Parser &parser;
619 
620   /// A flag that indicates if any errors were emitted during parsing.
621   bool emittedError = false;
622 };
623 } // namespace detail
624 } // namespace mlir
625 
626 #endif // MLIR_LIB_ASMPARSER_ASMPARSERIMPL_H
627