1 //===- Nodes.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Tools/PDLL/AST/Nodes.h"
10 #include "mlir/Tools/PDLL/AST/Context.h"
11 #include "llvm/ADT/SmallPtrSet.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 #include <optional>
14
15 using namespace mlir;
16 using namespace mlir::pdll::ast;
17
18 /// Copy a string reference into the context with a null terminator.
copyStringWithNull(Context & ctx,StringRef str)19 static StringRef copyStringWithNull(Context &ctx, StringRef str) {
20 if (str.empty())
21 return str;
22
23 char *data = ctx.getAllocator().Allocate<char>(str.size() + 1);
24 std::copy(str.begin(), str.end(), data);
25 data[str.size()] = 0;
26 return StringRef(data, str.size());
27 }
28
29 //===----------------------------------------------------------------------===//
30 // Name
31 //===----------------------------------------------------------------------===//
32
create(Context & ctx,StringRef name,SMRange location)33 const Name &Name::create(Context &ctx, StringRef name, SMRange location) {
34 return *new (ctx.getAllocator().Allocate<Name>())
35 Name(copyStringWithNull(ctx, name), location);
36 }
37
38 //===----------------------------------------------------------------------===//
39 // Node
40 //===----------------------------------------------------------------------===//
41
42 namespace {
43 class NodeVisitor {
44 public:
NodeVisitor(function_ref<void (const Node *)> visitFn)45 explicit NodeVisitor(function_ref<void(const Node *)> visitFn)
46 : visitFn(visitFn) {}
47
visit(const Node * node)48 void visit(const Node *node) {
49 if (!node || !alreadyVisited.insert(node).second)
50 return;
51
52 visitFn(node);
53 TypeSwitch<const Node *>(node)
54 .Case<
55 // Statements.
56 const CompoundStmt, const EraseStmt, const LetStmt,
57 const ReplaceStmt, const ReturnStmt, const RewriteStmt,
58
59 // Expressions.
60 const AttributeExpr, const CallExpr, const DeclRefExpr,
61 const MemberAccessExpr, const OperationExpr, const RangeExpr,
62 const TupleExpr, const TypeExpr,
63
64 // Core Constraint Decls.
65 const AttrConstraintDecl, const OpConstraintDecl,
66 const TypeConstraintDecl, const TypeRangeConstraintDecl,
67 const ValueConstraintDecl, const ValueRangeConstraintDecl,
68
69 // Decls.
70 const NamedAttributeDecl, const OpNameDecl, const PatternDecl,
71 const UserConstraintDecl, const UserRewriteDecl, const VariableDecl,
72
73 const Module>(
74 [&](auto derivedNode) { this->visitImpl(derivedNode); })
75 .Default([](const Node *) { llvm_unreachable("unknown AST node"); });
76 }
77
78 private:
visitImpl(const CompoundStmt * stmt)79 void visitImpl(const CompoundStmt *stmt) {
80 for (const Node *child : stmt->getChildren())
81 visit(child);
82 }
visitImpl(const EraseStmt * stmt)83 void visitImpl(const EraseStmt *stmt) { visit(stmt->getRootOpExpr()); }
visitImpl(const LetStmt * stmt)84 void visitImpl(const LetStmt *stmt) { visit(stmt->getVarDecl()); }
visitImpl(const ReplaceStmt * stmt)85 void visitImpl(const ReplaceStmt *stmt) {
86 visit(stmt->getRootOpExpr());
87 for (const Node *child : stmt->getReplExprs())
88 visit(child);
89 }
visitImpl(const ReturnStmt * stmt)90 void visitImpl(const ReturnStmt *stmt) { visit(stmt->getResultExpr()); }
visitImpl(const RewriteStmt * stmt)91 void visitImpl(const RewriteStmt *stmt) {
92 visit(stmt->getRootOpExpr());
93 visit(stmt->getRewriteBody());
94 }
95
visitImpl(const AttributeExpr * expr)96 void visitImpl(const AttributeExpr *expr) {}
visitImpl(const CallExpr * expr)97 void visitImpl(const CallExpr *expr) {
98 visit(expr->getCallableExpr());
99 for (const Node *child : expr->getArguments())
100 visit(child);
101 }
visitImpl(const DeclRefExpr * expr)102 void visitImpl(const DeclRefExpr *expr) { visit(expr->getDecl()); }
visitImpl(const MemberAccessExpr * expr)103 void visitImpl(const MemberAccessExpr *expr) { visit(expr->getParentExpr()); }
visitImpl(const OperationExpr * expr)104 void visitImpl(const OperationExpr *expr) {
105 visit(expr->getNameDecl());
106 for (const Node *child : expr->getOperands())
107 visit(child);
108 for (const Node *child : expr->getResultTypes())
109 visit(child);
110 for (const Node *child : expr->getAttributes())
111 visit(child);
112 }
visitImpl(const RangeExpr * expr)113 void visitImpl(const RangeExpr *expr) {
114 for (const Node *child : expr->getElements())
115 visit(child);
116 }
visitImpl(const TupleExpr * expr)117 void visitImpl(const TupleExpr *expr) {
118 for (const Node *child : expr->getElements())
119 visit(child);
120 }
visitImpl(const TypeExpr * expr)121 void visitImpl(const TypeExpr *expr) {}
122
visitImpl(const AttrConstraintDecl * decl)123 void visitImpl(const AttrConstraintDecl *decl) { visit(decl->getTypeExpr()); }
visitImpl(const OpConstraintDecl * decl)124 void visitImpl(const OpConstraintDecl *decl) { visit(decl->getNameDecl()); }
visitImpl(const TypeConstraintDecl * decl)125 void visitImpl(const TypeConstraintDecl *decl) {}
visitImpl(const TypeRangeConstraintDecl * decl)126 void visitImpl(const TypeRangeConstraintDecl *decl) {}
visitImpl(const ValueConstraintDecl * decl)127 void visitImpl(const ValueConstraintDecl *decl) {
128 visit(decl->getTypeExpr());
129 }
visitImpl(const ValueRangeConstraintDecl * decl)130 void visitImpl(const ValueRangeConstraintDecl *decl) {
131 visit(decl->getTypeExpr());
132 }
133
visitImpl(const NamedAttributeDecl * decl)134 void visitImpl(const NamedAttributeDecl *decl) { visit(decl->getValue()); }
visitImpl(const OpNameDecl * decl)135 void visitImpl(const OpNameDecl *decl) {}
visitImpl(const PatternDecl * decl)136 void visitImpl(const PatternDecl *decl) { visit(decl->getBody()); }
visitImpl(const UserConstraintDecl * decl)137 void visitImpl(const UserConstraintDecl *decl) {
138 for (const Node *child : decl->getInputs())
139 visit(child);
140 for (const Node *child : decl->getResults())
141 visit(child);
142 visit(decl->getBody());
143 }
visitImpl(const UserRewriteDecl * decl)144 void visitImpl(const UserRewriteDecl *decl) {
145 for (const Node *child : decl->getInputs())
146 visit(child);
147 for (const Node *child : decl->getResults())
148 visit(child);
149 visit(decl->getBody());
150 }
visitImpl(const VariableDecl * decl)151 void visitImpl(const VariableDecl *decl) {
152 visit(decl->getInitExpr());
153 for (const ConstraintRef &child : decl->getConstraints())
154 visit(child.constraint);
155 }
156
visitImpl(const Module * module)157 void visitImpl(const Module *module) {
158 for (const Node *child : module->getChildren())
159 visit(child);
160 }
161
162 function_ref<void(const Node *)> visitFn;
163 SmallPtrSet<const Node *, 16> alreadyVisited;
164 };
165 } // namespace
166
walk(function_ref<void (const Node *)> walkFn) const167 void Node::walk(function_ref<void(const Node *)> walkFn) const {
168 return NodeVisitor(walkFn).visit(this);
169 }
170
171 //===----------------------------------------------------------------------===//
172 // DeclScope
173 //===----------------------------------------------------------------------===//
174
add(Decl * decl)175 void DeclScope::add(Decl *decl) {
176 const Name *name = decl->getName();
177 assert(name && "expected a named decl");
178 assert(!decls.count(name->getName()) && "decl with this name already exists");
179 decls.try_emplace(name->getName(), decl);
180 }
181
lookup(StringRef name)182 Decl *DeclScope::lookup(StringRef name) {
183 if (Decl *decl = decls.lookup(name))
184 return decl;
185 return parent ? parent->lookup(name) : nullptr;
186 }
187
188 //===----------------------------------------------------------------------===//
189 // CompoundStmt
190 //===----------------------------------------------------------------------===//
191
create(Context & ctx,SMRange loc,ArrayRef<Stmt * > children)192 CompoundStmt *CompoundStmt::create(Context &ctx, SMRange loc,
193 ArrayRef<Stmt *> children) {
194 unsigned allocSize = CompoundStmt::totalSizeToAlloc<Stmt *>(children.size());
195 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CompoundStmt));
196
197 CompoundStmt *stmt = new (rawData) CompoundStmt(loc, children.size());
198 std::uninitialized_copy(children.begin(), children.end(),
199 stmt->getChildren().begin());
200 return stmt;
201 }
202
203 //===----------------------------------------------------------------------===//
204 // LetStmt
205 //===----------------------------------------------------------------------===//
206
create(Context & ctx,SMRange loc,VariableDecl * varDecl)207 LetStmt *LetStmt::create(Context &ctx, SMRange loc, VariableDecl *varDecl) {
208 return new (ctx.getAllocator().Allocate<LetStmt>()) LetStmt(loc, varDecl);
209 }
210
211 //===----------------------------------------------------------------------===//
212 // OpRewriteStmt
213 //===----------------------------------------------------------------------===//
214
215 //===----------------------------------------------------------------------===//
216 // EraseStmt
217
create(Context & ctx,SMRange loc,Expr * rootOp)218 EraseStmt *EraseStmt::create(Context &ctx, SMRange loc, Expr *rootOp) {
219 return new (ctx.getAllocator().Allocate<EraseStmt>()) EraseStmt(loc, rootOp);
220 }
221
222 //===----------------------------------------------------------------------===//
223 // ReplaceStmt
224
create(Context & ctx,SMRange loc,Expr * rootOp,ArrayRef<Expr * > replExprs)225 ReplaceStmt *ReplaceStmt::create(Context &ctx, SMRange loc, Expr *rootOp,
226 ArrayRef<Expr *> replExprs) {
227 unsigned allocSize = ReplaceStmt::totalSizeToAlloc<Expr *>(replExprs.size());
228 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(ReplaceStmt));
229
230 ReplaceStmt *stmt = new (rawData) ReplaceStmt(loc, rootOp, replExprs.size());
231 std::uninitialized_copy(replExprs.begin(), replExprs.end(),
232 stmt->getReplExprs().begin());
233 return stmt;
234 }
235
236 //===----------------------------------------------------------------------===//
237 // RewriteStmt
238
create(Context & ctx,SMRange loc,Expr * rootOp,CompoundStmt * rewriteBody)239 RewriteStmt *RewriteStmt::create(Context &ctx, SMRange loc, Expr *rootOp,
240 CompoundStmt *rewriteBody) {
241 return new (ctx.getAllocator().Allocate<RewriteStmt>())
242 RewriteStmt(loc, rootOp, rewriteBody);
243 }
244
245 //===----------------------------------------------------------------------===//
246 // ReturnStmt
247 //===----------------------------------------------------------------------===//
248
create(Context & ctx,SMRange loc,Expr * resultExpr)249 ReturnStmt *ReturnStmt::create(Context &ctx, SMRange loc, Expr *resultExpr) {
250 return new (ctx.getAllocator().Allocate<ReturnStmt>())
251 ReturnStmt(loc, resultExpr);
252 }
253
254 //===----------------------------------------------------------------------===//
255 // AttributeExpr
256 //===----------------------------------------------------------------------===//
257
create(Context & ctx,SMRange loc,StringRef value)258 AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc,
259 StringRef value) {
260 return new (ctx.getAllocator().Allocate<AttributeExpr>())
261 AttributeExpr(ctx, loc, copyStringWithNull(ctx, value));
262 }
263
264 //===----------------------------------------------------------------------===//
265 // CallExpr
266 //===----------------------------------------------------------------------===//
267
create(Context & ctx,SMRange loc,Expr * callable,ArrayRef<Expr * > arguments,Type resultType,bool isNegated)268 CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable,
269 ArrayRef<Expr *> arguments, Type resultType,
270 bool isNegated) {
271 unsigned allocSize = CallExpr::totalSizeToAlloc<Expr *>(arguments.size());
272 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr));
273
274 CallExpr *expr = new (rawData)
275 CallExpr(loc, resultType, callable, arguments.size(), isNegated);
276 std::uninitialized_copy(arguments.begin(), arguments.end(),
277 expr->getArguments().begin());
278 return expr;
279 }
280
281 //===----------------------------------------------------------------------===//
282 // DeclRefExpr
283 //===----------------------------------------------------------------------===//
284
create(Context & ctx,SMRange loc,Decl * decl,Type type)285 DeclRefExpr *DeclRefExpr::create(Context &ctx, SMRange loc, Decl *decl,
286 Type type) {
287 return new (ctx.getAllocator().Allocate<DeclRefExpr>())
288 DeclRefExpr(loc, decl, type);
289 }
290
291 //===----------------------------------------------------------------------===//
292 // MemberAccessExpr
293 //===----------------------------------------------------------------------===//
294
create(Context & ctx,SMRange loc,const Expr * parentExpr,StringRef memberName,Type type)295 MemberAccessExpr *MemberAccessExpr::create(Context &ctx, SMRange loc,
296 const Expr *parentExpr,
297 StringRef memberName, Type type) {
298 return new (ctx.getAllocator().Allocate<MemberAccessExpr>()) MemberAccessExpr(
299 loc, parentExpr, memberName.copy(ctx.getAllocator()), type);
300 }
301
302 //===----------------------------------------------------------------------===//
303 // OperationExpr
304 //===----------------------------------------------------------------------===//
305
306 OperationExpr *
create(Context & ctx,SMRange loc,const ods::Operation * odsOp,const OpNameDecl * name,ArrayRef<Expr * > operands,ArrayRef<Expr * > resultTypes,ArrayRef<NamedAttributeDecl * > attributes)307 OperationExpr::create(Context &ctx, SMRange loc, const ods::Operation *odsOp,
308 const OpNameDecl *name, ArrayRef<Expr *> operands,
309 ArrayRef<Expr *> resultTypes,
310 ArrayRef<NamedAttributeDecl *> attributes) {
311 unsigned allocSize =
312 OperationExpr::totalSizeToAlloc<Expr *, NamedAttributeDecl *>(
313 operands.size() + resultTypes.size(), attributes.size());
314 void *rawData =
315 ctx.getAllocator().Allocate(allocSize, alignof(OperationExpr));
316
317 Type resultType = OperationType::get(ctx, name->getName(), odsOp);
318 OperationExpr *opExpr = new (rawData)
319 OperationExpr(loc, resultType, name, operands.size(), resultTypes.size(),
320 attributes.size(), name->getLoc());
321 std::uninitialized_copy(operands.begin(), operands.end(),
322 opExpr->getOperands().begin());
323 std::uninitialized_copy(resultTypes.begin(), resultTypes.end(),
324 opExpr->getResultTypes().begin());
325 std::uninitialized_copy(attributes.begin(), attributes.end(),
326 opExpr->getAttributes().begin());
327 return opExpr;
328 }
329
getName() const330 std::optional<StringRef> OperationExpr::getName() const {
331 return getNameDecl()->getName();
332 }
333
334 //===----------------------------------------------------------------------===//
335 // RangeExpr
336 //===----------------------------------------------------------------------===//
337
create(Context & ctx,SMRange loc,ArrayRef<Expr * > elements,RangeType type)338 RangeExpr *RangeExpr::create(Context &ctx, SMRange loc,
339 ArrayRef<Expr *> elements, RangeType type) {
340 unsigned allocSize = RangeExpr::totalSizeToAlloc<Expr *>(elements.size());
341 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(TupleExpr));
342
343 RangeExpr *expr = new (rawData) RangeExpr(loc, type, elements.size());
344 std::uninitialized_copy(elements.begin(), elements.end(),
345 expr->getElements().begin());
346 return expr;
347 }
348
349 //===----------------------------------------------------------------------===//
350 // TupleExpr
351 //===----------------------------------------------------------------------===//
352
create(Context & ctx,SMRange loc,ArrayRef<Expr * > elements,ArrayRef<StringRef> names)353 TupleExpr *TupleExpr::create(Context &ctx, SMRange loc,
354 ArrayRef<Expr *> elements,
355 ArrayRef<StringRef> names) {
356 unsigned allocSize = TupleExpr::totalSizeToAlloc<Expr *>(elements.size());
357 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(TupleExpr));
358
359 auto elementTypes = llvm::map_range(
360 elements, [](const Expr *expr) { return expr->getType(); });
361 TupleType type = TupleType::get(ctx, llvm::to_vector(elementTypes), names);
362
363 TupleExpr *expr = new (rawData) TupleExpr(loc, type);
364 std::uninitialized_copy(elements.begin(), elements.end(),
365 expr->getElements().begin());
366 return expr;
367 }
368
369 //===----------------------------------------------------------------------===//
370 // TypeExpr
371 //===----------------------------------------------------------------------===//
372
create(Context & ctx,SMRange loc,StringRef value)373 TypeExpr *TypeExpr::create(Context &ctx, SMRange loc, StringRef value) {
374 return new (ctx.getAllocator().Allocate<TypeExpr>())
375 TypeExpr(ctx, loc, copyStringWithNull(ctx, value));
376 }
377
378 //===----------------------------------------------------------------------===//
379 // Decl
380 //===----------------------------------------------------------------------===//
381
setDocComment(Context & ctx,StringRef comment)382 void Decl::setDocComment(Context &ctx, StringRef comment) {
383 docComment = comment.copy(ctx.getAllocator());
384 }
385
386 //===----------------------------------------------------------------------===//
387 // AttrConstraintDecl
388 //===----------------------------------------------------------------------===//
389
create(Context & ctx,SMRange loc,Expr * typeExpr)390 AttrConstraintDecl *AttrConstraintDecl::create(Context &ctx, SMRange loc,
391 Expr *typeExpr) {
392 return new (ctx.getAllocator().Allocate<AttrConstraintDecl>())
393 AttrConstraintDecl(loc, typeExpr);
394 }
395
396 //===----------------------------------------------------------------------===//
397 // OpConstraintDecl
398 //===----------------------------------------------------------------------===//
399
create(Context & ctx,SMRange loc,const OpNameDecl * nameDecl)400 OpConstraintDecl *OpConstraintDecl::create(Context &ctx, SMRange loc,
401 const OpNameDecl *nameDecl) {
402 if (!nameDecl)
403 nameDecl = OpNameDecl::create(ctx, SMRange());
404
405 return new (ctx.getAllocator().Allocate<OpConstraintDecl>())
406 OpConstraintDecl(loc, nameDecl);
407 }
408
getName() const409 std::optional<StringRef> OpConstraintDecl::getName() const {
410 return getNameDecl()->getName();
411 }
412
413 //===----------------------------------------------------------------------===//
414 // TypeConstraintDecl
415 //===----------------------------------------------------------------------===//
416
create(Context & ctx,SMRange loc)417 TypeConstraintDecl *TypeConstraintDecl::create(Context &ctx, SMRange loc) {
418 return new (ctx.getAllocator().Allocate<TypeConstraintDecl>())
419 TypeConstraintDecl(loc);
420 }
421
422 //===----------------------------------------------------------------------===//
423 // TypeRangeConstraintDecl
424 //===----------------------------------------------------------------------===//
425
create(Context & ctx,SMRange loc)426 TypeRangeConstraintDecl *TypeRangeConstraintDecl::create(Context &ctx,
427 SMRange loc) {
428 return new (ctx.getAllocator().Allocate<TypeRangeConstraintDecl>())
429 TypeRangeConstraintDecl(loc);
430 }
431
432 //===----------------------------------------------------------------------===//
433 // ValueConstraintDecl
434 //===----------------------------------------------------------------------===//
435
create(Context & ctx,SMRange loc,Expr * typeExpr)436 ValueConstraintDecl *ValueConstraintDecl::create(Context &ctx, SMRange loc,
437 Expr *typeExpr) {
438 return new (ctx.getAllocator().Allocate<ValueConstraintDecl>())
439 ValueConstraintDecl(loc, typeExpr);
440 }
441
442 //===----------------------------------------------------------------------===//
443 // ValueRangeConstraintDecl
444 //===----------------------------------------------------------------------===//
445
446 ValueRangeConstraintDecl *
create(Context & ctx,SMRange loc,Expr * typeExpr)447 ValueRangeConstraintDecl::create(Context &ctx, SMRange loc, Expr *typeExpr) {
448 return new (ctx.getAllocator().Allocate<ValueRangeConstraintDecl>())
449 ValueRangeConstraintDecl(loc, typeExpr);
450 }
451
452 //===----------------------------------------------------------------------===//
453 // UserConstraintDecl
454 //===----------------------------------------------------------------------===//
455
456 std::optional<StringRef>
getNativeInputType(unsigned index) const457 UserConstraintDecl::getNativeInputType(unsigned index) const {
458 return hasNativeInputTypes ? getTrailingObjects<StringRef>()[index]
459 : std::optional<StringRef>();
460 }
461
createImpl(Context & ctx,const Name & name,ArrayRef<VariableDecl * > inputs,ArrayRef<StringRef> nativeInputTypes,ArrayRef<VariableDecl * > results,std::optional<StringRef> codeBlock,const CompoundStmt * body,Type resultType)462 UserConstraintDecl *UserConstraintDecl::createImpl(
463 Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
464 ArrayRef<StringRef> nativeInputTypes, ArrayRef<VariableDecl *> results,
465 std::optional<StringRef> codeBlock, const CompoundStmt *body,
466 Type resultType) {
467 bool hasNativeInputTypes = !nativeInputTypes.empty();
468 assert(!hasNativeInputTypes || nativeInputTypes.size() == inputs.size());
469
470 unsigned allocSize =
471 UserConstraintDecl::totalSizeToAlloc<VariableDecl *, StringRef>(
472 inputs.size() + results.size(),
473 hasNativeInputTypes ? inputs.size() : 0);
474 void *rawData =
475 ctx.getAllocator().Allocate(allocSize, alignof(UserConstraintDecl));
476 if (codeBlock)
477 codeBlock = codeBlock->copy(ctx.getAllocator());
478
479 UserConstraintDecl *decl = new (rawData)
480 UserConstraintDecl(name, inputs.size(), hasNativeInputTypes,
481 results.size(), codeBlock, body, resultType);
482 std::uninitialized_copy(inputs.begin(), inputs.end(),
483 decl->getInputs().begin());
484 std::uninitialized_copy(results.begin(), results.end(),
485 decl->getResults().begin());
486 if (hasNativeInputTypes) {
487 StringRef *nativeInputTypesPtr = decl->getTrailingObjects<StringRef>();
488 for (unsigned i = 0, e = inputs.size(); i < e; ++i)
489 nativeInputTypesPtr[i] = nativeInputTypes[i].copy(ctx.getAllocator());
490 }
491
492 return decl;
493 }
494
495 //===----------------------------------------------------------------------===//
496 // NamedAttributeDecl
497 //===----------------------------------------------------------------------===//
498
create(Context & ctx,const Name & name,Expr * value)499 NamedAttributeDecl *NamedAttributeDecl::create(Context &ctx, const Name &name,
500 Expr *value) {
501 return new (ctx.getAllocator().Allocate<NamedAttributeDecl>())
502 NamedAttributeDecl(name, value);
503 }
504
505 //===----------------------------------------------------------------------===//
506 // OpNameDecl
507 //===----------------------------------------------------------------------===//
508
create(Context & ctx,const Name & name)509 OpNameDecl *OpNameDecl::create(Context &ctx, const Name &name) {
510 return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(name);
511 }
create(Context & ctx,SMRange loc)512 OpNameDecl *OpNameDecl::create(Context &ctx, SMRange loc) {
513 return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(loc);
514 }
515
516 //===----------------------------------------------------------------------===//
517 // PatternDecl
518 //===----------------------------------------------------------------------===//
519
create(Context & ctx,SMRange loc,const Name * name,std::optional<uint16_t> benefit,bool hasBoundedRecursion,const CompoundStmt * body)520 PatternDecl *PatternDecl::create(Context &ctx, SMRange loc, const Name *name,
521 std::optional<uint16_t> benefit,
522 bool hasBoundedRecursion,
523 const CompoundStmt *body) {
524 return new (ctx.getAllocator().Allocate<PatternDecl>())
525 PatternDecl(loc, name, benefit, hasBoundedRecursion, body);
526 }
527
528 //===----------------------------------------------------------------------===//
529 // UserRewriteDecl
530 //===----------------------------------------------------------------------===//
531
createImpl(Context & ctx,const Name & name,ArrayRef<VariableDecl * > inputs,ArrayRef<VariableDecl * > results,std::optional<StringRef> codeBlock,const CompoundStmt * body,Type resultType)532 UserRewriteDecl *UserRewriteDecl::createImpl(Context &ctx, const Name &name,
533 ArrayRef<VariableDecl *> inputs,
534 ArrayRef<VariableDecl *> results,
535 std::optional<StringRef> codeBlock,
536 const CompoundStmt *body,
537 Type resultType) {
538 unsigned allocSize = UserRewriteDecl::totalSizeToAlloc<VariableDecl *>(
539 inputs.size() + results.size());
540 void *rawData =
541 ctx.getAllocator().Allocate(allocSize, alignof(UserRewriteDecl));
542 if (codeBlock)
543 codeBlock = codeBlock->copy(ctx.getAllocator());
544
545 UserRewriteDecl *decl = new (rawData) UserRewriteDecl(
546 name, inputs.size(), results.size(), codeBlock, body, resultType);
547 std::uninitialized_copy(inputs.begin(), inputs.end(),
548 decl->getInputs().begin());
549 std::uninitialized_copy(results.begin(), results.end(),
550 decl->getResults().begin());
551 return decl;
552 }
553
554 //===----------------------------------------------------------------------===//
555 // VariableDecl
556 //===----------------------------------------------------------------------===//
557
create(Context & ctx,const Name & name,Type type,Expr * initExpr,ArrayRef<ConstraintRef> constraints)558 VariableDecl *VariableDecl::create(Context &ctx, const Name &name, Type type,
559 Expr *initExpr,
560 ArrayRef<ConstraintRef> constraints) {
561 unsigned allocSize =
562 VariableDecl::totalSizeToAlloc<ConstraintRef>(constraints.size());
563 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(VariableDecl));
564
565 VariableDecl *varDecl =
566 new (rawData) VariableDecl(name, type, initExpr, constraints.size());
567 std::uninitialized_copy(constraints.begin(), constraints.end(),
568 varDecl->getConstraints().begin());
569 return varDecl;
570 }
571
572 //===----------------------------------------------------------------------===//
573 // Module
574 //===----------------------------------------------------------------------===//
575
create(Context & ctx,SMLoc loc,ArrayRef<Decl * > children)576 Module *Module::create(Context &ctx, SMLoc loc, ArrayRef<Decl *> children) {
577 unsigned allocSize = Module::totalSizeToAlloc<Decl *>(children.size());
578 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(Module));
579
580 Module *module = new (rawData) Module(loc, children.size());
581 std::uninitialized_copy(children.begin(), children.end(),
582 module->getChildren().begin());
583 return module;
584 }
585