xref: /llvm-project/mlir/lib/Tools/PDLL/Parser/Parser.cpp (revision 095b41c6eedb3acc908dc63ee91ff77944c07d75)
1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/Parser/Parser.h"
10 #include "Lexer.h"
11 #include "mlir/Support/IndentedOstream.h"
12 #include "mlir/TableGen/Argument.h"
13 #include "mlir/TableGen/Attribute.h"
14 #include "mlir/TableGen/Constraint.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/Operator.h"
17 #include "mlir/Tools/PDLL/AST/Context.h"
18 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
19 #include "mlir/Tools/PDLL/AST/Nodes.h"
20 #include "mlir/Tools/PDLL/AST/Types.h"
21 #include "mlir/Tools/PDLL/ODS/Constraint.h"
22 #include "mlir/Tools/PDLL/ODS/Context.h"
23 #include "mlir/Tools/PDLL/ODS/Operation.h"
24 #include "mlir/Tools/PDLL/Parser/CodeComplete.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/ManagedStatic.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/ScopedPrinter.h"
31 #include "llvm/TableGen/Error.h"
32 #include "llvm/TableGen/Parser.h"
33 #include <optional>
34 #include <string>
35 
36 using namespace mlir;
37 using namespace mlir::pdll;
38 
39 //===----------------------------------------------------------------------===//
40 // Parser
41 //===----------------------------------------------------------------------===//
42 
43 namespace {
44 class Parser {
45 public:
46   Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
47          bool enableDocumentation, CodeCompleteContext *codeCompleteContext)
48       : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
49         curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
50         typeTy(ast::TypeType::get(ctx)), valueTy(ast::ValueType::get(ctx)),
51         typeRangeTy(ast::TypeRangeType::get(ctx)),
52         valueRangeTy(ast::ValueRangeType::get(ctx)),
53         attrTy(ast::AttributeType::get(ctx)),
54         codeCompleteContext(codeCompleteContext) {}
55 
56   /// Try to parse a new module. Returns nullptr in the case of failure.
57   FailureOr<ast::Module *> parseModule();
58 
59 private:
60   /// The current context of the parser. It allows for the parser to know a bit
61   /// about the construct it is nested within during parsing. This is used
62   /// specifically to provide additional verification during parsing, e.g. to
63   /// prevent using rewrites within a match context, matcher constraints within
64   /// a rewrite section, etc.
65   enum class ParserContext {
66     /// The parser is in the global context.
67     Global,
68     /// The parser is currently within a Constraint, which disallows all types
69     /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
70     Constraint,
71     /// The parser is currently within the matcher portion of a Pattern, which
72     /// is allows a terminal operation rewrite statement but no other rewrite
73     /// transformations.
74     PatternMatch,
75     /// The parser is currently within a Rewrite, which disallows calls to
76     /// constraints, requires operation expressions to have names, etc.
77     Rewrite,
78   };
79 
80   /// The current specification context of an operations result type. This
81   /// indicates how the result types of an operation may be inferred.
82   enum class OpResultTypeContext {
83     /// The result types of the operation are not known to be inferred.
84     Explicit,
85     /// The result types of the operation are inferred from the root input of a
86     /// `replace` statement.
87     Replacement,
88     /// The result types of the operation are inferred by using the
89     /// `InferTypeOpInterface` interface provided by the operation.
90     Interface,
91   };
92 
93   //===--------------------------------------------------------------------===//
94   // Parsing
95   //===--------------------------------------------------------------------===//
96 
97   /// Push a new decl scope onto the lexer.
98   ast::DeclScope *pushDeclScope() {
99     ast::DeclScope *newScope =
100         new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
101     return (curDeclScope = newScope);
102   }
103   void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
104 
105   /// Pop the last decl scope from the lexer.
106   void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
107 
108   /// Parse the body of an AST module.
109   LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
110 
111   /// Try to convert the given expression to `type`. Returns failure and emits
112   /// an error if a conversion is not viable. On failure, `noteAttachFn` is
113   /// invoked to attach notes to the emitted error diagnostic. On success,
114   /// `expr` is updated to the expression used to convert to `type`.
115   LogicalResult convertExpressionTo(
116       ast::Expr *&expr, ast::Type type,
117       function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
118   LogicalResult
119   convertOpExpressionTo(ast::Expr *&expr, ast::OperationType exprType,
120                         ast::Type type,
121                         function_ref<ast::InFlightDiagnostic()> emitErrorFn);
122   LogicalResult convertTupleExpressionTo(
123       ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
124       function_ref<ast::InFlightDiagnostic()> emitErrorFn,
125       function_ref<void(ast::Diagnostic &diag)> noteAttachFn);
126 
127   /// Given an operation expression, convert it to a Value or ValueRange
128   /// typed expression.
129   ast::Expr *convertOpToValue(const ast::Expr *opExpr);
130 
131   /// Lookup ODS information for the given operation, returns nullptr if no
132   /// information is found.
133   const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) {
134     return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
135   }
136 
137   /// Process the given documentation string, or return an empty string if
138   /// documentation isn't enabled.
139   StringRef processDoc(StringRef doc) {
140     return enableDocumentation ? doc : StringRef();
141   }
142 
143   /// Process the given documentation string and format it, or return an empty
144   /// string if documentation isn't enabled.
145   std::string processAndFormatDoc(const Twine &doc) {
146     if (!enableDocumentation)
147       return "";
148     std::string docStr;
149     {
150       llvm::raw_string_ostream docOS(docStr);
151       raw_indented_ostream(docOS).printReindented(
152           StringRef(docStr).rtrim(" \t"));
153     }
154     return docStr;
155   }
156 
157   //===--------------------------------------------------------------------===//
158   // Directives
159 
160   LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
161   LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
162   LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
163                                SmallVectorImpl<ast::Decl *> &decls);
164 
165   /// Process the records of a parsed tablegen include file.
166   void processTdIncludeRecords(const llvm::RecordKeeper &tdRecords,
167                                SmallVectorImpl<ast::Decl *> &decls);
168 
169   /// Create a user defined native constraint for a constraint imported from
170   /// ODS.
171   template <typename ConstraintT>
172   ast::Decl *
173   createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
174                                     SMRange loc, ast::Type type,
175                                     StringRef nativeType, StringRef docString);
176   template <typename ConstraintT>
177   ast::Decl *
178   createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
179                                     SMRange loc, ast::Type type,
180                                     StringRef nativeType);
181 
182   //===--------------------------------------------------------------------===//
183   // Decls
184 
185   /// This structure contains the set of pattern metadata that may be parsed.
186   struct ParsedPatternMetadata {
187     std::optional<uint16_t> benefit;
188     bool hasBoundedRecursion = false;
189   };
190 
191   FailureOr<ast::Decl *> parseTopLevelDecl();
192   FailureOr<ast::NamedAttributeDecl *>
193   parseNamedAttributeDecl(std::optional<StringRef> parentOpName);
194 
195   /// Parse an argument variable as part of the signature of a
196   /// UserConstraintDecl or UserRewriteDecl.
197   FailureOr<ast::VariableDecl *> parseArgumentDecl();
198 
199   /// Parse a result variable as part of the signature of a UserConstraintDecl
200   /// or UserRewriteDecl.
201   FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
202 
203   /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
204   /// defined in a non-global context.
205   FailureOr<ast::UserConstraintDecl *>
206   parseUserConstraintDecl(bool isInline = false);
207 
208   /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
209   /// non-global context, such as within a Pattern/Constraint/etc.
210   FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
211 
212   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
213   /// PDLL constructs.
214   FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
215       const ast::Name &name, bool isInline,
216       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
217       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
218 
219   /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
220   /// defined in a non-global context.
221   FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
222 
223   /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
224   /// non-global context, such as within a Pattern/Rewrite/etc.
225   FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
226 
227   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
228   /// PDLL constructs.
229   FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
230       const ast::Name &name, bool isInline,
231       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
232       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
233 
234   /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
235   /// effectively the same syntax, and only differ on slight semantics (given
236   /// the different parsing contexts).
237   template <typename T, typename ParseUserPDLLDeclFnT>
238   FailureOr<T *> parseUserConstraintOrRewriteDecl(
239       ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
240       StringRef anonymousNamePrefix, bool isInline);
241 
242   /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
243   /// These decls have effectively the same syntax.
244   template <typename T>
245   FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
246       const ast::Name &name, bool isInline,
247       ArrayRef<ast::VariableDecl *> arguments,
248       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
249 
250   /// Parse the functional signature (i.e. the arguments and results) of a
251   /// UserConstraintDecl or UserRewriteDecl.
252   LogicalResult parseUserConstraintOrRewriteSignature(
253       SmallVectorImpl<ast::VariableDecl *> &arguments,
254       SmallVectorImpl<ast::VariableDecl *> &results,
255       ast::DeclScope *&argumentScope, ast::Type &resultType);
256 
257   /// Validate the return (which if present is specified by bodyIt) of a
258   /// UserConstraintDecl or UserRewriteDecl.
259   LogicalResult validateUserConstraintOrRewriteReturn(
260       StringRef declType, ast::CompoundStmt *body,
261       ArrayRef<ast::Stmt *>::iterator bodyIt,
262       ArrayRef<ast::Stmt *>::iterator bodyE,
263       ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
264 
265   FailureOr<ast::CompoundStmt *>
266   parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
267                   bool expectTerminalSemicolon = true);
268   FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
269   FailureOr<ast::Decl *> parsePatternDecl();
270   LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
271 
272   /// Check to see if a decl has already been defined with the given name, if
273   /// one has emit and error and return failure. Returns success otherwise.
274   LogicalResult checkDefineNamedDecl(const ast::Name &name);
275 
276   /// Try to define a variable decl with the given components, returns the
277   /// variable on success.
278   FailureOr<ast::VariableDecl *>
279   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
280                      ast::Expr *initExpr,
281                      ArrayRef<ast::ConstraintRef> constraints);
282   FailureOr<ast::VariableDecl *>
283   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
284                      ArrayRef<ast::ConstraintRef> constraints);
285 
286   /// Parse the constraint reference list for a variable decl.
287   LogicalResult parseVariableDeclConstraintList(
288       SmallVectorImpl<ast::ConstraintRef> &constraints);
289 
290   /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
291   FailureOr<ast::Expr *> parseTypeConstraintExpr();
292 
293   /// Try to parse a single reference to a constraint. `typeConstraint` is the
294   /// location of a previously parsed type constraint for the entity that will
295   /// be constrained by the parsed constraint. `existingConstraints` are any
296   /// existing constraints that have already been parsed for the same entity
297   /// that will be constrained by this constraint. `allowInlineTypeConstraints`
298   /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
299   FailureOr<ast::ConstraintRef>
300   parseConstraint(std::optional<SMRange> &typeConstraint,
301                   ArrayRef<ast::ConstraintRef> existingConstraints,
302                   bool allowInlineTypeConstraints);
303 
304   /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
305   /// argument or result variable. The constraints for these variables do not
306   /// allow inline type constraints, and only permit a single constraint.
307   FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
308 
309   //===--------------------------------------------------------------------===//
310   // Exprs
311 
312   FailureOr<ast::Expr *> parseExpr();
313 
314   /// Identifier expressions.
315   FailureOr<ast::Expr *> parseAttributeExpr();
316   FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr,
317                                        bool isNegated = false);
318   FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
319   FailureOr<ast::Expr *> parseIdentifierExpr();
320   FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
321   FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
322   FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
323   FailureOr<ast::Expr *> parseNegatedExpr();
324   FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
325   FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
326   FailureOr<ast::Expr *>
327   parseOperationExpr(OpResultTypeContext inputResultTypeContext =
328                          OpResultTypeContext::Explicit);
329   FailureOr<ast::Expr *> parseTupleExpr();
330   FailureOr<ast::Expr *> parseTypeExpr();
331   FailureOr<ast::Expr *> parseUnderscoreExpr();
332 
333   //===--------------------------------------------------------------------===//
334   // Stmts
335 
336   FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
337   FailureOr<ast::CompoundStmt *> parseCompoundStmt();
338   FailureOr<ast::EraseStmt *> parseEraseStmt();
339   FailureOr<ast::LetStmt *> parseLetStmt();
340   FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
341   FailureOr<ast::ReturnStmt *> parseReturnStmt();
342   FailureOr<ast::RewriteStmt *> parseRewriteStmt();
343 
344   //===--------------------------------------------------------------------===//
345   // Creation+Analysis
346   //===--------------------------------------------------------------------===//
347 
348   //===--------------------------------------------------------------------===//
349   // Decls
350 
351   /// Try to extract a callable from the given AST node. Returns nullptr on
352   /// failure.
353   ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
354 
355   /// Try to create a pattern decl with the given components, returning the
356   /// Pattern on success.
357   FailureOr<ast::PatternDecl *>
358   createPatternDecl(SMRange loc, const ast::Name *name,
359                     const ParsedPatternMetadata &metadata,
360                     ast::CompoundStmt *body);
361 
362   /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
363   /// of results, defined as part of the signature.
364   ast::Type
365   createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
366 
367   /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
368   template <typename T>
369   FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
370       const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
371       ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
372       ast::CompoundStmt *body);
373 
374   /// Try to create a variable decl with the given components, returning the
375   /// Variable on success.
376   FailureOr<ast::VariableDecl *>
377   createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
378                      ArrayRef<ast::ConstraintRef> constraints);
379 
380   /// Create a variable for an argument or result defined as part of the
381   /// signature of a UserConstraintDecl/UserRewriteDecl.
382   FailureOr<ast::VariableDecl *>
383   createArgOrResultVariableDecl(StringRef name, SMRange loc,
384                                 const ast::ConstraintRef &constraint);
385 
386   /// Validate the constraints used to constraint a variable decl.
387   /// `inferredType` is the type of the variable inferred by the constraints
388   /// within the list, and is updated to the most refined type as determined by
389   /// the constraints. Returns success if the constraint list is valid, failure
390   /// otherwise.
391   LogicalResult
392   validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
393                               ast::Type &inferredType);
394   /// Validate a single reference to a constraint. `inferredType` contains the
395   /// currently inferred variabled type and is refined within the type defined
396   /// by the constraint. Returns success if the constraint is valid, failure
397   /// otherwise.
398   LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
399                                            ast::Type &inferredType);
400   LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
401   LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
402 
403   //===--------------------------------------------------------------------===//
404   // Exprs
405 
406   FailureOr<ast::CallExpr *>
407   createCallExpr(SMRange loc, ast::Expr *parentExpr,
408                  MutableArrayRef<ast::Expr *> arguments,
409                  bool isNegated = false);
410   FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
411   FailureOr<ast::DeclRefExpr *>
412   createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
413                            ArrayRef<ast::ConstraintRef> constraints);
414   FailureOr<ast::MemberAccessExpr *>
415   createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
416 
417   /// Validate the member access `name` into the given parent expression. On
418   /// success, this also returns the type of the member accessed.
419   FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
420                                             StringRef name, SMRange loc);
421   FailureOr<ast::OperationExpr *>
422   createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
423                       OpResultTypeContext resultTypeContext,
424                       SmallVectorImpl<ast::Expr *> &operands,
425                       MutableArrayRef<ast::NamedAttributeDecl *> attributes,
426                       SmallVectorImpl<ast::Expr *> &results);
427   LogicalResult
428   validateOperationOperands(SMRange loc, std::optional<StringRef> name,
429                             const ods::Operation *odsOp,
430                             SmallVectorImpl<ast::Expr *> &operands);
431   LogicalResult validateOperationResults(SMRange loc,
432                                          std::optional<StringRef> name,
433                                          const ods::Operation *odsOp,
434                                          SmallVectorImpl<ast::Expr *> &results);
435   void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
436                                           const ods::Operation *odsOp);
437   LogicalResult validateOperationOperandsOrResults(
438       StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
439       std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
440       ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
441       ast::RangeType rangeTy);
442   FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
443                                               ArrayRef<ast::Expr *> elements,
444                                               ArrayRef<StringRef> elementNames);
445 
446   //===--------------------------------------------------------------------===//
447   // Stmts
448 
449   FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
450   FailureOr<ast::ReplaceStmt *>
451   createReplaceStmt(SMRange loc, ast::Expr *rootOp,
452                     MutableArrayRef<ast::Expr *> replValues);
453   FailureOr<ast::RewriteStmt *>
454   createRewriteStmt(SMRange loc, ast::Expr *rootOp,
455                     ast::CompoundStmt *rewriteBody);
456 
457   //===--------------------------------------------------------------------===//
458   // Code Completion
459   //===--------------------------------------------------------------------===//
460 
461   /// The set of various code completion methods. Every completion method
462   /// returns `failure` to stop the parsing process after providing completion
463   /// results.
464 
465   LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
466   LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName);
467   LogicalResult codeCompleteConstraintName(ast::Type inferredType,
468                                            bool allowInlineTypeConstraints);
469   LogicalResult codeCompleteDialectName();
470   LogicalResult codeCompleteOperationName(StringRef dialectName);
471   LogicalResult codeCompletePatternMetadata();
472   LogicalResult codeCompleteIncludeFilename(StringRef curPath);
473 
474   void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
475   void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
476                                               unsigned currentNumOperands);
477   void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
478                                              unsigned currentNumResults);
479 
480   //===--------------------------------------------------------------------===//
481   // Lexer Utilities
482   //===--------------------------------------------------------------------===//
483 
484   /// If the current token has the specified kind, consume it and return true.
485   /// If not, return false.
486   bool consumeIf(Token::Kind kind) {
487     if (curToken.isNot(kind))
488       return false;
489     consumeToken(kind);
490     return true;
491   }
492 
493   /// Advance the current lexer onto the next token.
494   void consumeToken() {
495     assert(curToken.isNot(Token::eof, Token::error) &&
496            "shouldn't advance past EOF or errors");
497     curToken = lexer.lexToken();
498   }
499 
500   /// Advance the current lexer onto the next token, asserting what the expected
501   /// current token is. This is preferred to the above method because it leads
502   /// to more self-documenting code with better checking.
503   void consumeToken(Token::Kind kind) {
504     assert(curToken.is(kind) && "consumed an unexpected token");
505     consumeToken();
506   }
507 
508   /// Reset the lexer to the location at the given position.
509   void resetToken(SMRange tokLoc) {
510     lexer.resetPointer(tokLoc.Start.getPointer());
511     curToken = lexer.lexToken();
512   }
513 
514   /// Consume the specified token if present and return success. On failure,
515   /// output a diagnostic and return failure.
516   LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
517     if (curToken.getKind() != kind)
518       return emitError(curToken.getLoc(), msg);
519     consumeToken();
520     return success();
521   }
522   LogicalResult emitError(SMRange loc, const Twine &msg) {
523     lexer.emitError(loc, msg);
524     return failure();
525   }
526   LogicalResult emitError(const Twine &msg) {
527     return emitError(curToken.getLoc(), msg);
528   }
529   LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
530                                  const Twine &note) {
531     lexer.emitErrorAndNote(loc, msg, noteLoc, note);
532     return failure();
533   }
534 
535   //===--------------------------------------------------------------------===//
536   // Fields
537   //===--------------------------------------------------------------------===//
538 
539   /// The owning AST context.
540   ast::Context &ctx;
541 
542   /// The lexer of this parser.
543   Lexer lexer;
544 
545   /// The current token within the lexer.
546   Token curToken;
547 
548   /// A flag indicating if the parser should add documentation to AST nodes when
549   /// viable.
550   bool enableDocumentation;
551 
552   /// The most recently defined decl scope.
553   ast::DeclScope *curDeclScope = nullptr;
554   llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
555 
556   /// The current context of the parser.
557   ParserContext parserContext = ParserContext::Global;
558 
559   /// Cached types to simplify verification and expression creation.
560   ast::Type typeTy, valueTy;
561   ast::RangeType typeRangeTy, valueRangeTy;
562   ast::Type attrTy;
563 
564   /// A counter used when naming anonymous constraints and rewrites.
565   unsigned anonymousDeclNameCounter = 0;
566 
567   /// The optional code completion context.
568   CodeCompleteContext *codeCompleteContext;
569 };
570 } // namespace
571 
572 FailureOr<ast::Module *> Parser::parseModule() {
573   SMLoc moduleLoc = curToken.getStartLoc();
574   pushDeclScope();
575 
576   // Parse the top-level decls of the module.
577   SmallVector<ast::Decl *> decls;
578   if (failed(parseModuleBody(decls)))
579     return popDeclScope(), failure();
580 
581   popDeclScope();
582   return ast::Module::create(ctx, moduleLoc, decls);
583 }
584 
585 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
586   while (curToken.isNot(Token::eof)) {
587     if (curToken.is(Token::directive)) {
588       if (failed(parseDirective(decls)))
589         return failure();
590       continue;
591     }
592 
593     FailureOr<ast::Decl *> decl = parseTopLevelDecl();
594     if (failed(decl))
595       return failure();
596     decls.push_back(*decl);
597   }
598   return success();
599 }
600 
601 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
602   return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
603                                                  valueRangeTy);
604 }
605 
606 LogicalResult Parser::convertExpressionTo(
607     ast::Expr *&expr, ast::Type type,
608     function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
609   ast::Type exprType = expr->getType();
610   if (exprType == type)
611     return success();
612 
613   auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
614     ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
615         expr->getLoc(), llvm::formatv("unable to convert expression of type "
616                                       "`{0}` to the expected type of "
617                                       "`{1}`",
618                                       exprType, type));
619     if (noteAttachFn)
620       noteAttachFn(*diag);
621     return diag;
622   };
623 
624   if (auto exprOpType = dyn_cast<ast::OperationType>(exprType))
625     return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);
626 
627   // FIXME: Decide how to allow/support converting a single result to multiple,
628   // and multiple to a single result. For now, we just allow Single->Range,
629   // but this isn't something really supported in the PDL dialect. We should
630   // figure out some way to support both.
631   if ((exprType == valueTy || exprType == valueRangeTy) &&
632       (type == valueTy || type == valueRangeTy))
633     return success();
634   if ((exprType == typeTy || exprType == typeRangeTy) &&
635       (type == typeTy || type == typeRangeTy))
636     return success();
637 
638   // Handle tuple types.
639   if (auto exprTupleType = dyn_cast<ast::TupleType>(exprType))
640     return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
641                                     noteAttachFn);
642 
643   return emitConvertError();
644 }
645 
646 LogicalResult Parser::convertOpExpressionTo(
647     ast::Expr *&expr, ast::OperationType exprType, ast::Type type,
648     function_ref<ast::InFlightDiagnostic()> emitErrorFn) {
649   // Two operation types are compatible if they have the same name, or if the
650   // expected type is more general.
651   if (auto opType = dyn_cast<ast::OperationType>(type)) {
652     if (opType.getName())
653       return emitErrorFn();
654     return success();
655   }
656 
657   // An operation can always convert to a ValueRange.
658   if (type == valueRangeTy) {
659     expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
660                                                    valueRangeTy);
661     return success();
662   }
663 
664   // Allow conversion to a single value by constraining the result range.
665   if (type == valueTy) {
666     // If the operation is registered, we can verify if it can ever have a
667     // single result.
668     if (const ods::Operation *odsOp = exprType.getODSOperation()) {
669       if (odsOp->getResults().empty()) {
670         return emitErrorFn()->attachNote(
671             llvm::formatv("see the definition of `{0}`, which was defined "
672                           "with zero results",
673                           odsOp->getName()),
674             odsOp->getLoc());
675       }
676 
677       unsigned numSingleResults = llvm::count_if(
678           odsOp->getResults(), [](const ods::OperandOrResult &result) {
679             return result.getVariableLengthKind() ==
680                    ods::VariableLengthKind::Single;
681           });
682       if (numSingleResults > 1) {
683         return emitErrorFn()->attachNote(
684             llvm::formatv("see the definition of `{0}`, which was defined "
685                           "with at least {1} results",
686                           odsOp->getName(), numSingleResults),
687             odsOp->getLoc());
688       }
689     }
690 
691     expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
692                                                    valueTy);
693     return success();
694   }
695   return emitErrorFn();
696 }
697 
698 LogicalResult Parser::convertTupleExpressionTo(
699     ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
700     function_ref<ast::InFlightDiagnostic()> emitErrorFn,
701     function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
702   // Handle conversions between tuples.
703   if (auto tupleType = dyn_cast<ast::TupleType>(type)) {
704     if (tupleType.size() != exprType.size())
705       return emitErrorFn();
706 
707     // Build a new tuple expression using each of the elements of the current
708     // tuple.
709     SmallVector<ast::Expr *> newExprs;
710     for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
711       newExprs.push_back(ast::MemberAccessExpr::create(
712           ctx, expr->getLoc(), expr, llvm::to_string(i),
713           exprType.getElementTypes()[i]));
714 
715       auto diagFn = [&](ast::Diagnostic &diag) {
716         diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
717                                       i, exprType));
718         if (noteAttachFn)
719           noteAttachFn(diag);
720       };
721       if (failed(convertExpressionTo(newExprs.back(),
722                                      tupleType.getElementTypes()[i], diagFn)))
723         return failure();
724     }
725     expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
726                                   tupleType.getElementNames());
727     return success();
728   }
729 
730   // Handle conversion to a range.
731   auto convertToRange = [&](ArrayRef<ast::Type> allowedElementTypes,
732                             ast::RangeType resultTy) -> LogicalResult {
733     // TODO: We currently only allow range conversion within a rewrite context.
734     if (parserContext != ParserContext::Rewrite) {
735       return emitErrorFn()->attachNote("Tuple to Range conversion is currently "
736                                        "only allowed within a rewrite context");
737     }
738 
739     // All of the tuple elements must be allowed types.
740     for (ast::Type elementType : exprType.getElementTypes())
741       if (!llvm::is_contained(allowedElementTypes, elementType))
742         return emitErrorFn();
743 
744     // Build a new tuple expression using each of the elements of the current
745     // tuple.
746     SmallVector<ast::Expr *> newExprs;
747     for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
748       newExprs.push_back(ast::MemberAccessExpr::create(
749           ctx, expr->getLoc(), expr, llvm::to_string(i),
750           exprType.getElementTypes()[i]));
751     }
752     expr = ast::RangeExpr::create(ctx, expr->getLoc(), newExprs, resultTy);
753     return success();
754   };
755   if (type == valueRangeTy)
756     return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
757   if (type == typeRangeTy)
758     return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
759 
760   return emitErrorFn();
761 }
762 
763 //===----------------------------------------------------------------------===//
764 // Directives
765 
766 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
767   StringRef directive = curToken.getSpelling();
768   if (directive == "#include")
769     return parseInclude(decls);
770 
771   return emitError("unknown directive `" + directive + "`");
772 }
773 
774 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
775   SMRange loc = curToken.getLoc();
776   consumeToken(Token::directive);
777 
778   // Handle code completion of the include file path.
779   if (curToken.is(Token::code_complete_string))
780     return codeCompleteIncludeFilename(curToken.getStringValue());
781 
782   // Parse the file being included.
783   if (!curToken.isString())
784     return emitError(loc,
785                      "expected string file name after `include` directive");
786   SMRange fileLoc = curToken.getLoc();
787   std::string filenameStr = curToken.getStringValue();
788   StringRef filename = filenameStr;
789   consumeToken();
790 
791   // Check the type of include. If ending with `.pdll`, this is another pdl file
792   // to be parsed along with the current module.
793   if (filename.ends_with(".pdll")) {
794     if (failed(lexer.pushInclude(filename, fileLoc)))
795       return emitError(fileLoc,
796                        "unable to open include file `" + filename + "`");
797 
798     // If we added the include successfully, parse it into the current module.
799     // Make sure to update to the next token after we finish parsing the nested
800     // file.
801     curToken = lexer.lexToken();
802     LogicalResult result = parseModuleBody(decls);
803     curToken = lexer.lexToken();
804     return result;
805   }
806 
807   // Otherwise, this must be a `.td` include.
808   if (filename.ends_with(".td"))
809     return parseTdInclude(filename, fileLoc, decls);
810 
811   return emitError(fileLoc,
812                    "expected include filename to end with `.pdll` or `.td`");
813 }
814 
815 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
816                                      SmallVectorImpl<ast::Decl *> &decls) {
817   llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
818 
819   // Use the source manager to open the file, but don't yet add it.
820   std::string includedFile;
821   llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
822       parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
823   if (!includeBuffer)
824     return emitError(fileLoc, "unable to open include file `" + filename + "`");
825 
826   // Setup the source manager for parsing the tablegen file.
827   llvm::SourceMgr tdSrcMgr;
828   tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
829   tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
830 
831   // This class provides a context argument for the llvm::SourceMgr diagnostic
832   // handler.
833   struct DiagHandlerContext {
834     Parser &parser;
835     StringRef filename;
836     llvm::SMRange loc;
837   } handlerContext{*this, filename, fileLoc};
838 
839   // Set the diagnostic handler for the tablegen source manager.
840   tdSrcMgr.setDiagHandler(
841       [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
842         auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
843         (void)ctx->parser.emitError(
844             ctx->loc,
845             llvm::formatv("error while processing include file `{0}`: {1}",
846                           ctx->filename, diag.getMessage()));
847       },
848       &handlerContext);
849 
850   // Parse the tablegen file.
851   llvm::RecordKeeper tdRecords;
852   if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
853     return failure();
854 
855   // Process the parsed records.
856   processTdIncludeRecords(tdRecords, decls);
857 
858   // After we are done processing, move all of the tablegen source buffers to
859   // the main parser source mgr. This allows for directly using source locations
860   // from the .td files without needing to remap them.
861   parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
862   return success();
863 }
864 
865 void Parser::processTdIncludeRecords(const llvm::RecordKeeper &tdRecords,
866                                      SmallVectorImpl<ast::Decl *> &decls) {
867   // Return the length kind of the given value.
868   auto getLengthKind = [](const auto &value) {
869     if (value.isOptional())
870       return ods::VariableLengthKind::Optional;
871     return value.isVariadic() ? ods::VariableLengthKind::Variadic
872                               : ods::VariableLengthKind::Single;
873   };
874 
875   // Insert a type constraint into the ODS context.
876   ods::Context &odsContext = ctx.getODSContext();
877   auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
878       -> const ods::TypeConstraint & {
879     return odsContext.insertTypeConstraint(
880         cst.constraint.getUniqueDefName(),
881         processDoc(cst.constraint.getSummary()), cst.constraint.getCppType());
882   };
883   auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
884     return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
885   };
886 
887   // Process the parsed tablegen records to build ODS information.
888   /// Operations.
889   for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
890     tblgen::Operator op(def);
891 
892     // Check to see if this operation is known to support type inferrence.
893     bool supportsResultTypeInferrence =
894         op.getTrait("::mlir::InferTypeOpInterface::Trait");
895 
896     auto [odsOp, inserted] = odsContext.insertOperation(
897         op.getOperationName(), processDoc(op.getSummary()),
898         processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
899         supportsResultTypeInferrence, op.getLoc().front());
900 
901     // Ignore operations that have already been added.
902     if (!inserted)
903       continue;
904 
905     for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
906       odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
907                              odsContext.insertAttributeConstraint(
908                                  attr.attr.getUniqueDefName(),
909                                  processDoc(attr.attr.getSummary()),
910                                  attr.attr.getStorageType()));
911     }
912     for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
913       odsOp->appendOperand(operand.name, getLengthKind(operand),
914                            addTypeConstraint(operand));
915     }
916     for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
917       odsOp->appendResult(result.name, getLengthKind(result),
918                           addTypeConstraint(result));
919     }
920   }
921 
922   auto shouldBeSkipped = [this](const llvm::Record *def) {
923     return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
924            def->isSubClassOf("DeclareInterfaceMethods");
925   };
926 
927   /// Attr constraints.
928   for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
929     if (shouldBeSkipped(def))
930       continue;
931 
932     tblgen::Attribute constraint(def);
933     decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
934         constraint, convertLocToRange(def->getLoc().front()), attrTy,
935         constraint.getStorageType()));
936   }
937   /// Type constraints.
938   for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
939     if (shouldBeSkipped(def))
940       continue;
941 
942     tblgen::TypeConstraint constraint(def);
943     decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
944         constraint, convertLocToRange(def->getLoc().front()), typeTy,
945         constraint.getCppType()));
946   }
947   /// OpInterfaces.
948   ast::Type opTy = ast::OperationType::get(ctx);
949   for (const llvm::Record *def :
950        tdRecords.getAllDerivedDefinitions("OpInterface")) {
951     if (shouldBeSkipped(def))
952       continue;
953 
954     SMRange loc = convertLocToRange(def->getLoc().front());
955 
956     std::string cppClassName =
957         llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"),
958                       def->getValueAsString("cppInterfaceName"))
959             .str();
960     std::string codeBlock =
961         llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));",
962                       cppClassName)
963             .str();
964 
965     std::string desc =
966         processAndFormatDoc(def->getValueAsString("description"));
967     decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
968         def->getName(), codeBlock, loc, opTy, cppClassName, desc));
969   }
970 }
971 
972 template <typename ConstraintT>
973 ast::Decl *Parser::createODSNativePDLLConstraintDecl(
974     StringRef name, StringRef codeBlock, SMRange loc, ast::Type type,
975     StringRef nativeType, StringRef docString) {
976   // Build the single input parameter.
977   ast::DeclScope *argScope = pushDeclScope();
978   auto *paramVar = ast::VariableDecl::create(
979       ctx, ast::Name::create(ctx, "self", loc), type,
980       /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
981   argScope->add(paramVar);
982   popDeclScope();
983 
984   // Build the native constraint.
985   auto *constraintDecl = ast::UserConstraintDecl::createNative(
986       ctx, ast::Name::create(ctx, name, loc), paramVar,
987       /*results=*/std::nullopt, codeBlock, ast::TupleType::get(ctx),
988       nativeType);
989   constraintDecl->setDocComment(ctx, docString);
990   curDeclScope->add(constraintDecl);
991   return constraintDecl;
992 }
993 
994 template <typename ConstraintT>
995 ast::Decl *
996 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
997                                           SMRange loc, ast::Type type,
998                                           StringRef nativeType) {
999   // Format the condition template.
1000   tblgen::FmtContext fmtContext;
1001   fmtContext.withSelf("self");
1002   std::string codeBlock = tblgen::tgfmt(
1003       "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
1004       &fmtContext);
1005 
1006   // If documentation was enabled, build the doc string for the generated
1007   // constraint. It would be nice to do this lazily, but TableGen information is
1008   // destroyed after we finish parsing the file.
1009   std::string docString;
1010   if (enableDocumentation) {
1011     StringRef desc = constraint.getDescription();
1012     docString = processAndFormatDoc(
1013         constraint.getSummary() +
1014         (desc.empty() ? "" : ("\n\n" + constraint.getDescription())));
1015   }
1016 
1017   return createODSNativePDLLConstraintDecl<ConstraintT>(
1018       constraint.getUniqueDefName(), codeBlock, loc, type, nativeType,
1019       docString);
1020 }
1021 
1022 //===----------------------------------------------------------------------===//
1023 // Decls
1024 
1025 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
1026   FailureOr<ast::Decl *> decl;
1027   switch (curToken.getKind()) {
1028   case Token::kw_Constraint:
1029     decl = parseUserConstraintDecl();
1030     break;
1031   case Token::kw_Pattern:
1032     decl = parsePatternDecl();
1033     break;
1034   case Token::kw_Rewrite:
1035     decl = parseUserRewriteDecl();
1036     break;
1037   default:
1038     return emitError("expected top-level declaration, such as a `Pattern`");
1039   }
1040   if (failed(decl))
1041     return failure();
1042 
1043   // If the decl has a name, add it to the current scope.
1044   if (const ast::Name *name = (*decl)->getName()) {
1045     if (failed(checkDefineNamedDecl(*name)))
1046       return failure();
1047     curDeclScope->add(*decl);
1048   }
1049   return decl;
1050 }
1051 
1052 FailureOr<ast::NamedAttributeDecl *>
1053 Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) {
1054   // Check for name code completion.
1055   if (curToken.is(Token::code_complete))
1056     return codeCompleteAttributeName(parentOpName);
1057 
1058   std::string attrNameStr;
1059   if (curToken.isString())
1060     attrNameStr = curToken.getStringValue();
1061   else if (curToken.is(Token::identifier) || curToken.isKeyword())
1062     attrNameStr = curToken.getSpelling().str();
1063   else
1064     return emitError("expected identifier or string attribute name");
1065   const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
1066   consumeToken();
1067 
1068   // Check for a value of the attribute.
1069   ast::Expr *attrValue = nullptr;
1070   if (consumeIf(Token::equal)) {
1071     FailureOr<ast::Expr *> attrExpr = parseExpr();
1072     if (failed(attrExpr))
1073       return failure();
1074     attrValue = *attrExpr;
1075   } else {
1076     // If there isn't a concrete value, create an expression representing a
1077     // UnitAttr.
1078     attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
1079   }
1080 
1081   return ast::NamedAttributeDecl::create(ctx, name, attrValue);
1082 }
1083 
1084 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1085     function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
1086     bool expectTerminalSemicolon) {
1087   consumeToken(Token::equal_arrow);
1088 
1089   // Parse the single statement of the lambda body.
1090   SMLoc bodyStartLoc = curToken.getStartLoc();
1091   pushDeclScope();
1092   FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1093   bool failedToParse =
1094       failed(singleStatement) || failed(processStatementFn(*singleStatement));
1095   popDeclScope();
1096   if (failedToParse)
1097     return failure();
1098 
1099   SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1100   return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
1101 }
1102 
1103 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1104   // Ensure that the argument is named.
1105   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
1106     return emitError("expected identifier argument name");
1107 
1108   // Parse the argument similarly to a normal variable.
1109   StringRef name = curToken.getSpelling();
1110   SMRange nameLoc = curToken.getLoc();
1111   consumeToken();
1112 
1113   if (failed(
1114           parseToken(Token::colon, "expected `:` before argument constraint")))
1115     return failure();
1116 
1117   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1118   if (failed(cst))
1119     return failure();
1120 
1121   return createArgOrResultVariableDecl(name, nameLoc, *cst);
1122 }
1123 
1124 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
1125   // Check to see if this result is named.
1126   if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1127     // Check to see if this name actually refers to a Constraint.
1128     if (!curDeclScope->lookup<ast::ConstraintDecl>(curToken.getSpelling())) {
1129       // If it wasn't a constraint, parse the result similarly to a variable. If
1130       // there is already an existing decl, we will emit an error when defining
1131       // this variable later.
1132       StringRef name = curToken.getSpelling();
1133       SMRange nameLoc = curToken.getLoc();
1134       consumeToken();
1135 
1136       if (failed(parseToken(Token::colon,
1137                             "expected `:` before result constraint")))
1138         return failure();
1139 
1140       FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1141       if (failed(cst))
1142         return failure();
1143 
1144       return createArgOrResultVariableDecl(name, nameLoc, *cst);
1145     }
1146   }
1147 
1148   // If it isn't named, we parse the constraint directly and create an unnamed
1149   // result variable.
1150   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1151   if (failed(cst))
1152     return failure();
1153 
1154   return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
1155 }
1156 
1157 FailureOr<ast::UserConstraintDecl *>
1158 Parser::parseUserConstraintDecl(bool isInline) {
1159   // Constraints and rewrites have very similar formats, dispatch to a shared
1160   // interface for parsing.
1161   return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1162       [&](auto &&...args) {
1163         return this->parseUserPDLLConstraintDecl(args...);
1164       },
1165       ParserContext::Constraint, "constraint", isInline);
1166 }
1167 
1168 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1169   FailureOr<ast::UserConstraintDecl *> decl =
1170       parseUserConstraintDecl(/*isInline=*/true);
1171   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1172     return failure();
1173 
1174   curDeclScope->add(*decl);
1175   return decl;
1176 }
1177 
1178 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1179     const ast::Name &name, bool isInline,
1180     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1181     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1182   // Push the argument scope back onto the list, so that the body can
1183   // reference arguments.
1184   pushDeclScope(argumentScope);
1185 
1186   // Parse the body of the constraint. The body is either defined as a compound
1187   // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1188   ast::CompoundStmt *body;
1189   if (curToken.is(Token::equal_arrow)) {
1190     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1191         [&](ast::Stmt *&stmt) -> LogicalResult {
1192           ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1193           if (!stmtExpr) {
1194             return emitError(stmt->getLoc(),
1195                              "expected `Constraint` lambda body to contain a "
1196                              "single expression");
1197           }
1198           stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
1199           return success();
1200         },
1201         /*expectTerminalSemicolon=*/!isInline);
1202     if (failed(bodyResult))
1203       return failure();
1204     body = *bodyResult;
1205   } else {
1206     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1207     if (failed(bodyResult))
1208       return failure();
1209     body = *bodyResult;
1210 
1211     // Verify the structure of the body.
1212     auto bodyIt = body->begin(), bodyE = body->end();
1213     for (; bodyIt != bodyE; ++bodyIt)
1214       if (isa<ast::ReturnStmt>(*bodyIt))
1215         break;
1216     if (failed(validateUserConstraintOrRewriteReturn(
1217             "Constraint", body, bodyIt, bodyE, results, resultType)))
1218       return failure();
1219   }
1220   popDeclScope();
1221 
1222   return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1223       name, arguments, results, resultType, body);
1224 }
1225 
1226 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1227   // Constraints and rewrites have very similar formats, dispatch to a shared
1228   // interface for parsing.
1229   return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1230       [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
1231       ParserContext::Rewrite, "rewrite", isInline);
1232 }
1233 
1234 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1235   FailureOr<ast::UserRewriteDecl *> decl =
1236       parseUserRewriteDecl(/*isInline=*/true);
1237   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1238     return failure();
1239 
1240   curDeclScope->add(*decl);
1241   return decl;
1242 }
1243 
1244 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1245     const ast::Name &name, bool isInline,
1246     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1247     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1248   // Push the argument scope back onto the list, so that the body can
1249   // reference arguments.
1250   curDeclScope = argumentScope;
1251   ast::CompoundStmt *body;
1252   if (curToken.is(Token::equal_arrow)) {
1253     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1254         [&](ast::Stmt *&statement) -> LogicalResult {
1255           if (isa<ast::OpRewriteStmt>(statement))
1256             return success();
1257 
1258           ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1259           if (!statementExpr) {
1260             return emitError(
1261                 statement->getLoc(),
1262                 "expected `Rewrite` lambda body to contain a single expression "
1263                 "or an operation rewrite statement; such as `erase`, "
1264                 "`replace`, or `rewrite`");
1265           }
1266           statement =
1267               ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
1268           return success();
1269         },
1270         /*expectTerminalSemicolon=*/!isInline);
1271     if (failed(bodyResult))
1272       return failure();
1273     body = *bodyResult;
1274   } else {
1275     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1276     if (failed(bodyResult))
1277       return failure();
1278     body = *bodyResult;
1279   }
1280   popDeclScope();
1281 
1282   // Verify the structure of the body.
1283   auto bodyIt = body->begin(), bodyE = body->end();
1284   for (; bodyIt != bodyE; ++bodyIt)
1285     if (isa<ast::ReturnStmt>(*bodyIt))
1286       break;
1287   if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
1288                                                    bodyE, results, resultType)))
1289     return failure();
1290   return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1291       name, arguments, results, resultType, body);
1292 }
1293 
1294 template <typename T, typename ParseUserPDLLDeclFnT>
1295 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1296     ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1297     StringRef anonymousNamePrefix, bool isInline) {
1298   SMRange loc = curToken.getLoc();
1299   consumeToken();
1300   llvm::SaveAndRestore saveCtx(parserContext, declContext);
1301 
1302   // Parse the name of the decl.
1303   const ast::Name *name = nullptr;
1304   if (curToken.isNot(Token::identifier)) {
1305     // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1306     // in C++, so being unnamed is fine.
1307     if (!isInline)
1308       return emitError("expected identifier name");
1309 
1310     // Create a unique anonymous name to use, as the name for this decl is not
1311     // important.
1312     std::string anonName =
1313         llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
1314                       anonymousDeclNameCounter++)
1315             .str();
1316     name = &ast::Name::create(ctx, anonName, loc);
1317   } else {
1318     // If a name was provided, we can use it directly.
1319     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1320     consumeToken(Token::identifier);
1321   }
1322 
1323   // Parse the functional signature of the decl.
1324   SmallVector<ast::VariableDecl *> arguments, results;
1325   ast::DeclScope *argumentScope;
1326   ast::Type resultType;
1327   if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1328                                                    argumentScope, resultType)))
1329     return failure();
1330 
1331   // Check to see which type of constraint this is. If the constraint contains a
1332   // compound body, this is a PDLL decl.
1333   if (curToken.isAny(Token::l_brace, Token::equal_arrow))
1334     return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1335                            resultType);
1336 
1337   // Otherwise, this is a native decl.
1338   return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1339                                                    results, resultType);
1340 }
1341 
1342 template <typename T>
1343 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1344     const ast::Name &name, bool isInline,
1345     ArrayRef<ast::VariableDecl *> arguments,
1346     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1347   // If followed by a string, the native code body has also been specified.
1348   std::string codeStrStorage;
1349   std::optional<StringRef> optCodeStr;
1350   if (curToken.isString()) {
1351     codeStrStorage = curToken.getStringValue();
1352     optCodeStr = codeStrStorage;
1353     consumeToken();
1354   } else if (isInline) {
1355     return emitError(name.getLoc(),
1356                      "external declarations must be declared in global scope");
1357   } else if (curToken.is(Token::error)) {
1358     return failure();
1359   }
1360   if (failed(parseToken(Token::semicolon,
1361                         "expected `;` after native declaration")))
1362     return failure();
1363   return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1364 }
1365 
1366 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1367     SmallVectorImpl<ast::VariableDecl *> &arguments,
1368     SmallVectorImpl<ast::VariableDecl *> &results,
1369     ast::DeclScope *&argumentScope, ast::Type &resultType) {
1370   // Parse the argument list of the decl.
1371   if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
1372     return failure();
1373 
1374   argumentScope = pushDeclScope();
1375   if (curToken.isNot(Token::r_paren)) {
1376     do {
1377       FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1378       if (failed(argument))
1379         return failure();
1380       arguments.emplace_back(*argument);
1381     } while (consumeIf(Token::comma));
1382   }
1383   popDeclScope();
1384   if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
1385     return failure();
1386 
1387   // Parse the results of the decl.
1388   pushDeclScope();
1389   if (consumeIf(Token::arrow)) {
1390     auto parseResultFn = [&]() -> LogicalResult {
1391       FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1392       if (failed(result))
1393         return failure();
1394       results.emplace_back(*result);
1395       return success();
1396     };
1397 
1398     // Check for a list of results.
1399     if (consumeIf(Token::l_paren)) {
1400       do {
1401         if (failed(parseResultFn()))
1402           return failure();
1403       } while (consumeIf(Token::comma));
1404       if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1405         return failure();
1406 
1407       // Otherwise, there is only one result.
1408     } else if (failed(parseResultFn())) {
1409       return failure();
1410     }
1411   }
1412   popDeclScope();
1413 
1414   // Compute the result type of the decl.
1415   resultType = createUserConstraintRewriteResultType(results);
1416 
1417   // Verify that results are only named if there are more than one.
1418   if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1419     return emitError(
1420         results.front()->getLoc(),
1421         "cannot create a single-element tuple with an element label");
1422   }
1423   return success();
1424 }
1425 
1426 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1427     StringRef declType, ast::CompoundStmt *body,
1428     ArrayRef<ast::Stmt *>::iterator bodyIt,
1429     ArrayRef<ast::Stmt *>::iterator bodyE,
1430     ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1431   // Handle if a `return` was provided.
1432   if (bodyIt != bodyE) {
1433     // Emit an error if we have trailing statements after the return.
1434     if (std::next(bodyIt) != bodyE) {
1435       return emitError(
1436           (*std::next(bodyIt))->getLoc(),
1437           llvm::formatv("`return` terminated the `{0}` body, but found "
1438                         "trailing statements afterwards",
1439                         declType));
1440     }
1441 
1442     // Otherwise if a return wasn't provided, check that no results are
1443     // expected.
1444   } else if (!results.empty()) {
1445     return emitError(
1446         {body->getLoc().End, body->getLoc().End},
1447         llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1448                       declType, resultType));
1449   }
1450   return success();
1451 }
1452 
1453 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1454   return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1455     if (isa<ast::OpRewriteStmt>(statement))
1456       return success();
1457     return emitError(
1458         statement->getLoc(),
1459         "expected Pattern lambda body to contain a single operation "
1460         "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1461   });
1462 }
1463 
1464 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1465   SMRange loc = curToken.getLoc();
1466   consumeToken(Token::kw_Pattern);
1467   llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch);
1468 
1469   // Check for an optional identifier for the pattern name.
1470   const ast::Name *name = nullptr;
1471   if (curToken.is(Token::identifier)) {
1472     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1473     consumeToken(Token::identifier);
1474   }
1475 
1476   // Parse any pattern metadata.
1477   ParsedPatternMetadata metadata;
1478   if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1479     return failure();
1480 
1481   // Parse the pattern body.
1482   ast::CompoundStmt *body;
1483 
1484   // Handle a lambda body.
1485   if (curToken.is(Token::equal_arrow)) {
1486     FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1487     if (failed(bodyResult))
1488       return failure();
1489     body = *bodyResult;
1490   } else {
1491     if (curToken.isNot(Token::l_brace))
1492       return emitError("expected `{` or `=>` to start pattern body");
1493     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1494     if (failed(bodyResult))
1495       return failure();
1496     body = *bodyResult;
1497 
1498     // Verify the body of the pattern.
1499     auto bodyIt = body->begin(), bodyE = body->end();
1500     for (; bodyIt != bodyE; ++bodyIt) {
1501       if (isa<ast::ReturnStmt>(*bodyIt)) {
1502         return emitError((*bodyIt)->getLoc(),
1503                          "`return` statements are only permitted within a "
1504                          "`Constraint` or `Rewrite` body");
1505       }
1506       // Break when we've found the rewrite statement.
1507       if (isa<ast::OpRewriteStmt>(*bodyIt))
1508         break;
1509     }
1510     if (bodyIt == bodyE) {
1511       return emitError(loc,
1512                        "expected Pattern body to terminate with an operation "
1513                        "rewrite statement, such as `erase`");
1514     }
1515     if (std::next(bodyIt) != bodyE) {
1516       return emitError((*std::next(bodyIt))->getLoc(),
1517                        "Pattern body was terminated by an operation "
1518                        "rewrite statement, but found trailing statements");
1519     }
1520   }
1521 
1522   return createPatternDecl(loc, name, metadata, body);
1523 }
1524 
1525 LogicalResult
1526 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1527   std::optional<SMRange> benefitLoc;
1528   std::optional<SMRange> hasBoundedRecursionLoc;
1529 
1530   do {
1531     // Handle metadata code completion.
1532     if (curToken.is(Token::code_complete))
1533       return codeCompletePatternMetadata();
1534 
1535     if (curToken.isNot(Token::identifier))
1536       return emitError("expected pattern metadata identifier");
1537     StringRef metadataStr = curToken.getSpelling();
1538     SMRange metadataLoc = curToken.getLoc();
1539     consumeToken(Token::identifier);
1540 
1541     // Parse the benefit metadata: benefit(<integer-value>)
1542     if (metadataStr == "benefit") {
1543       if (benefitLoc) {
1544         return emitErrorAndNote(metadataLoc,
1545                                 "pattern benefit has already been specified",
1546                                 *benefitLoc, "see previous definition here");
1547       }
1548       if (failed(parseToken(Token::l_paren,
1549                             "expected `(` before pattern benefit")))
1550         return failure();
1551 
1552       uint16_t benefitValue = 0;
1553       if (curToken.isNot(Token::integer))
1554         return emitError("expected integral pattern benefit");
1555       if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1556         return emitError(
1557             "expected pattern benefit to fit within a 16-bit integer");
1558       consumeToken(Token::integer);
1559 
1560       metadata.benefit = benefitValue;
1561       benefitLoc = metadataLoc;
1562 
1563       if (failed(
1564               parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1565         return failure();
1566       continue;
1567     }
1568 
1569     // Parse the bounded recursion metadata: recursion
1570     if (metadataStr == "recursion") {
1571       if (hasBoundedRecursionLoc) {
1572         return emitErrorAndNote(
1573             metadataLoc,
1574             "pattern recursion metadata has already been specified",
1575             *hasBoundedRecursionLoc, "see previous definition here");
1576       }
1577       metadata.hasBoundedRecursion = true;
1578       hasBoundedRecursionLoc = metadataLoc;
1579       continue;
1580     }
1581 
1582     return emitError(metadataLoc, "unknown pattern metadata");
1583   } while (consumeIf(Token::comma));
1584 
1585   return success();
1586 }
1587 
1588 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1589   consumeToken(Token::less);
1590 
1591   FailureOr<ast::Expr *> typeExpr = parseExpr();
1592   if (failed(typeExpr) ||
1593       failed(parseToken(Token::greater,
1594                         "expected `>` after variable type constraint")))
1595     return failure();
1596   return typeExpr;
1597 }
1598 
1599 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1600   assert(curDeclScope && "defining decl outside of a decl scope");
1601   if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1602     return emitErrorAndNote(
1603         name.getLoc(), "`" + name.getName() + "` has already been defined",
1604         lastDecl->getName()->getLoc(), "see previous definition here");
1605   }
1606   return success();
1607 }
1608 
1609 FailureOr<ast::VariableDecl *>
1610 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1611                            ast::Expr *initExpr,
1612                            ArrayRef<ast::ConstraintRef> constraints) {
1613   assert(curDeclScope && "defining variable outside of decl scope");
1614   const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1615 
1616   // If the name of the variable indicates a special variable, we don't add it
1617   // to the scope. This variable is local to the definition point.
1618   if (name.empty() || name == "_") {
1619     return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1620                                      constraints);
1621   }
1622   if (failed(checkDefineNamedDecl(nameDecl)))
1623     return failure();
1624 
1625   auto *varDecl =
1626       ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1627   curDeclScope->add(varDecl);
1628   return varDecl;
1629 }
1630 
1631 FailureOr<ast::VariableDecl *>
1632 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1633                            ArrayRef<ast::ConstraintRef> constraints) {
1634   return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1635                             constraints);
1636 }
1637 
1638 LogicalResult Parser::parseVariableDeclConstraintList(
1639     SmallVectorImpl<ast::ConstraintRef> &constraints) {
1640   std::optional<SMRange> typeConstraint;
1641   auto parseSingleConstraint = [&] {
1642     FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1643         typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
1644     if (failed(constraint))
1645       return failure();
1646     constraints.push_back(*constraint);
1647     return success();
1648   };
1649 
1650   // Check to see if this is a single constraint, or a list.
1651   if (!consumeIf(Token::l_square))
1652     return parseSingleConstraint();
1653 
1654   do {
1655     if (failed(parseSingleConstraint()))
1656       return failure();
1657   } while (consumeIf(Token::comma));
1658   return parseToken(Token::r_square, "expected `]` after constraint list");
1659 }
1660 
1661 FailureOr<ast::ConstraintRef>
1662 Parser::parseConstraint(std::optional<SMRange> &typeConstraint,
1663                         ArrayRef<ast::ConstraintRef> existingConstraints,
1664                         bool allowInlineTypeConstraints) {
1665   auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1666     if (!allowInlineTypeConstraints) {
1667       return emitError(
1668           curToken.getLoc(),
1669           "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1670           "permitted on arguments or results");
1671     }
1672     if (typeConstraint)
1673       return emitErrorAndNote(
1674           curToken.getLoc(),
1675           "the type of this variable has already been constrained",
1676           *typeConstraint, "see previous constraint location here");
1677     FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1678     if (failed(constraintExpr))
1679       return failure();
1680     typeExpr = *constraintExpr;
1681     typeConstraint = typeExpr->getLoc();
1682     return success();
1683   };
1684 
1685   SMRange loc = curToken.getLoc();
1686   switch (curToken.getKind()) {
1687   case Token::kw_Attr: {
1688     consumeToken(Token::kw_Attr);
1689 
1690     // Check for a type constraint.
1691     ast::Expr *typeExpr = nullptr;
1692     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1693       return failure();
1694     return ast::ConstraintRef(
1695         ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1696   }
1697   case Token::kw_Op: {
1698     consumeToken(Token::kw_Op);
1699 
1700     // Parse an optional operation name. If the name isn't provided, this refers
1701     // to "any" operation.
1702     FailureOr<ast::OpNameDecl *> opName =
1703         parseWrappedOperationName(/*allowEmptyName=*/true);
1704     if (failed(opName))
1705       return failure();
1706 
1707     return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1708                               loc);
1709   }
1710   case Token::kw_Type:
1711     consumeToken(Token::kw_Type);
1712     return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1713   case Token::kw_TypeRange:
1714     consumeToken(Token::kw_TypeRange);
1715     return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
1716                               loc);
1717   case Token::kw_Value: {
1718     consumeToken(Token::kw_Value);
1719 
1720     // Check for a type constraint.
1721     ast::Expr *typeExpr = nullptr;
1722     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1723       return failure();
1724 
1725     return ast::ConstraintRef(
1726         ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1727   }
1728   case Token::kw_ValueRange: {
1729     consumeToken(Token::kw_ValueRange);
1730 
1731     // Check for a type constraint.
1732     ast::Expr *typeExpr = nullptr;
1733     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1734       return failure();
1735 
1736     return ast::ConstraintRef(
1737         ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1738   }
1739 
1740   case Token::kw_Constraint: {
1741     // Handle an inline constraint.
1742     FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1743     if (failed(decl))
1744       return failure();
1745     return ast::ConstraintRef(*decl, loc);
1746   }
1747   case Token::identifier: {
1748     StringRef constraintName = curToken.getSpelling();
1749     consumeToken(Token::identifier);
1750 
1751     // Lookup the referenced constraint.
1752     ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1753     if (!cstDecl) {
1754       return emitError(loc, "unknown reference to constraint `" +
1755                                 constraintName + "`");
1756     }
1757 
1758     // Handle a reference to a proper constraint.
1759     if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1760       return ast::ConstraintRef(cst, loc);
1761 
1762     return emitErrorAndNote(
1763         loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1764         "see the definition of `" + constraintName + "` here");
1765   }
1766     // Handle single entity constraint code completion.
1767   case Token::code_complete: {
1768     // Try to infer the current type for use by code completion.
1769     ast::Type inferredType;
1770     if (failed(validateVariableConstraints(existingConstraints, inferredType)))
1771       return failure();
1772 
1773     return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
1774   }
1775   default:
1776     break;
1777   }
1778   return emitError(loc, "expected identifier constraint");
1779 }
1780 
1781 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1782   std::optional<SMRange> typeConstraint;
1783   return parseConstraint(typeConstraint, /*existingConstraints=*/std::nullopt,
1784                          /*allowInlineTypeConstraints=*/false);
1785 }
1786 
1787 //===----------------------------------------------------------------------===//
1788 // Exprs
1789 
1790 FailureOr<ast::Expr *> Parser::parseExpr() {
1791   if (curToken.is(Token::underscore))
1792     return parseUnderscoreExpr();
1793 
1794   // Parse the LHS expression.
1795   FailureOr<ast::Expr *> lhsExpr;
1796   switch (curToken.getKind()) {
1797   case Token::kw_attr:
1798     lhsExpr = parseAttributeExpr();
1799     break;
1800   case Token::kw_Constraint:
1801     lhsExpr = parseInlineConstraintLambdaExpr();
1802     break;
1803   case Token::kw_not:
1804     lhsExpr = parseNegatedExpr();
1805     break;
1806   case Token::identifier:
1807     lhsExpr = parseIdentifierExpr();
1808     break;
1809   case Token::kw_op:
1810     lhsExpr = parseOperationExpr();
1811     break;
1812   case Token::kw_Rewrite:
1813     lhsExpr = parseInlineRewriteLambdaExpr();
1814     break;
1815   case Token::kw_type:
1816     lhsExpr = parseTypeExpr();
1817     break;
1818   case Token::l_paren:
1819     lhsExpr = parseTupleExpr();
1820     break;
1821   default:
1822     return emitError("expected expression");
1823   }
1824   if (failed(lhsExpr))
1825     return failure();
1826 
1827   // Check for an operator expression.
1828   while (true) {
1829     switch (curToken.getKind()) {
1830     case Token::dot:
1831       lhsExpr = parseMemberAccessExpr(*lhsExpr);
1832       break;
1833     case Token::l_paren:
1834       lhsExpr = parseCallExpr(*lhsExpr);
1835       break;
1836     default:
1837       return lhsExpr;
1838     }
1839     if (failed(lhsExpr))
1840       return failure();
1841   }
1842 }
1843 
1844 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1845   SMRange loc = curToken.getLoc();
1846   consumeToken(Token::kw_attr);
1847 
1848   // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1849   // identifier.
1850   if (!consumeIf(Token::less)) {
1851     resetToken(loc);
1852     return parseIdentifierExpr();
1853   }
1854 
1855   if (!curToken.isString())
1856     return emitError("expected string literal containing MLIR attribute");
1857   std::string attrExpr = curToken.getStringValue();
1858   consumeToken();
1859 
1860   loc.End = curToken.getEndLoc();
1861   if (failed(
1862           parseToken(Token::greater, "expected `>` after attribute literal")))
1863     return failure();
1864   return ast::AttributeExpr::create(ctx, loc, attrExpr);
1865 }
1866 
1867 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr,
1868                                              bool isNegated) {
1869   consumeToken(Token::l_paren);
1870 
1871   // Parse the arguments of the call.
1872   SmallVector<ast::Expr *> arguments;
1873   if (curToken.isNot(Token::r_paren)) {
1874     do {
1875       // Handle code completion for the call arguments.
1876       if (curToken.is(Token::code_complete)) {
1877         codeCompleteCallSignature(parentExpr, arguments.size());
1878         return failure();
1879       }
1880 
1881       FailureOr<ast::Expr *> argument = parseExpr();
1882       if (failed(argument))
1883         return failure();
1884       arguments.push_back(*argument);
1885     } while (consumeIf(Token::comma));
1886   }
1887 
1888   SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1889   if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1890     return failure();
1891 
1892   return createCallExpr(loc, parentExpr, arguments, isNegated);
1893 }
1894 
1895 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1896   ast::Decl *decl = curDeclScope->lookup(name);
1897   if (!decl)
1898     return emitError(loc, "undefined reference to `" + name + "`");
1899 
1900   return createDeclRefExpr(loc, decl);
1901 }
1902 
1903 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1904   StringRef name = curToken.getSpelling();
1905   SMRange nameLoc = curToken.getLoc();
1906   consumeToken();
1907 
1908   // Check to see if this is a decl ref expression that defines a variable
1909   // inline.
1910   if (consumeIf(Token::colon)) {
1911     SmallVector<ast::ConstraintRef> constraints;
1912     if (failed(parseVariableDeclConstraintList(constraints)))
1913       return failure();
1914     ast::Type type;
1915     if (failed(validateVariableConstraints(constraints, type)))
1916       return failure();
1917     return createInlineVariableExpr(type, name, nameLoc, constraints);
1918   }
1919 
1920   return parseDeclRefExpr(name, nameLoc);
1921 }
1922 
1923 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1924   FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1925   if (failed(decl))
1926     return failure();
1927 
1928   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1929                                   ast::ConstraintType::get(ctx));
1930 }
1931 
1932 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1933   FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1934   if (failed(decl))
1935     return failure();
1936 
1937   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1938                                   ast::RewriteType::get(ctx));
1939 }
1940 
1941 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1942   SMRange dotLoc = curToken.getLoc();
1943   consumeToken(Token::dot);
1944 
1945   // Check for code completion of the member name.
1946   if (curToken.is(Token::code_complete))
1947     return codeCompleteMemberAccess(parentExpr);
1948 
1949   // Parse the member name.
1950   Token memberNameTok = curToken;
1951   if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1952       !memberNameTok.isKeyword())
1953     return emitError(dotLoc, "expected identifier or numeric member name");
1954   StringRef memberName = memberNameTok.getSpelling();
1955   SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1956   consumeToken();
1957 
1958   return createMemberAccessExpr(parentExpr, memberName, loc);
1959 }
1960 
1961 FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
1962   consumeToken(Token::kw_not);
1963   // Only native constraints are supported after negation
1964   if (!curToken.is(Token::identifier))
1965     return emitError("expected native constraint");
1966   FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
1967   if (failed(identifierExpr))
1968     return failure();
1969   if (!curToken.is(Token::l_paren))
1970     return emitError("expected `(` after function name");
1971   return parseCallExpr(*identifierExpr, /*isNegated = */ true);
1972 }
1973 
1974 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1975   SMRange loc = curToken.getLoc();
1976 
1977   // Check for code completion for the dialect name.
1978   if (curToken.is(Token::code_complete))
1979     return codeCompleteDialectName();
1980 
1981   // Handle the case of an no operation name.
1982   if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1983     if (allowEmptyName)
1984       return ast::OpNameDecl::create(ctx, SMRange());
1985     return emitError("expected dialect namespace");
1986   }
1987   StringRef name = curToken.getSpelling();
1988   consumeToken();
1989 
1990   // Otherwise, this is a literal operation name.
1991   if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1992     return failure();
1993 
1994   // Check for code completion for the operation name.
1995   if (curToken.is(Token::code_complete))
1996     return codeCompleteOperationName(name);
1997 
1998   if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
1999     return emitError("expected operation name after dialect namespace");
2000 
2001   name = StringRef(name.data(), name.size() + 1);
2002   do {
2003     name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
2004     loc.End = curToken.getEndLoc();
2005     consumeToken();
2006   } while (curToken.isAny(Token::identifier, Token::dot) ||
2007            curToken.isKeyword());
2008   return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
2009 }
2010 
2011 FailureOr<ast::OpNameDecl *>
2012 Parser::parseWrappedOperationName(bool allowEmptyName) {
2013   if (!consumeIf(Token::less))
2014     return ast::OpNameDecl::create(ctx, SMRange());
2015 
2016   FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
2017   if (failed(opNameDecl))
2018     return failure();
2019 
2020   if (failed(parseToken(Token::greater, "expected `>` after operation name")))
2021     return failure();
2022   return opNameDecl;
2023 }
2024 
2025 FailureOr<ast::Expr *>
2026 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
2027   SMRange loc = curToken.getLoc();
2028   consumeToken(Token::kw_op);
2029 
2030   // If it isn't followed by a `<`, the `op` keyword is treated as a normal
2031   // identifier.
2032   if (curToken.isNot(Token::less)) {
2033     resetToken(loc);
2034     return parseIdentifierExpr();
2035   }
2036 
2037   // Parse the operation name. The name may be elided, in which case the
2038   // operation refers to "any" operation(i.e. a difference between `MyOp` and
2039   // `Operation*`). Operation names within a rewrite context must be named.
2040   bool allowEmptyName = parserContext != ParserContext::Rewrite;
2041   FailureOr<ast::OpNameDecl *> opNameDecl =
2042       parseWrappedOperationName(allowEmptyName);
2043   if (failed(opNameDecl))
2044     return failure();
2045   std::optional<StringRef> opName = (*opNameDecl)->getName();
2046 
2047   // Functor used to create an implicit range variable, used for implicit "all"
2048   // operand or results variables.
2049   auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
2050     FailureOr<ast::VariableDecl *> rangeVar =
2051         defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
2052     assert(succeeded(rangeVar) && "expected range variable to be valid");
2053     return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
2054   };
2055 
2056   // Check for the optional list of operands.
2057   SmallVector<ast::Expr *> operands;
2058   if (!consumeIf(Token::l_paren)) {
2059     // If the operand list isn't specified and we are in a match context, define
2060     // an inplace unconstrained operand range corresponding to all of the
2061     // operands of the operation. This avoids treating zero operands the same
2062     // way as "unconstrained operands".
2063     if (parserContext != ParserContext::Rewrite) {
2064       operands.push_back(createImplicitRangeVar(
2065           ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
2066     }
2067   } else if (!consumeIf(Token::r_paren)) {
2068     // If the operand list was specified and non-empty, parse the operands.
2069     do {
2070       // Check for operand signature code completion.
2071       if (curToken.is(Token::code_complete)) {
2072         codeCompleteOperationOperandsSignature(opName, operands.size());
2073         return failure();
2074       }
2075 
2076       FailureOr<ast::Expr *> operand = parseExpr();
2077       if (failed(operand))
2078         return failure();
2079       operands.push_back(*operand);
2080     } while (consumeIf(Token::comma));
2081 
2082     if (failed(parseToken(Token::r_paren,
2083                           "expected `)` after operation operand list")))
2084       return failure();
2085   }
2086 
2087   // Check for the optional list of attributes.
2088   SmallVector<ast::NamedAttributeDecl *> attributes;
2089   if (consumeIf(Token::l_brace)) {
2090     do {
2091       FailureOr<ast::NamedAttributeDecl *> decl =
2092           parseNamedAttributeDecl(opName);
2093       if (failed(decl))
2094         return failure();
2095       attributes.emplace_back(*decl);
2096     } while (consumeIf(Token::comma));
2097 
2098     if (failed(parseToken(Token::r_brace,
2099                           "expected `}` after operation attribute list")))
2100       return failure();
2101   }
2102 
2103   // Handle the result types of the operation.
2104   SmallVector<ast::Expr *> resultTypes;
2105   OpResultTypeContext resultTypeContext = inputResultTypeContext;
2106 
2107   // Check for an explicit list of result types.
2108   if (consumeIf(Token::arrow)) {
2109     if (failed(parseToken(Token::l_paren,
2110                           "expected `(` before operation result type list")))
2111       return failure();
2112 
2113     // If result types are provided, initially assume that the operation does
2114     // not rely on type inferrence. We don't assert that it isn't, because we
2115     // may be inferring the value of some type/type range variables, but given
2116     // that these variables may be defined in calls we can't always discern when
2117     // this is the case.
2118     resultTypeContext = OpResultTypeContext::Explicit;
2119 
2120     // Handle the case of an empty result list.
2121     if (!consumeIf(Token::r_paren)) {
2122       do {
2123         // Check for result signature code completion.
2124         if (curToken.is(Token::code_complete)) {
2125           codeCompleteOperationResultsSignature(opName, resultTypes.size());
2126           return failure();
2127         }
2128 
2129         FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2130         if (failed(resultTypeExpr))
2131           return failure();
2132         resultTypes.push_back(*resultTypeExpr);
2133       } while (consumeIf(Token::comma));
2134 
2135       if (failed(parseToken(Token::r_paren,
2136                             "expected `)` after operation result type list")))
2137         return failure();
2138     }
2139   } else if (parserContext != ParserContext::Rewrite) {
2140     // If the result list isn't specified and we are in a match context, define
2141     // an inplace unconstrained result range corresponding to all of the results
2142     // of the operation. This avoids treating zero results the same way as
2143     // "unconstrained results".
2144     resultTypes.push_back(createImplicitRangeVar(
2145         ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
2146   } else if (resultTypeContext == OpResultTypeContext::Explicit) {
2147     // If the result list isn't specified and we are in a rewrite, try to infer
2148     // them at runtime instead.
2149     resultTypeContext = OpResultTypeContext::Interface;
2150   }
2151 
2152   return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2153                              attributes, resultTypes);
2154 }
2155 
2156 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2157   SMRange loc = curToken.getLoc();
2158   consumeToken(Token::l_paren);
2159 
2160   DenseMap<StringRef, SMRange> usedNames;
2161   SmallVector<StringRef> elementNames;
2162   SmallVector<ast::Expr *> elements;
2163   if (curToken.isNot(Token::r_paren)) {
2164     do {
2165       // Check for the optional element name assignment before the value.
2166       StringRef elementName;
2167       if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
2168         Token elementNameTok = curToken;
2169         consumeToken();
2170 
2171         // The element name is only present if followed by an `=`.
2172         if (consumeIf(Token::equal)) {
2173           elementName = elementNameTok.getSpelling();
2174 
2175           // Check to see if this name is already used.
2176           auto elementNameIt =
2177               usedNames.try_emplace(elementName, elementNameTok.getLoc());
2178           if (!elementNameIt.second) {
2179             return emitErrorAndNote(
2180                 elementNameTok.getLoc(),
2181                 llvm::formatv("duplicate tuple element label `{0}`",
2182                               elementName),
2183                 elementNameIt.first->getSecond(),
2184                 "see previous label use here");
2185           }
2186         } else {
2187           // Otherwise, we treat this as part of an expression so reset the
2188           // lexer.
2189           resetToken(elementNameTok.getLoc());
2190         }
2191       }
2192       elementNames.push_back(elementName);
2193 
2194       // Parse the tuple element value.
2195       FailureOr<ast::Expr *> element = parseExpr();
2196       if (failed(element))
2197         return failure();
2198       elements.push_back(*element);
2199     } while (consumeIf(Token::comma));
2200   }
2201   loc.End = curToken.getEndLoc();
2202   if (failed(
2203           parseToken(Token::r_paren, "expected `)` after tuple element list")))
2204     return failure();
2205   return createTupleExpr(loc, elements, elementNames);
2206 }
2207 
2208 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2209   SMRange loc = curToken.getLoc();
2210   consumeToken(Token::kw_type);
2211 
2212   // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2213   // identifier.
2214   if (!consumeIf(Token::less)) {
2215     resetToken(loc);
2216     return parseIdentifierExpr();
2217   }
2218 
2219   if (!curToken.isString())
2220     return emitError("expected string literal containing MLIR type");
2221   std::string attrExpr = curToken.getStringValue();
2222   consumeToken();
2223 
2224   loc.End = curToken.getEndLoc();
2225   if (failed(parseToken(Token::greater, "expected `>` after type literal")))
2226     return failure();
2227   return ast::TypeExpr::create(ctx, loc, attrExpr);
2228 }
2229 
2230 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2231   StringRef name = curToken.getSpelling();
2232   SMRange nameLoc = curToken.getLoc();
2233   consumeToken(Token::underscore);
2234 
2235   // Underscore expressions require a constraint list.
2236   if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
2237     return failure();
2238 
2239   // Parse the constraints for the expression.
2240   SmallVector<ast::ConstraintRef> constraints;
2241   if (failed(parseVariableDeclConstraintList(constraints)))
2242     return failure();
2243 
2244   ast::Type type;
2245   if (failed(validateVariableConstraints(constraints, type)))
2246     return failure();
2247   return createInlineVariableExpr(type, name, nameLoc, constraints);
2248 }
2249 
2250 //===----------------------------------------------------------------------===//
2251 // Stmts
2252 
2253 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2254   FailureOr<ast::Stmt *> stmt;
2255   switch (curToken.getKind()) {
2256   case Token::kw_erase:
2257     stmt = parseEraseStmt();
2258     break;
2259   case Token::kw_let:
2260     stmt = parseLetStmt();
2261     break;
2262   case Token::kw_replace:
2263     stmt = parseReplaceStmt();
2264     break;
2265   case Token::kw_return:
2266     stmt = parseReturnStmt();
2267     break;
2268   case Token::kw_rewrite:
2269     stmt = parseRewriteStmt();
2270     break;
2271   default:
2272     stmt = parseExpr();
2273     break;
2274   }
2275   if (failed(stmt) ||
2276       (expectTerminalSemicolon &&
2277        failed(parseToken(Token::semicolon, "expected `;` after statement"))))
2278     return failure();
2279   return stmt;
2280 }
2281 
2282 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2283   SMLoc startLoc = curToken.getStartLoc();
2284   consumeToken(Token::l_brace);
2285 
2286   // Push a new block scope and parse any nested statements.
2287   pushDeclScope();
2288   SmallVector<ast::Stmt *> statements;
2289   while (curToken.isNot(Token::r_brace)) {
2290     FailureOr<ast::Stmt *> statement = parseStmt();
2291     if (failed(statement))
2292       return popDeclScope(), failure();
2293     statements.push_back(*statement);
2294   }
2295   popDeclScope();
2296 
2297   // Consume the end brace.
2298   SMRange location(startLoc, curToken.getEndLoc());
2299   consumeToken(Token::r_brace);
2300 
2301   return ast::CompoundStmt::create(ctx, location, statements);
2302 }
2303 
2304 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2305   if (parserContext == ParserContext::Constraint)
2306     return emitError("`erase` cannot be used within a Constraint");
2307   SMRange loc = curToken.getLoc();
2308   consumeToken(Token::kw_erase);
2309 
2310   // Parse the root operation expression.
2311   FailureOr<ast::Expr *> rootOp = parseExpr();
2312   if (failed(rootOp))
2313     return failure();
2314 
2315   return createEraseStmt(loc, *rootOp);
2316 }
2317 
2318 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2319   SMRange loc = curToken.getLoc();
2320   consumeToken(Token::kw_let);
2321 
2322   // Parse the name of the new variable.
2323   SMRange varLoc = curToken.getLoc();
2324   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
2325     // `_` is a reserved variable name.
2326     if (curToken.is(Token::underscore)) {
2327       return emitError(varLoc,
2328                        "`_` may only be used to define \"inline\" variables");
2329     }
2330     return emitError(varLoc,
2331                      "expected identifier after `let` to name a new variable");
2332   }
2333   StringRef varName = curToken.getSpelling();
2334   consumeToken();
2335 
2336   // Parse the optional set of constraints.
2337   SmallVector<ast::ConstraintRef> constraints;
2338   if (consumeIf(Token::colon) &&
2339       failed(parseVariableDeclConstraintList(constraints)))
2340     return failure();
2341 
2342   // Parse the optional initializer expression.
2343   ast::Expr *initializer = nullptr;
2344   if (consumeIf(Token::equal)) {
2345     FailureOr<ast::Expr *> initOrFailure = parseExpr();
2346     if (failed(initOrFailure))
2347       return failure();
2348     initializer = *initOrFailure;
2349 
2350     // Check that the constraints are compatible with having an initializer,
2351     // e.g. type constraints cannot be used with initializers.
2352     for (ast::ConstraintRef constraint : constraints) {
2353       LogicalResult result =
2354           TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2355               .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
2356                     ast::ValueRangeConstraintDecl>([&](const auto *cst) {
2357                 if (cst->getTypeExpr()) {
2358                   return this->emitError(
2359                       constraint.referenceLoc,
2360                       "type constraints are not permitted on variables with "
2361                       "initializers");
2362                 }
2363                 return success();
2364               })
2365               .Default(success());
2366       if (failed(result))
2367         return failure();
2368     }
2369   }
2370 
2371   FailureOr<ast::VariableDecl *> varDecl =
2372       createVariableDecl(varName, varLoc, initializer, constraints);
2373   if (failed(varDecl))
2374     return failure();
2375   return ast::LetStmt::create(ctx, loc, *varDecl);
2376 }
2377 
2378 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2379   if (parserContext == ParserContext::Constraint)
2380     return emitError("`replace` cannot be used within a Constraint");
2381   SMRange loc = curToken.getLoc();
2382   consumeToken(Token::kw_replace);
2383 
2384   // Parse the root operation expression.
2385   FailureOr<ast::Expr *> rootOp = parseExpr();
2386   if (failed(rootOp))
2387     return failure();
2388 
2389   if (failed(
2390           parseToken(Token::kw_with, "expected `with` after root operation")))
2391     return failure();
2392 
2393   // The replacement portion of this statement is within a rewrite context.
2394   llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2395 
2396   // Parse the replacement values.
2397   SmallVector<ast::Expr *> replValues;
2398   if (consumeIf(Token::l_paren)) {
2399     if (consumeIf(Token::r_paren)) {
2400       return emitError(
2401           loc, "expected at least one replacement value, consider using "
2402                "`erase` if no replacement values are desired");
2403     }
2404 
2405     do {
2406       FailureOr<ast::Expr *> replExpr = parseExpr();
2407       if (failed(replExpr))
2408         return failure();
2409       replValues.emplace_back(*replExpr);
2410     } while (consumeIf(Token::comma));
2411 
2412     if (failed(parseToken(Token::r_paren,
2413                           "expected `)` after replacement values")))
2414       return failure();
2415   } else {
2416     // Handle replacement with an operation uniquely, as the replacement
2417     // operation supports type inferrence from the root operation.
2418     FailureOr<ast::Expr *> replExpr;
2419     if (curToken.is(Token::kw_op))
2420       replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2421     else
2422       replExpr = parseExpr();
2423     if (failed(replExpr))
2424       return failure();
2425     replValues.emplace_back(*replExpr);
2426   }
2427 
2428   return createReplaceStmt(loc, *rootOp, replValues);
2429 }
2430 
2431 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2432   SMRange loc = curToken.getLoc();
2433   consumeToken(Token::kw_return);
2434 
2435   // Parse the result value.
2436   FailureOr<ast::Expr *> resultExpr = parseExpr();
2437   if (failed(resultExpr))
2438     return failure();
2439 
2440   return ast::ReturnStmt::create(ctx, loc, *resultExpr);
2441 }
2442 
2443 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2444   if (parserContext == ParserContext::Constraint)
2445     return emitError("`rewrite` cannot be used within a Constraint");
2446   SMRange loc = curToken.getLoc();
2447   consumeToken(Token::kw_rewrite);
2448 
2449   // Parse the root operation.
2450   FailureOr<ast::Expr *> rootOp = parseExpr();
2451   if (failed(rootOp))
2452     return failure();
2453 
2454   if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
2455     return failure();
2456 
2457   if (curToken.isNot(Token::l_brace))
2458     return emitError("expected `{` to start rewrite body");
2459 
2460   // The rewrite body of this statement is within a rewrite context.
2461   llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2462 
2463   FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2464   if (failed(rewriteBody))
2465     return failure();
2466 
2467   // Verify the rewrite body.
2468   for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2469     if (isa<ast::ReturnStmt>(stmt)) {
2470       return emitError(stmt->getLoc(),
2471                        "`return` statements are only permitted within a "
2472                        "`Constraint` or `Rewrite` body");
2473     }
2474   }
2475 
2476   return createRewriteStmt(loc, *rootOp, *rewriteBody);
2477 }
2478 
2479 //===----------------------------------------------------------------------===//
2480 // Creation+Analysis
2481 //===----------------------------------------------------------------------===//
2482 
2483 //===----------------------------------------------------------------------===//
2484 // Decls
2485 
2486 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2487   // Unwrap reference expressions.
2488   if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2489     node = init->getDecl();
2490   return dyn_cast<ast::CallableDecl>(node);
2491 }
2492 
2493 FailureOr<ast::PatternDecl *>
2494 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2495                           const ParsedPatternMetadata &metadata,
2496                           ast::CompoundStmt *body) {
2497   return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2498                                   metadata.hasBoundedRecursion, body);
2499 }
2500 
2501 ast::Type Parser::createUserConstraintRewriteResultType(
2502     ArrayRef<ast::VariableDecl *> results) {
2503   // Single result decls use the type of the single result.
2504   if (results.size() == 1)
2505     return results[0]->getType();
2506 
2507   // Multiple results use a tuple type, with the types and names grabbed from
2508   // the result variable decls.
2509   auto resultTypes = llvm::map_range(
2510       results, [&](const auto *result) { return result->getType(); });
2511   auto resultNames = llvm::map_range(
2512       results, [&](const auto *result) { return result->getName().getName(); });
2513   return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2514                              llvm::to_vector(resultNames));
2515 }
2516 
2517 template <typename T>
2518 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2519     const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2520     ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2521     ast::CompoundStmt *body) {
2522   if (!body->getChildren().empty()) {
2523     if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2524       ast::Expr *resultExpr = retStmt->getResultExpr();
2525 
2526       // Process the result of the decl. If no explicit signature results
2527       // were provided, check for return type inference. Otherwise, check that
2528       // the return expression can be converted to the expected type.
2529       if (results.empty())
2530         resultType = resultExpr->getType();
2531       else if (failed(convertExpressionTo(resultExpr, resultType)))
2532         return failure();
2533       else
2534         retStmt->setResultExpr(resultExpr);
2535     }
2536   }
2537   return T::createPDLL(ctx, name, arguments, results, body, resultType);
2538 }
2539 
2540 FailureOr<ast::VariableDecl *>
2541 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2542                            ArrayRef<ast::ConstraintRef> constraints) {
2543   // The type of the variable, which is expected to be inferred by either a
2544   // constraint or an initializer expression.
2545   ast::Type type;
2546   if (failed(validateVariableConstraints(constraints, type)))
2547     return failure();
2548 
2549   if (initializer) {
2550     // Update the variable type based on the initializer, or try to convert the
2551     // initializer to the existing type.
2552     if (!type)
2553       type = initializer->getType();
2554     else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2555       type = mergedType;
2556     else if (failed(convertExpressionTo(initializer, type)))
2557       return failure();
2558 
2559     // Otherwise, if there is no initializer check that the type has already
2560     // been resolved from the constraint list.
2561   } else if (!type) {
2562     return emitErrorAndNote(
2563         loc, "unable to infer type for variable `" + name + "`", loc,
2564         "the type of a variable must be inferable from the constraint "
2565         "list or the initializer");
2566   }
2567 
2568   // Constraint types cannot be used when defining variables.
2569   if (isa<ast::ConstraintType, ast::RewriteType>(type)) {
2570     return emitError(
2571         loc, llvm::formatv("unable to define variable of `{0}` type", type));
2572   }
2573 
2574   // Try to define a variable with the given name.
2575   FailureOr<ast::VariableDecl *> varDecl =
2576       defineVariableDecl(name, loc, type, initializer, constraints);
2577   if (failed(varDecl))
2578     return failure();
2579 
2580   return *varDecl;
2581 }
2582 
2583 FailureOr<ast::VariableDecl *>
2584 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2585                                       const ast::ConstraintRef &constraint) {
2586   ast::Type argType;
2587   if (failed(validateVariableConstraint(constraint, argType)))
2588     return failure();
2589   return defineVariableDecl(name, loc, argType, constraint);
2590 }
2591 
2592 LogicalResult
2593 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2594                                     ast::Type &inferredType) {
2595   for (const ast::ConstraintRef &ref : constraints)
2596     if (failed(validateVariableConstraint(ref, inferredType)))
2597       return failure();
2598   return success();
2599 }
2600 
2601 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2602                                                  ast::Type &inferredType) {
2603   ast::Type constraintType;
2604   if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2605     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2606       if (failed(validateTypeConstraintExpr(typeExpr)))
2607         return failure();
2608     }
2609     constraintType = ast::AttributeType::get(ctx);
2610   } else if (const auto *cst =
2611                  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2612     constraintType = ast::OperationType::get(
2613         ctx, cst->getName(), lookupODSOperation(cst->getName()));
2614   } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2615     constraintType = typeTy;
2616   } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2617     constraintType = typeRangeTy;
2618   } else if (const auto *cst =
2619                  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2620     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2621       if (failed(validateTypeConstraintExpr(typeExpr)))
2622         return failure();
2623     }
2624     constraintType = valueTy;
2625   } else if (const auto *cst =
2626                  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2627     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2628       if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2629         return failure();
2630     }
2631     constraintType = valueRangeTy;
2632   } else if (const auto *cst =
2633                  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2634     ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2635     if (inputs.size() != 1) {
2636       return emitErrorAndNote(ref.referenceLoc,
2637                               "`Constraint`s applied via a variable constraint "
2638                               "list must take a single input, but got " +
2639                                   Twine(inputs.size()),
2640                               cst->getLoc(),
2641                               "see definition of constraint here");
2642     }
2643     constraintType = inputs.front()->getType();
2644   } else {
2645     llvm_unreachable("unknown constraint type");
2646   }
2647 
2648   // Check that the constraint type is compatible with the current inferred
2649   // type.
2650   if (!inferredType) {
2651     inferredType = constraintType;
2652   } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2653     inferredType = mergedTy;
2654   } else {
2655     return emitError(ref.referenceLoc,
2656                      llvm::formatv("constraint type `{0}` is incompatible "
2657                                    "with the previously inferred type `{1}`",
2658                                    constraintType, inferredType));
2659   }
2660   return success();
2661 }
2662 
2663 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2664   ast::Type typeExprType = typeExpr->getType();
2665   if (typeExprType != typeTy) {
2666     return emitError(typeExpr->getLoc(),
2667                      "expected expression of `Type` in type constraint");
2668   }
2669   return success();
2670 }
2671 
2672 LogicalResult
2673 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2674   ast::Type typeExprType = typeExpr->getType();
2675   if (typeExprType != typeRangeTy) {
2676     return emitError(typeExpr->getLoc(),
2677                      "expected expression of `TypeRange` in type constraint");
2678   }
2679   return success();
2680 }
2681 
2682 //===----------------------------------------------------------------------===//
2683 // Exprs
2684 
2685 FailureOr<ast::CallExpr *>
2686 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2687                        MutableArrayRef<ast::Expr *> arguments, bool isNegated) {
2688   ast::Type parentType = parentExpr->getType();
2689 
2690   ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2691   if (!callableDecl) {
2692     return emitError(loc,
2693                      llvm::formatv("expected a reference to a callable "
2694                                    "`Constraint` or `Rewrite`, but got: `{0}`",
2695                                    parentType));
2696   }
2697   if (parserContext == ParserContext::Rewrite) {
2698     if (isa<ast::UserConstraintDecl>(callableDecl))
2699       return emitError(
2700           loc, "unable to invoke `Constraint` within a rewrite section");
2701     if (isNegated)
2702       return emitError(loc, "unable to negate a Rewrite");
2703   } else {
2704     if (isa<ast::UserRewriteDecl>(callableDecl))
2705       return emitError(loc,
2706                        "unable to invoke `Rewrite` within a match section");
2707     if (isNegated && cast<ast::UserConstraintDecl>(callableDecl)->getBody())
2708       return emitError(loc, "unable to negate non native constraints");
2709   }
2710 
2711   // Verify the arguments of the call.
2712   /// Handle size mismatch.
2713   ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2714   if (callArgs.size() != arguments.size()) {
2715     return emitErrorAndNote(
2716         loc,
2717         llvm::formatv("invalid number of arguments for {0} call; expected "
2718                       "{1}, but got {2}",
2719                       callableDecl->getCallableType(), callArgs.size(),
2720                       arguments.size()),
2721         callableDecl->getLoc(),
2722         llvm::formatv("see the definition of {0} here",
2723                       callableDecl->getName()->getName()));
2724   }
2725 
2726   /// Handle argument type mismatch.
2727   auto attachDiagFn = [&](ast::Diagnostic &diag) {
2728     diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2729                                   callableDecl->getName()->getName()),
2730                     callableDecl->getLoc());
2731   };
2732   for (auto it : llvm::zip(callArgs, arguments)) {
2733     if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2734                                    attachDiagFn)))
2735       return failure();
2736   }
2737 
2738   return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2739                                callableDecl->getResultType(), isNegated);
2740 }
2741 
2742 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2743                                                         ast::Decl *decl) {
2744   // Check the type of decl being referenced.
2745   ast::Type declType;
2746   if (isa<ast::ConstraintDecl>(decl))
2747     declType = ast::ConstraintType::get(ctx);
2748   else if (isa<ast::UserRewriteDecl>(decl))
2749     declType = ast::RewriteType::get(ctx);
2750   else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2751     declType = varDecl->getType();
2752   else
2753     return emitError(loc, "invalid reference to `" +
2754                               decl->getName()->getName() + "`");
2755 
2756   return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2757 }
2758 
2759 FailureOr<ast::DeclRefExpr *>
2760 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2761                                  ArrayRef<ast::ConstraintRef> constraints) {
2762   FailureOr<ast::VariableDecl *> decl =
2763       defineVariableDecl(name, loc, type, constraints);
2764   if (failed(decl))
2765     return failure();
2766   return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2767 }
2768 
2769 FailureOr<ast::MemberAccessExpr *>
2770 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2771                                SMRange loc) {
2772   // Validate the member name for the given parent expression.
2773   FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2774   if (failed(memberType))
2775     return failure();
2776 
2777   return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2778 }
2779 
2780 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2781                                                   StringRef name, SMRange loc) {
2782   ast::Type parentType = parentExpr->getType();
2783   if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType)) {
2784     if (name == ast::AllResultsMemberAccessExpr::getMemberName())
2785       return valueRangeTy;
2786 
2787     // Verify member access based on the operation type.
2788     if (const ods::Operation *odsOp = opType.getODSOperation()) {
2789       auto results = odsOp->getResults();
2790 
2791       // Handle indexed results.
2792       unsigned index = 0;
2793       if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2794           index < results.size()) {
2795         return results[index].isVariadic() ? valueRangeTy : valueTy;
2796       }
2797 
2798       // Handle named results.
2799       const auto *it = llvm::find_if(results, [&](const auto &result) {
2800         return result.getName() == name;
2801       });
2802       if (it != results.end())
2803         return it->isVariadic() ? valueRangeTy : valueTy;
2804     } else if (llvm::isDigit(name[0])) {
2805       // Allow unchecked numeric indexing of the results of unregistered
2806       // operations. It returns a single value.
2807       return valueTy;
2808     }
2809   } else if (auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
2810     // Handle indexed results.
2811     unsigned index = 0;
2812     if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2813         index < tupleType.size()) {
2814       return tupleType.getElementTypes()[index];
2815     }
2816 
2817     // Handle named results.
2818     auto elementNames = tupleType.getElementNames();
2819     const auto *it = llvm::find(elementNames, name);
2820     if (it != elementNames.end())
2821       return tupleType.getElementTypes()[it - elementNames.begin()];
2822   }
2823   return emitError(
2824       loc,
2825       llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2826                     name, parentType));
2827 }
2828 
2829 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2830     SMRange loc, const ast::OpNameDecl *name,
2831     OpResultTypeContext resultTypeContext,
2832     SmallVectorImpl<ast::Expr *> &operands,
2833     MutableArrayRef<ast::NamedAttributeDecl *> attributes,
2834     SmallVectorImpl<ast::Expr *> &results) {
2835   std::optional<StringRef> opNameRef = name->getName();
2836   const ods::Operation *odsOp = lookupODSOperation(opNameRef);
2837 
2838   // Verify the inputs operands.
2839   if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2840     return failure();
2841 
2842   // Verify the attribute list.
2843   for (ast::NamedAttributeDecl *attr : attributes) {
2844     // Check for an attribute type, or a type awaiting resolution.
2845     ast::Type attrType = attr->getValue()->getType();
2846     if (!isa<ast::AttributeType>(attrType)) {
2847       return emitError(
2848           attr->getValue()->getLoc(),
2849           llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2850     }
2851   }
2852 
2853   assert(
2854       (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2855       "unexpected inferrence when results were explicitly specified");
2856 
2857   // If we aren't relying on type inferrence, or explicit results were provided,
2858   // validate them.
2859   if (resultTypeContext == OpResultTypeContext::Explicit) {
2860     if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2861       return failure();
2862 
2863     // Validate the use of interface based type inferrence for this operation.
2864   } else if (resultTypeContext == OpResultTypeContext::Interface) {
2865     assert(opNameRef &&
2866            "expected valid operation name when inferring operation results");
2867     checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2868   }
2869 
2870   return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results,
2871                                     attributes);
2872 }
2873 
2874 LogicalResult
2875 Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name,
2876                                   const ods::Operation *odsOp,
2877                                   SmallVectorImpl<ast::Expr *> &operands) {
2878   return validateOperationOperandsOrResults(
2879       "operand", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2880       operands, odsOp ? odsOp->getOperands() : std::nullopt, valueTy,
2881       valueRangeTy);
2882 }
2883 
2884 LogicalResult
2885 Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name,
2886                                  const ods::Operation *odsOp,
2887                                  SmallVectorImpl<ast::Expr *> &results) {
2888   return validateOperationOperandsOrResults(
2889       "result", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2890       results, odsOp ? odsOp->getResults() : std::nullopt, typeTy, typeRangeTy);
2891 }
2892 
2893 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2894                                                 const ods::Operation *odsOp) {
2895   // If the operation might not have inferrence support, emit a warning to the
2896   // user. We don't emit an error because the interface might be added to the
2897   // operation at runtime. It's rare, but it could still happen. We emit a
2898   // warning here instead.
2899 
2900   // Handle inferrence warnings for unknown operations.
2901   if (!odsOp) {
2902     ctx.getDiagEngine().emitWarning(
2903         loc, llvm::formatv(
2904                  "operation result types are marked to be inferred, but "
2905                  "`{0}` is unknown. Ensure that `{0}` supports zero "
2906                  "results or implements `InferTypeOpInterface`. Include "
2907                  "the ODS definition of this operation to remove this warning.",
2908                  opName));
2909     return;
2910   }
2911 
2912   // Handle inferrence warnings for known operations that expected at least one
2913   // result, but don't have inference support. An elided results list can mean
2914   // "zero-results", and we don't want to warn when that is the expected
2915   // behavior.
2916   bool requiresInferrence =
2917       llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) {
2918         return !result.isVariableLength();
2919       });
2920   if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
2921     ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning(
2922         loc,
2923         llvm::formatv("operation result types are marked to be inferred, but "
2924                       "`{0}` does not provide an implementation of "
2925                       "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2926                       "`InferTypeOpInterface` at runtime, or add support to "
2927                       "the ODS definition to remove this warning.",
2928                       opName));
2929     diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName),
2930                      odsOp->getLoc());
2931     return;
2932   }
2933 }
2934 
2935 LogicalResult Parser::validateOperationOperandsOrResults(
2936     StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
2937     std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
2938     ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2939     ast::RangeType rangeTy) {
2940   // All operation types accept a single range parameter.
2941   if (values.size() == 1) {
2942     if (failed(convertExpressionTo(values[0], rangeTy)))
2943       return failure();
2944     return success();
2945   }
2946 
2947   /// If the operation has ODS information, we can more accurately verify the
2948   /// values.
2949   if (odsOpLoc) {
2950     auto emitSizeMismatchError = [&] {
2951       return emitErrorAndNote(
2952           loc,
2953           llvm::formatv("invalid number of {0} groups for `{1}`; expected "
2954                         "{2}, but got {3}",
2955                         groupName, *name, odsValues.size(), values.size()),
2956           *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
2957     };
2958 
2959     // Handle the case where no values were provided.
2960     if (values.empty()) {
2961       // If we don't expect any on the ODS side, we are done.
2962       if (odsValues.empty())
2963         return success();
2964 
2965       // If we do, check if we actually need to provide values (i.e. if any of
2966       // the values are actually required).
2967       unsigned numVariadic = 0;
2968       for (const auto &odsValue : odsValues) {
2969         if (!odsValue.isVariableLength())
2970           return emitSizeMismatchError();
2971         ++numVariadic;
2972       }
2973 
2974       // If we are in a non-rewrite context, we don't need to do anything more.
2975       // Zero-values is a valid constraint on the operation.
2976       if (parserContext != ParserContext::Rewrite)
2977         return success();
2978 
2979       // Otherwise, when in a rewrite we may need to provide values to match the
2980       // ODS signature of the operation to create.
2981 
2982       // If we only have one variadic value, just use an empty list.
2983       if (numVariadic == 1)
2984         return success();
2985 
2986       // Otherwise, create dummy values for each of the entries so that we
2987       // adhere to the ODS signature.
2988       for (unsigned i = 0, e = odsValues.size(); i < e; ++i) {
2989         values.push_back(ast::RangeExpr::create(
2990             ctx, loc, /*elements=*/std::nullopt, rangeTy));
2991       }
2992       return success();
2993     }
2994 
2995     // Verify that the number of values provided matches the number of value
2996     // groups ODS expects.
2997     if (odsValues.size() != values.size())
2998       return emitSizeMismatchError();
2999 
3000     auto diagFn = [&](ast::Diagnostic &diag) {
3001       diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
3002                       *odsOpLoc);
3003     };
3004     for (unsigned i = 0, e = values.size(); i < e; ++i) {
3005       ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
3006       if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
3007         return failure();
3008     }
3009     return success();
3010   }
3011 
3012   // Otherwise, accept the value groups as they have been defined and just
3013   // ensure they are one of the expected types.
3014   for (ast::Expr *&valueExpr : values) {
3015     ast::Type valueExprType = valueExpr->getType();
3016 
3017     // Check if this is one of the expected types.
3018     if (valueExprType == rangeTy || valueExprType == singleTy)
3019       continue;
3020 
3021     // If the operand is an Operation, allow converting to a Value or
3022     // ValueRange. This situations arises quite often with nested operation
3023     // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
3024     if (singleTy == valueTy) {
3025       if (isa<ast::OperationType>(valueExprType)) {
3026         valueExpr = convertOpToValue(valueExpr);
3027         continue;
3028       }
3029     }
3030 
3031     // Otherwise, try to convert the expression to a range.
3032     if (succeeded(convertExpressionTo(valueExpr, rangeTy)))
3033       continue;
3034 
3035     return emitError(
3036         valueExpr->getLoc(),
3037         llvm::formatv(
3038             "expected `{0}` or `{1}` convertible expression, but got `{2}`",
3039             singleTy, rangeTy, valueExprType));
3040   }
3041   return success();
3042 }
3043 
3044 FailureOr<ast::TupleExpr *>
3045 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
3046                         ArrayRef<StringRef> elementNames) {
3047   for (const ast::Expr *element : elements) {
3048     ast::Type eleTy = element->getType();
3049     if (isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>(eleTy)) {
3050       return emitError(
3051           element->getLoc(),
3052           llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
3053     }
3054   }
3055   return ast::TupleExpr::create(ctx, loc, elements, elementNames);
3056 }
3057 
3058 //===----------------------------------------------------------------------===//
3059 // Stmts
3060 
3061 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
3062                                                     ast::Expr *rootOp) {
3063   // Check that root is an Operation.
3064   ast::Type rootType = rootOp->getType();
3065   if (!isa<ast::OperationType>(rootType))
3066     return emitError(rootOp->getLoc(), "expected `Op` expression");
3067 
3068   return ast::EraseStmt::create(ctx, loc, rootOp);
3069 }
3070 
3071 FailureOr<ast::ReplaceStmt *>
3072 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
3073                           MutableArrayRef<ast::Expr *> replValues) {
3074   // Check that root is an Operation.
3075   ast::Type rootType = rootOp->getType();
3076   if (!isa<ast::OperationType>(rootType)) {
3077     return emitError(
3078         rootOp->getLoc(),
3079         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3080   }
3081 
3082   // If there are multiple replacement values, we implicitly convert any Op
3083   // expressions to the value form.
3084   bool shouldConvertOpToValues = replValues.size() > 1;
3085   for (ast::Expr *&replExpr : replValues) {
3086     ast::Type replType = replExpr->getType();
3087 
3088     // Check that replExpr is an Operation, Value, or ValueRange.
3089     if (isa<ast::OperationType>(replType)) {
3090       if (shouldConvertOpToValues)
3091         replExpr = convertOpToValue(replExpr);
3092       continue;
3093     }
3094 
3095     if (replType != valueTy && replType != valueRangeTy) {
3096       return emitError(replExpr->getLoc(),
3097                        llvm::formatv("expected `Op`, `Value` or `ValueRange` "
3098                                      "expression, but got `{0}`",
3099                                      replType));
3100     }
3101   }
3102 
3103   return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
3104 }
3105 
3106 FailureOr<ast::RewriteStmt *>
3107 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
3108                           ast::CompoundStmt *rewriteBody) {
3109   // Check that root is an Operation.
3110   ast::Type rootType = rootOp->getType();
3111   if (!isa<ast::OperationType>(rootType)) {
3112     return emitError(
3113         rootOp->getLoc(),
3114         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3115   }
3116 
3117   return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
3118 }
3119 
3120 //===----------------------------------------------------------------------===//
3121 // Code Completion
3122 //===----------------------------------------------------------------------===//
3123 
3124 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
3125   ast::Type parentType = parentExpr->getType();
3126   if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType))
3127     codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3128   else if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(parentType))
3129     codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3130   return failure();
3131 }
3132 
3133 LogicalResult
3134 Parser::codeCompleteAttributeName(std::optional<StringRef> opName) {
3135   if (opName)
3136     codeCompleteContext->codeCompleteOperationAttributeName(*opName);
3137   return failure();
3138 }
3139 
3140 LogicalResult
3141 Parser::codeCompleteConstraintName(ast::Type inferredType,
3142                                    bool allowInlineTypeConstraints) {
3143   codeCompleteContext->codeCompleteConstraintName(
3144       inferredType, allowInlineTypeConstraints, curDeclScope);
3145   return failure();
3146 }
3147 
3148 LogicalResult Parser::codeCompleteDialectName() {
3149   codeCompleteContext->codeCompleteDialectName();
3150   return failure();
3151 }
3152 
3153 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3154   codeCompleteContext->codeCompleteOperationName(dialectName);
3155   return failure();
3156 }
3157 
3158 LogicalResult Parser::codeCompletePatternMetadata() {
3159   codeCompleteContext->codeCompletePatternMetadata();
3160   return failure();
3161 }
3162 
3163 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3164   codeCompleteContext->codeCompleteIncludeFilename(curPath);
3165   return failure();
3166 }
3167 
3168 void Parser::codeCompleteCallSignature(ast::Node *parent,
3169                                        unsigned currentNumArgs) {
3170   ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent);
3171   if (!callableDecl)
3172     return;
3173 
3174   codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs);
3175 }
3176 
3177 void Parser::codeCompleteOperationOperandsSignature(
3178     std::optional<StringRef> opName, unsigned currentNumOperands) {
3179   codeCompleteContext->codeCompleteOperationOperandsSignature(
3180       opName, currentNumOperands);
3181 }
3182 
3183 void Parser::codeCompleteOperationResultsSignature(
3184     std::optional<StringRef> opName, unsigned currentNumResults) {
3185   codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3186                                                              currentNumResults);
3187 }
3188 
3189 //===----------------------------------------------------------------------===//
3190 // Parser
3191 //===----------------------------------------------------------------------===//
3192 
3193 FailureOr<ast::Module *>
3194 mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
3195                          bool enableDocumentation,
3196                          CodeCompleteContext *codeCompleteContext) {
3197   Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3198   return parser.parseModule();
3199 }
3200