xref: /llvm-project/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp (revision 039b969b32b64b64123dce30dd28ec4e343d893f)
1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
10 #include "../PassDetail.h"
11 #include "PredicateTree.h"
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/ScopedHashTable.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 using namespace mlir;
24 using namespace mlir::pdl_to_pdl_interp;
25 
26 //===----------------------------------------------------------------------===//
27 // PatternLowering
28 //===----------------------------------------------------------------------===//
29 
30 namespace {
31 /// This class generators operations within the PDL Interpreter dialect from a
32 /// given module containing PDL pattern operations.
33 struct PatternLowering {
34 public:
35   PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule);
36 
37   /// Generate code for matching and rewriting based on the pattern operations
38   /// within the module.
39   void lower(ModuleOp module);
40 
41 private:
42   using ValueMap = llvm::ScopedHashTable<Position *, Value>;
43   using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
44 
45   /// Generate interpreter operations for the tree rooted at the given matcher
46   /// node, in the specified region.
47   Block *generateMatcher(MatcherNode &node, Region &region);
48 
49   /// Get or create an access to the provided positional value in the current
50   /// block. This operation may mutate the provided block pointer if nested
51   /// regions (i.e., pdl_interp.iterate) are required.
52   Value getValueAt(Block *&currentBlock, Position *pos);
53 
54   /// Create the interpreter predicate operations. This operation may mutate the
55   /// provided current block pointer if nested regions (iterates) are required.
56   void generate(BoolNode *boolNode, Block *&currentBlock, Value val);
57 
58   /// Create the interpreter switch / predicate operations, with several case
59   /// destinations. This operation never mutates the provided current block
60   /// pointer, because the switch operation does not need Values beyond `val`.
61   void generate(SwitchNode *switchNode, Block *currentBlock, Value val);
62 
63   /// Create the interpreter operations to record a successful pattern match
64   /// using the contained root operation. This operation may mutate the current
65   /// block pointer if nested regions (i.e., pdl_interp.iterate) are required.
66   void generate(SuccessNode *successNode, Block *&currentBlock);
67 
68   /// Generate a rewriter function for the given pattern operation, and returns
69   /// a reference to that function.
70   SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
71                                  SmallVectorImpl<Position *> &usedMatchValues);
72 
73   /// Generate the rewriter code for the given operation.
74   void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
75                         DenseMap<Value, Value> &rewriteValues,
76                         function_ref<Value(Value)> mapRewriteValue);
77   void generateRewriter(pdl::AttributeOp attrOp,
78                         DenseMap<Value, Value> &rewriteValues,
79                         function_ref<Value(Value)> mapRewriteValue);
80   void generateRewriter(pdl::EraseOp eraseOp,
81                         DenseMap<Value, Value> &rewriteValues,
82                         function_ref<Value(Value)> mapRewriteValue);
83   void generateRewriter(pdl::OperationOp operationOp,
84                         DenseMap<Value, Value> &rewriteValues,
85                         function_ref<Value(Value)> mapRewriteValue);
86   void generateRewriter(pdl::ReplaceOp replaceOp,
87                         DenseMap<Value, Value> &rewriteValues,
88                         function_ref<Value(Value)> mapRewriteValue);
89   void generateRewriter(pdl::ResultOp resultOp,
90                         DenseMap<Value, Value> &rewriteValues,
91                         function_ref<Value(Value)> mapRewriteValue);
92   void generateRewriter(pdl::ResultsOp resultOp,
93                         DenseMap<Value, Value> &rewriteValues,
94                         function_ref<Value(Value)> mapRewriteValue);
95   void generateRewriter(pdl::TypeOp typeOp,
96                         DenseMap<Value, Value> &rewriteValues,
97                         function_ref<Value(Value)> mapRewriteValue);
98   void generateRewriter(pdl::TypesOp typeOp,
99                         DenseMap<Value, Value> &rewriteValues,
100                         function_ref<Value(Value)> mapRewriteValue);
101 
102   /// Generate the values used for resolving the result types of an operation
103   /// created within a dag rewriter region. If the result types of the operation
104   /// should be inferred, `hasInferredResultTypes` is set to true.
105   void generateOperationResultTypeRewriter(
106       pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
107       SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
108       bool &hasInferredResultTypes);
109 
110   /// A builder to use when generating interpreter operations.
111   OpBuilder builder;
112 
113   /// The matcher function used for all match related logic within PDL patterns.
114   pdl_interp::FuncOp matcherFunc;
115 
116   /// The rewriter module containing the all rewrite related logic within PDL
117   /// patterns.
118   ModuleOp rewriterModule;
119 
120   /// The symbol table of the rewriter module used for insertion.
121   SymbolTable rewriterSymbolTable;
122 
123   /// A scoped map connecting a position with the corresponding interpreter
124   /// value.
125   ValueMap values;
126 
127   /// A stack of blocks used as the failure destination for matcher nodes that
128   /// don't have an explicit failure path.
129   SmallVector<Block *, 8> failureBlockStack;
130 
131   /// A mapping between values defined in a pattern match, and the corresponding
132   /// positional value.
133   DenseMap<Value, Position *> valueToPosition;
134 
135   /// The set of operation values whose whose location will be used for newly
136   /// generated operations.
137   SetVector<Value> locOps;
138 };
139 } // namespace
140 
141 PatternLowering::PatternLowering(pdl_interp::FuncOp matcherFunc,
142                                  ModuleOp rewriterModule)
143     : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
144       rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {}
145 
146 void PatternLowering::lower(ModuleOp module) {
147   PredicateUniquer predicateUniquer;
148   PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
149 
150   // Define top-level scope for the arguments to the matcher function.
151   ValueMapScope topLevelValueScope(values);
152 
153   // Insert the root operation, i.e. argument to the matcher, at the root
154   // position.
155   Block *matcherEntryBlock = &matcherFunc.front();
156   values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
157 
158   // Generate a root matcher node from the provided PDL module.
159   std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
160       module, predicateBuilder, valueToPosition);
161   Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
162   assert(failureBlockStack.empty() && "failed to empty the stack");
163 
164   // After generation, merged the first matched block into the entry.
165   matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
166                                             firstMatcherBlock->getOperations());
167   firstMatcherBlock->erase();
168 }
169 
170 Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region) {
171   // Push a new scope for the values used by this matcher.
172   Block *block = &region.emplaceBlock();
173   ValueMapScope scope(values);
174 
175   // If this is the return node, simply insert the corresponding interpreter
176   // finalize.
177   if (isa<ExitNode>(node)) {
178     builder.setInsertionPointToEnd(block);
179     builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
180     return block;
181   }
182 
183   // Get the next block in the match sequence.
184   // This is intentionally executed first, before we get the value for the
185   // position associated with the node, so that we preserve an "there exist"
186   // semantics: if getting a value requires an upward traversal (going from a
187   // value to its consumers), we want to perform the check on all the consumers
188   // before we pass control to the failure node.
189   std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
190   Block *failureBlock;
191   if (failureNode) {
192     failureBlock = generateMatcher(*failureNode, region);
193     failureBlockStack.push_back(failureBlock);
194   } else {
195     assert(!failureBlockStack.empty() && "expected valid failure block");
196     failureBlock = failureBlockStack.back();
197   }
198 
199   // If this node contains a position, get the corresponding value for this
200   // block.
201   Block *currentBlock = block;
202   Position *position = node.getPosition();
203   Value val = position ? getValueAt(currentBlock, position) : Value();
204 
205   // If this value corresponds to an operation, record that we are going to use
206   // its location as part of a fused location.
207   bool isOperationValue = val && val.getType().isa<pdl::OperationType>();
208   if (isOperationValue)
209     locOps.insert(val);
210 
211   // Dispatch to the correct method based on derived node type.
212   TypeSwitch<MatcherNode *>(&node)
213       .Case<BoolNode, SwitchNode>([&](auto *derivedNode) {
214         this->generate(derivedNode, currentBlock, val);
215       })
216       .Case([&](SuccessNode *successNode) {
217         generate(successNode, currentBlock);
218       });
219 
220   // Pop all the failure blocks that were inserted due to nesting of
221   // pdl_interp.iterate.
222   while (failureBlockStack.back() != failureBlock) {
223     failureBlockStack.pop_back();
224     assert(!failureBlockStack.empty() && "unable to locate failure block");
225   }
226 
227   // Pop the new failure block.
228   if (failureNode)
229     failureBlockStack.pop_back();
230 
231   if (isOperationValue)
232     locOps.remove(val);
233 
234   return block;
235 }
236 
237 Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
238   if (Value val = values.lookup(pos))
239     return val;
240 
241   // Get the value for the parent position.
242   Value parentVal;
243   if (Position *parent = pos->getParent())
244     parentVal = getValueAt(currentBlock, parent);
245 
246   // TODO: Use a location from the position.
247   Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
248   builder.setInsertionPointToEnd(currentBlock);
249   Value value;
250   switch (pos->getKind()) {
251   case Predicates::OperationPos: {
252     auto *operationPos = cast<OperationPosition>(pos);
253     if (operationPos->isOperandDefiningOp())
254       // Standard (downward) traversal which directly follows the defining op.
255       value = builder.create<pdl_interp::GetDefiningOpOp>(
256           loc, builder.getType<pdl::OperationType>(), parentVal);
257     else
258       // A passthrough operation position.
259       value = parentVal;
260     break;
261   }
262   case Predicates::UsersPos: {
263     auto *usersPos = cast<UsersPosition>(pos);
264 
265     // The first operation retrieves the representative value of a range.
266     // This applies only when the parent is a range of values and we were
267     // requested to use a representative value (e.g., upward traversal).
268     if (parentVal.getType().isa<pdl::RangeType>() &&
269         usersPos->useRepresentative())
270       value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
271     else
272       value = parentVal;
273 
274     // The second operation retrieves the users.
275     value = builder.create<pdl_interp::GetUsersOp>(loc, value);
276     break;
277   }
278   case Predicates::ForEachPos: {
279     assert(!failureBlockStack.empty() && "expected valid failure block");
280     auto foreach = builder.create<pdl_interp::ForEachOp>(
281         loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
282     value = foreach.getLoopVariable();
283 
284     // Create the continuation block.
285     Block *continueBlock = builder.createBlock(&foreach.getRegion());
286     builder.create<pdl_interp::ContinueOp>(loc);
287     failureBlockStack.push_back(continueBlock);
288 
289     currentBlock = &foreach.getRegion().front();
290     break;
291   }
292   case Predicates::OperandPos: {
293     auto *operandPos = cast<OperandPosition>(pos);
294     value = builder.create<pdl_interp::GetOperandOp>(
295         loc, builder.getType<pdl::ValueType>(), parentVal,
296         operandPos->getOperandNumber());
297     break;
298   }
299   case Predicates::OperandGroupPos: {
300     auto *operandPos = cast<OperandGroupPosition>(pos);
301     Type valueTy = builder.getType<pdl::ValueType>();
302     value = builder.create<pdl_interp::GetOperandsOp>(
303         loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
304         parentVal, operandPos->getOperandGroupNumber());
305     break;
306   }
307   case Predicates::AttributePos: {
308     auto *attrPos = cast<AttributePosition>(pos);
309     value = builder.create<pdl_interp::GetAttributeOp>(
310         loc, builder.getType<pdl::AttributeType>(), parentVal,
311         attrPos->getName().strref());
312     break;
313   }
314   case Predicates::TypePos: {
315     if (parentVal.getType().isa<pdl::AttributeType>())
316       value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
317     else
318       value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
319     break;
320   }
321   case Predicates::ResultPos: {
322     auto *resPos = cast<ResultPosition>(pos);
323     value = builder.create<pdl_interp::GetResultOp>(
324         loc, builder.getType<pdl::ValueType>(), parentVal,
325         resPos->getResultNumber());
326     break;
327   }
328   case Predicates::ResultGroupPos: {
329     auto *resPos = cast<ResultGroupPosition>(pos);
330     Type valueTy = builder.getType<pdl::ValueType>();
331     value = builder.create<pdl_interp::GetResultsOp>(
332         loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
333         parentVal, resPos->getResultGroupNumber());
334     break;
335   }
336   case Predicates::AttributeLiteralPos: {
337     auto *attrPos = cast<AttributeLiteralPosition>(pos);
338     value =
339         builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
340     break;
341   }
342   case Predicates::TypeLiteralPos: {
343     auto *typePos = cast<TypeLiteralPosition>(pos);
344     Attribute rawTypeAttr = typePos->getValue();
345     if (TypeAttr typeAttr = rawTypeAttr.dyn_cast<TypeAttr>())
346       value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
347     else
348       value = builder.create<pdl_interp::CreateTypesOp>(
349           loc, rawTypeAttr.cast<ArrayAttr>());
350     break;
351   }
352   default:
353     llvm_unreachable("Generating unknown Position getter");
354     break;
355   }
356 
357   values.insert(pos, value);
358   return value;
359 }
360 
361 void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
362                                Value val) {
363   Location loc = val.getLoc();
364   Qualifier *question = boolNode->getQuestion();
365   Qualifier *answer = boolNode->getAnswer();
366   Region *region = currentBlock->getParent();
367 
368   // Execute the getValue queries first, so that we create success
369   // matcher in the correct (possibly nested) region.
370   SmallVector<Value> args;
371   if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
372     args = {getValueAt(currentBlock, equalToQuestion->getValue())};
373   } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
374     for (Position *position : cstQuestion->getArgs())
375       args.push_back(getValueAt(currentBlock, position));
376   }
377 
378   // Generate the matcher in the current (potentially nested) region
379   // and get the failure successor.
380   Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
381   Block *failure = failureBlockStack.back();
382 
383   // Finally, create the predicate.
384   builder.setInsertionPointToEnd(currentBlock);
385   Predicates::Kind kind = question->getKind();
386   switch (kind) {
387   case Predicates::IsNotNullQuestion:
388     builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure);
389     break;
390   case Predicates::OperationNameQuestion: {
391     auto *opNameAnswer = cast<OperationNameAnswer>(answer);
392     builder.create<pdl_interp::CheckOperationNameOp>(
393         loc, val, opNameAnswer->getValue().getStringRef(), success, failure);
394     break;
395   }
396   case Predicates::TypeQuestion: {
397     auto *ans = cast<TypeAnswer>(answer);
398     if (val.getType().isa<pdl::RangeType>())
399       builder.create<pdl_interp::CheckTypesOp>(
400           loc, val, ans->getValue().cast<ArrayAttr>(), success, failure);
401     else
402       builder.create<pdl_interp::CheckTypeOp>(
403           loc, val, ans->getValue().cast<TypeAttr>(), success, failure);
404     break;
405   }
406   case Predicates::AttributeQuestion: {
407     auto *ans = cast<AttributeAnswer>(answer);
408     builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
409                                                  success, failure);
410     break;
411   }
412   case Predicates::OperandCountAtLeastQuestion:
413   case Predicates::OperandCountQuestion:
414     builder.create<pdl_interp::CheckOperandCountOp>(
415         loc, val, cast<UnsignedAnswer>(answer)->getValue(),
416         /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
417         success, failure);
418     break;
419   case Predicates::ResultCountAtLeastQuestion:
420   case Predicates::ResultCountQuestion:
421     builder.create<pdl_interp::CheckResultCountOp>(
422         loc, val, cast<UnsignedAnswer>(answer)->getValue(),
423         /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
424         success, failure);
425     break;
426   case Predicates::EqualToQuestion: {
427     bool trueAnswer = isa<TrueAnswer>(answer);
428     builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(),
429                                            trueAnswer ? success : failure,
430                                            trueAnswer ? failure : success);
431     break;
432   }
433   case Predicates::ConstraintQuestion: {
434     auto *cstQuestion = cast<ConstraintQuestion>(question);
435     builder.create<pdl_interp::ApplyConstraintOp>(loc, cstQuestion->getName(),
436                                                   args, success, failure);
437     break;
438   }
439   default:
440     llvm_unreachable("Generating unknown Predicate operation");
441   }
442 }
443 
444 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
445 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
446                            llvm::MapVector<Qualifier *, Block *> &dests) {
447   std::vector<ValT> values;
448   std::vector<Block *> blocks;
449   values.reserve(dests.size());
450   blocks.reserve(dests.size());
451   for (const auto &it : dests) {
452     blocks.push_back(it.second);
453     values.push_back(cast<PredT>(it.first)->getValue());
454   }
455   builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
456 }
457 
458 void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
459                                Value val) {
460   Qualifier *question = switchNode->getQuestion();
461   Region *region = currentBlock->getParent();
462   Block *defaultDest = failureBlockStack.back();
463 
464   // If the switch question is not an exact answer, i.e. for the `at_least`
465   // cases, we generate a special block sequence.
466   Predicates::Kind kind = question->getKind();
467   if (kind == Predicates::OperandCountAtLeastQuestion ||
468       kind == Predicates::ResultCountAtLeastQuestion) {
469     // Order the children such that the cases are in reverse numerical order.
470     SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
471         llvm::seq<unsigned>(0, switchNode->getChildren().size()));
472     llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
473       return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
474              cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
475     });
476 
477     // Build the destination for each child using the next highest child as a
478     // a failure destination. This essentially creates the following control
479     // flow:
480     //
481     // if (operand_count < 1)
482     //   goto failure
483     // if (child1.match())
484     //   ...
485     //
486     // if (operand_count < 2)
487     //   goto failure
488     // if (child2.match())
489     //   ...
490     //
491     // failure:
492     //   ...
493     //
494     failureBlockStack.push_back(defaultDest);
495     Location loc = val.getLoc();
496     for (unsigned idx : sortedChildren) {
497       auto &child = switchNode->getChild(idx);
498       Block *childBlock = generateMatcher(*child.second, *region);
499       Block *predicateBlock = builder.createBlock(childBlock);
500       builder.setInsertionPointToEnd(predicateBlock);
501       unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
502       switch (kind) {
503       case Predicates::OperandCountAtLeastQuestion:
504         builder.create<pdl_interp::CheckOperandCountOp>(
505             loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
506         break;
507       case Predicates::ResultCountAtLeastQuestion:
508         builder.create<pdl_interp::CheckResultCountOp>(
509             loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
510         break;
511       default:
512         llvm_unreachable("Generating invalid AtLeast operation");
513       }
514       failureBlockStack.back() = predicateBlock;
515     }
516     Block *firstPredicateBlock = failureBlockStack.pop_back_val();
517     currentBlock->getOperations().splice(currentBlock->end(),
518                                          firstPredicateBlock->getOperations());
519     firstPredicateBlock->erase();
520     return;
521   }
522 
523   // Otherwise, generate each of the children and generate an interpreter
524   // switch.
525   llvm::MapVector<Qualifier *, Block *> children;
526   for (auto &it : switchNode->getChildren())
527     children.insert({it.first, generateMatcher(*it.second, *region)});
528   builder.setInsertionPointToEnd(currentBlock);
529 
530   switch (question->getKind()) {
531   case Predicates::OperandCountQuestion:
532     return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
533                           int32_t>(val, defaultDest, builder, children);
534   case Predicates::ResultCountQuestion:
535     return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
536                           int32_t>(val, defaultDest, builder, children);
537   case Predicates::OperationNameQuestion:
538     return createSwitchOp<pdl_interp::SwitchOperationNameOp,
539                           OperationNameAnswer>(val, defaultDest, builder,
540                                                children);
541   case Predicates::TypeQuestion:
542     if (val.getType().isa<pdl::RangeType>()) {
543       return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
544           val, defaultDest, builder, children);
545     }
546     return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
547         val, defaultDest, builder, children);
548   case Predicates::AttributeQuestion:
549     return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
550         val, defaultDest, builder, children);
551   default:
552     llvm_unreachable("Generating unknown switch predicate.");
553   }
554 }
555 
556 void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {
557   pdl::PatternOp pattern = successNode->getPattern();
558   Value root = successNode->getRoot();
559 
560   // Generate a rewriter for the pattern this success node represents, and track
561   // any values used from the match region.
562   SmallVector<Position *, 8> usedMatchValues;
563   SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
564 
565   // Process any values used in the rewrite that are defined in the match.
566   std::vector<Value> mappedMatchValues;
567   mappedMatchValues.reserve(usedMatchValues.size());
568   for (Position *position : usedMatchValues)
569     mappedMatchValues.push_back(getValueAt(currentBlock, position));
570 
571   // Collect the set of operations generated by the rewriter.
572   SmallVector<StringRef, 4> generatedOps;
573   for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>())
574     generatedOps.push_back(*op.name());
575   ArrayAttr generatedOpsAttr;
576   if (!generatedOps.empty())
577     generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
578 
579   // Grab the root kind if present.
580   StringAttr rootKindAttr;
581   if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
582     if (Optional<StringRef> rootKind = rootOp.name())
583       rootKindAttr = builder.getStringAttr(*rootKind);
584 
585   builder.setInsertionPointToEnd(currentBlock);
586   builder.create<pdl_interp::RecordMatchOp>(
587       pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
588       rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(),
589       failureBlockStack.back());
590 }
591 
592 SymbolRefAttr PatternLowering::generateRewriter(
593     pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
594   builder.setInsertionPointToEnd(rewriterModule.getBody());
595   auto rewriterFunc = builder.create<pdl_interp::FuncOp>(
596       pattern.getLoc(), "pdl_generated_rewriter",
597       builder.getFunctionType(llvm::None, llvm::None));
598   rewriterSymbolTable.insert(rewriterFunc);
599 
600   // Generate the rewriter function body.
601   builder.setInsertionPointToEnd(&rewriterFunc.front());
602 
603   // Map an input operand of the pattern to a generated interpreter value.
604   DenseMap<Value, Value> rewriteValues;
605   auto mapRewriteValue = [&](Value oldValue) {
606     Value &newValue = rewriteValues[oldValue];
607     if (newValue)
608       return newValue;
609 
610     // Prefer materializing constants directly when possible.
611     Operation *oldOp = oldValue.getDefiningOp();
612     if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
613       if (Attribute value = attrOp.valueAttr()) {
614         return newValue = builder.create<pdl_interp::CreateAttributeOp>(
615                    attrOp.getLoc(), value);
616       }
617     } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
618       if (TypeAttr type = typeOp.typeAttr()) {
619         return newValue = builder.create<pdl_interp::CreateTypeOp>(
620                    typeOp.getLoc(), type);
621       }
622     } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
623       if (ArrayAttr type = typeOp.typesAttr()) {
624         return newValue = builder.create<pdl_interp::CreateTypesOp>(
625                    typeOp.getLoc(), typeOp.getType(), type);
626       }
627     }
628 
629     // Otherwise, add this as an input to the rewriter.
630     Position *inputPos = valueToPosition.lookup(oldValue);
631     assert(inputPos && "expected value to be a pattern input");
632     usedMatchValues.push_back(inputPos);
633     return newValue = rewriterFunc.front().addArgument(oldValue.getType(),
634                                                        oldValue.getLoc());
635   };
636 
637   // If this is a custom rewriter, simply dispatch to the registered rewrite
638   // method.
639   pdl::RewriteOp rewriter = pattern.getRewriter();
640   if (StringAttr rewriteName = rewriter.nameAttr()) {
641     SmallVector<Value> args;
642     if (rewriter.root())
643       args.push_back(mapRewriteValue(rewriter.root()));
644     auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue);
645     args.append(mappedArgs.begin(), mappedArgs.end());
646     builder.create<pdl_interp::ApplyRewriteOp>(
647         rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args);
648   } else {
649     // Otherwise this is a dag rewriter defined using PDL operations.
650     for (Operation &rewriteOp : *rewriter.getBody()) {
651       llvm::TypeSwitch<Operation *>(&rewriteOp)
652           .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
653                 pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp,
654                 pdl::TypeOp, pdl::TypesOp>([&](auto op) {
655             this->generateRewriter(op, rewriteValues, mapRewriteValue);
656           });
657     }
658   }
659 
660   // Update the signature of the rewrite function.
661   rewriterFunc.setType(builder.getFunctionType(
662       llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
663       /*results=*/llvm::None));
664 
665   builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
666   return SymbolRefAttr::get(
667       builder.getContext(),
668       pdl_interp::PDLInterpDialect::getRewriterModuleName(),
669       SymbolRefAttr::get(rewriterFunc));
670 }
671 
672 void PatternLowering::generateRewriter(
673     pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
674     function_ref<Value(Value)> mapRewriteValue) {
675   SmallVector<Value, 2> arguments;
676   for (Value argument : rewriteOp.args())
677     arguments.push_back(mapRewriteValue(argument));
678   auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
679       rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(),
680       arguments);
681   for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
682     rewriteValues[std::get<0>(it)] = std::get<1>(it);
683 }
684 
685 void PatternLowering::generateRewriter(
686     pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
687     function_ref<Value(Value)> mapRewriteValue) {
688   Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
689       attrOp.getLoc(), attrOp.valueAttr());
690   rewriteValues[attrOp] = newAttr;
691 }
692 
693 void PatternLowering::generateRewriter(
694     pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
695     function_ref<Value(Value)> mapRewriteValue) {
696   builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
697                                       mapRewriteValue(eraseOp.operation()));
698 }
699 
700 void PatternLowering::generateRewriter(
701     pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
702     function_ref<Value(Value)> mapRewriteValue) {
703   SmallVector<Value, 4> operands;
704   for (Value operand : operationOp.operands())
705     operands.push_back(mapRewriteValue(operand));
706 
707   SmallVector<Value, 4> attributes;
708   for (Value attr : operationOp.attributes())
709     attributes.push_back(mapRewriteValue(attr));
710 
711   bool hasInferredResultTypes = false;
712   SmallVector<Value, 2> types;
713   generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
714                                       rewriteValues, hasInferredResultTypes);
715 
716   // Create the new operation.
717   Location loc = operationOp.getLoc();
718   Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
719       loc, *operationOp.name(), types, hasInferredResultTypes, operands,
720       attributes, operationOp.attributeNames());
721   rewriteValues[operationOp.op()] = createdOp;
722 
723   // Generate accesses for any results that have their types constrained.
724   // Handle the case where there is a single range representing all of the
725   // result types.
726   OperandRange resultTys = operationOp.types();
727   if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) {
728     Value &type = rewriteValues[resultTys[0]];
729     if (!type) {
730       auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
731       type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
732     }
733     return;
734   }
735 
736   // Otherwise, populate the individual results.
737   bool seenVariableLength = false;
738   Type valueTy = builder.getType<pdl::ValueType>();
739   Type valueRangeTy = pdl::RangeType::get(valueTy);
740   for (const auto &it : llvm::enumerate(resultTys)) {
741     Value &type = rewriteValues[it.value()];
742     if (type)
743       continue;
744     bool isVariadic = it.value().getType().isa<pdl::RangeType>();
745     seenVariableLength |= isVariadic;
746 
747     // After a variable length result has been seen, we need to use result
748     // groups because the exact index of the result is not statically known.
749     Value resultVal;
750     if (seenVariableLength)
751       resultVal = builder.create<pdl_interp::GetResultsOp>(
752           loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
753     else
754       resultVal = builder.create<pdl_interp::GetResultOp>(
755           loc, valueTy, createdOp, it.index());
756     type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
757   }
758 }
759 
760 void PatternLowering::generateRewriter(
761     pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
762     function_ref<Value(Value)> mapRewriteValue) {
763   SmallVector<Value, 4> replOperands;
764 
765   // If the replacement was another operation, get its results. `pdl` allows
766   // for using an operation for simplicitly, but the interpreter isn't as
767   // user facing.
768   if (Value replOp = replaceOp.replOperation()) {
769     // Don't use replace if we know the replaced operation has no results.
770     auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>();
771     if (!opOp || !opOp.types().empty()) {
772       replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
773           replOp.getLoc(), mapRewriteValue(replOp)));
774     }
775   } else {
776     for (Value operand : replaceOp.replValues())
777       replOperands.push_back(mapRewriteValue(operand));
778   }
779 
780   // If there are no replacement values, just create an erase instead.
781   if (replOperands.empty()) {
782     builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(),
783                                         mapRewriteValue(replaceOp.operation()));
784     return;
785   }
786 
787   builder.create<pdl_interp::ReplaceOp>(
788       replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands);
789 }
790 
791 void PatternLowering::generateRewriter(
792     pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
793     function_ref<Value(Value)> mapRewriteValue) {
794   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
795       resultOp.getLoc(), builder.getType<pdl::ValueType>(),
796       mapRewriteValue(resultOp.parent()), resultOp.index());
797 }
798 
799 void PatternLowering::generateRewriter(
800     pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
801     function_ref<Value(Value)> mapRewriteValue) {
802   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
803       resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()),
804       resultOp.index());
805 }
806 
807 void PatternLowering::generateRewriter(
808     pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
809     function_ref<Value(Value)> mapRewriteValue) {
810   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
811   // type.
812   if (TypeAttr typeAttr = typeOp.typeAttr()) {
813     rewriteValues[typeOp] =
814         builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
815   }
816 }
817 
818 void PatternLowering::generateRewriter(
819     pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
820     function_ref<Value(Value)> mapRewriteValue) {
821   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
822   // type.
823   if (ArrayAttr typeAttr = typeOp.typesAttr()) {
824     rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
825         typeOp.getLoc(), typeOp.getType(), typeAttr);
826   }
827 }
828 
829 void PatternLowering::generateOperationResultTypeRewriter(
830     pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
831     SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
832     bool &hasInferredResultTypes) {
833   // Look for an operation that was replaced by `op`. The result types will be
834   // inferred from the results that were replaced.
835   Block *rewriterBlock = op->getBlock();
836   for (OpOperand &use : op.op().getUses()) {
837     // Check that the use corresponds to a ReplaceOp and that it is the
838     // replacement value, not the operation being replaced.
839     pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
840     if (!replOpUser || use.getOperandNumber() == 0)
841       continue;
842     // Make sure the replaced operation was defined before this one.
843     Value replOpVal = replOpUser.operation();
844     Operation *replacedOp = replOpVal.getDefiningOp();
845     if (replacedOp->getBlock() == rewriterBlock &&
846         !replacedOp->isBeforeInBlock(op))
847       continue;
848 
849     Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
850         replacedOp->getLoc(), mapRewriteValue(replOpVal));
851     types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
852         replacedOp->getLoc(), replacedOpResults));
853     return;
854   }
855 
856   // Try to handle resolution for each of the result types individually. This is
857   // preferred over type inferrence because it will allow for us to use existing
858   // types directly, as opposed to trying to rebuild the type list.
859   OperandRange resultTypeValues = op.types();
860   auto tryResolveResultTypes = [&] {
861     types.reserve(resultTypeValues.size());
862     for (const auto &it : llvm::enumerate(resultTypeValues)) {
863       Value resultType = it.value();
864 
865       // Check for an already translated value.
866       if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
867         types.push_back(existingRewriteValue);
868         continue;
869       }
870 
871       // Check for an input from the matcher.
872       if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
873         types.push_back(mapRewriteValue(resultType));
874         continue;
875       }
876 
877       // Otherwise, we couldn't infer the result types. Bail out here to see if
878       // we can infer the types for this operation from another way.
879       types.clear();
880       return failure();
881     }
882     return success();
883   };
884   if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes()))
885     return;
886 
887   // Otherwise, check if the operation has type inference support itself.
888   if (op.hasTypeInference()) {
889     hasInferredResultTypes = true;
890     return;
891   }
892 
893   // If the types could not be inferred from any context and there weren't any
894   // explicit result types, assume the user actually meant for the operation to
895   // have no results.
896   if (resultTypeValues.empty())
897     return;
898 
899   // The verifier asserts that the result types of each pdl.operation can be
900   // inferred. If we reach here, there is a bug either in the logic above or
901   // in the verifier for pdl.operation.
902   op->emitOpError() << "unable to infer result type for operation";
903   llvm_unreachable("unable to infer result type for operation");
904 }
905 
906 //===----------------------------------------------------------------------===//
907 // Conversion Pass
908 //===----------------------------------------------------------------------===//
909 
910 namespace {
911 struct PDLToPDLInterpPass
912     : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
913   void runOnOperation() final;
914 };
915 } // namespace
916 
917 /// Convert the given module containing PDL pattern operations into a PDL
918 /// Interpreter operations.
919 void PDLToPDLInterpPass::runOnOperation() {
920   ModuleOp module = getOperation();
921 
922   // Create the main matcher function This function contains all of the match
923   // related functionality from patterns in the module.
924   OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
925   auto matcherFunc = builder.create<pdl_interp::FuncOp>(
926       module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
927       builder.getFunctionType(builder.getType<pdl::OperationType>(),
928                               /*results=*/llvm::None),
929       /*attrs=*/llvm::None);
930 
931   // Create a nested module to hold the functions invoked for rewriting the IR
932   // after a successful match.
933   ModuleOp rewriterModule = builder.create<ModuleOp>(
934       module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
935 
936   // Generate the code for the patterns within the module.
937   PatternLowering generator(matcherFunc, rewriterModule);
938   generator.lower(module);
939 
940   // After generation, delete all of the pattern operations.
941   for (pdl::PatternOp pattern :
942        llvm::make_early_inc_range(module.getOps<pdl::PatternOp>()))
943     pattern.erase();
944 }
945 
946 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
947   return std::make_unique<PDLToPDLInterpPass>();
948 }
949