xref: /llvm-project/mlir/lib/AsmParser/DialectSymbolParser.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser  --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the dialect symbols, such as extended
10 // attributes and types.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "AsmParserImpl.h"
15 #include "Parser.h"
16 #include "mlir/AsmParser/AsmParserState.h"
17 #include "mlir/IR/AsmState.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/BuiltinAttributeInterfaces.h"
20 #include "mlir/IR/BuiltinAttributes.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Dialect.h"
23 #include "mlir/IR/DialectImplementation.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/Support/LLVM.h"
26 #include "llvm/Support/MemoryBuffer.h"
27 #include "llvm/Support/SourceMgr.h"
28 #include <cassert>
29 #include <cstddef>
30 #include <utility>
31 
32 using namespace mlir;
33 using namespace mlir::detail;
34 using llvm::MemoryBuffer;
35 using llvm::SourceMgr;
36 
37 namespace {
38 /// This class provides the main implementation of the DialectAsmParser that
39 /// allows for dialects to parse attributes and types. This allows for dialect
40 /// hooking into the main MLIR parsing logic.
41 class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
42 public:
CustomDialectAsmParser(StringRef fullSpec,Parser & parser)43   CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
44       : AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser),
45         fullSpec(fullSpec) {}
46   ~CustomDialectAsmParser() override = default;
47 
48   /// Returns the full specification of the symbol being parsed. This allows
49   /// for using a separate parser if necessary.
getFullSymbolSpec() const50   StringRef getFullSymbolSpec() const override { return fullSpec; }
51 
52 private:
53   /// The full symbol specification.
54   StringRef fullSpec;
55 };
56 } // namespace
57 
58 ///
59 ///   pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
60 ///   pretty-dialect-sym-contents ::= pretty-dialect-sym-body
61 ///                                  | '(' pretty-dialect-sym-contents+ ')'
62 ///                                  | '[' pretty-dialect-sym-contents+ ']'
63 ///                                  | '{' pretty-dialect-sym-contents+ '}'
64 ///                                  | '[^[<({>\])}\0]+'
65 ///
parseDialectSymbolBody(StringRef & body,bool & isCodeCompletion)66 ParseResult Parser::parseDialectSymbolBody(StringRef &body,
67                                            bool &isCodeCompletion) {
68   // Symbol bodies are a relatively unstructured format that contains a series
69   // of properly nested punctuation, with anything else in the middle. Scan
70   // ahead to find it and consume it if successful, otherwise emit an error.
71   const char *curPtr = getTokenSpelling().data();
72 
73   // Scan over the nested punctuation, bailing out on error and consuming until
74   // we find the end. We know that we're currently looking at the '<', so we can
75   // go until we find the matching '>' character.
76   assert(*curPtr == '<');
77   SmallVector<char, 8> nestedPunctuation;
78   const char *codeCompleteLoc = state.lex.getCodeCompleteLoc();
79 
80   // Functor used to emit an unbalanced punctuation error.
81   auto emitPunctError = [&] {
82     return emitError() << "unbalanced '" << nestedPunctuation.back()
83                        << "' character in pretty dialect name";
84   };
85   // Functor used to check for unbalanced punctuation.
86   auto checkNestedPunctuation = [&](char expectedToken) -> ParseResult {
87     if (nestedPunctuation.back() != expectedToken)
88       return emitPunctError();
89     nestedPunctuation.pop_back();
90     return success();
91   };
92   do {
93     // Handle code completions, which may appear in the middle of the symbol
94     // body.
95     if (curPtr == codeCompleteLoc) {
96       isCodeCompletion = true;
97       nestedPunctuation.clear();
98       break;
99     }
100 
101     char c = *curPtr++;
102     switch (c) {
103     case '\0':
104       // This also handles the EOF case.
105       if (!nestedPunctuation.empty())
106         return emitPunctError();
107       return emitError("unexpected nul or EOF in pretty dialect name");
108     case '<':
109     case '[':
110     case '(':
111     case '{':
112       nestedPunctuation.push_back(c);
113       continue;
114 
115     case '-':
116       // The sequence `->` is treated as special token.
117       if (*curPtr == '>')
118         ++curPtr;
119       continue;
120 
121     case '>':
122       if (failed(checkNestedPunctuation('<')))
123         return failure();
124       break;
125     case ']':
126       if (failed(checkNestedPunctuation('[')))
127         return failure();
128       break;
129     case ')':
130       if (failed(checkNestedPunctuation('(')))
131         return failure();
132       break;
133     case '}':
134       if (failed(checkNestedPunctuation('{')))
135         return failure();
136       break;
137     case '"': {
138       // Dispatch to the lexer to lex past strings.
139       resetToken(curPtr - 1);
140       curPtr = state.curToken.getEndLoc().getPointer();
141 
142       // Handle code completions, which may appear in the middle of the symbol
143       // body.
144       if (state.curToken.isCodeCompletion()) {
145         isCodeCompletion = true;
146         nestedPunctuation.clear();
147         break;
148       }
149 
150       // Otherwise, ensure this token was actually a string.
151       if (state.curToken.isNot(Token::string))
152         return failure();
153       break;
154     }
155 
156     default:
157       continue;
158     }
159   } while (!nestedPunctuation.empty());
160 
161   // Ok, we succeeded, remember where we stopped, reset the lexer to know it is
162   // consuming all this stuff, and return.
163   resetToken(curPtr);
164 
165   unsigned length = curPtr - body.begin();
166   body = StringRef(body.data(), length);
167   return success();
168 }
169 
170 /// Parse an extended dialect symbol.
171 template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
parseExtendedSymbol(Parser & p,AsmParserState * asmState,SymbolAliasMap & aliases,CreateFn && createSymbol)172 static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
173                                   SymbolAliasMap &aliases,
174                                   CreateFn &&createSymbol) {
175   Token tok = p.getToken();
176 
177   // Handle code completion of the extended symbol.
178   StringRef identifier = tok.getSpelling().drop_front();
179   if (tok.isCodeCompletion() && identifier.empty())
180     return p.codeCompleteDialectSymbol(aliases);
181 
182   // Parse the dialect namespace.
183   SMRange range = p.getToken().getLocRange();
184   SMLoc loc = p.getToken().getLoc();
185   p.consumeToken();
186 
187   // Check to see if this is a pretty name.
188   auto [dialectName, symbolData] = identifier.split('.');
189   bool isPrettyName = !symbolData.empty() || identifier.back() == '.';
190 
191   // Check to see if the symbol has trailing data, i.e. has an immediately
192   // following '<'.
193   bool hasTrailingData =
194       p.getToken().is(Token::less) &&
195       identifier.bytes_end() == p.getTokenSpelling().bytes_begin();
196 
197   // If there is no '<' token following this, and if the typename contains no
198   // dot, then we are parsing a symbol alias.
199   if (!hasTrailingData && !isPrettyName) {
200     // Check for an alias for this type.
201     auto aliasIt = aliases.find(identifier);
202     if (aliasIt == aliases.end())
203       return (p.emitWrongTokenError("undefined symbol alias id '" + identifier +
204                                     "'"),
205               nullptr);
206     if (asmState) {
207       if constexpr (std::is_same_v<Symbol, Type>)
208         asmState->addTypeAliasUses(identifier, range);
209       else
210         asmState->addAttrAliasUses(identifier, range);
211     }
212     return aliasIt->second;
213   }
214 
215   // If this isn't an alias, we are parsing a dialect-specific symbol. If the
216   // name contains a dot, then this is the "pretty" form. If not, it is the
217   // verbose form that looks like <...>.
218   if (!isPrettyName) {
219     // Point the symbol data to the end of the dialect name to start.
220     symbolData = StringRef(dialectName.end(), 0);
221 
222     // Parse the body of the symbol.
223     bool isCodeCompletion = false;
224     if (p.parseDialectSymbolBody(symbolData, isCodeCompletion))
225       return nullptr;
226     symbolData = symbolData.drop_front();
227 
228     // If the body contained a code completion it won't have the trailing `>`
229     // token, so don't drop it.
230     if (!isCodeCompletion)
231       symbolData = symbolData.drop_back();
232   } else {
233     loc = SMLoc::getFromPointer(symbolData.data());
234 
235     // If the dialect's symbol is followed immediately by a <, then lex the body
236     // of it into prettyName.
237     if (hasTrailingData && p.parseDialectSymbolBody(symbolData))
238       return nullptr;
239   }
240 
241   return createSymbol(dialectName, symbolData, loc);
242 }
243 
244 /// Parse an extended attribute.
245 ///
246 ///   extended-attribute ::= (dialect-attribute | attribute-alias)
247 ///   dialect-attribute  ::= `#` dialect-namespace `<` attr-data `>`
248 ///                          (`:` type)?
249 ///                        | `#` alias-name pretty-dialect-sym-body? (`:` type)?
250 ///   attribute-alias    ::= `#` alias-name
251 ///
parseExtendedAttr(Type type)252 Attribute Parser::parseExtendedAttr(Type type) {
253   MLIRContext *ctx = getContext();
254   Attribute attr = parseExtendedSymbol<Attribute>(
255       *this, state.asmState, state.symbols.attributeAliasDefinitions,
256       [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {
257         // Parse an optional trailing colon type.
258         Type attrType = type;
259         if (consumeIf(Token::colon) && !(attrType = parseType()))
260           return Attribute();
261 
262         // If we found a registered dialect, then ask it to parse the attribute.
263         if (Dialect *dialect =
264                 builder.getContext()->getOrLoadDialect(dialectName)) {
265           // Temporarily reset the lexer to let the dialect parse the attribute.
266           const char *curLexerPos = getToken().getLoc().getPointer();
267           resetToken(symbolData.data());
268 
269           // Parse the attribute.
270           CustomDialectAsmParser customParser(symbolData, *this);
271           Attribute attr = dialect->parseAttribute(customParser, attrType);
272           resetToken(curLexerPos);
273           return attr;
274         }
275 
276         // Otherwise, form a new opaque attribute.
277         return OpaqueAttr::getChecked(
278             [&] { return emitError(loc); }, StringAttr::get(ctx, dialectName),
279             symbolData, attrType ? attrType : NoneType::get(ctx));
280       });
281 
282   // Ensure that the attribute has the same type as requested.
283   auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
284   if (type && typedAttr && typedAttr.getType() != type) {
285     emitError("attribute type different than expected: expected ")
286         << type << ", but got " << typedAttr.getType();
287     return nullptr;
288   }
289   return attr;
290 }
291 
292 /// Parse an extended type.
293 ///
294 ///   extended-type ::= (dialect-type | type-alias)
295 ///   dialect-type  ::= `!` dialect-namespace `<` `"` type-data `"` `>`
296 ///   dialect-type  ::= `!` alias-name pretty-dialect-attribute-body?
297 ///   type-alias    ::= `!` alias-name
298 ///
parseExtendedType()299 Type Parser::parseExtendedType() {
300   MLIRContext *ctx = getContext();
301   return parseExtendedSymbol<Type>(
302       *this, state.asmState, state.symbols.typeAliasDefinitions,
303       [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
304         // If we found a registered dialect, then ask it to parse the type.
305         if (auto *dialect = ctx->getOrLoadDialect(dialectName)) {
306           // Temporarily reset the lexer to let the dialect parse the type.
307           const char *curLexerPos = getToken().getLoc().getPointer();
308           resetToken(symbolData.data());
309 
310           // Parse the type.
311           CustomDialectAsmParser customParser(symbolData, *this);
312           Type type = dialect->parseType(customParser);
313           resetToken(curLexerPos);
314           return type;
315         }
316 
317         // Otherwise, form a new opaque type.
318         return OpaqueType::getChecked([&] { return emitError(loc); },
319                                       StringAttr::get(ctx, dialectName),
320                                       symbolData);
321       });
322 }
323 
324 //===----------------------------------------------------------------------===//
325 // mlir::parseAttribute/parseType
326 //===----------------------------------------------------------------------===//
327 
328 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If
329 /// parsing failed, nullptr is returned.
330 template <typename T, typename ParserFn>
parseSymbol(StringRef inputStr,MLIRContext * context,size_t * numReadOut,bool isKnownNullTerminated,ParserFn && parserFn)331 static T parseSymbol(StringRef inputStr, MLIRContext *context,
332                      size_t *numReadOut, bool isKnownNullTerminated,
333                      ParserFn &&parserFn) {
334   // Set the buffer name to the string being parsed, so that it appears in error
335   // diagnostics.
336   auto memBuffer =
337       isKnownNullTerminated
338           ? MemoryBuffer::getMemBuffer(inputStr,
339                                        /*BufferName=*/inputStr)
340           : MemoryBuffer::getMemBufferCopy(inputStr, /*BufferName=*/inputStr);
341   SourceMgr sourceMgr;
342   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
343   SymbolState aliasState;
344   ParserConfig config(context);
345   ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr,
346                     /*codeCompleteContext=*/nullptr);
347   Parser parser(state);
348 
349   Token startTok = parser.getToken();
350   T symbol = parserFn(parser);
351   if (!symbol)
352     return T();
353 
354   // Provide the number of bytes that were read.
355   Token endTok = parser.getToken();
356   size_t numRead =
357       endTok.getLoc().getPointer() - startTok.getLoc().getPointer();
358   if (numReadOut) {
359     *numReadOut = numRead;
360   } else if (numRead != inputStr.size()) {
361     parser.emitError(endTok.getLoc()) << "found trailing characters: '"
362                                       << inputStr.drop_front(numRead) << "'";
363     return T();
364   }
365   return symbol;
366 }
367 
parseAttribute(StringRef attrStr,MLIRContext * context,Type type,size_t * numRead,bool isKnownNullTerminated)368 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
369                                Type type, size_t *numRead,
370                                bool isKnownNullTerminated) {
371   return parseSymbol<Attribute>(
372       attrStr, context, numRead, isKnownNullTerminated,
373       [type](Parser &parser) { return parser.parseAttribute(type); });
374 }
parseType(StringRef typeStr,MLIRContext * context,size_t * numRead,bool isKnownNullTerminated)375 Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead,
376                      bool isKnownNullTerminated) {
377   return parseSymbol<Type>(typeStr, context, numRead, isKnownNullTerminated,
378                            [](Parser &parser) { return parser.parseType(); });
379 }
380