xref: /llvm-project/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp (revision 1f13963ec14a5c664633f78856e70de1d40258cd)
1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PredicateTree.h"
10 #include "mlir/Dialect/PDL/IR/PDL.h"
11 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
12 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/Interfaces/InferTypeOpInterface.h"
15 #include "llvm/ADT/TypeSwitch.h"
16 
17 using namespace mlir;
18 using namespace mlir::pdl_to_pdl_interp;
19 
20 //===----------------------------------------------------------------------===//
21 // Predicate List Building
22 //===----------------------------------------------------------------------===//
23 
24 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
25                               Value val, PredicateBuilder &builder,
26                               DenseMap<Value, Position *> &inputs,
27                               Position *pos);
28 
29 /// Compares the depths of two positions.
30 static bool comparePosDepth(Position *lhs, Position *rhs) {
31   return lhs->getOperationDepth() < rhs->getOperationDepth();
32 }
33 
34 /// Returns the number of non-range elements within `values`.
35 static unsigned getNumNonRangeValues(ValueRange values) {
36   return llvm::count_if(values.getTypes(),
37                         [](Type type) { return !type.isa<pdl::RangeType>(); });
38 }
39 
40 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
41                               Value val, PredicateBuilder &builder,
42                               DenseMap<Value, Position *> &inputs,
43                               AttributePosition *pos) {
44   assert(val.getType().isa<pdl::AttributeType>() && "expected attribute type");
45   pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
46   predList.emplace_back(pos, builder.getIsNotNull());
47 
48   // If the attribute has a type or value, add a constraint.
49   if (Value type = attr.type())
50     getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
51   else if (Attribute value = attr.valueAttr())
52     predList.emplace_back(pos, builder.getAttributeConstraint(value));
53 }
54 
55 /// Collect all of the predicates for the given operand position.
56 static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
57                                      Value val, PredicateBuilder &builder,
58                                      DenseMap<Value, Position *> &inputs,
59                                      Position *pos) {
60   Type valueType = val.getType();
61   bool isVariadic = valueType.isa<pdl::RangeType>();
62 
63   // If this is a typed operand, add a type constraint.
64   TypeSwitch<Operation *>(val.getDefiningOp())
65       .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
66         // Prevent traversal into a null value if the operand has a proper
67         // index.
68         if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
69             cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
70           predList.emplace_back(pos, builder.getIsNotNull());
71 
72         if (Value type = op.type())
73           getTreePredicates(predList, type, builder, inputs,
74                             builder.getType(pos));
75       })
76       .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
77         Optional<unsigned> index = op.index();
78 
79         // Prevent traversal into a null value if the result has a proper index.
80         if (index)
81           predList.emplace_back(pos, builder.getIsNotNull());
82 
83         // Get the parent operation of this operand.
84         OperationPosition *parentPos = builder.getOperandDefiningOp(pos);
85         predList.emplace_back(parentPos, builder.getIsNotNull());
86 
87         // Ensure that the operands match the corresponding results of the
88         // parent operation.
89         Position *resultPos = nullptr;
90         if (std::is_same<pdl::ResultOp, decltype(op)>::value)
91           resultPos = builder.getResult(parentPos, *index);
92         else
93           resultPos = builder.getResultGroup(parentPos, index, isVariadic);
94         predList.emplace_back(resultPos, builder.getEqualTo(pos));
95 
96         // Collect the predicates of the parent operation.
97         getTreePredicates(predList, op.parent(), builder, inputs,
98                           (Position *)parentPos);
99       });
100 }
101 
102 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
103                               Value val, PredicateBuilder &builder,
104                               DenseMap<Value, Position *> &inputs,
105                               OperationPosition *pos) {
106   assert(val.getType().isa<pdl::OperationType>() && "expected operation");
107   pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
108   OperationPosition *opPos = cast<OperationPosition>(pos);
109 
110   // Ensure getDefiningOp returns a non-null operation.
111   if (!opPos->isRoot())
112     predList.emplace_back(pos, builder.getIsNotNull());
113 
114   // Check that this is the correct root operation.
115   if (Optional<StringRef> opName = op.name())
116     predList.emplace_back(pos, builder.getOperationName(*opName));
117 
118   // Check that the operation has the proper number of operands. If there are
119   // any variable length operands, we check a minimum instead of an exact count.
120   OperandRange operands = op.operands();
121   unsigned minOperands = getNumNonRangeValues(operands);
122   if (minOperands != operands.size()) {
123     if (minOperands)
124       predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands));
125   } else {
126     predList.emplace_back(pos, builder.getOperandCount(minOperands));
127   }
128 
129   // Check that the operation has the proper number of results. If there are
130   // any variable length results, we check a minimum instead of an exact count.
131   OperandRange types = op.types();
132   unsigned minResults = getNumNonRangeValues(types);
133   if (minResults == types.size())
134     predList.emplace_back(pos, builder.getResultCount(types.size()));
135   else if (minResults)
136     predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
137 
138   // Recurse into any attributes, operands, or results.
139   for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
140     getTreePredicates(
141         predList, std::get<1>(it), builder, inputs,
142         builder.getAttribute(opPos,
143                              std::get<0>(it).cast<StringAttr>().getValue()));
144   }
145 
146   // Process the operands and results of the operation. For all values up to
147   // the first variable length value, we use the concrete operand/result
148   // number. After that, we use the "group" given that we can't know the
149   // concrete indices until runtime. If there is only one variadic operand
150   // group, we treat it as all of the operands/results of the operation.
151   /// Operands.
152   if (operands.size() == 1 && operands[0].getType().isa<pdl::RangeType>()) {
153     getTreePredicates(predList, operands.front(), builder, inputs,
154                       builder.getAllOperands(opPos));
155   } else {
156     bool foundVariableLength = false;
157     for (auto operandIt : llvm::enumerate(operands)) {
158       bool isVariadic = operandIt.value().getType().isa<pdl::RangeType>();
159       foundVariableLength |= isVariadic;
160 
161       Position *pos =
162           foundVariableLength
163               ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
164               : builder.getOperand(opPos, operandIt.index());
165       getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
166     }
167   }
168   /// Results.
169   if (types.size() == 1 && types[0].getType().isa<pdl::RangeType>()) {
170     getTreePredicates(predList, types.front(), builder, inputs,
171                       builder.getType(builder.getAllResults(opPos)));
172   } else {
173     bool foundVariableLength = false;
174     for (auto &resultIt : llvm::enumerate(types)) {
175       bool isVariadic = resultIt.value().getType().isa<pdl::RangeType>();
176       foundVariableLength |= isVariadic;
177 
178       auto *resultPos =
179           foundVariableLength
180               ? builder.getResultGroup(pos, resultIt.index(), isVariadic)
181               : builder.getResult(pos, resultIt.index());
182       predList.emplace_back(resultPos, builder.getIsNotNull());
183       getTreePredicates(predList, resultIt.value(), builder, inputs,
184                         builder.getType(resultPos));
185     }
186   }
187 }
188 
189 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
190                               Value val, PredicateBuilder &builder,
191                               DenseMap<Value, Position *> &inputs,
192                               TypePosition *pos) {
193   // Check for a constraint on a constant type.
194   if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
195     if (Attribute type = typeOp.typeAttr())
196       predList.emplace_back(pos, builder.getTypeConstraint(type));
197   } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
198     if (Attribute typeAttr = typeOp.typesAttr())
199       predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
200   }
201 }
202 
203 /// Collect the tree predicates anchored at the given value.
204 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
205                               Value val, PredicateBuilder &builder,
206                               DenseMap<Value, Position *> &inputs,
207                               Position *pos) {
208   // Make sure this input value is accessible to the rewrite.
209   auto it = inputs.try_emplace(val, pos);
210   if (!it.second) {
211     // If this is an input value that has been visited in the tree, add a
212     // constraint to ensure that both instances refer to the same value.
213     if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
214             pdl::TypeOp>(val.getDefiningOp())) {
215       auto minMaxPositions =
216           std::minmax(pos, it.first->second, comparePosDepth);
217       predList.emplace_back(minMaxPositions.second,
218                             builder.getEqualTo(minMaxPositions.first));
219     }
220     return;
221   }
222 
223   TypeSwitch<Position *>(pos)
224       .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) {
225         getTreePredicates(predList, val, builder, inputs, pos);
226       })
227       .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {
228         getOperandTreePredicates(predList, val, builder, inputs, pos);
229       })
230       .Default([](auto *) { llvm_unreachable("unexpected position kind"); });
231 }
232 
233 /// Collect all of the predicates related to constraints within the given
234 /// pattern operation.
235 static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
236                                     std::vector<PositionalPredicate> &predList,
237                                     PredicateBuilder &builder,
238                                     DenseMap<Value, Position *> &inputs) {
239   OperandRange arguments = op.args();
240   ArrayAttr parameters = op.constParamsAttr();
241 
242   std::vector<Position *> allPositions;
243   allPositions.reserve(arguments.size());
244   for (Value arg : arguments)
245     allPositions.push_back(inputs.lookup(arg));
246 
247   // Push the constraint to the furthest position.
248   Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
249                                     comparePosDepth);
250   PredicateBuilder::Predicate pred =
251       builder.getConstraint(op.name(), std::move(allPositions), parameters);
252   predList.emplace_back(pos, pred);
253 }
254 
255 static void getResultPredicates(pdl::ResultOp op,
256                                 std::vector<PositionalPredicate> &predList,
257                                 PredicateBuilder &builder,
258                                 DenseMap<Value, Position *> &inputs) {
259   Position *&resultPos = inputs[op];
260   if (resultPos)
261     return;
262 
263   // Ensure that the result isn't null.
264   auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
265   resultPos = builder.getResult(parentPos, op.index());
266   predList.emplace_back(resultPos, builder.getIsNotNull());
267 }
268 
269 static void getResultPredicates(pdl::ResultsOp op,
270                                 std::vector<PositionalPredicate> &predList,
271                                 PredicateBuilder &builder,
272                                 DenseMap<Value, Position *> &inputs) {
273   Position *&resultPos = inputs[op];
274   if (resultPos)
275     return;
276 
277   // Ensure that the result isn't null if the result has an index.
278   auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
279   bool isVariadic = op.getType().isa<pdl::RangeType>();
280   Optional<unsigned> index = op.index();
281   resultPos = builder.getResultGroup(parentPos, index, isVariadic);
282   if (index)
283     predList.emplace_back(resultPos, builder.getIsNotNull());
284 }
285 
286 /// Collect all of the predicates that cannot be determined via walking the
287 /// tree.
288 static void getNonTreePredicates(pdl::PatternOp pattern,
289                                  std::vector<PositionalPredicate> &predList,
290                                  PredicateBuilder &builder,
291                                  DenseMap<Value, Position *> &inputs) {
292   for (Operation &op : pattern.body().getOps()) {
293     TypeSwitch<Operation *>(&op)
294         .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
295           getConstraintPredicates(constraintOp, predList, builder, inputs);
296         })
297         .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
298           getResultPredicates(resultOp, predList, builder, inputs);
299         });
300   }
301 }
302 
303 /// Given a pattern operation, build the set of matcher predicates necessary to
304 /// match this pattern.
305 static void buildPredicateList(pdl::PatternOp pattern,
306                                PredicateBuilder &builder,
307                                std::vector<PositionalPredicate> &predList,
308                                DenseMap<Value, Position *> &valueToPosition) {
309   getTreePredicates(predList, pattern.getRewriter().root(), builder,
310                     valueToPosition, builder.getRoot());
311   getNonTreePredicates(pattern, predList, builder, valueToPosition);
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // Pattern Predicate Tree Merging
316 //===----------------------------------------------------------------------===//
317 
318 namespace {
319 
320 /// This class represents a specific predicate applied to a position, and
321 /// provides hashing and ordering operators. This class allows for computing a
322 /// frequence sum and ordering predicates based on a cost model.
323 struct OrderedPredicate {
324   OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
325       : position(ip.first), question(ip.second) {}
326   OrderedPredicate(const PositionalPredicate &ip)
327       : position(ip.position), question(ip.question) {}
328 
329   /// The position this predicate is applied to.
330   Position *position;
331 
332   /// The question that is applied by this predicate onto the position.
333   Qualifier *question;
334 
335   /// The first and second order benefit sums.
336   /// The primary sum is the number of occurrences of this predicate among all
337   /// of the patterns.
338   unsigned primary = 0;
339   /// The secondary sum is a squared summation of the primary sum of all of the
340   /// predicates within each pattern that contains this predicate. This allows
341   /// for favoring predicates that are more commonly shared within a pattern, as
342   /// opposed to those shared across patterns.
343   unsigned secondary = 0;
344 
345   /// A map between a pattern operation and the answer to the predicate question
346   /// within that pattern.
347   DenseMap<Operation *, Qualifier *> patternToAnswer;
348 
349   /// Returns true if this predicate is ordered before `rhs`, based on the cost
350   /// model.
351   bool operator<(const OrderedPredicate &rhs) const {
352     // Sort by:
353     // * higher first and secondary order sums
354     // * lower depth
355     // * lower position dependency
356     // * lower predicate dependency
357     auto *rhsPos = rhs.position;
358     return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
359                            rhsPos->getKind(), rhs.question->getKind()) >
360            std::make_tuple(rhs.primary, rhs.secondary,
361                            position->getOperationDepth(), position->getKind(),
362                            question->getKind());
363   }
364 };
365 
366 /// A DenseMapInfo for OrderedPredicate based solely on the position and
367 /// question.
368 struct OrderedPredicateDenseInfo {
369   using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>;
370 
371   static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
372   static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
373   static bool isEqual(const OrderedPredicate &lhs,
374                       const OrderedPredicate &rhs) {
375     return lhs.position == rhs.position && lhs.question == rhs.question;
376   }
377   static unsigned getHashValue(const OrderedPredicate &p) {
378     return llvm::hash_combine(p.position, p.question);
379   }
380 };
381 
382 /// This class wraps a set of ordered predicates that are used within a specific
383 /// pattern operation.
384 struct OrderedPredicateList {
385   OrderedPredicateList(pdl::PatternOp pattern) : pattern(pattern) {}
386 
387   pdl::PatternOp pattern;
388   DenseSet<OrderedPredicate *> predicates;
389 };
390 } // end anonymous namespace
391 
392 /// Returns true if the given matcher refers to the same predicate as the given
393 /// ordered predicate. This means that the position and questions of the two
394 /// match.
395 static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
396   return node->getPosition() == predicate->position &&
397          node->getQuestion() == predicate->question;
398 }
399 
400 /// Get or insert a child matcher for the given parent switch node, given a
401 /// predicate and parent pattern.
402 std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
403                                                OrderedPredicate *predicate,
404                                                pdl::PatternOp pattern) {
405   assert(isSamePredicate(node, predicate) &&
406          "expected matcher to equal the given predicate");
407 
408   auto it = predicate->patternToAnswer.find(pattern);
409   assert(it != predicate->patternToAnswer.end() &&
410          "expected pattern to exist in predicate");
411   return node->getChildren().insert({it->second, nullptr}).first->second;
412 }
413 
414 /// Build the matcher CFG by "pushing" patterns through by sorted predicate
415 /// order. A pattern will traverse as far as possible using common predicates
416 /// and then either diverge from the CFG or reach the end of a branch and start
417 /// creating new nodes.
418 static void propagatePattern(std::unique_ptr<MatcherNode> &node,
419                              OrderedPredicateList &list,
420                              std::vector<OrderedPredicate *>::iterator current,
421                              std::vector<OrderedPredicate *>::iterator end) {
422   if (current == end) {
423     // We've hit the end of a pattern, so create a successful result node.
424     node = std::make_unique<SuccessNode>(list.pattern, std::move(node));
425 
426     // If the pattern doesn't contain this predicate, ignore it.
427   } else if (list.predicates.find(*current) == list.predicates.end()) {
428     propagatePattern(node, list, std::next(current), end);
429 
430     // If the current matcher node is invalid, create a new one for this
431     // position and continue propagation.
432   } else if (!node) {
433     // Create a new node at this position and continue
434     node = std::make_unique<SwitchNode>((*current)->position,
435                                         (*current)->question);
436     propagatePattern(
437         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
438         list, std::next(current), end);
439 
440     // If the matcher has already been created, and it is for this predicate we
441     // continue propagation to the child.
442   } else if (isSamePredicate(node.get(), *current)) {
443     propagatePattern(
444         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
445         list, std::next(current), end);
446 
447     // If the matcher doesn't match the current predicate, insert a branch as
448     // the common set of matchers has diverged.
449   } else {
450     propagatePattern(node->getFailureNode(), list, current, end);
451   }
452 }
453 
454 /// Fold any switch nodes nested under `node` to boolean nodes when possible.
455 /// `node` is updated in-place if it is a switch.
456 static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
457   if (!node)
458     return;
459 
460   if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
461     SwitchNode::ChildMapT &children = switchNode->getChildren();
462     for (auto &it : children)
463       foldSwitchToBool(it.second);
464 
465     // If the node only contains one child, collapse it into a boolean predicate
466     // node.
467     if (children.size() == 1) {
468       auto childIt = children.begin();
469       node = std::make_unique<BoolNode>(
470           node->getPosition(), node->getQuestion(), childIt->first,
471           std::move(childIt->second), std::move(node->getFailureNode()));
472     }
473   } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
474     foldSwitchToBool(boolNode->getSuccessNode());
475   }
476 
477   foldSwitchToBool(node->getFailureNode());
478 }
479 
480 /// Insert an exit node at the end of the failure path of the `root`.
481 static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
482   while (*root)
483     root = &(*root)->getFailureNode();
484   *root = std::make_unique<ExitNode>();
485 }
486 
487 /// Given a module containing PDL pattern operations, generate a matcher tree
488 /// using the patterns within the given module and return the root matcher node.
489 std::unique_ptr<MatcherNode>
490 MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
491                                  DenseMap<Value, Position *> &valueToPosition) {
492   // Collect the set of predicates contained within the pattern operations of
493   // the module.
494   SmallVector<std::pair<pdl::PatternOp, std::vector<PositionalPredicate>>, 16>
495       patternsAndPredicates;
496   for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
497     std::vector<PositionalPredicate> predicateList;
498     buildPredicateList(pattern, builder, predicateList, valueToPosition);
499     patternsAndPredicates.emplace_back(pattern, std::move(predicateList));
500   }
501 
502   // Associate a pattern result with each unique predicate.
503   DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued;
504   for (auto &patternAndPredList : patternsAndPredicates) {
505     for (auto &predicate : patternAndPredList.second) {
506       auto it = uniqued.insert(predicate);
507       it.first->patternToAnswer.try_emplace(patternAndPredList.first,
508                                             predicate.answer);
509     }
510   }
511 
512   // Associate each pattern to a set of its ordered predicates for later lookup.
513   std::vector<OrderedPredicateList> lists;
514   lists.reserve(patternsAndPredicates.size());
515   for (auto &patternAndPredList : patternsAndPredicates) {
516     OrderedPredicateList list(patternAndPredList.first);
517     for (auto &predicate : patternAndPredList.second) {
518       OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
519       list.predicates.insert(orderedPredicate);
520 
521       // Increment the primary sum for each reference to a particular predicate.
522       ++orderedPredicate->primary;
523     }
524     lists.push_back(std::move(list));
525   }
526 
527   // For a particular pattern, get the total primary sum and add it to the
528   // secondary sum of each predicate. Square the primary sums to emphasize
529   // shared predicates within rather than across patterns.
530   for (auto &list : lists) {
531     unsigned total = 0;
532     for (auto *predicate : list.predicates)
533       total += predicate->primary * predicate->primary;
534     for (auto *predicate : list.predicates)
535       predicate->secondary += total;
536   }
537 
538   // Sort the set of predicates now that the cost primary and secondary sums
539   // have been computed.
540   std::vector<OrderedPredicate *> ordered;
541   ordered.reserve(uniqued.size());
542   for (auto &ip : uniqued)
543     ordered.push_back(&ip);
544   std::stable_sort(
545       ordered.begin(), ordered.end(),
546       [](OrderedPredicate *lhs, OrderedPredicate *rhs) { return *lhs < *rhs; });
547 
548   // Build the matchers for each of the pattern predicate lists.
549   std::unique_ptr<MatcherNode> root;
550   for (OrderedPredicateList &list : lists)
551     propagatePattern(root, list, ordered.begin(), ordered.end());
552 
553   // Collapse the graph and insert the exit node.
554   foldSwitchToBool(root);
555   insertExitNode(&root);
556   return root;
557 }
558 
559 //===----------------------------------------------------------------------===//
560 // MatcherNode
561 //===----------------------------------------------------------------------===//
562 
563 MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
564                          std::unique_ptr<MatcherNode> failureNode)
565     : position(p), question(q), failureNode(std::move(failureNode)),
566       matcherTypeID(matcherTypeID) {}
567 
568 //===----------------------------------------------------------------------===//
569 // BoolNode
570 //===----------------------------------------------------------------------===//
571 
572 BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
573                    std::unique_ptr<MatcherNode> successNode,
574                    std::unique_ptr<MatcherNode> failureNode)
575     : MatcherNode(TypeID::get<BoolNode>(), position, question,
576                   std::move(failureNode)),
577       answer(answer), successNode(std::move(successNode)) {}
578 
579 //===----------------------------------------------------------------------===//
580 // SuccessNode
581 //===----------------------------------------------------------------------===//
582 
583 SuccessNode::SuccessNode(pdl::PatternOp pattern,
584                          std::unique_ptr<MatcherNode> failureNode)
585     : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
586                   /*question=*/nullptr, std::move(failureNode)),
587       pattern(pattern) {}
588 
589 //===----------------------------------------------------------------------===//
590 // SwitchNode
591 //===----------------------------------------------------------------------===//
592 
593 SwitchNode::SwitchNode(Position *position, Qualifier *question)
594     : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
595