1 //===- Parser.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/Parser/Parser.h" 10 #include "Lexer.h" 11 #include "mlir/Support/IndentedOstream.h" 12 #include "mlir/TableGen/Argument.h" 13 #include "mlir/TableGen/Attribute.h" 14 #include "mlir/TableGen/Constraint.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/Operator.h" 17 #include "mlir/Tools/PDLL/AST/Context.h" 18 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 19 #include "mlir/Tools/PDLL/AST/Nodes.h" 20 #include "mlir/Tools/PDLL/AST/Types.h" 21 #include "mlir/Tools/PDLL/ODS/Constraint.h" 22 #include "mlir/Tools/PDLL/ODS/Context.h" 23 #include "mlir/Tools/PDLL/ODS/Operation.h" 24 #include "mlir/Tools/PDLL/Parser/CodeComplete.h" 25 #include "llvm/ADT/StringExtras.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include "llvm/Support/FormatVariadic.h" 28 #include "llvm/Support/ManagedStatic.h" 29 #include "llvm/Support/SaveAndRestore.h" 30 #include "llvm/Support/ScopedPrinter.h" 31 #include "llvm/TableGen/Error.h" 32 #include "llvm/TableGen/Parser.h" 33 #include <optional> 34 #include <string> 35 36 using namespace mlir; 37 using namespace mlir::pdll; 38 39 //===----------------------------------------------------------------------===// 40 // Parser 41 //===----------------------------------------------------------------------===// 42 43 namespace { 44 class Parser { 45 public: 46 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 47 bool enableDocumentation, CodeCompleteContext *codeCompleteContext) 48 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext), 49 curToken(lexer.lexToken()), enableDocumentation(enableDocumentation), 50 typeTy(ast::TypeType::get(ctx)), valueTy(ast::ValueType::get(ctx)), 51 typeRangeTy(ast::TypeRangeType::get(ctx)), 52 valueRangeTy(ast::ValueRangeType::get(ctx)), 53 attrTy(ast::AttributeType::get(ctx)), 54 codeCompleteContext(codeCompleteContext) {} 55 56 /// Try to parse a new module. Returns nullptr in the case of failure. 57 FailureOr<ast::Module *> parseModule(); 58 59 private: 60 /// The current context of the parser. It allows for the parser to know a bit 61 /// about the construct it is nested within during parsing. This is used 62 /// specifically to provide additional verification during parsing, e.g. to 63 /// prevent using rewrites within a match context, matcher constraints within 64 /// a rewrite section, etc. 65 enum class ParserContext { 66 /// The parser is in the global context. 67 Global, 68 /// The parser is currently within a Constraint, which disallows all types 69 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.). 70 Constraint, 71 /// The parser is currently within the matcher portion of a Pattern, which 72 /// is allows a terminal operation rewrite statement but no other rewrite 73 /// transformations. 74 PatternMatch, 75 /// The parser is currently within a Rewrite, which disallows calls to 76 /// constraints, requires operation expressions to have names, etc. 77 Rewrite, 78 }; 79 80 /// The current specification context of an operations result type. This 81 /// indicates how the result types of an operation may be inferred. 82 enum class OpResultTypeContext { 83 /// The result types of the operation are not known to be inferred. 84 Explicit, 85 /// The result types of the operation are inferred from the root input of a 86 /// `replace` statement. 87 Replacement, 88 /// The result types of the operation are inferred by using the 89 /// `InferTypeOpInterface` interface provided by the operation. 90 Interface, 91 }; 92 93 //===--------------------------------------------------------------------===// 94 // Parsing 95 //===--------------------------------------------------------------------===// 96 97 /// Push a new decl scope onto the lexer. 98 ast::DeclScope *pushDeclScope() { 99 ast::DeclScope *newScope = 100 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); 101 return (curDeclScope = newScope); 102 } 103 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } 104 105 /// Pop the last decl scope from the lexer. 106 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } 107 108 /// Parse the body of an AST module. 109 LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls); 110 111 /// Try to convert the given expression to `type`. Returns failure and emits 112 /// an error if a conversion is not viable. On failure, `noteAttachFn` is 113 /// invoked to attach notes to the emitted error diagnostic. On success, 114 /// `expr` is updated to the expression used to convert to `type`. 115 LogicalResult convertExpressionTo( 116 ast::Expr *&expr, ast::Type type, 117 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); 118 LogicalResult 119 convertOpExpressionTo(ast::Expr *&expr, ast::OperationType exprType, 120 ast::Type type, 121 function_ref<ast::InFlightDiagnostic()> emitErrorFn); 122 LogicalResult convertTupleExpressionTo( 123 ast::Expr *&expr, ast::TupleType exprType, ast::Type type, 124 function_ref<ast::InFlightDiagnostic()> emitErrorFn, 125 function_ref<void(ast::Diagnostic &diag)> noteAttachFn); 126 127 /// Given an operation expression, convert it to a Value or ValueRange 128 /// typed expression. 129 ast::Expr *convertOpToValue(const ast::Expr *opExpr); 130 131 /// Lookup ODS information for the given operation, returns nullptr if no 132 /// information is found. 133 const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) { 134 return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr; 135 } 136 137 /// Process the given documentation string, or return an empty string if 138 /// documentation isn't enabled. 139 StringRef processDoc(StringRef doc) { 140 return enableDocumentation ? doc : StringRef(); 141 } 142 143 /// Process the given documentation string and format it, or return an empty 144 /// string if documentation isn't enabled. 145 std::string processAndFormatDoc(const Twine &doc) { 146 if (!enableDocumentation) 147 return ""; 148 std::string docStr; 149 { 150 llvm::raw_string_ostream docOS(docStr); 151 raw_indented_ostream(docOS).printReindented( 152 StringRef(docStr).rtrim(" \t")); 153 } 154 return docStr; 155 } 156 157 //===--------------------------------------------------------------------===// 158 // Directives 159 160 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls); 161 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls); 162 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc, 163 SmallVectorImpl<ast::Decl *> &decls); 164 165 /// Process the records of a parsed tablegen include file. 166 void processTdIncludeRecords(const llvm::RecordKeeper &tdRecords, 167 SmallVectorImpl<ast::Decl *> &decls); 168 169 /// Create a user defined native constraint for a constraint imported from 170 /// ODS. 171 template <typename ConstraintT> 172 ast::Decl * 173 createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock, 174 SMRange loc, ast::Type type, 175 StringRef nativeType, StringRef docString); 176 template <typename ConstraintT> 177 ast::Decl * 178 createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 179 SMRange loc, ast::Type type, 180 StringRef nativeType); 181 182 //===--------------------------------------------------------------------===// 183 // Decls 184 185 /// This structure contains the set of pattern metadata that may be parsed. 186 struct ParsedPatternMetadata { 187 std::optional<uint16_t> benefit; 188 bool hasBoundedRecursion = false; 189 }; 190 191 FailureOr<ast::Decl *> parseTopLevelDecl(); 192 FailureOr<ast::NamedAttributeDecl *> 193 parseNamedAttributeDecl(std::optional<StringRef> parentOpName); 194 195 /// Parse an argument variable as part of the signature of a 196 /// UserConstraintDecl or UserRewriteDecl. 197 FailureOr<ast::VariableDecl *> parseArgumentDecl(); 198 199 /// Parse a result variable as part of the signature of a UserConstraintDecl 200 /// or UserRewriteDecl. 201 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum); 202 203 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being 204 /// defined in a non-global context. 205 FailureOr<ast::UserConstraintDecl *> 206 parseUserConstraintDecl(bool isInline = false); 207 208 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a 209 /// non-global context, such as within a Pattern/Constraint/etc. 210 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl(); 211 212 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 213 /// PDLL constructs. 214 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl( 215 const ast::Name &name, bool isInline, 216 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 217 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 218 219 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being 220 /// defined in a non-global context. 221 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false); 222 223 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a 224 /// non-global context, such as within a Pattern/Rewrite/etc. 225 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl(); 226 227 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 228 /// PDLL constructs. 229 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl( 230 const ast::Name &name, bool isInline, 231 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 232 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 233 234 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have 235 /// effectively the same syntax, and only differ on slight semantics (given 236 /// the different parsing contexts). 237 template <typename T, typename ParseUserPDLLDeclFnT> 238 FailureOr<T *> parseUserConstraintOrRewriteDecl( 239 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 240 StringRef anonymousNamePrefix, bool isInline); 241 242 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl. 243 /// These decls have effectively the same syntax. 244 template <typename T> 245 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl( 246 const ast::Name &name, bool isInline, 247 ArrayRef<ast::VariableDecl *> arguments, 248 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 249 250 /// Parse the functional signature (i.e. the arguments and results) of a 251 /// UserConstraintDecl or UserRewriteDecl. 252 LogicalResult parseUserConstraintOrRewriteSignature( 253 SmallVectorImpl<ast::VariableDecl *> &arguments, 254 SmallVectorImpl<ast::VariableDecl *> &results, 255 ast::DeclScope *&argumentScope, ast::Type &resultType); 256 257 /// Validate the return (which if present is specified by bodyIt) of a 258 /// UserConstraintDecl or UserRewriteDecl. 259 LogicalResult validateUserConstraintOrRewriteReturn( 260 StringRef declType, ast::CompoundStmt *body, 261 ArrayRef<ast::Stmt *>::iterator bodyIt, 262 ArrayRef<ast::Stmt *>::iterator bodyE, 263 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType); 264 265 FailureOr<ast::CompoundStmt *> 266 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 267 bool expectTerminalSemicolon = true); 268 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody(); 269 FailureOr<ast::Decl *> parsePatternDecl(); 270 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); 271 272 /// Check to see if a decl has already been defined with the given name, if 273 /// one has emit and error and return failure. Returns success otherwise. 274 LogicalResult checkDefineNamedDecl(const ast::Name &name); 275 276 /// Try to define a variable decl with the given components, returns the 277 /// variable on success. 278 FailureOr<ast::VariableDecl *> 279 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 280 ast::Expr *initExpr, 281 ArrayRef<ast::ConstraintRef> constraints); 282 FailureOr<ast::VariableDecl *> 283 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 284 ArrayRef<ast::ConstraintRef> constraints); 285 286 /// Parse the constraint reference list for a variable decl. 287 LogicalResult parseVariableDeclConstraintList( 288 SmallVectorImpl<ast::ConstraintRef> &constraints); 289 290 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. 291 FailureOr<ast::Expr *> parseTypeConstraintExpr(); 292 293 /// Try to parse a single reference to a constraint. `typeConstraint` is the 294 /// location of a previously parsed type constraint for the entity that will 295 /// be constrained by the parsed constraint. `existingConstraints` are any 296 /// existing constraints that have already been parsed for the same entity 297 /// that will be constrained by this constraint. `allowInlineTypeConstraints` 298 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`. 299 FailureOr<ast::ConstraintRef> 300 parseConstraint(std::optional<SMRange> &typeConstraint, 301 ArrayRef<ast::ConstraintRef> existingConstraints, 302 bool allowInlineTypeConstraints); 303 304 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl 305 /// argument or result variable. The constraints for these variables do not 306 /// allow inline type constraints, and only permit a single constraint. 307 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint(); 308 309 //===--------------------------------------------------------------------===// 310 // Exprs 311 312 FailureOr<ast::Expr *> parseExpr(); 313 314 /// Identifier expressions. 315 FailureOr<ast::Expr *> parseAttributeExpr(); 316 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr, 317 bool isNegated = false); 318 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc); 319 FailureOr<ast::Expr *> parseIdentifierExpr(); 320 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr(); 321 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr(); 322 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); 323 FailureOr<ast::Expr *> parseNegatedExpr(); 324 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); 325 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); 326 FailureOr<ast::Expr *> 327 parseOperationExpr(OpResultTypeContext inputResultTypeContext = 328 OpResultTypeContext::Explicit); 329 FailureOr<ast::Expr *> parseTupleExpr(); 330 FailureOr<ast::Expr *> parseTypeExpr(); 331 FailureOr<ast::Expr *> parseUnderscoreExpr(); 332 333 //===--------------------------------------------------------------------===// 334 // Stmts 335 336 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); 337 FailureOr<ast::CompoundStmt *> parseCompoundStmt(); 338 FailureOr<ast::EraseStmt *> parseEraseStmt(); 339 FailureOr<ast::LetStmt *> parseLetStmt(); 340 FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); 341 FailureOr<ast::ReturnStmt *> parseReturnStmt(); 342 FailureOr<ast::RewriteStmt *> parseRewriteStmt(); 343 344 //===--------------------------------------------------------------------===// 345 // Creation+Analysis 346 //===--------------------------------------------------------------------===// 347 348 //===--------------------------------------------------------------------===// 349 // Decls 350 351 /// Try to extract a callable from the given AST node. Returns nullptr on 352 /// failure. 353 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node); 354 355 /// Try to create a pattern decl with the given components, returning the 356 /// Pattern on success. 357 FailureOr<ast::PatternDecl *> 358 createPatternDecl(SMRange loc, const ast::Name *name, 359 const ParsedPatternMetadata &metadata, 360 ast::CompoundStmt *body); 361 362 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set 363 /// of results, defined as part of the signature. 364 ast::Type 365 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results); 366 367 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl. 368 template <typename T> 369 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl( 370 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 371 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 372 ast::CompoundStmt *body); 373 374 /// Try to create a variable decl with the given components, returning the 375 /// Variable on success. 376 FailureOr<ast::VariableDecl *> 377 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 378 ArrayRef<ast::ConstraintRef> constraints); 379 380 /// Create a variable for an argument or result defined as part of the 381 /// signature of a UserConstraintDecl/UserRewriteDecl. 382 FailureOr<ast::VariableDecl *> 383 createArgOrResultVariableDecl(StringRef name, SMRange loc, 384 const ast::ConstraintRef &constraint); 385 386 /// Validate the constraints used to constraint a variable decl. 387 /// `inferredType` is the type of the variable inferred by the constraints 388 /// within the list, and is updated to the most refined type as determined by 389 /// the constraints. Returns success if the constraint list is valid, failure 390 /// otherwise. 391 LogicalResult 392 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 393 ast::Type &inferredType); 394 /// Validate a single reference to a constraint. `inferredType` contains the 395 /// currently inferred variabled type and is refined within the type defined 396 /// by the constraint. Returns success if the constraint is valid, failure 397 /// otherwise. 398 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, 399 ast::Type &inferredType); 400 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); 401 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); 402 403 //===--------------------------------------------------------------------===// 404 // Exprs 405 406 FailureOr<ast::CallExpr *> 407 createCallExpr(SMRange loc, ast::Expr *parentExpr, 408 MutableArrayRef<ast::Expr *> arguments, 409 bool isNegated = false); 410 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl); 411 FailureOr<ast::DeclRefExpr *> 412 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 413 ArrayRef<ast::ConstraintRef> constraints); 414 FailureOr<ast::MemberAccessExpr *> 415 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); 416 417 /// Validate the member access `name` into the given parent expression. On 418 /// success, this also returns the type of the member accessed. 419 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, 420 StringRef name, SMRange loc); 421 FailureOr<ast::OperationExpr *> 422 createOperationExpr(SMRange loc, const ast::OpNameDecl *name, 423 OpResultTypeContext resultTypeContext, 424 SmallVectorImpl<ast::Expr *> &operands, 425 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 426 SmallVectorImpl<ast::Expr *> &results); 427 LogicalResult 428 validateOperationOperands(SMRange loc, std::optional<StringRef> name, 429 const ods::Operation *odsOp, 430 SmallVectorImpl<ast::Expr *> &operands); 431 LogicalResult validateOperationResults(SMRange loc, 432 std::optional<StringRef> name, 433 const ods::Operation *odsOp, 434 SmallVectorImpl<ast::Expr *> &results); 435 void checkOperationResultTypeInferrence(SMRange loc, StringRef name, 436 const ods::Operation *odsOp); 437 LogicalResult validateOperationOperandsOrResults( 438 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc, 439 std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values, 440 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 441 ast::RangeType rangeTy); 442 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, 443 ArrayRef<ast::Expr *> elements, 444 ArrayRef<StringRef> elementNames); 445 446 //===--------------------------------------------------------------------===// 447 // Stmts 448 449 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp); 450 FailureOr<ast::ReplaceStmt *> 451 createReplaceStmt(SMRange loc, ast::Expr *rootOp, 452 MutableArrayRef<ast::Expr *> replValues); 453 FailureOr<ast::RewriteStmt *> 454 createRewriteStmt(SMRange loc, ast::Expr *rootOp, 455 ast::CompoundStmt *rewriteBody); 456 457 //===--------------------------------------------------------------------===// 458 // Code Completion 459 //===--------------------------------------------------------------------===// 460 461 /// The set of various code completion methods. Every completion method 462 /// returns `failure` to stop the parsing process after providing completion 463 /// results. 464 465 LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr); 466 LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName); 467 LogicalResult codeCompleteConstraintName(ast::Type inferredType, 468 bool allowInlineTypeConstraints); 469 LogicalResult codeCompleteDialectName(); 470 LogicalResult codeCompleteOperationName(StringRef dialectName); 471 LogicalResult codeCompletePatternMetadata(); 472 LogicalResult codeCompleteIncludeFilename(StringRef curPath); 473 474 void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs); 475 void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName, 476 unsigned currentNumOperands); 477 void codeCompleteOperationResultsSignature(std::optional<StringRef> opName, 478 unsigned currentNumResults); 479 480 //===--------------------------------------------------------------------===// 481 // Lexer Utilities 482 //===--------------------------------------------------------------------===// 483 484 /// If the current token has the specified kind, consume it and return true. 485 /// If not, return false. 486 bool consumeIf(Token::Kind kind) { 487 if (curToken.isNot(kind)) 488 return false; 489 consumeToken(kind); 490 return true; 491 } 492 493 /// Advance the current lexer onto the next token. 494 void consumeToken() { 495 assert(curToken.isNot(Token::eof, Token::error) && 496 "shouldn't advance past EOF or errors"); 497 curToken = lexer.lexToken(); 498 } 499 500 /// Advance the current lexer onto the next token, asserting what the expected 501 /// current token is. This is preferred to the above method because it leads 502 /// to more self-documenting code with better checking. 503 void consumeToken(Token::Kind kind) { 504 assert(curToken.is(kind) && "consumed an unexpected token"); 505 consumeToken(); 506 } 507 508 /// Reset the lexer to the location at the given position. 509 void resetToken(SMRange tokLoc) { 510 lexer.resetPointer(tokLoc.Start.getPointer()); 511 curToken = lexer.lexToken(); 512 } 513 514 /// Consume the specified token if present and return success. On failure, 515 /// output a diagnostic and return failure. 516 LogicalResult parseToken(Token::Kind kind, const Twine &msg) { 517 if (curToken.getKind() != kind) 518 return emitError(curToken.getLoc(), msg); 519 consumeToken(); 520 return success(); 521 } 522 LogicalResult emitError(SMRange loc, const Twine &msg) { 523 lexer.emitError(loc, msg); 524 return failure(); 525 } 526 LogicalResult emitError(const Twine &msg) { 527 return emitError(curToken.getLoc(), msg); 528 } 529 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, 530 const Twine ¬e) { 531 lexer.emitErrorAndNote(loc, msg, noteLoc, note); 532 return failure(); 533 } 534 535 //===--------------------------------------------------------------------===// 536 // Fields 537 //===--------------------------------------------------------------------===// 538 539 /// The owning AST context. 540 ast::Context &ctx; 541 542 /// The lexer of this parser. 543 Lexer lexer; 544 545 /// The current token within the lexer. 546 Token curToken; 547 548 /// A flag indicating if the parser should add documentation to AST nodes when 549 /// viable. 550 bool enableDocumentation; 551 552 /// The most recently defined decl scope. 553 ast::DeclScope *curDeclScope = nullptr; 554 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; 555 556 /// The current context of the parser. 557 ParserContext parserContext = ParserContext::Global; 558 559 /// Cached types to simplify verification and expression creation. 560 ast::Type typeTy, valueTy; 561 ast::RangeType typeRangeTy, valueRangeTy; 562 ast::Type attrTy; 563 564 /// A counter used when naming anonymous constraints and rewrites. 565 unsigned anonymousDeclNameCounter = 0; 566 567 /// The optional code completion context. 568 CodeCompleteContext *codeCompleteContext; 569 }; 570 } // namespace 571 572 FailureOr<ast::Module *> Parser::parseModule() { 573 SMLoc moduleLoc = curToken.getStartLoc(); 574 pushDeclScope(); 575 576 // Parse the top-level decls of the module. 577 SmallVector<ast::Decl *> decls; 578 if (failed(parseModuleBody(decls))) 579 return popDeclScope(), failure(); 580 581 popDeclScope(); 582 return ast::Module::create(ctx, moduleLoc, decls); 583 } 584 585 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) { 586 while (curToken.isNot(Token::eof)) { 587 if (curToken.is(Token::directive)) { 588 if (failed(parseDirective(decls))) 589 return failure(); 590 continue; 591 } 592 593 FailureOr<ast::Decl *> decl = parseTopLevelDecl(); 594 if (failed(decl)) 595 return failure(); 596 decls.push_back(*decl); 597 } 598 return success(); 599 } 600 601 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { 602 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr, 603 valueRangeTy); 604 } 605 606 LogicalResult Parser::convertExpressionTo( 607 ast::Expr *&expr, ast::Type type, 608 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { 609 ast::Type exprType = expr->getType(); 610 if (exprType == type) 611 return success(); 612 613 auto emitConvertError = [&]() -> ast::InFlightDiagnostic { 614 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( 615 expr->getLoc(), llvm::formatv("unable to convert expression of type " 616 "`{0}` to the expected type of " 617 "`{1}`", 618 exprType, type)); 619 if (noteAttachFn) 620 noteAttachFn(*diag); 621 return diag; 622 }; 623 624 if (auto exprOpType = dyn_cast<ast::OperationType>(exprType)) 625 return convertOpExpressionTo(expr, exprOpType, type, emitConvertError); 626 627 // FIXME: Decide how to allow/support converting a single result to multiple, 628 // and multiple to a single result. For now, we just allow Single->Range, 629 // but this isn't something really supported in the PDL dialect. We should 630 // figure out some way to support both. 631 if ((exprType == valueTy || exprType == valueRangeTy) && 632 (type == valueTy || type == valueRangeTy)) 633 return success(); 634 if ((exprType == typeTy || exprType == typeRangeTy) && 635 (type == typeTy || type == typeRangeTy)) 636 return success(); 637 638 // Handle tuple types. 639 if (auto exprTupleType = dyn_cast<ast::TupleType>(exprType)) 640 return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError, 641 noteAttachFn); 642 643 return emitConvertError(); 644 } 645 646 LogicalResult Parser::convertOpExpressionTo( 647 ast::Expr *&expr, ast::OperationType exprType, ast::Type type, 648 function_ref<ast::InFlightDiagnostic()> emitErrorFn) { 649 // Two operation types are compatible if they have the same name, or if the 650 // expected type is more general. 651 if (auto opType = dyn_cast<ast::OperationType>(type)) { 652 if (opType.getName()) 653 return emitErrorFn(); 654 return success(); 655 } 656 657 // An operation can always convert to a ValueRange. 658 if (type == valueRangeTy) { 659 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 660 valueRangeTy); 661 return success(); 662 } 663 664 // Allow conversion to a single value by constraining the result range. 665 if (type == valueTy) { 666 // If the operation is registered, we can verify if it can ever have a 667 // single result. 668 if (const ods::Operation *odsOp = exprType.getODSOperation()) { 669 if (odsOp->getResults().empty()) { 670 return emitErrorFn()->attachNote( 671 llvm::formatv("see the definition of `{0}`, which was defined " 672 "with zero results", 673 odsOp->getName()), 674 odsOp->getLoc()); 675 } 676 677 unsigned numSingleResults = llvm::count_if( 678 odsOp->getResults(), [](const ods::OperandOrResult &result) { 679 return result.getVariableLengthKind() == 680 ods::VariableLengthKind::Single; 681 }); 682 if (numSingleResults > 1) { 683 return emitErrorFn()->attachNote( 684 llvm::formatv("see the definition of `{0}`, which was defined " 685 "with at least {1} results", 686 odsOp->getName(), numSingleResults), 687 odsOp->getLoc()); 688 } 689 } 690 691 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 692 valueTy); 693 return success(); 694 } 695 return emitErrorFn(); 696 } 697 698 LogicalResult Parser::convertTupleExpressionTo( 699 ast::Expr *&expr, ast::TupleType exprType, ast::Type type, 700 function_ref<ast::InFlightDiagnostic()> emitErrorFn, 701 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { 702 // Handle conversions between tuples. 703 if (auto tupleType = dyn_cast<ast::TupleType>(type)) { 704 if (tupleType.size() != exprType.size()) 705 return emitErrorFn(); 706 707 // Build a new tuple expression using each of the elements of the current 708 // tuple. 709 SmallVector<ast::Expr *> newExprs; 710 for (unsigned i = 0, e = exprType.size(); i < e; ++i) { 711 newExprs.push_back(ast::MemberAccessExpr::create( 712 ctx, expr->getLoc(), expr, llvm::to_string(i), 713 exprType.getElementTypes()[i])); 714 715 auto diagFn = [&](ast::Diagnostic &diag) { 716 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", 717 i, exprType)); 718 if (noteAttachFn) 719 noteAttachFn(diag); 720 }; 721 if (failed(convertExpressionTo(newExprs.back(), 722 tupleType.getElementTypes()[i], diagFn))) 723 return failure(); 724 } 725 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, 726 tupleType.getElementNames()); 727 return success(); 728 } 729 730 // Handle conversion to a range. 731 auto convertToRange = [&](ArrayRef<ast::Type> allowedElementTypes, 732 ast::RangeType resultTy) -> LogicalResult { 733 // TODO: We currently only allow range conversion within a rewrite context. 734 if (parserContext != ParserContext::Rewrite) { 735 return emitErrorFn()->attachNote("Tuple to Range conversion is currently " 736 "only allowed within a rewrite context"); 737 } 738 739 // All of the tuple elements must be allowed types. 740 for (ast::Type elementType : exprType.getElementTypes()) 741 if (!llvm::is_contained(allowedElementTypes, elementType)) 742 return emitErrorFn(); 743 744 // Build a new tuple expression using each of the elements of the current 745 // tuple. 746 SmallVector<ast::Expr *> newExprs; 747 for (unsigned i = 0, e = exprType.size(); i < e; ++i) { 748 newExprs.push_back(ast::MemberAccessExpr::create( 749 ctx, expr->getLoc(), expr, llvm::to_string(i), 750 exprType.getElementTypes()[i])); 751 } 752 expr = ast::RangeExpr::create(ctx, expr->getLoc(), newExprs, resultTy); 753 return success(); 754 }; 755 if (type == valueRangeTy) 756 return convertToRange({valueTy, valueRangeTy}, valueRangeTy); 757 if (type == typeRangeTy) 758 return convertToRange({typeTy, typeRangeTy}, typeRangeTy); 759 760 return emitErrorFn(); 761 } 762 763 //===----------------------------------------------------------------------===// 764 // Directives 765 766 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) { 767 StringRef directive = curToken.getSpelling(); 768 if (directive == "#include") 769 return parseInclude(decls); 770 771 return emitError("unknown directive `" + directive + "`"); 772 } 773 774 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) { 775 SMRange loc = curToken.getLoc(); 776 consumeToken(Token::directive); 777 778 // Handle code completion of the include file path. 779 if (curToken.is(Token::code_complete_string)) 780 return codeCompleteIncludeFilename(curToken.getStringValue()); 781 782 // Parse the file being included. 783 if (!curToken.isString()) 784 return emitError(loc, 785 "expected string file name after `include` directive"); 786 SMRange fileLoc = curToken.getLoc(); 787 std::string filenameStr = curToken.getStringValue(); 788 StringRef filename = filenameStr; 789 consumeToken(); 790 791 // Check the type of include. If ending with `.pdll`, this is another pdl file 792 // to be parsed along with the current module. 793 if (filename.ends_with(".pdll")) { 794 if (failed(lexer.pushInclude(filename, fileLoc))) 795 return emitError(fileLoc, 796 "unable to open include file `" + filename + "`"); 797 798 // If we added the include successfully, parse it into the current module. 799 // Make sure to update to the next token after we finish parsing the nested 800 // file. 801 curToken = lexer.lexToken(); 802 LogicalResult result = parseModuleBody(decls); 803 curToken = lexer.lexToken(); 804 return result; 805 } 806 807 // Otherwise, this must be a `.td` include. 808 if (filename.ends_with(".td")) 809 return parseTdInclude(filename, fileLoc, decls); 810 811 return emitError(fileLoc, 812 "expected include filename to end with `.pdll` or `.td`"); 813 } 814 815 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, 816 SmallVectorImpl<ast::Decl *> &decls) { 817 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr(); 818 819 // Use the source manager to open the file, but don't yet add it. 820 std::string includedFile; 821 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer = 822 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile); 823 if (!includeBuffer) 824 return emitError(fileLoc, "unable to open include file `" + filename + "`"); 825 826 // Setup the source manager for parsing the tablegen file. 827 llvm::SourceMgr tdSrcMgr; 828 tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc()); 829 tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs()); 830 831 // This class provides a context argument for the llvm::SourceMgr diagnostic 832 // handler. 833 struct DiagHandlerContext { 834 Parser &parser; 835 StringRef filename; 836 llvm::SMRange loc; 837 } handlerContext{*this, filename, fileLoc}; 838 839 // Set the diagnostic handler for the tablegen source manager. 840 tdSrcMgr.setDiagHandler( 841 [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) { 842 auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext); 843 (void)ctx->parser.emitError( 844 ctx->loc, 845 llvm::formatv("error while processing include file `{0}`: {1}", 846 ctx->filename, diag.getMessage())); 847 }, 848 &handlerContext); 849 850 // Parse the tablegen file. 851 llvm::RecordKeeper tdRecords; 852 if (llvm::TableGenParseFile(tdSrcMgr, tdRecords)) 853 return failure(); 854 855 // Process the parsed records. 856 processTdIncludeRecords(tdRecords, decls); 857 858 // After we are done processing, move all of the tablegen source buffers to 859 // the main parser source mgr. This allows for directly using source locations 860 // from the .td files without needing to remap them. 861 parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End); 862 return success(); 863 } 864 865 void Parser::processTdIncludeRecords(const llvm::RecordKeeper &tdRecords, 866 SmallVectorImpl<ast::Decl *> &decls) { 867 // Return the length kind of the given value. 868 auto getLengthKind = [](const auto &value) { 869 if (value.isOptional()) 870 return ods::VariableLengthKind::Optional; 871 return value.isVariadic() ? ods::VariableLengthKind::Variadic 872 : ods::VariableLengthKind::Single; 873 }; 874 875 // Insert a type constraint into the ODS context. 876 ods::Context &odsContext = ctx.getODSContext(); 877 auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst) 878 -> const ods::TypeConstraint & { 879 return odsContext.insertTypeConstraint( 880 cst.constraint.getUniqueDefName(), 881 processDoc(cst.constraint.getSummary()), cst.constraint.getCppType()); 882 }; 883 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange { 884 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)}; 885 }; 886 887 // Process the parsed tablegen records to build ODS information. 888 /// Operations. 889 for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) { 890 tblgen::Operator op(def); 891 892 // Check to see if this operation is known to support type inferrence. 893 bool supportsResultTypeInferrence = 894 op.getTrait("::mlir::InferTypeOpInterface::Trait"); 895 896 auto [odsOp, inserted] = odsContext.insertOperation( 897 op.getOperationName(), processDoc(op.getSummary()), 898 processAndFormatDoc(op.getDescription()), op.getQualCppClassName(), 899 supportsResultTypeInferrence, op.getLoc().front()); 900 901 // Ignore operations that have already been added. 902 if (!inserted) 903 continue; 904 905 for (const tblgen::NamedAttribute &attr : op.getAttributes()) { 906 odsOp->appendAttribute(attr.name, attr.attr.isOptional(), 907 odsContext.insertAttributeConstraint( 908 attr.attr.getUniqueDefName(), 909 processDoc(attr.attr.getSummary()), 910 attr.attr.getStorageType())); 911 } 912 for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) { 913 odsOp->appendOperand(operand.name, getLengthKind(operand), 914 addTypeConstraint(operand)); 915 } 916 for (const tblgen::NamedTypeConstraint &result : op.getResults()) { 917 odsOp->appendResult(result.name, getLengthKind(result), 918 addTypeConstraint(result)); 919 } 920 } 921 922 auto shouldBeSkipped = [this](const llvm::Record *def) { 923 return def->isAnonymous() || curDeclScope->lookup(def->getName()) || 924 def->isSubClassOf("DeclareInterfaceMethods"); 925 }; 926 927 /// Attr constraints. 928 for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) { 929 if (shouldBeSkipped(def)) 930 continue; 931 932 tblgen::Attribute constraint(def); 933 decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 934 constraint, convertLocToRange(def->getLoc().front()), attrTy, 935 constraint.getStorageType())); 936 } 937 /// Type constraints. 938 for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) { 939 if (shouldBeSkipped(def)) 940 continue; 941 942 tblgen::TypeConstraint constraint(def); 943 decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 944 constraint, convertLocToRange(def->getLoc().front()), typeTy, 945 constraint.getCppType())); 946 } 947 /// OpInterfaces. 948 ast::Type opTy = ast::OperationType::get(ctx); 949 for (const llvm::Record *def : 950 tdRecords.getAllDerivedDefinitions("OpInterface")) { 951 if (shouldBeSkipped(def)) 952 continue; 953 954 SMRange loc = convertLocToRange(def->getLoc().front()); 955 956 std::string cppClassName = 957 llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"), 958 def->getValueAsString("cppInterfaceName")) 959 .str(); 960 std::string codeBlock = 961 llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));", 962 cppClassName) 963 .str(); 964 965 std::string desc = 966 processAndFormatDoc(def->getValueAsString("description")); 967 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>( 968 def->getName(), codeBlock, loc, opTy, cppClassName, desc)); 969 } 970 } 971 972 template <typename ConstraintT> 973 ast::Decl *Parser::createODSNativePDLLConstraintDecl( 974 StringRef name, StringRef codeBlock, SMRange loc, ast::Type type, 975 StringRef nativeType, StringRef docString) { 976 // Build the single input parameter. 977 ast::DeclScope *argScope = pushDeclScope(); 978 auto *paramVar = ast::VariableDecl::create( 979 ctx, ast::Name::create(ctx, "self", loc), type, 980 /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc))); 981 argScope->add(paramVar); 982 popDeclScope(); 983 984 // Build the native constraint. 985 auto *constraintDecl = ast::UserConstraintDecl::createNative( 986 ctx, ast::Name::create(ctx, name, loc), paramVar, 987 /*results=*/std::nullopt, codeBlock, ast::TupleType::get(ctx), 988 nativeType); 989 constraintDecl->setDocComment(ctx, docString); 990 curDeclScope->add(constraintDecl); 991 return constraintDecl; 992 } 993 994 template <typename ConstraintT> 995 ast::Decl * 996 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 997 SMRange loc, ast::Type type, 998 StringRef nativeType) { 999 // Format the condition template. 1000 tblgen::FmtContext fmtContext; 1001 fmtContext.withSelf("self"); 1002 std::string codeBlock = tblgen::tgfmt( 1003 "return ::mlir::success(" + constraint.getConditionTemplate() + ");", 1004 &fmtContext); 1005 1006 // If documentation was enabled, build the doc string for the generated 1007 // constraint. It would be nice to do this lazily, but TableGen information is 1008 // destroyed after we finish parsing the file. 1009 std::string docString; 1010 if (enableDocumentation) { 1011 StringRef desc = constraint.getDescription(); 1012 docString = processAndFormatDoc( 1013 constraint.getSummary() + 1014 (desc.empty() ? "" : ("\n\n" + constraint.getDescription()))); 1015 } 1016 1017 return createODSNativePDLLConstraintDecl<ConstraintT>( 1018 constraint.getUniqueDefName(), codeBlock, loc, type, nativeType, 1019 docString); 1020 } 1021 1022 //===----------------------------------------------------------------------===// 1023 // Decls 1024 1025 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { 1026 FailureOr<ast::Decl *> decl; 1027 switch (curToken.getKind()) { 1028 case Token::kw_Constraint: 1029 decl = parseUserConstraintDecl(); 1030 break; 1031 case Token::kw_Pattern: 1032 decl = parsePatternDecl(); 1033 break; 1034 case Token::kw_Rewrite: 1035 decl = parseUserRewriteDecl(); 1036 break; 1037 default: 1038 return emitError("expected top-level declaration, such as a `Pattern`"); 1039 } 1040 if (failed(decl)) 1041 return failure(); 1042 1043 // If the decl has a name, add it to the current scope. 1044 if (const ast::Name *name = (*decl)->getName()) { 1045 if (failed(checkDefineNamedDecl(*name))) 1046 return failure(); 1047 curDeclScope->add(*decl); 1048 } 1049 return decl; 1050 } 1051 1052 FailureOr<ast::NamedAttributeDecl *> 1053 Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) { 1054 // Check for name code completion. 1055 if (curToken.is(Token::code_complete)) 1056 return codeCompleteAttributeName(parentOpName); 1057 1058 std::string attrNameStr; 1059 if (curToken.isString()) 1060 attrNameStr = curToken.getStringValue(); 1061 else if (curToken.is(Token::identifier) || curToken.isKeyword()) 1062 attrNameStr = curToken.getSpelling().str(); 1063 else 1064 return emitError("expected identifier or string attribute name"); 1065 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc()); 1066 consumeToken(); 1067 1068 // Check for a value of the attribute. 1069 ast::Expr *attrValue = nullptr; 1070 if (consumeIf(Token::equal)) { 1071 FailureOr<ast::Expr *> attrExpr = parseExpr(); 1072 if (failed(attrExpr)) 1073 return failure(); 1074 attrValue = *attrExpr; 1075 } else { 1076 // If there isn't a concrete value, create an expression representing a 1077 // UnitAttr. 1078 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit"); 1079 } 1080 1081 return ast::NamedAttributeDecl::create(ctx, name, attrValue); 1082 } 1083 1084 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody( 1085 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 1086 bool expectTerminalSemicolon) { 1087 consumeToken(Token::equal_arrow); 1088 1089 // Parse the single statement of the lambda body. 1090 SMLoc bodyStartLoc = curToken.getStartLoc(); 1091 pushDeclScope(); 1092 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon); 1093 bool failedToParse = 1094 failed(singleStatement) || failed(processStatementFn(*singleStatement)); 1095 popDeclScope(); 1096 if (failedToParse) 1097 return failure(); 1098 1099 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); 1100 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement); 1101 } 1102 1103 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() { 1104 // Ensure that the argument is named. 1105 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) 1106 return emitError("expected identifier argument name"); 1107 1108 // Parse the argument similarly to a normal variable. 1109 StringRef name = curToken.getSpelling(); 1110 SMRange nameLoc = curToken.getLoc(); 1111 consumeToken(); 1112 1113 if (failed( 1114 parseToken(Token::colon, "expected `:` before argument constraint"))) 1115 return failure(); 1116 1117 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1118 if (failed(cst)) 1119 return failure(); 1120 1121 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1122 } 1123 1124 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) { 1125 // Check to see if this result is named. 1126 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 1127 // Check to see if this name actually refers to a Constraint. 1128 if (!curDeclScope->lookup<ast::ConstraintDecl>(curToken.getSpelling())) { 1129 // If it wasn't a constraint, parse the result similarly to a variable. If 1130 // there is already an existing decl, we will emit an error when defining 1131 // this variable later. 1132 StringRef name = curToken.getSpelling(); 1133 SMRange nameLoc = curToken.getLoc(); 1134 consumeToken(); 1135 1136 if (failed(parseToken(Token::colon, 1137 "expected `:` before result constraint"))) 1138 return failure(); 1139 1140 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1141 if (failed(cst)) 1142 return failure(); 1143 1144 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1145 } 1146 } 1147 1148 // If it isn't named, we parse the constraint directly and create an unnamed 1149 // result variable. 1150 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1151 if (failed(cst)) 1152 return failure(); 1153 1154 return createArgOrResultVariableDecl("", cst->referenceLoc, *cst); 1155 } 1156 1157 FailureOr<ast::UserConstraintDecl *> 1158 Parser::parseUserConstraintDecl(bool isInline) { 1159 // Constraints and rewrites have very similar formats, dispatch to a shared 1160 // interface for parsing. 1161 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1162 [&](auto &&...args) { 1163 return this->parseUserPDLLConstraintDecl(args...); 1164 }, 1165 ParserContext::Constraint, "constraint", isInline); 1166 } 1167 1168 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() { 1169 FailureOr<ast::UserConstraintDecl *> decl = 1170 parseUserConstraintDecl(/*isInline=*/true); 1171 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1172 return failure(); 1173 1174 curDeclScope->add(*decl); 1175 return decl; 1176 } 1177 1178 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl( 1179 const ast::Name &name, bool isInline, 1180 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1181 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1182 // Push the argument scope back onto the list, so that the body can 1183 // reference arguments. 1184 pushDeclScope(argumentScope); 1185 1186 // Parse the body of the constraint. The body is either defined as a compound 1187 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`. 1188 ast::CompoundStmt *body; 1189 if (curToken.is(Token::equal_arrow)) { 1190 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1191 [&](ast::Stmt *&stmt) -> LogicalResult { 1192 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt); 1193 if (!stmtExpr) { 1194 return emitError(stmt->getLoc(), 1195 "expected `Constraint` lambda body to contain a " 1196 "single expression"); 1197 } 1198 stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr); 1199 return success(); 1200 }, 1201 /*expectTerminalSemicolon=*/!isInline); 1202 if (failed(bodyResult)) 1203 return failure(); 1204 body = *bodyResult; 1205 } else { 1206 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1207 if (failed(bodyResult)) 1208 return failure(); 1209 body = *bodyResult; 1210 1211 // Verify the structure of the body. 1212 auto bodyIt = body->begin(), bodyE = body->end(); 1213 for (; bodyIt != bodyE; ++bodyIt) 1214 if (isa<ast::ReturnStmt>(*bodyIt)) 1215 break; 1216 if (failed(validateUserConstraintOrRewriteReturn( 1217 "Constraint", body, bodyIt, bodyE, results, resultType))) 1218 return failure(); 1219 } 1220 popDeclScope(); 1221 1222 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1223 name, arguments, results, resultType, body); 1224 } 1225 1226 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) { 1227 // Constraints and rewrites have very similar formats, dispatch to a shared 1228 // interface for parsing. 1229 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1230 [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); }, 1231 ParserContext::Rewrite, "rewrite", isInline); 1232 } 1233 1234 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() { 1235 FailureOr<ast::UserRewriteDecl *> decl = 1236 parseUserRewriteDecl(/*isInline=*/true); 1237 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1238 return failure(); 1239 1240 curDeclScope->add(*decl); 1241 return decl; 1242 } 1243 1244 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl( 1245 const ast::Name &name, bool isInline, 1246 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1247 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1248 // Push the argument scope back onto the list, so that the body can 1249 // reference arguments. 1250 curDeclScope = argumentScope; 1251 ast::CompoundStmt *body; 1252 if (curToken.is(Token::equal_arrow)) { 1253 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1254 [&](ast::Stmt *&statement) -> LogicalResult { 1255 if (isa<ast::OpRewriteStmt>(statement)) 1256 return success(); 1257 1258 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement); 1259 if (!statementExpr) { 1260 return emitError( 1261 statement->getLoc(), 1262 "expected `Rewrite` lambda body to contain a single expression " 1263 "or an operation rewrite statement; such as `erase`, " 1264 "`replace`, or `rewrite`"); 1265 } 1266 statement = 1267 ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr); 1268 return success(); 1269 }, 1270 /*expectTerminalSemicolon=*/!isInline); 1271 if (failed(bodyResult)) 1272 return failure(); 1273 body = *bodyResult; 1274 } else { 1275 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1276 if (failed(bodyResult)) 1277 return failure(); 1278 body = *bodyResult; 1279 } 1280 popDeclScope(); 1281 1282 // Verify the structure of the body. 1283 auto bodyIt = body->begin(), bodyE = body->end(); 1284 for (; bodyIt != bodyE; ++bodyIt) 1285 if (isa<ast::ReturnStmt>(*bodyIt)) 1286 break; 1287 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt, 1288 bodyE, results, resultType))) 1289 return failure(); 1290 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1291 name, arguments, results, resultType, body); 1292 } 1293 1294 template <typename T, typename ParseUserPDLLDeclFnT> 1295 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl( 1296 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 1297 StringRef anonymousNamePrefix, bool isInline) { 1298 SMRange loc = curToken.getLoc(); 1299 consumeToken(); 1300 llvm::SaveAndRestore saveCtx(parserContext, declContext); 1301 1302 // Parse the name of the decl. 1303 const ast::Name *name = nullptr; 1304 if (curToken.isNot(Token::identifier)) { 1305 // Only inline decls can be un-named. Inline decls are similar to "lambdas" 1306 // in C++, so being unnamed is fine. 1307 if (!isInline) 1308 return emitError("expected identifier name"); 1309 1310 // Create a unique anonymous name to use, as the name for this decl is not 1311 // important. 1312 std::string anonName = 1313 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix, 1314 anonymousDeclNameCounter++) 1315 .str(); 1316 name = &ast::Name::create(ctx, anonName, loc); 1317 } else { 1318 // If a name was provided, we can use it directly. 1319 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1320 consumeToken(Token::identifier); 1321 } 1322 1323 // Parse the functional signature of the decl. 1324 SmallVector<ast::VariableDecl *> arguments, results; 1325 ast::DeclScope *argumentScope; 1326 ast::Type resultType; 1327 if (failed(parseUserConstraintOrRewriteSignature(arguments, results, 1328 argumentScope, resultType))) 1329 return failure(); 1330 1331 // Check to see which type of constraint this is. If the constraint contains a 1332 // compound body, this is a PDLL decl. 1333 if (curToken.isAny(Token::l_brace, Token::equal_arrow)) 1334 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results, 1335 resultType); 1336 1337 // Otherwise, this is a native decl. 1338 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments, 1339 results, resultType); 1340 } 1341 1342 template <typename T> 1343 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl( 1344 const ast::Name &name, bool isInline, 1345 ArrayRef<ast::VariableDecl *> arguments, 1346 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1347 // If followed by a string, the native code body has also been specified. 1348 std::string codeStrStorage; 1349 std::optional<StringRef> optCodeStr; 1350 if (curToken.isString()) { 1351 codeStrStorage = curToken.getStringValue(); 1352 optCodeStr = codeStrStorage; 1353 consumeToken(); 1354 } else if (isInline) { 1355 return emitError(name.getLoc(), 1356 "external declarations must be declared in global scope"); 1357 } else if (curToken.is(Token::error)) { 1358 return failure(); 1359 } 1360 if (failed(parseToken(Token::semicolon, 1361 "expected `;` after native declaration"))) 1362 return failure(); 1363 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); 1364 } 1365 1366 LogicalResult Parser::parseUserConstraintOrRewriteSignature( 1367 SmallVectorImpl<ast::VariableDecl *> &arguments, 1368 SmallVectorImpl<ast::VariableDecl *> &results, 1369 ast::DeclScope *&argumentScope, ast::Type &resultType) { 1370 // Parse the argument list of the decl. 1371 if (failed(parseToken(Token::l_paren, "expected `(` to start argument list"))) 1372 return failure(); 1373 1374 argumentScope = pushDeclScope(); 1375 if (curToken.isNot(Token::r_paren)) { 1376 do { 1377 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl(); 1378 if (failed(argument)) 1379 return failure(); 1380 arguments.emplace_back(*argument); 1381 } while (consumeIf(Token::comma)); 1382 } 1383 popDeclScope(); 1384 if (failed(parseToken(Token::r_paren, "expected `)` to end argument list"))) 1385 return failure(); 1386 1387 // Parse the results of the decl. 1388 pushDeclScope(); 1389 if (consumeIf(Token::arrow)) { 1390 auto parseResultFn = [&]() -> LogicalResult { 1391 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size()); 1392 if (failed(result)) 1393 return failure(); 1394 results.emplace_back(*result); 1395 return success(); 1396 }; 1397 1398 // Check for a list of results. 1399 if (consumeIf(Token::l_paren)) { 1400 do { 1401 if (failed(parseResultFn())) 1402 return failure(); 1403 } while (consumeIf(Token::comma)); 1404 if (failed(parseToken(Token::r_paren, "expected `)` to end result list"))) 1405 return failure(); 1406 1407 // Otherwise, there is only one result. 1408 } else if (failed(parseResultFn())) { 1409 return failure(); 1410 } 1411 } 1412 popDeclScope(); 1413 1414 // Compute the result type of the decl. 1415 resultType = createUserConstraintRewriteResultType(results); 1416 1417 // Verify that results are only named if there are more than one. 1418 if (results.size() == 1 && !results.front()->getName().getName().empty()) { 1419 return emitError( 1420 results.front()->getLoc(), 1421 "cannot create a single-element tuple with an element label"); 1422 } 1423 return success(); 1424 } 1425 1426 LogicalResult Parser::validateUserConstraintOrRewriteReturn( 1427 StringRef declType, ast::CompoundStmt *body, 1428 ArrayRef<ast::Stmt *>::iterator bodyIt, 1429 ArrayRef<ast::Stmt *>::iterator bodyE, 1430 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) { 1431 // Handle if a `return` was provided. 1432 if (bodyIt != bodyE) { 1433 // Emit an error if we have trailing statements after the return. 1434 if (std::next(bodyIt) != bodyE) { 1435 return emitError( 1436 (*std::next(bodyIt))->getLoc(), 1437 llvm::formatv("`return` terminated the `{0}` body, but found " 1438 "trailing statements afterwards", 1439 declType)); 1440 } 1441 1442 // Otherwise if a return wasn't provided, check that no results are 1443 // expected. 1444 } else if (!results.empty()) { 1445 return emitError( 1446 {body->getLoc().End, body->getLoc().End}, 1447 llvm::formatv("missing return in a `{0}` expected to return `{1}`", 1448 declType, resultType)); 1449 } 1450 return success(); 1451 } 1452 1453 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() { 1454 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult { 1455 if (isa<ast::OpRewriteStmt>(statement)) 1456 return success(); 1457 return emitError( 1458 statement->getLoc(), 1459 "expected Pattern lambda body to contain a single operation " 1460 "rewrite statement, such as `erase`, `replace`, or `rewrite`"); 1461 }); 1462 } 1463 1464 FailureOr<ast::Decl *> Parser::parsePatternDecl() { 1465 SMRange loc = curToken.getLoc(); 1466 consumeToken(Token::kw_Pattern); 1467 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch); 1468 1469 // Check for an optional identifier for the pattern name. 1470 const ast::Name *name = nullptr; 1471 if (curToken.is(Token::identifier)) { 1472 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1473 consumeToken(Token::identifier); 1474 } 1475 1476 // Parse any pattern metadata. 1477 ParsedPatternMetadata metadata; 1478 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) 1479 return failure(); 1480 1481 // Parse the pattern body. 1482 ast::CompoundStmt *body; 1483 1484 // Handle a lambda body. 1485 if (curToken.is(Token::equal_arrow)) { 1486 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody(); 1487 if (failed(bodyResult)) 1488 return failure(); 1489 body = *bodyResult; 1490 } else { 1491 if (curToken.isNot(Token::l_brace)) 1492 return emitError("expected `{` or `=>` to start pattern body"); 1493 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1494 if (failed(bodyResult)) 1495 return failure(); 1496 body = *bodyResult; 1497 1498 // Verify the body of the pattern. 1499 auto bodyIt = body->begin(), bodyE = body->end(); 1500 for (; bodyIt != bodyE; ++bodyIt) { 1501 if (isa<ast::ReturnStmt>(*bodyIt)) { 1502 return emitError((*bodyIt)->getLoc(), 1503 "`return` statements are only permitted within a " 1504 "`Constraint` or `Rewrite` body"); 1505 } 1506 // Break when we've found the rewrite statement. 1507 if (isa<ast::OpRewriteStmt>(*bodyIt)) 1508 break; 1509 } 1510 if (bodyIt == bodyE) { 1511 return emitError(loc, 1512 "expected Pattern body to terminate with an operation " 1513 "rewrite statement, such as `erase`"); 1514 } 1515 if (std::next(bodyIt) != bodyE) { 1516 return emitError((*std::next(bodyIt))->getLoc(), 1517 "Pattern body was terminated by an operation " 1518 "rewrite statement, but found trailing statements"); 1519 } 1520 } 1521 1522 return createPatternDecl(loc, name, metadata, body); 1523 } 1524 1525 LogicalResult 1526 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { 1527 std::optional<SMRange> benefitLoc; 1528 std::optional<SMRange> hasBoundedRecursionLoc; 1529 1530 do { 1531 // Handle metadata code completion. 1532 if (curToken.is(Token::code_complete)) 1533 return codeCompletePatternMetadata(); 1534 1535 if (curToken.isNot(Token::identifier)) 1536 return emitError("expected pattern metadata identifier"); 1537 StringRef metadataStr = curToken.getSpelling(); 1538 SMRange metadataLoc = curToken.getLoc(); 1539 consumeToken(Token::identifier); 1540 1541 // Parse the benefit metadata: benefit(<integer-value>) 1542 if (metadataStr == "benefit") { 1543 if (benefitLoc) { 1544 return emitErrorAndNote(metadataLoc, 1545 "pattern benefit has already been specified", 1546 *benefitLoc, "see previous definition here"); 1547 } 1548 if (failed(parseToken(Token::l_paren, 1549 "expected `(` before pattern benefit"))) 1550 return failure(); 1551 1552 uint16_t benefitValue = 0; 1553 if (curToken.isNot(Token::integer)) 1554 return emitError("expected integral pattern benefit"); 1555 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) 1556 return emitError( 1557 "expected pattern benefit to fit within a 16-bit integer"); 1558 consumeToken(Token::integer); 1559 1560 metadata.benefit = benefitValue; 1561 benefitLoc = metadataLoc; 1562 1563 if (failed( 1564 parseToken(Token::r_paren, "expected `)` after pattern benefit"))) 1565 return failure(); 1566 continue; 1567 } 1568 1569 // Parse the bounded recursion metadata: recursion 1570 if (metadataStr == "recursion") { 1571 if (hasBoundedRecursionLoc) { 1572 return emitErrorAndNote( 1573 metadataLoc, 1574 "pattern recursion metadata has already been specified", 1575 *hasBoundedRecursionLoc, "see previous definition here"); 1576 } 1577 metadata.hasBoundedRecursion = true; 1578 hasBoundedRecursionLoc = metadataLoc; 1579 continue; 1580 } 1581 1582 return emitError(metadataLoc, "unknown pattern metadata"); 1583 } while (consumeIf(Token::comma)); 1584 1585 return success(); 1586 } 1587 1588 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { 1589 consumeToken(Token::less); 1590 1591 FailureOr<ast::Expr *> typeExpr = parseExpr(); 1592 if (failed(typeExpr) || 1593 failed(parseToken(Token::greater, 1594 "expected `>` after variable type constraint"))) 1595 return failure(); 1596 return typeExpr; 1597 } 1598 1599 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { 1600 assert(curDeclScope && "defining decl outside of a decl scope"); 1601 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) { 1602 return emitErrorAndNote( 1603 name.getLoc(), "`" + name.getName() + "` has already been defined", 1604 lastDecl->getName()->getLoc(), "see previous definition here"); 1605 } 1606 return success(); 1607 } 1608 1609 FailureOr<ast::VariableDecl *> 1610 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1611 ast::Expr *initExpr, 1612 ArrayRef<ast::ConstraintRef> constraints) { 1613 assert(curDeclScope && "defining variable outside of decl scope"); 1614 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); 1615 1616 // If the name of the variable indicates a special variable, we don't add it 1617 // to the scope. This variable is local to the definition point. 1618 if (name.empty() || name == "_") { 1619 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr, 1620 constraints); 1621 } 1622 if (failed(checkDefineNamedDecl(nameDecl))) 1623 return failure(); 1624 1625 auto *varDecl = 1626 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); 1627 curDeclScope->add(varDecl); 1628 return varDecl; 1629 } 1630 1631 FailureOr<ast::VariableDecl *> 1632 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1633 ArrayRef<ast::ConstraintRef> constraints) { 1634 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, 1635 constraints); 1636 } 1637 1638 LogicalResult Parser::parseVariableDeclConstraintList( 1639 SmallVectorImpl<ast::ConstraintRef> &constraints) { 1640 std::optional<SMRange> typeConstraint; 1641 auto parseSingleConstraint = [&] { 1642 FailureOr<ast::ConstraintRef> constraint = parseConstraint( 1643 typeConstraint, constraints, /*allowInlineTypeConstraints=*/true); 1644 if (failed(constraint)) 1645 return failure(); 1646 constraints.push_back(*constraint); 1647 return success(); 1648 }; 1649 1650 // Check to see if this is a single constraint, or a list. 1651 if (!consumeIf(Token::l_square)) 1652 return parseSingleConstraint(); 1653 1654 do { 1655 if (failed(parseSingleConstraint())) 1656 return failure(); 1657 } while (consumeIf(Token::comma)); 1658 return parseToken(Token::r_square, "expected `]` after constraint list"); 1659 } 1660 1661 FailureOr<ast::ConstraintRef> 1662 Parser::parseConstraint(std::optional<SMRange> &typeConstraint, 1663 ArrayRef<ast::ConstraintRef> existingConstraints, 1664 bool allowInlineTypeConstraints) { 1665 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { 1666 if (!allowInlineTypeConstraints) { 1667 return emitError( 1668 curToken.getLoc(), 1669 "inline `Attr`, `Value`, and `ValueRange` type constraints are not " 1670 "permitted on arguments or results"); 1671 } 1672 if (typeConstraint) 1673 return emitErrorAndNote( 1674 curToken.getLoc(), 1675 "the type of this variable has already been constrained", 1676 *typeConstraint, "see previous constraint location here"); 1677 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); 1678 if (failed(constraintExpr)) 1679 return failure(); 1680 typeExpr = *constraintExpr; 1681 typeConstraint = typeExpr->getLoc(); 1682 return success(); 1683 }; 1684 1685 SMRange loc = curToken.getLoc(); 1686 switch (curToken.getKind()) { 1687 case Token::kw_Attr: { 1688 consumeToken(Token::kw_Attr); 1689 1690 // Check for a type constraint. 1691 ast::Expr *typeExpr = nullptr; 1692 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1693 return failure(); 1694 return ast::ConstraintRef( 1695 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); 1696 } 1697 case Token::kw_Op: { 1698 consumeToken(Token::kw_Op); 1699 1700 // Parse an optional operation name. If the name isn't provided, this refers 1701 // to "any" operation. 1702 FailureOr<ast::OpNameDecl *> opName = 1703 parseWrappedOperationName(/*allowEmptyName=*/true); 1704 if (failed(opName)) 1705 return failure(); 1706 1707 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName), 1708 loc); 1709 } 1710 case Token::kw_Type: 1711 consumeToken(Token::kw_Type); 1712 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); 1713 case Token::kw_TypeRange: 1714 consumeToken(Token::kw_TypeRange); 1715 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), 1716 loc); 1717 case Token::kw_Value: { 1718 consumeToken(Token::kw_Value); 1719 1720 // Check for a type constraint. 1721 ast::Expr *typeExpr = nullptr; 1722 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1723 return failure(); 1724 1725 return ast::ConstraintRef( 1726 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); 1727 } 1728 case Token::kw_ValueRange: { 1729 consumeToken(Token::kw_ValueRange); 1730 1731 // Check for a type constraint. 1732 ast::Expr *typeExpr = nullptr; 1733 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1734 return failure(); 1735 1736 return ast::ConstraintRef( 1737 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); 1738 } 1739 1740 case Token::kw_Constraint: { 1741 // Handle an inline constraint. 1742 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1743 if (failed(decl)) 1744 return failure(); 1745 return ast::ConstraintRef(*decl, loc); 1746 } 1747 case Token::identifier: { 1748 StringRef constraintName = curToken.getSpelling(); 1749 consumeToken(Token::identifier); 1750 1751 // Lookup the referenced constraint. 1752 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName); 1753 if (!cstDecl) { 1754 return emitError(loc, "unknown reference to constraint `" + 1755 constraintName + "`"); 1756 } 1757 1758 // Handle a reference to a proper constraint. 1759 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl)) 1760 return ast::ConstraintRef(cst, loc); 1761 1762 return emitErrorAndNote( 1763 loc, "invalid reference to non-constraint", cstDecl->getLoc(), 1764 "see the definition of `" + constraintName + "` here"); 1765 } 1766 // Handle single entity constraint code completion. 1767 case Token::code_complete: { 1768 // Try to infer the current type for use by code completion. 1769 ast::Type inferredType; 1770 if (failed(validateVariableConstraints(existingConstraints, inferredType))) 1771 return failure(); 1772 1773 return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints); 1774 } 1775 default: 1776 break; 1777 } 1778 return emitError(loc, "expected identifier constraint"); 1779 } 1780 1781 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() { 1782 std::optional<SMRange> typeConstraint; 1783 return parseConstraint(typeConstraint, /*existingConstraints=*/std::nullopt, 1784 /*allowInlineTypeConstraints=*/false); 1785 } 1786 1787 //===----------------------------------------------------------------------===// 1788 // Exprs 1789 1790 FailureOr<ast::Expr *> Parser::parseExpr() { 1791 if (curToken.is(Token::underscore)) 1792 return parseUnderscoreExpr(); 1793 1794 // Parse the LHS expression. 1795 FailureOr<ast::Expr *> lhsExpr; 1796 switch (curToken.getKind()) { 1797 case Token::kw_attr: 1798 lhsExpr = parseAttributeExpr(); 1799 break; 1800 case Token::kw_Constraint: 1801 lhsExpr = parseInlineConstraintLambdaExpr(); 1802 break; 1803 case Token::kw_not: 1804 lhsExpr = parseNegatedExpr(); 1805 break; 1806 case Token::identifier: 1807 lhsExpr = parseIdentifierExpr(); 1808 break; 1809 case Token::kw_op: 1810 lhsExpr = parseOperationExpr(); 1811 break; 1812 case Token::kw_Rewrite: 1813 lhsExpr = parseInlineRewriteLambdaExpr(); 1814 break; 1815 case Token::kw_type: 1816 lhsExpr = parseTypeExpr(); 1817 break; 1818 case Token::l_paren: 1819 lhsExpr = parseTupleExpr(); 1820 break; 1821 default: 1822 return emitError("expected expression"); 1823 } 1824 if (failed(lhsExpr)) 1825 return failure(); 1826 1827 // Check for an operator expression. 1828 while (true) { 1829 switch (curToken.getKind()) { 1830 case Token::dot: 1831 lhsExpr = parseMemberAccessExpr(*lhsExpr); 1832 break; 1833 case Token::l_paren: 1834 lhsExpr = parseCallExpr(*lhsExpr); 1835 break; 1836 default: 1837 return lhsExpr; 1838 } 1839 if (failed(lhsExpr)) 1840 return failure(); 1841 } 1842 } 1843 1844 FailureOr<ast::Expr *> Parser::parseAttributeExpr() { 1845 SMRange loc = curToken.getLoc(); 1846 consumeToken(Token::kw_attr); 1847 1848 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal 1849 // identifier. 1850 if (!consumeIf(Token::less)) { 1851 resetToken(loc); 1852 return parseIdentifierExpr(); 1853 } 1854 1855 if (!curToken.isString()) 1856 return emitError("expected string literal containing MLIR attribute"); 1857 std::string attrExpr = curToken.getStringValue(); 1858 consumeToken(); 1859 1860 loc.End = curToken.getEndLoc(); 1861 if (failed( 1862 parseToken(Token::greater, "expected `>` after attribute literal"))) 1863 return failure(); 1864 return ast::AttributeExpr::create(ctx, loc, attrExpr); 1865 } 1866 1867 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr, 1868 bool isNegated) { 1869 consumeToken(Token::l_paren); 1870 1871 // Parse the arguments of the call. 1872 SmallVector<ast::Expr *> arguments; 1873 if (curToken.isNot(Token::r_paren)) { 1874 do { 1875 // Handle code completion for the call arguments. 1876 if (curToken.is(Token::code_complete)) { 1877 codeCompleteCallSignature(parentExpr, arguments.size()); 1878 return failure(); 1879 } 1880 1881 FailureOr<ast::Expr *> argument = parseExpr(); 1882 if (failed(argument)) 1883 return failure(); 1884 arguments.push_back(*argument); 1885 } while (consumeIf(Token::comma)); 1886 } 1887 1888 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc()); 1889 if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) 1890 return failure(); 1891 1892 return createCallExpr(loc, parentExpr, arguments, isNegated); 1893 } 1894 1895 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) { 1896 ast::Decl *decl = curDeclScope->lookup(name); 1897 if (!decl) 1898 return emitError(loc, "undefined reference to `" + name + "`"); 1899 1900 return createDeclRefExpr(loc, decl); 1901 } 1902 1903 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { 1904 StringRef name = curToken.getSpelling(); 1905 SMRange nameLoc = curToken.getLoc(); 1906 consumeToken(); 1907 1908 // Check to see if this is a decl ref expression that defines a variable 1909 // inline. 1910 if (consumeIf(Token::colon)) { 1911 SmallVector<ast::ConstraintRef> constraints; 1912 if (failed(parseVariableDeclConstraintList(constraints))) 1913 return failure(); 1914 ast::Type type; 1915 if (failed(validateVariableConstraints(constraints, type))) 1916 return failure(); 1917 return createInlineVariableExpr(type, name, nameLoc, constraints); 1918 } 1919 1920 return parseDeclRefExpr(name, nameLoc); 1921 } 1922 1923 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() { 1924 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1925 if (failed(decl)) 1926 return failure(); 1927 1928 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1929 ast::ConstraintType::get(ctx)); 1930 } 1931 1932 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() { 1933 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl(); 1934 if (failed(decl)) 1935 return failure(); 1936 1937 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1938 ast::RewriteType::get(ctx)); 1939 } 1940 1941 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { 1942 SMRange dotLoc = curToken.getLoc(); 1943 consumeToken(Token::dot); 1944 1945 // Check for code completion of the member name. 1946 if (curToken.is(Token::code_complete)) 1947 return codeCompleteMemberAccess(parentExpr); 1948 1949 // Parse the member name. 1950 Token memberNameTok = curToken; 1951 if (memberNameTok.isNot(Token::identifier, Token::integer) && 1952 !memberNameTok.isKeyword()) 1953 return emitError(dotLoc, "expected identifier or numeric member name"); 1954 StringRef memberName = memberNameTok.getSpelling(); 1955 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc()); 1956 consumeToken(); 1957 1958 return createMemberAccessExpr(parentExpr, memberName, loc); 1959 } 1960 1961 FailureOr<ast::Expr *> Parser::parseNegatedExpr() { 1962 consumeToken(Token::kw_not); 1963 // Only native constraints are supported after negation 1964 if (!curToken.is(Token::identifier)) 1965 return emitError("expected native constraint"); 1966 FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr(); 1967 if (failed(identifierExpr)) 1968 return failure(); 1969 if (!curToken.is(Token::l_paren)) 1970 return emitError("expected `(` after function name"); 1971 return parseCallExpr(*identifierExpr, /*isNegated = */ true); 1972 } 1973 1974 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { 1975 SMRange loc = curToken.getLoc(); 1976 1977 // Check for code completion for the dialect name. 1978 if (curToken.is(Token::code_complete)) 1979 return codeCompleteDialectName(); 1980 1981 // Handle the case of an no operation name. 1982 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { 1983 if (allowEmptyName) 1984 return ast::OpNameDecl::create(ctx, SMRange()); 1985 return emitError("expected dialect namespace"); 1986 } 1987 StringRef name = curToken.getSpelling(); 1988 consumeToken(); 1989 1990 // Otherwise, this is a literal operation name. 1991 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) 1992 return failure(); 1993 1994 // Check for code completion for the operation name. 1995 if (curToken.is(Token::code_complete)) 1996 return codeCompleteOperationName(name); 1997 1998 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) 1999 return emitError("expected operation name after dialect namespace"); 2000 2001 name = StringRef(name.data(), name.size() + 1); 2002 do { 2003 name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); 2004 loc.End = curToken.getEndLoc(); 2005 consumeToken(); 2006 } while (curToken.isAny(Token::identifier, Token::dot) || 2007 curToken.isKeyword()); 2008 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); 2009 } 2010 2011 FailureOr<ast::OpNameDecl *> 2012 Parser::parseWrappedOperationName(bool allowEmptyName) { 2013 if (!consumeIf(Token::less)) 2014 return ast::OpNameDecl::create(ctx, SMRange()); 2015 2016 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); 2017 if (failed(opNameDecl)) 2018 return failure(); 2019 2020 if (failed(parseToken(Token::greater, "expected `>` after operation name"))) 2021 return failure(); 2022 return opNameDecl; 2023 } 2024 2025 FailureOr<ast::Expr *> 2026 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) { 2027 SMRange loc = curToken.getLoc(); 2028 consumeToken(Token::kw_op); 2029 2030 // If it isn't followed by a `<`, the `op` keyword is treated as a normal 2031 // identifier. 2032 if (curToken.isNot(Token::less)) { 2033 resetToken(loc); 2034 return parseIdentifierExpr(); 2035 } 2036 2037 // Parse the operation name. The name may be elided, in which case the 2038 // operation refers to "any" operation(i.e. a difference between `MyOp` and 2039 // `Operation*`). Operation names within a rewrite context must be named. 2040 bool allowEmptyName = parserContext != ParserContext::Rewrite; 2041 FailureOr<ast::OpNameDecl *> opNameDecl = 2042 parseWrappedOperationName(allowEmptyName); 2043 if (failed(opNameDecl)) 2044 return failure(); 2045 std::optional<StringRef> opName = (*opNameDecl)->getName(); 2046 2047 // Functor used to create an implicit range variable, used for implicit "all" 2048 // operand or results variables. 2049 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) { 2050 FailureOr<ast::VariableDecl *> rangeVar = 2051 defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc)); 2052 assert(succeeded(rangeVar) && "expected range variable to be valid"); 2053 return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type); 2054 }; 2055 2056 // Check for the optional list of operands. 2057 SmallVector<ast::Expr *> operands; 2058 if (!consumeIf(Token::l_paren)) { 2059 // If the operand list isn't specified and we are in a match context, define 2060 // an inplace unconstrained operand range corresponding to all of the 2061 // operands of the operation. This avoids treating zero operands the same 2062 // way as "unconstrained operands". 2063 if (parserContext != ParserContext::Rewrite) { 2064 operands.push_back(createImplicitRangeVar( 2065 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy)); 2066 } 2067 } else if (!consumeIf(Token::r_paren)) { 2068 // If the operand list was specified and non-empty, parse the operands. 2069 do { 2070 // Check for operand signature code completion. 2071 if (curToken.is(Token::code_complete)) { 2072 codeCompleteOperationOperandsSignature(opName, operands.size()); 2073 return failure(); 2074 } 2075 2076 FailureOr<ast::Expr *> operand = parseExpr(); 2077 if (failed(operand)) 2078 return failure(); 2079 operands.push_back(*operand); 2080 } while (consumeIf(Token::comma)); 2081 2082 if (failed(parseToken(Token::r_paren, 2083 "expected `)` after operation operand list"))) 2084 return failure(); 2085 } 2086 2087 // Check for the optional list of attributes. 2088 SmallVector<ast::NamedAttributeDecl *> attributes; 2089 if (consumeIf(Token::l_brace)) { 2090 do { 2091 FailureOr<ast::NamedAttributeDecl *> decl = 2092 parseNamedAttributeDecl(opName); 2093 if (failed(decl)) 2094 return failure(); 2095 attributes.emplace_back(*decl); 2096 } while (consumeIf(Token::comma)); 2097 2098 if (failed(parseToken(Token::r_brace, 2099 "expected `}` after operation attribute list"))) 2100 return failure(); 2101 } 2102 2103 // Handle the result types of the operation. 2104 SmallVector<ast::Expr *> resultTypes; 2105 OpResultTypeContext resultTypeContext = inputResultTypeContext; 2106 2107 // Check for an explicit list of result types. 2108 if (consumeIf(Token::arrow)) { 2109 if (failed(parseToken(Token::l_paren, 2110 "expected `(` before operation result type list"))) 2111 return failure(); 2112 2113 // If result types are provided, initially assume that the operation does 2114 // not rely on type inferrence. We don't assert that it isn't, because we 2115 // may be inferring the value of some type/type range variables, but given 2116 // that these variables may be defined in calls we can't always discern when 2117 // this is the case. 2118 resultTypeContext = OpResultTypeContext::Explicit; 2119 2120 // Handle the case of an empty result list. 2121 if (!consumeIf(Token::r_paren)) { 2122 do { 2123 // Check for result signature code completion. 2124 if (curToken.is(Token::code_complete)) { 2125 codeCompleteOperationResultsSignature(opName, resultTypes.size()); 2126 return failure(); 2127 } 2128 2129 FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); 2130 if (failed(resultTypeExpr)) 2131 return failure(); 2132 resultTypes.push_back(*resultTypeExpr); 2133 } while (consumeIf(Token::comma)); 2134 2135 if (failed(parseToken(Token::r_paren, 2136 "expected `)` after operation result type list"))) 2137 return failure(); 2138 } 2139 } else if (parserContext != ParserContext::Rewrite) { 2140 // If the result list isn't specified and we are in a match context, define 2141 // an inplace unconstrained result range corresponding to all of the results 2142 // of the operation. This avoids treating zero results the same way as 2143 // "unconstrained results". 2144 resultTypes.push_back(createImplicitRangeVar( 2145 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy)); 2146 } else if (resultTypeContext == OpResultTypeContext::Explicit) { 2147 // If the result list isn't specified and we are in a rewrite, try to infer 2148 // them at runtime instead. 2149 resultTypeContext = OpResultTypeContext::Interface; 2150 } 2151 2152 return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands, 2153 attributes, resultTypes); 2154 } 2155 2156 FailureOr<ast::Expr *> Parser::parseTupleExpr() { 2157 SMRange loc = curToken.getLoc(); 2158 consumeToken(Token::l_paren); 2159 2160 DenseMap<StringRef, SMRange> usedNames; 2161 SmallVector<StringRef> elementNames; 2162 SmallVector<ast::Expr *> elements; 2163 if (curToken.isNot(Token::r_paren)) { 2164 do { 2165 // Check for the optional element name assignment before the value. 2166 StringRef elementName; 2167 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 2168 Token elementNameTok = curToken; 2169 consumeToken(); 2170 2171 // The element name is only present if followed by an `=`. 2172 if (consumeIf(Token::equal)) { 2173 elementName = elementNameTok.getSpelling(); 2174 2175 // Check to see if this name is already used. 2176 auto elementNameIt = 2177 usedNames.try_emplace(elementName, elementNameTok.getLoc()); 2178 if (!elementNameIt.second) { 2179 return emitErrorAndNote( 2180 elementNameTok.getLoc(), 2181 llvm::formatv("duplicate tuple element label `{0}`", 2182 elementName), 2183 elementNameIt.first->getSecond(), 2184 "see previous label use here"); 2185 } 2186 } else { 2187 // Otherwise, we treat this as part of an expression so reset the 2188 // lexer. 2189 resetToken(elementNameTok.getLoc()); 2190 } 2191 } 2192 elementNames.push_back(elementName); 2193 2194 // Parse the tuple element value. 2195 FailureOr<ast::Expr *> element = parseExpr(); 2196 if (failed(element)) 2197 return failure(); 2198 elements.push_back(*element); 2199 } while (consumeIf(Token::comma)); 2200 } 2201 loc.End = curToken.getEndLoc(); 2202 if (failed( 2203 parseToken(Token::r_paren, "expected `)` after tuple element list"))) 2204 return failure(); 2205 return createTupleExpr(loc, elements, elementNames); 2206 } 2207 2208 FailureOr<ast::Expr *> Parser::parseTypeExpr() { 2209 SMRange loc = curToken.getLoc(); 2210 consumeToken(Token::kw_type); 2211 2212 // If we aren't followed by a `<`, the `type` keyword is treated as a normal 2213 // identifier. 2214 if (!consumeIf(Token::less)) { 2215 resetToken(loc); 2216 return parseIdentifierExpr(); 2217 } 2218 2219 if (!curToken.isString()) 2220 return emitError("expected string literal containing MLIR type"); 2221 std::string attrExpr = curToken.getStringValue(); 2222 consumeToken(); 2223 2224 loc.End = curToken.getEndLoc(); 2225 if (failed(parseToken(Token::greater, "expected `>` after type literal"))) 2226 return failure(); 2227 return ast::TypeExpr::create(ctx, loc, attrExpr); 2228 } 2229 2230 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { 2231 StringRef name = curToken.getSpelling(); 2232 SMRange nameLoc = curToken.getLoc(); 2233 consumeToken(Token::underscore); 2234 2235 // Underscore expressions require a constraint list. 2236 if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) 2237 return failure(); 2238 2239 // Parse the constraints for the expression. 2240 SmallVector<ast::ConstraintRef> constraints; 2241 if (failed(parseVariableDeclConstraintList(constraints))) 2242 return failure(); 2243 2244 ast::Type type; 2245 if (failed(validateVariableConstraints(constraints, type))) 2246 return failure(); 2247 return createInlineVariableExpr(type, name, nameLoc, constraints); 2248 } 2249 2250 //===----------------------------------------------------------------------===// 2251 // Stmts 2252 2253 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { 2254 FailureOr<ast::Stmt *> stmt; 2255 switch (curToken.getKind()) { 2256 case Token::kw_erase: 2257 stmt = parseEraseStmt(); 2258 break; 2259 case Token::kw_let: 2260 stmt = parseLetStmt(); 2261 break; 2262 case Token::kw_replace: 2263 stmt = parseReplaceStmt(); 2264 break; 2265 case Token::kw_return: 2266 stmt = parseReturnStmt(); 2267 break; 2268 case Token::kw_rewrite: 2269 stmt = parseRewriteStmt(); 2270 break; 2271 default: 2272 stmt = parseExpr(); 2273 break; 2274 } 2275 if (failed(stmt) || 2276 (expectTerminalSemicolon && 2277 failed(parseToken(Token::semicolon, "expected `;` after statement")))) 2278 return failure(); 2279 return stmt; 2280 } 2281 2282 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { 2283 SMLoc startLoc = curToken.getStartLoc(); 2284 consumeToken(Token::l_brace); 2285 2286 // Push a new block scope and parse any nested statements. 2287 pushDeclScope(); 2288 SmallVector<ast::Stmt *> statements; 2289 while (curToken.isNot(Token::r_brace)) { 2290 FailureOr<ast::Stmt *> statement = parseStmt(); 2291 if (failed(statement)) 2292 return popDeclScope(), failure(); 2293 statements.push_back(*statement); 2294 } 2295 popDeclScope(); 2296 2297 // Consume the end brace. 2298 SMRange location(startLoc, curToken.getEndLoc()); 2299 consumeToken(Token::r_brace); 2300 2301 return ast::CompoundStmt::create(ctx, location, statements); 2302 } 2303 2304 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { 2305 if (parserContext == ParserContext::Constraint) 2306 return emitError("`erase` cannot be used within a Constraint"); 2307 SMRange loc = curToken.getLoc(); 2308 consumeToken(Token::kw_erase); 2309 2310 // Parse the root operation expression. 2311 FailureOr<ast::Expr *> rootOp = parseExpr(); 2312 if (failed(rootOp)) 2313 return failure(); 2314 2315 return createEraseStmt(loc, *rootOp); 2316 } 2317 2318 FailureOr<ast::LetStmt *> Parser::parseLetStmt() { 2319 SMRange loc = curToken.getLoc(); 2320 consumeToken(Token::kw_let); 2321 2322 // Parse the name of the new variable. 2323 SMRange varLoc = curToken.getLoc(); 2324 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { 2325 // `_` is a reserved variable name. 2326 if (curToken.is(Token::underscore)) { 2327 return emitError(varLoc, 2328 "`_` may only be used to define \"inline\" variables"); 2329 } 2330 return emitError(varLoc, 2331 "expected identifier after `let` to name a new variable"); 2332 } 2333 StringRef varName = curToken.getSpelling(); 2334 consumeToken(); 2335 2336 // Parse the optional set of constraints. 2337 SmallVector<ast::ConstraintRef> constraints; 2338 if (consumeIf(Token::colon) && 2339 failed(parseVariableDeclConstraintList(constraints))) 2340 return failure(); 2341 2342 // Parse the optional initializer expression. 2343 ast::Expr *initializer = nullptr; 2344 if (consumeIf(Token::equal)) { 2345 FailureOr<ast::Expr *> initOrFailure = parseExpr(); 2346 if (failed(initOrFailure)) 2347 return failure(); 2348 initializer = *initOrFailure; 2349 2350 // Check that the constraints are compatible with having an initializer, 2351 // e.g. type constraints cannot be used with initializers. 2352 for (ast::ConstraintRef constraint : constraints) { 2353 LogicalResult result = 2354 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) 2355 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 2356 ast::ValueRangeConstraintDecl>([&](const auto *cst) { 2357 if (cst->getTypeExpr()) { 2358 return this->emitError( 2359 constraint.referenceLoc, 2360 "type constraints are not permitted on variables with " 2361 "initializers"); 2362 } 2363 return success(); 2364 }) 2365 .Default(success()); 2366 if (failed(result)) 2367 return failure(); 2368 } 2369 } 2370 2371 FailureOr<ast::VariableDecl *> varDecl = 2372 createVariableDecl(varName, varLoc, initializer, constraints); 2373 if (failed(varDecl)) 2374 return failure(); 2375 return ast::LetStmt::create(ctx, loc, *varDecl); 2376 } 2377 2378 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { 2379 if (parserContext == ParserContext::Constraint) 2380 return emitError("`replace` cannot be used within a Constraint"); 2381 SMRange loc = curToken.getLoc(); 2382 consumeToken(Token::kw_replace); 2383 2384 // Parse the root operation expression. 2385 FailureOr<ast::Expr *> rootOp = parseExpr(); 2386 if (failed(rootOp)) 2387 return failure(); 2388 2389 if (failed( 2390 parseToken(Token::kw_with, "expected `with` after root operation"))) 2391 return failure(); 2392 2393 // The replacement portion of this statement is within a rewrite context. 2394 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite); 2395 2396 // Parse the replacement values. 2397 SmallVector<ast::Expr *> replValues; 2398 if (consumeIf(Token::l_paren)) { 2399 if (consumeIf(Token::r_paren)) { 2400 return emitError( 2401 loc, "expected at least one replacement value, consider using " 2402 "`erase` if no replacement values are desired"); 2403 } 2404 2405 do { 2406 FailureOr<ast::Expr *> replExpr = parseExpr(); 2407 if (failed(replExpr)) 2408 return failure(); 2409 replValues.emplace_back(*replExpr); 2410 } while (consumeIf(Token::comma)); 2411 2412 if (failed(parseToken(Token::r_paren, 2413 "expected `)` after replacement values"))) 2414 return failure(); 2415 } else { 2416 // Handle replacement with an operation uniquely, as the replacement 2417 // operation supports type inferrence from the root operation. 2418 FailureOr<ast::Expr *> replExpr; 2419 if (curToken.is(Token::kw_op)) 2420 replExpr = parseOperationExpr(OpResultTypeContext::Replacement); 2421 else 2422 replExpr = parseExpr(); 2423 if (failed(replExpr)) 2424 return failure(); 2425 replValues.emplace_back(*replExpr); 2426 } 2427 2428 return createReplaceStmt(loc, *rootOp, replValues); 2429 } 2430 2431 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() { 2432 SMRange loc = curToken.getLoc(); 2433 consumeToken(Token::kw_return); 2434 2435 // Parse the result value. 2436 FailureOr<ast::Expr *> resultExpr = parseExpr(); 2437 if (failed(resultExpr)) 2438 return failure(); 2439 2440 return ast::ReturnStmt::create(ctx, loc, *resultExpr); 2441 } 2442 2443 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { 2444 if (parserContext == ParserContext::Constraint) 2445 return emitError("`rewrite` cannot be used within a Constraint"); 2446 SMRange loc = curToken.getLoc(); 2447 consumeToken(Token::kw_rewrite); 2448 2449 // Parse the root operation. 2450 FailureOr<ast::Expr *> rootOp = parseExpr(); 2451 if (failed(rootOp)) 2452 return failure(); 2453 2454 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) 2455 return failure(); 2456 2457 if (curToken.isNot(Token::l_brace)) 2458 return emitError("expected `{` to start rewrite body"); 2459 2460 // The rewrite body of this statement is within a rewrite context. 2461 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite); 2462 2463 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); 2464 if (failed(rewriteBody)) 2465 return failure(); 2466 2467 // Verify the rewrite body. 2468 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) { 2469 if (isa<ast::ReturnStmt>(stmt)) { 2470 return emitError(stmt->getLoc(), 2471 "`return` statements are only permitted within a " 2472 "`Constraint` or `Rewrite` body"); 2473 } 2474 } 2475 2476 return createRewriteStmt(loc, *rootOp, *rewriteBody); 2477 } 2478 2479 //===----------------------------------------------------------------------===// 2480 // Creation+Analysis 2481 //===----------------------------------------------------------------------===// 2482 2483 //===----------------------------------------------------------------------===// 2484 // Decls 2485 2486 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) { 2487 // Unwrap reference expressions. 2488 if (auto *init = dyn_cast<ast::DeclRefExpr>(node)) 2489 node = init->getDecl(); 2490 return dyn_cast<ast::CallableDecl>(node); 2491 } 2492 2493 FailureOr<ast::PatternDecl *> 2494 Parser::createPatternDecl(SMRange loc, const ast::Name *name, 2495 const ParsedPatternMetadata &metadata, 2496 ast::CompoundStmt *body) { 2497 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, 2498 metadata.hasBoundedRecursion, body); 2499 } 2500 2501 ast::Type Parser::createUserConstraintRewriteResultType( 2502 ArrayRef<ast::VariableDecl *> results) { 2503 // Single result decls use the type of the single result. 2504 if (results.size() == 1) 2505 return results[0]->getType(); 2506 2507 // Multiple results use a tuple type, with the types and names grabbed from 2508 // the result variable decls. 2509 auto resultTypes = llvm::map_range( 2510 results, [&](const auto *result) { return result->getType(); }); 2511 auto resultNames = llvm::map_range( 2512 results, [&](const auto *result) { return result->getName().getName(); }); 2513 return ast::TupleType::get(ctx, llvm::to_vector(resultTypes), 2514 llvm::to_vector(resultNames)); 2515 } 2516 2517 template <typename T> 2518 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl( 2519 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 2520 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 2521 ast::CompoundStmt *body) { 2522 if (!body->getChildren().empty()) { 2523 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) { 2524 ast::Expr *resultExpr = retStmt->getResultExpr(); 2525 2526 // Process the result of the decl. If no explicit signature results 2527 // were provided, check for return type inference. Otherwise, check that 2528 // the return expression can be converted to the expected type. 2529 if (results.empty()) 2530 resultType = resultExpr->getType(); 2531 else if (failed(convertExpressionTo(resultExpr, resultType))) 2532 return failure(); 2533 else 2534 retStmt->setResultExpr(resultExpr); 2535 } 2536 } 2537 return T::createPDLL(ctx, name, arguments, results, body, resultType); 2538 } 2539 2540 FailureOr<ast::VariableDecl *> 2541 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 2542 ArrayRef<ast::ConstraintRef> constraints) { 2543 // The type of the variable, which is expected to be inferred by either a 2544 // constraint or an initializer expression. 2545 ast::Type type; 2546 if (failed(validateVariableConstraints(constraints, type))) 2547 return failure(); 2548 2549 if (initializer) { 2550 // Update the variable type based on the initializer, or try to convert the 2551 // initializer to the existing type. 2552 if (!type) 2553 type = initializer->getType(); 2554 else if (ast::Type mergedType = type.refineWith(initializer->getType())) 2555 type = mergedType; 2556 else if (failed(convertExpressionTo(initializer, type))) 2557 return failure(); 2558 2559 // Otherwise, if there is no initializer check that the type has already 2560 // been resolved from the constraint list. 2561 } else if (!type) { 2562 return emitErrorAndNote( 2563 loc, "unable to infer type for variable `" + name + "`", loc, 2564 "the type of a variable must be inferable from the constraint " 2565 "list or the initializer"); 2566 } 2567 2568 // Constraint types cannot be used when defining variables. 2569 if (isa<ast::ConstraintType, ast::RewriteType>(type)) { 2570 return emitError( 2571 loc, llvm::formatv("unable to define variable of `{0}` type", type)); 2572 } 2573 2574 // Try to define a variable with the given name. 2575 FailureOr<ast::VariableDecl *> varDecl = 2576 defineVariableDecl(name, loc, type, initializer, constraints); 2577 if (failed(varDecl)) 2578 return failure(); 2579 2580 return *varDecl; 2581 } 2582 2583 FailureOr<ast::VariableDecl *> 2584 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, 2585 const ast::ConstraintRef &constraint) { 2586 ast::Type argType; 2587 if (failed(validateVariableConstraint(constraint, argType))) 2588 return failure(); 2589 return defineVariableDecl(name, loc, argType, constraint); 2590 } 2591 2592 LogicalResult 2593 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 2594 ast::Type &inferredType) { 2595 for (const ast::ConstraintRef &ref : constraints) 2596 if (failed(validateVariableConstraint(ref, inferredType))) 2597 return failure(); 2598 return success(); 2599 } 2600 2601 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, 2602 ast::Type &inferredType) { 2603 ast::Type constraintType; 2604 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) { 2605 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2606 if (failed(validateTypeConstraintExpr(typeExpr))) 2607 return failure(); 2608 } 2609 constraintType = ast::AttributeType::get(ctx); 2610 } else if (const auto *cst = 2611 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) { 2612 constraintType = ast::OperationType::get( 2613 ctx, cst->getName(), lookupODSOperation(cst->getName())); 2614 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) { 2615 constraintType = typeTy; 2616 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) { 2617 constraintType = typeRangeTy; 2618 } else if (const auto *cst = 2619 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) { 2620 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2621 if (failed(validateTypeConstraintExpr(typeExpr))) 2622 return failure(); 2623 } 2624 constraintType = valueTy; 2625 } else if (const auto *cst = 2626 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) { 2627 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2628 if (failed(validateTypeRangeConstraintExpr(typeExpr))) 2629 return failure(); 2630 } 2631 constraintType = valueRangeTy; 2632 } else if (const auto *cst = 2633 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) { 2634 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs(); 2635 if (inputs.size() != 1) { 2636 return emitErrorAndNote(ref.referenceLoc, 2637 "`Constraint`s applied via a variable constraint " 2638 "list must take a single input, but got " + 2639 Twine(inputs.size()), 2640 cst->getLoc(), 2641 "see definition of constraint here"); 2642 } 2643 constraintType = inputs.front()->getType(); 2644 } else { 2645 llvm_unreachable("unknown constraint type"); 2646 } 2647 2648 // Check that the constraint type is compatible with the current inferred 2649 // type. 2650 if (!inferredType) { 2651 inferredType = constraintType; 2652 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { 2653 inferredType = mergedTy; 2654 } else { 2655 return emitError(ref.referenceLoc, 2656 llvm::formatv("constraint type `{0}` is incompatible " 2657 "with the previously inferred type `{1}`", 2658 constraintType, inferredType)); 2659 } 2660 return success(); 2661 } 2662 2663 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { 2664 ast::Type typeExprType = typeExpr->getType(); 2665 if (typeExprType != typeTy) { 2666 return emitError(typeExpr->getLoc(), 2667 "expected expression of `Type` in type constraint"); 2668 } 2669 return success(); 2670 } 2671 2672 LogicalResult 2673 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { 2674 ast::Type typeExprType = typeExpr->getType(); 2675 if (typeExprType != typeRangeTy) { 2676 return emitError(typeExpr->getLoc(), 2677 "expected expression of `TypeRange` in type constraint"); 2678 } 2679 return success(); 2680 } 2681 2682 //===----------------------------------------------------------------------===// 2683 // Exprs 2684 2685 FailureOr<ast::CallExpr *> 2686 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, 2687 MutableArrayRef<ast::Expr *> arguments, bool isNegated) { 2688 ast::Type parentType = parentExpr->getType(); 2689 2690 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); 2691 if (!callableDecl) { 2692 return emitError(loc, 2693 llvm::formatv("expected a reference to a callable " 2694 "`Constraint` or `Rewrite`, but got: `{0}`", 2695 parentType)); 2696 } 2697 if (parserContext == ParserContext::Rewrite) { 2698 if (isa<ast::UserConstraintDecl>(callableDecl)) 2699 return emitError( 2700 loc, "unable to invoke `Constraint` within a rewrite section"); 2701 if (isNegated) 2702 return emitError(loc, "unable to negate a Rewrite"); 2703 } else { 2704 if (isa<ast::UserRewriteDecl>(callableDecl)) 2705 return emitError(loc, 2706 "unable to invoke `Rewrite` within a match section"); 2707 if (isNegated && cast<ast::UserConstraintDecl>(callableDecl)->getBody()) 2708 return emitError(loc, "unable to negate non native constraints"); 2709 } 2710 2711 // Verify the arguments of the call. 2712 /// Handle size mismatch. 2713 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs(); 2714 if (callArgs.size() != arguments.size()) { 2715 return emitErrorAndNote( 2716 loc, 2717 llvm::formatv("invalid number of arguments for {0} call; expected " 2718 "{1}, but got {2}", 2719 callableDecl->getCallableType(), callArgs.size(), 2720 arguments.size()), 2721 callableDecl->getLoc(), 2722 llvm::formatv("see the definition of {0} here", 2723 callableDecl->getName()->getName())); 2724 } 2725 2726 /// Handle argument type mismatch. 2727 auto attachDiagFn = [&](ast::Diagnostic &diag) { 2728 diag.attachNote(llvm::formatv("see the definition of `{0}` here", 2729 callableDecl->getName()->getName()), 2730 callableDecl->getLoc()); 2731 }; 2732 for (auto it : llvm::zip(callArgs, arguments)) { 2733 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(), 2734 attachDiagFn))) 2735 return failure(); 2736 } 2737 2738 return ast::CallExpr::create(ctx, loc, parentExpr, arguments, 2739 callableDecl->getResultType(), isNegated); 2740 } 2741 2742 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc, 2743 ast::Decl *decl) { 2744 // Check the type of decl being referenced. 2745 ast::Type declType; 2746 if (isa<ast::ConstraintDecl>(decl)) 2747 declType = ast::ConstraintType::get(ctx); 2748 else if (isa<ast::UserRewriteDecl>(decl)) 2749 declType = ast::RewriteType::get(ctx); 2750 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) 2751 declType = varDecl->getType(); 2752 else 2753 return emitError(loc, "invalid reference to `" + 2754 decl->getName()->getName() + "`"); 2755 2756 return ast::DeclRefExpr::create(ctx, loc, decl, declType); 2757 } 2758 2759 FailureOr<ast::DeclRefExpr *> 2760 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 2761 ArrayRef<ast::ConstraintRef> constraints) { 2762 FailureOr<ast::VariableDecl *> decl = 2763 defineVariableDecl(name, loc, type, constraints); 2764 if (failed(decl)) 2765 return failure(); 2766 return ast::DeclRefExpr::create(ctx, loc, *decl, type); 2767 } 2768 2769 FailureOr<ast::MemberAccessExpr *> 2770 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 2771 SMRange loc) { 2772 // Validate the member name for the given parent expression. 2773 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); 2774 if (failed(memberType)) 2775 return failure(); 2776 2777 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); 2778 } 2779 2780 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, 2781 StringRef name, SMRange loc) { 2782 ast::Type parentType = parentExpr->getType(); 2783 if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType)) { 2784 if (name == ast::AllResultsMemberAccessExpr::getMemberName()) 2785 return valueRangeTy; 2786 2787 // Verify member access based on the operation type. 2788 if (const ods::Operation *odsOp = opType.getODSOperation()) { 2789 auto results = odsOp->getResults(); 2790 2791 // Handle indexed results. 2792 unsigned index = 0; 2793 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2794 index < results.size()) { 2795 return results[index].isVariadic() ? valueRangeTy : valueTy; 2796 } 2797 2798 // Handle named results. 2799 const auto *it = llvm::find_if(results, [&](const auto &result) { 2800 return result.getName() == name; 2801 }); 2802 if (it != results.end()) 2803 return it->isVariadic() ? valueRangeTy : valueTy; 2804 } else if (llvm::isDigit(name[0])) { 2805 // Allow unchecked numeric indexing of the results of unregistered 2806 // operations. It returns a single value. 2807 return valueTy; 2808 } 2809 } else if (auto tupleType = dyn_cast<ast::TupleType>(parentType)) { 2810 // Handle indexed results. 2811 unsigned index = 0; 2812 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2813 index < tupleType.size()) { 2814 return tupleType.getElementTypes()[index]; 2815 } 2816 2817 // Handle named results. 2818 auto elementNames = tupleType.getElementNames(); 2819 const auto *it = llvm::find(elementNames, name); 2820 if (it != elementNames.end()) 2821 return tupleType.getElementTypes()[it - elementNames.begin()]; 2822 } 2823 return emitError( 2824 loc, 2825 llvm::formatv("invalid member access `{0}` on expression of type `{1}`", 2826 name, parentType)); 2827 } 2828 2829 FailureOr<ast::OperationExpr *> Parser::createOperationExpr( 2830 SMRange loc, const ast::OpNameDecl *name, 2831 OpResultTypeContext resultTypeContext, 2832 SmallVectorImpl<ast::Expr *> &operands, 2833 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 2834 SmallVectorImpl<ast::Expr *> &results) { 2835 std::optional<StringRef> opNameRef = name->getName(); 2836 const ods::Operation *odsOp = lookupODSOperation(opNameRef); 2837 2838 // Verify the inputs operands. 2839 if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands))) 2840 return failure(); 2841 2842 // Verify the attribute list. 2843 for (ast::NamedAttributeDecl *attr : attributes) { 2844 // Check for an attribute type, or a type awaiting resolution. 2845 ast::Type attrType = attr->getValue()->getType(); 2846 if (!isa<ast::AttributeType>(attrType)) { 2847 return emitError( 2848 attr->getValue()->getLoc(), 2849 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); 2850 } 2851 } 2852 2853 assert( 2854 (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) && 2855 "unexpected inferrence when results were explicitly specified"); 2856 2857 // If we aren't relying on type inferrence, or explicit results were provided, 2858 // validate them. 2859 if (resultTypeContext == OpResultTypeContext::Explicit) { 2860 if (failed(validateOperationResults(loc, opNameRef, odsOp, results))) 2861 return failure(); 2862 2863 // Validate the use of interface based type inferrence for this operation. 2864 } else if (resultTypeContext == OpResultTypeContext::Interface) { 2865 assert(opNameRef && 2866 "expected valid operation name when inferring operation results"); 2867 checkOperationResultTypeInferrence(loc, *opNameRef, odsOp); 2868 } 2869 2870 return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results, 2871 attributes); 2872 } 2873 2874 LogicalResult 2875 Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name, 2876 const ods::Operation *odsOp, 2877 SmallVectorImpl<ast::Expr *> &operands) { 2878 return validateOperationOperandsOrResults( 2879 "operand", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name, 2880 operands, odsOp ? odsOp->getOperands() : std::nullopt, valueTy, 2881 valueRangeTy); 2882 } 2883 2884 LogicalResult 2885 Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name, 2886 const ods::Operation *odsOp, 2887 SmallVectorImpl<ast::Expr *> &results) { 2888 return validateOperationOperandsOrResults( 2889 "result", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name, 2890 results, odsOp ? odsOp->getResults() : std::nullopt, typeTy, typeRangeTy); 2891 } 2892 2893 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName, 2894 const ods::Operation *odsOp) { 2895 // If the operation might not have inferrence support, emit a warning to the 2896 // user. We don't emit an error because the interface might be added to the 2897 // operation at runtime. It's rare, but it could still happen. We emit a 2898 // warning here instead. 2899 2900 // Handle inferrence warnings for unknown operations. 2901 if (!odsOp) { 2902 ctx.getDiagEngine().emitWarning( 2903 loc, llvm::formatv( 2904 "operation result types are marked to be inferred, but " 2905 "`{0}` is unknown. Ensure that `{0}` supports zero " 2906 "results or implements `InferTypeOpInterface`. Include " 2907 "the ODS definition of this operation to remove this warning.", 2908 opName)); 2909 return; 2910 } 2911 2912 // Handle inferrence warnings for known operations that expected at least one 2913 // result, but don't have inference support. An elided results list can mean 2914 // "zero-results", and we don't want to warn when that is the expected 2915 // behavior. 2916 bool requiresInferrence = 2917 llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) { 2918 return !result.isVariableLength(); 2919 }); 2920 if (requiresInferrence && !odsOp->hasResultTypeInferrence()) { 2921 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning( 2922 loc, 2923 llvm::formatv("operation result types are marked to be inferred, but " 2924 "`{0}` does not provide an implementation of " 2925 "`InferTypeOpInterface`. Ensure that `{0}` attaches " 2926 "`InferTypeOpInterface` at runtime, or add support to " 2927 "the ODS definition to remove this warning.", 2928 opName)); 2929 diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName), 2930 odsOp->getLoc()); 2931 return; 2932 } 2933 } 2934 2935 LogicalResult Parser::validateOperationOperandsOrResults( 2936 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc, 2937 std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values, 2938 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 2939 ast::RangeType rangeTy) { 2940 // All operation types accept a single range parameter. 2941 if (values.size() == 1) { 2942 if (failed(convertExpressionTo(values[0], rangeTy))) 2943 return failure(); 2944 return success(); 2945 } 2946 2947 /// If the operation has ODS information, we can more accurately verify the 2948 /// values. 2949 if (odsOpLoc) { 2950 auto emitSizeMismatchError = [&] { 2951 return emitErrorAndNote( 2952 loc, 2953 llvm::formatv("invalid number of {0} groups for `{1}`; expected " 2954 "{2}, but got {3}", 2955 groupName, *name, odsValues.size(), values.size()), 2956 *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name)); 2957 }; 2958 2959 // Handle the case where no values were provided. 2960 if (values.empty()) { 2961 // If we don't expect any on the ODS side, we are done. 2962 if (odsValues.empty()) 2963 return success(); 2964 2965 // If we do, check if we actually need to provide values (i.e. if any of 2966 // the values are actually required). 2967 unsigned numVariadic = 0; 2968 for (const auto &odsValue : odsValues) { 2969 if (!odsValue.isVariableLength()) 2970 return emitSizeMismatchError(); 2971 ++numVariadic; 2972 } 2973 2974 // If we are in a non-rewrite context, we don't need to do anything more. 2975 // Zero-values is a valid constraint on the operation. 2976 if (parserContext != ParserContext::Rewrite) 2977 return success(); 2978 2979 // Otherwise, when in a rewrite we may need to provide values to match the 2980 // ODS signature of the operation to create. 2981 2982 // If we only have one variadic value, just use an empty list. 2983 if (numVariadic == 1) 2984 return success(); 2985 2986 // Otherwise, create dummy values for each of the entries so that we 2987 // adhere to the ODS signature. 2988 for (unsigned i = 0, e = odsValues.size(); i < e; ++i) { 2989 values.push_back(ast::RangeExpr::create( 2990 ctx, loc, /*elements=*/std::nullopt, rangeTy)); 2991 } 2992 return success(); 2993 } 2994 2995 // Verify that the number of values provided matches the number of value 2996 // groups ODS expects. 2997 if (odsValues.size() != values.size()) 2998 return emitSizeMismatchError(); 2999 3000 auto diagFn = [&](ast::Diagnostic &diag) { 3001 diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name), 3002 *odsOpLoc); 3003 }; 3004 for (unsigned i = 0, e = values.size(); i < e; ++i) { 3005 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy; 3006 if (failed(convertExpressionTo(values[i], expectedType, diagFn))) 3007 return failure(); 3008 } 3009 return success(); 3010 } 3011 3012 // Otherwise, accept the value groups as they have been defined and just 3013 // ensure they are one of the expected types. 3014 for (ast::Expr *&valueExpr : values) { 3015 ast::Type valueExprType = valueExpr->getType(); 3016 3017 // Check if this is one of the expected types. 3018 if (valueExprType == rangeTy || valueExprType == singleTy) 3019 continue; 3020 3021 // If the operand is an Operation, allow converting to a Value or 3022 // ValueRange. This situations arises quite often with nested operation 3023 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` 3024 if (singleTy == valueTy) { 3025 if (isa<ast::OperationType>(valueExprType)) { 3026 valueExpr = convertOpToValue(valueExpr); 3027 continue; 3028 } 3029 } 3030 3031 // Otherwise, try to convert the expression to a range. 3032 if (succeeded(convertExpressionTo(valueExpr, rangeTy))) 3033 continue; 3034 3035 return emitError( 3036 valueExpr->getLoc(), 3037 llvm::formatv( 3038 "expected `{0}` or `{1}` convertible expression, but got `{2}`", 3039 singleTy, rangeTy, valueExprType)); 3040 } 3041 return success(); 3042 } 3043 3044 FailureOr<ast::TupleExpr *> 3045 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements, 3046 ArrayRef<StringRef> elementNames) { 3047 for (const ast::Expr *element : elements) { 3048 ast::Type eleTy = element->getType(); 3049 if (isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>(eleTy)) { 3050 return emitError( 3051 element->getLoc(), 3052 llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); 3053 } 3054 } 3055 return ast::TupleExpr::create(ctx, loc, elements, elementNames); 3056 } 3057 3058 //===----------------------------------------------------------------------===// 3059 // Stmts 3060 3061 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc, 3062 ast::Expr *rootOp) { 3063 // Check that root is an Operation. 3064 ast::Type rootType = rootOp->getType(); 3065 if (!isa<ast::OperationType>(rootType)) 3066 return emitError(rootOp->getLoc(), "expected `Op` expression"); 3067 3068 return ast::EraseStmt::create(ctx, loc, rootOp); 3069 } 3070 3071 FailureOr<ast::ReplaceStmt *> 3072 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, 3073 MutableArrayRef<ast::Expr *> replValues) { 3074 // Check that root is an Operation. 3075 ast::Type rootType = rootOp->getType(); 3076 if (!isa<ast::OperationType>(rootType)) { 3077 return emitError( 3078 rootOp->getLoc(), 3079 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 3080 } 3081 3082 // If there are multiple replacement values, we implicitly convert any Op 3083 // expressions to the value form. 3084 bool shouldConvertOpToValues = replValues.size() > 1; 3085 for (ast::Expr *&replExpr : replValues) { 3086 ast::Type replType = replExpr->getType(); 3087 3088 // Check that replExpr is an Operation, Value, or ValueRange. 3089 if (isa<ast::OperationType>(replType)) { 3090 if (shouldConvertOpToValues) 3091 replExpr = convertOpToValue(replExpr); 3092 continue; 3093 } 3094 3095 if (replType != valueTy && replType != valueRangeTy) { 3096 return emitError(replExpr->getLoc(), 3097 llvm::formatv("expected `Op`, `Value` or `ValueRange` " 3098 "expression, but got `{0}`", 3099 replType)); 3100 } 3101 } 3102 3103 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); 3104 } 3105 3106 FailureOr<ast::RewriteStmt *> 3107 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, 3108 ast::CompoundStmt *rewriteBody) { 3109 // Check that root is an Operation. 3110 ast::Type rootType = rootOp->getType(); 3111 if (!isa<ast::OperationType>(rootType)) { 3112 return emitError( 3113 rootOp->getLoc(), 3114 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 3115 } 3116 3117 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); 3118 } 3119 3120 //===----------------------------------------------------------------------===// 3121 // Code Completion 3122 //===----------------------------------------------------------------------===// 3123 3124 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) { 3125 ast::Type parentType = parentExpr->getType(); 3126 if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType)) 3127 codeCompleteContext->codeCompleteOperationMemberAccess(opType); 3128 else if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(parentType)) 3129 codeCompleteContext->codeCompleteTupleMemberAccess(tupleType); 3130 return failure(); 3131 } 3132 3133 LogicalResult 3134 Parser::codeCompleteAttributeName(std::optional<StringRef> opName) { 3135 if (opName) 3136 codeCompleteContext->codeCompleteOperationAttributeName(*opName); 3137 return failure(); 3138 } 3139 3140 LogicalResult 3141 Parser::codeCompleteConstraintName(ast::Type inferredType, 3142 bool allowInlineTypeConstraints) { 3143 codeCompleteContext->codeCompleteConstraintName( 3144 inferredType, allowInlineTypeConstraints, curDeclScope); 3145 return failure(); 3146 } 3147 3148 LogicalResult Parser::codeCompleteDialectName() { 3149 codeCompleteContext->codeCompleteDialectName(); 3150 return failure(); 3151 } 3152 3153 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) { 3154 codeCompleteContext->codeCompleteOperationName(dialectName); 3155 return failure(); 3156 } 3157 3158 LogicalResult Parser::codeCompletePatternMetadata() { 3159 codeCompleteContext->codeCompletePatternMetadata(); 3160 return failure(); 3161 } 3162 3163 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) { 3164 codeCompleteContext->codeCompleteIncludeFilename(curPath); 3165 return failure(); 3166 } 3167 3168 void Parser::codeCompleteCallSignature(ast::Node *parent, 3169 unsigned currentNumArgs) { 3170 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent); 3171 if (!callableDecl) 3172 return; 3173 3174 codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs); 3175 } 3176 3177 void Parser::codeCompleteOperationOperandsSignature( 3178 std::optional<StringRef> opName, unsigned currentNumOperands) { 3179 codeCompleteContext->codeCompleteOperationOperandsSignature( 3180 opName, currentNumOperands); 3181 } 3182 3183 void Parser::codeCompleteOperationResultsSignature( 3184 std::optional<StringRef> opName, unsigned currentNumResults) { 3185 codeCompleteContext->codeCompleteOperationResultsSignature(opName, 3186 currentNumResults); 3187 } 3188 3189 //===----------------------------------------------------------------------===// 3190 // Parser 3191 //===----------------------------------------------------------------------===// 3192 3193 FailureOr<ast::Module *> 3194 mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 3195 bool enableDocumentation, 3196 CodeCompleteContext *codeCompleteContext) { 3197 Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext); 3198 return parser.parseModule(); 3199 } 3200