xref: /llvm-project/mlir/lib/Tools/PDLL/Parser/Lexer.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- Lexer.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Lexer.h"
10 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
11 #include "mlir/Tools/PDLL/Parser/CodeComplete.h"
12 #include "llvm/ADT/StringExtras.h"
13 #include "llvm/ADT/StringSwitch.h"
14 #include "llvm/Support/SourceMgr.h"
15 
16 using namespace mlir;
17 using namespace mlir::pdll;
18 
19 //===----------------------------------------------------------------------===//
20 // Token
21 //===----------------------------------------------------------------------===//
22 
getStringValue() const23 std::string Token::getStringValue() const {
24   assert(getKind() == string || getKind() == string_block ||
25          getKind() == code_complete_string);
26 
27   // Start by dropping the quotes.
28   StringRef bytes = getSpelling();
29   if (is(string))
30     bytes = bytes.drop_front().drop_back();
31   else if (is(string_block))
32     bytes = bytes.drop_front(2).drop_back(2);
33 
34   std::string result;
35   result.reserve(bytes.size());
36   for (unsigned i = 0, e = bytes.size(); i != e;) {
37     auto c = bytes[i++];
38     if (c != '\\') {
39       result.push_back(c);
40       continue;
41     }
42 
43     assert(i + 1 <= e && "invalid string should be caught by lexer");
44     auto c1 = bytes[i++];
45     switch (c1) {
46     case '"':
47     case '\\':
48       result.push_back(c1);
49       continue;
50     case 'n':
51       result.push_back('\n');
52       continue;
53     case 't':
54       result.push_back('\t');
55       continue;
56     default:
57       break;
58     }
59 
60     assert(i + 1 <= e && "invalid string should be caught by lexer");
61     auto c2 = bytes[i++];
62 
63     assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape");
64     result.push_back((llvm::hexDigitValue(c1) << 4) | llvm::hexDigitValue(c2));
65   }
66 
67   return result;
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // Lexer
72 //===----------------------------------------------------------------------===//
73 
Lexer(llvm::SourceMgr & mgr,ast::DiagnosticEngine & diagEngine,CodeCompleteContext * codeCompleteContext)74 Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine,
75              CodeCompleteContext *codeCompleteContext)
76     : srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false),
77       codeCompletionLocation(nullptr) {
78   curBufferID = mgr.getMainFileID();
79   curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
80   curPtr = curBuffer.begin();
81 
82   // Set the code completion location if necessary.
83   if (codeCompleteContext) {
84     codeCompletionLocation =
85         codeCompleteContext->getCodeCompleteLoc().getPointer();
86   }
87 
88   // If the diag engine has no handler, add a default that emits to the
89   // SourceMgr.
90   if (!diagEngine.getHandlerFn()) {
91     diagEngine.setHandlerFn([&](const ast::Diagnostic &diag) {
92       srcMgr.PrintMessage(diag.getLocation().Start, diag.getSeverity(),
93                           diag.getMessage());
94       for (const ast::Diagnostic &note : diag.getNotes())
95         srcMgr.PrintMessage(note.getLocation().Start, note.getSeverity(),
96                             note.getMessage());
97     });
98     addedHandlerToDiagEngine = true;
99   }
100 }
101 
~Lexer()102 Lexer::~Lexer() {
103   if (addedHandlerToDiagEngine)
104     diagEngine.setHandlerFn(nullptr);
105 }
106 
pushInclude(StringRef filename,SMRange includeLoc)107 LogicalResult Lexer::pushInclude(StringRef filename, SMRange includeLoc) {
108   std::string includedFile;
109   int bufferID =
110       srcMgr.AddIncludeFile(filename.str(), includeLoc.End, includedFile);
111   if (!bufferID)
112     return failure();
113 
114   curBufferID = bufferID;
115   curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
116   curPtr = curBuffer.begin();
117   return success();
118 }
119 
emitError(SMRange loc,const Twine & msg)120 Token Lexer::emitError(SMRange loc, const Twine &msg) {
121   diagEngine.emitError(loc, msg);
122   return formToken(Token::error, loc.Start.getPointer());
123 }
emitErrorAndNote(SMRange loc,const Twine & msg,SMRange noteLoc,const Twine & note)124 Token Lexer::emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
125                               const Twine &note) {
126   diagEngine.emitError(loc, msg)->attachNote(note, noteLoc);
127   return formToken(Token::error, loc.Start.getPointer());
128 }
emitError(const char * loc,const Twine & msg)129 Token Lexer::emitError(const char *loc, const Twine &msg) {
130   return emitError(
131       SMRange(SMLoc::getFromPointer(loc), SMLoc::getFromPointer(loc + 1)), msg);
132 }
133 
getNextChar()134 int Lexer::getNextChar() {
135   char curChar = *curPtr++;
136   switch (curChar) {
137   default:
138     return static_cast<unsigned char>(curChar);
139   case 0: {
140     // A nul character in the stream is either the end of the current buffer
141     // or a random nul in the file. Disambiguate that here.
142     if (curPtr - 1 != curBuffer.end())
143       return 0;
144 
145     // Otherwise, return end of file.
146     --curPtr;
147     return EOF;
148   }
149   case '\n':
150   case '\r':
151     // Handle the newline character by ignoring it and incrementing the line
152     // count. However, be careful about 'dos style' files with \n\r in them.
153     // Only treat a \n\r or \r\n as a single line.
154     if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
155       ++curPtr;
156     return '\n';
157   }
158 }
159 
lexToken()160 Token Lexer::lexToken() {
161   while (true) {
162     const char *tokStart = curPtr;
163 
164     // Check to see if this token is at the code completion location.
165     if (tokStart == codeCompletionLocation)
166       return formToken(Token::code_complete, tokStart);
167 
168     // This always consumes at least one character.
169     int curChar = getNextChar();
170     switch (curChar) {
171     default:
172       // Handle identifiers: [a-zA-Z_]
173       if (isalpha(curChar) || curChar == '_')
174         return lexIdentifier(tokStart);
175 
176       // Unknown character, emit an error.
177       return emitError(tokStart, "unexpected character");
178     case EOF: {
179       // Return EOF denoting the end of lexing.
180       Token eof = formToken(Token::eof, tokStart);
181 
182       // Check to see if we are in an included file.
183       SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(curBufferID);
184       if (parentIncludeLoc.isValid()) {
185         curBufferID = srcMgr.FindBufferContainingLoc(parentIncludeLoc);
186         curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
187         curPtr = parentIncludeLoc.getPointer();
188       }
189 
190       return eof;
191     }
192 
193     // Lex punctuation.
194     case '-':
195       if (*curPtr == '>') {
196         ++curPtr;
197         return formToken(Token::arrow, tokStart);
198       }
199       return emitError(tokStart, "unexpected character");
200     case ':':
201       return formToken(Token::colon, tokStart);
202     case ',':
203       return formToken(Token::comma, tokStart);
204     case '.':
205       return formToken(Token::dot, tokStart);
206     case '=':
207       if (*curPtr == '>') {
208         ++curPtr;
209         return formToken(Token::equal_arrow, tokStart);
210       }
211       return formToken(Token::equal, tokStart);
212     case ';':
213       return formToken(Token::semicolon, tokStart);
214     case '[':
215       if (*curPtr == '{') {
216         ++curPtr;
217         return lexString(tokStart, /*isStringBlock=*/true);
218       }
219       return formToken(Token::l_square, tokStart);
220     case ']':
221       return formToken(Token::r_square, tokStart);
222 
223     case '<':
224       return formToken(Token::less, tokStart);
225     case '>':
226       return formToken(Token::greater, tokStart);
227     case '{':
228       return formToken(Token::l_brace, tokStart);
229     case '}':
230       return formToken(Token::r_brace, tokStart);
231     case '(':
232       return formToken(Token::l_paren, tokStart);
233     case ')':
234       return formToken(Token::r_paren, tokStart);
235     case '/':
236       if (*curPtr == '/') {
237         lexComment();
238         continue;
239       }
240       return emitError(tokStart, "unexpected character");
241 
242     // Ignore whitespace characters.
243     case 0:
244     case ' ':
245     case '\t':
246     case '\n':
247       return lexToken();
248 
249     case '#':
250       return lexDirective(tokStart);
251     case '"':
252       return lexString(tokStart, /*isStringBlock=*/false);
253 
254     case '0':
255     case '1':
256     case '2':
257     case '3':
258     case '4':
259     case '5':
260     case '6':
261     case '7':
262     case '8':
263     case '9':
264       return lexNumber(tokStart);
265     }
266   }
267 }
268 
269 /// Skip a comment line, starting with a '//'.
lexComment()270 void Lexer::lexComment() {
271   // Advance over the second '/' in a '//' comment.
272   assert(*curPtr == '/');
273   ++curPtr;
274 
275   while (true) {
276     switch (*curPtr++) {
277     case '\n':
278     case '\r':
279       // Newline is end of comment.
280       return;
281     case 0:
282       // If this is the end of the buffer, end the comment.
283       if (curPtr - 1 == curBuffer.end()) {
284         --curPtr;
285         return;
286       }
287       [[fallthrough]];
288     default:
289       // Skip over other characters.
290       break;
291     }
292   }
293 }
294 
lexDirective(const char * tokStart)295 Token Lexer::lexDirective(const char *tokStart) {
296   // Match the rest with an identifier regex: [0-9a-zA-Z_]*
297   while (isalnum(*curPtr) || *curPtr == '_')
298     ++curPtr;
299 
300   StringRef str(tokStart, curPtr - tokStart);
301   return Token(Token::directive, str);
302 }
303 
lexIdentifier(const char * tokStart)304 Token Lexer::lexIdentifier(const char *tokStart) {
305   // Match the rest of the identifier regex: [0-9a-zA-Z_]*
306   while (isalnum(*curPtr) || *curPtr == '_')
307     ++curPtr;
308 
309   // Check to see if this identifier is a keyword.
310   StringRef str(tokStart, curPtr - tokStart);
311   Token::Kind kind = StringSwitch<Token::Kind>(str)
312                          .Case("attr", Token::kw_attr)
313                          .Case("Attr", Token::kw_Attr)
314                          .Case("erase", Token::kw_erase)
315                          .Case("let", Token::kw_let)
316                          .Case("Constraint", Token::kw_Constraint)
317                          .Case("not", Token::kw_not)
318                          .Case("op", Token::kw_op)
319                          .Case("Op", Token::kw_Op)
320                          .Case("OpName", Token::kw_OpName)
321                          .Case("Pattern", Token::kw_Pattern)
322                          .Case("replace", Token::kw_replace)
323                          .Case("return", Token::kw_return)
324                          .Case("rewrite", Token::kw_rewrite)
325                          .Case("Rewrite", Token::kw_Rewrite)
326                          .Case("type", Token::kw_type)
327                          .Case("Type", Token::kw_Type)
328                          .Case("TypeRange", Token::kw_TypeRange)
329                          .Case("Value", Token::kw_Value)
330                          .Case("ValueRange", Token::kw_ValueRange)
331                          .Case("with", Token::kw_with)
332                          .Case("_", Token::underscore)
333                          .Default(Token::identifier);
334   return Token(kind, str);
335 }
336 
lexNumber(const char * tokStart)337 Token Lexer::lexNumber(const char *tokStart) {
338   assert(isdigit(curPtr[-1]));
339 
340   // Handle the normal decimal case.
341   while (isdigit(*curPtr))
342     ++curPtr;
343 
344   return formToken(Token::integer, tokStart);
345 }
346 
lexString(const char * tokStart,bool isStringBlock)347 Token Lexer::lexString(const char *tokStart, bool isStringBlock) {
348   while (true) {
349     // Check to see if there is a code completion location within the string. In
350     // these cases we generate a completion location and place the currently
351     // lexed string within the token (without the quotes). This allows for the
352     // parser to use the partially lexed string when computing the completion
353     // results.
354     if (curPtr == codeCompletionLocation) {
355       return formToken(Token::code_complete_string,
356                        tokStart + (isStringBlock ? 2 : 1));
357     }
358 
359     switch (*curPtr++) {
360     case '"':
361       // If this is a string block, we only end the string when we encounter a
362       // `}]`.
363       if (!isStringBlock)
364         return formToken(Token::string, tokStart);
365       continue;
366     case '}':
367       // If this is a string block, we only end the string when we encounter a
368       // `}]`.
369       if (!isStringBlock || *curPtr != ']')
370         continue;
371       ++curPtr;
372       return formToken(Token::string_block, tokStart);
373     case 0: {
374       // If this is a random nul character in the middle of a string, just
375       // include it. If it is the end of file, then it is an error.
376       if (curPtr - 1 != curBuffer.end())
377         continue;
378       --curPtr;
379 
380       StringRef expectedEndStr = isStringBlock ? "}]" : "\"";
381       return emitError(curPtr - 1,
382                        "expected '" + expectedEndStr + "' in string literal");
383     }
384 
385     case '\n':
386     case '\v':
387     case '\f':
388       // String blocks allow multiple lines.
389       if (!isStringBlock)
390         return emitError(curPtr - 1, "expected '\"' in string literal");
391       continue;
392 
393     case '\\':
394       // Handle explicitly a few escapes.
395       if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' ||
396           *curPtr == 't') {
397         ++curPtr;
398       } else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) {
399         // Support \xx for two hex digits.
400         curPtr += 2;
401       } else {
402         return emitError(curPtr - 1, "unknown escape in string literal");
403       }
404       continue;
405 
406     default:
407       continue;
408     }
409   }
410 }
411