xref: /llvm-project/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp (revision 310c3ee4724435464db36148a30c40aaf89bcc1d)
1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
10 
11 #include "PredicateTree.h"
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/ScopedHashTable.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_CONVERTPDLTOPDLINTERP
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::pdl_to_pdl_interp;
30 
31 //===----------------------------------------------------------------------===//
32 // PatternLowering
33 //===----------------------------------------------------------------------===//
34 
35 namespace {
36 /// This class generators operations within the PDL Interpreter dialect from a
37 /// given module containing PDL pattern operations.
38 struct PatternLowering {
39 public:
40   PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule);
41 
42   /// Generate code for matching and rewriting based on the pattern operations
43   /// within the module.
44   void lower(ModuleOp module);
45 
46 private:
47   using ValueMap = llvm::ScopedHashTable<Position *, Value>;
48   using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
49 
50   /// Generate interpreter operations for the tree rooted at the given matcher
51   /// node, in the specified region.
52   Block *generateMatcher(MatcherNode &node, Region &region);
53 
54   /// Get or create an access to the provided positional value in the current
55   /// block. This operation may mutate the provided block pointer if nested
56   /// regions (i.e., pdl_interp.iterate) are required.
57   Value getValueAt(Block *&currentBlock, Position *pos);
58 
59   /// Create the interpreter predicate operations. This operation may mutate the
60   /// provided current block pointer if nested regions (iterates) are required.
61   void generate(BoolNode *boolNode, Block *&currentBlock, Value val);
62 
63   /// Create the interpreter switch / predicate operations, with several case
64   /// destinations. This operation never mutates the provided current block
65   /// pointer, because the switch operation does not need Values beyond `val`.
66   void generate(SwitchNode *switchNode, Block *currentBlock, Value val);
67 
68   /// Create the interpreter operations to record a successful pattern match
69   /// using the contained root operation. This operation may mutate the current
70   /// block pointer if nested regions (i.e., pdl_interp.iterate) are required.
71   void generate(SuccessNode *successNode, Block *&currentBlock);
72 
73   /// Generate a rewriter function for the given pattern operation, and returns
74   /// a reference to that function.
75   SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
76                                  SmallVectorImpl<Position *> &usedMatchValues);
77 
78   /// Generate the rewriter code for the given operation.
79   void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
80                         DenseMap<Value, Value> &rewriteValues,
81                         function_ref<Value(Value)> mapRewriteValue);
82   void generateRewriter(pdl::AttributeOp attrOp,
83                         DenseMap<Value, Value> &rewriteValues,
84                         function_ref<Value(Value)> mapRewriteValue);
85   void generateRewriter(pdl::EraseOp eraseOp,
86                         DenseMap<Value, Value> &rewriteValues,
87                         function_ref<Value(Value)> mapRewriteValue);
88   void generateRewriter(pdl::OperationOp operationOp,
89                         DenseMap<Value, Value> &rewriteValues,
90                         function_ref<Value(Value)> mapRewriteValue);
91   void generateRewriter(pdl::ReplaceOp replaceOp,
92                         DenseMap<Value, Value> &rewriteValues,
93                         function_ref<Value(Value)> mapRewriteValue);
94   void generateRewriter(pdl::ResultOp resultOp,
95                         DenseMap<Value, Value> &rewriteValues,
96                         function_ref<Value(Value)> mapRewriteValue);
97   void generateRewriter(pdl::ResultsOp resultOp,
98                         DenseMap<Value, Value> &rewriteValues,
99                         function_ref<Value(Value)> mapRewriteValue);
100   void generateRewriter(pdl::TypeOp typeOp,
101                         DenseMap<Value, Value> &rewriteValues,
102                         function_ref<Value(Value)> mapRewriteValue);
103   void generateRewriter(pdl::TypesOp typeOp,
104                         DenseMap<Value, Value> &rewriteValues,
105                         function_ref<Value(Value)> mapRewriteValue);
106 
107   /// Generate the values used for resolving the result types of an operation
108   /// created within a dag rewriter region. If the result types of the operation
109   /// should be inferred, `hasInferredResultTypes` is set to true.
110   void generateOperationResultTypeRewriter(
111       pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
112       SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
113       bool &hasInferredResultTypes);
114 
115   /// A builder to use when generating interpreter operations.
116   OpBuilder builder;
117 
118   /// The matcher function used for all match related logic within PDL patterns.
119   pdl_interp::FuncOp matcherFunc;
120 
121   /// The rewriter module containing the all rewrite related logic within PDL
122   /// patterns.
123   ModuleOp rewriterModule;
124 
125   /// The symbol table of the rewriter module used for insertion.
126   SymbolTable rewriterSymbolTable;
127 
128   /// A scoped map connecting a position with the corresponding interpreter
129   /// value.
130   ValueMap values;
131 
132   /// A stack of blocks used as the failure destination for matcher nodes that
133   /// don't have an explicit failure path.
134   SmallVector<Block *, 8> failureBlockStack;
135 
136   /// A mapping between values defined in a pattern match, and the corresponding
137   /// positional value.
138   DenseMap<Value, Position *> valueToPosition;
139 
140   /// The set of operation values whose whose location will be used for newly
141   /// generated operations.
142   SetVector<Value> locOps;
143 };
144 } // namespace
145 
146 PatternLowering::PatternLowering(pdl_interp::FuncOp matcherFunc,
147                                  ModuleOp rewriterModule)
148     : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
149       rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {}
150 
151 void PatternLowering::lower(ModuleOp module) {
152   PredicateUniquer predicateUniquer;
153   PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
154 
155   // Define top-level scope for the arguments to the matcher function.
156   ValueMapScope topLevelValueScope(values);
157 
158   // Insert the root operation, i.e. argument to the matcher, at the root
159   // position.
160   Block *matcherEntryBlock = &matcherFunc.front();
161   values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
162 
163   // Generate a root matcher node from the provided PDL module.
164   std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
165       module, predicateBuilder, valueToPosition);
166   Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
167   assert(failureBlockStack.empty() && "failed to empty the stack");
168 
169   // After generation, merged the first matched block into the entry.
170   matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
171                                             firstMatcherBlock->getOperations());
172   firstMatcherBlock->erase();
173 }
174 
175 Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region) {
176   // Push a new scope for the values used by this matcher.
177   Block *block = &region.emplaceBlock();
178   ValueMapScope scope(values);
179 
180   // If this is the return node, simply insert the corresponding interpreter
181   // finalize.
182   if (isa<ExitNode>(node)) {
183     builder.setInsertionPointToEnd(block);
184     builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
185     return block;
186   }
187 
188   // Get the next block in the match sequence.
189   // This is intentionally executed first, before we get the value for the
190   // position associated with the node, so that we preserve an "there exist"
191   // semantics: if getting a value requires an upward traversal (going from a
192   // value to its consumers), we want to perform the check on all the consumers
193   // before we pass control to the failure node.
194   std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
195   Block *failureBlock;
196   if (failureNode) {
197     failureBlock = generateMatcher(*failureNode, region);
198     failureBlockStack.push_back(failureBlock);
199   } else {
200     assert(!failureBlockStack.empty() && "expected valid failure block");
201     failureBlock = failureBlockStack.back();
202   }
203 
204   // If this node contains a position, get the corresponding value for this
205   // block.
206   Block *currentBlock = block;
207   Position *position = node.getPosition();
208   Value val = position ? getValueAt(currentBlock, position) : Value();
209 
210   // If this value corresponds to an operation, record that we are going to use
211   // its location as part of a fused location.
212   bool isOperationValue = val && val.getType().isa<pdl::OperationType>();
213   if (isOperationValue)
214     locOps.insert(val);
215 
216   // Dispatch to the correct method based on derived node type.
217   TypeSwitch<MatcherNode *>(&node)
218       .Case<BoolNode, SwitchNode>([&](auto *derivedNode) {
219         this->generate(derivedNode, currentBlock, val);
220       })
221       .Case([&](SuccessNode *successNode) {
222         generate(successNode, currentBlock);
223       });
224 
225   // Pop all the failure blocks that were inserted due to nesting of
226   // pdl_interp.iterate.
227   while (failureBlockStack.back() != failureBlock) {
228     failureBlockStack.pop_back();
229     assert(!failureBlockStack.empty() && "unable to locate failure block");
230   }
231 
232   // Pop the new failure block.
233   if (failureNode)
234     failureBlockStack.pop_back();
235 
236   if (isOperationValue)
237     locOps.remove(val);
238 
239   return block;
240 }
241 
242 Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
243   if (Value val = values.lookup(pos))
244     return val;
245 
246   // Get the value for the parent position.
247   Value parentVal;
248   if (Position *parent = pos->getParent())
249     parentVal = getValueAt(currentBlock, parent);
250 
251   // TODO: Use a location from the position.
252   Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
253   builder.setInsertionPointToEnd(currentBlock);
254   Value value;
255   switch (pos->getKind()) {
256   case Predicates::OperationPos: {
257     auto *operationPos = cast<OperationPosition>(pos);
258     if (operationPos->isOperandDefiningOp())
259       // Standard (downward) traversal which directly follows the defining op.
260       value = builder.create<pdl_interp::GetDefiningOpOp>(
261           loc, builder.getType<pdl::OperationType>(), parentVal);
262     else
263       // A passthrough operation position.
264       value = parentVal;
265     break;
266   }
267   case Predicates::UsersPos: {
268     auto *usersPos = cast<UsersPosition>(pos);
269 
270     // The first operation retrieves the representative value of a range.
271     // This applies only when the parent is a range of values and we were
272     // requested to use a representative value (e.g., upward traversal).
273     if (parentVal.getType().isa<pdl::RangeType>() &&
274         usersPos->useRepresentative())
275       value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
276     else
277       value = parentVal;
278 
279     // The second operation retrieves the users.
280     value = builder.create<pdl_interp::GetUsersOp>(loc, value);
281     break;
282   }
283   case Predicates::ForEachPos: {
284     assert(!failureBlockStack.empty() && "expected valid failure block");
285     auto foreach = builder.create<pdl_interp::ForEachOp>(
286         loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
287     value = foreach.getLoopVariable();
288 
289     // Create the continuation block.
290     Block *continueBlock = builder.createBlock(&foreach.getRegion());
291     builder.create<pdl_interp::ContinueOp>(loc);
292     failureBlockStack.push_back(continueBlock);
293 
294     currentBlock = &foreach.getRegion().front();
295     break;
296   }
297   case Predicates::OperandPos: {
298     auto *operandPos = cast<OperandPosition>(pos);
299     value = builder.create<pdl_interp::GetOperandOp>(
300         loc, builder.getType<pdl::ValueType>(), parentVal,
301         operandPos->getOperandNumber());
302     break;
303   }
304   case Predicates::OperandGroupPos: {
305     auto *operandPos = cast<OperandGroupPosition>(pos);
306     Type valueTy = builder.getType<pdl::ValueType>();
307     value = builder.create<pdl_interp::GetOperandsOp>(
308         loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
309         parentVal, operandPos->getOperandGroupNumber());
310     break;
311   }
312   case Predicates::AttributePos: {
313     auto *attrPos = cast<AttributePosition>(pos);
314     value = builder.create<pdl_interp::GetAttributeOp>(
315         loc, builder.getType<pdl::AttributeType>(), parentVal,
316         attrPos->getName().strref());
317     break;
318   }
319   case Predicates::TypePos: {
320     if (parentVal.getType().isa<pdl::AttributeType>())
321       value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
322     else
323       value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
324     break;
325   }
326   case Predicates::ResultPos: {
327     auto *resPos = cast<ResultPosition>(pos);
328     value = builder.create<pdl_interp::GetResultOp>(
329         loc, builder.getType<pdl::ValueType>(), parentVal,
330         resPos->getResultNumber());
331     break;
332   }
333   case Predicates::ResultGroupPos: {
334     auto *resPos = cast<ResultGroupPosition>(pos);
335     Type valueTy = builder.getType<pdl::ValueType>();
336     value = builder.create<pdl_interp::GetResultsOp>(
337         loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
338         parentVal, resPos->getResultGroupNumber());
339     break;
340   }
341   case Predicates::AttributeLiteralPos: {
342     auto *attrPos = cast<AttributeLiteralPosition>(pos);
343     value =
344         builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
345     break;
346   }
347   case Predicates::TypeLiteralPos: {
348     auto *typePos = cast<TypeLiteralPosition>(pos);
349     Attribute rawTypeAttr = typePos->getValue();
350     if (TypeAttr typeAttr = rawTypeAttr.dyn_cast<TypeAttr>())
351       value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
352     else
353       value = builder.create<pdl_interp::CreateTypesOp>(
354           loc, rawTypeAttr.cast<ArrayAttr>());
355     break;
356   }
357   default:
358     llvm_unreachable("Generating unknown Position getter");
359     break;
360   }
361 
362   values.insert(pos, value);
363   return value;
364 }
365 
366 void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
367                                Value val) {
368   Location loc = val.getLoc();
369   Qualifier *question = boolNode->getQuestion();
370   Qualifier *answer = boolNode->getAnswer();
371   Region *region = currentBlock->getParent();
372 
373   // Execute the getValue queries first, so that we create success
374   // matcher in the correct (possibly nested) region.
375   SmallVector<Value> args;
376   if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
377     args = {getValueAt(currentBlock, equalToQuestion->getValue())};
378   } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
379     for (Position *position : cstQuestion->getArgs())
380       args.push_back(getValueAt(currentBlock, position));
381   }
382 
383   // Generate the matcher in the current (potentially nested) region
384   // and get the failure successor.
385   Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
386   Block *failure = failureBlockStack.back();
387 
388   // Finally, create the predicate.
389   builder.setInsertionPointToEnd(currentBlock);
390   Predicates::Kind kind = question->getKind();
391   switch (kind) {
392   case Predicates::IsNotNullQuestion:
393     builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure);
394     break;
395   case Predicates::OperationNameQuestion: {
396     auto *opNameAnswer = cast<OperationNameAnswer>(answer);
397     builder.create<pdl_interp::CheckOperationNameOp>(
398         loc, val, opNameAnswer->getValue().getStringRef(), success, failure);
399     break;
400   }
401   case Predicates::TypeQuestion: {
402     auto *ans = cast<TypeAnswer>(answer);
403     if (val.getType().isa<pdl::RangeType>())
404       builder.create<pdl_interp::CheckTypesOp>(
405           loc, val, ans->getValue().cast<ArrayAttr>(), success, failure);
406     else
407       builder.create<pdl_interp::CheckTypeOp>(
408           loc, val, ans->getValue().cast<TypeAttr>(), success, failure);
409     break;
410   }
411   case Predicates::AttributeQuestion: {
412     auto *ans = cast<AttributeAnswer>(answer);
413     builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
414                                                  success, failure);
415     break;
416   }
417   case Predicates::OperandCountAtLeastQuestion:
418   case Predicates::OperandCountQuestion:
419     builder.create<pdl_interp::CheckOperandCountOp>(
420         loc, val, cast<UnsignedAnswer>(answer)->getValue(),
421         /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
422         success, failure);
423     break;
424   case Predicates::ResultCountAtLeastQuestion:
425   case Predicates::ResultCountQuestion:
426     builder.create<pdl_interp::CheckResultCountOp>(
427         loc, val, cast<UnsignedAnswer>(answer)->getValue(),
428         /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
429         success, failure);
430     break;
431   case Predicates::EqualToQuestion: {
432     bool trueAnswer = isa<TrueAnswer>(answer);
433     builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(),
434                                            trueAnswer ? success : failure,
435                                            trueAnswer ? failure : success);
436     break;
437   }
438   case Predicates::ConstraintQuestion: {
439     auto *cstQuestion = cast<ConstraintQuestion>(question);
440     builder.create<pdl_interp::ApplyConstraintOp>(loc, cstQuestion->getName(),
441                                                   args, success, failure);
442     break;
443   }
444   default:
445     llvm_unreachable("Generating unknown Predicate operation");
446   }
447 }
448 
449 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
450 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
451                            llvm::MapVector<Qualifier *, Block *> &dests) {
452   std::vector<ValT> values;
453   std::vector<Block *> blocks;
454   values.reserve(dests.size());
455   blocks.reserve(dests.size());
456   for (const auto &it : dests) {
457     blocks.push_back(it.second);
458     values.push_back(cast<PredT>(it.first)->getValue());
459   }
460   builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
461 }
462 
463 void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
464                                Value val) {
465   Qualifier *question = switchNode->getQuestion();
466   Region *region = currentBlock->getParent();
467   Block *defaultDest = failureBlockStack.back();
468 
469   // If the switch question is not an exact answer, i.e. for the `at_least`
470   // cases, we generate a special block sequence.
471   Predicates::Kind kind = question->getKind();
472   if (kind == Predicates::OperandCountAtLeastQuestion ||
473       kind == Predicates::ResultCountAtLeastQuestion) {
474     // Order the children such that the cases are in reverse numerical order.
475     SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
476         llvm::seq<unsigned>(0, switchNode->getChildren().size()));
477     llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
478       return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
479              cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
480     });
481 
482     // Build the destination for each child using the next highest child as a
483     // a failure destination. This essentially creates the following control
484     // flow:
485     //
486     // if (operand_count < 1)
487     //   goto failure
488     // if (child1.match())
489     //   ...
490     //
491     // if (operand_count < 2)
492     //   goto failure
493     // if (child2.match())
494     //   ...
495     //
496     // failure:
497     //   ...
498     //
499     failureBlockStack.push_back(defaultDest);
500     Location loc = val.getLoc();
501     for (unsigned idx : sortedChildren) {
502       auto &child = switchNode->getChild(idx);
503       Block *childBlock = generateMatcher(*child.second, *region);
504       Block *predicateBlock = builder.createBlock(childBlock);
505       builder.setInsertionPointToEnd(predicateBlock);
506       unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
507       switch (kind) {
508       case Predicates::OperandCountAtLeastQuestion:
509         builder.create<pdl_interp::CheckOperandCountOp>(
510             loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
511         break;
512       case Predicates::ResultCountAtLeastQuestion:
513         builder.create<pdl_interp::CheckResultCountOp>(
514             loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
515         break;
516       default:
517         llvm_unreachable("Generating invalid AtLeast operation");
518       }
519       failureBlockStack.back() = predicateBlock;
520     }
521     Block *firstPredicateBlock = failureBlockStack.pop_back_val();
522     currentBlock->getOperations().splice(currentBlock->end(),
523                                          firstPredicateBlock->getOperations());
524     firstPredicateBlock->erase();
525     return;
526   }
527 
528   // Otherwise, generate each of the children and generate an interpreter
529   // switch.
530   llvm::MapVector<Qualifier *, Block *> children;
531   for (auto &it : switchNode->getChildren())
532     children.insert({it.first, generateMatcher(*it.second, *region)});
533   builder.setInsertionPointToEnd(currentBlock);
534 
535   switch (question->getKind()) {
536   case Predicates::OperandCountQuestion:
537     return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
538                           int32_t>(val, defaultDest, builder, children);
539   case Predicates::ResultCountQuestion:
540     return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
541                           int32_t>(val, defaultDest, builder, children);
542   case Predicates::OperationNameQuestion:
543     return createSwitchOp<pdl_interp::SwitchOperationNameOp,
544                           OperationNameAnswer>(val, defaultDest, builder,
545                                                children);
546   case Predicates::TypeQuestion:
547     if (val.getType().isa<pdl::RangeType>()) {
548       return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
549           val, defaultDest, builder, children);
550     }
551     return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
552         val, defaultDest, builder, children);
553   case Predicates::AttributeQuestion:
554     return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
555         val, defaultDest, builder, children);
556   default:
557     llvm_unreachable("Generating unknown switch predicate.");
558   }
559 }
560 
561 void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {
562   pdl::PatternOp pattern = successNode->getPattern();
563   Value root = successNode->getRoot();
564 
565   // Generate a rewriter for the pattern this success node represents, and track
566   // any values used from the match region.
567   SmallVector<Position *, 8> usedMatchValues;
568   SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
569 
570   // Process any values used in the rewrite that are defined in the match.
571   std::vector<Value> mappedMatchValues;
572   mappedMatchValues.reserve(usedMatchValues.size());
573   for (Position *position : usedMatchValues)
574     mappedMatchValues.push_back(getValueAt(currentBlock, position));
575 
576   // Collect the set of operations generated by the rewriter.
577   SmallVector<StringRef, 4> generatedOps;
578   for (auto op :
579        pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>())
580     generatedOps.push_back(*op.getOpName());
581   ArrayAttr generatedOpsAttr;
582   if (!generatedOps.empty())
583     generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
584 
585   // Grab the root kind if present.
586   StringAttr rootKindAttr;
587   if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
588     if (Optional<StringRef> rootKind = rootOp.getOpName())
589       rootKindAttr = builder.getStringAttr(*rootKind);
590 
591   builder.setInsertionPointToEnd(currentBlock);
592   builder.create<pdl_interp::RecordMatchOp>(
593       pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
594       rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
595       failureBlockStack.back());
596 }
597 
598 SymbolRefAttr PatternLowering::generateRewriter(
599     pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
600   builder.setInsertionPointToEnd(rewriterModule.getBody());
601   auto rewriterFunc = builder.create<pdl_interp::FuncOp>(
602       pattern.getLoc(), "pdl_generated_rewriter",
603       builder.getFunctionType(llvm::None, llvm::None));
604   rewriterSymbolTable.insert(rewriterFunc);
605 
606   // Generate the rewriter function body.
607   builder.setInsertionPointToEnd(&rewriterFunc.front());
608 
609   // Map an input operand of the pattern to a generated interpreter value.
610   DenseMap<Value, Value> rewriteValues;
611   auto mapRewriteValue = [&](Value oldValue) {
612     Value &newValue = rewriteValues[oldValue];
613     if (newValue)
614       return newValue;
615 
616     // Prefer materializing constants directly when possible.
617     Operation *oldOp = oldValue.getDefiningOp();
618     if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
619       if (Attribute value = attrOp.getValueAttr()) {
620         return newValue = builder.create<pdl_interp::CreateAttributeOp>(
621                    attrOp.getLoc(), value);
622       }
623     } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
624       if (TypeAttr type = typeOp.getConstantTypeAttr()) {
625         return newValue = builder.create<pdl_interp::CreateTypeOp>(
626                    typeOp.getLoc(), type);
627       }
628     } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
629       if (ArrayAttr type = typeOp.getConstantTypesAttr()) {
630         return newValue = builder.create<pdl_interp::CreateTypesOp>(
631                    typeOp.getLoc(), typeOp.getType(), type);
632       }
633     }
634 
635     // Otherwise, add this as an input to the rewriter.
636     Position *inputPos = valueToPosition.lookup(oldValue);
637     assert(inputPos && "expected value to be a pattern input");
638     usedMatchValues.push_back(inputPos);
639     return newValue = rewriterFunc.front().addArgument(oldValue.getType(),
640                                                        oldValue.getLoc());
641   };
642 
643   // If this is a custom rewriter, simply dispatch to the registered rewrite
644   // method.
645   pdl::RewriteOp rewriter = pattern.getRewriter();
646   if (StringAttr rewriteName = rewriter.getNameAttr()) {
647     SmallVector<Value> args;
648     if (rewriter.getRoot())
649       args.push_back(mapRewriteValue(rewriter.getRoot()));
650     auto mappedArgs =
651         llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
652     args.append(mappedArgs.begin(), mappedArgs.end());
653     builder.create<pdl_interp::ApplyRewriteOp>(
654         rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args);
655   } else {
656     // Otherwise this is a dag rewriter defined using PDL operations.
657     for (Operation &rewriteOp : *rewriter.getBody()) {
658       llvm::TypeSwitch<Operation *>(&rewriteOp)
659           .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
660                 pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp,
661                 pdl::TypeOp, pdl::TypesOp>([&](auto op) {
662             this->generateRewriter(op, rewriteValues, mapRewriteValue);
663           });
664     }
665   }
666 
667   // Update the signature of the rewrite function.
668   rewriterFunc.setType(builder.getFunctionType(
669       llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
670       /*results=*/llvm::None));
671 
672   builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
673   return SymbolRefAttr::get(
674       builder.getContext(),
675       pdl_interp::PDLInterpDialect::getRewriterModuleName(),
676       SymbolRefAttr::get(rewriterFunc));
677 }
678 
679 void PatternLowering::generateRewriter(
680     pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
681     function_ref<Value(Value)> mapRewriteValue) {
682   SmallVector<Value, 2> arguments;
683   for (Value argument : rewriteOp.getArgs())
684     arguments.push_back(mapRewriteValue(argument));
685   auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
686       rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(),
687       arguments);
688   for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
689     rewriteValues[std::get<0>(it)] = std::get<1>(it);
690 }
691 
692 void PatternLowering::generateRewriter(
693     pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
694     function_ref<Value(Value)> mapRewriteValue) {
695   Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
696       attrOp.getLoc(), attrOp.getValueAttr());
697   rewriteValues[attrOp] = newAttr;
698 }
699 
700 void PatternLowering::generateRewriter(
701     pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
702     function_ref<Value(Value)> mapRewriteValue) {
703   builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
704                                       mapRewriteValue(eraseOp.getOpValue()));
705 }
706 
707 void PatternLowering::generateRewriter(
708     pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
709     function_ref<Value(Value)> mapRewriteValue) {
710   SmallVector<Value, 4> operands;
711   for (Value operand : operationOp.getOperandValues())
712     operands.push_back(mapRewriteValue(operand));
713 
714   SmallVector<Value, 4> attributes;
715   for (Value attr : operationOp.getAttributeValues())
716     attributes.push_back(mapRewriteValue(attr));
717 
718   bool hasInferredResultTypes = false;
719   SmallVector<Value, 2> types;
720   generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
721                                       rewriteValues, hasInferredResultTypes);
722 
723   // Create the new operation.
724   Location loc = operationOp.getLoc();
725   Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
726       loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
727       attributes, operationOp.getAttributeValueNames());
728   rewriteValues[operationOp.getOp()] = createdOp;
729 
730   // Generate accesses for any results that have their types constrained.
731   // Handle the case where there is a single range representing all of the
732   // result types.
733   OperandRange resultTys = operationOp.getTypeValues();
734   if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) {
735     Value &type = rewriteValues[resultTys[0]];
736     if (!type) {
737       auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
738       type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
739     }
740     return;
741   }
742 
743   // Otherwise, populate the individual results.
744   bool seenVariableLength = false;
745   Type valueTy = builder.getType<pdl::ValueType>();
746   Type valueRangeTy = pdl::RangeType::get(valueTy);
747   for (const auto &it : llvm::enumerate(resultTys)) {
748     Value &type = rewriteValues[it.value()];
749     if (type)
750       continue;
751     bool isVariadic = it.value().getType().isa<pdl::RangeType>();
752     seenVariableLength |= isVariadic;
753 
754     // After a variable length result has been seen, we need to use result
755     // groups because the exact index of the result is not statically known.
756     Value resultVal;
757     if (seenVariableLength)
758       resultVal = builder.create<pdl_interp::GetResultsOp>(
759           loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
760     else
761       resultVal = builder.create<pdl_interp::GetResultOp>(
762           loc, valueTy, createdOp, it.index());
763     type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
764   }
765 }
766 
767 void PatternLowering::generateRewriter(
768     pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
769     function_ref<Value(Value)> mapRewriteValue) {
770   SmallVector<Value, 4> replOperands;
771 
772   // If the replacement was another operation, get its results. `pdl` allows
773   // for using an operation for simplicitly, but the interpreter isn't as
774   // user facing.
775   if (Value replOp = replaceOp.getReplOperation()) {
776     // Don't use replace if we know the replaced operation has no results.
777     auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
778     if (!opOp || !opOp.getTypeValues().empty()) {
779       replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
780           replOp.getLoc(), mapRewriteValue(replOp)));
781     }
782   } else {
783     for (Value operand : replaceOp.getReplValues())
784       replOperands.push_back(mapRewriteValue(operand));
785   }
786 
787   // If there are no replacement values, just create an erase instead.
788   if (replOperands.empty()) {
789     builder.create<pdl_interp::EraseOp>(
790         replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue()));
791     return;
792   }
793 
794   builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(),
795                                         mapRewriteValue(replaceOp.getOpValue()),
796                                         replOperands);
797 }
798 
799 void PatternLowering::generateRewriter(
800     pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
801     function_ref<Value(Value)> mapRewriteValue) {
802   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
803       resultOp.getLoc(), builder.getType<pdl::ValueType>(),
804       mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
805 }
806 
807 void PatternLowering::generateRewriter(
808     pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
809     function_ref<Value(Value)> mapRewriteValue) {
810   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
811       resultOp.getLoc(), resultOp.getType(),
812       mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
813 }
814 
815 void PatternLowering::generateRewriter(
816     pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
817     function_ref<Value(Value)> mapRewriteValue) {
818   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
819   // type.
820   if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
821     rewriteValues[typeOp] =
822         builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
823   }
824 }
825 
826 void PatternLowering::generateRewriter(
827     pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
828     function_ref<Value(Value)> mapRewriteValue) {
829   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
830   // type.
831   if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
832     rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
833         typeOp.getLoc(), typeOp.getType(), typeAttr);
834   }
835 }
836 
837 void PatternLowering::generateOperationResultTypeRewriter(
838     pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
839     SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
840     bool &hasInferredResultTypes) {
841   Block *rewriterBlock = op->getBlock();
842 
843   // Try to handle resolution for each of the result types individually. This is
844   // preferred over type inferrence because it will allow for us to use existing
845   // types directly, as opposed to trying to rebuild the type list.
846   OperandRange resultTypeValues = op.getTypeValues();
847   auto tryResolveResultTypes = [&] {
848     types.reserve(resultTypeValues.size());
849     for (const auto &it : llvm::enumerate(resultTypeValues)) {
850       Value resultType = it.value();
851 
852       // Check for an already translated value.
853       if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
854         types.push_back(existingRewriteValue);
855         continue;
856       }
857 
858       // Check for an input from the matcher.
859       if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
860         types.push_back(mapRewriteValue(resultType));
861         continue;
862       }
863 
864       // Otherwise, we couldn't infer the result types. Bail out here to see if
865       // we can infer the types for this operation from another way.
866       types.clear();
867       return failure();
868     }
869     return success();
870   };
871   if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes()))
872     return;
873 
874   // Otherwise, check if the operation has type inference support itself.
875   if (op.hasTypeInference()) {
876     hasInferredResultTypes = true;
877     return;
878   }
879 
880   // Look for an operation that was replaced by `op`. The result types will be
881   // inferred from the results that were replaced.
882   for (OpOperand &use : op.getOp().getUses()) {
883     // Check that the use corresponds to a ReplaceOp and that it is the
884     // replacement value, not the operation being replaced.
885     pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
886     if (!replOpUser || use.getOperandNumber() == 0)
887       continue;
888     // Make sure the replaced operation was defined before this one. PDL
889     // rewrites only have single block regions, so if the op isn't in the
890     // rewriter block (i.e. the current block of the operation) we already know
891     // it dominates (i.e. it's in the matcher).
892     Value replOpVal = replOpUser.getOpValue();
893     Operation *replacedOp = replOpVal.getDefiningOp();
894     if (replacedOp->getBlock() == rewriterBlock &&
895         !replacedOp->isBeforeInBlock(op))
896       continue;
897 
898     Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
899         replacedOp->getLoc(), mapRewriteValue(replOpVal));
900     types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
901         replacedOp->getLoc(), replacedOpResults));
902     return;
903   }
904 
905   // If the types could not be inferred from any context and there weren't any
906   // explicit result types, assume the user actually meant for the operation to
907   // have no results.
908   if (resultTypeValues.empty())
909     return;
910 
911   // The verifier asserts that the result types of each pdl.getOperation can be
912   // inferred. If we reach here, there is a bug either in the logic above or
913   // in the verifier for pdl.getOperation.
914   op->emitOpError() << "unable to infer result type for operation";
915   llvm_unreachable("unable to infer result type for operation");
916 }
917 
918 //===----------------------------------------------------------------------===//
919 // Conversion Pass
920 //===----------------------------------------------------------------------===//
921 
922 namespace {
923 struct PDLToPDLInterpPass
924     : public impl::ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
925   void runOnOperation() final;
926 };
927 } // namespace
928 
929 /// Convert the given module containing PDL pattern operations into a PDL
930 /// Interpreter operations.
931 void PDLToPDLInterpPass::runOnOperation() {
932   ModuleOp module = getOperation();
933 
934   // Create the main matcher function This function contains all of the match
935   // related functionality from patterns in the module.
936   OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
937   auto matcherFunc = builder.create<pdl_interp::FuncOp>(
938       module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
939       builder.getFunctionType(builder.getType<pdl::OperationType>(),
940                               /*results=*/llvm::None),
941       /*attrs=*/llvm::None);
942 
943   // Create a nested module to hold the functions invoked for rewriting the IR
944   // after a successful match.
945   ModuleOp rewriterModule = builder.create<ModuleOp>(
946       module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
947 
948   // Generate the code for the patterns within the module.
949   PatternLowering generator(matcherFunc, rewriterModule);
950   generator.lower(module);
951 
952   // After generation, delete all of the pattern operations.
953   for (pdl::PatternOp pattern :
954        llvm::make_early_inc_range(module.getOps<pdl::PatternOp>()))
955     pattern.erase();
956 }
957 
958 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
959   return std::make_unique<PDLToPDLInterpPass>();
960 }
961