xref: /llvm-project/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp (revision 8ec28af8eaff5acd0df3e53340159c034f08533d)
1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
10 
11 #include "PredicateTree.h"
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/ScopedHashTable.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_CONVERTPDLTOPDLINTERP
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::pdl_to_pdl_interp;
30 
31 //===----------------------------------------------------------------------===//
32 // PatternLowering
33 //===----------------------------------------------------------------------===//
34 
35 namespace {
36 /// This class generators operations within the PDL Interpreter dialect from a
37 /// given module containing PDL pattern operations.
38 struct PatternLowering {
39 public:
40   PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
41                   DenseMap<Operation *, PDLPatternConfigSet *> *configMap);
42 
43   /// Generate code for matching and rewriting based on the pattern operations
44   /// within the module.
45   void lower(ModuleOp module);
46 
47 private:
48   using ValueMap = llvm::ScopedHashTable<Position *, Value>;
49   using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
50 
51   /// Generate interpreter operations for the tree rooted at the given matcher
52   /// node, in the specified region.
53   Block *generateMatcher(MatcherNode &node, Region &region,
54                          Block *block = nullptr);
55 
56   /// Get or create an access to the provided positional value in the current
57   /// block. This operation may mutate the provided block pointer if nested
58   /// regions (i.e., pdl_interp.iterate) are required.
59   Value getValueAt(Block *&currentBlock, Position *pos);
60 
61   /// Create the interpreter predicate operations. This operation may mutate the
62   /// provided current block pointer if nested regions (iterates) are required.
63   void generate(BoolNode *boolNode, Block *&currentBlock, Value val);
64 
65   /// Create the interpreter switch / predicate operations, with several case
66   /// destinations. This operation never mutates the provided current block
67   /// pointer, because the switch operation does not need Values beyond `val`.
68   void generate(SwitchNode *switchNode, Block *currentBlock, Value val);
69 
70   /// Create the interpreter operations to record a successful pattern match
71   /// using the contained root operation. This operation may mutate the current
72   /// block pointer if nested regions (i.e., pdl_interp.iterate) are required.
73   void generate(SuccessNode *successNode, Block *&currentBlock);
74 
75   /// Generate a rewriter function for the given pattern operation, and returns
76   /// a reference to that function.
77   SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
78                                  SmallVectorImpl<Position *> &usedMatchValues);
79 
80   /// Generate the rewriter code for the given operation.
81   void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
82                         DenseMap<Value, Value> &rewriteValues,
83                         function_ref<Value(Value)> mapRewriteValue);
84   void generateRewriter(pdl::AttributeOp attrOp,
85                         DenseMap<Value, Value> &rewriteValues,
86                         function_ref<Value(Value)> mapRewriteValue);
87   void generateRewriter(pdl::EraseOp eraseOp,
88                         DenseMap<Value, Value> &rewriteValues,
89                         function_ref<Value(Value)> mapRewriteValue);
90   void generateRewriter(pdl::OperationOp operationOp,
91                         DenseMap<Value, Value> &rewriteValues,
92                         function_ref<Value(Value)> mapRewriteValue);
93   void generateRewriter(pdl::RangeOp rangeOp,
94                         DenseMap<Value, Value> &rewriteValues,
95                         function_ref<Value(Value)> mapRewriteValue);
96   void generateRewriter(pdl::ReplaceOp replaceOp,
97                         DenseMap<Value, Value> &rewriteValues,
98                         function_ref<Value(Value)> mapRewriteValue);
99   void generateRewriter(pdl::ResultOp resultOp,
100                         DenseMap<Value, Value> &rewriteValues,
101                         function_ref<Value(Value)> mapRewriteValue);
102   void generateRewriter(pdl::ResultsOp resultOp,
103                         DenseMap<Value, Value> &rewriteValues,
104                         function_ref<Value(Value)> mapRewriteValue);
105   void generateRewriter(pdl::TypeOp typeOp,
106                         DenseMap<Value, Value> &rewriteValues,
107                         function_ref<Value(Value)> mapRewriteValue);
108   void generateRewriter(pdl::TypesOp typeOp,
109                         DenseMap<Value, Value> &rewriteValues,
110                         function_ref<Value(Value)> mapRewriteValue);
111 
112   /// Generate the values used for resolving the result types of an operation
113   /// created within a dag rewriter region. If the result types of the operation
114   /// should be inferred, `hasInferredResultTypes` is set to true.
115   void generateOperationResultTypeRewriter(
116       pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
117       SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
118       bool &hasInferredResultTypes);
119 
120   /// A builder to use when generating interpreter operations.
121   OpBuilder builder;
122 
123   /// The matcher function used for all match related logic within PDL patterns.
124   pdl_interp::FuncOp matcherFunc;
125 
126   /// The rewriter module containing the all rewrite related logic within PDL
127   /// patterns.
128   ModuleOp rewriterModule;
129 
130   /// The symbol table of the rewriter module used for insertion.
131   SymbolTable rewriterSymbolTable;
132 
133   /// A scoped map connecting a position with the corresponding interpreter
134   /// value.
135   ValueMap values;
136 
137   /// A stack of blocks used as the failure destination for matcher nodes that
138   /// don't have an explicit failure path.
139   SmallVector<Block *, 8> failureBlockStack;
140 
141   /// A mapping between values defined in a pattern match, and the corresponding
142   /// positional value.
143   DenseMap<Value, Position *> valueToPosition;
144 
145   /// The set of operation values whose location will be used for newly
146   /// generated operations.
147   SetVector<Value> locOps;
148 
149   /// A mapping between pattern operations and the corresponding configuration
150   /// set.
151   DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
152 
153   /// A mapping from a constraint question to the ApplyConstraintOp
154   /// that implements it.
155   DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap;
156 };
157 } // namespace
158 
PatternLowering(pdl_interp::FuncOp matcherFunc,ModuleOp rewriterModule,DenseMap<Operation *,PDLPatternConfigSet * > * configMap)159 PatternLowering::PatternLowering(
160     pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
161     DenseMap<Operation *, PDLPatternConfigSet *> *configMap)
162     : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
163       rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule),
164       configMap(configMap) {}
165 
lower(ModuleOp module)166 void PatternLowering::lower(ModuleOp module) {
167   PredicateUniquer predicateUniquer;
168   PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
169 
170   // Define top-level scope for the arguments to the matcher function.
171   ValueMapScope topLevelValueScope(values);
172 
173   // Insert the root operation, i.e. argument to the matcher, at the root
174   // position.
175   Block *matcherEntryBlock = &matcherFunc.front();
176   values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
177 
178   // Generate a root matcher node from the provided PDL module.
179   std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
180       module, predicateBuilder, valueToPosition);
181   Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
182   assert(failureBlockStack.empty() && "failed to empty the stack");
183 
184   // After generation, merged the first matched block into the entry.
185   matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
186                                             firstMatcherBlock->getOperations());
187   firstMatcherBlock->erase();
188 }
189 
generateMatcher(MatcherNode & node,Region & region,Block * block)190 Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region,
191                                         Block *block) {
192   // Push a new scope for the values used by this matcher.
193   if (!block)
194     block = &region.emplaceBlock();
195   ValueMapScope scope(values);
196 
197   // If this is the return node, simply insert the corresponding interpreter
198   // finalize.
199   if (isa<ExitNode>(node)) {
200     builder.setInsertionPointToEnd(block);
201     builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
202     return block;
203   }
204 
205   // Get the next block in the match sequence.
206   // This is intentionally executed first, before we get the value for the
207   // position associated with the node, so that we preserve an "there exist"
208   // semantics: if getting a value requires an upward traversal (going from a
209   // value to its consumers), we want to perform the check on all the consumers
210   // before we pass control to the failure node.
211   std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
212   Block *failureBlock;
213   if (failureNode) {
214     failureBlock = generateMatcher(*failureNode, region);
215     failureBlockStack.push_back(failureBlock);
216   } else {
217     assert(!failureBlockStack.empty() && "expected valid failure block");
218     failureBlock = failureBlockStack.back();
219   }
220 
221   // If this node contains a position, get the corresponding value for this
222   // block.
223   Block *currentBlock = block;
224   Position *position = node.getPosition();
225   Value val = position ? getValueAt(currentBlock, position) : Value();
226 
227   // If this value corresponds to an operation, record that we are going to use
228   // its location as part of a fused location.
229   bool isOperationValue = val && isa<pdl::OperationType>(val.getType());
230   if (isOperationValue)
231     locOps.insert(val);
232 
233   // Dispatch to the correct method based on derived node type.
234   TypeSwitch<MatcherNode *>(&node)
235       .Case<BoolNode, SwitchNode>([&](auto *derivedNode) {
236         this->generate(derivedNode, currentBlock, val);
237       })
238       .Case([&](SuccessNode *successNode) {
239         generate(successNode, currentBlock);
240       });
241 
242   // Pop all the failure blocks that were inserted due to nesting of
243   // pdl_interp.iterate.
244   while (failureBlockStack.back() != failureBlock) {
245     failureBlockStack.pop_back();
246     assert(!failureBlockStack.empty() && "unable to locate failure block");
247   }
248 
249   // Pop the new failure block.
250   if (failureNode)
251     failureBlockStack.pop_back();
252 
253   if (isOperationValue)
254     locOps.remove(val);
255 
256   return block;
257 }
258 
getValueAt(Block * & currentBlock,Position * pos)259 Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
260   if (Value val = values.lookup(pos))
261     return val;
262 
263   // Get the value for the parent position.
264   Value parentVal;
265   if (Position *parent = pos->getParent())
266     parentVal = getValueAt(currentBlock, parent);
267 
268   // TODO: Use a location from the position.
269   Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
270   builder.setInsertionPointToEnd(currentBlock);
271   Value value;
272   switch (pos->getKind()) {
273   case Predicates::OperationPos: {
274     auto *operationPos = cast<OperationPosition>(pos);
275     if (operationPos->isOperandDefiningOp())
276       // Standard (downward) traversal which directly follows the defining op.
277       value = builder.create<pdl_interp::GetDefiningOpOp>(
278           loc, builder.getType<pdl::OperationType>(), parentVal);
279     else
280       // A passthrough operation position.
281       value = parentVal;
282     break;
283   }
284   case Predicates::UsersPos: {
285     auto *usersPos = cast<UsersPosition>(pos);
286 
287     // The first operation retrieves the representative value of a range.
288     // This applies only when the parent is a range of values and we were
289     // requested to use a representative value (e.g., upward traversal).
290     if (isa<pdl::RangeType>(parentVal.getType()) &&
291         usersPos->useRepresentative())
292       value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
293     else
294       value = parentVal;
295 
296     // The second operation retrieves the users.
297     value = builder.create<pdl_interp::GetUsersOp>(loc, value);
298     break;
299   }
300   case Predicates::ForEachPos: {
301     assert(!failureBlockStack.empty() && "expected valid failure block");
302     auto foreach = builder.create<pdl_interp::ForEachOp>(
303         loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
304     value = foreach.getLoopVariable();
305 
306     // Create the continuation block.
307     Block *continueBlock = builder.createBlock(&foreach.getRegion());
308     builder.create<pdl_interp::ContinueOp>(loc);
309     failureBlockStack.push_back(continueBlock);
310 
311     currentBlock = &foreach.getRegion().front();
312     break;
313   }
314   case Predicates::OperandPos: {
315     auto *operandPos = cast<OperandPosition>(pos);
316     value = builder.create<pdl_interp::GetOperandOp>(
317         loc, builder.getType<pdl::ValueType>(), parentVal,
318         operandPos->getOperandNumber());
319     break;
320   }
321   case Predicates::OperandGroupPos: {
322     auto *operandPos = cast<OperandGroupPosition>(pos);
323     Type valueTy = builder.getType<pdl::ValueType>();
324     value = builder.create<pdl_interp::GetOperandsOp>(
325         loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
326         parentVal, operandPos->getOperandGroupNumber());
327     break;
328   }
329   case Predicates::AttributePos: {
330     auto *attrPos = cast<AttributePosition>(pos);
331     value = builder.create<pdl_interp::GetAttributeOp>(
332         loc, builder.getType<pdl::AttributeType>(), parentVal,
333         attrPos->getName().strref());
334     break;
335   }
336   case Predicates::TypePos: {
337     if (isa<pdl::AttributeType>(parentVal.getType()))
338       value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
339     else
340       value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
341     break;
342   }
343   case Predicates::ResultPos: {
344     auto *resPos = cast<ResultPosition>(pos);
345     value = builder.create<pdl_interp::GetResultOp>(
346         loc, builder.getType<pdl::ValueType>(), parentVal,
347         resPos->getResultNumber());
348     break;
349   }
350   case Predicates::ResultGroupPos: {
351     auto *resPos = cast<ResultGroupPosition>(pos);
352     Type valueTy = builder.getType<pdl::ValueType>();
353     value = builder.create<pdl_interp::GetResultsOp>(
354         loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
355         parentVal, resPos->getResultGroupNumber());
356     break;
357   }
358   case Predicates::AttributeLiteralPos: {
359     auto *attrPos = cast<AttributeLiteralPosition>(pos);
360     value =
361         builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
362     break;
363   }
364   case Predicates::TypeLiteralPos: {
365     auto *typePos = cast<TypeLiteralPosition>(pos);
366     Attribute rawTypeAttr = typePos->getValue();
367     if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr))
368       value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
369     else
370       value = builder.create<pdl_interp::CreateTypesOp>(
371           loc, cast<ArrayAttr>(rawTypeAttr));
372     break;
373   }
374   case Predicates::ConstraintResultPos: {
375     // Due to the order of traversal, the ApplyConstraintOp has already been
376     // created and we can find it in constraintOpMap.
377     auto *constrResPos = cast<ConstraintPosition>(pos);
378     auto i = constraintOpMap.find(constrResPos->getQuestion());
379     assert(i != constraintOpMap.end());
380     value = i->second->getResult(constrResPos->getIndex());
381     break;
382   }
383   default:
384     llvm_unreachable("Generating unknown Position getter");
385     break;
386   }
387 
388   values.insert(pos, value);
389   return value;
390 }
391 
generate(BoolNode * boolNode,Block * & currentBlock,Value val)392 void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
393                                Value val) {
394   Location loc = val.getLoc();
395   Qualifier *question = boolNode->getQuestion();
396   Qualifier *answer = boolNode->getAnswer();
397   Region *region = currentBlock->getParent();
398 
399   // Execute the getValue queries first, so that we create success
400   // matcher in the correct (possibly nested) region.
401   SmallVector<Value> args;
402   if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
403     args = {getValueAt(currentBlock, equalToQuestion->getValue())};
404   } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
405     for (Position *position : cstQuestion->getArgs())
406       args.push_back(getValueAt(currentBlock, position));
407   }
408 
409   // Generate a new block as success successor and get the failure successor.
410   Block *success = &region->emplaceBlock();
411   Block *failure = failureBlockStack.back();
412 
413   // Create the predicate.
414   builder.setInsertionPointToEnd(currentBlock);
415   Predicates::Kind kind = question->getKind();
416   switch (kind) {
417   case Predicates::IsNotNullQuestion:
418     builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure);
419     break;
420   case Predicates::OperationNameQuestion: {
421     auto *opNameAnswer = cast<OperationNameAnswer>(answer);
422     builder.create<pdl_interp::CheckOperationNameOp>(
423         loc, val, opNameAnswer->getValue().getStringRef(), success, failure);
424     break;
425   }
426   case Predicates::TypeQuestion: {
427     auto *ans = cast<TypeAnswer>(answer);
428     if (isa<pdl::RangeType>(val.getType()))
429       builder.create<pdl_interp::CheckTypesOp>(
430           loc, val, llvm::cast<ArrayAttr>(ans->getValue()), success, failure);
431     else
432       builder.create<pdl_interp::CheckTypeOp>(
433           loc, val, llvm::cast<TypeAttr>(ans->getValue()), success, failure);
434     break;
435   }
436   case Predicates::AttributeQuestion: {
437     auto *ans = cast<AttributeAnswer>(answer);
438     builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
439                                                  success, failure);
440     break;
441   }
442   case Predicates::OperandCountAtLeastQuestion:
443   case Predicates::OperandCountQuestion:
444     builder.create<pdl_interp::CheckOperandCountOp>(
445         loc, val, cast<UnsignedAnswer>(answer)->getValue(),
446         /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
447         success, failure);
448     break;
449   case Predicates::ResultCountAtLeastQuestion:
450   case Predicates::ResultCountQuestion:
451     builder.create<pdl_interp::CheckResultCountOp>(
452         loc, val, cast<UnsignedAnswer>(answer)->getValue(),
453         /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
454         success, failure);
455     break;
456   case Predicates::EqualToQuestion: {
457     bool trueAnswer = isa<TrueAnswer>(answer);
458     builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(),
459                                            trueAnswer ? success : failure,
460                                            trueAnswer ? failure : success);
461     break;
462   }
463   case Predicates::ConstraintQuestion: {
464     auto *cstQuestion = cast<ConstraintQuestion>(question);
465     auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
466         loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
467         cstQuestion->getIsNegated(), success, failure);
468 
469     constraintOpMap.insert({cstQuestion, applyConstraintOp});
470     break;
471   }
472   default:
473     llvm_unreachable("Generating unknown Predicate operation");
474   }
475 
476   // Generate the matcher in the current (potentially nested) region.
477   // This might use the results of the current predicate.
478   generateMatcher(*boolNode->getSuccessNode(), *region, success);
479 }
480 
481 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
createSwitchOp(Value val,Block * defaultDest,OpBuilder & builder,llvm::MapVector<Qualifier *,Block * > & dests)482 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
483                            llvm::MapVector<Qualifier *, Block *> &dests) {
484   std::vector<ValT> values;
485   std::vector<Block *> blocks;
486   values.reserve(dests.size());
487   blocks.reserve(dests.size());
488   for (const auto &it : dests) {
489     blocks.push_back(it.second);
490     values.push_back(cast<PredT>(it.first)->getValue());
491   }
492   builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
493 }
494 
generate(SwitchNode * switchNode,Block * currentBlock,Value val)495 void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
496                                Value val) {
497   Qualifier *question = switchNode->getQuestion();
498   Region *region = currentBlock->getParent();
499   Block *defaultDest = failureBlockStack.back();
500 
501   // If the switch question is not an exact answer, i.e. for the `at_least`
502   // cases, we generate a special block sequence.
503   Predicates::Kind kind = question->getKind();
504   if (kind == Predicates::OperandCountAtLeastQuestion ||
505       kind == Predicates::ResultCountAtLeastQuestion) {
506     // Order the children such that the cases are in reverse numerical order.
507     SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
508         llvm::seq<unsigned>(0, switchNode->getChildren().size()));
509     llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
510       return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
511              cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
512     });
513 
514     // Build the destination for each child using the next highest child as a
515     // a failure destination. This essentially creates the following control
516     // flow:
517     //
518     // if (operand_count < 1)
519     //   goto failure
520     // if (child1.match())
521     //   ...
522     //
523     // if (operand_count < 2)
524     //   goto failure
525     // if (child2.match())
526     //   ...
527     //
528     // failure:
529     //   ...
530     //
531     failureBlockStack.push_back(defaultDest);
532     Location loc = val.getLoc();
533     for (unsigned idx : sortedChildren) {
534       auto &child = switchNode->getChild(idx);
535       Block *childBlock = generateMatcher(*child.second, *region);
536       Block *predicateBlock = builder.createBlock(childBlock);
537       builder.setInsertionPointToEnd(predicateBlock);
538       unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
539       switch (kind) {
540       case Predicates::OperandCountAtLeastQuestion:
541         builder.create<pdl_interp::CheckOperandCountOp>(
542             loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
543         break;
544       case Predicates::ResultCountAtLeastQuestion:
545         builder.create<pdl_interp::CheckResultCountOp>(
546             loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
547         break;
548       default:
549         llvm_unreachable("Generating invalid AtLeast operation");
550       }
551       failureBlockStack.back() = predicateBlock;
552     }
553     Block *firstPredicateBlock = failureBlockStack.pop_back_val();
554     currentBlock->getOperations().splice(currentBlock->end(),
555                                          firstPredicateBlock->getOperations());
556     firstPredicateBlock->erase();
557     return;
558   }
559 
560   // Otherwise, generate each of the children and generate an interpreter
561   // switch.
562   llvm::MapVector<Qualifier *, Block *> children;
563   for (auto &it : switchNode->getChildren())
564     children.insert({it.first, generateMatcher(*it.second, *region)});
565   builder.setInsertionPointToEnd(currentBlock);
566 
567   switch (question->getKind()) {
568   case Predicates::OperandCountQuestion:
569     return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
570                           int32_t>(val, defaultDest, builder, children);
571   case Predicates::ResultCountQuestion:
572     return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
573                           int32_t>(val, defaultDest, builder, children);
574   case Predicates::OperationNameQuestion:
575     return createSwitchOp<pdl_interp::SwitchOperationNameOp,
576                           OperationNameAnswer>(val, defaultDest, builder,
577                                                children);
578   case Predicates::TypeQuestion:
579     if (isa<pdl::RangeType>(val.getType())) {
580       return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
581           val, defaultDest, builder, children);
582     }
583     return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
584         val, defaultDest, builder, children);
585   case Predicates::AttributeQuestion:
586     return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
587         val, defaultDest, builder, children);
588   default:
589     llvm_unreachable("Generating unknown switch predicate.");
590   }
591 }
592 
generate(SuccessNode * successNode,Block * & currentBlock)593 void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {
594   pdl::PatternOp pattern = successNode->getPattern();
595   Value root = successNode->getRoot();
596 
597   // Generate a rewriter for the pattern this success node represents, and track
598   // any values used from the match region.
599   SmallVector<Position *, 8> usedMatchValues;
600   SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
601 
602   // Process any values used in the rewrite that are defined in the match.
603   std::vector<Value> mappedMatchValues;
604   mappedMatchValues.reserve(usedMatchValues.size());
605   for (Position *position : usedMatchValues)
606     mappedMatchValues.push_back(getValueAt(currentBlock, position));
607 
608   // Collect the set of operations generated by the rewriter.
609   SmallVector<StringRef, 4> generatedOps;
610   for (auto op :
611        pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>())
612     generatedOps.push_back(*op.getOpName());
613   ArrayAttr generatedOpsAttr;
614   if (!generatedOps.empty())
615     generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
616 
617   // Grab the root kind if present.
618   StringAttr rootKindAttr;
619   if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
620     if (std::optional<StringRef> rootKind = rootOp.getOpName())
621       rootKindAttr = builder.getStringAttr(*rootKind);
622 
623   builder.setInsertionPointToEnd(currentBlock);
624   auto matchOp = builder.create<pdl_interp::RecordMatchOp>(
625       pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
626       rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
627       failureBlockStack.back());
628 
629   // Set the config of the lowered match to the parent pattern.
630   if (configMap)
631     configMap->try_emplace(matchOp, configMap->lookup(pattern));
632 }
633 
generateRewriter(pdl::PatternOp pattern,SmallVectorImpl<Position * > & usedMatchValues)634 SymbolRefAttr PatternLowering::generateRewriter(
635     pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
636   builder.setInsertionPointToEnd(rewriterModule.getBody());
637   auto rewriterFunc = builder.create<pdl_interp::FuncOp>(
638       pattern.getLoc(), "pdl_generated_rewriter",
639       builder.getFunctionType(std::nullopt, std::nullopt));
640   rewriterSymbolTable.insert(rewriterFunc);
641 
642   // Generate the rewriter function body.
643   builder.setInsertionPointToEnd(&rewriterFunc.front());
644 
645   // Map an input operand of the pattern to a generated interpreter value.
646   DenseMap<Value, Value> rewriteValues;
647   auto mapRewriteValue = [&](Value oldValue) {
648     Value &newValue = rewriteValues[oldValue];
649     if (newValue)
650       return newValue;
651 
652     // Prefer materializing constants directly when possible.
653     Operation *oldOp = oldValue.getDefiningOp();
654     if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
655       if (Attribute value = attrOp.getValueAttr()) {
656         return newValue = builder.create<pdl_interp::CreateAttributeOp>(
657                    attrOp.getLoc(), value);
658       }
659     } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
660       if (TypeAttr type = typeOp.getConstantTypeAttr()) {
661         return newValue = builder.create<pdl_interp::CreateTypeOp>(
662                    typeOp.getLoc(), type);
663       }
664     } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
665       if (ArrayAttr type = typeOp.getConstantTypesAttr()) {
666         return newValue = builder.create<pdl_interp::CreateTypesOp>(
667                    typeOp.getLoc(), typeOp.getType(), type);
668       }
669     }
670 
671     // Otherwise, add this as an input to the rewriter.
672     Position *inputPos = valueToPosition.lookup(oldValue);
673     assert(inputPos && "expected value to be a pattern input");
674     usedMatchValues.push_back(inputPos);
675     return newValue = rewriterFunc.front().addArgument(oldValue.getType(),
676                                                        oldValue.getLoc());
677   };
678 
679   // If this is a custom rewriter, simply dispatch to the registered rewrite
680   // method.
681   pdl::RewriteOp rewriter = pattern.getRewriter();
682   if (StringAttr rewriteName = rewriter.getNameAttr()) {
683     SmallVector<Value> args;
684     if (rewriter.getRoot())
685       args.push_back(mapRewriteValue(rewriter.getRoot()));
686     auto mappedArgs =
687         llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
688     args.append(mappedArgs.begin(), mappedArgs.end());
689     builder.create<pdl_interp::ApplyRewriteOp>(
690         rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args);
691   } else {
692     // Otherwise this is a dag rewriter defined using PDL operations.
693     for (Operation &rewriteOp : *rewriter.getBody()) {
694       llvm::TypeSwitch<Operation *>(&rewriteOp)
695           .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
696                 pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp,
697                 pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) {
698             this->generateRewriter(op, rewriteValues, mapRewriteValue);
699           });
700     }
701   }
702 
703   // Update the signature of the rewrite function.
704   rewriterFunc.setType(builder.getFunctionType(
705       llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
706       /*results=*/std::nullopt));
707 
708   builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
709   return SymbolRefAttr::get(
710       builder.getContext(),
711       pdl_interp::PDLInterpDialect::getRewriterModuleName(),
712       SymbolRefAttr::get(rewriterFunc));
713 }
714 
generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)715 void PatternLowering::generateRewriter(
716     pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
717     function_ref<Value(Value)> mapRewriteValue) {
718   SmallVector<Value, 2> arguments;
719   for (Value argument : rewriteOp.getArgs())
720     arguments.push_back(mapRewriteValue(argument));
721   auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
722       rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(),
723       arguments);
724   for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
725     rewriteValues[std::get<0>(it)] = std::get<1>(it);
726 }
727 
generateRewriter(pdl::AttributeOp attrOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)728 void PatternLowering::generateRewriter(
729     pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
730     function_ref<Value(Value)> mapRewriteValue) {
731   Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
732       attrOp.getLoc(), attrOp.getValueAttr());
733   rewriteValues[attrOp] = newAttr;
734 }
735 
generateRewriter(pdl::EraseOp eraseOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)736 void PatternLowering::generateRewriter(
737     pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
738     function_ref<Value(Value)> mapRewriteValue) {
739   builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
740                                       mapRewriteValue(eraseOp.getOpValue()));
741 }
742 
generateRewriter(pdl::OperationOp operationOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)743 void PatternLowering::generateRewriter(
744     pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
745     function_ref<Value(Value)> mapRewriteValue) {
746   SmallVector<Value, 4> operands;
747   for (Value operand : operationOp.getOperandValues())
748     operands.push_back(mapRewriteValue(operand));
749 
750   SmallVector<Value, 4> attributes;
751   for (Value attr : operationOp.getAttributeValues())
752     attributes.push_back(mapRewriteValue(attr));
753 
754   bool hasInferredResultTypes = false;
755   SmallVector<Value, 2> types;
756   generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
757                                       rewriteValues, hasInferredResultTypes);
758 
759   // Create the new operation.
760   Location loc = operationOp.getLoc();
761   Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
762       loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
763       attributes, operationOp.getAttributeValueNames());
764   rewriteValues[operationOp.getOp()] = createdOp;
765 
766   // Generate accesses for any results that have their types constrained.
767   // Handle the case where there is a single range representing all of the
768   // result types.
769   OperandRange resultTys = operationOp.getTypeValues();
770   if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) {
771     Value &type = rewriteValues[resultTys[0]];
772     if (!type) {
773       auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
774       type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
775     }
776     return;
777   }
778 
779   // Otherwise, populate the individual results.
780   bool seenVariableLength = false;
781   Type valueTy = builder.getType<pdl::ValueType>();
782   Type valueRangeTy = pdl::RangeType::get(valueTy);
783   for (const auto &it : llvm::enumerate(resultTys)) {
784     Value &type = rewriteValues[it.value()];
785     if (type)
786       continue;
787     bool isVariadic = isa<pdl::RangeType>(it.value().getType());
788     seenVariableLength |= isVariadic;
789 
790     // After a variable length result has been seen, we need to use result
791     // groups because the exact index of the result is not statically known.
792     Value resultVal;
793     if (seenVariableLength)
794       resultVal = builder.create<pdl_interp::GetResultsOp>(
795           loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
796     else
797       resultVal = builder.create<pdl_interp::GetResultOp>(
798           loc, valueTy, createdOp, it.index());
799     type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
800   }
801 }
802 
generateRewriter(pdl::RangeOp rangeOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)803 void PatternLowering::generateRewriter(
804     pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues,
805     function_ref<Value(Value)> mapRewriteValue) {
806   SmallVector<Value, 4> replOperands;
807   for (Value operand : rangeOp.getArguments())
808     replOperands.push_back(mapRewriteValue(operand));
809   rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>(
810       rangeOp.getLoc(), rangeOp.getType(), replOperands);
811 }
812 
generateRewriter(pdl::ReplaceOp replaceOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)813 void PatternLowering::generateRewriter(
814     pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
815     function_ref<Value(Value)> mapRewriteValue) {
816   SmallVector<Value, 4> replOperands;
817 
818   // If the replacement was another operation, get its results. `pdl` allows
819   // for using an operation for simplicitly, but the interpreter isn't as
820   // user facing.
821   if (Value replOp = replaceOp.getReplOperation()) {
822     // Don't use replace if we know the replaced operation has no results.
823     auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
824     if (!opOp || !opOp.getTypeValues().empty()) {
825       replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
826           replOp.getLoc(), mapRewriteValue(replOp)));
827     }
828   } else {
829     for (Value operand : replaceOp.getReplValues())
830       replOperands.push_back(mapRewriteValue(operand));
831   }
832 
833   // If there are no replacement values, just create an erase instead.
834   if (replOperands.empty()) {
835     builder.create<pdl_interp::EraseOp>(
836         replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue()));
837     return;
838   }
839 
840   builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(),
841                                         mapRewriteValue(replaceOp.getOpValue()),
842                                         replOperands);
843 }
844 
generateRewriter(pdl::ResultOp resultOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)845 void PatternLowering::generateRewriter(
846     pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
847     function_ref<Value(Value)> mapRewriteValue) {
848   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
849       resultOp.getLoc(), builder.getType<pdl::ValueType>(),
850       mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
851 }
852 
generateRewriter(pdl::ResultsOp resultOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)853 void PatternLowering::generateRewriter(
854     pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
855     function_ref<Value(Value)> mapRewriteValue) {
856   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
857       resultOp.getLoc(), resultOp.getType(),
858       mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
859 }
860 
generateRewriter(pdl::TypeOp typeOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)861 void PatternLowering::generateRewriter(
862     pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
863     function_ref<Value(Value)> mapRewriteValue) {
864   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
865   // type.
866   if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
867     rewriteValues[typeOp] =
868         builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
869   }
870 }
871 
generateRewriter(pdl::TypesOp typeOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)872 void PatternLowering::generateRewriter(
873     pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
874     function_ref<Value(Value)> mapRewriteValue) {
875   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
876   // type.
877   if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
878     rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
879         typeOp.getLoc(), typeOp.getType(), typeAttr);
880   }
881 }
882 
generateOperationResultTypeRewriter(pdl::OperationOp op,function_ref<Value (Value)> mapRewriteValue,SmallVectorImpl<Value> & types,DenseMap<Value,Value> & rewriteValues,bool & hasInferredResultTypes)883 void PatternLowering::generateOperationResultTypeRewriter(
884     pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
885     SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
886     bool &hasInferredResultTypes) {
887   Block *rewriterBlock = op->getBlock();
888 
889   // Try to handle resolution for each of the result types individually. This is
890   // preferred over type inferrence because it will allow for us to use existing
891   // types directly, as opposed to trying to rebuild the type list.
892   OperandRange resultTypeValues = op.getTypeValues();
893   auto tryResolveResultTypes = [&] {
894     types.reserve(resultTypeValues.size());
895     for (const auto &it : llvm::enumerate(resultTypeValues)) {
896       Value resultType = it.value();
897 
898       // Check for an already translated value.
899       if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
900         types.push_back(existingRewriteValue);
901         continue;
902       }
903 
904       // Check for an input from the matcher.
905       if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
906         types.push_back(mapRewriteValue(resultType));
907         continue;
908       }
909 
910       // Otherwise, we couldn't infer the result types. Bail out here to see if
911       // we can infer the types for this operation from another way.
912       types.clear();
913       return failure();
914     }
915     return success();
916   };
917   if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes()))
918     return;
919 
920   // Otherwise, check if the operation has type inference support itself.
921   if (op.hasTypeInference()) {
922     hasInferredResultTypes = true;
923     return;
924   }
925 
926   // Look for an operation that was replaced by `op`. The result types will be
927   // inferred from the results that were replaced.
928   for (OpOperand &use : op.getOp().getUses()) {
929     // Check that the use corresponds to a ReplaceOp and that it is the
930     // replacement value, not the operation being replaced.
931     pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
932     if (!replOpUser || use.getOperandNumber() == 0)
933       continue;
934     // Make sure the replaced operation was defined before this one. PDL
935     // rewrites only have single block regions, so if the op isn't in the
936     // rewriter block (i.e. the current block of the operation) we already know
937     // it dominates (i.e. it's in the matcher).
938     Value replOpVal = replOpUser.getOpValue();
939     Operation *replacedOp = replOpVal.getDefiningOp();
940     if (replacedOp->getBlock() == rewriterBlock &&
941         !replacedOp->isBeforeInBlock(op))
942       continue;
943 
944     Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
945         replacedOp->getLoc(), mapRewriteValue(replOpVal));
946     types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
947         replacedOp->getLoc(), replacedOpResults));
948     return;
949   }
950 
951   // If the types could not be inferred from any context and there weren't any
952   // explicit result types, assume the user actually meant for the operation to
953   // have no results.
954   if (resultTypeValues.empty())
955     return;
956 
957   // The verifier asserts that the result types of each pdl.getOperation can be
958   // inferred. If we reach here, there is a bug either in the logic above or
959   // in the verifier for pdl.getOperation.
960   op->emitOpError() << "unable to infer result type for operation";
961   llvm_unreachable("unable to infer result type for operation");
962 }
963 
964 //===----------------------------------------------------------------------===//
965 // Conversion Pass
966 //===----------------------------------------------------------------------===//
967 
968 namespace {
969 struct PDLToPDLInterpPass
970     : public impl::ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
971   PDLToPDLInterpPass() = default;
972   PDLToPDLInterpPass(const PDLToPDLInterpPass &rhs) = default;
PDLToPDLInterpPass__anon4598d56a0811::PDLToPDLInterpPass973   PDLToPDLInterpPass(DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
974       : configMap(&configMap) {}
975   void runOnOperation() final;
976 
977   /// A map containing the configuration for each pattern.
978   DenseMap<Operation *, PDLPatternConfigSet *> *configMap = nullptr;
979 };
980 } // namespace
981 
982 /// Convert the given module containing PDL pattern operations into a PDL
983 /// Interpreter operations.
runOnOperation()984 void PDLToPDLInterpPass::runOnOperation() {
985   ModuleOp module = getOperation();
986 
987   // Create the main matcher function This function contains all of the match
988   // related functionality from patterns in the module.
989   OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
990   auto matcherFunc = builder.create<pdl_interp::FuncOp>(
991       module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
992       builder.getFunctionType(builder.getType<pdl::OperationType>(),
993                               /*results=*/std::nullopt),
994       /*attrs=*/std::nullopt);
995 
996   // Create a nested module to hold the functions invoked for rewriting the IR
997   // after a successful match.
998   ModuleOp rewriterModule = builder.create<ModuleOp>(
999       module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
1000 
1001   // Generate the code for the patterns within the module.
1002   PatternLowering generator(matcherFunc, rewriterModule, configMap);
1003   generator.lower(module);
1004 
1005   // After generation, delete all of the pattern operations.
1006   for (pdl::PatternOp pattern :
1007        llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) {
1008     // Drop the now dead config mappings.
1009     if (configMap)
1010       configMap->erase(pattern);
1011 
1012     pattern.erase();
1013   }
1014 }
1015 
createPDLToPDLInterpPass()1016 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
1017   return std::make_unique<PDLToPDLInterpPass>();
1018 }
createPDLToPDLInterpPass(DenseMap<Operation *,PDLPatternConfigSet * > & configMap)1019 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass(
1020     DenseMap<Operation *, PDLPatternConfigSet *> &configMap) {
1021   return std::make_unique<PDLToPDLInterpPass>(configMap);
1022 }
1023