xref: /llvm-project/mlir/include/mlir/Tools/PDLL/AST/Nodes.h (revision d2353695f8cb864f88475d3a921249b0dcbcc6f4)
1 //===- Nodes.h --------------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_TOOLS_PDLL_AST_NODES_H_
10 #define MLIR_TOOLS_PDLL_AST_NODES_H_
11 
12 #include "mlir/Support/LLVM.h"
13 #include "mlir/Tools/PDLL/AST/Types.h"
14 #include "llvm/ADT/StringMap.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/Support/SMLoc.h"
17 #include "llvm/Support/SourceMgr.h"
18 #include "llvm/Support/TrailingObjects.h"
19 #include <optional>
20 
21 namespace mlir {
22 namespace pdll {
23 namespace ast {
24 class Context;
25 class Decl;
26 class Expr;
27 class NamedAttributeDecl;
28 class OpNameDecl;
29 class VariableDecl;
30 
31 //===----------------------------------------------------------------------===//
32 // Name
33 //===----------------------------------------------------------------------===//
34 
35 /// This class provides a convenient API for interacting with source names. It
36 /// contains a string name as well as the source location for that name.
37 struct Name {
38   static const Name &create(Context &ctx, StringRef name, SMRange location);
39 
40   /// Return the raw string name.
getNameName41   StringRef getName() const { return name; }
42 
43   /// Get the location of this name.
getLocName44   SMRange getLoc() const { return location; }
45 
46 private:
47   Name() = delete;
48   Name(const Name &) = delete;
49   Name &operator=(const Name &) = delete;
NameName50   Name(StringRef name, SMRange location) : name(name), location(location) {}
51 
52   /// The string name of the decl.
53   StringRef name;
54   /// The location of the decl name.
55   SMRange location;
56 };
57 
58 //===----------------------------------------------------------------------===//
59 // DeclScope
60 //===----------------------------------------------------------------------===//
61 
62 /// This class represents a scope for named AST decls. A scope determines the
63 /// visibility and lifetime of a named declaration.
64 class DeclScope {
65 public:
66   /// Create a new scope with an optional parent scope.
parent(parent)67   DeclScope(DeclScope *parent = nullptr) : parent(parent) {}
68 
69   /// Return the parent scope of this scope, or nullptr if there is no parent.
getParentScope()70   DeclScope *getParentScope() { return parent; }
getParentScope()71   const DeclScope *getParentScope() const { return parent; }
72 
73   /// Return all of the decls within this scope.
getDecls()74   auto getDecls() const { return llvm::make_second_range(decls); }
75 
76   /// Add a new decl to the scope.
77   void add(Decl *decl);
78 
79   /// Lookup a decl with the given name starting from this scope. Returns
80   /// nullptr if no decl could be found.
81   Decl *lookup(StringRef name);
82   template <typename T>
lookup(StringRef name)83   T *lookup(StringRef name) {
84     return dyn_cast_or_null<T>(lookup(name));
85   }
lookup(StringRef name)86   const Decl *lookup(StringRef name) const {
87     return const_cast<DeclScope *>(this)->lookup(name);
88   }
89   template <typename T>
lookup(StringRef name)90   const T *lookup(StringRef name) const {
91     return dyn_cast_or_null<T>(lookup(name));
92   }
93 
94 private:
95   /// The parent scope, or null if this is a top-level scope.
96   DeclScope *parent;
97   /// The decls defined within this scope.
98   llvm::StringMap<Decl *> decls;
99 };
100 
101 //===----------------------------------------------------------------------===//
102 // Node
103 //===----------------------------------------------------------------------===//
104 
105 /// This class represents a base AST node. All AST nodes are derived from this
106 /// class, and it contains many of the base functionality for interacting with
107 /// nodes.
108 class Node {
109 public:
110   /// This CRTP class provides several utilies when defining new AST nodes.
111   template <typename T, typename BaseT>
112   class NodeBase : public BaseT {
113   public:
114     using Base = NodeBase<T, BaseT>;
115 
116     /// Provide type casting support.
classof(const Node * node)117     static bool classof(const Node *node) {
118       return node->getTypeID() == TypeID::get<T>();
119     }
120 
121   protected:
122     template <typename... Args>
NodeBase(SMRange loc,Args &&...args)123     explicit NodeBase(SMRange loc, Args &&...args)
124         : BaseT(TypeID::get<T>(), loc, std::forward<Args>(args)...) {}
125   };
126 
127   /// Return the type identifier of this node.
getTypeID()128   TypeID getTypeID() const { return typeID; }
129 
130   /// Return the location of this node.
getLoc()131   SMRange getLoc() const { return loc; }
132 
133   /// Print this node to the given stream.
134   void print(raw_ostream &os) const;
135 
136   /// Walk all of the nodes including, and nested under, this node in pre-order.
137   void walk(function_ref<void(const Node *)> walkFn) const;
138   template <typename WalkFnT, typename ArgT = typename llvm::function_traits<
139                                   WalkFnT>::template arg_t<0>>
140   std::enable_if_t<!std::is_convertible<const Node *, ArgT>::value>
walk(WalkFnT && walkFn)141   walk(WalkFnT &&walkFn) const {
142     walk([&](const Node *node) {
143       if (const ArgT *derivedNode = dyn_cast<ArgT>(node))
144         walkFn(derivedNode);
145     });
146   }
147 
148 protected:
Node(TypeID typeID,SMRange loc)149   Node(TypeID typeID, SMRange loc) : typeID(typeID), loc(loc) {}
150 
151 private:
152   /// A unique type identifier for this node.
153   TypeID typeID;
154 
155   /// The location of this node.
156   SMRange loc;
157 };
158 
159 //===----------------------------------------------------------------------===//
160 // Stmt
161 //===----------------------------------------------------------------------===//
162 
163 /// This class represents a base AST Statement node.
164 class Stmt : public Node {
165 public:
166   using Node::Node;
167 
168   /// Provide type casting support.
169   static bool classof(const Node *node);
170 };
171 
172 //===----------------------------------------------------------------------===//
173 // CompoundStmt
174 //===----------------------------------------------------------------------===//
175 
176 /// This statement represents a compound statement, which contains a collection
177 /// of other statements.
178 class CompoundStmt final : public Node::NodeBase<CompoundStmt, Stmt>,
179                            private llvm::TrailingObjects<CompoundStmt, Stmt *> {
180 public:
181   static CompoundStmt *create(Context &ctx, SMRange location,
182                               ArrayRef<Stmt *> children);
183 
184   /// Return the children of this compound statement.
getChildren()185   MutableArrayRef<Stmt *> getChildren() {
186     return {getTrailingObjects<Stmt *>(), numChildren};
187   }
getChildren()188   ArrayRef<Stmt *> getChildren() const {
189     return const_cast<CompoundStmt *>(this)->getChildren();
190   }
begin()191   ArrayRef<Stmt *>::iterator begin() const { return getChildren().begin(); }
end()192   ArrayRef<Stmt *>::iterator end() const { return getChildren().end(); }
193 
194 private:
CompoundStmt(SMRange location,unsigned numChildren)195   CompoundStmt(SMRange location, unsigned numChildren)
196       : Base(location), numChildren(numChildren) {}
197 
198   /// The number of held children statements.
199   unsigned numChildren;
200 
201   // Allow access to various privates.
202   friend class llvm::TrailingObjects<CompoundStmt, Stmt *>;
203 };
204 
205 //===----------------------------------------------------------------------===//
206 // LetStmt
207 //===----------------------------------------------------------------------===//
208 
209 /// This statement represents a `let` statement in PDLL. This statement is used
210 /// to define variables.
211 class LetStmt final : public Node::NodeBase<LetStmt, Stmt> {
212 public:
213   static LetStmt *create(Context &ctx, SMRange loc, VariableDecl *varDecl);
214 
215   /// Return the variable defined by this statement.
getVarDecl()216   VariableDecl *getVarDecl() const { return varDecl; }
217 
218 private:
LetStmt(SMRange loc,VariableDecl * varDecl)219   LetStmt(SMRange loc, VariableDecl *varDecl) : Base(loc), varDecl(varDecl) {}
220 
221   /// The variable defined by this statement.
222   VariableDecl *varDecl;
223 };
224 
225 //===----------------------------------------------------------------------===//
226 // OpRewriteStmt
227 //===----------------------------------------------------------------------===//
228 
229 /// This class represents a base operation rewrite statement. Operation rewrite
230 /// statements perform a set of transformations on a given root operation.
231 class OpRewriteStmt : public Stmt {
232 public:
233   /// Provide type casting support.
234   static bool classof(const Node *node);
235 
236   /// Return the root operation of this rewrite.
getRootOpExpr()237   Expr *getRootOpExpr() const { return rootOp; }
238 
239 protected:
OpRewriteStmt(TypeID typeID,SMRange loc,Expr * rootOp)240   OpRewriteStmt(TypeID typeID, SMRange loc, Expr *rootOp)
241       : Stmt(typeID, loc), rootOp(rootOp) {}
242 
243 protected:
244   /// The root operation being rewritten.
245   Expr *rootOp;
246 };
247 
248 //===----------------------------------------------------------------------===//
249 // EraseStmt
250 
251 /// This statement represents the `erase` statement in PDLL. This statement
252 /// erases the given root operation, corresponding roughly to the
253 /// PatternRewriter::eraseOp API.
254 class EraseStmt final : public Node::NodeBase<EraseStmt, OpRewriteStmt> {
255 public:
256   static EraseStmt *create(Context &ctx, SMRange loc, Expr *rootOp);
257 
258 private:
EraseStmt(SMRange loc,Expr * rootOp)259   EraseStmt(SMRange loc, Expr *rootOp) : Base(loc, rootOp) {}
260 };
261 
262 //===----------------------------------------------------------------------===//
263 // ReplaceStmt
264 
265 /// This statement represents the `replace` statement in PDLL. This statement
266 /// replace the given root operation with a set of values, corresponding roughly
267 /// to the PatternRewriter::replaceOp API.
268 class ReplaceStmt final : public Node::NodeBase<ReplaceStmt, OpRewriteStmt>,
269                           private llvm::TrailingObjects<ReplaceStmt, Expr *> {
270 public:
271   static ReplaceStmt *create(Context &ctx, SMRange loc, Expr *rootOp,
272                              ArrayRef<Expr *> replExprs);
273 
274   /// Return the replacement values of this statement.
getReplExprs()275   MutableArrayRef<Expr *> getReplExprs() {
276     return {getTrailingObjects<Expr *>(), numReplExprs};
277   }
getReplExprs()278   ArrayRef<Expr *> getReplExprs() const {
279     return const_cast<ReplaceStmt *>(this)->getReplExprs();
280   }
281 
282 private:
ReplaceStmt(SMRange loc,Expr * rootOp,unsigned numReplExprs)283   ReplaceStmt(SMRange loc, Expr *rootOp, unsigned numReplExprs)
284       : Base(loc, rootOp), numReplExprs(numReplExprs) {}
285 
286   /// The number of replacement values within this statement.
287   unsigned numReplExprs;
288 
289   /// TrailingObject utilities.
290   friend class llvm::TrailingObjects<ReplaceStmt, Expr *>;
291 };
292 
293 //===----------------------------------------------------------------------===//
294 // RewriteStmt
295 
296 /// This statement represents an operation rewrite that contains a block of
297 /// nested rewrite commands. This allows for building more complex operation
298 /// rewrites that span across multiple statements, which may be unconnected.
299 class RewriteStmt final : public Node::NodeBase<RewriteStmt, OpRewriteStmt> {
300 public:
301   static RewriteStmt *create(Context &ctx, SMRange loc, Expr *rootOp,
302                              CompoundStmt *rewriteBody);
303 
304   /// Return the compound rewrite body.
getRewriteBody()305   CompoundStmt *getRewriteBody() const { return rewriteBody; }
306 
307 private:
RewriteStmt(SMRange loc,Expr * rootOp,CompoundStmt * rewriteBody)308   RewriteStmt(SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
309       : Base(loc, rootOp), rewriteBody(rewriteBody) {}
310 
311   /// The body of nested rewriters within this statement.
312   CompoundStmt *rewriteBody;
313 };
314 
315 //===----------------------------------------------------------------------===//
316 // ReturnStmt
317 //===----------------------------------------------------------------------===//
318 
319 /// This statement represents a return from a "callable" like decl, e.g. a
320 /// Constraint or a Rewrite.
321 class ReturnStmt final : public Node::NodeBase<ReturnStmt, Stmt> {
322 public:
323   static ReturnStmt *create(Context &ctx, SMRange loc, Expr *resultExpr);
324 
325   /// Return the result expression of this statement.
getResultExpr()326   Expr *getResultExpr() { return resultExpr; }
getResultExpr()327   const Expr *getResultExpr() const { return resultExpr; }
328 
329   /// Set the result expression of this statement.
setResultExpr(Expr * expr)330   void setResultExpr(Expr *expr) { resultExpr = expr; }
331 
332 private:
ReturnStmt(SMRange loc,Expr * resultExpr)333   ReturnStmt(SMRange loc, Expr *resultExpr)
334       : Base(loc), resultExpr(resultExpr) {}
335 
336   // The result expression of this statement.
337   Expr *resultExpr;
338 };
339 
340 //===----------------------------------------------------------------------===//
341 // Expr
342 //===----------------------------------------------------------------------===//
343 
344 /// This class represents a base AST Expression node.
345 class Expr : public Stmt {
346 public:
347   /// Return the type of this expression.
getType()348   Type getType() const { return type; }
349 
350   /// Provide type casting support.
351   static bool classof(const Node *node);
352 
353 protected:
Expr(TypeID typeID,SMRange loc,Type type)354   Expr(TypeID typeID, SMRange loc, Type type) : Stmt(typeID, loc), type(type) {}
355 
356 private:
357   /// The type of this expression.
358   Type type;
359 };
360 
361 //===----------------------------------------------------------------------===//
362 // AttributeExpr
363 //===----------------------------------------------------------------------===//
364 
365 /// This expression represents a literal MLIR Attribute, and contains the
366 /// textual assembly format of that attribute.
367 class AttributeExpr : public Node::NodeBase<AttributeExpr, Expr> {
368 public:
369   static AttributeExpr *create(Context &ctx, SMRange loc, StringRef value);
370 
371   /// Get the raw value of this expression. This is the textual assembly format
372   /// of the MLIR Attribute.
getValue()373   StringRef getValue() const { return value; }
374 
375 private:
AttributeExpr(Context & ctx,SMRange loc,StringRef value)376   AttributeExpr(Context &ctx, SMRange loc, StringRef value)
377       : Base(loc, AttributeType::get(ctx)), value(value) {}
378 
379   /// The value referenced by this expression.
380   StringRef value;
381 };
382 
383 //===----------------------------------------------------------------------===//
384 // CallExpr
385 //===----------------------------------------------------------------------===//
386 
387 /// This expression represents a call to a decl, such as a
388 /// UserConstraintDecl/UserRewriteDecl.
389 class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
390                        private llvm::TrailingObjects<CallExpr, Expr *> {
391 public:
392   static CallExpr *create(Context &ctx, SMRange loc, Expr *callable,
393                           ArrayRef<Expr *> arguments, Type resultType,
394                           bool isNegated = false);
395 
396   /// Return the callable of this call.
getCallableExpr()397   Expr *getCallableExpr() const { return callable; }
398 
399   /// Return the arguments of this call.
getArguments()400   MutableArrayRef<Expr *> getArguments() {
401     return {getTrailingObjects<Expr *>(), numArgs};
402   }
getArguments()403   ArrayRef<Expr *> getArguments() const {
404     return const_cast<CallExpr *>(this)->getArguments();
405   }
406 
407   /// Returns whether the result of this call is to be negated.
getIsNegated()408   bool getIsNegated() const { return isNegated; }
409 
410 private:
CallExpr(SMRange loc,Type type,Expr * callable,unsigned numArgs,bool isNegated)411   CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs,
412            bool isNegated)
413       : Base(loc, type), callable(callable), numArgs(numArgs),
414         isNegated(isNegated) {}
415 
416   /// The callable of this call.
417   Expr *callable;
418 
419   /// The number of arguments of the call.
420   unsigned numArgs;
421 
422   /// TrailingObject utilities.
423   friend llvm::TrailingObjects<CallExpr, Expr *>;
424 
425   // Is the result of this call to be negated.
426   bool isNegated;
427 };
428 
429 //===----------------------------------------------------------------------===//
430 // DeclRefExpr
431 //===----------------------------------------------------------------------===//
432 
433 /// This expression represents a reference to a Decl node.
434 class DeclRefExpr : public Node::NodeBase<DeclRefExpr, Expr> {
435 public:
436   static DeclRefExpr *create(Context &ctx, SMRange loc, Decl *decl, Type type);
437 
438   /// Get the decl referenced by this expression.
getDecl()439   Decl *getDecl() const { return decl; }
440 
441 private:
DeclRefExpr(SMRange loc,Decl * decl,Type type)442   DeclRefExpr(SMRange loc, Decl *decl, Type type)
443       : Base(loc, type), decl(decl) {}
444 
445   /// The decl referenced by this expression.
446   Decl *decl;
447 };
448 
449 //===----------------------------------------------------------------------===//
450 // MemberAccessExpr
451 //===----------------------------------------------------------------------===//
452 
453 /// This expression represents a named member or field access of a given parent
454 /// expression.
455 class MemberAccessExpr : public Node::NodeBase<MemberAccessExpr, Expr> {
456 public:
457   static MemberAccessExpr *create(Context &ctx, SMRange loc,
458                                   const Expr *parentExpr, StringRef memberName,
459                                   Type type);
460 
461   /// Get the parent expression of this access.
getParentExpr()462   const Expr *getParentExpr() const { return parentExpr; }
463 
464   /// Return the name of the member being accessed.
getMemberName()465   StringRef getMemberName() const { return memberName; }
466 
467 private:
MemberAccessExpr(SMRange loc,const Expr * parentExpr,StringRef memberName,Type type)468   MemberAccessExpr(SMRange loc, const Expr *parentExpr, StringRef memberName,
469                    Type type)
470       : Base(loc, type), parentExpr(parentExpr), memberName(memberName) {}
471 
472   /// The parent expression of this access.
473   const Expr *parentExpr;
474 
475   /// The name of the member being accessed from the parent.
476   StringRef memberName;
477 };
478 
479 //===----------------------------------------------------------------------===//
480 // AllResultsMemberAccessExpr
481 
482 /// This class represents an instance of MemberAccessExpr that references all
483 /// results of an operation.
484 class AllResultsMemberAccessExpr : public MemberAccessExpr {
485 public:
486   /// Return the member name used for the "all-results" access.
getMemberName()487   static StringRef getMemberName() { return "$results"; }
488 
create(Context & ctx,SMRange loc,const Expr * parentExpr,Type type)489   static AllResultsMemberAccessExpr *create(Context &ctx, SMRange loc,
490                                             const Expr *parentExpr, Type type) {
491     return cast<AllResultsMemberAccessExpr>(
492         MemberAccessExpr::create(ctx, loc, parentExpr, getMemberName(), type));
493   }
494 
495   /// Provide type casting support.
classof(const Node * node)496   static bool classof(const Node *node) {
497     const MemberAccessExpr *memAccess = dyn_cast<MemberAccessExpr>(node);
498     return memAccess && memAccess->getMemberName() == getMemberName();
499   }
500 };
501 
502 //===----------------------------------------------------------------------===//
503 // OperationExpr
504 //===----------------------------------------------------------------------===//
505 
506 /// This expression represents the structural form of an MLIR Operation. It
507 /// represents either an input operation to match, or an operation to create
508 /// within a rewrite.
509 class OperationExpr final
510     : public Node::NodeBase<OperationExpr, Expr>,
511       private llvm::TrailingObjects<OperationExpr, Expr *,
512                                     NamedAttributeDecl *> {
513 public:
514   static OperationExpr *create(Context &ctx, SMRange loc,
515                                const ods::Operation *odsOp,
516                                const OpNameDecl *nameDecl,
517                                ArrayRef<Expr *> operands,
518                                ArrayRef<Expr *> resultTypes,
519                                ArrayRef<NamedAttributeDecl *> attributes);
520 
521   /// Return the name of the operation, or std::nullopt if there isn't one.
522   std::optional<StringRef> getName() const;
523 
524   /// Return the declaration of the operation name.
getNameDecl()525   const OpNameDecl *getNameDecl() const { return nameDecl; }
526 
527   /// Return the location of the name of the operation expression, or an invalid
528   /// location if there isn't a name.
getNameLoc()529   SMRange getNameLoc() const { return nameLoc; }
530 
531   /// Return the operands of this operation.
getOperands()532   MutableArrayRef<Expr *> getOperands() {
533     return {getTrailingObjects<Expr *>(), numOperands};
534   }
getOperands()535   ArrayRef<Expr *> getOperands() const {
536     return const_cast<OperationExpr *>(this)->getOperands();
537   }
538 
539   /// Return the result types of this operation.
getResultTypes()540   MutableArrayRef<Expr *> getResultTypes() {
541     return {getTrailingObjects<Expr *>() + numOperands, numResultTypes};
542   }
getResultTypes()543   MutableArrayRef<Expr *> getResultTypes() const {
544     return const_cast<OperationExpr *>(this)->getResultTypes();
545   }
546 
547   /// Return the attributes of this operation.
getAttributes()548   MutableArrayRef<NamedAttributeDecl *> getAttributes() {
549     return {getTrailingObjects<NamedAttributeDecl *>(), numAttributes};
550   }
getAttributes()551   MutableArrayRef<NamedAttributeDecl *> getAttributes() const {
552     return const_cast<OperationExpr *>(this)->getAttributes();
553   }
554 
555 private:
OperationExpr(SMRange loc,Type type,const OpNameDecl * nameDecl,unsigned numOperands,unsigned numResultTypes,unsigned numAttributes,SMRange nameLoc)556   OperationExpr(SMRange loc, Type type, const OpNameDecl *nameDecl,
557                 unsigned numOperands, unsigned numResultTypes,
558                 unsigned numAttributes, SMRange nameLoc)
559       : Base(loc, type), nameDecl(nameDecl), numOperands(numOperands),
560         numResultTypes(numResultTypes), numAttributes(numAttributes),
561         nameLoc(nameLoc) {}
562 
563   /// The name decl of this expression.
564   const OpNameDecl *nameDecl;
565 
566   /// The number of operands, result types, and attributes of the operation.
567   unsigned numOperands, numResultTypes, numAttributes;
568 
569   /// The location of the operation name in the expression if it has a name.
570   SMRange nameLoc;
571 
572   /// TrailingObject utilities.
573   friend llvm::TrailingObjects<OperationExpr, Expr *, NamedAttributeDecl *>;
numTrailingObjects(OverloadToken<Expr * >)574   size_t numTrailingObjects(OverloadToken<Expr *>) const {
575     return numOperands + numResultTypes;
576   }
577 };
578 
579 //===----------------------------------------------------------------------===//
580 // RangeExpr
581 //===----------------------------------------------------------------------===//
582 
583 /// This expression builds a range from a set of element values (which may be
584 /// ranges themselves).
585 class RangeExpr final : public Node::NodeBase<RangeExpr, Expr>,
586                         private llvm::TrailingObjects<RangeExpr, Expr *> {
587 public:
588   static RangeExpr *create(Context &ctx, SMRange loc, ArrayRef<Expr *> elements,
589                            RangeType type);
590 
591   /// Return the element expressions of this range.
getElements()592   MutableArrayRef<Expr *> getElements() {
593     return {getTrailingObjects<Expr *>(), numElements};
594   }
getElements()595   ArrayRef<Expr *> getElements() const {
596     return const_cast<RangeExpr *>(this)->getElements();
597   }
598 
599   /// Return the range result type of this expression.
getType()600   RangeType getType() const { return mlir::cast<RangeType>(Base::getType()); }
601 
602 private:
RangeExpr(SMRange loc,RangeType type,unsigned numElements)603   RangeExpr(SMRange loc, RangeType type, unsigned numElements)
604       : Base(loc, type), numElements(numElements) {}
605 
606   /// The number of element values for this range.
607   unsigned numElements;
608 
609   /// TrailingObject utilities.
610   friend class llvm::TrailingObjects<RangeExpr, Expr *>;
611 };
612 
613 //===----------------------------------------------------------------------===//
614 // TupleExpr
615 //===----------------------------------------------------------------------===//
616 
617 /// This expression builds a tuple from a set of element values.
618 class TupleExpr final : public Node::NodeBase<TupleExpr, Expr>,
619                         private llvm::TrailingObjects<TupleExpr, Expr *> {
620 public:
621   static TupleExpr *create(Context &ctx, SMRange loc, ArrayRef<Expr *> elements,
622                            ArrayRef<StringRef> elementNames);
623 
624   /// Return the element expressions of this tuple.
getElements()625   MutableArrayRef<Expr *> getElements() {
626     return {getTrailingObjects<Expr *>(), getType().size()};
627   }
getElements()628   ArrayRef<Expr *> getElements() const {
629     return const_cast<TupleExpr *>(this)->getElements();
630   }
631 
632   /// Return the tuple result type of this expression.
getType()633   TupleType getType() const { return mlir::cast<TupleType>(Base::getType()); }
634 
635 private:
TupleExpr(SMRange loc,TupleType type)636   TupleExpr(SMRange loc, TupleType type) : Base(loc, type) {}
637 
638   /// TrailingObject utilities.
639   friend class llvm::TrailingObjects<TupleExpr, Expr *>;
640 };
641 
642 //===----------------------------------------------------------------------===//
643 // TypeExpr
644 //===----------------------------------------------------------------------===//
645 
646 /// This expression represents a literal MLIR Type, and contains the textual
647 /// assembly format of that type.
648 class TypeExpr : public Node::NodeBase<TypeExpr, Expr> {
649 public:
650   static TypeExpr *create(Context &ctx, SMRange loc, StringRef value);
651 
652   /// Get the raw value of this expression. This is the textual assembly format
653   /// of the MLIR Type.
getValue()654   StringRef getValue() const { return value; }
655 
656 private:
TypeExpr(Context & ctx,SMRange loc,StringRef value)657   TypeExpr(Context &ctx, SMRange loc, StringRef value)
658       : Base(loc, TypeType::get(ctx)), value(value) {}
659 
660   /// The value referenced by this expression.
661   StringRef value;
662 };
663 
664 //===----------------------------------------------------------------------===//
665 // Decl
666 //===----------------------------------------------------------------------===//
667 
668 /// This class represents the base Decl node.
669 class Decl : public Node {
670 public:
671   /// Return the name of the decl, or nullptr if it doesn't have one.
getName()672   const Name *getName() const { return name; }
673 
674   /// Provide type casting support.
675   static bool classof(const Node *node);
676 
677   /// Set the documentation comment for this decl.
678   void setDocComment(Context &ctx, StringRef comment);
679 
680   /// Return the documentation comment attached to this decl if it has been set.
681   /// Otherwise, returns std::nullopt.
getDocComment()682   std::optional<StringRef> getDocComment() const { return docComment; }
683 
684 protected:
685   Decl(TypeID typeID, SMRange loc, const Name *name = nullptr)
Node(typeID,loc)686       : Node(typeID, loc), name(name) {}
687 
688 private:
689   /// The name of the decl. This is optional for some decls, such as
690   /// PatternDecl.
691   const Name *name;
692 
693   /// The documentation comment attached to this decl. Defaults to std::nullopt
694   /// if the comment is unset/unknown.
695   std::optional<StringRef> docComment;
696 };
697 
698 //===----------------------------------------------------------------------===//
699 // ConstraintDecl
700 //===----------------------------------------------------------------------===//
701 
702 /// This class represents the base of all AST Constraint decls. Constraints
703 /// apply matcher conditions to, and define the type of PDLL variables.
704 class ConstraintDecl : public Decl {
705 public:
706   /// Provide type casting support.
707   static bool classof(const Node *node);
708 
709 protected:
710   ConstraintDecl(TypeID typeID, SMRange loc, const Name *name = nullptr)
Decl(typeID,loc,name)711       : Decl(typeID, loc, name) {}
712 };
713 
714 /// This class represents a reference to a constraint, and contains a constraint
715 /// and the location of the reference.
716 struct ConstraintRef {
ConstraintRefConstraintRef717   ConstraintRef(const ConstraintDecl *constraint, SMRange refLoc)
718       : constraint(constraint), referenceLoc(refLoc) {}
ConstraintRefConstraintRef719   explicit ConstraintRef(const ConstraintDecl *constraint)
720       : ConstraintRef(constraint, constraint->getLoc()) {}
721 
722   const ConstraintDecl *constraint;
723   SMRange referenceLoc;
724 };
725 
726 //===----------------------------------------------------------------------===//
727 // CoreConstraintDecl
728 //===----------------------------------------------------------------------===//
729 
730 /// This class represents the base of all "core" constraints. Core constraints
731 /// are those that generally represent a concrete IR construct, such as
732 /// `Type`s or `Value`s.
733 class CoreConstraintDecl : public ConstraintDecl {
734 public:
735   /// Provide type casting support.
736   static bool classof(const Node *node);
737 
738 protected:
739   CoreConstraintDecl(TypeID typeID, SMRange loc, const Name *name = nullptr)
ConstraintDecl(typeID,loc,name)740       : ConstraintDecl(typeID, loc, name) {}
741 };
742 
743 //===----------------------------------------------------------------------===//
744 // AttrConstraintDecl
745 
746 /// The class represents an Attribute constraint, and constrains a variable to
747 /// be an Attribute.
748 class AttrConstraintDecl
749     : public Node::NodeBase<AttrConstraintDecl, CoreConstraintDecl> {
750 public:
751   static AttrConstraintDecl *create(Context &ctx, SMRange loc,
752                                     Expr *typeExpr = nullptr);
753 
754   /// Return the optional type the attribute is constrained to.
getTypeExpr()755   Expr *getTypeExpr() { return typeExpr; }
getTypeExpr()756   const Expr *getTypeExpr() const { return typeExpr; }
757 
758 protected:
AttrConstraintDecl(SMRange loc,Expr * typeExpr)759   AttrConstraintDecl(SMRange loc, Expr *typeExpr)
760       : Base(loc), typeExpr(typeExpr) {}
761 
762   /// An optional type that the attribute is constrained to.
763   Expr *typeExpr;
764 };
765 
766 //===----------------------------------------------------------------------===//
767 // OpConstraintDecl
768 
769 /// The class represents an Operation constraint, and constrains a variable to
770 /// be an Operation.
771 class OpConstraintDecl
772     : public Node::NodeBase<OpConstraintDecl, CoreConstraintDecl> {
773 public:
774   static OpConstraintDecl *create(Context &ctx, SMRange loc,
775                                   const OpNameDecl *nameDecl = nullptr);
776 
777   /// Return the name of the operation, or std::nullopt if there isn't one.
778   std::optional<StringRef> getName() const;
779 
780   /// Return the declaration of the operation name.
getNameDecl()781   const OpNameDecl *getNameDecl() const { return nameDecl; }
782 
783 protected:
OpConstraintDecl(SMRange loc,const OpNameDecl * nameDecl)784   explicit OpConstraintDecl(SMRange loc, const OpNameDecl *nameDecl)
785       : Base(loc), nameDecl(nameDecl) {}
786 
787   /// The operation name of this constraint.
788   const OpNameDecl *nameDecl;
789 };
790 
791 //===----------------------------------------------------------------------===//
792 // TypeConstraintDecl
793 
794 /// The class represents a Type constraint, and constrains a variable to be a
795 /// Type.
796 class TypeConstraintDecl
797     : public Node::NodeBase<TypeConstraintDecl, CoreConstraintDecl> {
798 public:
799   static TypeConstraintDecl *create(Context &ctx, SMRange loc);
800 
801 protected:
802   using Base::Base;
803 };
804 
805 //===----------------------------------------------------------------------===//
806 // TypeRangeConstraintDecl
807 
808 /// The class represents a TypeRange constraint, and constrains a variable to be
809 /// a TypeRange.
810 class TypeRangeConstraintDecl
811     : public Node::NodeBase<TypeRangeConstraintDecl, CoreConstraintDecl> {
812 public:
813   static TypeRangeConstraintDecl *create(Context &ctx, SMRange loc);
814 
815 protected:
816   using Base::Base;
817 };
818 
819 //===----------------------------------------------------------------------===//
820 // ValueConstraintDecl
821 
822 /// The class represents a Value constraint, and constrains a variable to be a
823 /// Value.
824 class ValueConstraintDecl
825     : public Node::NodeBase<ValueConstraintDecl, CoreConstraintDecl> {
826 public:
827   static ValueConstraintDecl *create(Context &ctx, SMRange loc, Expr *typeExpr);
828 
829   /// Return the optional type the value is constrained to.
getTypeExpr()830   Expr *getTypeExpr() { return typeExpr; }
getTypeExpr()831   const Expr *getTypeExpr() const { return typeExpr; }
832 
833 protected:
ValueConstraintDecl(SMRange loc,Expr * typeExpr)834   ValueConstraintDecl(SMRange loc, Expr *typeExpr)
835       : Base(loc), typeExpr(typeExpr) {}
836 
837   /// An optional type that the value is constrained to.
838   Expr *typeExpr;
839 };
840 
841 //===----------------------------------------------------------------------===//
842 // ValueRangeConstraintDecl
843 
844 /// The class represents a ValueRange constraint, and constrains a variable to
845 /// be a ValueRange.
846 class ValueRangeConstraintDecl
847     : public Node::NodeBase<ValueRangeConstraintDecl, CoreConstraintDecl> {
848 public:
849   static ValueRangeConstraintDecl *create(Context &ctx, SMRange loc,
850                                           Expr *typeExpr = nullptr);
851 
852   /// Return the optional type the value range is constrained to.
getTypeExpr()853   Expr *getTypeExpr() { return typeExpr; }
getTypeExpr()854   const Expr *getTypeExpr() const { return typeExpr; }
855 
856 protected:
ValueRangeConstraintDecl(SMRange loc,Expr * typeExpr)857   ValueRangeConstraintDecl(SMRange loc, Expr *typeExpr)
858       : Base(loc), typeExpr(typeExpr) {}
859 
860   /// An optional type that the value range is constrained to.
861   Expr *typeExpr;
862 };
863 
864 //===----------------------------------------------------------------------===//
865 // UserConstraintDecl
866 //===----------------------------------------------------------------------===//
867 
868 /// This decl represents a user defined constraint. This is either:
869 ///   * an imported native constraint
870 ///     - Similar to an external function declaration. This is a native
871 ///       constraint defined externally, and imported into PDLL via a
872 ///       declaration.
873 ///   * a native constraint defined in PDLL
874 ///     - This is a native constraint, i.e. a constraint whose implementation is
875 ///       defined in C++(or potentially some other non-PDLL language). The
876 ///       implementation of this constraint is specified as a string code block
877 ///       in PDLL.
878 ///   * a PDLL constraint
879 ///     - This is a constraint which is defined using only PDLL constructs.
880 class UserConstraintDecl final
881     : public Node::NodeBase<UserConstraintDecl, ConstraintDecl>,
882       llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef> {
883 public:
884   /// Create a native constraint with the given optional code block.
885   static UserConstraintDecl *
886   createNative(Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
887                ArrayRef<VariableDecl *> results,
888                std::optional<StringRef> codeBlock, Type resultType,
889                ArrayRef<StringRef> nativeInputTypes = {}) {
890     return createImpl(ctx, name, inputs, nativeInputTypes, results, codeBlock,
891                       /*body=*/nullptr, resultType);
892   }
893 
894   /// Create a PDLL constraint with the given body.
createPDLL(Context & ctx,const Name & name,ArrayRef<VariableDecl * > inputs,ArrayRef<VariableDecl * > results,const CompoundStmt * body,Type resultType)895   static UserConstraintDecl *createPDLL(Context &ctx, const Name &name,
896                                         ArrayRef<VariableDecl *> inputs,
897                                         ArrayRef<VariableDecl *> results,
898                                         const CompoundStmt *body,
899                                         Type resultType) {
900     return createImpl(ctx, name, inputs, /*nativeInputTypes=*/std::nullopt,
901                       results, /*codeBlock=*/std::nullopt, body, resultType);
902   }
903 
904   /// Return the name of the constraint.
getName()905   const Name &getName() const { return *Decl::getName(); }
906 
907   /// Return the input arguments of this constraint.
getInputs()908   MutableArrayRef<VariableDecl *> getInputs() {
909     return {getTrailingObjects<VariableDecl *>(), numInputs};
910   }
getInputs()911   ArrayRef<VariableDecl *> getInputs() const {
912     return const_cast<UserConstraintDecl *>(this)->getInputs();
913   }
914 
915   /// Return the explicit native type to use for the given input. Returns
916   /// std::nullopt if no explicit type was set.
917   std::optional<StringRef> getNativeInputType(unsigned index) const;
918 
919   /// Return the explicit results of the constraint declaration. May be empty,
920   /// even if the constraint has results (e.g. in the case of inferred results).
getResults()921   MutableArrayRef<VariableDecl *> getResults() {
922     return {getTrailingObjects<VariableDecl *>() + numInputs, numResults};
923   }
getResults()924   ArrayRef<VariableDecl *> getResults() const {
925     return const_cast<UserConstraintDecl *>(this)->getResults();
926   }
927 
928   /// Return the optional code block of this constraint, if this is a native
929   /// constraint with a provided implementation.
getCodeBlock()930   std::optional<StringRef> getCodeBlock() const { return codeBlock; }
931 
932   /// Return the body of this constraint if this constraint is a PDLL
933   /// constraint, otherwise returns nullptr.
getBody()934   const CompoundStmt *getBody() const { return constraintBody; }
935 
936   /// Return the result type of this constraint.
getResultType()937   Type getResultType() const { return resultType; }
938 
939   /// Returns true if this constraint is external.
isExternal()940   bool isExternal() const { return !constraintBody && !codeBlock; }
941 
942 private:
943   /// Create either a PDLL constraint or a native constraint with the given
944   /// components.
945   static UserConstraintDecl *createImpl(Context &ctx, const Name &name,
946                                         ArrayRef<VariableDecl *> inputs,
947                                         ArrayRef<StringRef> nativeInputTypes,
948                                         ArrayRef<VariableDecl *> results,
949                                         std::optional<StringRef> codeBlock,
950                                         const CompoundStmt *body,
951                                         Type resultType);
952 
UserConstraintDecl(const Name & name,unsigned numInputs,bool hasNativeInputTypes,unsigned numResults,std::optional<StringRef> codeBlock,const CompoundStmt * body,Type resultType)953   UserConstraintDecl(const Name &name, unsigned numInputs,
954                      bool hasNativeInputTypes, unsigned numResults,
955                      std::optional<StringRef> codeBlock,
956                      const CompoundStmt *body, Type resultType)
957       : Base(name.getLoc(), &name), numInputs(numInputs),
958         numResults(numResults), codeBlock(codeBlock), constraintBody(body),
959         resultType(resultType), hasNativeInputTypes(hasNativeInputTypes) {}
960 
961   /// The number of inputs to this constraint.
962   unsigned numInputs;
963 
964   /// The number of explicit results to this constraint.
965   unsigned numResults;
966 
967   /// The optional code block of this constraint.
968   std::optional<StringRef> codeBlock;
969 
970   /// The optional body of this constraint.
971   const CompoundStmt *constraintBody;
972 
973   /// The result type of the constraint.
974   Type resultType;
975 
976   /// Flag indicating if this constraint has explicit native input types.
977   bool hasNativeInputTypes;
978 
979   /// Allow access to various internals.
980   friend llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef>;
numTrailingObjects(OverloadToken<VariableDecl * >)981   size_t numTrailingObjects(OverloadToken<VariableDecl *>) const {
982     return numInputs + numResults;
983   }
984 };
985 
986 //===----------------------------------------------------------------------===//
987 // NamedAttributeDecl
988 //===----------------------------------------------------------------------===//
989 
990 /// This Decl represents a NamedAttribute, and contains a string name and
991 /// attribute value.
992 class NamedAttributeDecl : public Node::NodeBase<NamedAttributeDecl, Decl> {
993 public:
994   static NamedAttributeDecl *create(Context &ctx, const Name &name,
995                                     Expr *value);
996 
997   /// Return the name of the attribute.
getName()998   const Name &getName() const { return *Decl::getName(); }
999 
1000   /// Return value of the attribute.
getValue()1001   Expr *getValue() const { return value; }
1002 
1003 private:
NamedAttributeDecl(const Name & name,Expr * value)1004   NamedAttributeDecl(const Name &name, Expr *value)
1005       : Base(name.getLoc(), &name), value(value) {}
1006 
1007   /// The value of the attribute.
1008   Expr *value;
1009 };
1010 
1011 //===----------------------------------------------------------------------===//
1012 // OpNameDecl
1013 //===----------------------------------------------------------------------===//
1014 
1015 /// This Decl represents an OperationName.
1016 class OpNameDecl : public Node::NodeBase<OpNameDecl, Decl> {
1017 public:
1018   static OpNameDecl *create(Context &ctx, const Name &name);
1019   static OpNameDecl *create(Context &ctx, SMRange loc);
1020 
1021   /// Return the name of this operation, or std::nullopt if the name is unknown.
getName()1022   std::optional<StringRef> getName() const {
1023     const Name *name = Decl::getName();
1024     return name ? std::optional<StringRef>(name->getName()) : std::nullopt;
1025   }
1026 
1027 private:
OpNameDecl(const Name & name)1028   explicit OpNameDecl(const Name &name) : Base(name.getLoc(), &name) {}
OpNameDecl(SMRange loc)1029   explicit OpNameDecl(SMRange loc) : Base(loc) {}
1030 };
1031 
1032 //===----------------------------------------------------------------------===//
1033 // PatternDecl
1034 //===----------------------------------------------------------------------===//
1035 
1036 /// This Decl represents a single Pattern.
1037 class PatternDecl : public Node::NodeBase<PatternDecl, Decl> {
1038 public:
1039   static PatternDecl *create(Context &ctx, SMRange location, const Name *name,
1040                              std::optional<uint16_t> benefit,
1041                              bool hasBoundedRecursion,
1042                              const CompoundStmt *body);
1043 
1044   /// Return the benefit of this pattern if specified, or std::nullopt.
getBenefit()1045   std::optional<uint16_t> getBenefit() const { return benefit; }
1046 
1047   /// Return if this pattern has bounded rewrite recursion.
hasBoundedRewriteRecursion()1048   bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; }
1049 
1050   /// Return the body of this pattern.
getBody()1051   const CompoundStmt *getBody() const { return patternBody; }
1052 
1053   /// Return the root rewrite statement of this pattern.
getRootRewriteStmt()1054   const OpRewriteStmt *getRootRewriteStmt() const {
1055     return cast<OpRewriteStmt>(patternBody->getChildren().back());
1056   }
1057 
1058 private:
PatternDecl(SMRange loc,const Name * name,std::optional<uint16_t> benefit,bool hasBoundedRecursion,const CompoundStmt * body)1059   PatternDecl(SMRange loc, const Name *name, std::optional<uint16_t> benefit,
1060               bool hasBoundedRecursion, const CompoundStmt *body)
1061       : Base(loc, name), benefit(benefit),
1062         hasBoundedRecursion(hasBoundedRecursion), patternBody(body) {}
1063 
1064   /// The benefit of the pattern if it was explicitly specified, std::nullopt
1065   /// otherwise.
1066   std::optional<uint16_t> benefit;
1067 
1068   /// If the pattern has properly bounded rewrite recursion or not.
1069   bool hasBoundedRecursion;
1070 
1071   /// The compound statement representing the body of the pattern.
1072   const CompoundStmt *patternBody;
1073 };
1074 
1075 //===----------------------------------------------------------------------===//
1076 // UserRewriteDecl
1077 //===----------------------------------------------------------------------===//
1078 
1079 /// This decl represents a user defined rewrite. This is either:
1080 ///   * an imported native rewrite
1081 ///     - Similar to an external function declaration. This is a native
1082 ///       rewrite defined externally, and imported into PDLL via a declaration.
1083 ///   * a native rewrite defined in PDLL
1084 ///     - This is a native rewrite, i.e. a rewrite whose implementation is
1085 ///       defined in C++(or potentially some other non-PDLL language). The
1086 ///       implementation of this rewrite is specified as a string code block
1087 ///       in PDLL.
1088 ///   * a PDLL rewrite
1089 ///     - This is a rewrite which is defined using only PDLL constructs.
1090 class UserRewriteDecl final
1091     : public Node::NodeBase<UserRewriteDecl, Decl>,
1092       llvm::TrailingObjects<UserRewriteDecl, VariableDecl *> {
1093 public:
1094   /// Create a native rewrite with the given optional code block.
createNative(Context & ctx,const Name & name,ArrayRef<VariableDecl * > inputs,ArrayRef<VariableDecl * > results,std::optional<StringRef> codeBlock,Type resultType)1095   static UserRewriteDecl *createNative(Context &ctx, const Name &name,
1096                                        ArrayRef<VariableDecl *> inputs,
1097                                        ArrayRef<VariableDecl *> results,
1098                                        std::optional<StringRef> codeBlock,
1099                                        Type resultType) {
1100     return createImpl(ctx, name, inputs, results, codeBlock, /*body=*/nullptr,
1101                       resultType);
1102   }
1103 
1104   /// Create a PDLL rewrite with the given body.
createPDLL(Context & ctx,const Name & name,ArrayRef<VariableDecl * > inputs,ArrayRef<VariableDecl * > results,const CompoundStmt * body,Type resultType)1105   static UserRewriteDecl *createPDLL(Context &ctx, const Name &name,
1106                                      ArrayRef<VariableDecl *> inputs,
1107                                      ArrayRef<VariableDecl *> results,
1108                                      const CompoundStmt *body,
1109                                      Type resultType) {
1110     return createImpl(ctx, name, inputs, results, /*codeBlock=*/std::nullopt,
1111                       body, resultType);
1112   }
1113 
1114   /// Return the name of the rewrite.
getName()1115   const Name &getName() const { return *Decl::getName(); }
1116 
1117   /// Return the input arguments of this rewrite.
getInputs()1118   MutableArrayRef<VariableDecl *> getInputs() {
1119     return {getTrailingObjects<VariableDecl *>(), numInputs};
1120   }
getInputs()1121   ArrayRef<VariableDecl *> getInputs() const {
1122     return const_cast<UserRewriteDecl *>(this)->getInputs();
1123   }
1124 
1125   /// Return the explicit results of the rewrite declaration. May be empty,
1126   /// even if the rewrite has results (e.g. in the case of inferred results).
getResults()1127   MutableArrayRef<VariableDecl *> getResults() {
1128     return {getTrailingObjects<VariableDecl *>() + numInputs, numResults};
1129   }
getResults()1130   ArrayRef<VariableDecl *> getResults() const {
1131     return const_cast<UserRewriteDecl *>(this)->getResults();
1132   }
1133 
1134   /// Return the optional code block of this rewrite, if this is a native
1135   /// rewrite with a provided implementation.
getCodeBlock()1136   std::optional<StringRef> getCodeBlock() const { return codeBlock; }
1137 
1138   /// Return the body of this rewrite if this rewrite is a PDLL rewrite,
1139   /// otherwise returns nullptr.
getBody()1140   const CompoundStmt *getBody() const { return rewriteBody; }
1141 
1142   /// Return the result type of this rewrite.
getResultType()1143   Type getResultType() const { return resultType; }
1144 
1145   /// Returns true if this rewrite is external.
isExternal()1146   bool isExternal() const { return !rewriteBody && !codeBlock; }
1147 
1148 private:
1149   /// Create either a PDLL rewrite or a native rewrite with the given
1150   /// components.
1151   static UserRewriteDecl *createImpl(Context &ctx, const Name &name,
1152                                      ArrayRef<VariableDecl *> inputs,
1153                                      ArrayRef<VariableDecl *> results,
1154                                      std::optional<StringRef> codeBlock,
1155                                      const CompoundStmt *body, Type resultType);
1156 
UserRewriteDecl(const Name & name,unsigned numInputs,unsigned numResults,std::optional<StringRef> codeBlock,const CompoundStmt * body,Type resultType)1157   UserRewriteDecl(const Name &name, unsigned numInputs, unsigned numResults,
1158                   std::optional<StringRef> codeBlock, const CompoundStmt *body,
1159                   Type resultType)
1160       : Base(name.getLoc(), &name), numInputs(numInputs),
1161         numResults(numResults), codeBlock(codeBlock), rewriteBody(body),
1162         resultType(resultType) {}
1163 
1164   /// The number of inputs to this rewrite.
1165   unsigned numInputs;
1166 
1167   /// The number of explicit results to this rewrite.
1168   unsigned numResults;
1169 
1170   /// The optional code block of this rewrite.
1171   std::optional<StringRef> codeBlock;
1172 
1173   /// The optional body of this rewrite.
1174   const CompoundStmt *rewriteBody;
1175 
1176   /// The result type of the rewrite.
1177   Type resultType;
1178 
1179   /// Allow access to various internals.
1180   friend llvm::TrailingObjects<UserRewriteDecl, VariableDecl *>;
1181 };
1182 
1183 //===----------------------------------------------------------------------===//
1184 // CallableDecl
1185 //===----------------------------------------------------------------------===//
1186 
1187 /// This decl represents a shared interface for all callable decls.
1188 class CallableDecl : public Decl {
1189 public:
1190   /// Return the callable type of this decl.
getCallableType()1191   StringRef getCallableType() const {
1192     if (isa<UserConstraintDecl>(this))
1193       return "constraint";
1194     assert(isa<UserRewriteDecl>(this) && "unknown callable type");
1195     return "rewrite";
1196   }
1197 
1198   /// Return the inputs of this decl.
getInputs()1199   ArrayRef<VariableDecl *> getInputs() const {
1200     if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1201       return cst->getInputs();
1202     return cast<UserRewriteDecl>(this)->getInputs();
1203   }
1204 
1205   /// Return the result type of this decl.
getResultType()1206   Type getResultType() const {
1207     if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1208       return cst->getResultType();
1209     return cast<UserRewriteDecl>(this)->getResultType();
1210   }
1211 
1212   /// Return the explicit results of the declaration. Note that these may be
1213   /// empty, even if the callable has results (e.g. in the case of inferred
1214   /// results).
getResults()1215   ArrayRef<VariableDecl *> getResults() const {
1216     if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1217       return cst->getResults();
1218     return cast<UserRewriteDecl>(this)->getResults();
1219   }
1220 
1221   /// Return the optional code block of this callable, if this is a native
1222   /// callable with a provided implementation.
getCodeBlock()1223   std::optional<StringRef> getCodeBlock() const {
1224     if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1225       return cst->getCodeBlock();
1226     return cast<UserRewriteDecl>(this)->getCodeBlock();
1227   }
1228 
1229   /// Support LLVM type casting facilities.
classof(const Node * decl)1230   static bool classof(const Node *decl) {
1231     return isa<UserConstraintDecl, UserRewriteDecl>(decl);
1232   }
1233 };
1234 
1235 //===----------------------------------------------------------------------===//
1236 // VariableDecl
1237 //===----------------------------------------------------------------------===//
1238 
1239 /// This Decl represents the definition of a PDLL variable.
1240 class VariableDecl final
1241     : public Node::NodeBase<VariableDecl, Decl>,
1242       private llvm::TrailingObjects<VariableDecl, ConstraintRef> {
1243 public:
1244   static VariableDecl *create(Context &ctx, const Name &name, Type type,
1245                               Expr *initExpr,
1246                               ArrayRef<ConstraintRef> constraints);
1247 
1248   /// Return the constraints of this variable.
getConstraints()1249   MutableArrayRef<ConstraintRef> getConstraints() {
1250     return {getTrailingObjects<ConstraintRef>(), numConstraints};
1251   }
getConstraints()1252   ArrayRef<ConstraintRef> getConstraints() const {
1253     return const_cast<VariableDecl *>(this)->getConstraints();
1254   }
1255 
1256   /// Return the initializer expression of this statement, or nullptr if there
1257   /// was no initializer.
getInitExpr()1258   Expr *getInitExpr() const { return initExpr; }
1259 
1260   /// Return the name of the decl.
getName()1261   const Name &getName() const { return *Decl::getName(); }
1262 
1263   /// Return the type of the decl.
getType()1264   Type getType() const { return type; }
1265 
1266 private:
VariableDecl(const Name & name,Type type,Expr * initExpr,unsigned numConstraints)1267   VariableDecl(const Name &name, Type type, Expr *initExpr,
1268                unsigned numConstraints)
1269       : Base(name.getLoc(), &name), type(type), initExpr(initExpr),
1270         numConstraints(numConstraints) {}
1271 
1272   /// The type of the variable.
1273   Type type;
1274 
1275   /// The optional initializer expression of this statement.
1276   Expr *initExpr;
1277 
1278   /// The number of constraints attached to this variable.
1279   unsigned numConstraints;
1280 
1281   /// Allow access to various internals.
1282   friend llvm::TrailingObjects<VariableDecl, ConstraintRef>;
1283 };
1284 
1285 //===----------------------------------------------------------------------===//
1286 // Module
1287 //===----------------------------------------------------------------------===//
1288 
1289 /// This class represents a top-level AST module.
1290 class Module final : public Node::NodeBase<Module, Node>,
1291                      private llvm::TrailingObjects<Module, Decl *> {
1292 public:
1293   static Module *create(Context &ctx, SMLoc loc, ArrayRef<Decl *> children);
1294 
1295   /// Return the children of this module.
getChildren()1296   MutableArrayRef<Decl *> getChildren() {
1297     return {getTrailingObjects<Decl *>(), numChildren};
1298   }
getChildren()1299   ArrayRef<Decl *> getChildren() const {
1300     return const_cast<Module *>(this)->getChildren();
1301   }
1302 
1303 private:
Module(SMLoc loc,unsigned numChildren)1304   Module(SMLoc loc, unsigned numChildren)
1305       : Base(SMRange{loc, loc}), numChildren(numChildren) {}
1306 
1307   /// The number of decls held by this module.
1308   unsigned numChildren;
1309 
1310   /// Allow access to various internals.
1311   friend llvm::TrailingObjects<Module, Decl *>;
1312 };
1313 
1314 //===----------------------------------------------------------------------===//
1315 // Defered Method Definitions
1316 //===----------------------------------------------------------------------===//
1317 
classof(const Node * node)1318 inline bool Decl::classof(const Node *node) {
1319   return isa<ConstraintDecl, NamedAttributeDecl, OpNameDecl, PatternDecl,
1320              UserRewriteDecl, VariableDecl>(node);
1321 }
1322 
classof(const Node * node)1323 inline bool ConstraintDecl::classof(const Node *node) {
1324   return isa<CoreConstraintDecl, UserConstraintDecl>(node);
1325 }
1326 
classof(const Node * node)1327 inline bool CoreConstraintDecl::classof(const Node *node) {
1328   return isa<AttrConstraintDecl, OpConstraintDecl, TypeConstraintDecl,
1329              TypeRangeConstraintDecl, ValueConstraintDecl,
1330              ValueRangeConstraintDecl>(node);
1331 }
1332 
classof(const Node * node)1333 inline bool Expr::classof(const Node *node) {
1334   return isa<AttributeExpr, CallExpr, DeclRefExpr, MemberAccessExpr,
1335              OperationExpr, RangeExpr, TupleExpr, TypeExpr>(node);
1336 }
1337 
classof(const Node * node)1338 inline bool OpRewriteStmt::classof(const Node *node) {
1339   return isa<EraseStmt, ReplaceStmt, RewriteStmt>(node);
1340 }
1341 
classof(const Node * node)1342 inline bool Stmt::classof(const Node *node) {
1343   return isa<CompoundStmt, LetStmt, OpRewriteStmt, Expr>(node);
1344 }
1345 
1346 } // namespace ast
1347 } // namespace pdll
1348 } // namespace mlir
1349 
1350 #endif // MLIR_TOOLS_PDLL_AST_NODES_H_
1351