xref: /llvm-project/mlir/lib/Tools/PDLL/AST/Nodes.cpp (revision 930916c7f3622870b40138dafcc5f94740404e8c)
1 //===- Nodes.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/AST/Nodes.h"
10 #include "mlir/Tools/PDLL/AST/Context.h"
11 #include "llvm/ADT/SmallPtrSet.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 #include <optional>
14 
15 using namespace mlir;
16 using namespace mlir::pdll::ast;
17 
18 /// Copy a string reference into the context with a null terminator.
copyStringWithNull(Context & ctx,StringRef str)19 static StringRef copyStringWithNull(Context &ctx, StringRef str) {
20   if (str.empty())
21     return str;
22 
23   char *data = ctx.getAllocator().Allocate<char>(str.size() + 1);
24   std::copy(str.begin(), str.end(), data);
25   data[str.size()] = 0;
26   return StringRef(data, str.size());
27 }
28 
29 //===----------------------------------------------------------------------===//
30 // Name
31 //===----------------------------------------------------------------------===//
32 
create(Context & ctx,StringRef name,SMRange location)33 const Name &Name::create(Context &ctx, StringRef name, SMRange location) {
34   return *new (ctx.getAllocator().Allocate<Name>())
35       Name(copyStringWithNull(ctx, name), location);
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Node
40 //===----------------------------------------------------------------------===//
41 
42 namespace {
43 class NodeVisitor {
44 public:
NodeVisitor(function_ref<void (const Node *)> visitFn)45   explicit NodeVisitor(function_ref<void(const Node *)> visitFn)
46       : visitFn(visitFn) {}
47 
visit(const Node * node)48   void visit(const Node *node) {
49     if (!node || !alreadyVisited.insert(node).second)
50       return;
51 
52     visitFn(node);
53     TypeSwitch<const Node *>(node)
54         .Case<
55             // Statements.
56             const CompoundStmt, const EraseStmt, const LetStmt,
57             const ReplaceStmt, const ReturnStmt, const RewriteStmt,
58 
59             // Expressions.
60             const AttributeExpr, const CallExpr, const DeclRefExpr,
61             const MemberAccessExpr, const OperationExpr, const RangeExpr,
62             const TupleExpr, const TypeExpr,
63 
64             // Core Constraint Decls.
65             const AttrConstraintDecl, const OpConstraintDecl,
66             const TypeConstraintDecl, const TypeRangeConstraintDecl,
67             const ValueConstraintDecl, const ValueRangeConstraintDecl,
68 
69             // Decls.
70             const NamedAttributeDecl, const OpNameDecl, const PatternDecl,
71             const UserConstraintDecl, const UserRewriteDecl, const VariableDecl,
72 
73             const Module>(
74             [&](auto derivedNode) { this->visitImpl(derivedNode); })
75         .Default([](const Node *) { llvm_unreachable("unknown AST node"); });
76   }
77 
78 private:
visitImpl(const CompoundStmt * stmt)79   void visitImpl(const CompoundStmt *stmt) {
80     for (const Node *child : stmt->getChildren())
81       visit(child);
82   }
visitImpl(const EraseStmt * stmt)83   void visitImpl(const EraseStmt *stmt) { visit(stmt->getRootOpExpr()); }
visitImpl(const LetStmt * stmt)84   void visitImpl(const LetStmt *stmt) { visit(stmt->getVarDecl()); }
visitImpl(const ReplaceStmt * stmt)85   void visitImpl(const ReplaceStmt *stmt) {
86     visit(stmt->getRootOpExpr());
87     for (const Node *child : stmt->getReplExprs())
88       visit(child);
89   }
visitImpl(const ReturnStmt * stmt)90   void visitImpl(const ReturnStmt *stmt) { visit(stmt->getResultExpr()); }
visitImpl(const RewriteStmt * stmt)91   void visitImpl(const RewriteStmt *stmt) {
92     visit(stmt->getRootOpExpr());
93     visit(stmt->getRewriteBody());
94   }
95 
visitImpl(const AttributeExpr * expr)96   void visitImpl(const AttributeExpr *expr) {}
visitImpl(const CallExpr * expr)97   void visitImpl(const CallExpr *expr) {
98     visit(expr->getCallableExpr());
99     for (const Node *child : expr->getArguments())
100       visit(child);
101   }
visitImpl(const DeclRefExpr * expr)102   void visitImpl(const DeclRefExpr *expr) { visit(expr->getDecl()); }
visitImpl(const MemberAccessExpr * expr)103   void visitImpl(const MemberAccessExpr *expr) { visit(expr->getParentExpr()); }
visitImpl(const OperationExpr * expr)104   void visitImpl(const OperationExpr *expr) {
105     visit(expr->getNameDecl());
106     for (const Node *child : expr->getOperands())
107       visit(child);
108     for (const Node *child : expr->getResultTypes())
109       visit(child);
110     for (const Node *child : expr->getAttributes())
111       visit(child);
112   }
visitImpl(const RangeExpr * expr)113   void visitImpl(const RangeExpr *expr) {
114     for (const Node *child : expr->getElements())
115       visit(child);
116   }
visitImpl(const TupleExpr * expr)117   void visitImpl(const TupleExpr *expr) {
118     for (const Node *child : expr->getElements())
119       visit(child);
120   }
visitImpl(const TypeExpr * expr)121   void visitImpl(const TypeExpr *expr) {}
122 
visitImpl(const AttrConstraintDecl * decl)123   void visitImpl(const AttrConstraintDecl *decl) { visit(decl->getTypeExpr()); }
visitImpl(const OpConstraintDecl * decl)124   void visitImpl(const OpConstraintDecl *decl) { visit(decl->getNameDecl()); }
visitImpl(const TypeConstraintDecl * decl)125   void visitImpl(const TypeConstraintDecl *decl) {}
visitImpl(const TypeRangeConstraintDecl * decl)126   void visitImpl(const TypeRangeConstraintDecl *decl) {}
visitImpl(const ValueConstraintDecl * decl)127   void visitImpl(const ValueConstraintDecl *decl) {
128     visit(decl->getTypeExpr());
129   }
visitImpl(const ValueRangeConstraintDecl * decl)130   void visitImpl(const ValueRangeConstraintDecl *decl) {
131     visit(decl->getTypeExpr());
132   }
133 
visitImpl(const NamedAttributeDecl * decl)134   void visitImpl(const NamedAttributeDecl *decl) { visit(decl->getValue()); }
visitImpl(const OpNameDecl * decl)135   void visitImpl(const OpNameDecl *decl) {}
visitImpl(const PatternDecl * decl)136   void visitImpl(const PatternDecl *decl) { visit(decl->getBody()); }
visitImpl(const UserConstraintDecl * decl)137   void visitImpl(const UserConstraintDecl *decl) {
138     for (const Node *child : decl->getInputs())
139       visit(child);
140     for (const Node *child : decl->getResults())
141       visit(child);
142     visit(decl->getBody());
143   }
visitImpl(const UserRewriteDecl * decl)144   void visitImpl(const UserRewriteDecl *decl) {
145     for (const Node *child : decl->getInputs())
146       visit(child);
147     for (const Node *child : decl->getResults())
148       visit(child);
149     visit(decl->getBody());
150   }
visitImpl(const VariableDecl * decl)151   void visitImpl(const VariableDecl *decl) {
152     visit(decl->getInitExpr());
153     for (const ConstraintRef &child : decl->getConstraints())
154       visit(child.constraint);
155   }
156 
visitImpl(const Module * module)157   void visitImpl(const Module *module) {
158     for (const Node *child : module->getChildren())
159       visit(child);
160   }
161 
162   function_ref<void(const Node *)> visitFn;
163   SmallPtrSet<const Node *, 16> alreadyVisited;
164 };
165 } // namespace
166 
walk(function_ref<void (const Node *)> walkFn) const167 void Node::walk(function_ref<void(const Node *)> walkFn) const {
168   return NodeVisitor(walkFn).visit(this);
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // DeclScope
173 //===----------------------------------------------------------------------===//
174 
add(Decl * decl)175 void DeclScope::add(Decl *decl) {
176   const Name *name = decl->getName();
177   assert(name && "expected a named decl");
178   assert(!decls.count(name->getName()) && "decl with this name already exists");
179   decls.try_emplace(name->getName(), decl);
180 }
181 
lookup(StringRef name)182 Decl *DeclScope::lookup(StringRef name) {
183   if (Decl *decl = decls.lookup(name))
184     return decl;
185   return parent ? parent->lookup(name) : nullptr;
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // CompoundStmt
190 //===----------------------------------------------------------------------===//
191 
create(Context & ctx,SMRange loc,ArrayRef<Stmt * > children)192 CompoundStmt *CompoundStmt::create(Context &ctx, SMRange loc,
193                                    ArrayRef<Stmt *> children) {
194   unsigned allocSize = CompoundStmt::totalSizeToAlloc<Stmt *>(children.size());
195   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CompoundStmt));
196 
197   CompoundStmt *stmt = new (rawData) CompoundStmt(loc, children.size());
198   std::uninitialized_copy(children.begin(), children.end(),
199                           stmt->getChildren().begin());
200   return stmt;
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // LetStmt
205 //===----------------------------------------------------------------------===//
206 
create(Context & ctx,SMRange loc,VariableDecl * varDecl)207 LetStmt *LetStmt::create(Context &ctx, SMRange loc, VariableDecl *varDecl) {
208   return new (ctx.getAllocator().Allocate<LetStmt>()) LetStmt(loc, varDecl);
209 }
210 
211 //===----------------------------------------------------------------------===//
212 // OpRewriteStmt
213 //===----------------------------------------------------------------------===//
214 
215 //===----------------------------------------------------------------------===//
216 // EraseStmt
217 
create(Context & ctx,SMRange loc,Expr * rootOp)218 EraseStmt *EraseStmt::create(Context &ctx, SMRange loc, Expr *rootOp) {
219   return new (ctx.getAllocator().Allocate<EraseStmt>()) EraseStmt(loc, rootOp);
220 }
221 
222 //===----------------------------------------------------------------------===//
223 // ReplaceStmt
224 
create(Context & ctx,SMRange loc,Expr * rootOp,ArrayRef<Expr * > replExprs)225 ReplaceStmt *ReplaceStmt::create(Context &ctx, SMRange loc, Expr *rootOp,
226                                  ArrayRef<Expr *> replExprs) {
227   unsigned allocSize = ReplaceStmt::totalSizeToAlloc<Expr *>(replExprs.size());
228   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(ReplaceStmt));
229 
230   ReplaceStmt *stmt = new (rawData) ReplaceStmt(loc, rootOp, replExprs.size());
231   std::uninitialized_copy(replExprs.begin(), replExprs.end(),
232                           stmt->getReplExprs().begin());
233   return stmt;
234 }
235 
236 //===----------------------------------------------------------------------===//
237 // RewriteStmt
238 
create(Context & ctx,SMRange loc,Expr * rootOp,CompoundStmt * rewriteBody)239 RewriteStmt *RewriteStmt::create(Context &ctx, SMRange loc, Expr *rootOp,
240                                  CompoundStmt *rewriteBody) {
241   return new (ctx.getAllocator().Allocate<RewriteStmt>())
242       RewriteStmt(loc, rootOp, rewriteBody);
243 }
244 
245 //===----------------------------------------------------------------------===//
246 // ReturnStmt
247 //===----------------------------------------------------------------------===//
248 
create(Context & ctx,SMRange loc,Expr * resultExpr)249 ReturnStmt *ReturnStmt::create(Context &ctx, SMRange loc, Expr *resultExpr) {
250   return new (ctx.getAllocator().Allocate<ReturnStmt>())
251       ReturnStmt(loc, resultExpr);
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // AttributeExpr
256 //===----------------------------------------------------------------------===//
257 
create(Context & ctx,SMRange loc,StringRef value)258 AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc,
259                                      StringRef value) {
260   return new (ctx.getAllocator().Allocate<AttributeExpr>())
261       AttributeExpr(ctx, loc, copyStringWithNull(ctx, value));
262 }
263 
264 //===----------------------------------------------------------------------===//
265 // CallExpr
266 //===----------------------------------------------------------------------===//
267 
create(Context & ctx,SMRange loc,Expr * callable,ArrayRef<Expr * > arguments,Type resultType,bool isNegated)268 CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable,
269                            ArrayRef<Expr *> arguments, Type resultType,
270                            bool isNegated) {
271   unsigned allocSize = CallExpr::totalSizeToAlloc<Expr *>(arguments.size());
272   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr));
273 
274   CallExpr *expr = new (rawData)
275       CallExpr(loc, resultType, callable, arguments.size(), isNegated);
276   std::uninitialized_copy(arguments.begin(), arguments.end(),
277                           expr->getArguments().begin());
278   return expr;
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // DeclRefExpr
283 //===----------------------------------------------------------------------===//
284 
create(Context & ctx,SMRange loc,Decl * decl,Type type)285 DeclRefExpr *DeclRefExpr::create(Context &ctx, SMRange loc, Decl *decl,
286                                  Type type) {
287   return new (ctx.getAllocator().Allocate<DeclRefExpr>())
288       DeclRefExpr(loc, decl, type);
289 }
290 
291 //===----------------------------------------------------------------------===//
292 // MemberAccessExpr
293 //===----------------------------------------------------------------------===//
294 
create(Context & ctx,SMRange loc,const Expr * parentExpr,StringRef memberName,Type type)295 MemberAccessExpr *MemberAccessExpr::create(Context &ctx, SMRange loc,
296                                            const Expr *parentExpr,
297                                            StringRef memberName, Type type) {
298   return new (ctx.getAllocator().Allocate<MemberAccessExpr>()) MemberAccessExpr(
299       loc, parentExpr, memberName.copy(ctx.getAllocator()), type);
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // OperationExpr
304 //===----------------------------------------------------------------------===//
305 
306 OperationExpr *
create(Context & ctx,SMRange loc,const ods::Operation * odsOp,const OpNameDecl * name,ArrayRef<Expr * > operands,ArrayRef<Expr * > resultTypes,ArrayRef<NamedAttributeDecl * > attributes)307 OperationExpr::create(Context &ctx, SMRange loc, const ods::Operation *odsOp,
308                       const OpNameDecl *name, ArrayRef<Expr *> operands,
309                       ArrayRef<Expr *> resultTypes,
310                       ArrayRef<NamedAttributeDecl *> attributes) {
311   unsigned allocSize =
312       OperationExpr::totalSizeToAlloc<Expr *, NamedAttributeDecl *>(
313           operands.size() + resultTypes.size(), attributes.size());
314   void *rawData =
315       ctx.getAllocator().Allocate(allocSize, alignof(OperationExpr));
316 
317   Type resultType = OperationType::get(ctx, name->getName(), odsOp);
318   OperationExpr *opExpr = new (rawData)
319       OperationExpr(loc, resultType, name, operands.size(), resultTypes.size(),
320                     attributes.size(), name->getLoc());
321   std::uninitialized_copy(operands.begin(), operands.end(),
322                           opExpr->getOperands().begin());
323   std::uninitialized_copy(resultTypes.begin(), resultTypes.end(),
324                           opExpr->getResultTypes().begin());
325   std::uninitialized_copy(attributes.begin(), attributes.end(),
326                           opExpr->getAttributes().begin());
327   return opExpr;
328 }
329 
getName() const330 std::optional<StringRef> OperationExpr::getName() const {
331   return getNameDecl()->getName();
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // RangeExpr
336 //===----------------------------------------------------------------------===//
337 
create(Context & ctx,SMRange loc,ArrayRef<Expr * > elements,RangeType type)338 RangeExpr *RangeExpr::create(Context &ctx, SMRange loc,
339                              ArrayRef<Expr *> elements, RangeType type) {
340   unsigned allocSize = RangeExpr::totalSizeToAlloc<Expr *>(elements.size());
341   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(TupleExpr));
342 
343   RangeExpr *expr = new (rawData) RangeExpr(loc, type, elements.size());
344   std::uninitialized_copy(elements.begin(), elements.end(),
345                           expr->getElements().begin());
346   return expr;
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // TupleExpr
351 //===----------------------------------------------------------------------===//
352 
create(Context & ctx,SMRange loc,ArrayRef<Expr * > elements,ArrayRef<StringRef> names)353 TupleExpr *TupleExpr::create(Context &ctx, SMRange loc,
354                              ArrayRef<Expr *> elements,
355                              ArrayRef<StringRef> names) {
356   unsigned allocSize = TupleExpr::totalSizeToAlloc<Expr *>(elements.size());
357   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(TupleExpr));
358 
359   auto elementTypes = llvm::map_range(
360       elements, [](const Expr *expr) { return expr->getType(); });
361   TupleType type = TupleType::get(ctx, llvm::to_vector(elementTypes), names);
362 
363   TupleExpr *expr = new (rawData) TupleExpr(loc, type);
364   std::uninitialized_copy(elements.begin(), elements.end(),
365                           expr->getElements().begin());
366   return expr;
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // TypeExpr
371 //===----------------------------------------------------------------------===//
372 
create(Context & ctx,SMRange loc,StringRef value)373 TypeExpr *TypeExpr::create(Context &ctx, SMRange loc, StringRef value) {
374   return new (ctx.getAllocator().Allocate<TypeExpr>())
375       TypeExpr(ctx, loc, copyStringWithNull(ctx, value));
376 }
377 
378 //===----------------------------------------------------------------------===//
379 // Decl
380 //===----------------------------------------------------------------------===//
381 
setDocComment(Context & ctx,StringRef comment)382 void Decl::setDocComment(Context &ctx, StringRef comment) {
383   docComment = comment.copy(ctx.getAllocator());
384 }
385 
386 //===----------------------------------------------------------------------===//
387 // AttrConstraintDecl
388 //===----------------------------------------------------------------------===//
389 
create(Context & ctx,SMRange loc,Expr * typeExpr)390 AttrConstraintDecl *AttrConstraintDecl::create(Context &ctx, SMRange loc,
391                                                Expr *typeExpr) {
392   return new (ctx.getAllocator().Allocate<AttrConstraintDecl>())
393       AttrConstraintDecl(loc, typeExpr);
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // OpConstraintDecl
398 //===----------------------------------------------------------------------===//
399 
create(Context & ctx,SMRange loc,const OpNameDecl * nameDecl)400 OpConstraintDecl *OpConstraintDecl::create(Context &ctx, SMRange loc,
401                                            const OpNameDecl *nameDecl) {
402   if (!nameDecl)
403     nameDecl = OpNameDecl::create(ctx, SMRange());
404 
405   return new (ctx.getAllocator().Allocate<OpConstraintDecl>())
406       OpConstraintDecl(loc, nameDecl);
407 }
408 
getName() const409 std::optional<StringRef> OpConstraintDecl::getName() const {
410   return getNameDecl()->getName();
411 }
412 
413 //===----------------------------------------------------------------------===//
414 // TypeConstraintDecl
415 //===----------------------------------------------------------------------===//
416 
create(Context & ctx,SMRange loc)417 TypeConstraintDecl *TypeConstraintDecl::create(Context &ctx, SMRange loc) {
418   return new (ctx.getAllocator().Allocate<TypeConstraintDecl>())
419       TypeConstraintDecl(loc);
420 }
421 
422 //===----------------------------------------------------------------------===//
423 // TypeRangeConstraintDecl
424 //===----------------------------------------------------------------------===//
425 
create(Context & ctx,SMRange loc)426 TypeRangeConstraintDecl *TypeRangeConstraintDecl::create(Context &ctx,
427                                                          SMRange loc) {
428   return new (ctx.getAllocator().Allocate<TypeRangeConstraintDecl>())
429       TypeRangeConstraintDecl(loc);
430 }
431 
432 //===----------------------------------------------------------------------===//
433 // ValueConstraintDecl
434 //===----------------------------------------------------------------------===//
435 
create(Context & ctx,SMRange loc,Expr * typeExpr)436 ValueConstraintDecl *ValueConstraintDecl::create(Context &ctx, SMRange loc,
437                                                  Expr *typeExpr) {
438   return new (ctx.getAllocator().Allocate<ValueConstraintDecl>())
439       ValueConstraintDecl(loc, typeExpr);
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // ValueRangeConstraintDecl
444 //===----------------------------------------------------------------------===//
445 
446 ValueRangeConstraintDecl *
create(Context & ctx,SMRange loc,Expr * typeExpr)447 ValueRangeConstraintDecl::create(Context &ctx, SMRange loc, Expr *typeExpr) {
448   return new (ctx.getAllocator().Allocate<ValueRangeConstraintDecl>())
449       ValueRangeConstraintDecl(loc, typeExpr);
450 }
451 
452 //===----------------------------------------------------------------------===//
453 // UserConstraintDecl
454 //===----------------------------------------------------------------------===//
455 
456 std::optional<StringRef>
getNativeInputType(unsigned index) const457 UserConstraintDecl::getNativeInputType(unsigned index) const {
458   return hasNativeInputTypes ? getTrailingObjects<StringRef>()[index]
459                              : std::optional<StringRef>();
460 }
461 
createImpl(Context & ctx,const Name & name,ArrayRef<VariableDecl * > inputs,ArrayRef<StringRef> nativeInputTypes,ArrayRef<VariableDecl * > results,std::optional<StringRef> codeBlock,const CompoundStmt * body,Type resultType)462 UserConstraintDecl *UserConstraintDecl::createImpl(
463     Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
464     ArrayRef<StringRef> nativeInputTypes, ArrayRef<VariableDecl *> results,
465     std::optional<StringRef> codeBlock, const CompoundStmt *body,
466     Type resultType) {
467   bool hasNativeInputTypes = !nativeInputTypes.empty();
468   assert(!hasNativeInputTypes || nativeInputTypes.size() == inputs.size());
469 
470   unsigned allocSize =
471       UserConstraintDecl::totalSizeToAlloc<VariableDecl *, StringRef>(
472           inputs.size() + results.size(),
473           hasNativeInputTypes ? inputs.size() : 0);
474   void *rawData =
475       ctx.getAllocator().Allocate(allocSize, alignof(UserConstraintDecl));
476   if (codeBlock)
477     codeBlock = codeBlock->copy(ctx.getAllocator());
478 
479   UserConstraintDecl *decl = new (rawData)
480       UserConstraintDecl(name, inputs.size(), hasNativeInputTypes,
481                          results.size(), codeBlock, body, resultType);
482   std::uninitialized_copy(inputs.begin(), inputs.end(),
483                           decl->getInputs().begin());
484   std::uninitialized_copy(results.begin(), results.end(),
485                           decl->getResults().begin());
486   if (hasNativeInputTypes) {
487     StringRef *nativeInputTypesPtr = decl->getTrailingObjects<StringRef>();
488     for (unsigned i = 0, e = inputs.size(); i < e; ++i)
489       nativeInputTypesPtr[i] = nativeInputTypes[i].copy(ctx.getAllocator());
490   }
491 
492   return decl;
493 }
494 
495 //===----------------------------------------------------------------------===//
496 // NamedAttributeDecl
497 //===----------------------------------------------------------------------===//
498 
create(Context & ctx,const Name & name,Expr * value)499 NamedAttributeDecl *NamedAttributeDecl::create(Context &ctx, const Name &name,
500                                                Expr *value) {
501   return new (ctx.getAllocator().Allocate<NamedAttributeDecl>())
502       NamedAttributeDecl(name, value);
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // OpNameDecl
507 //===----------------------------------------------------------------------===//
508 
create(Context & ctx,const Name & name)509 OpNameDecl *OpNameDecl::create(Context &ctx, const Name &name) {
510   return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(name);
511 }
create(Context & ctx,SMRange loc)512 OpNameDecl *OpNameDecl::create(Context &ctx, SMRange loc) {
513   return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(loc);
514 }
515 
516 //===----------------------------------------------------------------------===//
517 // PatternDecl
518 //===----------------------------------------------------------------------===//
519 
create(Context & ctx,SMRange loc,const Name * name,std::optional<uint16_t> benefit,bool hasBoundedRecursion,const CompoundStmt * body)520 PatternDecl *PatternDecl::create(Context &ctx, SMRange loc, const Name *name,
521                                  std::optional<uint16_t> benefit,
522                                  bool hasBoundedRecursion,
523                                  const CompoundStmt *body) {
524   return new (ctx.getAllocator().Allocate<PatternDecl>())
525       PatternDecl(loc, name, benefit, hasBoundedRecursion, body);
526 }
527 
528 //===----------------------------------------------------------------------===//
529 // UserRewriteDecl
530 //===----------------------------------------------------------------------===//
531 
createImpl(Context & ctx,const Name & name,ArrayRef<VariableDecl * > inputs,ArrayRef<VariableDecl * > results,std::optional<StringRef> codeBlock,const CompoundStmt * body,Type resultType)532 UserRewriteDecl *UserRewriteDecl::createImpl(Context &ctx, const Name &name,
533                                              ArrayRef<VariableDecl *> inputs,
534                                              ArrayRef<VariableDecl *> results,
535                                              std::optional<StringRef> codeBlock,
536                                              const CompoundStmt *body,
537                                              Type resultType) {
538   unsigned allocSize = UserRewriteDecl::totalSizeToAlloc<VariableDecl *>(
539       inputs.size() + results.size());
540   void *rawData =
541       ctx.getAllocator().Allocate(allocSize, alignof(UserRewriteDecl));
542   if (codeBlock)
543     codeBlock = codeBlock->copy(ctx.getAllocator());
544 
545   UserRewriteDecl *decl = new (rawData) UserRewriteDecl(
546       name, inputs.size(), results.size(), codeBlock, body, resultType);
547   std::uninitialized_copy(inputs.begin(), inputs.end(),
548                           decl->getInputs().begin());
549   std::uninitialized_copy(results.begin(), results.end(),
550                           decl->getResults().begin());
551   return decl;
552 }
553 
554 //===----------------------------------------------------------------------===//
555 // VariableDecl
556 //===----------------------------------------------------------------------===//
557 
create(Context & ctx,const Name & name,Type type,Expr * initExpr,ArrayRef<ConstraintRef> constraints)558 VariableDecl *VariableDecl::create(Context &ctx, const Name &name, Type type,
559                                    Expr *initExpr,
560                                    ArrayRef<ConstraintRef> constraints) {
561   unsigned allocSize =
562       VariableDecl::totalSizeToAlloc<ConstraintRef>(constraints.size());
563   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(VariableDecl));
564 
565   VariableDecl *varDecl =
566       new (rawData) VariableDecl(name, type, initExpr, constraints.size());
567   std::uninitialized_copy(constraints.begin(), constraints.end(),
568                           varDecl->getConstraints().begin());
569   return varDecl;
570 }
571 
572 //===----------------------------------------------------------------------===//
573 // Module
574 //===----------------------------------------------------------------------===//
575 
create(Context & ctx,SMLoc loc,ArrayRef<Decl * > children)576 Module *Module::create(Context &ctx, SMLoc loc, ArrayRef<Decl *> children) {
577   unsigned allocSize = Module::totalSizeToAlloc<Decl *>(children.size());
578   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(Module));
579 
580   Module *module = new (rawData) Module(loc, children.size());
581   std::uninitialized_copy(children.begin(), children.end(),
582                           module->getChildren().begin());
583   return module;
584 }
585