xref: /llvm-project/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp (revision d2353695f8cb864f88475d3a921249b0dcbcc6f4)
1 //===- MLIRGen.cpp --------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/CodeGen/MLIRGen.h"
10 #include "mlir/AsmParser/AsmParser.h"
11 #include "mlir/Dialect/PDL/IR/PDL.h"
12 #include "mlir/Dialect/PDL/IR/PDLOps.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Verifier.h"
17 #include "mlir/Tools/PDLL/AST/Context.h"
18 #include "mlir/Tools/PDLL/AST/Nodes.h"
19 #include "mlir/Tools/PDLL/AST/Types.h"
20 #include "mlir/Tools/PDLL/ODS/Context.h"
21 #include "mlir/Tools/PDLL/ODS/Operation.h"
22 #include "llvm/ADT/ScopedHashTable.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include <optional>
26 
27 using namespace mlir;
28 using namespace mlir::pdll;
29 
30 //===----------------------------------------------------------------------===//
31 // CodeGen
32 //===----------------------------------------------------------------------===//
33 
34 namespace {
35 class CodeGen {
36 public:
CodeGen(MLIRContext * mlirContext,const ast::Context & context,const llvm::SourceMgr & sourceMgr)37   CodeGen(MLIRContext *mlirContext, const ast::Context &context,
38           const llvm::SourceMgr &sourceMgr)
39       : builder(mlirContext), odsContext(context.getODSContext()),
40         sourceMgr(sourceMgr) {
41     // Make sure that the PDL dialect is loaded.
42     mlirContext->loadDialect<pdl::PDLDialect>();
43   }
44 
45   OwningOpRef<ModuleOp> generate(const ast::Module &module);
46 
47 private:
48   /// Generate an MLIR location from the given source location.
49   Location genLoc(llvm::SMLoc loc);
genLoc(llvm::SMRange loc)50   Location genLoc(llvm::SMRange loc) { return genLoc(loc.Start); }
51 
52   /// Generate an MLIR type from the given source type.
53   Type genType(ast::Type type);
54 
55   /// Generate MLIR for the given AST node.
56   void gen(const ast::Node *node);
57 
58   //===--------------------------------------------------------------------===//
59   // Statements
60   //===--------------------------------------------------------------------===//
61 
62   void genImpl(const ast::CompoundStmt *stmt);
63   void genImpl(const ast::EraseStmt *stmt);
64   void genImpl(const ast::LetStmt *stmt);
65   void genImpl(const ast::ReplaceStmt *stmt);
66   void genImpl(const ast::RewriteStmt *stmt);
67   void genImpl(const ast::ReturnStmt *stmt);
68 
69   //===--------------------------------------------------------------------===//
70   // Decls
71   //===--------------------------------------------------------------------===//
72 
73   void genImpl(const ast::UserConstraintDecl *decl);
74   void genImpl(const ast::UserRewriteDecl *decl);
75   void genImpl(const ast::PatternDecl *decl);
76 
77   /// Generate the set of MLIR values defined for the given variable decl, and
78   /// apply any attached constraints.
79   SmallVector<Value> genVar(const ast::VariableDecl *varDecl);
80 
81   /// Generate the value for a variable that does not have an initializer
82   /// expression, i.e. create the PDL value based on the type/constraints of the
83   /// variable.
84   Value genNonInitializerVar(const ast::VariableDecl *varDecl, Location loc);
85 
86   /// Apply the constraints of the given variable to `values`, which correspond
87   /// to the MLIR values of the variable.
88   void applyVarConstraints(const ast::VariableDecl *varDecl, ValueRange values);
89 
90   //===--------------------------------------------------------------------===//
91   // Expressions
92   //===--------------------------------------------------------------------===//
93 
94   Value genSingleExpr(const ast::Expr *expr);
95   SmallVector<Value> genExpr(const ast::Expr *expr);
96   Value genExprImpl(const ast::AttributeExpr *expr);
97   SmallVector<Value> genExprImpl(const ast::CallExpr *expr);
98   SmallVector<Value> genExprImpl(const ast::DeclRefExpr *expr);
99   Value genExprImpl(const ast::MemberAccessExpr *expr);
100   Value genExprImpl(const ast::OperationExpr *expr);
101   Value genExprImpl(const ast::RangeExpr *expr);
102   SmallVector<Value> genExprImpl(const ast::TupleExpr *expr);
103   Value genExprImpl(const ast::TypeExpr *expr);
104 
105   SmallVector<Value> genConstraintCall(const ast::UserConstraintDecl *decl,
106                                        Location loc, ValueRange inputs,
107                                        bool isNegated = false);
108   SmallVector<Value> genRewriteCall(const ast::UserRewriteDecl *decl,
109                                     Location loc, ValueRange inputs);
110   template <typename PDLOpT, typename T>
111   SmallVector<Value> genConstraintOrRewriteCall(const T *decl, Location loc,
112                                                 ValueRange inputs,
113                                                 bool isNegated = false);
114 
115   //===--------------------------------------------------------------------===//
116   // Fields
117   //===--------------------------------------------------------------------===//
118 
119   /// The MLIR builder used for building the resultant IR.
120   OpBuilder builder;
121 
122   /// A map from variable declarations to the MLIR equivalent.
123   using VariableMapTy =
124       llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>;
125   VariableMapTy variables;
126 
127   /// A reference to the ODS context.
128   const ods::Context &odsContext;
129 
130   /// The source manager of the PDLL ast.
131   const llvm::SourceMgr &sourceMgr;
132 };
133 } // namespace
134 
generate(const ast::Module & module)135 OwningOpRef<ModuleOp> CodeGen::generate(const ast::Module &module) {
136   OwningOpRef<ModuleOp> mlirModule =
137       builder.create<ModuleOp>(genLoc(module.getLoc()));
138   builder.setInsertionPointToStart(mlirModule->getBody());
139 
140   // Generate code for each of the decls within the module.
141   for (const ast::Decl *decl : module.getChildren())
142     gen(decl);
143 
144   return mlirModule;
145 }
146 
genLoc(llvm::SMLoc loc)147 Location CodeGen::genLoc(llvm::SMLoc loc) {
148   unsigned fileID = sourceMgr.FindBufferContainingLoc(loc);
149 
150   // TODO: Fix performance issues in SourceMgr::getLineAndColumn so that we can
151   //       use it here.
152   auto &bufferInfo = sourceMgr.getBufferInfo(fileID);
153   unsigned lineNo = bufferInfo.getLineNumber(loc.getPointer());
154   unsigned column =
155       (loc.getPointer() - bufferInfo.getPointerForLineNumber(lineNo)) + 1;
156   auto *buffer = sourceMgr.getMemoryBuffer(fileID);
157 
158   return FileLineColLoc::get(builder.getContext(),
159                              buffer->getBufferIdentifier(), lineNo, column);
160 }
161 
genType(ast::Type type)162 Type CodeGen::genType(ast::Type type) {
163   return TypeSwitch<ast::Type, Type>(type)
164       .Case([&](ast::AttributeType astType) -> Type {
165         return builder.getType<pdl::AttributeType>();
166       })
167       .Case([&](ast::OperationType astType) -> Type {
168         return builder.getType<pdl::OperationType>();
169       })
170       .Case([&](ast::TypeType astType) -> Type {
171         return builder.getType<pdl::TypeType>();
172       })
173       .Case([&](ast::ValueType astType) -> Type {
174         return builder.getType<pdl::ValueType>();
175       })
176       .Case([&](ast::RangeType astType) -> Type {
177         return pdl::RangeType::get(genType(astType.getElementType()));
178       });
179 }
180 
gen(const ast::Node * node)181 void CodeGen::gen(const ast::Node *node) {
182   TypeSwitch<const ast::Node *>(node)
183       .Case<const ast::CompoundStmt, const ast::EraseStmt, const ast::LetStmt,
184             const ast::ReplaceStmt, const ast::RewriteStmt,
185             const ast::ReturnStmt, const ast::UserConstraintDecl,
186             const ast::UserRewriteDecl, const ast::PatternDecl>(
187           [&](auto derivedNode) { this->genImpl(derivedNode); })
188       .Case([&](const ast::Expr *expr) { genExpr(expr); });
189 }
190 
191 //===----------------------------------------------------------------------===//
192 // CodeGen: Statements
193 //===----------------------------------------------------------------------===//
194 
genImpl(const ast::CompoundStmt * stmt)195 void CodeGen::genImpl(const ast::CompoundStmt *stmt) {
196   VariableMapTy::ScopeTy varScope(variables);
197   for (const ast::Stmt *childStmt : stmt->getChildren())
198     gen(childStmt);
199 }
200 
201 /// If the given builder is nested under a PDL PatternOp, build a rewrite
202 /// operation and update the builder to nest under it. This is necessary for
203 /// PDLL operation rewrite statements that are directly nested within a Pattern.
checkAndNestUnderRewriteOp(OpBuilder & builder,Value rootExpr,Location loc)204 static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr,
205                                        Location loc) {
206   if (isa<pdl::PatternOp>(builder.getInsertionBlock()->getParentOp())) {
207     pdl::RewriteOp rewrite =
208         builder.create<pdl::RewriteOp>(loc, rootExpr, /*name=*/StringAttr(),
209                                        /*externalArgs=*/ValueRange());
210     builder.createBlock(&rewrite.getBodyRegion());
211   }
212 }
213 
genImpl(const ast::EraseStmt * stmt)214 void CodeGen::genImpl(const ast::EraseStmt *stmt) {
215   OpBuilder::InsertionGuard insertGuard(builder);
216   Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
217   Location loc = genLoc(stmt->getLoc());
218 
219   // Make sure we are nested in a RewriteOp.
220   OpBuilder::InsertionGuard guard(builder);
221   checkAndNestUnderRewriteOp(builder, rootExpr, loc);
222   builder.create<pdl::EraseOp>(loc, rootExpr);
223 }
224 
genImpl(const ast::LetStmt * stmt)225 void CodeGen::genImpl(const ast::LetStmt *stmt) { genVar(stmt->getVarDecl()); }
226 
genImpl(const ast::ReplaceStmt * stmt)227 void CodeGen::genImpl(const ast::ReplaceStmt *stmt) {
228   OpBuilder::InsertionGuard insertGuard(builder);
229   Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
230   Location loc = genLoc(stmt->getLoc());
231 
232   // Make sure we are nested in a RewriteOp.
233   OpBuilder::InsertionGuard guard(builder);
234   checkAndNestUnderRewriteOp(builder, rootExpr, loc);
235 
236   SmallVector<Value> replValues;
237   for (ast::Expr *replExpr : stmt->getReplExprs())
238     replValues.push_back(genSingleExpr(replExpr));
239 
240   // Check to see if the statement has a replacement operation, or a range of
241   // replacement values.
242   bool usesReplOperation =
243       replValues.size() == 1 &&
244       isa<pdl::OperationType>(replValues.front().getType());
245   builder.create<pdl::ReplaceOp>(
246       loc, rootExpr, usesReplOperation ? replValues[0] : Value(),
247       usesReplOperation ? ValueRange() : ValueRange(replValues));
248 }
249 
genImpl(const ast::RewriteStmt * stmt)250 void CodeGen::genImpl(const ast::RewriteStmt *stmt) {
251   OpBuilder::InsertionGuard insertGuard(builder);
252   Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
253 
254   // Make sure we are nested in a RewriteOp.
255   OpBuilder::InsertionGuard guard(builder);
256   checkAndNestUnderRewriteOp(builder, rootExpr, genLoc(stmt->getLoc()));
257   gen(stmt->getRewriteBody());
258 }
259 
genImpl(const ast::ReturnStmt * stmt)260 void CodeGen::genImpl(const ast::ReturnStmt *stmt) {
261   // ReturnStmt generation is handled by the respective constraint or rewrite
262   // parent node.
263 }
264 
265 //===----------------------------------------------------------------------===//
266 // CodeGen: Decls
267 //===----------------------------------------------------------------------===//
268 
genImpl(const ast::UserConstraintDecl * decl)269 void CodeGen::genImpl(const ast::UserConstraintDecl *decl) {
270   // All PDLL constraints get inlined when called, and the main native
271   // constraint declarations doesn't require any MLIR to be generated, only uses
272   // of it do.
273 }
274 
genImpl(const ast::UserRewriteDecl * decl)275 void CodeGen::genImpl(const ast::UserRewriteDecl *decl) {
276   // All PDLL rewrites get inlined when called, and the main native
277   // rewrite declarations doesn't require any MLIR to be generated, only uses
278   // of it do.
279 }
280 
genImpl(const ast::PatternDecl * decl)281 void CodeGen::genImpl(const ast::PatternDecl *decl) {
282   const ast::Name *name = decl->getName();
283 
284   // FIXME: Properly model HasBoundedRecursion in PDL so that we don't drop it
285   // here.
286   pdl::PatternOp pattern = builder.create<pdl::PatternOp>(
287       genLoc(decl->getLoc()), decl->getBenefit(),
288       name ? std::optional<StringRef>(name->getName())
289            : std::optional<StringRef>());
290 
291   OpBuilder::InsertionGuard savedInsertPoint(builder);
292   builder.setInsertionPointToStart(pattern.getBody());
293   gen(decl->getBody());
294 }
295 
genVar(const ast::VariableDecl * varDecl)296 SmallVector<Value> CodeGen::genVar(const ast::VariableDecl *varDecl) {
297   auto it = variables.begin(varDecl);
298   if (it != variables.end())
299     return *it;
300 
301   // If the variable has an initial value, use that as the base value.
302   // Otherwise, generate a value using the constraint list.
303   SmallVector<Value> values;
304   if (const ast::Expr *initExpr = varDecl->getInitExpr())
305     values = genExpr(initExpr);
306   else
307     values.push_back(genNonInitializerVar(varDecl, genLoc(varDecl->getLoc())));
308 
309   // Apply the constraints of the values of the variable.
310   applyVarConstraints(varDecl, values);
311 
312   variables.insert(varDecl, values);
313   return values;
314 }
315 
genNonInitializerVar(const ast::VariableDecl * varDecl,Location loc)316 Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl,
317                                     Location loc) {
318   // A functor used to generate expressions nested
319   auto getTypeConstraint = [&]() -> Value {
320     for (const ast::ConstraintRef &constraint : varDecl->getConstraints()) {
321       Value typeValue =
322           TypeSwitch<const ast::Node *, Value>(constraint.constraint)
323               .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
324                     ast::ValueRangeConstraintDecl>(
325                   [&, this](auto *cst) -> Value {
326                     if (auto *typeConstraintExpr = cst->getTypeExpr())
327                       return this->genSingleExpr(typeConstraintExpr);
328                     return Value();
329                   })
330               .Default(Value());
331       if (typeValue)
332         return typeValue;
333     }
334     return Value();
335   };
336 
337   // Generate a value based on the type of the variable.
338   ast::Type type = varDecl->getType();
339   Type mlirType = genType(type);
340   if (isa<ast::ValueType>(type))
341     return builder.create<pdl::OperandOp>(loc, mlirType, getTypeConstraint());
342   if (isa<ast::TypeType>(type))
343     return builder.create<pdl::TypeOp>(loc, mlirType, /*type=*/TypeAttr());
344   if (isa<ast::AttributeType>(type))
345     return builder.create<pdl::AttributeOp>(loc, getTypeConstraint());
346   if (ast::OperationType opType = dyn_cast<ast::OperationType>(type)) {
347     Value operands = builder.create<pdl::OperandsOp>(
348         loc, pdl::RangeType::get(builder.getType<pdl::ValueType>()),
349         /*type=*/Value());
350     Value results = builder.create<pdl::TypesOp>(
351         loc, pdl::RangeType::get(builder.getType<pdl::TypeType>()),
352         /*types=*/ArrayAttr());
353     return builder.create<pdl::OperationOp>(
354         loc, opType.getName(), operands, std::nullopt, ValueRange(), results);
355   }
356 
357   if (ast::RangeType rangeTy = dyn_cast<ast::RangeType>(type)) {
358     ast::Type eleTy = rangeTy.getElementType();
359     if (isa<ast::ValueType>(eleTy))
360       return builder.create<pdl::OperandsOp>(loc, mlirType,
361                                              getTypeConstraint());
362     if (isa<ast::TypeType>(eleTy))
363       return builder.create<pdl::TypesOp>(loc, mlirType, /*types=*/ArrayAttr());
364   }
365 
366   llvm_unreachable("invalid non-initialized variable type");
367 }
368 
applyVarConstraints(const ast::VariableDecl * varDecl,ValueRange values)369 void CodeGen::applyVarConstraints(const ast::VariableDecl *varDecl,
370                                   ValueRange values) {
371   // Generate calls to any user constraints that were attached via the
372   // constraint list.
373   for (const ast::ConstraintRef &ref : varDecl->getConstraints())
374     if (const auto *userCst = dyn_cast<ast::UserConstraintDecl>(ref.constraint))
375       genConstraintCall(userCst, genLoc(ref.referenceLoc), values);
376 }
377 
378 //===----------------------------------------------------------------------===//
379 // CodeGen: Expressions
380 //===----------------------------------------------------------------------===//
381 
genSingleExpr(const ast::Expr * expr)382 Value CodeGen::genSingleExpr(const ast::Expr *expr) {
383   return TypeSwitch<const ast::Expr *, Value>(expr)
384       .Case<const ast::AttributeExpr, const ast::MemberAccessExpr,
385             const ast::OperationExpr, const ast::RangeExpr,
386             const ast::TypeExpr>(
387           [&](auto derivedNode) { return this->genExprImpl(derivedNode); })
388       .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
389           [&](auto derivedNode) {
390             SmallVector<Value> results = this->genExprImpl(derivedNode);
391             assert(results.size() == 1 && "expected single expression result");
392             return results[0];
393           });
394 }
395 
genExpr(const ast::Expr * expr)396 SmallVector<Value> CodeGen::genExpr(const ast::Expr *expr) {
397   return TypeSwitch<const ast::Expr *, SmallVector<Value>>(expr)
398       .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
399           [&](auto derivedNode) { return this->genExprImpl(derivedNode); })
400       .Default([&](const ast::Expr *expr) -> SmallVector<Value> {
401         return {genSingleExpr(expr)};
402       });
403 }
404 
genExprImpl(const ast::AttributeExpr * expr)405 Value CodeGen::genExprImpl(const ast::AttributeExpr *expr) {
406   Attribute attr = parseAttribute(expr->getValue(), builder.getContext());
407   assert(attr && "invalid MLIR attribute data");
408   return builder.create<pdl::AttributeOp>(genLoc(expr->getLoc()), attr);
409 }
410 
genExprImpl(const ast::CallExpr * expr)411 SmallVector<Value> CodeGen::genExprImpl(const ast::CallExpr *expr) {
412   Location loc = genLoc(expr->getLoc());
413   SmallVector<Value> arguments;
414   for (const ast::Expr *arg : expr->getArguments())
415     arguments.push_back(genSingleExpr(arg));
416 
417   // Resolve the callable expression of this call.
418   auto *callableExpr = dyn_cast<ast::DeclRefExpr>(expr->getCallableExpr());
419   assert(callableExpr && "unhandled CallExpr callable");
420 
421   // Generate the PDL based on the type of callable.
422   const ast::Decl *callable = callableExpr->getDecl();
423   if (const auto *decl = dyn_cast<ast::UserConstraintDecl>(callable))
424     return genConstraintCall(decl, loc, arguments, expr->getIsNegated());
425   if (const auto *decl = dyn_cast<ast::UserRewriteDecl>(callable))
426     return genRewriteCall(decl, loc, arguments);
427   llvm_unreachable("unhandled CallExpr callable");
428 }
429 
genExprImpl(const ast::DeclRefExpr * expr)430 SmallVector<Value> CodeGen::genExprImpl(const ast::DeclRefExpr *expr) {
431   if (const auto *varDecl = dyn_cast<ast::VariableDecl>(expr->getDecl()))
432     return genVar(varDecl);
433   llvm_unreachable("unknown decl reference expression");
434 }
435 
genExprImpl(const ast::MemberAccessExpr * expr)436 Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) {
437   Location loc = genLoc(expr->getLoc());
438   StringRef name = expr->getMemberName();
439   SmallVector<Value> parentExprs = genExpr(expr->getParentExpr());
440   ast::Type parentType = expr->getParentExpr()->getType();
441 
442   // Handle operation based member access.
443   if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType)) {
444     if (isa<ast::AllResultsMemberAccessExpr>(expr)) {
445       Type mlirType = genType(expr->getType());
446       if (isa<pdl::ValueType>(mlirType))
447         return builder.create<pdl::ResultOp>(loc, mlirType, parentExprs[0],
448                                              builder.getI32IntegerAttr(0));
449       return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]);
450     }
451 
452     const ods::Operation *odsOp = opType.getODSOperation();
453     if (!odsOp) {
454       assert(llvm::isDigit(name[0]) &&
455              "unregistered op only allows numeric indexing");
456       unsigned resultIndex;
457       name.getAsInteger(/*Radix=*/10, resultIndex);
458       IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
459       return builder.create<pdl::ResultOp>(loc, genType(expr->getType()),
460                                            parentExprs[0], index);
461     }
462 
463     // Find the result with the member name or by index.
464     ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
465     unsigned resultIndex = results.size();
466     if (llvm::isDigit(name[0])) {
467       name.getAsInteger(/*Radix=*/10, resultIndex);
468     } else {
469       auto findFn = [&](const ods::OperandOrResult &result) {
470         return result.getName() == name;
471       };
472       resultIndex = llvm::find_if(results, findFn) - results.begin();
473     }
474     assert(resultIndex < results.size() && "invalid result index");
475 
476     // Generate the result access.
477     IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
478     return builder.create<pdl::ResultsOp>(loc, genType(expr->getType()),
479                                           parentExprs[0], index);
480   }
481 
482   // Handle tuple based member access.
483   if (auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
484     auto elementNames = tupleType.getElementNames();
485 
486     // The index is either a numeric index, or a name.
487     unsigned index = 0;
488     if (llvm::isDigit(name[0]))
489       name.getAsInteger(/*Radix=*/10, index);
490     else
491       index = llvm::find(elementNames, name) - elementNames.begin();
492 
493     assert(index < parentExprs.size() && "invalid result index");
494     return parentExprs[index];
495   }
496 
497   llvm_unreachable("unhandled member access expression");
498 }
499 
genExprImpl(const ast::OperationExpr * expr)500 Value CodeGen::genExprImpl(const ast::OperationExpr *expr) {
501   Location loc = genLoc(expr->getLoc());
502   std::optional<StringRef> opName = expr->getName();
503 
504   // Operands.
505   SmallVector<Value> operands;
506   for (const ast::Expr *operand : expr->getOperands())
507     operands.push_back(genSingleExpr(operand));
508 
509   // Attributes.
510   SmallVector<StringRef> attrNames;
511   SmallVector<Value> attrValues;
512   for (const ast::NamedAttributeDecl *attr : expr->getAttributes()) {
513     attrNames.push_back(attr->getName().getName());
514     attrValues.push_back(genSingleExpr(attr->getValue()));
515   }
516 
517   // Results.
518   SmallVector<Value> results;
519   for (const ast::Expr *result : expr->getResultTypes())
520     results.push_back(genSingleExpr(result));
521 
522   return builder.create<pdl::OperationOp>(loc, opName, operands, attrNames,
523                                           attrValues, results);
524 }
525 
genExprImpl(const ast::RangeExpr * expr)526 Value CodeGen::genExprImpl(const ast::RangeExpr *expr) {
527   SmallVector<Value> elements;
528   for (const ast::Expr *element : expr->getElements())
529     llvm::append_range(elements, genExpr(element));
530 
531   return builder.create<pdl::RangeOp>(genLoc(expr->getLoc()),
532                                       genType(expr->getType()), elements);
533 }
534 
genExprImpl(const ast::TupleExpr * expr)535 SmallVector<Value> CodeGen::genExprImpl(const ast::TupleExpr *expr) {
536   SmallVector<Value> elements;
537   for (const ast::Expr *element : expr->getElements())
538     elements.push_back(genSingleExpr(element));
539   return elements;
540 }
541 
genExprImpl(const ast::TypeExpr * expr)542 Value CodeGen::genExprImpl(const ast::TypeExpr *expr) {
543   Type type = parseType(expr->getValue(), builder.getContext());
544   assert(type && "invalid MLIR type data");
545   return builder.create<pdl::TypeOp>(genLoc(expr->getLoc()),
546                                      builder.getType<pdl::TypeType>(),
547                                      TypeAttr::get(type));
548 }
549 
550 SmallVector<Value>
genConstraintCall(const ast::UserConstraintDecl * decl,Location loc,ValueRange inputs,bool isNegated)551 CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc,
552                            ValueRange inputs, bool isNegated) {
553   // Apply any constraints defined on the arguments to the input values.
554   for (auto it : llvm::zip(decl->getInputs(), inputs))
555     applyVarConstraints(std::get<0>(it), std::get<1>(it));
556 
557   // Generate the constraint call.
558   SmallVector<Value> results =
559       genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(
560           decl, loc, inputs, isNegated);
561 
562   // Apply any constraints defined on the results of the constraint.
563   for (auto it : llvm::zip(decl->getResults(), results))
564     applyVarConstraints(std::get<0>(it), std::get<1>(it));
565   return results;
566 }
567 
genRewriteCall(const ast::UserRewriteDecl * decl,Location loc,ValueRange inputs)568 SmallVector<Value> CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl,
569                                            Location loc, ValueRange inputs) {
570   return genConstraintOrRewriteCall<pdl::ApplyNativeRewriteOp>(decl, loc,
571                                                                inputs);
572 }
573 
574 template <typename PDLOpT, typename T>
575 SmallVector<Value>
genConstraintOrRewriteCall(const T * decl,Location loc,ValueRange inputs,bool isNegated)576 CodeGen::genConstraintOrRewriteCall(const T *decl, Location loc,
577                                     ValueRange inputs, bool isNegated) {
578   const ast::CompoundStmt *cstBody = decl->getBody();
579 
580   // If the decl doesn't have a statement body, it is a native decl.
581   if (!cstBody) {
582     ast::Type declResultType = decl->getResultType();
583     SmallVector<Type> resultTypes;
584     if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(declResultType)) {
585       for (ast::Type type : tupleType.getElementTypes())
586         resultTypes.push_back(genType(type));
587     } else {
588       resultTypes.push_back(genType(declResultType));
589     }
590     PDLOpT pdlOp = builder.create<PDLOpT>(loc, resultTypes,
591                                           decl->getName().getName(), inputs);
592     if (isNegated && std::is_same_v<PDLOpT, pdl::ApplyNativeConstraintOp>)
593       cast<pdl::ApplyNativeConstraintOp>(pdlOp).setIsNegated(true);
594     return pdlOp->getResults();
595   }
596 
597   // Otherwise, this is a PDLL decl.
598   VariableMapTy::ScopeTy varScope(variables);
599 
600   // Map the inputs of the call to the decl arguments.
601   // Note: This is only valid because we do not support recursion, meaning
602   // we don't need to worry about conflicting mappings here.
603   for (auto it : llvm::zip(inputs, decl->getInputs()))
604     variables.insert(std::get<1>(it), {std::get<0>(it)});
605 
606   // Visit the body of the call as normal.
607   gen(cstBody);
608 
609   // If the decl has no results, there is nothing to do.
610   if (cstBody->getChildren().empty())
611     return SmallVector<Value>();
612   auto *returnStmt = dyn_cast<ast::ReturnStmt>(cstBody->getChildren().back());
613   if (!returnStmt)
614     return SmallVector<Value>();
615 
616   // Otherwise, grab the results from the return statement.
617   return genExpr(returnStmt->getResultExpr());
618 }
619 
620 //===----------------------------------------------------------------------===//
621 // MLIRGen
622 //===----------------------------------------------------------------------===//
623 
codegenPDLLToMLIR(MLIRContext * mlirContext,const ast::Context & context,const llvm::SourceMgr & sourceMgr,const ast::Module & module)624 OwningOpRef<ModuleOp> mlir::pdll::codegenPDLLToMLIR(
625     MLIRContext *mlirContext, const ast::Context &context,
626     const llvm::SourceMgr &sourceMgr, const ast::Module &module) {
627   CodeGen codegen(mlirContext, context, sourceMgr);
628   OwningOpRef<ModuleOp> mlirModule = codegen.generate(module);
629   if (failed(verify(*mlirModule)))
630     return nullptr;
631   return mlirModule;
632 }
633