xref: /llvm-project/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h (revision 8ec28af8eaff5acd0df3e53340159c034f08533d)
1 //===- Predicate.h - Pattern predicates -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions for "predicates" used when converting PDL into
10 // a matcher tree. Predicates are composed of three different parts:
11 //
12 //  * Positions
13 //    - A position refers to a specific location on the input DAG, i.e. an
14 //      existing MLIR entity being matched. These can be attributes, operands,
15 //      operations, results, and types. Each position also defines a relation to
16 //      its parent. For example, the operand `[0] -> 1` has a parent operation
17 //      position `[0]`. The attribute `[0, 1] -> "myAttr"` has parent operation
18 //      position of `[0, 1]`. The operation `[0, 1]` has a parent operand edge
19 //      `[0] -> 1` (i.e. it is the defining op of operand 1). The only position
20 //      without a parent is `[0]`, which refers to the root operation.
21 //  * Questions
22 //    - A question refers to a query on a specific positional value. For
23 //    example, an operation name question checks the name of an operation
24 //    position.
25 //  * Answers
26 //    - An answer is the expected result of a question. For example, when
27 //    matching an operation with the name "foo.op". The question would be an
28 //    operation name question, with an expected answer of "foo.op".
29 //
30 //===----------------------------------------------------------------------===//
31 
32 #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
33 #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
34 
35 #include "mlir/IR/MLIRContext.h"
36 #include "mlir/IR/OperationSupport.h"
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/IR/Types.h"
39 
40 namespace mlir {
41 namespace pdl_to_pdl_interp {
42 namespace Predicates {
43 /// An enumeration of the kinds of predicates.
44 enum Kind : unsigned {
45   /// Positions, ordered by decreasing priority.
46   OperationPos,
47   OperandPos,
48   OperandGroupPos,
49   AttributePos,
50   ConstraintResultPos,
51   ResultPos,
52   ResultGroupPos,
53   TypePos,
54   AttributeLiteralPos,
55   TypeLiteralPos,
56   UsersPos,
57   ForEachPos,
58 
59   // Questions, ordered by dependency and decreasing priority.
60   IsNotNullQuestion,
61   OperationNameQuestion,
62   TypeQuestion,
63   AttributeQuestion,
64   OperandCountAtLeastQuestion,
65   OperandCountQuestion,
66   ResultCountAtLeastQuestion,
67   ResultCountQuestion,
68   EqualToQuestion,
69   ConstraintQuestion,
70 
71   // Answers.
72   AttributeAnswer,
73   FalseAnswer,
74   OperationNameAnswer,
75   TrueAnswer,
76   TypeAnswer,
77   UnsignedAnswer,
78 };
79 } // namespace Predicates
80 
81 /// Base class for all predicates, used to allow efficient pointer comparison.
82 template <typename ConcreteT, typename BaseT, typename Key,
83           Predicates::Kind Kind>
84 class PredicateBase : public BaseT {
85 public:
86   using KeyTy = Key;
87   using Base = PredicateBase<ConcreteT, BaseT, Key, Kind>;
88 
89   template <typename KeyT>
PredicateBase(KeyT && key)90   explicit PredicateBase(KeyT &&key)
91       : BaseT(Kind), key(std::forward<KeyT>(key)) {}
92 
93   /// Get an instance of this position.
94   template <typename... Args>
get(StorageUniquer & uniquer,Args &&...args)95   static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) {
96     return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...);
97   }
98 
99   /// Construct an instance with the given storage allocator.
100   template <typename KeyT>
construct(StorageUniquer::StorageAllocator & alloc,KeyT && key)101   static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
102                               KeyT &&key) {
103     return new (alloc.allocate<ConcreteT>()) ConcreteT(std::forward<KeyT>(key));
104   }
105 
106   /// Utility methods required by the storage allocator.
107   bool operator==(const KeyTy &key) const { return this->key == key; }
classof(const BaseT * pred)108   static bool classof(const BaseT *pred) { return pred->getKind() == Kind; }
109 
110   /// Return the key value of this predicate.
getValue()111   const KeyTy &getValue() const { return key; }
112 
113 protected:
114   KeyTy key;
115 };
116 
117 /// Base storage for simple predicates that only unique with the kind.
118 template <typename ConcreteT, typename BaseT, Predicates::Kind Kind>
119 class PredicateBase<ConcreteT, BaseT, void, Kind> : public BaseT {
120 public:
121   using Base = PredicateBase<ConcreteT, BaseT, void, Kind>;
122 
PredicateBase()123   explicit PredicateBase() : BaseT(Kind) {}
124 
get(StorageUniquer & uniquer)125   static ConcreteT *get(StorageUniquer &uniquer) {
126     return uniquer.get<ConcreteT>();
127   }
classof(const BaseT * pred)128   static bool classof(const BaseT *pred) { return pred->getKind() == Kind; }
129 };
130 
131 //===----------------------------------------------------------------------===//
132 // Positions
133 //===----------------------------------------------------------------------===//
134 
135 struct OperationPosition;
136 
137 /// A position describes a value on the input IR on which a predicate may be
138 /// applied, such as an operation or attribute. This enables re-use between
139 /// predicates, and assists generating bytecode and memory management.
140 ///
141 /// Operation positions form the base of other positions, which are formed
142 /// relative to a parent operation. Operations are anchored at Operand nodes,
143 /// except for the root operation which is parentless.
144 class Position : public StorageUniquer::BaseStorage {
145 public:
Position(Predicates::Kind kind)146   explicit Position(Predicates::Kind kind) : kind(kind) {}
147   virtual ~Position();
148 
149   /// Returns the depth of the first ancestor operation position.
150   unsigned getOperationDepth() const;
151 
152   /// Returns the parent position. The root operation position has no parent.
getParent()153   Position *getParent() const { return parent; }
154 
155   /// Returns the kind of this position.
getKind()156   Predicates::Kind getKind() const { return kind; }
157 
158 protected:
159   /// Link to the parent position.
160   Position *parent = nullptr;
161 
162 private:
163   /// The kind of this position.
164   Predicates::Kind kind;
165 };
166 
167 //===----------------------------------------------------------------------===//
168 // AttributePosition
169 
170 /// A position describing an attribute of an operation.
171 struct AttributePosition
172     : public PredicateBase<AttributePosition, Position,
173                            std::pair<OperationPosition *, StringAttr>,
174                            Predicates::AttributePos> {
175   explicit AttributePosition(const KeyTy &key);
176 
177   /// Returns the attribute name of this position.
getNameAttributePosition178   StringAttr getName() const { return key.second; }
179 };
180 
181 //===----------------------------------------------------------------------===//
182 // AttributeLiteralPosition
183 
184 /// A position describing a literal attribute.
185 struct AttributeLiteralPosition
186     : public PredicateBase<AttributeLiteralPosition, Position, Attribute,
187                            Predicates::AttributeLiteralPos> {
188   using PredicateBase::PredicateBase;
189 };
190 
191 //===----------------------------------------------------------------------===//
192 // ForEachPosition
193 
194 /// A position describing an iterative choice of an operation.
195 struct ForEachPosition : public PredicateBase<ForEachPosition, Position,
196                                               std::pair<Position *, unsigned>,
197                                               Predicates::ForEachPos> {
ForEachPositionForEachPosition198   explicit ForEachPosition(const KeyTy &key) : Base(key) { parent = key.first; }
199 
200   /// Returns the ID, for differentiating various loops.
201   /// For upward traversals, this is the index of the root.
getIDForEachPosition202   unsigned getID() const { return key.second; }
203 };
204 
205 //===----------------------------------------------------------------------===//
206 // OperandPosition
207 
208 /// A position describing an operand of an operation.
209 struct OperandPosition
210     : public PredicateBase<OperandPosition, Position,
211                            std::pair<OperationPosition *, unsigned>,
212                            Predicates::OperandPos> {
213   explicit OperandPosition(const KeyTy &key);
214 
215   /// Returns the operand number of this position.
getOperandNumberOperandPosition216   unsigned getOperandNumber() const { return key.second; }
217 };
218 
219 //===----------------------------------------------------------------------===//
220 // OperandGroupPosition
221 
222 /// A position describing an operand group of an operation.
223 struct OperandGroupPosition
224     : public PredicateBase<
225           OperandGroupPosition, Position,
226           std::tuple<OperationPosition *, std::optional<unsigned>, bool>,
227           Predicates::OperandGroupPos> {
228   explicit OperandGroupPosition(const KeyTy &key);
229 
230   /// Returns a hash suitable for the given keytype.
hashKeyOperandGroupPosition231   static llvm::hash_code hashKey(const KeyTy &key) {
232     return llvm::hash_value(key);
233   }
234 
235   /// Returns the group number of this position. If std::nullopt, this group
236   /// refers to all operands.
getOperandGroupNumberOperandGroupPosition237   std::optional<unsigned> getOperandGroupNumber() const {
238     return std::get<1>(key);
239   }
240 
241   /// Returns if the operand group has unknown size. If false, the operand group
242   /// has at max one element.
isVariadicOperandGroupPosition243   bool isVariadic() const { return std::get<2>(key); }
244 };
245 
246 //===----------------------------------------------------------------------===//
247 // OperationPosition
248 
249 /// An operation position describes an operation node in the IR. Other position
250 /// kinds are formed with respect to an operation position.
251 struct OperationPosition : public PredicateBase<OperationPosition, Position,
252                                                 std::pair<Position *, unsigned>,
253                                                 Predicates::OperationPos> {
OperationPositionOperationPosition254   explicit OperationPosition(const KeyTy &key) : Base(key) {
255     parent = key.first;
256   }
257 
258   /// Returns a hash suitable for the given keytype.
hashKeyOperationPosition259   static llvm::hash_code hashKey(const KeyTy &key) {
260     return llvm::hash_value(key);
261   }
262 
263   /// Gets the root position.
getRootOperationPosition264   static OperationPosition *getRoot(StorageUniquer &uniquer) {
265     return Base::get(uniquer, nullptr, 0);
266   }
267 
268   /// Gets an operation position with the given parent.
getOperationPosition269   static OperationPosition *get(StorageUniquer &uniquer, Position *parent) {
270     return Base::get(uniquer, parent, parent->getOperationDepth() + 1);
271   }
272 
273   /// Returns the depth of this position.
getDepthOperationPosition274   unsigned getDepth() const { return key.second; }
275 
276   /// Returns if this operation position corresponds to the root.
isRootOperationPosition277   bool isRoot() const { return getDepth() == 0; }
278 
279   /// Returns if this operation represents an operand defining op.
280   bool isOperandDefiningOp() const;
281 };
282 
283 //===----------------------------------------------------------------------===//
284 // ConstraintPosition
285 
286 struct ConstraintQuestion;
287 
288 /// A position describing the result of a native constraint. It saves the
289 /// corresponding ConstraintQuestion and result index to enable referring
290 /// back to them
291 struct ConstraintPosition
292     : public PredicateBase<ConstraintPosition, Position,
293                            std::pair<ConstraintQuestion *, unsigned>,
294                            Predicates::ConstraintResultPos> {
295   using PredicateBase::PredicateBase;
296 
297   /// Returns the ConstraintQuestion to enable keeping track of the native
298   /// constraint this position stems from.
getQuestionConstraintPosition299   ConstraintQuestion *getQuestion() const { return key.first; }
300 
301   // Returns the result index of this position
getIndexConstraintPosition302   unsigned getIndex() const { return key.second; }
303 };
304 
305 //===----------------------------------------------------------------------===//
306 // ResultPosition
307 
308 /// A position describing a result of an operation.
309 struct ResultPosition
310     : public PredicateBase<ResultPosition, Position,
311                            std::pair<OperationPosition *, unsigned>,
312                            Predicates::ResultPos> {
ResultPositionResultPosition313   explicit ResultPosition(const KeyTy &key) : Base(key) { parent = key.first; }
314 
315   /// Returns the result number of this position.
getResultNumberResultPosition316   unsigned getResultNumber() const { return key.second; }
317 };
318 
319 //===----------------------------------------------------------------------===//
320 // ResultGroupPosition
321 
322 /// A position describing a result group of an operation.
323 struct ResultGroupPosition
324     : public PredicateBase<
325           ResultGroupPosition, Position,
326           std::tuple<OperationPosition *, std::optional<unsigned>, bool>,
327           Predicates::ResultGroupPos> {
ResultGroupPositionResultGroupPosition328   explicit ResultGroupPosition(const KeyTy &key) : Base(key) {
329     parent = std::get<0>(key);
330   }
331 
332   /// Returns a hash suitable for the given keytype.
hashKeyResultGroupPosition333   static llvm::hash_code hashKey(const KeyTy &key) {
334     return llvm::hash_value(key);
335   }
336 
337   /// Returns the group number of this position. If std::nullopt, this group
338   /// refers to all results.
getResultGroupNumberResultGroupPosition339   std::optional<unsigned> getResultGroupNumber() const {
340     return std::get<1>(key);
341   }
342 
343   /// Returns if the result group has unknown size. If false, the result group
344   /// has at max one element.
isVariadicResultGroupPosition345   bool isVariadic() const { return std::get<2>(key); }
346 };
347 
348 //===----------------------------------------------------------------------===//
349 // TypePosition
350 
351 /// A position describing the result type of an entity, i.e. an Attribute,
352 /// Operand, Result, etc.
353 struct TypePosition : public PredicateBase<TypePosition, Position, Position *,
354                                            Predicates::TypePos> {
TypePositionTypePosition355   explicit TypePosition(const KeyTy &key) : Base(key) {
356     assert((isa<AttributePosition, OperandPosition, OperandGroupPosition,
357                 ResultPosition, ResultGroupPosition>(key)) &&
358            "expected parent to be an attribute, operand, or result");
359     parent = key;
360   }
361 };
362 
363 //===----------------------------------------------------------------------===//
364 // TypeLiteralPosition
365 
366 /// A position describing a literal type or type range. The value is stored as
367 /// either a TypeAttr, or an ArrayAttr of TypeAttr.
368 struct TypeLiteralPosition
369     : public PredicateBase<TypeLiteralPosition, Position, Attribute,
370                            Predicates::TypeLiteralPos> {
371   using PredicateBase::PredicateBase;
372 };
373 
374 //===----------------------------------------------------------------------===//
375 // UsersPosition
376 
377 /// A position describing the users of a value or a range of values. The second
378 /// value in the key indicates whether we choose users of a representative for
379 /// a range (this is true, e.g., in the upward traversals).
380 struct UsersPosition
381     : public PredicateBase<UsersPosition, Position, std::pair<Position *, bool>,
382                            Predicates::UsersPos> {
UsersPositionUsersPosition383   explicit UsersPosition(const KeyTy &key) : Base(key) { parent = key.first; }
384 
385   /// Returns a hash suitable for the given keytype.
hashKeyUsersPosition386   static llvm::hash_code hashKey(const KeyTy &key) {
387     return llvm::hash_value(key);
388   }
389 
390   /// Indicates whether to compute a range of a representative.
useRepresentativeUsersPosition391   bool useRepresentative() const { return key.second; }
392 };
393 
394 //===----------------------------------------------------------------------===//
395 // Qualifiers
396 //===----------------------------------------------------------------------===//
397 
398 /// An ordinal predicate consists of a "Question" and a set of acceptable
399 /// "Answers" (later converted to ordinal values). A predicate will query some
400 /// property of a positional value and decide what to do based on the result.
401 ///
402 /// This makes top-level predicate representations ordinal (SwitchOp). Later,
403 /// predicates that end up with only one acceptable answer (including all
404 /// boolean kinds) will be converted to boolean predicates (PredicateOp) in the
405 /// matcher.
406 ///
407 /// For simplicity, both are represented as "qualifiers", with a base kind and
408 /// perhaps additional properties. For example, all OperationName predicates ask
409 /// the same question, but GenericConstraint predicates may ask different ones.
410 class Qualifier : public StorageUniquer::BaseStorage {
411 public:
Qualifier(Predicates::Kind kind)412   explicit Qualifier(Predicates::Kind kind) : kind(kind) {}
413 
414   /// Returns the kind of this qualifier.
getKind()415   Predicates::Kind getKind() const { return kind; }
416 
417 private:
418   /// The kind of this position.
419   Predicates::Kind kind;
420 };
421 
422 //===----------------------------------------------------------------------===//
423 // Answers
424 
425 /// An Answer representing an `Attribute` value.
426 struct AttributeAnswer
427     : public PredicateBase<AttributeAnswer, Qualifier, Attribute,
428                            Predicates::AttributeAnswer> {
429   using Base::Base;
430 };
431 
432 /// An Answer representing an `OperationName` value.
433 struct OperationNameAnswer
434     : public PredicateBase<OperationNameAnswer, Qualifier, OperationName,
435                            Predicates::OperationNameAnswer> {
436   using Base::Base;
437 };
438 
439 /// An Answer representing a boolean `true` value.
440 struct TrueAnswer
441     : PredicateBase<TrueAnswer, Qualifier, void, Predicates::TrueAnswer> {
442   using Base::Base;
443 };
444 
445 /// An Answer representing a boolean 'false' value.
446 struct FalseAnswer
447     : PredicateBase<FalseAnswer, Qualifier, void, Predicates::FalseAnswer> {
448   using Base::Base;
449 };
450 
451 /// An Answer representing a `Type` value. The value is stored as either a
452 /// TypeAttr, or an ArrayAttr of TypeAttr.
453 struct TypeAnswer : public PredicateBase<TypeAnswer, Qualifier, Attribute,
454                                          Predicates::TypeAnswer> {
455   using Base::Base;
456 };
457 
458 /// An Answer representing an unsigned value.
459 struct UnsignedAnswer
460     : public PredicateBase<UnsignedAnswer, Qualifier, unsigned,
461                            Predicates::UnsignedAnswer> {
462   using Base::Base;
463 };
464 
465 //===----------------------------------------------------------------------===//
466 // Questions
467 
468 /// Compare an `Attribute` to a constant value.
469 struct AttributeQuestion
470     : public PredicateBase<AttributeQuestion, Qualifier, void,
471                            Predicates::AttributeQuestion> {};
472 
473 /// Apply a parameterized constraint to multiple position values and possibly
474 /// produce results.
475 struct ConstraintQuestion
476     : public PredicateBase<
477           ConstraintQuestion, Qualifier,
478           std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
479           Predicates::ConstraintQuestion> {
480   using Base::Base;
481 
482   /// Return the name of the constraint.
getNameConstraintQuestion483   StringRef getName() const { return std::get<0>(key); }
484 
485   /// Return the arguments of the constraint.
getArgsConstraintQuestion486   ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
487 
488   /// Return the result types of the constraint.
getResultTypesConstraintQuestion489   ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }
490 
491   /// Return the negation status of the constraint.
getIsNegatedConstraintQuestion492   bool getIsNegated() const { return std::get<3>(key); }
493 
494   /// Construct an instance with the given storage allocator.
constructConstraintQuestion495   static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
496                                        KeyTy key) {
497     return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
498                                         alloc.copyInto(std::get<1>(key)),
499                                         alloc.copyInto(std::get<2>(key)),
500                                         std::get<3>(key)});
501   }
502 
503   /// Returns a hash suitable for the given keytype.
hashKeyConstraintQuestion504   static llvm::hash_code hashKey(const KeyTy &key) {
505     return llvm::hash_value(key);
506   }
507 };
508 
509 /// Compare the equality of two values.
510 struct EqualToQuestion
511     : public PredicateBase<EqualToQuestion, Qualifier, Position *,
512                            Predicates::EqualToQuestion> {
513   using Base::Base;
514 };
515 
516 /// Compare a positional value with null, i.e. check if it exists.
517 struct IsNotNullQuestion
518     : public PredicateBase<IsNotNullQuestion, Qualifier, void,
519                            Predicates::IsNotNullQuestion> {};
520 
521 /// Compare the number of operands of an operation with a known value.
522 struct OperandCountQuestion
523     : public PredicateBase<OperandCountQuestion, Qualifier, void,
524                            Predicates::OperandCountQuestion> {};
525 struct OperandCountAtLeastQuestion
526     : public PredicateBase<OperandCountAtLeastQuestion, Qualifier, void,
527                            Predicates::OperandCountAtLeastQuestion> {};
528 
529 /// Compare the name of an operation with a known value.
530 struct OperationNameQuestion
531     : public PredicateBase<OperationNameQuestion, Qualifier, void,
532                            Predicates::OperationNameQuestion> {};
533 
534 /// Compare the number of results of an operation with a known value.
535 struct ResultCountQuestion
536     : public PredicateBase<ResultCountQuestion, Qualifier, void,
537                            Predicates::ResultCountQuestion> {};
538 struct ResultCountAtLeastQuestion
539     : public PredicateBase<ResultCountAtLeastQuestion, Qualifier, void,
540                            Predicates::ResultCountAtLeastQuestion> {};
541 
542 /// Compare the type of an attribute or value with a known type.
543 struct TypeQuestion : public PredicateBase<TypeQuestion, Qualifier, void,
544                                            Predicates::TypeQuestion> {};
545 
546 //===----------------------------------------------------------------------===//
547 // PredicateUniquer
548 //===----------------------------------------------------------------------===//
549 
550 /// This class provides a storage uniquer that is used to allocate predicate
551 /// instances.
552 class PredicateUniquer : public StorageUniquer {
553 public:
PredicateUniquer()554   PredicateUniquer() {
555     // Register the types of Positions with the uniquer.
556     registerParametricStorageType<AttributePosition>();
557     registerParametricStorageType<AttributeLiteralPosition>();
558     registerParametricStorageType<ConstraintPosition>();
559     registerParametricStorageType<ForEachPosition>();
560     registerParametricStorageType<OperandPosition>();
561     registerParametricStorageType<OperandGroupPosition>();
562     registerParametricStorageType<OperationPosition>();
563     registerParametricStorageType<ResultPosition>();
564     registerParametricStorageType<ResultGroupPosition>();
565     registerParametricStorageType<TypePosition>();
566     registerParametricStorageType<TypeLiteralPosition>();
567     registerParametricStorageType<UsersPosition>();
568 
569     // Register the types of Questions with the uniquer.
570     registerParametricStorageType<AttributeAnswer>();
571     registerParametricStorageType<OperationNameAnswer>();
572     registerParametricStorageType<TypeAnswer>();
573     registerParametricStorageType<UnsignedAnswer>();
574     registerSingletonStorageType<FalseAnswer>();
575     registerSingletonStorageType<TrueAnswer>();
576 
577     // Register the types of Answers with the uniquer.
578     registerParametricStorageType<ConstraintQuestion>();
579     registerParametricStorageType<EqualToQuestion>();
580     registerSingletonStorageType<AttributeQuestion>();
581     registerSingletonStorageType<IsNotNullQuestion>();
582     registerSingletonStorageType<OperandCountQuestion>();
583     registerSingletonStorageType<OperandCountAtLeastQuestion>();
584     registerSingletonStorageType<OperationNameQuestion>();
585     registerSingletonStorageType<ResultCountQuestion>();
586     registerSingletonStorageType<ResultCountAtLeastQuestion>();
587     registerSingletonStorageType<TypeQuestion>();
588   }
589 };
590 
591 //===----------------------------------------------------------------------===//
592 // PredicateBuilder
593 //===----------------------------------------------------------------------===//
594 
595 /// This class provides utilities for constructing predicates.
596 class PredicateBuilder {
597 public:
PredicateBuilder(PredicateUniquer & uniquer,MLIRContext * ctx)598   PredicateBuilder(PredicateUniquer &uniquer, MLIRContext *ctx)
599       : uniquer(uniquer), ctx(ctx) {}
600 
601   //===--------------------------------------------------------------------===//
602   // Positions
603   //===--------------------------------------------------------------------===//
604 
605   /// Returns the root operation position.
getRoot()606   Position *getRoot() { return OperationPosition::getRoot(uniquer); }
607 
608   /// Returns the parent position defining the value held by the given operand.
getOperandDefiningOp(Position * p)609   OperationPosition *getOperandDefiningOp(Position *p) {
610     assert((isa<OperandPosition, OperandGroupPosition>(p)) &&
611            "expected operand position");
612     return OperationPosition::get(uniquer, p);
613   }
614 
615   /// Returns the operation position equivalent to the given position.
getPassthroughOp(Position * p)616   OperationPosition *getPassthroughOp(Position *p) {
617     assert((isa<ForEachPosition>(p)) && "expected users position");
618     return OperationPosition::get(uniquer, p);
619   }
620 
621   // Returns a position for a new value created by a constraint.
getConstraintPosition(ConstraintQuestion * q,unsigned index)622   ConstraintPosition *getConstraintPosition(ConstraintQuestion *q,
623                                             unsigned index) {
624     return ConstraintPosition::get(uniquer, std::make_pair(q, index));
625   }
626 
627   /// Returns an attribute position for an attribute of the given operation.
getAttribute(OperationPosition * p,StringRef name)628   Position *getAttribute(OperationPosition *p, StringRef name) {
629     return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
630   }
631 
632   /// Returns an attribute position for the given attribute.
getAttributeLiteral(Attribute attr)633   Position *getAttributeLiteral(Attribute attr) {
634     return AttributeLiteralPosition::get(uniquer, attr);
635   }
636 
getForEach(Position * p,unsigned id)637   Position *getForEach(Position *p, unsigned id) {
638     return ForEachPosition::get(uniquer, p, id);
639   }
640 
641   /// Returns an operand position for an operand of the given operation.
getOperand(OperationPosition * p,unsigned operand)642   Position *getOperand(OperationPosition *p, unsigned operand) {
643     return OperandPosition::get(uniquer, p, operand);
644   }
645 
646   /// Returns a position for a group of operands of the given operation.
getOperandGroup(OperationPosition * p,std::optional<unsigned> group,bool isVariadic)647   Position *getOperandGroup(OperationPosition *p, std::optional<unsigned> group,
648                             bool isVariadic) {
649     return OperandGroupPosition::get(uniquer, p, group, isVariadic);
650   }
getAllOperands(OperationPosition * p)651   Position *getAllOperands(OperationPosition *p) {
652     return getOperandGroup(p, /*group=*/std::nullopt, /*isVariadic=*/true);
653   }
654 
655   /// Returns a result position for a result of the given operation.
getResult(OperationPosition * p,unsigned result)656   Position *getResult(OperationPosition *p, unsigned result) {
657     return ResultPosition::get(uniquer, p, result);
658   }
659 
660   /// Returns a position for a group of results of the given operation.
getResultGroup(OperationPosition * p,std::optional<unsigned> group,bool isVariadic)661   Position *getResultGroup(OperationPosition *p, std::optional<unsigned> group,
662                            bool isVariadic) {
663     return ResultGroupPosition::get(uniquer, p, group, isVariadic);
664   }
getAllResults(OperationPosition * p)665   Position *getAllResults(OperationPosition *p) {
666     return getResultGroup(p, /*group=*/std::nullopt, /*isVariadic=*/true);
667   }
668 
669   /// Returns a type position for the given entity.
getType(Position * p)670   Position *getType(Position *p) { return TypePosition::get(uniquer, p); }
671 
672   /// Returns a type position for the given type value. The value is stored
673   /// as either a TypeAttr, or an ArrayAttr of TypeAttr.
getTypeLiteral(Attribute attr)674   Position *getTypeLiteral(Attribute attr) {
675     return TypeLiteralPosition::get(uniquer, attr);
676   }
677 
678   /// Returns the users of a position using the value at the given operand.
getUsers(Position * p,bool useRepresentative)679   UsersPosition *getUsers(Position *p, bool useRepresentative) {
680     assert((isa<OperandPosition, OperandGroupPosition, ResultPosition,
681                 ResultGroupPosition>(p)) &&
682            "expected result position");
683     return UsersPosition::get(uniquer, p, useRepresentative);
684   }
685 
686   //===--------------------------------------------------------------------===//
687   // Qualifiers
688   //===--------------------------------------------------------------------===//
689 
690   /// An ordinal predicate consists of a "Question" and a set of acceptable
691   /// "Answers" (later converted to ordinal values). A predicate will query some
692   /// property of a positional value and decide what to do based on the result.
693   using Predicate = std::pair<Qualifier *, Qualifier *>;
694 
695   /// Create a predicate comparing an attribute to a known value.
getAttributeConstraint(Attribute attr)696   Predicate getAttributeConstraint(Attribute attr) {
697     return {AttributeQuestion::get(uniquer),
698             AttributeAnswer::get(uniquer, attr)};
699   }
700 
701   /// Create a predicate checking if two values are equal.
getEqualTo(Position * pos)702   Predicate getEqualTo(Position *pos) {
703     return {EqualToQuestion::get(uniquer, pos), TrueAnswer::get(uniquer)};
704   }
705 
706   /// Create a predicate checking if two values are not equal.
getNotEqualTo(Position * pos)707   Predicate getNotEqualTo(Position *pos) {
708     return {EqualToQuestion::get(uniquer, pos), FalseAnswer::get(uniquer)};
709   }
710 
711   /// Create a predicate that applies a generic constraint.
getConstraint(StringRef name,ArrayRef<Position * > args,ArrayRef<Type> resultTypes,bool isNegated)712   Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
713                           ArrayRef<Type> resultTypes, bool isNegated) {
714     return {ConstraintQuestion::get(
715                 uniquer, std::make_tuple(name, args, resultTypes, isNegated)),
716             TrueAnswer::get(uniquer)};
717   }
718 
719   /// Create a predicate comparing a value with null.
getIsNotNull()720   Predicate getIsNotNull() {
721     return {IsNotNullQuestion::get(uniquer), TrueAnswer::get(uniquer)};
722   }
723 
724   /// Create a predicate comparing the number of operands of an operation to a
725   /// known value.
getOperandCount(unsigned count)726   Predicate getOperandCount(unsigned count) {
727     return {OperandCountQuestion::get(uniquer),
728             UnsignedAnswer::get(uniquer, count)};
729   }
getOperandCountAtLeast(unsigned count)730   Predicate getOperandCountAtLeast(unsigned count) {
731     return {OperandCountAtLeastQuestion::get(uniquer),
732             UnsignedAnswer::get(uniquer, count)};
733   }
734 
735   /// Create a predicate comparing the name of an operation to a known value.
getOperationName(StringRef name)736   Predicate getOperationName(StringRef name) {
737     return {OperationNameQuestion::get(uniquer),
738             OperationNameAnswer::get(uniquer, OperationName(name, ctx))};
739   }
740 
741   /// Create a predicate comparing the number of results of an operation to a
742   /// known value.
getResultCount(unsigned count)743   Predicate getResultCount(unsigned count) {
744     return {ResultCountQuestion::get(uniquer),
745             UnsignedAnswer::get(uniquer, count)};
746   }
getResultCountAtLeast(unsigned count)747   Predicate getResultCountAtLeast(unsigned count) {
748     return {ResultCountAtLeastQuestion::get(uniquer),
749             UnsignedAnswer::get(uniquer, count)};
750   }
751 
752   /// Create a predicate comparing the type of an attribute or value to a known
753   /// type. The value is stored as either a TypeAttr, or an ArrayAttr of
754   /// TypeAttr.
getTypeConstraint(Attribute type)755   Predicate getTypeConstraint(Attribute type) {
756     return {TypeQuestion::get(uniquer), TypeAnswer::get(uniquer, type)};
757   }
758 
759 private:
760   /// The uniquer used when allocating predicate nodes.
761   PredicateUniquer &uniquer;
762 
763   /// The current MLIR context.
764   MLIRContext *ctx;
765 };
766 
767 } // namespace pdl_to_pdl_interp
768 } // namespace mlir
769 
770 #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
771