1 //===- Lexer.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "Lexer.h"
10 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
11 #include "mlir/Tools/PDLL/Parser/CodeComplete.h"
12 #include "llvm/ADT/StringExtras.h"
13 #include "llvm/ADT/StringSwitch.h"
14 #include "llvm/Support/SourceMgr.h"
15
16 using namespace mlir;
17 using namespace mlir::pdll;
18
19 //===----------------------------------------------------------------------===//
20 // Token
21 //===----------------------------------------------------------------------===//
22
getStringValue() const23 std::string Token::getStringValue() const {
24 assert(getKind() == string || getKind() == string_block ||
25 getKind() == code_complete_string);
26
27 // Start by dropping the quotes.
28 StringRef bytes = getSpelling();
29 if (is(string))
30 bytes = bytes.drop_front().drop_back();
31 else if (is(string_block))
32 bytes = bytes.drop_front(2).drop_back(2);
33
34 std::string result;
35 result.reserve(bytes.size());
36 for (unsigned i = 0, e = bytes.size(); i != e;) {
37 auto c = bytes[i++];
38 if (c != '\\') {
39 result.push_back(c);
40 continue;
41 }
42
43 assert(i + 1 <= e && "invalid string should be caught by lexer");
44 auto c1 = bytes[i++];
45 switch (c1) {
46 case '"':
47 case '\\':
48 result.push_back(c1);
49 continue;
50 case 'n':
51 result.push_back('\n');
52 continue;
53 case 't':
54 result.push_back('\t');
55 continue;
56 default:
57 break;
58 }
59
60 assert(i + 1 <= e && "invalid string should be caught by lexer");
61 auto c2 = bytes[i++];
62
63 assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape");
64 result.push_back((llvm::hexDigitValue(c1) << 4) | llvm::hexDigitValue(c2));
65 }
66
67 return result;
68 }
69
70 //===----------------------------------------------------------------------===//
71 // Lexer
72 //===----------------------------------------------------------------------===//
73
Lexer(llvm::SourceMgr & mgr,ast::DiagnosticEngine & diagEngine,CodeCompleteContext * codeCompleteContext)74 Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine,
75 CodeCompleteContext *codeCompleteContext)
76 : srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false),
77 codeCompletionLocation(nullptr) {
78 curBufferID = mgr.getMainFileID();
79 curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
80 curPtr = curBuffer.begin();
81
82 // Set the code completion location if necessary.
83 if (codeCompleteContext) {
84 codeCompletionLocation =
85 codeCompleteContext->getCodeCompleteLoc().getPointer();
86 }
87
88 // If the diag engine has no handler, add a default that emits to the
89 // SourceMgr.
90 if (!diagEngine.getHandlerFn()) {
91 diagEngine.setHandlerFn([&](const ast::Diagnostic &diag) {
92 srcMgr.PrintMessage(diag.getLocation().Start, diag.getSeverity(),
93 diag.getMessage());
94 for (const ast::Diagnostic ¬e : diag.getNotes())
95 srcMgr.PrintMessage(note.getLocation().Start, note.getSeverity(),
96 note.getMessage());
97 });
98 addedHandlerToDiagEngine = true;
99 }
100 }
101
~Lexer()102 Lexer::~Lexer() {
103 if (addedHandlerToDiagEngine)
104 diagEngine.setHandlerFn(nullptr);
105 }
106
pushInclude(StringRef filename,SMRange includeLoc)107 LogicalResult Lexer::pushInclude(StringRef filename, SMRange includeLoc) {
108 std::string includedFile;
109 int bufferID =
110 srcMgr.AddIncludeFile(filename.str(), includeLoc.End, includedFile);
111 if (!bufferID)
112 return failure();
113
114 curBufferID = bufferID;
115 curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
116 curPtr = curBuffer.begin();
117 return success();
118 }
119
emitError(SMRange loc,const Twine & msg)120 Token Lexer::emitError(SMRange loc, const Twine &msg) {
121 diagEngine.emitError(loc, msg);
122 return formToken(Token::error, loc.Start.getPointer());
123 }
emitErrorAndNote(SMRange loc,const Twine & msg,SMRange noteLoc,const Twine & note)124 Token Lexer::emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
125 const Twine ¬e) {
126 diagEngine.emitError(loc, msg)->attachNote(note, noteLoc);
127 return formToken(Token::error, loc.Start.getPointer());
128 }
emitError(const char * loc,const Twine & msg)129 Token Lexer::emitError(const char *loc, const Twine &msg) {
130 return emitError(
131 SMRange(SMLoc::getFromPointer(loc), SMLoc::getFromPointer(loc + 1)), msg);
132 }
133
getNextChar()134 int Lexer::getNextChar() {
135 char curChar = *curPtr++;
136 switch (curChar) {
137 default:
138 return static_cast<unsigned char>(curChar);
139 case 0: {
140 // A nul character in the stream is either the end of the current buffer
141 // or a random nul in the file. Disambiguate that here.
142 if (curPtr - 1 != curBuffer.end())
143 return 0;
144
145 // Otherwise, return end of file.
146 --curPtr;
147 return EOF;
148 }
149 case '\n':
150 case '\r':
151 // Handle the newline character by ignoring it and incrementing the line
152 // count. However, be careful about 'dos style' files with \n\r in them.
153 // Only treat a \n\r or \r\n as a single line.
154 if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
155 ++curPtr;
156 return '\n';
157 }
158 }
159
lexToken()160 Token Lexer::lexToken() {
161 while (true) {
162 const char *tokStart = curPtr;
163
164 // Check to see if this token is at the code completion location.
165 if (tokStart == codeCompletionLocation)
166 return formToken(Token::code_complete, tokStart);
167
168 // This always consumes at least one character.
169 int curChar = getNextChar();
170 switch (curChar) {
171 default:
172 // Handle identifiers: [a-zA-Z_]
173 if (isalpha(curChar) || curChar == '_')
174 return lexIdentifier(tokStart);
175
176 // Unknown character, emit an error.
177 return emitError(tokStart, "unexpected character");
178 case EOF: {
179 // Return EOF denoting the end of lexing.
180 Token eof = formToken(Token::eof, tokStart);
181
182 // Check to see if we are in an included file.
183 SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(curBufferID);
184 if (parentIncludeLoc.isValid()) {
185 curBufferID = srcMgr.FindBufferContainingLoc(parentIncludeLoc);
186 curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
187 curPtr = parentIncludeLoc.getPointer();
188 }
189
190 return eof;
191 }
192
193 // Lex punctuation.
194 case '-':
195 if (*curPtr == '>') {
196 ++curPtr;
197 return formToken(Token::arrow, tokStart);
198 }
199 return emitError(tokStart, "unexpected character");
200 case ':':
201 return formToken(Token::colon, tokStart);
202 case ',':
203 return formToken(Token::comma, tokStart);
204 case '.':
205 return formToken(Token::dot, tokStart);
206 case '=':
207 if (*curPtr == '>') {
208 ++curPtr;
209 return formToken(Token::equal_arrow, tokStart);
210 }
211 return formToken(Token::equal, tokStart);
212 case ';':
213 return formToken(Token::semicolon, tokStart);
214 case '[':
215 if (*curPtr == '{') {
216 ++curPtr;
217 return lexString(tokStart, /*isStringBlock=*/true);
218 }
219 return formToken(Token::l_square, tokStart);
220 case ']':
221 return formToken(Token::r_square, tokStart);
222
223 case '<':
224 return formToken(Token::less, tokStart);
225 case '>':
226 return formToken(Token::greater, tokStart);
227 case '{':
228 return formToken(Token::l_brace, tokStart);
229 case '}':
230 return formToken(Token::r_brace, tokStart);
231 case '(':
232 return formToken(Token::l_paren, tokStart);
233 case ')':
234 return formToken(Token::r_paren, tokStart);
235 case '/':
236 if (*curPtr == '/') {
237 lexComment();
238 continue;
239 }
240 return emitError(tokStart, "unexpected character");
241
242 // Ignore whitespace characters.
243 case 0:
244 case ' ':
245 case '\t':
246 case '\n':
247 return lexToken();
248
249 case '#':
250 return lexDirective(tokStart);
251 case '"':
252 return lexString(tokStart, /*isStringBlock=*/false);
253
254 case '0':
255 case '1':
256 case '2':
257 case '3':
258 case '4':
259 case '5':
260 case '6':
261 case '7':
262 case '8':
263 case '9':
264 return lexNumber(tokStart);
265 }
266 }
267 }
268
269 /// Skip a comment line, starting with a '//'.
lexComment()270 void Lexer::lexComment() {
271 // Advance over the second '/' in a '//' comment.
272 assert(*curPtr == '/');
273 ++curPtr;
274
275 while (true) {
276 switch (*curPtr++) {
277 case '\n':
278 case '\r':
279 // Newline is end of comment.
280 return;
281 case 0:
282 // If this is the end of the buffer, end the comment.
283 if (curPtr - 1 == curBuffer.end()) {
284 --curPtr;
285 return;
286 }
287 [[fallthrough]];
288 default:
289 // Skip over other characters.
290 break;
291 }
292 }
293 }
294
lexDirective(const char * tokStart)295 Token Lexer::lexDirective(const char *tokStart) {
296 // Match the rest with an identifier regex: [0-9a-zA-Z_]*
297 while (isalnum(*curPtr) || *curPtr == '_')
298 ++curPtr;
299
300 StringRef str(tokStart, curPtr - tokStart);
301 return Token(Token::directive, str);
302 }
303
lexIdentifier(const char * tokStart)304 Token Lexer::lexIdentifier(const char *tokStart) {
305 // Match the rest of the identifier regex: [0-9a-zA-Z_]*
306 while (isalnum(*curPtr) || *curPtr == '_')
307 ++curPtr;
308
309 // Check to see if this identifier is a keyword.
310 StringRef str(tokStart, curPtr - tokStart);
311 Token::Kind kind = StringSwitch<Token::Kind>(str)
312 .Case("attr", Token::kw_attr)
313 .Case("Attr", Token::kw_Attr)
314 .Case("erase", Token::kw_erase)
315 .Case("let", Token::kw_let)
316 .Case("Constraint", Token::kw_Constraint)
317 .Case("not", Token::kw_not)
318 .Case("op", Token::kw_op)
319 .Case("Op", Token::kw_Op)
320 .Case("OpName", Token::kw_OpName)
321 .Case("Pattern", Token::kw_Pattern)
322 .Case("replace", Token::kw_replace)
323 .Case("return", Token::kw_return)
324 .Case("rewrite", Token::kw_rewrite)
325 .Case("Rewrite", Token::kw_Rewrite)
326 .Case("type", Token::kw_type)
327 .Case("Type", Token::kw_Type)
328 .Case("TypeRange", Token::kw_TypeRange)
329 .Case("Value", Token::kw_Value)
330 .Case("ValueRange", Token::kw_ValueRange)
331 .Case("with", Token::kw_with)
332 .Case("_", Token::underscore)
333 .Default(Token::identifier);
334 return Token(kind, str);
335 }
336
lexNumber(const char * tokStart)337 Token Lexer::lexNumber(const char *tokStart) {
338 assert(isdigit(curPtr[-1]));
339
340 // Handle the normal decimal case.
341 while (isdigit(*curPtr))
342 ++curPtr;
343
344 return formToken(Token::integer, tokStart);
345 }
346
lexString(const char * tokStart,bool isStringBlock)347 Token Lexer::lexString(const char *tokStart, bool isStringBlock) {
348 while (true) {
349 // Check to see if there is a code completion location within the string. In
350 // these cases we generate a completion location and place the currently
351 // lexed string within the token (without the quotes). This allows for the
352 // parser to use the partially lexed string when computing the completion
353 // results.
354 if (curPtr == codeCompletionLocation) {
355 return formToken(Token::code_complete_string,
356 tokStart + (isStringBlock ? 2 : 1));
357 }
358
359 switch (*curPtr++) {
360 case '"':
361 // If this is a string block, we only end the string when we encounter a
362 // `}]`.
363 if (!isStringBlock)
364 return formToken(Token::string, tokStart);
365 continue;
366 case '}':
367 // If this is a string block, we only end the string when we encounter a
368 // `}]`.
369 if (!isStringBlock || *curPtr != ']')
370 continue;
371 ++curPtr;
372 return formToken(Token::string_block, tokStart);
373 case 0: {
374 // If this is a random nul character in the middle of a string, just
375 // include it. If it is the end of file, then it is an error.
376 if (curPtr - 1 != curBuffer.end())
377 continue;
378 --curPtr;
379
380 StringRef expectedEndStr = isStringBlock ? "}]" : "\"";
381 return emitError(curPtr - 1,
382 "expected '" + expectedEndStr + "' in string literal");
383 }
384
385 case '\n':
386 case '\v':
387 case '\f':
388 // String blocks allow multiple lines.
389 if (!isStringBlock)
390 return emitError(curPtr - 1, "expected '\"' in string literal");
391 continue;
392
393 case '\\':
394 // Handle explicitly a few escapes.
395 if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' ||
396 *curPtr == 't') {
397 ++curPtr;
398 } else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) {
399 // Support \xx for two hex digits.
400 curPtr += 2;
401 } else {
402 return emitError(curPtr - 1, "unknown escape in string literal");
403 }
404 continue;
405
406 default:
407 continue;
408 }
409 }
410 }
411