xref: /llvm-project/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp (revision c80e6edba4a9593f0587e27fa0ac825ebe174afd)
1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
10 
11 #include "PredicateTree.h"
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/ScopedHashTable.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_CONVERTPDLTOPDLINTERP
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::pdl_to_pdl_interp;
30 
31 //===----------------------------------------------------------------------===//
32 // PatternLowering
33 //===----------------------------------------------------------------------===//
34 
35 namespace {
36 /// This class generators operations within the PDL Interpreter dialect from a
37 /// given module containing PDL pattern operations.
38 struct PatternLowering {
39 public:
40   PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
41                   DenseMap<Operation *, PDLPatternConfigSet *> *configMap);
42 
43   /// Generate code for matching and rewriting based on the pattern operations
44   /// within the module.
45   void lower(ModuleOp module);
46 
47 private:
48   using ValueMap = llvm::ScopedHashTable<Position *, Value>;
49   using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
50 
51   /// Generate interpreter operations for the tree rooted at the given matcher
52   /// node, in the specified region.
53   Block *generateMatcher(MatcherNode &node, Region &region);
54 
55   /// Get or create an access to the provided positional value in the current
56   /// block. This operation may mutate the provided block pointer if nested
57   /// regions (i.e., pdl_interp.iterate) are required.
58   Value getValueAt(Block *&currentBlock, Position *pos);
59 
60   /// Create the interpreter predicate operations. This operation may mutate the
61   /// provided current block pointer if nested regions (iterates) are required.
62   void generate(BoolNode *boolNode, Block *&currentBlock, Value val);
63 
64   /// Create the interpreter switch / predicate operations, with several case
65   /// destinations. This operation never mutates the provided current block
66   /// pointer, because the switch operation does not need Values beyond `val`.
67   void generate(SwitchNode *switchNode, Block *currentBlock, Value val);
68 
69   /// Create the interpreter operations to record a successful pattern match
70   /// using the contained root operation. This operation may mutate the current
71   /// block pointer if nested regions (i.e., pdl_interp.iterate) are required.
72   void generate(SuccessNode *successNode, Block *&currentBlock);
73 
74   /// Generate a rewriter function for the given pattern operation, and returns
75   /// a reference to that function.
76   SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
77                                  SmallVectorImpl<Position *> &usedMatchValues);
78 
79   /// Generate the rewriter code for the given operation.
80   void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
81                         DenseMap<Value, Value> &rewriteValues,
82                         function_ref<Value(Value)> mapRewriteValue);
83   void generateRewriter(pdl::AttributeOp attrOp,
84                         DenseMap<Value, Value> &rewriteValues,
85                         function_ref<Value(Value)> mapRewriteValue);
86   void generateRewriter(pdl::EraseOp eraseOp,
87                         DenseMap<Value, Value> &rewriteValues,
88                         function_ref<Value(Value)> mapRewriteValue);
89   void generateRewriter(pdl::OperationOp operationOp,
90                         DenseMap<Value, Value> &rewriteValues,
91                         function_ref<Value(Value)> mapRewriteValue);
92   void generateRewriter(pdl::RangeOp rangeOp,
93                         DenseMap<Value, Value> &rewriteValues,
94                         function_ref<Value(Value)> mapRewriteValue);
95   void generateRewriter(pdl::ReplaceOp replaceOp,
96                         DenseMap<Value, Value> &rewriteValues,
97                         function_ref<Value(Value)> mapRewriteValue);
98   void generateRewriter(pdl::ResultOp resultOp,
99                         DenseMap<Value, Value> &rewriteValues,
100                         function_ref<Value(Value)> mapRewriteValue);
101   void generateRewriter(pdl::ResultsOp resultOp,
102                         DenseMap<Value, Value> &rewriteValues,
103                         function_ref<Value(Value)> mapRewriteValue);
104   void generateRewriter(pdl::TypeOp typeOp,
105                         DenseMap<Value, Value> &rewriteValues,
106                         function_ref<Value(Value)> mapRewriteValue);
107   void generateRewriter(pdl::TypesOp typeOp,
108                         DenseMap<Value, Value> &rewriteValues,
109                         function_ref<Value(Value)> mapRewriteValue);
110 
111   /// Generate the values used for resolving the result types of an operation
112   /// created within a dag rewriter region. If the result types of the operation
113   /// should be inferred, `hasInferredResultTypes` is set to true.
114   void generateOperationResultTypeRewriter(
115       pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
116       SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
117       bool &hasInferredResultTypes);
118 
119   /// A builder to use when generating interpreter operations.
120   OpBuilder builder;
121 
122   /// The matcher function used for all match related logic within PDL patterns.
123   pdl_interp::FuncOp matcherFunc;
124 
125   /// The rewriter module containing the all rewrite related logic within PDL
126   /// patterns.
127   ModuleOp rewriterModule;
128 
129   /// The symbol table of the rewriter module used for insertion.
130   SymbolTable rewriterSymbolTable;
131 
132   /// A scoped map connecting a position with the corresponding interpreter
133   /// value.
134   ValueMap values;
135 
136   /// A stack of blocks used as the failure destination for matcher nodes that
137   /// don't have an explicit failure path.
138   SmallVector<Block *, 8> failureBlockStack;
139 
140   /// A mapping between values defined in a pattern match, and the corresponding
141   /// positional value.
142   DenseMap<Value, Position *> valueToPosition;
143 
144   /// The set of operation values whose location will be used for newly
145   /// generated operations.
146   SetVector<Value> locOps;
147 
148   /// A mapping between pattern operations and the corresponding configuration
149   /// set.
150   DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
151 };
152 } // namespace
153 
154 PatternLowering::PatternLowering(
155     pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
156     DenseMap<Operation *, PDLPatternConfigSet *> *configMap)
157     : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
158       rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule),
159       configMap(configMap) {}
160 
161 void PatternLowering::lower(ModuleOp module) {
162   PredicateUniquer predicateUniquer;
163   PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
164 
165   // Define top-level scope for the arguments to the matcher function.
166   ValueMapScope topLevelValueScope(values);
167 
168   // Insert the root operation, i.e. argument to the matcher, at the root
169   // position.
170   Block *matcherEntryBlock = &matcherFunc.front();
171   values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
172 
173   // Generate a root matcher node from the provided PDL module.
174   std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
175       module, predicateBuilder, valueToPosition);
176   Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
177   assert(failureBlockStack.empty() && "failed to empty the stack");
178 
179   // After generation, merged the first matched block into the entry.
180   matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
181                                             firstMatcherBlock->getOperations());
182   firstMatcherBlock->erase();
183 }
184 
185 Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region) {
186   // Push a new scope for the values used by this matcher.
187   Block *block = &region.emplaceBlock();
188   ValueMapScope scope(values);
189 
190   // If this is the return node, simply insert the corresponding interpreter
191   // finalize.
192   if (isa<ExitNode>(node)) {
193     builder.setInsertionPointToEnd(block);
194     builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
195     return block;
196   }
197 
198   // Get the next block in the match sequence.
199   // This is intentionally executed first, before we get the value for the
200   // position associated with the node, so that we preserve an "there exist"
201   // semantics: if getting a value requires an upward traversal (going from a
202   // value to its consumers), we want to perform the check on all the consumers
203   // before we pass control to the failure node.
204   std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
205   Block *failureBlock;
206   if (failureNode) {
207     failureBlock = generateMatcher(*failureNode, region);
208     failureBlockStack.push_back(failureBlock);
209   } else {
210     assert(!failureBlockStack.empty() && "expected valid failure block");
211     failureBlock = failureBlockStack.back();
212   }
213 
214   // If this node contains a position, get the corresponding value for this
215   // block.
216   Block *currentBlock = block;
217   Position *position = node.getPosition();
218   Value val = position ? getValueAt(currentBlock, position) : Value();
219 
220   // If this value corresponds to an operation, record that we are going to use
221   // its location as part of a fused location.
222   bool isOperationValue = val && isa<pdl::OperationType>(val.getType());
223   if (isOperationValue)
224     locOps.insert(val);
225 
226   // Dispatch to the correct method based on derived node type.
227   TypeSwitch<MatcherNode *>(&node)
228       .Case<BoolNode, SwitchNode>([&](auto *derivedNode) {
229         this->generate(derivedNode, currentBlock, val);
230       })
231       .Case([&](SuccessNode *successNode) {
232         generate(successNode, currentBlock);
233       });
234 
235   // Pop all the failure blocks that were inserted due to nesting of
236   // pdl_interp.iterate.
237   while (failureBlockStack.back() != failureBlock) {
238     failureBlockStack.pop_back();
239     assert(!failureBlockStack.empty() && "unable to locate failure block");
240   }
241 
242   // Pop the new failure block.
243   if (failureNode)
244     failureBlockStack.pop_back();
245 
246   if (isOperationValue)
247     locOps.remove(val);
248 
249   return block;
250 }
251 
252 Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
253   if (Value val = values.lookup(pos))
254     return val;
255 
256   // Get the value for the parent position.
257   Value parentVal;
258   if (Position *parent = pos->getParent())
259     parentVal = getValueAt(currentBlock, parent);
260 
261   // TODO: Use a location from the position.
262   Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
263   builder.setInsertionPointToEnd(currentBlock);
264   Value value;
265   switch (pos->getKind()) {
266   case Predicates::OperationPos: {
267     auto *operationPos = cast<OperationPosition>(pos);
268     if (operationPos->isOperandDefiningOp())
269       // Standard (downward) traversal which directly follows the defining op.
270       value = builder.create<pdl_interp::GetDefiningOpOp>(
271           loc, builder.getType<pdl::OperationType>(), parentVal);
272     else
273       // A passthrough operation position.
274       value = parentVal;
275     break;
276   }
277   case Predicates::UsersPos: {
278     auto *usersPos = cast<UsersPosition>(pos);
279 
280     // The first operation retrieves the representative value of a range.
281     // This applies only when the parent is a range of values and we were
282     // requested to use a representative value (e.g., upward traversal).
283     if (isa<pdl::RangeType>(parentVal.getType()) &&
284         usersPos->useRepresentative())
285       value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
286     else
287       value = parentVal;
288 
289     // The second operation retrieves the users.
290     value = builder.create<pdl_interp::GetUsersOp>(loc, value);
291     break;
292   }
293   case Predicates::ForEachPos: {
294     assert(!failureBlockStack.empty() && "expected valid failure block");
295     auto foreach = builder.create<pdl_interp::ForEachOp>(
296         loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
297     value = foreach.getLoopVariable();
298 
299     // Create the continuation block.
300     Block *continueBlock = builder.createBlock(&foreach.getRegion());
301     builder.create<pdl_interp::ContinueOp>(loc);
302     failureBlockStack.push_back(continueBlock);
303 
304     currentBlock = &foreach.getRegion().front();
305     break;
306   }
307   case Predicates::OperandPos: {
308     auto *operandPos = cast<OperandPosition>(pos);
309     value = builder.create<pdl_interp::GetOperandOp>(
310         loc, builder.getType<pdl::ValueType>(), parentVal,
311         operandPos->getOperandNumber());
312     break;
313   }
314   case Predicates::OperandGroupPos: {
315     auto *operandPos = cast<OperandGroupPosition>(pos);
316     Type valueTy = builder.getType<pdl::ValueType>();
317     value = builder.create<pdl_interp::GetOperandsOp>(
318         loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
319         parentVal, operandPos->getOperandGroupNumber());
320     break;
321   }
322   case Predicates::AttributePos: {
323     auto *attrPos = cast<AttributePosition>(pos);
324     value = builder.create<pdl_interp::GetAttributeOp>(
325         loc, builder.getType<pdl::AttributeType>(), parentVal,
326         attrPos->getName().strref());
327     break;
328   }
329   case Predicates::TypePos: {
330     if (isa<pdl::AttributeType>(parentVal.getType()))
331       value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
332     else
333       value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
334     break;
335   }
336   case Predicates::ResultPos: {
337     auto *resPos = cast<ResultPosition>(pos);
338     value = builder.create<pdl_interp::GetResultOp>(
339         loc, builder.getType<pdl::ValueType>(), parentVal,
340         resPos->getResultNumber());
341     break;
342   }
343   case Predicates::ResultGroupPos: {
344     auto *resPos = cast<ResultGroupPosition>(pos);
345     Type valueTy = builder.getType<pdl::ValueType>();
346     value = builder.create<pdl_interp::GetResultsOp>(
347         loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
348         parentVal, resPos->getResultGroupNumber());
349     break;
350   }
351   case Predicates::AttributeLiteralPos: {
352     auto *attrPos = cast<AttributeLiteralPosition>(pos);
353     value =
354         builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
355     break;
356   }
357   case Predicates::TypeLiteralPos: {
358     auto *typePos = cast<TypeLiteralPosition>(pos);
359     Attribute rawTypeAttr = typePos->getValue();
360     if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr))
361       value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
362     else
363       value = builder.create<pdl_interp::CreateTypesOp>(
364           loc, cast<ArrayAttr>(rawTypeAttr));
365     break;
366   }
367   default:
368     llvm_unreachable("Generating unknown Position getter");
369     break;
370   }
371 
372   values.insert(pos, value);
373   return value;
374 }
375 
376 void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
377                                Value val) {
378   Location loc = val.getLoc();
379   Qualifier *question = boolNode->getQuestion();
380   Qualifier *answer = boolNode->getAnswer();
381   Region *region = currentBlock->getParent();
382 
383   // Execute the getValue queries first, so that we create success
384   // matcher in the correct (possibly nested) region.
385   SmallVector<Value> args;
386   if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
387     args = {getValueAt(currentBlock, equalToQuestion->getValue())};
388   } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
389     for (Position *position : cstQuestion->getArgs())
390       args.push_back(getValueAt(currentBlock, position));
391   }
392 
393   // Generate the matcher in the current (potentially nested) region
394   // and get the failure successor.
395   Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
396   Block *failure = failureBlockStack.back();
397 
398   // Finally, create the predicate.
399   builder.setInsertionPointToEnd(currentBlock);
400   Predicates::Kind kind = question->getKind();
401   switch (kind) {
402   case Predicates::IsNotNullQuestion:
403     builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure);
404     break;
405   case Predicates::OperationNameQuestion: {
406     auto *opNameAnswer = cast<OperationNameAnswer>(answer);
407     builder.create<pdl_interp::CheckOperationNameOp>(
408         loc, val, opNameAnswer->getValue().getStringRef(), success, failure);
409     break;
410   }
411   case Predicates::TypeQuestion: {
412     auto *ans = cast<TypeAnswer>(answer);
413     if (isa<pdl::RangeType>(val.getType()))
414       builder.create<pdl_interp::CheckTypesOp>(
415           loc, val, llvm::cast<ArrayAttr>(ans->getValue()), success, failure);
416     else
417       builder.create<pdl_interp::CheckTypeOp>(
418           loc, val, llvm::cast<TypeAttr>(ans->getValue()), success, failure);
419     break;
420   }
421   case Predicates::AttributeQuestion: {
422     auto *ans = cast<AttributeAnswer>(answer);
423     builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
424                                                  success, failure);
425     break;
426   }
427   case Predicates::OperandCountAtLeastQuestion:
428   case Predicates::OperandCountQuestion:
429     builder.create<pdl_interp::CheckOperandCountOp>(
430         loc, val, cast<UnsignedAnswer>(answer)->getValue(),
431         /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
432         success, failure);
433     break;
434   case Predicates::ResultCountAtLeastQuestion:
435   case Predicates::ResultCountQuestion:
436     builder.create<pdl_interp::CheckResultCountOp>(
437         loc, val, cast<UnsignedAnswer>(answer)->getValue(),
438         /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
439         success, failure);
440     break;
441   case Predicates::EqualToQuestion: {
442     bool trueAnswer = isa<TrueAnswer>(answer);
443     builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(),
444                                            trueAnswer ? success : failure,
445                                            trueAnswer ? failure : success);
446     break;
447   }
448   case Predicates::ConstraintQuestion: {
449     auto *cstQuestion = cast<ConstraintQuestion>(question);
450     builder.create<pdl_interp::ApplyConstraintOp>(
451         loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(), success,
452         failure);
453     break;
454   }
455   default:
456     llvm_unreachable("Generating unknown Predicate operation");
457   }
458 }
459 
460 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
461 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
462                            llvm::MapVector<Qualifier *, Block *> &dests) {
463   std::vector<ValT> values;
464   std::vector<Block *> blocks;
465   values.reserve(dests.size());
466   blocks.reserve(dests.size());
467   for (const auto &it : dests) {
468     blocks.push_back(it.second);
469     values.push_back(cast<PredT>(it.first)->getValue());
470   }
471   builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
472 }
473 
474 void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
475                                Value val) {
476   Qualifier *question = switchNode->getQuestion();
477   Region *region = currentBlock->getParent();
478   Block *defaultDest = failureBlockStack.back();
479 
480   // If the switch question is not an exact answer, i.e. for the `at_least`
481   // cases, we generate a special block sequence.
482   Predicates::Kind kind = question->getKind();
483   if (kind == Predicates::OperandCountAtLeastQuestion ||
484       kind == Predicates::ResultCountAtLeastQuestion) {
485     // Order the children such that the cases are in reverse numerical order.
486     SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
487         llvm::seq<unsigned>(0, switchNode->getChildren().size()));
488     llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
489       return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
490              cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
491     });
492 
493     // Build the destination for each child using the next highest child as a
494     // a failure destination. This essentially creates the following control
495     // flow:
496     //
497     // if (operand_count < 1)
498     //   goto failure
499     // if (child1.match())
500     //   ...
501     //
502     // if (operand_count < 2)
503     //   goto failure
504     // if (child2.match())
505     //   ...
506     //
507     // failure:
508     //   ...
509     //
510     failureBlockStack.push_back(defaultDest);
511     Location loc = val.getLoc();
512     for (unsigned idx : sortedChildren) {
513       auto &child = switchNode->getChild(idx);
514       Block *childBlock = generateMatcher(*child.second, *region);
515       Block *predicateBlock = builder.createBlock(childBlock);
516       builder.setInsertionPointToEnd(predicateBlock);
517       unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
518       switch (kind) {
519       case Predicates::OperandCountAtLeastQuestion:
520         builder.create<pdl_interp::CheckOperandCountOp>(
521             loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
522         break;
523       case Predicates::ResultCountAtLeastQuestion:
524         builder.create<pdl_interp::CheckResultCountOp>(
525             loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
526         break;
527       default:
528         llvm_unreachable("Generating invalid AtLeast operation");
529       }
530       failureBlockStack.back() = predicateBlock;
531     }
532     Block *firstPredicateBlock = failureBlockStack.pop_back_val();
533     currentBlock->getOperations().splice(currentBlock->end(),
534                                          firstPredicateBlock->getOperations());
535     firstPredicateBlock->erase();
536     return;
537   }
538 
539   // Otherwise, generate each of the children and generate an interpreter
540   // switch.
541   llvm::MapVector<Qualifier *, Block *> children;
542   for (auto &it : switchNode->getChildren())
543     children.insert({it.first, generateMatcher(*it.second, *region)});
544   builder.setInsertionPointToEnd(currentBlock);
545 
546   switch (question->getKind()) {
547   case Predicates::OperandCountQuestion:
548     return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
549                           int32_t>(val, defaultDest, builder, children);
550   case Predicates::ResultCountQuestion:
551     return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
552                           int32_t>(val, defaultDest, builder, children);
553   case Predicates::OperationNameQuestion:
554     return createSwitchOp<pdl_interp::SwitchOperationNameOp,
555                           OperationNameAnswer>(val, defaultDest, builder,
556                                                children);
557   case Predicates::TypeQuestion:
558     if (isa<pdl::RangeType>(val.getType())) {
559       return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
560           val, defaultDest, builder, children);
561     }
562     return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
563         val, defaultDest, builder, children);
564   case Predicates::AttributeQuestion:
565     return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
566         val, defaultDest, builder, children);
567   default:
568     llvm_unreachable("Generating unknown switch predicate.");
569   }
570 }
571 
572 void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {
573   pdl::PatternOp pattern = successNode->getPattern();
574   Value root = successNode->getRoot();
575 
576   // Generate a rewriter for the pattern this success node represents, and track
577   // any values used from the match region.
578   SmallVector<Position *, 8> usedMatchValues;
579   SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
580 
581   // Process any values used in the rewrite that are defined in the match.
582   std::vector<Value> mappedMatchValues;
583   mappedMatchValues.reserve(usedMatchValues.size());
584   for (Position *position : usedMatchValues)
585     mappedMatchValues.push_back(getValueAt(currentBlock, position));
586 
587   // Collect the set of operations generated by the rewriter.
588   SmallVector<StringRef, 4> generatedOps;
589   for (auto op :
590        pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>())
591     generatedOps.push_back(*op.getOpName());
592   ArrayAttr generatedOpsAttr;
593   if (!generatedOps.empty())
594     generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
595 
596   // Grab the root kind if present.
597   StringAttr rootKindAttr;
598   if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
599     if (std::optional<StringRef> rootKind = rootOp.getOpName())
600       rootKindAttr = builder.getStringAttr(*rootKind);
601 
602   builder.setInsertionPointToEnd(currentBlock);
603   auto matchOp = builder.create<pdl_interp::RecordMatchOp>(
604       pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
605       rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
606       failureBlockStack.back());
607 
608   // Set the config of the lowered match to the parent pattern.
609   if (configMap)
610     configMap->try_emplace(matchOp, configMap->lookup(pattern));
611 }
612 
613 SymbolRefAttr PatternLowering::generateRewriter(
614     pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
615   builder.setInsertionPointToEnd(rewriterModule.getBody());
616   auto rewriterFunc = builder.create<pdl_interp::FuncOp>(
617       pattern.getLoc(), "pdl_generated_rewriter",
618       builder.getFunctionType(std::nullopt, std::nullopt));
619   rewriterSymbolTable.insert(rewriterFunc);
620 
621   // Generate the rewriter function body.
622   builder.setInsertionPointToEnd(&rewriterFunc.front());
623 
624   // Map an input operand of the pattern to a generated interpreter value.
625   DenseMap<Value, Value> rewriteValues;
626   auto mapRewriteValue = [&](Value oldValue) {
627     Value &newValue = rewriteValues[oldValue];
628     if (newValue)
629       return newValue;
630 
631     // Prefer materializing constants directly when possible.
632     Operation *oldOp = oldValue.getDefiningOp();
633     if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
634       if (Attribute value = attrOp.getValueAttr()) {
635         return newValue = builder.create<pdl_interp::CreateAttributeOp>(
636                    attrOp.getLoc(), value);
637       }
638     } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
639       if (TypeAttr type = typeOp.getConstantTypeAttr()) {
640         return newValue = builder.create<pdl_interp::CreateTypeOp>(
641                    typeOp.getLoc(), type);
642       }
643     } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
644       if (ArrayAttr type = typeOp.getConstantTypesAttr()) {
645         return newValue = builder.create<pdl_interp::CreateTypesOp>(
646                    typeOp.getLoc(), typeOp.getType(), type);
647       }
648     }
649 
650     // Otherwise, add this as an input to the rewriter.
651     Position *inputPos = valueToPosition.lookup(oldValue);
652     assert(inputPos && "expected value to be a pattern input");
653     usedMatchValues.push_back(inputPos);
654     return newValue = rewriterFunc.front().addArgument(oldValue.getType(),
655                                                        oldValue.getLoc());
656   };
657 
658   // If this is a custom rewriter, simply dispatch to the registered rewrite
659   // method.
660   pdl::RewriteOp rewriter = pattern.getRewriter();
661   if (StringAttr rewriteName = rewriter.getNameAttr()) {
662     SmallVector<Value> args;
663     if (rewriter.getRoot())
664       args.push_back(mapRewriteValue(rewriter.getRoot()));
665     auto mappedArgs =
666         llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
667     args.append(mappedArgs.begin(), mappedArgs.end());
668     builder.create<pdl_interp::ApplyRewriteOp>(
669         rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args);
670   } else {
671     // Otherwise this is a dag rewriter defined using PDL operations.
672     for (Operation &rewriteOp : *rewriter.getBody()) {
673       llvm::TypeSwitch<Operation *>(&rewriteOp)
674           .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
675                 pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp,
676                 pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) {
677             this->generateRewriter(op, rewriteValues, mapRewriteValue);
678           });
679     }
680   }
681 
682   // Update the signature of the rewrite function.
683   rewriterFunc.setType(builder.getFunctionType(
684       llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
685       /*results=*/std::nullopt));
686 
687   builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
688   return SymbolRefAttr::get(
689       builder.getContext(),
690       pdl_interp::PDLInterpDialect::getRewriterModuleName(),
691       SymbolRefAttr::get(rewriterFunc));
692 }
693 
694 void PatternLowering::generateRewriter(
695     pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
696     function_ref<Value(Value)> mapRewriteValue) {
697   SmallVector<Value, 2> arguments;
698   for (Value argument : rewriteOp.getArgs())
699     arguments.push_back(mapRewriteValue(argument));
700   auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
701       rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(),
702       arguments);
703   for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
704     rewriteValues[std::get<0>(it)] = std::get<1>(it);
705 }
706 
707 void PatternLowering::generateRewriter(
708     pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
709     function_ref<Value(Value)> mapRewriteValue) {
710   Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
711       attrOp.getLoc(), attrOp.getValueAttr());
712   rewriteValues[attrOp] = newAttr;
713 }
714 
715 void PatternLowering::generateRewriter(
716     pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
717     function_ref<Value(Value)> mapRewriteValue) {
718   builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
719                                       mapRewriteValue(eraseOp.getOpValue()));
720 }
721 
722 void PatternLowering::generateRewriter(
723     pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
724     function_ref<Value(Value)> mapRewriteValue) {
725   SmallVector<Value, 4> operands;
726   for (Value operand : operationOp.getOperandValues())
727     operands.push_back(mapRewriteValue(operand));
728 
729   SmallVector<Value, 4> attributes;
730   for (Value attr : operationOp.getAttributeValues())
731     attributes.push_back(mapRewriteValue(attr));
732 
733   bool hasInferredResultTypes = false;
734   SmallVector<Value, 2> types;
735   generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
736                                       rewriteValues, hasInferredResultTypes);
737 
738   // Create the new operation.
739   Location loc = operationOp.getLoc();
740   Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
741       loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
742       attributes, operationOp.getAttributeValueNames());
743   rewriteValues[operationOp.getOp()] = createdOp;
744 
745   // Generate accesses for any results that have their types constrained.
746   // Handle the case where there is a single range representing all of the
747   // result types.
748   OperandRange resultTys = operationOp.getTypeValues();
749   if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) {
750     Value &type = rewriteValues[resultTys[0]];
751     if (!type) {
752       auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
753       type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
754     }
755     return;
756   }
757 
758   // Otherwise, populate the individual results.
759   bool seenVariableLength = false;
760   Type valueTy = builder.getType<pdl::ValueType>();
761   Type valueRangeTy = pdl::RangeType::get(valueTy);
762   for (const auto &it : llvm::enumerate(resultTys)) {
763     Value &type = rewriteValues[it.value()];
764     if (type)
765       continue;
766     bool isVariadic = isa<pdl::RangeType>(it.value().getType());
767     seenVariableLength |= isVariadic;
768 
769     // After a variable length result has been seen, we need to use result
770     // groups because the exact index of the result is not statically known.
771     Value resultVal;
772     if (seenVariableLength)
773       resultVal = builder.create<pdl_interp::GetResultsOp>(
774           loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
775     else
776       resultVal = builder.create<pdl_interp::GetResultOp>(
777           loc, valueTy, createdOp, it.index());
778     type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
779   }
780 }
781 
782 void PatternLowering::generateRewriter(
783     pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues,
784     function_ref<Value(Value)> mapRewriteValue) {
785   SmallVector<Value, 4> replOperands;
786   for (Value operand : rangeOp.getArguments())
787     replOperands.push_back(mapRewriteValue(operand));
788   rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>(
789       rangeOp.getLoc(), rangeOp.getType(), replOperands);
790 }
791 
792 void PatternLowering::generateRewriter(
793     pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
794     function_ref<Value(Value)> mapRewriteValue) {
795   SmallVector<Value, 4> replOperands;
796 
797   // If the replacement was another operation, get its results. `pdl` allows
798   // for using an operation for simplicitly, but the interpreter isn't as
799   // user facing.
800   if (Value replOp = replaceOp.getReplOperation()) {
801     // Don't use replace if we know the replaced operation has no results.
802     auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
803     if (!opOp || !opOp.getTypeValues().empty()) {
804       replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
805           replOp.getLoc(), mapRewriteValue(replOp)));
806     }
807   } else {
808     for (Value operand : replaceOp.getReplValues())
809       replOperands.push_back(mapRewriteValue(operand));
810   }
811 
812   // If there are no replacement values, just create an erase instead.
813   if (replOperands.empty()) {
814     builder.create<pdl_interp::EraseOp>(
815         replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue()));
816     return;
817   }
818 
819   builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(),
820                                         mapRewriteValue(replaceOp.getOpValue()),
821                                         replOperands);
822 }
823 
824 void PatternLowering::generateRewriter(
825     pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
826     function_ref<Value(Value)> mapRewriteValue) {
827   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
828       resultOp.getLoc(), builder.getType<pdl::ValueType>(),
829       mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
830 }
831 
832 void PatternLowering::generateRewriter(
833     pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
834     function_ref<Value(Value)> mapRewriteValue) {
835   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
836       resultOp.getLoc(), resultOp.getType(),
837       mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
838 }
839 
840 void PatternLowering::generateRewriter(
841     pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
842     function_ref<Value(Value)> mapRewriteValue) {
843   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
844   // type.
845   if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
846     rewriteValues[typeOp] =
847         builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
848   }
849 }
850 
851 void PatternLowering::generateRewriter(
852     pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
853     function_ref<Value(Value)> mapRewriteValue) {
854   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
855   // type.
856   if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
857     rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
858         typeOp.getLoc(), typeOp.getType(), typeAttr);
859   }
860 }
861 
862 void PatternLowering::generateOperationResultTypeRewriter(
863     pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
864     SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
865     bool &hasInferredResultTypes) {
866   Block *rewriterBlock = op->getBlock();
867 
868   // Try to handle resolution for each of the result types individually. This is
869   // preferred over type inferrence because it will allow for us to use existing
870   // types directly, as opposed to trying to rebuild the type list.
871   OperandRange resultTypeValues = op.getTypeValues();
872   auto tryResolveResultTypes = [&] {
873     types.reserve(resultTypeValues.size());
874     for (const auto &it : llvm::enumerate(resultTypeValues)) {
875       Value resultType = it.value();
876 
877       // Check for an already translated value.
878       if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
879         types.push_back(existingRewriteValue);
880         continue;
881       }
882 
883       // Check for an input from the matcher.
884       if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
885         types.push_back(mapRewriteValue(resultType));
886         continue;
887       }
888 
889       // Otherwise, we couldn't infer the result types. Bail out here to see if
890       // we can infer the types for this operation from another way.
891       types.clear();
892       return failure();
893     }
894     return success();
895   };
896   if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes()))
897     return;
898 
899   // Otherwise, check if the operation has type inference support itself.
900   if (op.hasTypeInference()) {
901     hasInferredResultTypes = true;
902     return;
903   }
904 
905   // Look for an operation that was replaced by `op`. The result types will be
906   // inferred from the results that were replaced.
907   for (OpOperand &use : op.getOp().getUses()) {
908     // Check that the use corresponds to a ReplaceOp and that it is the
909     // replacement value, not the operation being replaced.
910     pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
911     if (!replOpUser || use.getOperandNumber() == 0)
912       continue;
913     // Make sure the replaced operation was defined before this one. PDL
914     // rewrites only have single block regions, so if the op isn't in the
915     // rewriter block (i.e. the current block of the operation) we already know
916     // it dominates (i.e. it's in the matcher).
917     Value replOpVal = replOpUser.getOpValue();
918     Operation *replacedOp = replOpVal.getDefiningOp();
919     if (replacedOp->getBlock() == rewriterBlock &&
920         !replacedOp->isBeforeInBlock(op))
921       continue;
922 
923     Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
924         replacedOp->getLoc(), mapRewriteValue(replOpVal));
925     types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
926         replacedOp->getLoc(), replacedOpResults));
927     return;
928   }
929 
930   // If the types could not be inferred from any context and there weren't any
931   // explicit result types, assume the user actually meant for the operation to
932   // have no results.
933   if (resultTypeValues.empty())
934     return;
935 
936   // The verifier asserts that the result types of each pdl.getOperation can be
937   // inferred. If we reach here, there is a bug either in the logic above or
938   // in the verifier for pdl.getOperation.
939   op->emitOpError() << "unable to infer result type for operation";
940   llvm_unreachable("unable to infer result type for operation");
941 }
942 
943 //===----------------------------------------------------------------------===//
944 // Conversion Pass
945 //===----------------------------------------------------------------------===//
946 
947 namespace {
948 struct PDLToPDLInterpPass
949     : public impl::ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
950   PDLToPDLInterpPass() = default;
951   PDLToPDLInterpPass(const PDLToPDLInterpPass &rhs) = default;
952   PDLToPDLInterpPass(DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
953       : configMap(&configMap) {}
954   void runOnOperation() final;
955 
956   /// A map containing the configuration for each pattern.
957   DenseMap<Operation *, PDLPatternConfigSet *> *configMap = nullptr;
958 };
959 } // namespace
960 
961 /// Convert the given module containing PDL pattern operations into a PDL
962 /// Interpreter operations.
963 void PDLToPDLInterpPass::runOnOperation() {
964   ModuleOp module = getOperation();
965 
966   // Create the main matcher function This function contains all of the match
967   // related functionality from patterns in the module.
968   OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
969   auto matcherFunc = builder.create<pdl_interp::FuncOp>(
970       module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
971       builder.getFunctionType(builder.getType<pdl::OperationType>(),
972                               /*results=*/std::nullopt),
973       /*attrs=*/std::nullopt);
974 
975   // Create a nested module to hold the functions invoked for rewriting the IR
976   // after a successful match.
977   ModuleOp rewriterModule = builder.create<ModuleOp>(
978       module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
979 
980   // Generate the code for the patterns within the module.
981   PatternLowering generator(matcherFunc, rewriterModule, configMap);
982   generator.lower(module);
983 
984   // After generation, delete all of the pattern operations.
985   for (pdl::PatternOp pattern :
986        llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) {
987     // Drop the now dead config mappings.
988     if (configMap)
989       configMap->erase(pattern);
990 
991     pattern.erase();
992   }
993 }
994 
995 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
996   return std::make_unique<PDLToPDLInterpPass>();
997 }
998 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass(
999     DenseMap<Operation *, PDLPatternConfigSet *> &configMap) {
1000   return std::make_unique<PDLToPDLInterpPass>(configMap);
1001 }
1002