xref: /llvm-project/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp (revision b69f10f5a118ec6efe5edfae2deb376b53877170)
1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
10 
11 #include "PredicateTree.h"
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/ScopedHashTable.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_CONVERTPDLTOPDLINTERP
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::pdl_to_pdl_interp;
30 
31 //===----------------------------------------------------------------------===//
32 // PatternLowering
33 //===----------------------------------------------------------------------===//
34 
35 namespace {
36 /// This class generators operations within the PDL Interpreter dialect from a
37 /// given module containing PDL pattern operations.
38 struct PatternLowering {
39 public:
40   PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule);
41 
42   /// Generate code for matching and rewriting based on the pattern operations
43   /// within the module.
44   void lower(ModuleOp module);
45 
46 private:
47   using ValueMap = llvm::ScopedHashTable<Position *, Value>;
48   using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
49 
50   /// Generate interpreter operations for the tree rooted at the given matcher
51   /// node, in the specified region.
52   Block *generateMatcher(MatcherNode &node, Region &region);
53 
54   /// Get or create an access to the provided positional value in the current
55   /// block. This operation may mutate the provided block pointer if nested
56   /// regions (i.e., pdl_interp.iterate) are required.
57   Value getValueAt(Block *&currentBlock, Position *pos);
58 
59   /// Create the interpreter predicate operations. This operation may mutate the
60   /// provided current block pointer if nested regions (iterates) are required.
61   void generate(BoolNode *boolNode, Block *&currentBlock, Value val);
62 
63   /// Create the interpreter switch / predicate operations, with several case
64   /// destinations. This operation never mutates the provided current block
65   /// pointer, because the switch operation does not need Values beyond `val`.
66   void generate(SwitchNode *switchNode, Block *currentBlock, Value val);
67 
68   /// Create the interpreter operations to record a successful pattern match
69   /// using the contained root operation. This operation may mutate the current
70   /// block pointer if nested regions (i.e., pdl_interp.iterate) are required.
71   void generate(SuccessNode *successNode, Block *&currentBlock);
72 
73   /// Generate a rewriter function for the given pattern operation, and returns
74   /// a reference to that function.
75   SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
76                                  SmallVectorImpl<Position *> &usedMatchValues);
77 
78   /// Generate the rewriter code for the given operation.
79   void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
80                         DenseMap<Value, Value> &rewriteValues,
81                         function_ref<Value(Value)> mapRewriteValue);
82   void generateRewriter(pdl::AttributeOp attrOp,
83                         DenseMap<Value, Value> &rewriteValues,
84                         function_ref<Value(Value)> mapRewriteValue);
85   void generateRewriter(pdl::EraseOp eraseOp,
86                         DenseMap<Value, Value> &rewriteValues,
87                         function_ref<Value(Value)> mapRewriteValue);
88   void generateRewriter(pdl::OperationOp operationOp,
89                         DenseMap<Value, Value> &rewriteValues,
90                         function_ref<Value(Value)> mapRewriteValue);
91   void generateRewriter(pdl::ReplaceOp replaceOp,
92                         DenseMap<Value, Value> &rewriteValues,
93                         function_ref<Value(Value)> mapRewriteValue);
94   void generateRewriter(pdl::ResultOp resultOp,
95                         DenseMap<Value, Value> &rewriteValues,
96                         function_ref<Value(Value)> mapRewriteValue);
97   void generateRewriter(pdl::ResultsOp resultOp,
98                         DenseMap<Value, Value> &rewriteValues,
99                         function_ref<Value(Value)> mapRewriteValue);
100   void generateRewriter(pdl::TypeOp typeOp,
101                         DenseMap<Value, Value> &rewriteValues,
102                         function_ref<Value(Value)> mapRewriteValue);
103   void generateRewriter(pdl::TypesOp typeOp,
104                         DenseMap<Value, Value> &rewriteValues,
105                         function_ref<Value(Value)> mapRewriteValue);
106 
107   /// Generate the values used for resolving the result types of an operation
108   /// created within a dag rewriter region. If the result types of the operation
109   /// should be inferred, `hasInferredResultTypes` is set to true.
110   void generateOperationResultTypeRewriter(
111       pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
112       SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
113       bool &hasInferredResultTypes);
114 
115   /// A builder to use when generating interpreter operations.
116   OpBuilder builder;
117 
118   /// The matcher function used for all match related logic within PDL patterns.
119   pdl_interp::FuncOp matcherFunc;
120 
121   /// The rewriter module containing the all rewrite related logic within PDL
122   /// patterns.
123   ModuleOp rewriterModule;
124 
125   /// The symbol table of the rewriter module used for insertion.
126   SymbolTable rewriterSymbolTable;
127 
128   /// A scoped map connecting a position with the corresponding interpreter
129   /// value.
130   ValueMap values;
131 
132   /// A stack of blocks used as the failure destination for matcher nodes that
133   /// don't have an explicit failure path.
134   SmallVector<Block *, 8> failureBlockStack;
135 
136   /// A mapping between values defined in a pattern match, and the corresponding
137   /// positional value.
138   DenseMap<Value, Position *> valueToPosition;
139 
140   /// The set of operation values whose whose location will be used for newly
141   /// generated operations.
142   SetVector<Value> locOps;
143 };
144 } // namespace
145 
146 PatternLowering::PatternLowering(pdl_interp::FuncOp matcherFunc,
147                                  ModuleOp rewriterModule)
148     : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
149       rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {}
150 
151 void PatternLowering::lower(ModuleOp module) {
152   PredicateUniquer predicateUniquer;
153   PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
154 
155   // Define top-level scope for the arguments to the matcher function.
156   ValueMapScope topLevelValueScope(values);
157 
158   // Insert the root operation, i.e. argument to the matcher, at the root
159   // position.
160   Block *matcherEntryBlock = &matcherFunc.front();
161   values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
162 
163   // Generate a root matcher node from the provided PDL module.
164   std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
165       module, predicateBuilder, valueToPosition);
166   Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
167   assert(failureBlockStack.empty() && "failed to empty the stack");
168 
169   // After generation, merged the first matched block into the entry.
170   matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
171                                             firstMatcherBlock->getOperations());
172   firstMatcherBlock->erase();
173 }
174 
175 Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region) {
176   // Push a new scope for the values used by this matcher.
177   Block *block = &region.emplaceBlock();
178   ValueMapScope scope(values);
179 
180   // If this is the return node, simply insert the corresponding interpreter
181   // finalize.
182   if (isa<ExitNode>(node)) {
183     builder.setInsertionPointToEnd(block);
184     builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
185     return block;
186   }
187 
188   // Get the next block in the match sequence.
189   // This is intentionally executed first, before we get the value for the
190   // position associated with the node, so that we preserve an "there exist"
191   // semantics: if getting a value requires an upward traversal (going from a
192   // value to its consumers), we want to perform the check on all the consumers
193   // before we pass control to the failure node.
194   std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
195   Block *failureBlock;
196   if (failureNode) {
197     failureBlock = generateMatcher(*failureNode, region);
198     failureBlockStack.push_back(failureBlock);
199   } else {
200     assert(!failureBlockStack.empty() && "expected valid failure block");
201     failureBlock = failureBlockStack.back();
202   }
203 
204   // If this node contains a position, get the corresponding value for this
205   // block.
206   Block *currentBlock = block;
207   Position *position = node.getPosition();
208   Value val = position ? getValueAt(currentBlock, position) : Value();
209 
210   // If this value corresponds to an operation, record that we are going to use
211   // its location as part of a fused location.
212   bool isOperationValue = val && val.getType().isa<pdl::OperationType>();
213   if (isOperationValue)
214     locOps.insert(val);
215 
216   // Dispatch to the correct method based on derived node type.
217   TypeSwitch<MatcherNode *>(&node)
218       .Case<BoolNode, SwitchNode>([&](auto *derivedNode) {
219         this->generate(derivedNode, currentBlock, val);
220       })
221       .Case([&](SuccessNode *successNode) {
222         generate(successNode, currentBlock);
223       });
224 
225   // Pop all the failure blocks that were inserted due to nesting of
226   // pdl_interp.iterate.
227   while (failureBlockStack.back() != failureBlock) {
228     failureBlockStack.pop_back();
229     assert(!failureBlockStack.empty() && "unable to locate failure block");
230   }
231 
232   // Pop the new failure block.
233   if (failureNode)
234     failureBlockStack.pop_back();
235 
236   if (isOperationValue)
237     locOps.remove(val);
238 
239   return block;
240 }
241 
242 Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
243   if (Value val = values.lookup(pos))
244     return val;
245 
246   // Get the value for the parent position.
247   Value parentVal;
248   if (Position *parent = pos->getParent())
249     parentVal = getValueAt(currentBlock, parent);
250 
251   // TODO: Use a location from the position.
252   Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
253   builder.setInsertionPointToEnd(currentBlock);
254   Value value;
255   switch (pos->getKind()) {
256   case Predicates::OperationPos: {
257     auto *operationPos = cast<OperationPosition>(pos);
258     if (operationPos->isOperandDefiningOp())
259       // Standard (downward) traversal which directly follows the defining op.
260       value = builder.create<pdl_interp::GetDefiningOpOp>(
261           loc, builder.getType<pdl::OperationType>(), parentVal);
262     else
263       // A passthrough operation position.
264       value = parentVal;
265     break;
266   }
267   case Predicates::UsersPos: {
268     auto *usersPos = cast<UsersPosition>(pos);
269 
270     // The first operation retrieves the representative value of a range.
271     // This applies only when the parent is a range of values and we were
272     // requested to use a representative value (e.g., upward traversal).
273     if (parentVal.getType().isa<pdl::RangeType>() &&
274         usersPos->useRepresentative())
275       value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
276     else
277       value = parentVal;
278 
279     // The second operation retrieves the users.
280     value = builder.create<pdl_interp::GetUsersOp>(loc, value);
281     break;
282   }
283   case Predicates::ForEachPos: {
284     assert(!failureBlockStack.empty() && "expected valid failure block");
285     auto foreach = builder.create<pdl_interp::ForEachOp>(
286         loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
287     value = foreach.getLoopVariable();
288 
289     // Create the continuation block.
290     Block *continueBlock = builder.createBlock(&foreach.getRegion());
291     builder.create<pdl_interp::ContinueOp>(loc);
292     failureBlockStack.push_back(continueBlock);
293 
294     currentBlock = &foreach.getRegion().front();
295     break;
296   }
297   case Predicates::OperandPos: {
298     auto *operandPos = cast<OperandPosition>(pos);
299     value = builder.create<pdl_interp::GetOperandOp>(
300         loc, builder.getType<pdl::ValueType>(), parentVal,
301         operandPos->getOperandNumber());
302     break;
303   }
304   case Predicates::OperandGroupPos: {
305     auto *operandPos = cast<OperandGroupPosition>(pos);
306     Type valueTy = builder.getType<pdl::ValueType>();
307     value = builder.create<pdl_interp::GetOperandsOp>(
308         loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
309         parentVal, operandPos->getOperandGroupNumber());
310     break;
311   }
312   case Predicates::AttributePos: {
313     auto *attrPos = cast<AttributePosition>(pos);
314     value = builder.create<pdl_interp::GetAttributeOp>(
315         loc, builder.getType<pdl::AttributeType>(), parentVal,
316         attrPos->getName().strref());
317     break;
318   }
319   case Predicates::TypePos: {
320     if (parentVal.getType().isa<pdl::AttributeType>())
321       value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
322     else
323       value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
324     break;
325   }
326   case Predicates::ResultPos: {
327     auto *resPos = cast<ResultPosition>(pos);
328     value = builder.create<pdl_interp::GetResultOp>(
329         loc, builder.getType<pdl::ValueType>(), parentVal,
330         resPos->getResultNumber());
331     break;
332   }
333   case Predicates::ResultGroupPos: {
334     auto *resPos = cast<ResultGroupPosition>(pos);
335     Type valueTy = builder.getType<pdl::ValueType>();
336     value = builder.create<pdl_interp::GetResultsOp>(
337         loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
338         parentVal, resPos->getResultGroupNumber());
339     break;
340   }
341   case Predicates::AttributeLiteralPos: {
342     auto *attrPos = cast<AttributeLiteralPosition>(pos);
343     value =
344         builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
345     break;
346   }
347   case Predicates::TypeLiteralPos: {
348     auto *typePos = cast<TypeLiteralPosition>(pos);
349     Attribute rawTypeAttr = typePos->getValue();
350     if (TypeAttr typeAttr = rawTypeAttr.dyn_cast<TypeAttr>())
351       value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
352     else
353       value = builder.create<pdl_interp::CreateTypesOp>(
354           loc, rawTypeAttr.cast<ArrayAttr>());
355     break;
356   }
357   default:
358     llvm_unreachable("Generating unknown Position getter");
359     break;
360   }
361 
362   values.insert(pos, value);
363   return value;
364 }
365 
366 void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
367                                Value val) {
368   Location loc = val.getLoc();
369   Qualifier *question = boolNode->getQuestion();
370   Qualifier *answer = boolNode->getAnswer();
371   Region *region = currentBlock->getParent();
372 
373   // Execute the getValue queries first, so that we create success
374   // matcher in the correct (possibly nested) region.
375   SmallVector<Value> args;
376   if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
377     args = {getValueAt(currentBlock, equalToQuestion->getValue())};
378   } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
379     for (Position *position : cstQuestion->getArgs())
380       args.push_back(getValueAt(currentBlock, position));
381   }
382 
383   // Generate the matcher in the current (potentially nested) region
384   // and get the failure successor.
385   Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
386   Block *failure = failureBlockStack.back();
387 
388   // Finally, create the predicate.
389   builder.setInsertionPointToEnd(currentBlock);
390   Predicates::Kind kind = question->getKind();
391   switch (kind) {
392   case Predicates::IsNotNullQuestion:
393     builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure);
394     break;
395   case Predicates::OperationNameQuestion: {
396     auto *opNameAnswer = cast<OperationNameAnswer>(answer);
397     builder.create<pdl_interp::CheckOperationNameOp>(
398         loc, val, opNameAnswer->getValue().getStringRef(), success, failure);
399     break;
400   }
401   case Predicates::TypeQuestion: {
402     auto *ans = cast<TypeAnswer>(answer);
403     if (val.getType().isa<pdl::RangeType>())
404       builder.create<pdl_interp::CheckTypesOp>(
405           loc, val, ans->getValue().cast<ArrayAttr>(), success, failure);
406     else
407       builder.create<pdl_interp::CheckTypeOp>(
408           loc, val, ans->getValue().cast<TypeAttr>(), success, failure);
409     break;
410   }
411   case Predicates::AttributeQuestion: {
412     auto *ans = cast<AttributeAnswer>(answer);
413     builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
414                                                  success, failure);
415     break;
416   }
417   case Predicates::OperandCountAtLeastQuestion:
418   case Predicates::OperandCountQuestion:
419     builder.create<pdl_interp::CheckOperandCountOp>(
420         loc, val, cast<UnsignedAnswer>(answer)->getValue(),
421         /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
422         success, failure);
423     break;
424   case Predicates::ResultCountAtLeastQuestion:
425   case Predicates::ResultCountQuestion:
426     builder.create<pdl_interp::CheckResultCountOp>(
427         loc, val, cast<UnsignedAnswer>(answer)->getValue(),
428         /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
429         success, failure);
430     break;
431   case Predicates::EqualToQuestion: {
432     bool trueAnswer = isa<TrueAnswer>(answer);
433     builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(),
434                                            trueAnswer ? success : failure,
435                                            trueAnswer ? failure : success);
436     break;
437   }
438   case Predicates::ConstraintQuestion: {
439     auto *cstQuestion = cast<ConstraintQuestion>(question);
440     builder.create<pdl_interp::ApplyConstraintOp>(loc, cstQuestion->getName(),
441                                                   args, success, failure);
442     break;
443   }
444   default:
445     llvm_unreachable("Generating unknown Predicate operation");
446   }
447 }
448 
449 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
450 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
451                            llvm::MapVector<Qualifier *, Block *> &dests) {
452   std::vector<ValT> values;
453   std::vector<Block *> blocks;
454   values.reserve(dests.size());
455   blocks.reserve(dests.size());
456   for (const auto &it : dests) {
457     blocks.push_back(it.second);
458     values.push_back(cast<PredT>(it.first)->getValue());
459   }
460   builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
461 }
462 
463 void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
464                                Value val) {
465   Qualifier *question = switchNode->getQuestion();
466   Region *region = currentBlock->getParent();
467   Block *defaultDest = failureBlockStack.back();
468 
469   // If the switch question is not an exact answer, i.e. for the `at_least`
470   // cases, we generate a special block sequence.
471   Predicates::Kind kind = question->getKind();
472   if (kind == Predicates::OperandCountAtLeastQuestion ||
473       kind == Predicates::ResultCountAtLeastQuestion) {
474     // Order the children such that the cases are in reverse numerical order.
475     SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
476         llvm::seq<unsigned>(0, switchNode->getChildren().size()));
477     llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
478       return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
479              cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
480     });
481 
482     // Build the destination for each child using the next highest child as a
483     // a failure destination. This essentially creates the following control
484     // flow:
485     //
486     // if (operand_count < 1)
487     //   goto failure
488     // if (child1.match())
489     //   ...
490     //
491     // if (operand_count < 2)
492     //   goto failure
493     // if (child2.match())
494     //   ...
495     //
496     // failure:
497     //   ...
498     //
499     failureBlockStack.push_back(defaultDest);
500     Location loc = val.getLoc();
501     for (unsigned idx : sortedChildren) {
502       auto &child = switchNode->getChild(idx);
503       Block *childBlock = generateMatcher(*child.second, *region);
504       Block *predicateBlock = builder.createBlock(childBlock);
505       builder.setInsertionPointToEnd(predicateBlock);
506       unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
507       switch (kind) {
508       case Predicates::OperandCountAtLeastQuestion:
509         builder.create<pdl_interp::CheckOperandCountOp>(
510             loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
511         break;
512       case Predicates::ResultCountAtLeastQuestion:
513         builder.create<pdl_interp::CheckResultCountOp>(
514             loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
515         break;
516       default:
517         llvm_unreachable("Generating invalid AtLeast operation");
518       }
519       failureBlockStack.back() = predicateBlock;
520     }
521     Block *firstPredicateBlock = failureBlockStack.pop_back_val();
522     currentBlock->getOperations().splice(currentBlock->end(),
523                                          firstPredicateBlock->getOperations());
524     firstPredicateBlock->erase();
525     return;
526   }
527 
528   // Otherwise, generate each of the children and generate an interpreter
529   // switch.
530   llvm::MapVector<Qualifier *, Block *> children;
531   for (auto &it : switchNode->getChildren())
532     children.insert({it.first, generateMatcher(*it.second, *region)});
533   builder.setInsertionPointToEnd(currentBlock);
534 
535   switch (question->getKind()) {
536   case Predicates::OperandCountQuestion:
537     return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
538                           int32_t>(val, defaultDest, builder, children);
539   case Predicates::ResultCountQuestion:
540     return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
541                           int32_t>(val, defaultDest, builder, children);
542   case Predicates::OperationNameQuestion:
543     return createSwitchOp<pdl_interp::SwitchOperationNameOp,
544                           OperationNameAnswer>(val, defaultDest, builder,
545                                                children);
546   case Predicates::TypeQuestion:
547     if (val.getType().isa<pdl::RangeType>()) {
548       return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
549           val, defaultDest, builder, children);
550     }
551     return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
552         val, defaultDest, builder, children);
553   case Predicates::AttributeQuestion:
554     return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
555         val, defaultDest, builder, children);
556   default:
557     llvm_unreachable("Generating unknown switch predicate.");
558   }
559 }
560 
561 void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {
562   pdl::PatternOp pattern = successNode->getPattern();
563   Value root = successNode->getRoot();
564 
565   // Generate a rewriter for the pattern this success node represents, and track
566   // any values used from the match region.
567   SmallVector<Position *, 8> usedMatchValues;
568   SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
569 
570   // Process any values used in the rewrite that are defined in the match.
571   std::vector<Value> mappedMatchValues;
572   mappedMatchValues.reserve(usedMatchValues.size());
573   for (Position *position : usedMatchValues)
574     mappedMatchValues.push_back(getValueAt(currentBlock, position));
575 
576   // Collect the set of operations generated by the rewriter.
577   SmallVector<StringRef, 4> generatedOps;
578   for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>())
579     generatedOps.push_back(*op.name());
580   ArrayAttr generatedOpsAttr;
581   if (!generatedOps.empty())
582     generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
583 
584   // Grab the root kind if present.
585   StringAttr rootKindAttr;
586   if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
587     if (Optional<StringRef> rootKind = rootOp.name())
588       rootKindAttr = builder.getStringAttr(*rootKind);
589 
590   builder.setInsertionPointToEnd(currentBlock);
591   builder.create<pdl_interp::RecordMatchOp>(
592       pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
593       rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(),
594       failureBlockStack.back());
595 }
596 
597 SymbolRefAttr PatternLowering::generateRewriter(
598     pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
599   builder.setInsertionPointToEnd(rewriterModule.getBody());
600   auto rewriterFunc = builder.create<pdl_interp::FuncOp>(
601       pattern.getLoc(), "pdl_generated_rewriter",
602       builder.getFunctionType(llvm::None, llvm::None));
603   rewriterSymbolTable.insert(rewriterFunc);
604 
605   // Generate the rewriter function body.
606   builder.setInsertionPointToEnd(&rewriterFunc.front());
607 
608   // Map an input operand of the pattern to a generated interpreter value.
609   DenseMap<Value, Value> rewriteValues;
610   auto mapRewriteValue = [&](Value oldValue) {
611     Value &newValue = rewriteValues[oldValue];
612     if (newValue)
613       return newValue;
614 
615     // Prefer materializing constants directly when possible.
616     Operation *oldOp = oldValue.getDefiningOp();
617     if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
618       if (Attribute value = attrOp.valueAttr()) {
619         return newValue = builder.create<pdl_interp::CreateAttributeOp>(
620                    attrOp.getLoc(), value);
621       }
622     } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
623       if (TypeAttr type = typeOp.typeAttr()) {
624         return newValue = builder.create<pdl_interp::CreateTypeOp>(
625                    typeOp.getLoc(), type);
626       }
627     } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
628       if (ArrayAttr type = typeOp.typesAttr()) {
629         return newValue = builder.create<pdl_interp::CreateTypesOp>(
630                    typeOp.getLoc(), typeOp.getType(), type);
631       }
632     }
633 
634     // Otherwise, add this as an input to the rewriter.
635     Position *inputPos = valueToPosition.lookup(oldValue);
636     assert(inputPos && "expected value to be a pattern input");
637     usedMatchValues.push_back(inputPos);
638     return newValue = rewriterFunc.front().addArgument(oldValue.getType(),
639                                                        oldValue.getLoc());
640   };
641 
642   // If this is a custom rewriter, simply dispatch to the registered rewrite
643   // method.
644   pdl::RewriteOp rewriter = pattern.getRewriter();
645   if (StringAttr rewriteName = rewriter.nameAttr()) {
646     SmallVector<Value> args;
647     if (rewriter.root())
648       args.push_back(mapRewriteValue(rewriter.root()));
649     auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue);
650     args.append(mappedArgs.begin(), mappedArgs.end());
651     builder.create<pdl_interp::ApplyRewriteOp>(
652         rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args);
653   } else {
654     // Otherwise this is a dag rewriter defined using PDL operations.
655     for (Operation &rewriteOp : *rewriter.getBody()) {
656       llvm::TypeSwitch<Operation *>(&rewriteOp)
657           .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
658                 pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp,
659                 pdl::TypeOp, pdl::TypesOp>([&](auto op) {
660             this->generateRewriter(op, rewriteValues, mapRewriteValue);
661           });
662     }
663   }
664 
665   // Update the signature of the rewrite function.
666   rewriterFunc.setType(builder.getFunctionType(
667       llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
668       /*results=*/llvm::None));
669 
670   builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
671   return SymbolRefAttr::get(
672       builder.getContext(),
673       pdl_interp::PDLInterpDialect::getRewriterModuleName(),
674       SymbolRefAttr::get(rewriterFunc));
675 }
676 
677 void PatternLowering::generateRewriter(
678     pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
679     function_ref<Value(Value)> mapRewriteValue) {
680   SmallVector<Value, 2> arguments;
681   for (Value argument : rewriteOp.args())
682     arguments.push_back(mapRewriteValue(argument));
683   auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
684       rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(),
685       arguments);
686   for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
687     rewriteValues[std::get<0>(it)] = std::get<1>(it);
688 }
689 
690 void PatternLowering::generateRewriter(
691     pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
692     function_ref<Value(Value)> mapRewriteValue) {
693   Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
694       attrOp.getLoc(), attrOp.valueAttr());
695   rewriteValues[attrOp] = newAttr;
696 }
697 
698 void PatternLowering::generateRewriter(
699     pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
700     function_ref<Value(Value)> mapRewriteValue) {
701   builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
702                                       mapRewriteValue(eraseOp.operation()));
703 }
704 
705 void PatternLowering::generateRewriter(
706     pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
707     function_ref<Value(Value)> mapRewriteValue) {
708   SmallVector<Value, 4> operands;
709   for (Value operand : operationOp.operands())
710     operands.push_back(mapRewriteValue(operand));
711 
712   SmallVector<Value, 4> attributes;
713   for (Value attr : operationOp.attributes())
714     attributes.push_back(mapRewriteValue(attr));
715 
716   bool hasInferredResultTypes = false;
717   SmallVector<Value, 2> types;
718   generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
719                                       rewriteValues, hasInferredResultTypes);
720 
721   // Create the new operation.
722   Location loc = operationOp.getLoc();
723   Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
724       loc, *operationOp.name(), types, hasInferredResultTypes, operands,
725       attributes, operationOp.attributeNames());
726   rewriteValues[operationOp.op()] = createdOp;
727 
728   // Generate accesses for any results that have their types constrained.
729   // Handle the case where there is a single range representing all of the
730   // result types.
731   OperandRange resultTys = operationOp.types();
732   if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) {
733     Value &type = rewriteValues[resultTys[0]];
734     if (!type) {
735       auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
736       type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
737     }
738     return;
739   }
740 
741   // Otherwise, populate the individual results.
742   bool seenVariableLength = false;
743   Type valueTy = builder.getType<pdl::ValueType>();
744   Type valueRangeTy = pdl::RangeType::get(valueTy);
745   for (const auto &it : llvm::enumerate(resultTys)) {
746     Value &type = rewriteValues[it.value()];
747     if (type)
748       continue;
749     bool isVariadic = it.value().getType().isa<pdl::RangeType>();
750     seenVariableLength |= isVariadic;
751 
752     // After a variable length result has been seen, we need to use result
753     // groups because the exact index of the result is not statically known.
754     Value resultVal;
755     if (seenVariableLength)
756       resultVal = builder.create<pdl_interp::GetResultsOp>(
757           loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
758     else
759       resultVal = builder.create<pdl_interp::GetResultOp>(
760           loc, valueTy, createdOp, it.index());
761     type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
762   }
763 }
764 
765 void PatternLowering::generateRewriter(
766     pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
767     function_ref<Value(Value)> mapRewriteValue) {
768   SmallVector<Value, 4> replOperands;
769 
770   // If the replacement was another operation, get its results. `pdl` allows
771   // for using an operation for simplicitly, but the interpreter isn't as
772   // user facing.
773   if (Value replOp = replaceOp.replOperation()) {
774     // Don't use replace if we know the replaced operation has no results.
775     auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>();
776     if (!opOp || !opOp.types().empty()) {
777       replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
778           replOp.getLoc(), mapRewriteValue(replOp)));
779     }
780   } else {
781     for (Value operand : replaceOp.replValues())
782       replOperands.push_back(mapRewriteValue(operand));
783   }
784 
785   // If there are no replacement values, just create an erase instead.
786   if (replOperands.empty()) {
787     builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(),
788                                         mapRewriteValue(replaceOp.operation()));
789     return;
790   }
791 
792   builder.create<pdl_interp::ReplaceOp>(
793       replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands);
794 }
795 
796 void PatternLowering::generateRewriter(
797     pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
798     function_ref<Value(Value)> mapRewriteValue) {
799   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
800       resultOp.getLoc(), builder.getType<pdl::ValueType>(),
801       mapRewriteValue(resultOp.parent()), resultOp.index());
802 }
803 
804 void PatternLowering::generateRewriter(
805     pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
806     function_ref<Value(Value)> mapRewriteValue) {
807   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
808       resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()),
809       resultOp.index());
810 }
811 
812 void PatternLowering::generateRewriter(
813     pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
814     function_ref<Value(Value)> mapRewriteValue) {
815   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
816   // type.
817   if (TypeAttr typeAttr = typeOp.typeAttr()) {
818     rewriteValues[typeOp] =
819         builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
820   }
821 }
822 
823 void PatternLowering::generateRewriter(
824     pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
825     function_ref<Value(Value)> mapRewriteValue) {
826   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
827   // type.
828   if (ArrayAttr typeAttr = typeOp.typesAttr()) {
829     rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
830         typeOp.getLoc(), typeOp.getType(), typeAttr);
831   }
832 }
833 
834 void PatternLowering::generateOperationResultTypeRewriter(
835     pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
836     SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
837     bool &hasInferredResultTypes) {
838   Block *rewriterBlock = op->getBlock();
839 
840   // Try to handle resolution for each of the result types individually. This is
841   // preferred over type inferrence because it will allow for us to use existing
842   // types directly, as opposed to trying to rebuild the type list.
843   OperandRange resultTypeValues = op.types();
844   auto tryResolveResultTypes = [&] {
845     types.reserve(resultTypeValues.size());
846     for (const auto &it : llvm::enumerate(resultTypeValues)) {
847       Value resultType = it.value();
848 
849       // Check for an already translated value.
850       if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
851         types.push_back(existingRewriteValue);
852         continue;
853       }
854 
855       // Check for an input from the matcher.
856       if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
857         types.push_back(mapRewriteValue(resultType));
858         continue;
859       }
860 
861       // Otherwise, we couldn't infer the result types. Bail out here to see if
862       // we can infer the types for this operation from another way.
863       types.clear();
864       return failure();
865     }
866     return success();
867   };
868   if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes()))
869     return;
870 
871   // Otherwise, check if the operation has type inference support itself.
872   if (op.hasTypeInference()) {
873     hasInferredResultTypes = true;
874     return;
875   }
876 
877   // Look for an operation that was replaced by `op`. The result types will be
878   // inferred from the results that were replaced.
879   for (OpOperand &use : op.op().getUses()) {
880     // Check that the use corresponds to a ReplaceOp and that it is the
881     // replacement value, not the operation being replaced.
882     pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
883     if (!replOpUser || use.getOperandNumber() == 0)
884       continue;
885     // Make sure the replaced operation was defined before this one. PDL
886     // rewrites only have single block regions, so if the op isn't in the
887     // rewriter block (i.e. the current block of the operation) we already know
888     // it dominates (i.e. it's in the matcher).
889     Value replOpVal = replOpUser.operation();
890     Operation *replacedOp = replOpVal.getDefiningOp();
891     if (replacedOp->getBlock() == rewriterBlock &&
892         !replacedOp->isBeforeInBlock(op))
893       continue;
894 
895     Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
896         replacedOp->getLoc(), mapRewriteValue(replOpVal));
897     types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
898         replacedOp->getLoc(), replacedOpResults));
899     return;
900   }
901 
902   // If the types could not be inferred from any context and there weren't any
903   // explicit result types, assume the user actually meant for the operation to
904   // have no results.
905   if (resultTypeValues.empty())
906     return;
907 
908   // The verifier asserts that the result types of each pdl.operation can be
909   // inferred. If we reach here, there is a bug either in the logic above or
910   // in the verifier for pdl.operation.
911   op->emitOpError() << "unable to infer result type for operation";
912   llvm_unreachable("unable to infer result type for operation");
913 }
914 
915 //===----------------------------------------------------------------------===//
916 // Conversion Pass
917 //===----------------------------------------------------------------------===//
918 
919 namespace {
920 struct PDLToPDLInterpPass
921     : public impl::ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
922   void runOnOperation() final;
923 };
924 } // namespace
925 
926 /// Convert the given module containing PDL pattern operations into a PDL
927 /// Interpreter operations.
928 void PDLToPDLInterpPass::runOnOperation() {
929   ModuleOp module = getOperation();
930 
931   // Create the main matcher function This function contains all of the match
932   // related functionality from patterns in the module.
933   OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
934   auto matcherFunc = builder.create<pdl_interp::FuncOp>(
935       module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
936       builder.getFunctionType(builder.getType<pdl::OperationType>(),
937                               /*results=*/llvm::None),
938       /*attrs=*/llvm::None);
939 
940   // Create a nested module to hold the functions invoked for rewriting the IR
941   // after a successful match.
942   ModuleOp rewriterModule = builder.create<ModuleOp>(
943       module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
944 
945   // Generate the code for the patterns within the module.
946   PatternLowering generator(matcherFunc, rewriterModule);
947   generator.lower(module);
948 
949   // After generation, delete all of the pattern operations.
950   for (pdl::PatternOp pattern :
951        llvm::make_early_inc_range(module.getOps<pdl::PatternOp>()))
952     pattern.erase();
953 }
954 
955 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
956   return std::make_unique<PDLToPDLInterpPass>();
957 }
958