xref: /llvm-project/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp (revision 3a833a0e0e526d4ef3f0037eaa2ace3511f216ce)
1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PredicateTree.h"
10 #include "mlir/Dialect/PDL/IR/PDL.h"
11 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
12 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/Interfaces/InferTypeOpInterface.h"
15 #include "llvm/ADT/TypeSwitch.h"
16 
17 using namespace mlir;
18 using namespace mlir::pdl_to_pdl_interp;
19 
20 //===----------------------------------------------------------------------===//
21 // Predicate List Building
22 //===----------------------------------------------------------------------===//
23 
24 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
25                               Value val, PredicateBuilder &builder,
26                               DenseMap<Value, Position *> &inputs,
27                               Position *pos);
28 
29 /// Compares the depths of two positions.
30 static bool comparePosDepth(Position *lhs, Position *rhs) {
31   return lhs->getOperationDepth() < rhs->getOperationDepth();
32 }
33 
34 /// Returns the number of non-range elements within `values`.
35 static unsigned getNumNonRangeValues(ValueRange values) {
36   return llvm::count_if(values.getTypes(),
37                         [](Type type) { return !type.isa<pdl::RangeType>(); });
38 }
39 
40 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
41                               Value val, PredicateBuilder &builder,
42                               DenseMap<Value, Position *> &inputs,
43                               AttributePosition *pos) {
44   assert(val.getType().isa<pdl::AttributeType>() && "expected attribute type");
45   pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
46   predList.emplace_back(pos, builder.getIsNotNull());
47 
48   // If the attribute has a type or value, add a constraint.
49   if (Value type = attr.type())
50     getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
51   else if (Attribute value = attr.valueAttr())
52     predList.emplace_back(pos, builder.getAttributeConstraint(value));
53 }
54 
55 /// Collect all of the predicates for the given operand position.
56 static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
57                                      Value val, PredicateBuilder &builder,
58                                      DenseMap<Value, Position *> &inputs,
59                                      Position *pos) {
60   Type valueType = val.getType();
61   bool isVariadic = valueType.isa<pdl::RangeType>();
62 
63   // If this is a typed operand, add a type constraint.
64   TypeSwitch<Operation *>(val.getDefiningOp())
65       .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
66         // Prevent traversal into a null value if the operand has a proper
67         // index.
68         if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
69             cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
70           predList.emplace_back(pos, builder.getIsNotNull());
71 
72         if (Value type = op.type())
73           getTreePredicates(predList, type, builder, inputs,
74                             builder.getType(pos));
75       })
76       .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
77         Optional<unsigned> index = op.index();
78 
79         // Prevent traversal into a null value if the result has a proper index.
80         if (index)
81           predList.emplace_back(pos, builder.getIsNotNull());
82 
83         // Get the parent operation of this operand.
84         OperationPosition *parentPos = builder.getOperandDefiningOp(pos);
85         predList.emplace_back(parentPos, builder.getIsNotNull());
86 
87         // Ensure that the operands match the corresponding results of the
88         // parent operation.
89         Position *resultPos = nullptr;
90         if (std::is_same<pdl::ResultOp, decltype(op)>::value)
91           resultPos = builder.getResult(parentPos, *index);
92         else
93           resultPos = builder.getResultGroup(parentPos, index, isVariadic);
94         predList.emplace_back(resultPos, builder.getEqualTo(pos));
95 
96         // Collect the predicates of the parent operation.
97         getTreePredicates(predList, op.parent(), builder, inputs, parentPos);
98       });
99 }
100 
101 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
102                               Value val, PredicateBuilder &builder,
103                               DenseMap<Value, Position *> &inputs,
104                               OperationPosition *pos) {
105   assert(val.getType().isa<pdl::OperationType>() && "expected operation");
106   pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
107   OperationPosition *opPos = cast<OperationPosition>(pos);
108 
109   // Ensure getDefiningOp returns a non-null operation.
110   if (!opPos->isRoot())
111     predList.emplace_back(pos, builder.getIsNotNull());
112 
113   // Check that this is the correct root operation.
114   if (Optional<StringRef> opName = op.name())
115     predList.emplace_back(pos, builder.getOperationName(*opName));
116 
117   // Check that the operation has the proper number of operands. If there are
118   // any variable length operands, we check a minimum instead of an exact count.
119   OperandRange operands = op.operands();
120   unsigned minOperands = getNumNonRangeValues(operands);
121   if (minOperands != operands.size()) {
122     if (minOperands)
123       predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands));
124   } else {
125     predList.emplace_back(pos, builder.getOperandCount(minOperands));
126   }
127 
128   // Check that the operation has the proper number of results. If there are
129   // any variable length results, we check a minimum instead of an exact count.
130   OperandRange types = op.types();
131   unsigned minResults = getNumNonRangeValues(types);
132   if (minResults == types.size())
133     predList.emplace_back(pos, builder.getResultCount(types.size()));
134   else if (minResults)
135     predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
136 
137   // Recurse into any attributes, operands, or results.
138   for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
139     getTreePredicates(
140         predList, std::get<1>(it), builder, inputs,
141         builder.getAttribute(opPos,
142                              std::get<0>(it).cast<StringAttr>().getValue()));
143   }
144 
145   // Process the operands and results of the operation. For all values up to
146   // the first variable length value, we use the concrete operand/result
147   // number. After that, we use the "group" given that we can't know the
148   // concrete indices until runtime. If there is only one variadic operand
149   // group, we treat it as all of the operands/results of the operation.
150   /// Operands.
151   if (operands.size() == 1 && operands[0].getType().isa<pdl::RangeType>()) {
152     getTreePredicates(predList, operands.front(), builder, inputs,
153                       builder.getAllOperands(opPos));
154   } else {
155     bool foundVariableLength = false;
156     for (auto operandIt : llvm::enumerate(operands)) {
157       bool isVariadic = operandIt.value().getType().isa<pdl::RangeType>();
158       foundVariableLength |= isVariadic;
159 
160       Position *pos =
161           foundVariableLength
162               ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
163               : builder.getOperand(opPos, operandIt.index());
164       getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
165     }
166   }
167   /// Results.
168   if (types.size() == 1 && types[0].getType().isa<pdl::RangeType>()) {
169     getTreePredicates(predList, types.front(), builder, inputs,
170                       builder.getType(builder.getAllResults(opPos)));
171   } else {
172     bool foundVariableLength = false;
173     for (auto &resultIt : llvm::enumerate(types)) {
174       bool isVariadic = resultIt.value().getType().isa<pdl::RangeType>();
175       foundVariableLength |= isVariadic;
176 
177       auto *resultPos =
178           foundVariableLength
179               ? builder.getResultGroup(pos, resultIt.index(), isVariadic)
180               : builder.getResult(pos, resultIt.index());
181       predList.emplace_back(resultPos, builder.getIsNotNull());
182       getTreePredicates(predList, resultIt.value(), builder, inputs,
183                         builder.getType(resultPos));
184     }
185   }
186 }
187 
188 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
189                               Value val, PredicateBuilder &builder,
190                               DenseMap<Value, Position *> &inputs,
191                               TypePosition *pos) {
192   // Check for a constraint on a constant type.
193   if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
194     if (Attribute type = typeOp.typeAttr())
195       predList.emplace_back(pos, builder.getTypeConstraint(type));
196   } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
197     if (Attribute typeAttr = typeOp.typesAttr())
198       predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
199   }
200 }
201 
202 /// Collect the tree predicates anchored at the given value.
203 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
204                               Value val, PredicateBuilder &builder,
205                               DenseMap<Value, Position *> &inputs,
206                               Position *pos) {
207   // Make sure this input value is accessible to the rewrite.
208   auto it = inputs.try_emplace(val, pos);
209   if (!it.second) {
210     // If this is an input value that has been visited in the tree, add a
211     // constraint to ensure that both instances refer to the same value.
212     if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
213             pdl::TypeOp>(val.getDefiningOp())) {
214       auto minMaxPositions =
215           std::minmax(pos, it.first->second, comparePosDepth);
216       predList.emplace_back(minMaxPositions.second,
217                             builder.getEqualTo(minMaxPositions.first));
218     }
219     return;
220   }
221 
222   TypeSwitch<Position *>(pos)
223       .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) {
224         getTreePredicates(predList, val, builder, inputs, pos);
225       })
226       .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {
227         getOperandTreePredicates(predList, val, builder, inputs, pos);
228       })
229       .Default([](auto *) { llvm_unreachable("unexpected position kind"); });
230 }
231 
232 /// Collect all of the predicates related to constraints within the given
233 /// pattern operation.
234 static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
235                                     std::vector<PositionalPredicate> &predList,
236                                     PredicateBuilder &builder,
237                                     DenseMap<Value, Position *> &inputs) {
238   OperandRange arguments = op.args();
239   ArrayAttr parameters = op.constParamsAttr();
240 
241   std::vector<Position *> allPositions;
242   allPositions.reserve(arguments.size());
243   for (Value arg : arguments)
244     allPositions.push_back(inputs.lookup(arg));
245 
246   // Push the constraint to the furthest position.
247   Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
248                                     comparePosDepth);
249   PredicateBuilder::Predicate pred =
250       builder.getConstraint(op.name(), std::move(allPositions), parameters);
251   predList.emplace_back(pos, pred);
252 }
253 
254 static void getResultPredicates(pdl::ResultOp op,
255                                 std::vector<PositionalPredicate> &predList,
256                                 PredicateBuilder &builder,
257                                 DenseMap<Value, Position *> &inputs) {
258   Position *&resultPos = inputs[op];
259   if (resultPos)
260     return;
261 
262   // Ensure that the result isn't null.
263   auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
264   resultPos = builder.getResult(parentPos, op.index());
265   predList.emplace_back(resultPos, builder.getIsNotNull());
266 }
267 
268 static void getResultPredicates(pdl::ResultsOp op,
269                                 std::vector<PositionalPredicate> &predList,
270                                 PredicateBuilder &builder,
271                                 DenseMap<Value, Position *> &inputs) {
272   Position *&resultPos = inputs[op];
273   if (resultPos)
274     return;
275 
276   // Ensure that the result isn't null if the result has an index.
277   auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
278   bool isVariadic = op.getType().isa<pdl::RangeType>();
279   Optional<unsigned> index = op.index();
280   resultPos = builder.getResultGroup(parentPos, index, isVariadic);
281   if (index)
282     predList.emplace_back(resultPos, builder.getIsNotNull());
283 }
284 
285 /// Collect all of the predicates that cannot be determined via walking the
286 /// tree.
287 static void getNonTreePredicates(pdl::PatternOp pattern,
288                                  std::vector<PositionalPredicate> &predList,
289                                  PredicateBuilder &builder,
290                                  DenseMap<Value, Position *> &inputs) {
291   for (Operation &op : pattern.body().getOps()) {
292     TypeSwitch<Operation *>(&op)
293         .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
294           getConstraintPredicates(constraintOp, predList, builder, inputs);
295         })
296         .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
297           getResultPredicates(resultOp, predList, builder, inputs);
298         });
299   }
300 }
301 
302 /// Given a pattern operation, build the set of matcher predicates necessary to
303 /// match this pattern.
304 static void buildPredicateList(pdl::PatternOp pattern,
305                                PredicateBuilder &builder,
306                                std::vector<PositionalPredicate> &predList,
307                                DenseMap<Value, Position *> &valueToPosition) {
308   getTreePredicates(predList, pattern.getRewriter().root(), builder,
309                     valueToPosition, builder.getRoot());
310   getNonTreePredicates(pattern, predList, builder, valueToPosition);
311 }
312 
313 //===----------------------------------------------------------------------===//
314 // Pattern Predicate Tree Merging
315 //===----------------------------------------------------------------------===//
316 
317 namespace {
318 
319 /// This class represents a specific predicate applied to a position, and
320 /// provides hashing and ordering operators. This class allows for computing a
321 /// frequence sum and ordering predicates based on a cost model.
322 struct OrderedPredicate {
323   OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
324       : position(ip.first), question(ip.second) {}
325   OrderedPredicate(const PositionalPredicate &ip)
326       : position(ip.position), question(ip.question) {}
327 
328   /// The position this predicate is applied to.
329   Position *position;
330 
331   /// The question that is applied by this predicate onto the position.
332   Qualifier *question;
333 
334   /// The first and second order benefit sums.
335   /// The primary sum is the number of occurrences of this predicate among all
336   /// of the patterns.
337   unsigned primary = 0;
338   /// The secondary sum is a squared summation of the primary sum of all of the
339   /// predicates within each pattern that contains this predicate. This allows
340   /// for favoring predicates that are more commonly shared within a pattern, as
341   /// opposed to those shared across patterns.
342   unsigned secondary = 0;
343 
344   /// A map between a pattern operation and the answer to the predicate question
345   /// within that pattern.
346   DenseMap<Operation *, Qualifier *> patternToAnswer;
347 
348   /// Returns true if this predicate is ordered before `rhs`, based on the cost
349   /// model.
350   bool operator<(const OrderedPredicate &rhs) const {
351     // Sort by:
352     // * higher first and secondary order sums
353     // * lower depth
354     // * lower position dependency
355     // * lower predicate dependency
356     auto *rhsPos = rhs.position;
357     return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
358                            rhsPos->getKind(), rhs.question->getKind()) >
359            std::make_tuple(rhs.primary, rhs.secondary,
360                            position->getOperationDepth(), position->getKind(),
361                            question->getKind());
362   }
363 };
364 
365 /// A DenseMapInfo for OrderedPredicate based solely on the position and
366 /// question.
367 struct OrderedPredicateDenseInfo {
368   using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>;
369 
370   static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
371   static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
372   static bool isEqual(const OrderedPredicate &lhs,
373                       const OrderedPredicate &rhs) {
374     return lhs.position == rhs.position && lhs.question == rhs.question;
375   }
376   static unsigned getHashValue(const OrderedPredicate &p) {
377     return llvm::hash_combine(p.position, p.question);
378   }
379 };
380 
381 /// This class wraps a set of ordered predicates that are used within a specific
382 /// pattern operation.
383 struct OrderedPredicateList {
384   OrderedPredicateList(pdl::PatternOp pattern) : pattern(pattern) {}
385 
386   pdl::PatternOp pattern;
387   DenseSet<OrderedPredicate *> predicates;
388 };
389 } // end anonymous namespace
390 
391 /// Returns true if the given matcher refers to the same predicate as the given
392 /// ordered predicate. This means that the position and questions of the two
393 /// match.
394 static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
395   return node->getPosition() == predicate->position &&
396          node->getQuestion() == predicate->question;
397 }
398 
399 /// Get or insert a child matcher for the given parent switch node, given a
400 /// predicate and parent pattern.
401 std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
402                                                OrderedPredicate *predicate,
403                                                pdl::PatternOp pattern) {
404   assert(isSamePredicate(node, predicate) &&
405          "expected matcher to equal the given predicate");
406 
407   auto it = predicate->patternToAnswer.find(pattern);
408   assert(it != predicate->patternToAnswer.end() &&
409          "expected pattern to exist in predicate");
410   return node->getChildren().insert({it->second, nullptr}).first->second;
411 }
412 
413 /// Build the matcher CFG by "pushing" patterns through by sorted predicate
414 /// order. A pattern will traverse as far as possible using common predicates
415 /// and then either diverge from the CFG or reach the end of a branch and start
416 /// creating new nodes.
417 static void propagatePattern(std::unique_ptr<MatcherNode> &node,
418                              OrderedPredicateList &list,
419                              std::vector<OrderedPredicate *>::iterator current,
420                              std::vector<OrderedPredicate *>::iterator end) {
421   if (current == end) {
422     // We've hit the end of a pattern, so create a successful result node.
423     node = std::make_unique<SuccessNode>(list.pattern, std::move(node));
424 
425     // If the pattern doesn't contain this predicate, ignore it.
426   } else if (list.predicates.find(*current) == list.predicates.end()) {
427     propagatePattern(node, list, std::next(current), end);
428 
429     // If the current matcher node is invalid, create a new one for this
430     // position and continue propagation.
431   } else if (!node) {
432     // Create a new node at this position and continue
433     node = std::make_unique<SwitchNode>((*current)->position,
434                                         (*current)->question);
435     propagatePattern(
436         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
437         list, std::next(current), end);
438 
439     // If the matcher has already been created, and it is for this predicate we
440     // continue propagation to the child.
441   } else if (isSamePredicate(node.get(), *current)) {
442     propagatePattern(
443         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
444         list, std::next(current), end);
445 
446     // If the matcher doesn't match the current predicate, insert a branch as
447     // the common set of matchers has diverged.
448   } else {
449     propagatePattern(node->getFailureNode(), list, current, end);
450   }
451 }
452 
453 /// Fold any switch nodes nested under `node` to boolean nodes when possible.
454 /// `node` is updated in-place if it is a switch.
455 static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
456   if (!node)
457     return;
458 
459   if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
460     SwitchNode::ChildMapT &children = switchNode->getChildren();
461     for (auto &it : children)
462       foldSwitchToBool(it.second);
463 
464     // If the node only contains one child, collapse it into a boolean predicate
465     // node.
466     if (children.size() == 1) {
467       auto childIt = children.begin();
468       node = std::make_unique<BoolNode>(
469           node->getPosition(), node->getQuestion(), childIt->first,
470           std::move(childIt->second), std::move(node->getFailureNode()));
471     }
472   } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
473     foldSwitchToBool(boolNode->getSuccessNode());
474   }
475 
476   foldSwitchToBool(node->getFailureNode());
477 }
478 
479 /// Insert an exit node at the end of the failure path of the `root`.
480 static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
481   while (*root)
482     root = &(*root)->getFailureNode();
483   *root = std::make_unique<ExitNode>();
484 }
485 
486 /// Given a module containing PDL pattern operations, generate a matcher tree
487 /// using the patterns within the given module and return the root matcher node.
488 std::unique_ptr<MatcherNode>
489 MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
490                                  DenseMap<Value, Position *> &valueToPosition) {
491   // Collect the set of predicates contained within the pattern operations of
492   // the module.
493   SmallVector<std::pair<pdl::PatternOp, std::vector<PositionalPredicate>>, 16>
494       patternsAndPredicates;
495   for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
496     std::vector<PositionalPredicate> predicateList;
497     buildPredicateList(pattern, builder, predicateList, valueToPosition);
498     patternsAndPredicates.emplace_back(pattern, std::move(predicateList));
499   }
500 
501   // Associate a pattern result with each unique predicate.
502   DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued;
503   for (auto &patternAndPredList : patternsAndPredicates) {
504     for (auto &predicate : patternAndPredList.second) {
505       auto it = uniqued.insert(predicate);
506       it.first->patternToAnswer.try_emplace(patternAndPredList.first,
507                                             predicate.answer);
508     }
509   }
510 
511   // Associate each pattern to a set of its ordered predicates for later lookup.
512   std::vector<OrderedPredicateList> lists;
513   lists.reserve(patternsAndPredicates.size());
514   for (auto &patternAndPredList : patternsAndPredicates) {
515     OrderedPredicateList list(patternAndPredList.first);
516     for (auto &predicate : patternAndPredList.second) {
517       OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
518       list.predicates.insert(orderedPredicate);
519 
520       // Increment the primary sum for each reference to a particular predicate.
521       ++orderedPredicate->primary;
522     }
523     lists.push_back(std::move(list));
524   }
525 
526   // For a particular pattern, get the total primary sum and add it to the
527   // secondary sum of each predicate. Square the primary sums to emphasize
528   // shared predicates within rather than across patterns.
529   for (auto &list : lists) {
530     unsigned total = 0;
531     for (auto *predicate : list.predicates)
532       total += predicate->primary * predicate->primary;
533     for (auto *predicate : list.predicates)
534       predicate->secondary += total;
535   }
536 
537   // Sort the set of predicates now that the cost primary and secondary sums
538   // have been computed.
539   std::vector<OrderedPredicate *> ordered;
540   ordered.reserve(uniqued.size());
541   for (auto &ip : uniqued)
542     ordered.push_back(&ip);
543   std::stable_sort(
544       ordered.begin(), ordered.end(),
545       [](OrderedPredicate *lhs, OrderedPredicate *rhs) { return *lhs < *rhs; });
546 
547   // Build the matchers for each of the pattern predicate lists.
548   std::unique_ptr<MatcherNode> root;
549   for (OrderedPredicateList &list : lists)
550     propagatePattern(root, list, ordered.begin(), ordered.end());
551 
552   // Collapse the graph and insert the exit node.
553   foldSwitchToBool(root);
554   insertExitNode(&root);
555   return root;
556 }
557 
558 //===----------------------------------------------------------------------===//
559 // MatcherNode
560 //===----------------------------------------------------------------------===//
561 
562 MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
563                          std::unique_ptr<MatcherNode> failureNode)
564     : position(p), question(q), failureNode(std::move(failureNode)),
565       matcherTypeID(matcherTypeID) {}
566 
567 //===----------------------------------------------------------------------===//
568 // BoolNode
569 //===----------------------------------------------------------------------===//
570 
571 BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
572                    std::unique_ptr<MatcherNode> successNode,
573                    std::unique_ptr<MatcherNode> failureNode)
574     : MatcherNode(TypeID::get<BoolNode>(), position, question,
575                   std::move(failureNode)),
576       answer(answer), successNode(std::move(successNode)) {}
577 
578 //===----------------------------------------------------------------------===//
579 // SuccessNode
580 //===----------------------------------------------------------------------===//
581 
582 SuccessNode::SuccessNode(pdl::PatternOp pattern,
583                          std::unique_ptr<MatcherNode> failureNode)
584     : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
585                   /*question=*/nullptr, std::move(failureNode)),
586       pattern(pattern) {}
587 
588 //===----------------------------------------------------------------------===//
589 // SwitchNode
590 //===----------------------------------------------------------------------===//
591 
592 SwitchNode::SwitchNode(Position *position, Qualifier *question)
593     : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
594