1 //===- Predicate.h - Pattern predicates -------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains definitions for "predicates" used when converting PDL into 10 // a matcher tree. Predicates are composed of three different parts: 11 // 12 // * Positions 13 // - A position refers to a specific location on the input DAG, i.e. an 14 // existing MLIR entity being matched. These can be attributes, operands, 15 // operations, results, and types. Each position also defines a relation to 16 // its parent. For example, the operand `[0] -> 1` has a parent operation 17 // position `[0]`. The attribute `[0, 1] -> "myAttr"` has parent operation 18 // position of `[0, 1]`. The operation `[0, 1]` has a parent operand edge 19 // `[0] -> 1` (i.e. it is the defining op of operand 1). The only position 20 // without a parent is `[0]`, which refers to the root operation. 21 // * Questions 22 // - A question refers to a query on a specific positional value. For 23 // example, an operation name question checks the name of an operation 24 // position. 25 // * Answers 26 // - An answer is the expected result of a question. For example, when 27 // matching an operation with the name "foo.op". The question would be an 28 // operation name question, with an expected answer of "foo.op". 29 // 30 //===----------------------------------------------------------------------===// 31 32 #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ 33 #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ 34 35 #include "mlir/IR/MLIRContext.h" 36 #include "mlir/IR/OperationSupport.h" 37 #include "mlir/IR/PatternMatch.h" 38 #include "mlir/IR/Types.h" 39 40 namespace mlir { 41 namespace pdl_to_pdl_interp { 42 namespace Predicates { 43 /// An enumeration of the kinds of predicates. 44 enum Kind : unsigned { 45 /// Positions, ordered by decreasing priority. 46 OperationPos, 47 OperandPos, 48 OperandGroupPos, 49 AttributePos, 50 ConstraintResultPos, 51 ResultPos, 52 ResultGroupPos, 53 TypePos, 54 AttributeLiteralPos, 55 TypeLiteralPos, 56 UsersPos, 57 ForEachPos, 58 59 // Questions, ordered by dependency and decreasing priority. 60 IsNotNullQuestion, 61 OperationNameQuestion, 62 TypeQuestion, 63 AttributeQuestion, 64 OperandCountAtLeastQuestion, 65 OperandCountQuestion, 66 ResultCountAtLeastQuestion, 67 ResultCountQuestion, 68 EqualToQuestion, 69 ConstraintQuestion, 70 71 // Answers. 72 AttributeAnswer, 73 FalseAnswer, 74 OperationNameAnswer, 75 TrueAnswer, 76 TypeAnswer, 77 UnsignedAnswer, 78 }; 79 } // namespace Predicates 80 81 /// Base class for all predicates, used to allow efficient pointer comparison. 82 template <typename ConcreteT, typename BaseT, typename Key, 83 Predicates::Kind Kind> 84 class PredicateBase : public BaseT { 85 public: 86 using KeyTy = Key; 87 using Base = PredicateBase<ConcreteT, BaseT, Key, Kind>; 88 89 template <typename KeyT> PredicateBase(KeyT && key)90 explicit PredicateBase(KeyT &&key) 91 : BaseT(Kind), key(std::forward<KeyT>(key)) {} 92 93 /// Get an instance of this position. 94 template <typename... Args> get(StorageUniquer & uniquer,Args &&...args)95 static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) { 96 return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...); 97 } 98 99 /// Construct an instance with the given storage allocator. 100 template <typename KeyT> construct(StorageUniquer::StorageAllocator & alloc,KeyT && key)101 static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc, 102 KeyT &&key) { 103 return new (alloc.allocate<ConcreteT>()) ConcreteT(std::forward<KeyT>(key)); 104 } 105 106 /// Utility methods required by the storage allocator. 107 bool operator==(const KeyTy &key) const { return this->key == key; } classof(const BaseT * pred)108 static bool classof(const BaseT *pred) { return pred->getKind() == Kind; } 109 110 /// Return the key value of this predicate. getValue()111 const KeyTy &getValue() const { return key; } 112 113 protected: 114 KeyTy key; 115 }; 116 117 /// Base storage for simple predicates that only unique with the kind. 118 template <typename ConcreteT, typename BaseT, Predicates::Kind Kind> 119 class PredicateBase<ConcreteT, BaseT, void, Kind> : public BaseT { 120 public: 121 using Base = PredicateBase<ConcreteT, BaseT, void, Kind>; 122 PredicateBase()123 explicit PredicateBase() : BaseT(Kind) {} 124 get(StorageUniquer & uniquer)125 static ConcreteT *get(StorageUniquer &uniquer) { 126 return uniquer.get<ConcreteT>(); 127 } classof(const BaseT * pred)128 static bool classof(const BaseT *pred) { return pred->getKind() == Kind; } 129 }; 130 131 //===----------------------------------------------------------------------===// 132 // Positions 133 //===----------------------------------------------------------------------===// 134 135 struct OperationPosition; 136 137 /// A position describes a value on the input IR on which a predicate may be 138 /// applied, such as an operation or attribute. This enables re-use between 139 /// predicates, and assists generating bytecode and memory management. 140 /// 141 /// Operation positions form the base of other positions, which are formed 142 /// relative to a parent operation. Operations are anchored at Operand nodes, 143 /// except for the root operation which is parentless. 144 class Position : public StorageUniquer::BaseStorage { 145 public: Position(Predicates::Kind kind)146 explicit Position(Predicates::Kind kind) : kind(kind) {} 147 virtual ~Position(); 148 149 /// Returns the depth of the first ancestor operation position. 150 unsigned getOperationDepth() const; 151 152 /// Returns the parent position. The root operation position has no parent. getParent()153 Position *getParent() const { return parent; } 154 155 /// Returns the kind of this position. getKind()156 Predicates::Kind getKind() const { return kind; } 157 158 protected: 159 /// Link to the parent position. 160 Position *parent = nullptr; 161 162 private: 163 /// The kind of this position. 164 Predicates::Kind kind; 165 }; 166 167 //===----------------------------------------------------------------------===// 168 // AttributePosition 169 170 /// A position describing an attribute of an operation. 171 struct AttributePosition 172 : public PredicateBase<AttributePosition, Position, 173 std::pair<OperationPosition *, StringAttr>, 174 Predicates::AttributePos> { 175 explicit AttributePosition(const KeyTy &key); 176 177 /// Returns the attribute name of this position. getNameAttributePosition178 StringAttr getName() const { return key.second; } 179 }; 180 181 //===----------------------------------------------------------------------===// 182 // AttributeLiteralPosition 183 184 /// A position describing a literal attribute. 185 struct AttributeLiteralPosition 186 : public PredicateBase<AttributeLiteralPosition, Position, Attribute, 187 Predicates::AttributeLiteralPos> { 188 using PredicateBase::PredicateBase; 189 }; 190 191 //===----------------------------------------------------------------------===// 192 // ForEachPosition 193 194 /// A position describing an iterative choice of an operation. 195 struct ForEachPosition : public PredicateBase<ForEachPosition, Position, 196 std::pair<Position *, unsigned>, 197 Predicates::ForEachPos> { ForEachPositionForEachPosition198 explicit ForEachPosition(const KeyTy &key) : Base(key) { parent = key.first; } 199 200 /// Returns the ID, for differentiating various loops. 201 /// For upward traversals, this is the index of the root. getIDForEachPosition202 unsigned getID() const { return key.second; } 203 }; 204 205 //===----------------------------------------------------------------------===// 206 // OperandPosition 207 208 /// A position describing an operand of an operation. 209 struct OperandPosition 210 : public PredicateBase<OperandPosition, Position, 211 std::pair<OperationPosition *, unsigned>, 212 Predicates::OperandPos> { 213 explicit OperandPosition(const KeyTy &key); 214 215 /// Returns the operand number of this position. getOperandNumberOperandPosition216 unsigned getOperandNumber() const { return key.second; } 217 }; 218 219 //===----------------------------------------------------------------------===// 220 // OperandGroupPosition 221 222 /// A position describing an operand group of an operation. 223 struct OperandGroupPosition 224 : public PredicateBase< 225 OperandGroupPosition, Position, 226 std::tuple<OperationPosition *, std::optional<unsigned>, bool>, 227 Predicates::OperandGroupPos> { 228 explicit OperandGroupPosition(const KeyTy &key); 229 230 /// Returns a hash suitable for the given keytype. hashKeyOperandGroupPosition231 static llvm::hash_code hashKey(const KeyTy &key) { 232 return llvm::hash_value(key); 233 } 234 235 /// Returns the group number of this position. If std::nullopt, this group 236 /// refers to all operands. getOperandGroupNumberOperandGroupPosition237 std::optional<unsigned> getOperandGroupNumber() const { 238 return std::get<1>(key); 239 } 240 241 /// Returns if the operand group has unknown size. If false, the operand group 242 /// has at max one element. isVariadicOperandGroupPosition243 bool isVariadic() const { return std::get<2>(key); } 244 }; 245 246 //===----------------------------------------------------------------------===// 247 // OperationPosition 248 249 /// An operation position describes an operation node in the IR. Other position 250 /// kinds are formed with respect to an operation position. 251 struct OperationPosition : public PredicateBase<OperationPosition, Position, 252 std::pair<Position *, unsigned>, 253 Predicates::OperationPos> { OperationPositionOperationPosition254 explicit OperationPosition(const KeyTy &key) : Base(key) { 255 parent = key.first; 256 } 257 258 /// Returns a hash suitable for the given keytype. hashKeyOperationPosition259 static llvm::hash_code hashKey(const KeyTy &key) { 260 return llvm::hash_value(key); 261 } 262 263 /// Gets the root position. getRootOperationPosition264 static OperationPosition *getRoot(StorageUniquer &uniquer) { 265 return Base::get(uniquer, nullptr, 0); 266 } 267 268 /// Gets an operation position with the given parent. getOperationPosition269 static OperationPosition *get(StorageUniquer &uniquer, Position *parent) { 270 return Base::get(uniquer, parent, parent->getOperationDepth() + 1); 271 } 272 273 /// Returns the depth of this position. getDepthOperationPosition274 unsigned getDepth() const { return key.second; } 275 276 /// Returns if this operation position corresponds to the root. isRootOperationPosition277 bool isRoot() const { return getDepth() == 0; } 278 279 /// Returns if this operation represents an operand defining op. 280 bool isOperandDefiningOp() const; 281 }; 282 283 //===----------------------------------------------------------------------===// 284 // ConstraintPosition 285 286 struct ConstraintQuestion; 287 288 /// A position describing the result of a native constraint. It saves the 289 /// corresponding ConstraintQuestion and result index to enable referring 290 /// back to them 291 struct ConstraintPosition 292 : public PredicateBase<ConstraintPosition, Position, 293 std::pair<ConstraintQuestion *, unsigned>, 294 Predicates::ConstraintResultPos> { 295 using PredicateBase::PredicateBase; 296 297 /// Returns the ConstraintQuestion to enable keeping track of the native 298 /// constraint this position stems from. getQuestionConstraintPosition299 ConstraintQuestion *getQuestion() const { return key.first; } 300 301 // Returns the result index of this position getIndexConstraintPosition302 unsigned getIndex() const { return key.second; } 303 }; 304 305 //===----------------------------------------------------------------------===// 306 // ResultPosition 307 308 /// A position describing a result of an operation. 309 struct ResultPosition 310 : public PredicateBase<ResultPosition, Position, 311 std::pair<OperationPosition *, unsigned>, 312 Predicates::ResultPos> { ResultPositionResultPosition313 explicit ResultPosition(const KeyTy &key) : Base(key) { parent = key.first; } 314 315 /// Returns the result number of this position. getResultNumberResultPosition316 unsigned getResultNumber() const { return key.second; } 317 }; 318 319 //===----------------------------------------------------------------------===// 320 // ResultGroupPosition 321 322 /// A position describing a result group of an operation. 323 struct ResultGroupPosition 324 : public PredicateBase< 325 ResultGroupPosition, Position, 326 std::tuple<OperationPosition *, std::optional<unsigned>, bool>, 327 Predicates::ResultGroupPos> { ResultGroupPositionResultGroupPosition328 explicit ResultGroupPosition(const KeyTy &key) : Base(key) { 329 parent = std::get<0>(key); 330 } 331 332 /// Returns a hash suitable for the given keytype. hashKeyResultGroupPosition333 static llvm::hash_code hashKey(const KeyTy &key) { 334 return llvm::hash_value(key); 335 } 336 337 /// Returns the group number of this position. If std::nullopt, this group 338 /// refers to all results. getResultGroupNumberResultGroupPosition339 std::optional<unsigned> getResultGroupNumber() const { 340 return std::get<1>(key); 341 } 342 343 /// Returns if the result group has unknown size. If false, the result group 344 /// has at max one element. isVariadicResultGroupPosition345 bool isVariadic() const { return std::get<2>(key); } 346 }; 347 348 //===----------------------------------------------------------------------===// 349 // TypePosition 350 351 /// A position describing the result type of an entity, i.e. an Attribute, 352 /// Operand, Result, etc. 353 struct TypePosition : public PredicateBase<TypePosition, Position, Position *, 354 Predicates::TypePos> { TypePositionTypePosition355 explicit TypePosition(const KeyTy &key) : Base(key) { 356 assert((isa<AttributePosition, OperandPosition, OperandGroupPosition, 357 ResultPosition, ResultGroupPosition>(key)) && 358 "expected parent to be an attribute, operand, or result"); 359 parent = key; 360 } 361 }; 362 363 //===----------------------------------------------------------------------===// 364 // TypeLiteralPosition 365 366 /// A position describing a literal type or type range. The value is stored as 367 /// either a TypeAttr, or an ArrayAttr of TypeAttr. 368 struct TypeLiteralPosition 369 : public PredicateBase<TypeLiteralPosition, Position, Attribute, 370 Predicates::TypeLiteralPos> { 371 using PredicateBase::PredicateBase; 372 }; 373 374 //===----------------------------------------------------------------------===// 375 // UsersPosition 376 377 /// A position describing the users of a value or a range of values. The second 378 /// value in the key indicates whether we choose users of a representative for 379 /// a range (this is true, e.g., in the upward traversals). 380 struct UsersPosition 381 : public PredicateBase<UsersPosition, Position, std::pair<Position *, bool>, 382 Predicates::UsersPos> { UsersPositionUsersPosition383 explicit UsersPosition(const KeyTy &key) : Base(key) { parent = key.first; } 384 385 /// Returns a hash suitable for the given keytype. hashKeyUsersPosition386 static llvm::hash_code hashKey(const KeyTy &key) { 387 return llvm::hash_value(key); 388 } 389 390 /// Indicates whether to compute a range of a representative. useRepresentativeUsersPosition391 bool useRepresentative() const { return key.second; } 392 }; 393 394 //===----------------------------------------------------------------------===// 395 // Qualifiers 396 //===----------------------------------------------------------------------===// 397 398 /// An ordinal predicate consists of a "Question" and a set of acceptable 399 /// "Answers" (later converted to ordinal values). A predicate will query some 400 /// property of a positional value and decide what to do based on the result. 401 /// 402 /// This makes top-level predicate representations ordinal (SwitchOp). Later, 403 /// predicates that end up with only one acceptable answer (including all 404 /// boolean kinds) will be converted to boolean predicates (PredicateOp) in the 405 /// matcher. 406 /// 407 /// For simplicity, both are represented as "qualifiers", with a base kind and 408 /// perhaps additional properties. For example, all OperationName predicates ask 409 /// the same question, but GenericConstraint predicates may ask different ones. 410 class Qualifier : public StorageUniquer::BaseStorage { 411 public: Qualifier(Predicates::Kind kind)412 explicit Qualifier(Predicates::Kind kind) : kind(kind) {} 413 414 /// Returns the kind of this qualifier. getKind()415 Predicates::Kind getKind() const { return kind; } 416 417 private: 418 /// The kind of this position. 419 Predicates::Kind kind; 420 }; 421 422 //===----------------------------------------------------------------------===// 423 // Answers 424 425 /// An Answer representing an `Attribute` value. 426 struct AttributeAnswer 427 : public PredicateBase<AttributeAnswer, Qualifier, Attribute, 428 Predicates::AttributeAnswer> { 429 using Base::Base; 430 }; 431 432 /// An Answer representing an `OperationName` value. 433 struct OperationNameAnswer 434 : public PredicateBase<OperationNameAnswer, Qualifier, OperationName, 435 Predicates::OperationNameAnswer> { 436 using Base::Base; 437 }; 438 439 /// An Answer representing a boolean `true` value. 440 struct TrueAnswer 441 : PredicateBase<TrueAnswer, Qualifier, void, Predicates::TrueAnswer> { 442 using Base::Base; 443 }; 444 445 /// An Answer representing a boolean 'false' value. 446 struct FalseAnswer 447 : PredicateBase<FalseAnswer, Qualifier, void, Predicates::FalseAnswer> { 448 using Base::Base; 449 }; 450 451 /// An Answer representing a `Type` value. The value is stored as either a 452 /// TypeAttr, or an ArrayAttr of TypeAttr. 453 struct TypeAnswer : public PredicateBase<TypeAnswer, Qualifier, Attribute, 454 Predicates::TypeAnswer> { 455 using Base::Base; 456 }; 457 458 /// An Answer representing an unsigned value. 459 struct UnsignedAnswer 460 : public PredicateBase<UnsignedAnswer, Qualifier, unsigned, 461 Predicates::UnsignedAnswer> { 462 using Base::Base; 463 }; 464 465 //===----------------------------------------------------------------------===// 466 // Questions 467 468 /// Compare an `Attribute` to a constant value. 469 struct AttributeQuestion 470 : public PredicateBase<AttributeQuestion, Qualifier, void, 471 Predicates::AttributeQuestion> {}; 472 473 /// Apply a parameterized constraint to multiple position values and possibly 474 /// produce results. 475 struct ConstraintQuestion 476 : public PredicateBase< 477 ConstraintQuestion, Qualifier, 478 std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>, 479 Predicates::ConstraintQuestion> { 480 using Base::Base; 481 482 /// Return the name of the constraint. getNameConstraintQuestion483 StringRef getName() const { return std::get<0>(key); } 484 485 /// Return the arguments of the constraint. getArgsConstraintQuestion486 ArrayRef<Position *> getArgs() const { return std::get<1>(key); } 487 488 /// Return the result types of the constraint. getResultTypesConstraintQuestion489 ArrayRef<Type> getResultTypes() const { return std::get<2>(key); } 490 491 /// Return the negation status of the constraint. getIsNegatedConstraintQuestion492 bool getIsNegated() const { return std::get<3>(key); } 493 494 /// Construct an instance with the given storage allocator. constructConstraintQuestion495 static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc, 496 KeyTy key) { 497 return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)), 498 alloc.copyInto(std::get<1>(key)), 499 alloc.copyInto(std::get<2>(key)), 500 std::get<3>(key)}); 501 } 502 503 /// Returns a hash suitable for the given keytype. hashKeyConstraintQuestion504 static llvm::hash_code hashKey(const KeyTy &key) { 505 return llvm::hash_value(key); 506 } 507 }; 508 509 /// Compare the equality of two values. 510 struct EqualToQuestion 511 : public PredicateBase<EqualToQuestion, Qualifier, Position *, 512 Predicates::EqualToQuestion> { 513 using Base::Base; 514 }; 515 516 /// Compare a positional value with null, i.e. check if it exists. 517 struct IsNotNullQuestion 518 : public PredicateBase<IsNotNullQuestion, Qualifier, void, 519 Predicates::IsNotNullQuestion> {}; 520 521 /// Compare the number of operands of an operation with a known value. 522 struct OperandCountQuestion 523 : public PredicateBase<OperandCountQuestion, Qualifier, void, 524 Predicates::OperandCountQuestion> {}; 525 struct OperandCountAtLeastQuestion 526 : public PredicateBase<OperandCountAtLeastQuestion, Qualifier, void, 527 Predicates::OperandCountAtLeastQuestion> {}; 528 529 /// Compare the name of an operation with a known value. 530 struct OperationNameQuestion 531 : public PredicateBase<OperationNameQuestion, Qualifier, void, 532 Predicates::OperationNameQuestion> {}; 533 534 /// Compare the number of results of an operation with a known value. 535 struct ResultCountQuestion 536 : public PredicateBase<ResultCountQuestion, Qualifier, void, 537 Predicates::ResultCountQuestion> {}; 538 struct ResultCountAtLeastQuestion 539 : public PredicateBase<ResultCountAtLeastQuestion, Qualifier, void, 540 Predicates::ResultCountAtLeastQuestion> {}; 541 542 /// Compare the type of an attribute or value with a known type. 543 struct TypeQuestion : public PredicateBase<TypeQuestion, Qualifier, void, 544 Predicates::TypeQuestion> {}; 545 546 //===----------------------------------------------------------------------===// 547 // PredicateUniquer 548 //===----------------------------------------------------------------------===// 549 550 /// This class provides a storage uniquer that is used to allocate predicate 551 /// instances. 552 class PredicateUniquer : public StorageUniquer { 553 public: PredicateUniquer()554 PredicateUniquer() { 555 // Register the types of Positions with the uniquer. 556 registerParametricStorageType<AttributePosition>(); 557 registerParametricStorageType<AttributeLiteralPosition>(); 558 registerParametricStorageType<ConstraintPosition>(); 559 registerParametricStorageType<ForEachPosition>(); 560 registerParametricStorageType<OperandPosition>(); 561 registerParametricStorageType<OperandGroupPosition>(); 562 registerParametricStorageType<OperationPosition>(); 563 registerParametricStorageType<ResultPosition>(); 564 registerParametricStorageType<ResultGroupPosition>(); 565 registerParametricStorageType<TypePosition>(); 566 registerParametricStorageType<TypeLiteralPosition>(); 567 registerParametricStorageType<UsersPosition>(); 568 569 // Register the types of Questions with the uniquer. 570 registerParametricStorageType<AttributeAnswer>(); 571 registerParametricStorageType<OperationNameAnswer>(); 572 registerParametricStorageType<TypeAnswer>(); 573 registerParametricStorageType<UnsignedAnswer>(); 574 registerSingletonStorageType<FalseAnswer>(); 575 registerSingletonStorageType<TrueAnswer>(); 576 577 // Register the types of Answers with the uniquer. 578 registerParametricStorageType<ConstraintQuestion>(); 579 registerParametricStorageType<EqualToQuestion>(); 580 registerSingletonStorageType<AttributeQuestion>(); 581 registerSingletonStorageType<IsNotNullQuestion>(); 582 registerSingletonStorageType<OperandCountQuestion>(); 583 registerSingletonStorageType<OperandCountAtLeastQuestion>(); 584 registerSingletonStorageType<OperationNameQuestion>(); 585 registerSingletonStorageType<ResultCountQuestion>(); 586 registerSingletonStorageType<ResultCountAtLeastQuestion>(); 587 registerSingletonStorageType<TypeQuestion>(); 588 } 589 }; 590 591 //===----------------------------------------------------------------------===// 592 // PredicateBuilder 593 //===----------------------------------------------------------------------===// 594 595 /// This class provides utilities for constructing predicates. 596 class PredicateBuilder { 597 public: PredicateBuilder(PredicateUniquer & uniquer,MLIRContext * ctx)598 PredicateBuilder(PredicateUniquer &uniquer, MLIRContext *ctx) 599 : uniquer(uniquer), ctx(ctx) {} 600 601 //===--------------------------------------------------------------------===// 602 // Positions 603 //===--------------------------------------------------------------------===// 604 605 /// Returns the root operation position. getRoot()606 Position *getRoot() { return OperationPosition::getRoot(uniquer); } 607 608 /// Returns the parent position defining the value held by the given operand. getOperandDefiningOp(Position * p)609 OperationPosition *getOperandDefiningOp(Position *p) { 610 assert((isa<OperandPosition, OperandGroupPosition>(p)) && 611 "expected operand position"); 612 return OperationPosition::get(uniquer, p); 613 } 614 615 /// Returns the operation position equivalent to the given position. getPassthroughOp(Position * p)616 OperationPosition *getPassthroughOp(Position *p) { 617 assert((isa<ForEachPosition>(p)) && "expected users position"); 618 return OperationPosition::get(uniquer, p); 619 } 620 621 // Returns a position for a new value created by a constraint. getConstraintPosition(ConstraintQuestion * q,unsigned index)622 ConstraintPosition *getConstraintPosition(ConstraintQuestion *q, 623 unsigned index) { 624 return ConstraintPosition::get(uniquer, std::make_pair(q, index)); 625 } 626 627 /// Returns an attribute position for an attribute of the given operation. getAttribute(OperationPosition * p,StringRef name)628 Position *getAttribute(OperationPosition *p, StringRef name) { 629 return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name)); 630 } 631 632 /// Returns an attribute position for the given attribute. getAttributeLiteral(Attribute attr)633 Position *getAttributeLiteral(Attribute attr) { 634 return AttributeLiteralPosition::get(uniquer, attr); 635 } 636 getForEach(Position * p,unsigned id)637 Position *getForEach(Position *p, unsigned id) { 638 return ForEachPosition::get(uniquer, p, id); 639 } 640 641 /// Returns an operand position for an operand of the given operation. getOperand(OperationPosition * p,unsigned operand)642 Position *getOperand(OperationPosition *p, unsigned operand) { 643 return OperandPosition::get(uniquer, p, operand); 644 } 645 646 /// Returns a position for a group of operands of the given operation. getOperandGroup(OperationPosition * p,std::optional<unsigned> group,bool isVariadic)647 Position *getOperandGroup(OperationPosition *p, std::optional<unsigned> group, 648 bool isVariadic) { 649 return OperandGroupPosition::get(uniquer, p, group, isVariadic); 650 } getAllOperands(OperationPosition * p)651 Position *getAllOperands(OperationPosition *p) { 652 return getOperandGroup(p, /*group=*/std::nullopt, /*isVariadic=*/true); 653 } 654 655 /// Returns a result position for a result of the given operation. getResult(OperationPosition * p,unsigned result)656 Position *getResult(OperationPosition *p, unsigned result) { 657 return ResultPosition::get(uniquer, p, result); 658 } 659 660 /// Returns a position for a group of results of the given operation. getResultGroup(OperationPosition * p,std::optional<unsigned> group,bool isVariadic)661 Position *getResultGroup(OperationPosition *p, std::optional<unsigned> group, 662 bool isVariadic) { 663 return ResultGroupPosition::get(uniquer, p, group, isVariadic); 664 } getAllResults(OperationPosition * p)665 Position *getAllResults(OperationPosition *p) { 666 return getResultGroup(p, /*group=*/std::nullopt, /*isVariadic=*/true); 667 } 668 669 /// Returns a type position for the given entity. getType(Position * p)670 Position *getType(Position *p) { return TypePosition::get(uniquer, p); } 671 672 /// Returns a type position for the given type value. The value is stored 673 /// as either a TypeAttr, or an ArrayAttr of TypeAttr. getTypeLiteral(Attribute attr)674 Position *getTypeLiteral(Attribute attr) { 675 return TypeLiteralPosition::get(uniquer, attr); 676 } 677 678 /// Returns the users of a position using the value at the given operand. getUsers(Position * p,bool useRepresentative)679 UsersPosition *getUsers(Position *p, bool useRepresentative) { 680 assert((isa<OperandPosition, OperandGroupPosition, ResultPosition, 681 ResultGroupPosition>(p)) && 682 "expected result position"); 683 return UsersPosition::get(uniquer, p, useRepresentative); 684 } 685 686 //===--------------------------------------------------------------------===// 687 // Qualifiers 688 //===--------------------------------------------------------------------===// 689 690 /// An ordinal predicate consists of a "Question" and a set of acceptable 691 /// "Answers" (later converted to ordinal values). A predicate will query some 692 /// property of a positional value and decide what to do based on the result. 693 using Predicate = std::pair<Qualifier *, Qualifier *>; 694 695 /// Create a predicate comparing an attribute to a known value. getAttributeConstraint(Attribute attr)696 Predicate getAttributeConstraint(Attribute attr) { 697 return {AttributeQuestion::get(uniquer), 698 AttributeAnswer::get(uniquer, attr)}; 699 } 700 701 /// Create a predicate checking if two values are equal. getEqualTo(Position * pos)702 Predicate getEqualTo(Position *pos) { 703 return {EqualToQuestion::get(uniquer, pos), TrueAnswer::get(uniquer)}; 704 } 705 706 /// Create a predicate checking if two values are not equal. getNotEqualTo(Position * pos)707 Predicate getNotEqualTo(Position *pos) { 708 return {EqualToQuestion::get(uniquer, pos), FalseAnswer::get(uniquer)}; 709 } 710 711 /// Create a predicate that applies a generic constraint. getConstraint(StringRef name,ArrayRef<Position * > args,ArrayRef<Type> resultTypes,bool isNegated)712 Predicate getConstraint(StringRef name, ArrayRef<Position *> args, 713 ArrayRef<Type> resultTypes, bool isNegated) { 714 return {ConstraintQuestion::get( 715 uniquer, std::make_tuple(name, args, resultTypes, isNegated)), 716 TrueAnswer::get(uniquer)}; 717 } 718 719 /// Create a predicate comparing a value with null. getIsNotNull()720 Predicate getIsNotNull() { 721 return {IsNotNullQuestion::get(uniquer), TrueAnswer::get(uniquer)}; 722 } 723 724 /// Create a predicate comparing the number of operands of an operation to a 725 /// known value. getOperandCount(unsigned count)726 Predicate getOperandCount(unsigned count) { 727 return {OperandCountQuestion::get(uniquer), 728 UnsignedAnswer::get(uniquer, count)}; 729 } getOperandCountAtLeast(unsigned count)730 Predicate getOperandCountAtLeast(unsigned count) { 731 return {OperandCountAtLeastQuestion::get(uniquer), 732 UnsignedAnswer::get(uniquer, count)}; 733 } 734 735 /// Create a predicate comparing the name of an operation to a known value. getOperationName(StringRef name)736 Predicate getOperationName(StringRef name) { 737 return {OperationNameQuestion::get(uniquer), 738 OperationNameAnswer::get(uniquer, OperationName(name, ctx))}; 739 } 740 741 /// Create a predicate comparing the number of results of an operation to a 742 /// known value. getResultCount(unsigned count)743 Predicate getResultCount(unsigned count) { 744 return {ResultCountQuestion::get(uniquer), 745 UnsignedAnswer::get(uniquer, count)}; 746 } getResultCountAtLeast(unsigned count)747 Predicate getResultCountAtLeast(unsigned count) { 748 return {ResultCountAtLeastQuestion::get(uniquer), 749 UnsignedAnswer::get(uniquer, count)}; 750 } 751 752 /// Create a predicate comparing the type of an attribute or value to a known 753 /// type. The value is stored as either a TypeAttr, or an ArrayAttr of 754 /// TypeAttr. getTypeConstraint(Attribute type)755 Predicate getTypeConstraint(Attribute type) { 756 return {TypeQuestion::get(uniquer), TypeAnswer::get(uniquer, type)}; 757 } 758 759 private: 760 /// The uniquer used when allocating predicate nodes. 761 PredicateUniquer &uniquer; 762 763 /// The current MLIR context. 764 MLIRContext *ctx; 765 }; 766 767 } // namespace pdl_to_pdl_interp 768 } // namespace mlir 769 770 #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ 771