xref: /llvm-project/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp (revision 41760a6b40c1f14e0622aea4a2ee4b4a93c40ec1)
1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
10 #include "../PassDetail.h"
11 #include "PredicateTree.h"
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/ScopedHashTable.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 using namespace mlir;
24 using namespace mlir::pdl_to_pdl_interp;
25 
26 //===----------------------------------------------------------------------===//
27 // PatternLowering
28 //===----------------------------------------------------------------------===//
29 
30 namespace {
31 /// This class generators operations within the PDL Interpreter dialect from a
32 /// given module containing PDL pattern operations.
33 struct PatternLowering {
34 public:
35   PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule);
36 
37   /// Generate code for matching and rewriting based on the pattern operations
38   /// within the module.
39   void lower(ModuleOp module);
40 
41 private:
42   using ValueMap = llvm::ScopedHashTable<Position *, Value>;
43   using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
44 
45   /// Generate interpreter operations for the tree rooted at the given matcher
46   /// node, in the specified region.
47   Block *generateMatcher(MatcherNode &node, Region &region);
48 
49   /// Get or create an access to the provided positional value in the current
50   /// block. This operation may mutate the provided block pointer if nested
51   /// regions (i.e., pdl_interp.iterate) are required.
52   Value getValueAt(Block *&currentBlock, Position *pos);
53 
54   /// Create the interpreter predicate operations. This operation may mutate the
55   /// provided current block pointer if nested regions (iterates) are required.
56   void generate(BoolNode *boolNode, Block *&currentBlock, Value val);
57 
58   /// Create the interpreter switch / predicate operations, with several case
59   /// destinations. This operation never mutates the provided current block
60   /// pointer, because the switch operation does not need Values beyond `val`.
61   void generate(SwitchNode *switchNode, Block *currentBlock, Value val);
62 
63   /// Create the interpreter operations to record a successful pattern match
64   /// using the contained root operation. This operation may mutate the current
65   /// block pointer if nested regions (i.e., pdl_interp.iterate) are required.
66   void generate(SuccessNode *successNode, Block *&currentBlock);
67 
68   /// Generate a rewriter function for the given pattern operation, and returns
69   /// a reference to that function.
70   SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
71                                  SmallVectorImpl<Position *> &usedMatchValues);
72 
73   /// Generate the rewriter code for the given operation.
74   void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
75                         DenseMap<Value, Value> &rewriteValues,
76                         function_ref<Value(Value)> mapRewriteValue);
77   void generateRewriter(pdl::AttributeOp attrOp,
78                         DenseMap<Value, Value> &rewriteValues,
79                         function_ref<Value(Value)> mapRewriteValue);
80   void generateRewriter(pdl::EraseOp eraseOp,
81                         DenseMap<Value, Value> &rewriteValues,
82                         function_ref<Value(Value)> mapRewriteValue);
83   void generateRewriter(pdl::OperationOp operationOp,
84                         DenseMap<Value, Value> &rewriteValues,
85                         function_ref<Value(Value)> mapRewriteValue);
86   void generateRewriter(pdl::ReplaceOp replaceOp,
87                         DenseMap<Value, Value> &rewriteValues,
88                         function_ref<Value(Value)> mapRewriteValue);
89   void generateRewriter(pdl::ResultOp resultOp,
90                         DenseMap<Value, Value> &rewriteValues,
91                         function_ref<Value(Value)> mapRewriteValue);
92   void generateRewriter(pdl::ResultsOp resultOp,
93                         DenseMap<Value, Value> &rewriteValues,
94                         function_ref<Value(Value)> mapRewriteValue);
95   void generateRewriter(pdl::TypeOp typeOp,
96                         DenseMap<Value, Value> &rewriteValues,
97                         function_ref<Value(Value)> mapRewriteValue);
98   void generateRewriter(pdl::TypesOp typeOp,
99                         DenseMap<Value, Value> &rewriteValues,
100                         function_ref<Value(Value)> mapRewriteValue);
101 
102   /// Generate the values used for resolving the result types of an operation
103   /// created within a dag rewriter region.
104   void generateOperationResultTypeRewriter(
105       pdl::OperationOp op, SmallVectorImpl<Value> &types,
106       DenseMap<Value, Value> &rewriteValues,
107       function_ref<Value(Value)> mapRewriteValue);
108 
109   /// A builder to use when generating interpreter operations.
110   OpBuilder builder;
111 
112   /// The matcher function used for all match related logic within PDL patterns.
113   FuncOp matcherFunc;
114 
115   /// The rewriter module containing the all rewrite related logic within PDL
116   /// patterns.
117   ModuleOp rewriterModule;
118 
119   /// The symbol table of the rewriter module used for insertion.
120   SymbolTable rewriterSymbolTable;
121 
122   /// A scoped map connecting a position with the corresponding interpreter
123   /// value.
124   ValueMap values;
125 
126   /// A stack of blocks used as the failure destination for matcher nodes that
127   /// don't have an explicit failure path.
128   SmallVector<Block *, 8> failureBlockStack;
129 
130   /// A mapping between values defined in a pattern match, and the corresponding
131   /// positional value.
132   DenseMap<Value, Position *> valueToPosition;
133 
134   /// The set of operation values whose whose location will be used for newly
135   /// generated operations.
136   SetVector<Value> locOps;
137 };
138 } // namespace
139 
140 PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule)
141     : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
142       rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {}
143 
144 void PatternLowering::lower(ModuleOp module) {
145   PredicateUniquer predicateUniquer;
146   PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
147 
148   // Define top-level scope for the arguments to the matcher function.
149   ValueMapScope topLevelValueScope(values);
150 
151   // Insert the root operation, i.e. argument to the matcher, at the root
152   // position.
153   Block *matcherEntryBlock = matcherFunc.addEntryBlock();
154   values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
155 
156   // Generate a root matcher node from the provided PDL module.
157   std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
158       module, predicateBuilder, valueToPosition);
159   Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
160   assert(failureBlockStack.empty() && "failed to empty the stack");
161 
162   // After generation, merged the first matched block into the entry.
163   matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
164                                             firstMatcherBlock->getOperations());
165   firstMatcherBlock->erase();
166 }
167 
168 Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region) {
169   // Push a new scope for the values used by this matcher.
170   Block *block = &region.emplaceBlock();
171   ValueMapScope scope(values);
172 
173   // If this is the return node, simply insert the corresponding interpreter
174   // finalize.
175   if (isa<ExitNode>(node)) {
176     builder.setInsertionPointToEnd(block);
177     builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
178     return block;
179   }
180 
181   // Get the next block in the match sequence.
182   // This is intentionally executed first, before we get the value for the
183   // position associated with the node, so that we preserve an "there exist"
184   // semantics: if getting a value requires an upward traversal (going from a
185   // value to its consumers), we want to perform the check on all the consumers
186   // before we pass control to the failure node.
187   std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
188   Block *failureBlock;
189   if (failureNode) {
190     failureBlock = generateMatcher(*failureNode, region);
191     failureBlockStack.push_back(failureBlock);
192   } else {
193     assert(!failureBlockStack.empty() && "expected valid failure block");
194     failureBlock = failureBlockStack.back();
195   }
196 
197   // If this node contains a position, get the corresponding value for this
198   // block.
199   Block *currentBlock = block;
200   Position *position = node.getPosition();
201   Value val = position ? getValueAt(currentBlock, position) : Value();
202 
203   // If this value corresponds to an operation, record that we are going to use
204   // its location as part of a fused location.
205   bool isOperationValue = val && val.getType().isa<pdl::OperationType>();
206   if (isOperationValue)
207     locOps.insert(val);
208 
209   // Dispatch to the correct method based on derived node type.
210   TypeSwitch<MatcherNode *>(&node)
211       .Case<BoolNode, SwitchNode>([&](auto *derivedNode) {
212         this->generate(derivedNode, currentBlock, val);
213       })
214       .Case([&](SuccessNode *successNode) {
215         generate(successNode, currentBlock);
216       });
217 
218   // Pop all the failure blocks that were inserted due to nesting of
219   // pdl_interp.iterate.
220   while (failureBlockStack.back() != failureBlock) {
221     failureBlockStack.pop_back();
222     assert(!failureBlockStack.empty() && "unable to locate failure block");
223   }
224 
225   // Pop the new failure block.
226   if (failureNode)
227     failureBlockStack.pop_back();
228 
229   if (isOperationValue)
230     locOps.remove(val);
231 
232   return block;
233 }
234 
235 Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
236   if (Value val = values.lookup(pos))
237     return val;
238 
239   // Get the value for the parent position.
240   Value parentVal;
241   if (Position *parent = pos->getParent())
242     parentVal = getValueAt(currentBlock, parent);
243 
244   // TODO: Use a location from the position.
245   Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
246   builder.setInsertionPointToEnd(currentBlock);
247   Value value;
248   switch (pos->getKind()) {
249   case Predicates::OperationPos: {
250     auto *operationPos = cast<OperationPosition>(pos);
251     if (operationPos->isOperandDefiningOp())
252       // Standard (downward) traversal which directly follows the defining op.
253       value = builder.create<pdl_interp::GetDefiningOpOp>(
254           loc, builder.getType<pdl::OperationType>(), parentVal);
255     else
256       // A passthrough operation position.
257       value = parentVal;
258     break;
259   }
260   case Predicates::UsersPos: {
261     auto *usersPos = cast<UsersPosition>(pos);
262 
263     // The first operation retrieves the representative value of a range.
264     // This applies only when the parent is a range of values and we were
265     // requested to use a representative value (e.g., upward traversal).
266     if (parentVal.getType().isa<pdl::RangeType>() &&
267         usersPos->useRepresentative())
268       value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
269     else
270       value = parentVal;
271 
272     // The second operation retrieves the users.
273     value = builder.create<pdl_interp::GetUsersOp>(loc, value);
274     break;
275   }
276   case Predicates::ForEachPos: {
277     assert(!failureBlockStack.empty() && "expected valid failure block");
278     auto foreach = builder.create<pdl_interp::ForEachOp>(
279         loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
280     value = foreach.getLoopVariable();
281 
282     // Create the continuation block.
283     Block *continueBlock = builder.createBlock(&foreach.region());
284     builder.create<pdl_interp::ContinueOp>(loc);
285     failureBlockStack.push_back(continueBlock);
286 
287     currentBlock = &foreach.region().front();
288     break;
289   }
290   case Predicates::OperandPos: {
291     auto *operandPos = cast<OperandPosition>(pos);
292     value = builder.create<pdl_interp::GetOperandOp>(
293         loc, builder.getType<pdl::ValueType>(), parentVal,
294         operandPos->getOperandNumber());
295     break;
296   }
297   case Predicates::OperandGroupPos: {
298     auto *operandPos = cast<OperandGroupPosition>(pos);
299     Type valueTy = builder.getType<pdl::ValueType>();
300     value = builder.create<pdl_interp::GetOperandsOp>(
301         loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
302         parentVal, operandPos->getOperandGroupNumber());
303     break;
304   }
305   case Predicates::AttributePos: {
306     auto *attrPos = cast<AttributePosition>(pos);
307     value = builder.create<pdl_interp::GetAttributeOp>(
308         loc, builder.getType<pdl::AttributeType>(), parentVal,
309         attrPos->getName().strref());
310     break;
311   }
312   case Predicates::TypePos: {
313     if (parentVal.getType().isa<pdl::AttributeType>())
314       value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
315     else
316       value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
317     break;
318   }
319   case Predicates::ResultPos: {
320     auto *resPos = cast<ResultPosition>(pos);
321     value = builder.create<pdl_interp::GetResultOp>(
322         loc, builder.getType<pdl::ValueType>(), parentVal,
323         resPos->getResultNumber());
324     break;
325   }
326   case Predicates::ResultGroupPos: {
327     auto *resPos = cast<ResultGroupPosition>(pos);
328     Type valueTy = builder.getType<pdl::ValueType>();
329     value = builder.create<pdl_interp::GetResultsOp>(
330         loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
331         parentVal, resPos->getResultGroupNumber());
332     break;
333   }
334   case Predicates::AttributeLiteralPos: {
335     auto *attrPos = cast<AttributeLiteralPosition>(pos);
336     value =
337         builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
338     break;
339   }
340   case Predicates::TypeLiteralPos: {
341     auto *typePos = cast<TypeLiteralPosition>(pos);
342     Attribute rawTypeAttr = typePos->getValue();
343     if (TypeAttr typeAttr = rawTypeAttr.dyn_cast<TypeAttr>())
344       value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
345     else
346       value = builder.create<pdl_interp::CreateTypesOp>(
347           loc, rawTypeAttr.cast<ArrayAttr>());
348     break;
349   }
350   default:
351     llvm_unreachable("Generating unknown Position getter");
352     break;
353   }
354 
355   values.insert(pos, value);
356   return value;
357 }
358 
359 void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
360                                Value val) {
361   Location loc = val.getLoc();
362   Qualifier *question = boolNode->getQuestion();
363   Qualifier *answer = boolNode->getAnswer();
364   Region *region = currentBlock->getParent();
365 
366   // Execute the getValue queries first, so that we create success
367   // matcher in the correct (possibly nested) region.
368   SmallVector<Value> args;
369   if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
370     args = {getValueAt(currentBlock, equalToQuestion->getValue())};
371   } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
372     for (Position *position : cstQuestion->getArgs())
373       args.push_back(getValueAt(currentBlock, position));
374   }
375 
376   // Generate the matcher in the current (potentially nested) region
377   // and get the failure successor.
378   Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
379   Block *failure = failureBlockStack.back();
380 
381   // Finally, create the predicate.
382   builder.setInsertionPointToEnd(currentBlock);
383   Predicates::Kind kind = question->getKind();
384   switch (kind) {
385   case Predicates::IsNotNullQuestion:
386     builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure);
387     break;
388   case Predicates::OperationNameQuestion: {
389     auto *opNameAnswer = cast<OperationNameAnswer>(answer);
390     builder.create<pdl_interp::CheckOperationNameOp>(
391         loc, val, opNameAnswer->getValue().getStringRef(), success, failure);
392     break;
393   }
394   case Predicates::TypeQuestion: {
395     auto *ans = cast<TypeAnswer>(answer);
396     if (val.getType().isa<pdl::RangeType>())
397       builder.create<pdl_interp::CheckTypesOp>(
398           loc, val, ans->getValue().cast<ArrayAttr>(), success, failure);
399     else
400       builder.create<pdl_interp::CheckTypeOp>(
401           loc, val, ans->getValue().cast<TypeAttr>(), success, failure);
402     break;
403   }
404   case Predicates::AttributeQuestion: {
405     auto *ans = cast<AttributeAnswer>(answer);
406     builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
407                                                  success, failure);
408     break;
409   }
410   case Predicates::OperandCountAtLeastQuestion:
411   case Predicates::OperandCountQuestion:
412     builder.create<pdl_interp::CheckOperandCountOp>(
413         loc, val, cast<UnsignedAnswer>(answer)->getValue(),
414         /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
415         success, failure);
416     break;
417   case Predicates::ResultCountAtLeastQuestion:
418   case Predicates::ResultCountQuestion:
419     builder.create<pdl_interp::CheckResultCountOp>(
420         loc, val, cast<UnsignedAnswer>(answer)->getValue(),
421         /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
422         success, failure);
423     break;
424   case Predicates::EqualToQuestion: {
425     bool trueAnswer = isa<TrueAnswer>(answer);
426     builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(),
427                                            trueAnswer ? success : failure,
428                                            trueAnswer ? failure : success);
429     break;
430   }
431   case Predicates::ConstraintQuestion: {
432     auto *cstQuestion = cast<ConstraintQuestion>(question);
433     builder.create<pdl_interp::ApplyConstraintOp>(
434         loc, cstQuestion->getName(), args, cstQuestion->getParams(), success,
435         failure);
436     break;
437   }
438   default:
439     llvm_unreachable("Generating unknown Predicate operation");
440   }
441 }
442 
443 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
444 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
445                            llvm::MapVector<Qualifier *, Block *> &dests) {
446   std::vector<ValT> values;
447   std::vector<Block *> blocks;
448   values.reserve(dests.size());
449   blocks.reserve(dests.size());
450   for (const auto &it : dests) {
451     blocks.push_back(it.second);
452     values.push_back(cast<PredT>(it.first)->getValue());
453   }
454   builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
455 }
456 
457 void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
458                                Value val) {
459   Qualifier *question = switchNode->getQuestion();
460   Region *region = currentBlock->getParent();
461   Block *defaultDest = failureBlockStack.back();
462 
463   // If the switch question is not an exact answer, i.e. for the `at_least`
464   // cases, we generate a special block sequence.
465   Predicates::Kind kind = question->getKind();
466   if (kind == Predicates::OperandCountAtLeastQuestion ||
467       kind == Predicates::ResultCountAtLeastQuestion) {
468     // Order the children such that the cases are in reverse numerical order.
469     SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
470         llvm::seq<unsigned>(0, switchNode->getChildren().size()));
471     llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
472       return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
473              cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
474     });
475 
476     // Build the destination for each child using the next highest child as a
477     // a failure destination. This essentially creates the following control
478     // flow:
479     //
480     // if (operand_count < 1)
481     //   goto failure
482     // if (child1.match())
483     //   ...
484     //
485     // if (operand_count < 2)
486     //   goto failure
487     // if (child2.match())
488     //   ...
489     //
490     // failure:
491     //   ...
492     //
493     failureBlockStack.push_back(defaultDest);
494     Location loc = val.getLoc();
495     for (unsigned idx : sortedChildren) {
496       auto &child = switchNode->getChild(idx);
497       Block *childBlock = generateMatcher(*child.second, *region);
498       Block *predicateBlock = builder.createBlock(childBlock);
499       builder.setInsertionPointToEnd(predicateBlock);
500       unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
501       switch (kind) {
502       case Predicates::OperandCountAtLeastQuestion:
503         builder.create<pdl_interp::CheckOperandCountOp>(
504             loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
505         break;
506       case Predicates::ResultCountAtLeastQuestion:
507         builder.create<pdl_interp::CheckResultCountOp>(
508             loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
509         break;
510       default:
511         llvm_unreachable("Generating invalid AtLeast operation");
512       }
513       failureBlockStack.back() = predicateBlock;
514     }
515     Block *firstPredicateBlock = failureBlockStack.pop_back_val();
516     currentBlock->getOperations().splice(currentBlock->end(),
517                                          firstPredicateBlock->getOperations());
518     firstPredicateBlock->erase();
519     return;
520   }
521 
522   // Otherwise, generate each of the children and generate an interpreter
523   // switch.
524   llvm::MapVector<Qualifier *, Block *> children;
525   for (auto &it : switchNode->getChildren())
526     children.insert({it.first, generateMatcher(*it.second, *region)});
527   builder.setInsertionPointToEnd(currentBlock);
528 
529   switch (question->getKind()) {
530   case Predicates::OperandCountQuestion:
531     return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
532                           int32_t>(val, defaultDest, builder, children);
533   case Predicates::ResultCountQuestion:
534     return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
535                           int32_t>(val, defaultDest, builder, children);
536   case Predicates::OperationNameQuestion:
537     return createSwitchOp<pdl_interp::SwitchOperationNameOp,
538                           OperationNameAnswer>(val, defaultDest, builder,
539                                                children);
540   case Predicates::TypeQuestion:
541     if (val.getType().isa<pdl::RangeType>()) {
542       return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
543           val, defaultDest, builder, children);
544     }
545     return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
546         val, defaultDest, builder, children);
547   case Predicates::AttributeQuestion:
548     return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
549         val, defaultDest, builder, children);
550   default:
551     llvm_unreachable("Generating unknown switch predicate.");
552   }
553 }
554 
555 void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {
556   pdl::PatternOp pattern = successNode->getPattern();
557   Value root = successNode->getRoot();
558 
559   // Generate a rewriter for the pattern this success node represents, and track
560   // any values used from the match region.
561   SmallVector<Position *, 8> usedMatchValues;
562   SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
563 
564   // Process any values used in the rewrite that are defined in the match.
565   std::vector<Value> mappedMatchValues;
566   mappedMatchValues.reserve(usedMatchValues.size());
567   for (Position *position : usedMatchValues)
568     mappedMatchValues.push_back(getValueAt(currentBlock, position));
569 
570   // Collect the set of operations generated by the rewriter.
571   SmallVector<StringRef, 4> generatedOps;
572   for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>())
573     generatedOps.push_back(*op.name());
574   ArrayAttr generatedOpsAttr;
575   if (!generatedOps.empty())
576     generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
577 
578   // Grab the root kind if present.
579   StringAttr rootKindAttr;
580   if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
581     if (Optional<StringRef> rootKind = rootOp.name())
582       rootKindAttr = builder.getStringAttr(*rootKind);
583 
584   builder.setInsertionPointToEnd(currentBlock);
585   builder.create<pdl_interp::RecordMatchOp>(
586       pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
587       rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(),
588       failureBlockStack.back());
589 }
590 
591 SymbolRefAttr PatternLowering::generateRewriter(
592     pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
593   FuncOp rewriterFunc =
594       FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter",
595                      builder.getFunctionType(llvm::None, llvm::None));
596   rewriterSymbolTable.insert(rewriterFunc);
597 
598   // Generate the rewriter function body.
599   builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock());
600 
601   // Map an input operand of the pattern to a generated interpreter value.
602   DenseMap<Value, Value> rewriteValues;
603   auto mapRewriteValue = [&](Value oldValue) {
604     Value &newValue = rewriteValues[oldValue];
605     if (newValue)
606       return newValue;
607 
608     // Prefer materializing constants directly when possible.
609     Operation *oldOp = oldValue.getDefiningOp();
610     if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
611       if (Attribute value = attrOp.valueAttr()) {
612         return newValue = builder.create<pdl_interp::CreateAttributeOp>(
613                    attrOp.getLoc(), value);
614       }
615     } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
616       if (TypeAttr type = typeOp.typeAttr()) {
617         return newValue = builder.create<pdl_interp::CreateTypeOp>(
618                    typeOp.getLoc(), type);
619       }
620     } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
621       if (ArrayAttr type = typeOp.typesAttr()) {
622         return newValue = builder.create<pdl_interp::CreateTypesOp>(
623                    typeOp.getLoc(), typeOp.getType(), type);
624       }
625     }
626 
627     // Otherwise, add this as an input to the rewriter.
628     Position *inputPos = valueToPosition.lookup(oldValue);
629     assert(inputPos && "expected value to be a pattern input");
630     usedMatchValues.push_back(inputPos);
631     return newValue = rewriterFunc.front().addArgument(oldValue.getType());
632   };
633 
634   // If this is a custom rewriter, simply dispatch to the registered rewrite
635   // method.
636   pdl::RewriteOp rewriter = pattern.getRewriter();
637   if (StringAttr rewriteName = rewriter.nameAttr()) {
638     SmallVector<Value> args;
639     if (rewriter.root())
640       args.push_back(mapRewriteValue(rewriter.root()));
641     auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue);
642     args.append(mappedArgs.begin(), mappedArgs.end());
643     builder.create<pdl_interp::ApplyRewriteOp>(
644         rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args,
645         rewriter.externalConstParamsAttr());
646   } else {
647     // Otherwise this is a dag rewriter defined using PDL operations.
648     for (Operation &rewriteOp : *rewriter.getBody()) {
649       llvm::TypeSwitch<Operation *>(&rewriteOp)
650           .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
651                 pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp,
652                 pdl::TypeOp, pdl::TypesOp>([&](auto op) {
653             this->generateRewriter(op, rewriteValues, mapRewriteValue);
654           });
655     }
656   }
657 
658   // Update the signature of the rewrite function.
659   rewriterFunc.setType(builder.getFunctionType(
660       llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
661       /*results=*/llvm::None));
662 
663   builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
664   return SymbolRefAttr::get(
665       builder.getContext(),
666       pdl_interp::PDLInterpDialect::getRewriterModuleName(),
667       SymbolRefAttr::get(rewriterFunc));
668 }
669 
670 void PatternLowering::generateRewriter(
671     pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
672     function_ref<Value(Value)> mapRewriteValue) {
673   SmallVector<Value, 2> arguments;
674   for (Value argument : rewriteOp.args())
675     arguments.push_back(mapRewriteValue(argument));
676   auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
677       rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(),
678       arguments, rewriteOp.constParamsAttr());
679   for (auto it : llvm::zip(rewriteOp.results(), interpOp.results()))
680     rewriteValues[std::get<0>(it)] = std::get<1>(it);
681 }
682 
683 void PatternLowering::generateRewriter(
684     pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
685     function_ref<Value(Value)> mapRewriteValue) {
686   Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
687       attrOp.getLoc(), attrOp.valueAttr());
688   rewriteValues[attrOp] = newAttr;
689 }
690 
691 void PatternLowering::generateRewriter(
692     pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
693     function_ref<Value(Value)> mapRewriteValue) {
694   builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
695                                       mapRewriteValue(eraseOp.operation()));
696 }
697 
698 void PatternLowering::generateRewriter(
699     pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
700     function_ref<Value(Value)> mapRewriteValue) {
701   SmallVector<Value, 4> operands;
702   for (Value operand : operationOp.operands())
703     operands.push_back(mapRewriteValue(operand));
704 
705   SmallVector<Value, 4> attributes;
706   for (Value attr : operationOp.attributes())
707     attributes.push_back(mapRewriteValue(attr));
708 
709   SmallVector<Value, 2> types;
710   generateOperationResultTypeRewriter(operationOp, types, rewriteValues,
711                                       mapRewriteValue);
712 
713   // Create the new operation.
714   Location loc = operationOp.getLoc();
715   Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
716       loc, *operationOp.name(), types, operands, attributes,
717       operationOp.attributeNames());
718   rewriteValues[operationOp.op()] = createdOp;
719 
720   // Generate accesses for any results that have their types constrained.
721   // Handle the case where there is a single range representing all of the
722   // result types.
723   OperandRange resultTys = operationOp.types();
724   if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) {
725     Value &type = rewriteValues[resultTys[0]];
726     if (!type) {
727       auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
728       type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
729     }
730     return;
731   }
732 
733   // Otherwise, populate the individual results.
734   bool seenVariableLength = false;
735   Type valueTy = builder.getType<pdl::ValueType>();
736   Type valueRangeTy = pdl::RangeType::get(valueTy);
737   for (const auto &it : llvm::enumerate(resultTys)) {
738     Value &type = rewriteValues[it.value()];
739     if (type)
740       continue;
741     bool isVariadic = it.value().getType().isa<pdl::RangeType>();
742     seenVariableLength |= isVariadic;
743 
744     // After a variable length result has been seen, we need to use result
745     // groups because the exact index of the result is not statically known.
746     Value resultVal;
747     if (seenVariableLength)
748       resultVal = builder.create<pdl_interp::GetResultsOp>(
749           loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
750     else
751       resultVal = builder.create<pdl_interp::GetResultOp>(
752           loc, valueTy, createdOp, it.index());
753     type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
754   }
755 }
756 
757 void PatternLowering::generateRewriter(
758     pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
759     function_ref<Value(Value)> mapRewriteValue) {
760   SmallVector<Value, 4> replOperands;
761 
762   // If the replacement was another operation, get its results. `pdl` allows
763   // for using an operation for simplicitly, but the interpreter isn't as
764   // user facing.
765   if (Value replOp = replaceOp.replOperation()) {
766     // Don't use replace if we know the replaced operation has no results.
767     auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>();
768     if (!opOp || !opOp.types().empty()) {
769       replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
770           replOp.getLoc(), mapRewriteValue(replOp)));
771     }
772   } else {
773     for (Value operand : replaceOp.replValues())
774       replOperands.push_back(mapRewriteValue(operand));
775   }
776 
777   // If there are no replacement values, just create an erase instead.
778   if (replOperands.empty()) {
779     builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(),
780                                         mapRewriteValue(replaceOp.operation()));
781     return;
782   }
783 
784   builder.create<pdl_interp::ReplaceOp>(
785       replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands);
786 }
787 
788 void PatternLowering::generateRewriter(
789     pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
790     function_ref<Value(Value)> mapRewriteValue) {
791   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
792       resultOp.getLoc(), builder.getType<pdl::ValueType>(),
793       mapRewriteValue(resultOp.parent()), resultOp.index());
794 }
795 
796 void PatternLowering::generateRewriter(
797     pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
798     function_ref<Value(Value)> mapRewriteValue) {
799   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
800       resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()),
801       resultOp.index());
802 }
803 
804 void PatternLowering::generateRewriter(
805     pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
806     function_ref<Value(Value)> mapRewriteValue) {
807   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
808   // type.
809   if (TypeAttr typeAttr = typeOp.typeAttr()) {
810     rewriteValues[typeOp] =
811         builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
812   }
813 }
814 
815 void PatternLowering::generateRewriter(
816     pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
817     function_ref<Value(Value)> mapRewriteValue) {
818   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
819   // type.
820   if (ArrayAttr typeAttr = typeOp.typesAttr()) {
821     rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
822         typeOp.getLoc(), typeOp.getType(), typeAttr);
823   }
824 }
825 
826 void PatternLowering::generateOperationResultTypeRewriter(
827     pdl::OperationOp op, SmallVectorImpl<Value> &types,
828     DenseMap<Value, Value> &rewriteValues,
829     function_ref<Value(Value)> mapRewriteValue) {
830   // Look for an operation that was replaced by `op`. The result types will be
831   // inferred from the results that were replaced.
832   Block *rewriterBlock = op->getBlock();
833   for (OpOperand &use : op.op().getUses()) {
834     // Check that the use corresponds to a ReplaceOp and that it is the
835     // replacement value, not the operation being replaced.
836     pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
837     if (!replOpUser || use.getOperandNumber() == 0)
838       continue;
839     // Make sure the replaced operation was defined before this one.
840     Value replOpVal = replOpUser.operation();
841     Operation *replacedOp = replOpVal.getDefiningOp();
842     if (replacedOp->getBlock() == rewriterBlock &&
843         !replacedOp->isBeforeInBlock(op))
844       continue;
845 
846     Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
847         replacedOp->getLoc(), mapRewriteValue(replOpVal));
848     types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
849         replacedOp->getLoc(), replacedOpResults));
850     return;
851   }
852 
853   // Check if the operation has type inference support.
854   if (op.hasTypeInference()) {
855     types.push_back(builder.create<pdl_interp::InferredTypesOp>(op.getLoc()));
856     return;
857   }
858 
859   // Otherwise, handle inference for each of the result types individually.
860   OperandRange resultTypeValues = op.types();
861   types.reserve(resultTypeValues.size());
862   for (const auto &it : llvm::enumerate(resultTypeValues)) {
863     Value resultType = it.value();
864 
865     // Check for an already translated value.
866     if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
867       types.push_back(existingRewriteValue);
868       continue;
869     }
870 
871     // Check for an input from the matcher.
872     if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
873       types.push_back(mapRewriteValue(resultType));
874       continue;
875     }
876 
877     // The verifier asserts that the result types of each pdl.operation can be
878     // inferred. If we reach here, there is a bug either in the logic above or
879     // in the verifier for pdl.operation.
880     op->emitOpError() << "unable to infer result type for operation";
881     llvm_unreachable("unable to infer result type for operation");
882   }
883 }
884 
885 //===----------------------------------------------------------------------===//
886 // Conversion Pass
887 //===----------------------------------------------------------------------===//
888 
889 namespace {
890 struct PDLToPDLInterpPass
891     : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
892   void runOnOperation() final;
893 };
894 } // namespace
895 
896 /// Convert the given module containing PDL pattern operations into a PDL
897 /// Interpreter operations.
898 void PDLToPDLInterpPass::runOnOperation() {
899   ModuleOp module = getOperation();
900 
901   // Create the main matcher function This function contains all of the match
902   // related functionality from patterns in the module.
903   OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
904   FuncOp matcherFunc = builder.create<FuncOp>(
905       module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
906       builder.getFunctionType(builder.getType<pdl::OperationType>(),
907                               /*results=*/llvm::None),
908       /*attrs=*/llvm::None);
909 
910   // Create a nested module to hold the functions invoked for rewriting the IR
911   // after a successful match.
912   ModuleOp rewriterModule = builder.create<ModuleOp>(
913       module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
914 
915   // Generate the code for the patterns within the module.
916   PatternLowering generator(matcherFunc, rewriterModule);
917   generator.lower(module);
918 
919   // After generation, delete all of the pattern operations.
920   for (pdl::PatternOp pattern :
921        llvm::make_early_inc_range(module.getOps<pdl::PatternOp>()))
922     pattern.erase();
923 }
924 
925 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
926   return std::make_unique<PDLToPDLInterpPass>();
927 }
928