1 //===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements helper classes for implementing the "Op" types. This 10 // includes the Op type, which is the base class for Op class definitions, 11 // as well as number of traits in the OpTrait namespace that provide a 12 // declarative way to specify properties of Ops. 13 // 14 // The purpose of these types are to allow light-weight implementation of 15 // concrete ops (like DimOp) with very little boilerplate. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #ifndef MLIR_IR_OPDEFINITION_H 20 #define MLIR_IR_OPDEFINITION_H 21 22 #include "mlir/IR/Dialect.h" 23 #include "mlir/IR/ODSSupport.h" 24 #include "mlir/IR/Operation.h" 25 #include "llvm/Support/PointerLikeTypeTraits.h" 26 27 #include <optional> 28 #include <type_traits> 29 30 namespace mlir { 31 class Builder; 32 class OpBuilder; 33 34 /// This class implements `Optional` functionality for ParseResult. We don't 35 /// directly use Optional here, because it provides an implicit conversion 36 /// to 'bool' which we want to avoid. This class is used to implement tri-state 37 /// 'parseOptional' functions that may have a failure mode when parsing that 38 /// shouldn't be attributed to "not present". 39 class OptionalParseResult { 40 public: 41 OptionalParseResult() = default; 42 OptionalParseResult(LogicalResult result) : impl(result) {} 43 OptionalParseResult(ParseResult result) : impl(result) {} 44 OptionalParseResult(const InFlightDiagnostic &) 45 : OptionalParseResult(failure()) {} 46 OptionalParseResult(std::nullopt_t) : impl(std::nullopt) {} 47 48 /// Returns true if we contain a valid ParseResult value. 49 bool has_value() const { return impl.has_value(); } 50 51 /// Access the internal ParseResult value. 52 ParseResult value() const { return *impl; } 53 ParseResult operator*() const { return value(); } 54 55 private: 56 std::optional<ParseResult> impl; 57 }; 58 59 // These functions are out-of-line utilities, which avoids them being template 60 // instantiated/duplicated. 61 namespace impl { 62 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the 63 /// region's only block if it does not have a terminator already. If the region 64 /// is empty, insert a new block first. `buildTerminatorOp` should return the 65 /// terminator operation to insert. 66 void ensureRegionTerminator( 67 Region ®ion, OpBuilder &builder, Location loc, 68 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp); 69 void ensureRegionTerminator( 70 Region ®ion, Builder &builder, Location loc, 71 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp); 72 73 } // namespace impl 74 75 /// Structure used by default as a "marker" when no "Properties" are set on an 76 /// Operation. 77 struct EmptyProperties {}; 78 79 /// Traits to detect whether an Operation defined a `Properties` type, otherwise 80 /// it'll default to `EmptyProperties`. 81 template <class Op, class = void> 82 struct PropertiesSelector { 83 using type = EmptyProperties; 84 }; 85 template <class Op> 86 struct PropertiesSelector<Op, std::void_t<typename Op::Properties>> { 87 using type = typename Op::Properties; 88 }; 89 90 /// This is the concrete base class that holds the operation pointer and has 91 /// non-generic methods that only depend on State (to avoid having them 92 /// instantiated on template types that don't affect them. 93 /// 94 /// This also has the fallback implementations of customization hooks for when 95 /// they aren't customized. 96 class OpState { 97 public: 98 /// Ops are pointer-like, so we allow conversion to bool. 99 explicit operator bool() { return getOperation() != nullptr; } 100 101 /// This implicitly converts to Operation*. 102 operator Operation *() const { return state; } 103 104 /// Shortcut of `->` to access a member of Operation. 105 Operation *operator->() const { return state; } 106 107 /// Return the operation that this refers to. 108 Operation *getOperation() { return state; } 109 110 /// Return the context this operation belongs to. 111 MLIRContext *getContext() { return getOperation()->getContext(); } 112 113 /// Print the operation to the given stream. 114 void print(raw_ostream &os, OpPrintingFlags flags = std::nullopt) { 115 state->print(os, flags); 116 } 117 void print(raw_ostream &os, AsmState &asmState) { 118 state->print(os, asmState); 119 } 120 121 /// Dump this operation. 122 void dump() { state->dump(); } 123 124 /// The source location the operation was defined or derived from. 125 Location getLoc() { return state->getLoc(); } 126 127 /// Return true if there are no users of any results of this operation. 128 bool use_empty() { return state->use_empty(); } 129 130 /// Remove this operation from its parent block and delete it. 131 void erase() { state->erase(); } 132 133 /// Emit an error with the op name prefixed, like "'dim' op " which is 134 /// convenient for verifiers. 135 InFlightDiagnostic emitOpError(const Twine &message = {}); 136 137 /// Emit an error about fatal conditions with this operation, reporting up to 138 /// any diagnostic handlers that may be listening. 139 InFlightDiagnostic emitError(const Twine &message = {}); 140 141 /// Emit a warning about this operation, reporting up to any diagnostic 142 /// handlers that may be listening. 143 InFlightDiagnostic emitWarning(const Twine &message = {}); 144 145 /// Emit a remark about this operation, reporting up to any diagnostic 146 /// handlers that may be listening. 147 InFlightDiagnostic emitRemark(const Twine &message = {}); 148 149 /// Walk the operation by calling the callback for each nested operation 150 /// (including this one), block or region, depending on the callback provided. 151 /// The order in which regions, blocks and operations the same nesting level 152 /// are visited (e.g., lexicographical or reverse lexicographical order) is 153 /// determined by 'Iterator'. The walk order for enclosing regions, blocks 154 /// and operations with respect to their nested ones is specified by 'Order' 155 /// (post-order by default). A callback on a block or operation is allowed to 156 /// erase that block or operation if either: 157 /// * the walk is in post-order, or 158 /// * the walk is in pre-order and the walk is skipped after the erasure. 159 /// See Operation::walk for more details. 160 template <WalkOrder Order = WalkOrder::PostOrder, 161 typename Iterator = ForwardIterator, typename FnT, 162 typename RetT = detail::walkResultType<FnT>> 163 std::enable_if_t<llvm::function_traits<std::decay_t<FnT>>::num_args == 1, 164 RetT> 165 walk(FnT &&callback) { 166 return state->walk<Order, Iterator>(std::forward<FnT>(callback)); 167 } 168 169 /// Generic walker with a stage aware callback. Walk the operation by calling 170 /// the callback for each nested operation (including this one) N+1 times, 171 /// where N is the number of regions attached to that operation. 172 /// 173 /// The callback method can take any of the following forms: 174 /// void(Operation *, const WalkStage &) : Walk all operation opaquely 175 /// * op.walk([](Operation *nestedOp, const WalkStage &stage) { ...}); 176 /// void(OpT, const WalkStage &) : Walk all operations of the given derived 177 /// type. 178 /// * op.walk([](ReturnOp returnOp, const WalkStage &stage) { ...}); 179 /// WalkResult(Operation*|OpT, const WalkStage &stage) : Walk operations, 180 /// but allow for interruption/skipping. 181 /// * op.walk([](... op, const WalkStage &stage) { 182 /// // Skip the walk of this op based on some invariant. 183 /// if (some_invariant) 184 /// return WalkResult::skip(); 185 /// // Interrupt, i.e cancel, the walk based on some invariant. 186 /// if (another_invariant) 187 /// return WalkResult::interrupt(); 188 /// return WalkResult::advance(); 189 /// }); 190 template <typename FnT, typename RetT = detail::walkResultType<FnT>> 191 std::enable_if_t<llvm::function_traits<std::decay_t<FnT>>::num_args == 2, 192 RetT> 193 walk(FnT &&callback) { 194 return state->walk(std::forward<FnT>(callback)); 195 } 196 197 // These are default implementations of customization hooks. 198 public: 199 /// This hook returns any canonicalization pattern rewrites that the operation 200 /// supports, for use by the canonicalization pass. 201 static void getCanonicalizationPatterns(RewritePatternSet &results, 202 MLIRContext *context) {} 203 204 /// This hook populates any unset default attrs. 205 static void populateDefaultAttrs(const OperationName &, NamedAttrList &) {} 206 207 protected: 208 /// If the concrete type didn't implement a custom verifier hook, just fall 209 /// back to this one which accepts everything. 210 LogicalResult verify() { return success(); } 211 LogicalResult verifyRegions() { return success(); } 212 213 /// Parse the custom form of an operation. Unless overridden, this method will 214 /// first try to get an operation parser from the op's dialect. Otherwise the 215 /// custom assembly form of an op is always rejected. Op implementations 216 /// should implement this to return failure. On success, they should fill in 217 /// result with the fields to use. 218 static ParseResult parse(OpAsmParser &parser, OperationState &result); 219 220 /// Print the operation. Unless overridden, this method will first try to get 221 /// an operation printer from the dialect. Otherwise, it prints the operation 222 /// in generic form. 223 static void print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect); 224 225 /// Parse properties as a Attribute. 226 static ParseResult genericParseProperties(OpAsmParser &parser, 227 Attribute &result); 228 229 /// Print the properties as a Attribute with names not included within 230 /// 'elidedProps' 231 static void genericPrintProperties(OpAsmPrinter &p, Attribute properties, 232 ArrayRef<StringRef> elidedProps = {}); 233 234 /// Print an operation name, eliding the dialect prefix if necessary. 235 static void printOpName(Operation *op, OpAsmPrinter &p, 236 StringRef defaultDialect); 237 238 /// Mutability management is handled by the OpWrapper/OpConstWrapper classes, 239 /// so we can cast it away here. 240 explicit OpState(Operation *state) : state(state) {} 241 242 /// For all op which don't have properties, we keep a single instance of 243 /// `EmptyProperties` to be used where a reference to a properties is needed: 244 /// this allow to bind a pointer to the reference without triggering UB. 245 static EmptyProperties &getEmptyProperties() { 246 static EmptyProperties emptyProperties; 247 return emptyProperties; 248 } 249 250 private: 251 Operation *state; 252 253 /// Allow access to internal hook implementation methods. 254 friend RegisteredOperationName; 255 }; 256 257 // Allow comparing operators. 258 inline bool operator==(OpState lhs, OpState rhs) { 259 return lhs.getOperation() == rhs.getOperation(); 260 } 261 inline bool operator!=(OpState lhs, OpState rhs) { 262 return lhs.getOperation() != rhs.getOperation(); 263 } 264 265 raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr); 266 267 /// This class represents a single result from folding an operation. 268 class OpFoldResult : public PointerUnion<Attribute, Value> { 269 using PointerUnion<Attribute, Value>::PointerUnion; 270 271 public: 272 void dump() const { llvm::errs() << *this << "\n"; } 273 274 MLIRContext *getContext() const { 275 PointerUnion pu = *this; 276 return isa<Attribute>(pu) ? cast<Attribute>(pu).getContext() 277 : cast<Value>(pu).getContext(); 278 } 279 }; 280 281 // Temporarily exit the MLIR namespace to add casting support as later code in 282 // this uses it. The CastInfo must come after the OpFoldResult definition and 283 // before any cast function calls depending on CastInfo. 284 285 } // namespace mlir 286 287 namespace llvm { 288 289 // Allow llvm::cast style functions. 290 template <typename To> 291 struct CastInfo<To, mlir::OpFoldResult> 292 : public CastInfo<To, mlir::OpFoldResult::PointerUnion> {}; 293 294 template <typename To> 295 struct CastInfo<To, const mlir::OpFoldResult> 296 : public CastInfo<To, const mlir::OpFoldResult::PointerUnion> {}; 297 298 } // namespace llvm 299 300 namespace mlir { 301 302 /// Allow printing to a stream. 303 inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) { 304 if (Value value = llvm::dyn_cast_if_present<Value>(ofr)) 305 value.print(os); 306 else 307 llvm::dyn_cast_if_present<Attribute>(ofr).print(os); 308 return os; 309 } 310 /// Allow printing to a stream. 311 inline raw_ostream &operator<<(raw_ostream &os, OpState op) { 312 op.print(os, OpPrintingFlags().useLocalScope()); 313 return os; 314 } 315 316 //===----------------------------------------------------------------------===// 317 // Operation Trait Types 318 //===----------------------------------------------------------------------===// 319 320 namespace OpTrait { 321 322 // These functions are out-of-line implementations of the methods in the 323 // corresponding trait classes. This avoids them being template 324 // instantiated/duplicated. 325 namespace impl { 326 LogicalResult foldCommutative(Operation *op, ArrayRef<Attribute> operands, 327 SmallVectorImpl<OpFoldResult> &results); 328 OpFoldResult foldIdempotent(Operation *op); 329 OpFoldResult foldInvolution(Operation *op); 330 LogicalResult verifyZeroOperands(Operation *op); 331 LogicalResult verifyOneOperand(Operation *op); 332 LogicalResult verifyNOperands(Operation *op, unsigned numOperands); 333 LogicalResult verifyIsIdempotent(Operation *op); 334 LogicalResult verifyIsInvolution(Operation *op); 335 LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands); 336 LogicalResult verifyOperandsAreFloatLike(Operation *op); 337 LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op); 338 LogicalResult verifySameTypeOperands(Operation *op); 339 LogicalResult verifyZeroRegions(Operation *op); 340 LogicalResult verifyOneRegion(Operation *op); 341 LogicalResult verifyNRegions(Operation *op, unsigned numRegions); 342 LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions); 343 LogicalResult verifyZeroResults(Operation *op); 344 LogicalResult verifyOneResult(Operation *op); 345 LogicalResult verifyNResults(Operation *op, unsigned numOperands); 346 LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands); 347 LogicalResult verifySameOperandsShape(Operation *op); 348 LogicalResult verifySameOperandsAndResultShape(Operation *op); 349 LogicalResult verifySameOperandsElementType(Operation *op); 350 LogicalResult verifySameOperandsAndResultElementType(Operation *op); 351 LogicalResult verifySameOperandsAndResultType(Operation *op); 352 LogicalResult verifySameOperandsAndResultRank(Operation *op); 353 LogicalResult verifyResultsAreBoolLike(Operation *op); 354 LogicalResult verifyResultsAreFloatLike(Operation *op); 355 LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op); 356 LogicalResult verifyIsTerminator(Operation *op); 357 LogicalResult verifyZeroSuccessors(Operation *op); 358 LogicalResult verifyOneSuccessor(Operation *op); 359 LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors); 360 LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors); 361 LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName, 362 StringRef valueGroupName, 363 size_t expectedCount); 364 LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName); 365 LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName); 366 LogicalResult verifyNoRegionArguments(Operation *op); 367 LogicalResult verifyElementwise(Operation *op); 368 LogicalResult verifyIsIsolatedFromAbove(Operation *op); 369 } // namespace impl 370 371 /// Helper class for implementing traits. Clients are not expected to interact 372 /// with this directly, so its members are all protected. 373 template <typename ConcreteType, template <typename> class TraitType> 374 class TraitBase { 375 protected: 376 /// Return the ultimate Operation being worked on. 377 Operation *getOperation() { 378 auto *concrete = static_cast<ConcreteType *>(this); 379 return concrete->getOperation(); 380 } 381 }; 382 383 //===----------------------------------------------------------------------===// 384 // Operand Traits 385 386 namespace detail { 387 /// Utility trait base that provides accessors for derived traits that have 388 /// multiple operands. 389 template <typename ConcreteType, template <typename> class TraitType> 390 struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> { 391 using operand_iterator = Operation::operand_iterator; 392 using operand_range = Operation::operand_range; 393 using operand_type_iterator = Operation::operand_type_iterator; 394 using operand_type_range = Operation::operand_type_range; 395 396 /// Return the number of operands. 397 unsigned getNumOperands() { return this->getOperation()->getNumOperands(); } 398 399 /// Return the operand at index 'i'. 400 Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); } 401 402 /// Set the operand at index 'i' to 'value'. 403 void setOperand(unsigned i, Value value) { 404 this->getOperation()->setOperand(i, value); 405 } 406 407 /// Operand iterator access. 408 operand_iterator operand_begin() { 409 return this->getOperation()->operand_begin(); 410 } 411 operand_iterator operand_end() { return this->getOperation()->operand_end(); } 412 operand_range getOperands() { return this->getOperation()->getOperands(); } 413 414 /// Operand type access. 415 operand_type_iterator operand_type_begin() { 416 return this->getOperation()->operand_type_begin(); 417 } 418 operand_type_iterator operand_type_end() { 419 return this->getOperation()->operand_type_end(); 420 } 421 operand_type_range getOperandTypes() { 422 return this->getOperation()->getOperandTypes(); 423 } 424 }; 425 } // namespace detail 426 427 /// `verifyInvariantsImpl` verifies the invariants like the types, attrs, .etc. 428 /// It should be run after core traits and before any other user defined traits. 429 /// In order to run it in the correct order, wrap it with OpInvariants trait so 430 /// that tblgen will be able to put it in the right order. 431 template <typename ConcreteType> 432 class OpInvariants : public TraitBase<ConcreteType, OpInvariants> { 433 public: 434 static LogicalResult verifyTrait(Operation *op) { 435 return cast<ConcreteType>(op).verifyInvariantsImpl(); 436 } 437 }; 438 439 /// This class provides the API for ops that are known to have no 440 /// SSA operand. 441 template <typename ConcreteType> 442 class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> { 443 public: 444 static LogicalResult verifyTrait(Operation *op) { 445 return impl::verifyZeroOperands(op); 446 } 447 448 private: 449 // Disable these. 450 void getOperand() {} 451 void setOperand() {} 452 }; 453 454 /// This class provides the API for ops that are known to have exactly one 455 /// SSA operand. 456 template <typename ConcreteType> 457 class OneOperand : public TraitBase<ConcreteType, OneOperand> { 458 public: 459 Value getOperand() { return this->getOperation()->getOperand(0); } 460 461 void setOperand(Value value) { this->getOperation()->setOperand(0, value); } 462 463 static LogicalResult verifyTrait(Operation *op) { 464 return impl::verifyOneOperand(op); 465 } 466 }; 467 468 /// This class provides the API for ops that are known to have a specified 469 /// number of operands. This is used as a trait like this: 470 /// 471 /// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> { 472 /// 473 template <unsigned N> 474 class NOperands { 475 public: 476 static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2"); 477 478 template <typename ConcreteType> 479 class Impl 480 : public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> { 481 public: 482 static LogicalResult verifyTrait(Operation *op) { 483 return impl::verifyNOperands(op, N); 484 } 485 }; 486 }; 487 488 /// This class provides the API for ops that are known to have a at least a 489 /// specified number of operands. This is used as a trait like this: 490 /// 491 /// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> { 492 /// 493 template <unsigned N> 494 class AtLeastNOperands { 495 public: 496 template <typename ConcreteType> 497 class Impl : public detail::MultiOperandTraitBase<ConcreteType, 498 AtLeastNOperands<N>::Impl> { 499 public: 500 static LogicalResult verifyTrait(Operation *op) { 501 return impl::verifyAtLeastNOperands(op, N); 502 } 503 }; 504 }; 505 506 /// This class provides the API for ops which have an unknown number of 507 /// SSA operands. 508 template <typename ConcreteType> 509 class VariadicOperands 510 : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {}; 511 512 //===----------------------------------------------------------------------===// 513 // Region Traits 514 515 /// This class provides verification for ops that are known to have zero 516 /// regions. 517 template <typename ConcreteType> 518 class ZeroRegions : public TraitBase<ConcreteType, ZeroRegions> { 519 public: 520 static LogicalResult verifyTrait(Operation *op) { 521 return impl::verifyZeroRegions(op); 522 } 523 }; 524 525 namespace detail { 526 /// Utility trait base that provides accessors for derived traits that have 527 /// multiple regions. 528 template <typename ConcreteType, template <typename> class TraitType> 529 struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> { 530 using region_iterator = MutableArrayRef<Region>; 531 using region_range = RegionRange; 532 533 /// Return the number of regions. 534 unsigned getNumRegions() { return this->getOperation()->getNumRegions(); } 535 536 /// Return the region at `index`. 537 Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); } 538 539 /// Region iterator access. 540 region_iterator region_begin() { 541 return this->getOperation()->region_begin(); 542 } 543 region_iterator region_end() { return this->getOperation()->region_end(); } 544 region_range getRegions() { return this->getOperation()->getRegions(); } 545 }; 546 } // namespace detail 547 548 /// This class provides APIs for ops that are known to have a single region. 549 template <typename ConcreteType> 550 class OneRegion : public TraitBase<ConcreteType, OneRegion> { 551 public: 552 Region &getRegion() { return this->getOperation()->getRegion(0); } 553 554 /// Returns a range of operations within the region of this operation. 555 auto getOps() { return getRegion().getOps(); } 556 template <typename OpT> 557 auto getOps() { 558 return getRegion().template getOps<OpT>(); 559 } 560 561 static LogicalResult verifyTrait(Operation *op) { 562 return impl::verifyOneRegion(op); 563 } 564 }; 565 566 /// This class provides the API for ops that are known to have a specified 567 /// number of regions. 568 template <unsigned N> 569 class NRegions { 570 public: 571 static_assert(N > 1, "use ZeroRegions/OneRegion for N < 2"); 572 573 template <typename ConcreteType> 574 class Impl 575 : public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> { 576 public: 577 static LogicalResult verifyTrait(Operation *op) { 578 return impl::verifyNRegions(op, N); 579 } 580 }; 581 }; 582 583 /// This class provides APIs for ops that are known to have at least a specified 584 /// number of regions. 585 template <unsigned N> 586 class AtLeastNRegions { 587 public: 588 template <typename ConcreteType> 589 class Impl : public detail::MultiRegionTraitBase<ConcreteType, 590 AtLeastNRegions<N>::Impl> { 591 public: 592 static LogicalResult verifyTrait(Operation *op) { 593 return impl::verifyAtLeastNRegions(op, N); 594 } 595 }; 596 }; 597 598 /// This class provides the API for ops which have an unknown number of 599 /// regions. 600 template <typename ConcreteType> 601 class VariadicRegions 602 : public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {}; 603 604 //===----------------------------------------------------------------------===// 605 // Result Traits 606 607 /// This class provides return value APIs for ops that are known to have 608 /// zero results. 609 template <typename ConcreteType> 610 class ZeroResults : public TraitBase<ConcreteType, ZeroResults> { 611 public: 612 static LogicalResult verifyTrait(Operation *op) { 613 return impl::verifyZeroResults(op); 614 } 615 }; 616 617 namespace detail { 618 /// Utility trait base that provides accessors for derived traits that have 619 /// multiple results. 620 template <typename ConcreteType, template <typename> class TraitType> 621 struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> { 622 using result_iterator = Operation::result_iterator; 623 using result_range = Operation::result_range; 624 using result_type_iterator = Operation::result_type_iterator; 625 using result_type_range = Operation::result_type_range; 626 627 /// Return the number of results. 628 unsigned getNumResults() { return this->getOperation()->getNumResults(); } 629 630 /// Return the result at index 'i'. 631 Value getResult(unsigned i) { return this->getOperation()->getResult(i); } 632 633 /// Replace all uses of results of this operation with the provided 'values'. 634 /// 'values' may correspond to an existing operation, or a range of 'Value'. 635 template <typename ValuesT> 636 void replaceAllUsesWith(ValuesT &&values) { 637 this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values)); 638 } 639 640 /// Return the type of the `i`-th result. 641 Type getType(unsigned i) { return getResult(i).getType(); } 642 643 /// Result iterator access. 644 result_iterator result_begin() { 645 return this->getOperation()->result_begin(); 646 } 647 result_iterator result_end() { return this->getOperation()->result_end(); } 648 result_range getResults() { return this->getOperation()->getResults(); } 649 650 /// Result type access. 651 result_type_iterator result_type_begin() { 652 return this->getOperation()->result_type_begin(); 653 } 654 result_type_iterator result_type_end() { 655 return this->getOperation()->result_type_end(); 656 } 657 result_type_range getResultTypes() { 658 return this->getOperation()->getResultTypes(); 659 } 660 }; 661 } // namespace detail 662 663 /// This class provides return value APIs for ops that are known to have a 664 /// single result. ResultType is the concrete type returned by getType(). 665 template <typename ConcreteType> 666 class OneResult : public TraitBase<ConcreteType, OneResult> { 667 public: 668 /// Replace all uses of 'this' value with the new value, updating anything 669 /// in the IR that uses 'this' to use the other value instead. When this 670 /// returns there are zero uses of 'this'. 671 void replaceAllUsesWith(Value newValue) { 672 this->getOperation()->getResult(0).replaceAllUsesWith(newValue); 673 } 674 675 /// Replace all uses of 'this' value with the result of 'op'. 676 void replaceAllUsesWith(Operation *op) { 677 this->getOperation()->replaceAllUsesWith(op); 678 } 679 680 static LogicalResult verifyTrait(Operation *op) { 681 return impl::verifyOneResult(op); 682 } 683 }; 684 685 /// This trait is used for return value APIs for ops that are known to have a 686 /// specific type other than `Type`. This allows the "getType()" member to be 687 /// more specific for an op. This should be used in conjunction with OneResult, 688 /// and occur in the trait list before OneResult. 689 template <typename ResultType> 690 class OneTypedResult { 691 public: 692 /// This class provides return value APIs for ops that are known to have a 693 /// single result. ResultType is the concrete type returned by getType(). 694 template <typename ConcreteType> 695 class Impl 696 : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> { 697 public: 698 mlir::TypedValue<ResultType> getResult() { 699 return cast<mlir::TypedValue<ResultType>>( 700 this->getOperation()->getResult(0)); 701 } 702 703 /// If the operation returns a single value, then the Op can be implicitly 704 /// converted to a Value. This yields the value of the only result. 705 operator mlir::TypedValue<ResultType>() { return getResult(); } 706 707 ResultType getType() { return getResult().getType(); } 708 }; 709 }; 710 711 /// This class provides the API for ops that are known to have a specified 712 /// number of results. This is used as a trait like this: 713 /// 714 /// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> { 715 /// 716 template <unsigned N> 717 class NResults { 718 public: 719 static_assert(N > 1, "use ZeroResults/OneResult for N < 2"); 720 721 template <typename ConcreteType> 722 class Impl 723 : public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> { 724 public: 725 static LogicalResult verifyTrait(Operation *op) { 726 return impl::verifyNResults(op, N); 727 } 728 }; 729 }; 730 731 /// This class provides the API for ops that are known to have at least a 732 /// specified number of results. This is used as a trait like this: 733 /// 734 /// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> { 735 /// 736 template <unsigned N> 737 class AtLeastNResults { 738 public: 739 template <typename ConcreteType> 740 class Impl : public detail::MultiResultTraitBase<ConcreteType, 741 AtLeastNResults<N>::Impl> { 742 public: 743 static LogicalResult verifyTrait(Operation *op) { 744 return impl::verifyAtLeastNResults(op, N); 745 } 746 }; 747 }; 748 749 /// This class provides the API for ops which have an unknown number of 750 /// results. 751 template <typename ConcreteType> 752 class VariadicResults 753 : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {}; 754 755 //===----------------------------------------------------------------------===// 756 // Terminator Traits 757 758 /// This class indicates that the regions associated with this op don't have 759 /// terminators. 760 template <typename ConcreteType> 761 class NoTerminator : public TraitBase<ConcreteType, NoTerminator> {}; 762 763 /// This class provides the API for ops that are known to be terminators. 764 template <typename ConcreteType> 765 class IsTerminator : public TraitBase<ConcreteType, IsTerminator> { 766 public: 767 static LogicalResult verifyTrait(Operation *op) { 768 return impl::verifyIsTerminator(op); 769 } 770 }; 771 772 /// This class provides verification for ops that are known to have zero 773 /// successors. 774 template <typename ConcreteType> 775 class ZeroSuccessors : public TraitBase<ConcreteType, ZeroSuccessors> { 776 public: 777 static LogicalResult verifyTrait(Operation *op) { 778 return impl::verifyZeroSuccessors(op); 779 } 780 }; 781 782 namespace detail { 783 /// Utility trait base that provides accessors for derived traits that have 784 /// multiple successors. 785 template <typename ConcreteType, template <typename> class TraitType> 786 struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> { 787 using succ_iterator = Operation::succ_iterator; 788 using succ_range = SuccessorRange; 789 790 /// Return the number of successors. 791 unsigned getNumSuccessors() { 792 return this->getOperation()->getNumSuccessors(); 793 } 794 795 /// Return the successor at `index`. 796 Block *getSuccessor(unsigned i) { 797 return this->getOperation()->getSuccessor(i); 798 } 799 800 /// Set the successor at `index`. 801 void setSuccessor(Block *block, unsigned i) { 802 return this->getOperation()->setSuccessor(block, i); 803 } 804 805 /// Successor iterator access. 806 succ_iterator succ_begin() { return this->getOperation()->succ_begin(); } 807 succ_iterator succ_end() { return this->getOperation()->succ_end(); } 808 succ_range getSuccessors() { return this->getOperation()->getSuccessors(); } 809 }; 810 } // namespace detail 811 812 /// This class provides APIs for ops that are known to have a single successor. 813 template <typename ConcreteType> 814 class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> { 815 public: 816 Block *getSuccessor() { return this->getOperation()->getSuccessor(0); } 817 void setSuccessor(Block *succ) { 818 this->getOperation()->setSuccessor(succ, 0); 819 } 820 821 static LogicalResult verifyTrait(Operation *op) { 822 return impl::verifyOneSuccessor(op); 823 } 824 }; 825 826 /// This class provides the API for ops that are known to have a specified 827 /// number of successors. 828 template <unsigned N> 829 class NSuccessors { 830 public: 831 static_assert(N > 1, "use ZeroSuccessors/OneSuccessor for N < 2"); 832 833 template <typename ConcreteType> 834 class Impl : public detail::MultiSuccessorTraitBase<ConcreteType, 835 NSuccessors<N>::Impl> { 836 public: 837 static LogicalResult verifyTrait(Operation *op) { 838 return impl::verifyNSuccessors(op, N); 839 } 840 }; 841 }; 842 843 /// This class provides APIs for ops that are known to have at least a specified 844 /// number of successors. 845 template <unsigned N> 846 class AtLeastNSuccessors { 847 public: 848 template <typename ConcreteType> 849 class Impl 850 : public detail::MultiSuccessorTraitBase<ConcreteType, 851 AtLeastNSuccessors<N>::Impl> { 852 public: 853 static LogicalResult verifyTrait(Operation *op) { 854 return impl::verifyAtLeastNSuccessors(op, N); 855 } 856 }; 857 }; 858 859 /// This class provides the API for ops which have an unknown number of 860 /// successors. 861 template <typename ConcreteType> 862 class VariadicSuccessors 863 : public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> { 864 }; 865 866 //===----------------------------------------------------------------------===// 867 // SingleBlock 868 869 /// This class provides APIs and verifiers for ops with regions having a single 870 /// block. 871 template <typename ConcreteType> 872 struct SingleBlock : public TraitBase<ConcreteType, SingleBlock> { 873 public: 874 static LogicalResult verifyTrait(Operation *op) { 875 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { 876 Region ®ion = op->getRegion(i); 877 878 // Empty regions are fine. 879 if (region.empty()) 880 continue; 881 882 // Non-empty regions must contain a single basic block. 883 if (!llvm::hasSingleElement(region)) 884 return op->emitOpError("expects region #") 885 << i << " to have 0 or 1 blocks"; 886 887 if (!ConcreteType::template hasTrait<NoTerminator>()) { 888 Block &block = region.front(); 889 if (block.empty()) 890 return op->emitOpError() << "expects a non-empty block"; 891 } 892 } 893 return success(); 894 } 895 896 Block *getBody(unsigned idx = 0) { 897 Region ®ion = this->getOperation()->getRegion(idx); 898 assert(!region.empty() && "unexpected empty region"); 899 return ®ion.front(); 900 } 901 Region &getBodyRegion(unsigned idx = 0) { 902 return this->getOperation()->getRegion(idx); 903 } 904 905 //===------------------------------------------------------------------===// 906 // Single Region Utilities 907 //===------------------------------------------------------------------===// 908 909 /// The following are a set of methods only enabled when the parent 910 /// operation has a single region. Each of these methods take an additional 911 /// template parameter that represents the concrete operation so that we 912 /// can use SFINAE to disable the methods for non-single region operations. 913 template <typename OpT, typename T = void> 914 using enable_if_single_region = 915 std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>; 916 917 template <typename OpT = ConcreteType> 918 enable_if_single_region<OpT, Block::iterator> begin() { 919 return getBody()->begin(); 920 } 921 template <typename OpT = ConcreteType> 922 enable_if_single_region<OpT, Block::iterator> end() { 923 return getBody()->end(); 924 } 925 template <typename OpT = ConcreteType> 926 enable_if_single_region<OpT, Operation &> front() { 927 return *begin(); 928 } 929 930 /// Insert the operation into the back of the body. 931 template <typename OpT = ConcreteType> 932 enable_if_single_region<OpT> push_back(Operation *op) { 933 insert(Block::iterator(getBody()->end()), op); 934 } 935 936 /// Insert the operation at the given insertion point. 937 template <typename OpT = ConcreteType> 938 enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) { 939 insert(Block::iterator(insertPt), op); 940 } 941 template <typename OpT = ConcreteType> 942 enable_if_single_region<OpT> insert(Block::iterator insertPt, Operation *op) { 943 getBody()->getOperations().insert(insertPt, op); 944 } 945 }; 946 947 //===----------------------------------------------------------------------===// 948 // SingleBlockImplicitTerminator 949 950 /// This class provides APIs and verifiers for ops with regions having a single 951 /// block that must terminate with `TerminatorOpType`. 952 template <typename TerminatorOpType> 953 struct SingleBlockImplicitTerminator { 954 template <typename ConcreteType> 955 class Impl : public TraitBase<ConcreteType, SingleBlockImplicitTerminator< 956 TerminatorOpType>::Impl> { 957 private: 958 /// Builds a terminator operation without relying on OpBuilder APIs to avoid 959 /// cyclic header inclusion. 960 static Operation *buildTerminator(OpBuilder &builder, Location loc) { 961 OperationState state(loc, TerminatorOpType::getOperationName()); 962 TerminatorOpType::build(builder, state); 963 return Operation::create(state); 964 } 965 966 public: 967 /// The type of the operation used as the implicit terminator type. 968 using ImplicitTerminatorOpT = TerminatorOpType; 969 970 static LogicalResult verifyRegionTrait(Operation *op) { 971 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { 972 Region ®ion = op->getRegion(i); 973 // Empty regions are fine. 974 if (region.empty()) 975 continue; 976 Operation &terminator = region.front().back(); 977 if (isa<TerminatorOpType>(terminator)) 978 continue; 979 980 return op->emitOpError("expects regions to end with '" + 981 TerminatorOpType::getOperationName() + 982 "', found '" + 983 terminator.getName().getStringRef() + "'") 984 .attachNote() 985 << "in custom textual format, the absence of terminator implies " 986 "'" 987 << TerminatorOpType::getOperationName() << '\''; 988 } 989 990 return success(); 991 } 992 993 /// Ensure that the given region has the terminator required by this trait. 994 /// If OpBuilder is provided, use it to build the terminator and notify the 995 /// OpBuilder listeners accordingly. If only a Builder is provided, locally 996 /// construct an OpBuilder with no listeners; this should only be used if no 997 /// OpBuilder is available at the call site, e.g., in the parser. 998 static void ensureTerminator(Region ®ion, Builder &builder, 999 Location loc) { 1000 ::mlir::impl::ensureRegionTerminator(region, builder, loc, 1001 buildTerminator); 1002 } 1003 static void ensureTerminator(Region ®ion, OpBuilder &builder, 1004 Location loc) { 1005 ::mlir::impl::ensureRegionTerminator(region, builder, loc, 1006 buildTerminator); 1007 } 1008 }; 1009 }; 1010 1011 /// Check is an op defines the `ImplicitTerminatorOpT` member. This is intended 1012 /// to be used with `llvm::is_detected`. 1013 template <class T> 1014 using has_implicit_terminator_t = typename T::ImplicitTerminatorOpT; 1015 1016 /// Support to check if an operation has the SingleBlockImplicitTerminator 1017 /// trait. We can't just use `hasTrait` because this class is templated on a 1018 /// specific terminator op. 1019 template <class Op, bool hasTerminator = 1020 llvm::is_detected<has_implicit_terminator_t, Op>::value> 1021 struct hasSingleBlockImplicitTerminator { 1022 static constexpr bool value = std::is_base_of< 1023 typename OpTrait::SingleBlockImplicitTerminator< 1024 typename Op::ImplicitTerminatorOpT>::template Impl<Op>, 1025 Op>::value; 1026 }; 1027 template <class Op> 1028 struct hasSingleBlockImplicitTerminator<Op, false> { 1029 static constexpr bool value = false; 1030 }; 1031 1032 //===----------------------------------------------------------------------===// 1033 // Misc Traits 1034 1035 /// This class provides verification for ops that are known to have the same 1036 /// operand shape: all operands are scalars, vectors/tensors of the same 1037 /// shape. 1038 template <typename ConcreteType> 1039 class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> { 1040 public: 1041 static LogicalResult verifyTrait(Operation *op) { 1042 return impl::verifySameOperandsShape(op); 1043 } 1044 }; 1045 1046 /// This class provides verification for ops that are known to have the same 1047 /// operand and result shape: both are scalars, vectors/tensors of the same 1048 /// shape. 1049 template <typename ConcreteType> 1050 class SameOperandsAndResultShape 1051 : public TraitBase<ConcreteType, SameOperandsAndResultShape> { 1052 public: 1053 static LogicalResult verifyTrait(Operation *op) { 1054 return impl::verifySameOperandsAndResultShape(op); 1055 } 1056 }; 1057 1058 /// This class provides verification for ops that are known to have the same 1059 /// operand element type (or the type itself if it is scalar). 1060 /// 1061 template <typename ConcreteType> 1062 class SameOperandsElementType 1063 : public TraitBase<ConcreteType, SameOperandsElementType> { 1064 public: 1065 static LogicalResult verifyTrait(Operation *op) { 1066 return impl::verifySameOperandsElementType(op); 1067 } 1068 }; 1069 1070 /// This class provides verification for ops that are known to have the same 1071 /// operand and result element type (or the type itself if it is scalar). 1072 /// 1073 template <typename ConcreteType> 1074 class SameOperandsAndResultElementType 1075 : public TraitBase<ConcreteType, SameOperandsAndResultElementType> { 1076 public: 1077 static LogicalResult verifyTrait(Operation *op) { 1078 return impl::verifySameOperandsAndResultElementType(op); 1079 } 1080 }; 1081 1082 /// This class provides verification for ops that are known to have the same 1083 /// operand and result type. 1084 /// 1085 /// Note: this trait subsumes the SameOperandsAndResultShape and 1086 /// SameOperandsAndResultElementType traits. 1087 template <typename ConcreteType> 1088 class SameOperandsAndResultType 1089 : public TraitBase<ConcreteType, SameOperandsAndResultType> { 1090 public: 1091 static LogicalResult verifyTrait(Operation *op) { 1092 return impl::verifySameOperandsAndResultType(op); 1093 } 1094 }; 1095 1096 /// This class verifies that op has same ranks for all 1097 /// operands and results types, if known. 1098 template <typename ConcreteType> 1099 class SameOperandsAndResultRank 1100 : public TraitBase<ConcreteType, SameOperandsAndResultRank> { 1101 public: 1102 static LogicalResult verifyTrait(Operation *op) { 1103 return impl::verifySameOperandsAndResultRank(op); 1104 } 1105 }; 1106 1107 /// This class verifies that any results of the specified op have a boolean 1108 /// type, a vector thereof, or a tensor thereof. 1109 template <typename ConcreteType> 1110 class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> { 1111 public: 1112 static LogicalResult verifyTrait(Operation *op) { 1113 return impl::verifyResultsAreBoolLike(op); 1114 } 1115 }; 1116 1117 /// This class verifies that any results of the specified op have a floating 1118 /// point type, a vector thereof, or a tensor thereof. 1119 template <typename ConcreteType> 1120 class ResultsAreFloatLike 1121 : public TraitBase<ConcreteType, ResultsAreFloatLike> { 1122 public: 1123 static LogicalResult verifyTrait(Operation *op) { 1124 return impl::verifyResultsAreFloatLike(op); 1125 } 1126 }; 1127 1128 /// This class verifies that any results of the specified op have a signless 1129 /// integer or index type, a vector thereof, or a tensor thereof. 1130 template <typename ConcreteType> 1131 class ResultsAreSignlessIntegerLike 1132 : public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> { 1133 public: 1134 static LogicalResult verifyTrait(Operation *op) { 1135 return impl::verifyResultsAreSignlessIntegerLike(op); 1136 } 1137 }; 1138 1139 /// This class adds property that the operation is commutative. 1140 template <typename ConcreteType> 1141 class IsCommutative : public TraitBase<ConcreteType, IsCommutative> { 1142 public: 1143 static LogicalResult foldTrait(Operation *op, ArrayRef<Attribute> operands, 1144 SmallVectorImpl<OpFoldResult> &results) { 1145 return impl::foldCommutative(op, operands, results); 1146 } 1147 }; 1148 1149 /// This class adds property that the operation is an involution. 1150 /// This means a unary to unary operation "f" that satisfies f(f(x)) = x 1151 template <typename ConcreteType> 1152 class IsInvolution : public TraitBase<ConcreteType, IsInvolution> { 1153 public: 1154 static LogicalResult verifyTrait(Operation *op) { 1155 static_assert(ConcreteType::template hasTrait<OneResult>(), 1156 "expected operation to produce one result"); 1157 static_assert(ConcreteType::template hasTrait<OneOperand>(), 1158 "expected operation to take one operand"); 1159 static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(), 1160 "expected operation to preserve type"); 1161 // Involution requires the operation to be side effect free as well 1162 // but currently this check is under a FIXME and is not actually done. 1163 return impl::verifyIsInvolution(op); 1164 } 1165 1166 static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) { 1167 return impl::foldInvolution(op); 1168 } 1169 }; 1170 1171 /// This class adds property that the operation is idempotent. 1172 /// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x), 1173 /// or a binary operation "g" that satisfies g(x, x) = x. 1174 template <typename ConcreteType> 1175 class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> { 1176 public: 1177 static LogicalResult verifyTrait(Operation *op) { 1178 static_assert(ConcreteType::template hasTrait<OneResult>(), 1179 "expected operation to produce one result"); 1180 static_assert(ConcreteType::template hasTrait<OneOperand>() || 1181 ConcreteType::template hasTrait<NOperands<2>::Impl>(), 1182 "expected operation to take one or two operands"); 1183 static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(), 1184 "expected operation to preserve type"); 1185 // Idempotent requires the operation to be side effect free as well 1186 // but currently this check is under a FIXME and is not actually done. 1187 return impl::verifyIsIdempotent(op); 1188 } 1189 1190 static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) { 1191 return impl::foldIdempotent(op); 1192 } 1193 }; 1194 1195 /// This class verifies that all operands of the specified op have a float type, 1196 /// a vector thereof, or a tensor thereof. 1197 template <typename ConcreteType> 1198 class OperandsAreFloatLike 1199 : public TraitBase<ConcreteType, OperandsAreFloatLike> { 1200 public: 1201 static LogicalResult verifyTrait(Operation *op) { 1202 return impl::verifyOperandsAreFloatLike(op); 1203 } 1204 }; 1205 1206 /// This class verifies that all operands of the specified op have a signless 1207 /// integer or index type, a vector thereof, or a tensor thereof. 1208 template <typename ConcreteType> 1209 class OperandsAreSignlessIntegerLike 1210 : public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> { 1211 public: 1212 static LogicalResult verifyTrait(Operation *op) { 1213 return impl::verifyOperandsAreSignlessIntegerLike(op); 1214 } 1215 }; 1216 1217 /// This class verifies that all operands of the specified op have the same 1218 /// type. 1219 template <typename ConcreteType> 1220 class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> { 1221 public: 1222 static LogicalResult verifyTrait(Operation *op) { 1223 return impl::verifySameTypeOperands(op); 1224 } 1225 }; 1226 1227 /// This class provides the API for a sub-set of ops that are known to be 1228 /// constant-like. These are non-side effecting operations with one result and 1229 /// zero operands that can always be folded to a specific attribute value. 1230 template <typename ConcreteType> 1231 class ConstantLike : public TraitBase<ConcreteType, ConstantLike> { 1232 public: 1233 static LogicalResult verifyTrait(Operation *op) { 1234 static_assert(ConcreteType::template hasTrait<OneResult>(), 1235 "expected operation to produce one result"); 1236 static_assert(ConcreteType::template hasTrait<ZeroOperands>(), 1237 "expected operation to take zero operands"); 1238 // TODO: We should verify that the operation can always be folded, but this 1239 // requires that the attributes of the op already be verified. We should add 1240 // support for verifying traits "after" the operation to enable this use 1241 // case. 1242 return success(); 1243 } 1244 }; 1245 1246 /// This class provides the API for ops that are known to be isolated from 1247 /// above. 1248 template <typename ConcreteType> 1249 class IsIsolatedFromAbove 1250 : public TraitBase<ConcreteType, IsIsolatedFromAbove> { 1251 public: 1252 static LogicalResult verifyRegionTrait(Operation *op) { 1253 return impl::verifyIsIsolatedFromAbove(op); 1254 } 1255 }; 1256 1257 /// A trait of region holding operations that defines a new scope for polyhedral 1258 /// optimization purposes. Any SSA values of 'index' type that either dominate 1259 /// such an operation or are used at the top-level of such an operation 1260 /// automatically become valid symbols for the polyhedral scope defined by that 1261 /// operation. For more details, see `Traits.md#AffineScope`. 1262 template <typename ConcreteType> 1263 class AffineScope : public TraitBase<ConcreteType, AffineScope> { 1264 public: 1265 static LogicalResult verifyTrait(Operation *op) { 1266 static_assert(!ConcreteType::template hasTrait<ZeroRegions>(), 1267 "expected operation to have one or more regions"); 1268 return success(); 1269 } 1270 }; 1271 1272 /// A trait of region holding operations that define a new scope for automatic 1273 /// allocations, i.e., allocations that are freed when control is transferred 1274 /// back from the operation's region. Any operations performing such allocations 1275 /// (for eg. memref.alloca) will have their allocations automatically freed at 1276 /// their closest enclosing operation with this trait. 1277 template <typename ConcreteType> 1278 class AutomaticAllocationScope 1279 : public TraitBase<ConcreteType, AutomaticAllocationScope> { 1280 public: 1281 static LogicalResult verifyTrait(Operation *op) { 1282 static_assert(!ConcreteType::template hasTrait<ZeroRegions>(), 1283 "expected operation to have one or more regions"); 1284 return success(); 1285 } 1286 }; 1287 1288 /// This class provides a verifier for ops that are expecting their parent 1289 /// to be one of the given parent ops 1290 template <typename... ParentOpTypes> 1291 struct HasParent { 1292 template <typename ConcreteType> 1293 class Impl : public TraitBase<ConcreteType, Impl> { 1294 public: 1295 static LogicalResult verifyTrait(Operation *op) { 1296 if (llvm::isa_and_nonnull<ParentOpTypes...>(op->getParentOp())) 1297 return success(); 1298 1299 return op->emitOpError() 1300 << "expects parent op " 1301 << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'") 1302 << llvm::ArrayRef({ParentOpTypes::getOperationName()...}) << "'"; 1303 } 1304 1305 template <typename ParentOpType = 1306 std::tuple_element_t<0, std::tuple<ParentOpTypes...>>> 1307 std::enable_if_t<sizeof...(ParentOpTypes) == 1, ParentOpType> 1308 getParentOp() { 1309 Operation *parent = this->getOperation()->getParentOp(); 1310 return llvm::cast<ParentOpType>(parent); 1311 } 1312 }; 1313 }; 1314 1315 /// A trait for operations that have an attribute specifying operand segments. 1316 /// 1317 /// Certain operations can have multiple variadic operands and their size 1318 /// relationship is not always known statically. For such cases, we need 1319 /// a per-op-instance specification to divide the operands into logical groups 1320 /// or segments. This can be modeled by attributes. The attribute will be named 1321 /// as `operandSegmentSizes`. 1322 /// 1323 /// This trait verifies the attribute for specifying operand segments has 1324 /// the correct type (1D vector) and values (non-negative), etc. 1325 template <typename ConcreteType> 1326 class AttrSizedOperandSegments 1327 : public TraitBase<ConcreteType, AttrSizedOperandSegments> { 1328 public: 1329 static StringRef getOperandSegmentSizeAttr() { return "operandSegmentSizes"; } 1330 1331 static LogicalResult verifyTrait(Operation *op) { 1332 return ::mlir::OpTrait::impl::verifyOperandSizeAttr( 1333 op, getOperandSegmentSizeAttr()); 1334 } 1335 }; 1336 1337 /// Similar to AttrSizedOperandSegments but used for results. 1338 template <typename ConcreteType> 1339 class AttrSizedResultSegments 1340 : public TraitBase<ConcreteType, AttrSizedResultSegments> { 1341 public: 1342 static StringRef getResultSegmentSizeAttr() { return "resultSegmentSizes"; } 1343 1344 static LogicalResult verifyTrait(Operation *op) { 1345 return ::mlir::OpTrait::impl::verifyResultSizeAttr( 1346 op, getResultSegmentSizeAttr()); 1347 } 1348 }; 1349 1350 /// This trait provides a verifier for ops that are expecting their regions to 1351 /// not have any arguments 1352 template <typename ConcrentType> 1353 struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> { 1354 static LogicalResult verifyTrait(Operation *op) { 1355 return ::mlir::OpTrait::impl::verifyNoRegionArguments(op); 1356 } 1357 }; 1358 1359 // This trait is used to flag operations that consume or produce 1360 // values of `MemRef` type where those references can be 'normalized'. 1361 // TODO: Right now, the operands of an operation are either all normalizable, 1362 // or not. In the future, we may want to allow some of the operands to be 1363 // normalizable. 1364 template <typename ConcrentType> 1365 struct MemRefsNormalizable 1366 : public TraitBase<ConcrentType, MemRefsNormalizable> {}; 1367 1368 /// This trait tags element-wise ops on vectors or tensors. 1369 /// 1370 /// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this 1371 /// trait. In particular, broadcasting behavior is not allowed. 1372 /// 1373 /// An `Elementwise` op must satisfy the following properties: 1374 /// 1375 /// 1. If any result is a vector/tensor then at least one operand must also be a 1376 /// vector/tensor. 1377 /// 2. If any operand is a vector/tensor then there must be at least one result 1378 /// and all results must be vectors/tensors. 1379 /// 3. All operand and result vector/tensor types must be of the same shape. The 1380 /// shape may be dynamic in which case the op's behaviour is undefined for 1381 /// non-matching shapes. 1382 /// 4. The operation must be elementwise on its vector/tensor operands and 1383 /// results. When applied to single-element vectors/tensors, the result must 1384 /// be the same per elememnt. 1385 /// 1386 /// TODO: Avoid hardcoding vector/tensor, and generalize this trait to a new 1387 /// interface `ElementwiseTypeInterface` that describes the container types for 1388 /// which the operation is elementwise. 1389 /// 1390 /// Rationale: 1391 /// - 1. and 2. guarantee a well-defined iteration space and exclude the cases 1392 /// of 0 non-scalar operands or 0 non-scalar results, which complicate a 1393 /// generic definition of the iteration space. 1394 /// - 3. guarantees that folding can be done across scalars/vectors/tensors with 1395 /// the same pattern, as otherwise lots of special handling for type 1396 /// mismatches would be needed. 1397 /// - 4. guarantees that no error handling is needed. Higher-level dialects 1398 /// should reify any needed guards or error handling code before lowering to 1399 /// an `Elementwise` op. 1400 template <typename ConcreteType> 1401 struct Elementwise : public TraitBase<ConcreteType, Elementwise> { 1402 static LogicalResult verifyTrait(Operation *op) { 1403 return ::mlir::OpTrait::impl::verifyElementwise(op); 1404 } 1405 }; 1406 1407 /// This trait tags `Elementwise` operatons that can be systematically 1408 /// scalarized. All vector/tensor operands and results are then replaced by 1409 /// scalars of the respective element type. Semantically, this is the operation 1410 /// on a single element of the vector/tensor. 1411 /// 1412 /// Rationale: 1413 /// Allow to define the vector/tensor semantics of elementwise operations based 1414 /// on the same op's behavior on scalars. This provides a constructive procedure 1415 /// for IR transformations to, e.g., create scalar loop bodies from tensor ops. 1416 /// 1417 /// Example: 1418 /// ``` 1419 /// %tensor_select = "arith.select"(%pred_tensor, %true_val, %false_val) 1420 /// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) 1421 /// -> tensor<?xf32> 1422 /// ``` 1423 /// can be scalarized to 1424 /// 1425 /// ``` 1426 /// %scalar_select = "arith.select"(%pred, %true_val_scalar, %false_val_scalar) 1427 /// : (i1, f32, f32) -> f32 1428 /// ``` 1429 template <typename ConcreteType> 1430 struct Scalarizable : public TraitBase<ConcreteType, Scalarizable> { 1431 static LogicalResult verifyTrait(Operation *op) { 1432 static_assert( 1433 ConcreteType::template hasTrait<Elementwise>(), 1434 "`Scalarizable` trait is only applicable to `Elementwise` ops."); 1435 return success(); 1436 } 1437 }; 1438 1439 /// This trait tags `Elementwise` operatons that can be systematically 1440 /// vectorized. All scalar operands and results are then replaced by vectors 1441 /// with the respective element type. Semantically, this is the operation on 1442 /// multiple elements simultaneously. See also `Tensorizable`. 1443 /// 1444 /// Rationale: 1445 /// Provide the reverse to `Scalarizable` which, when chained together, allows 1446 /// reasoning about the relationship between the tensor and vector case. 1447 /// Additionally, it permits reasoning about promoting scalars to vectors via 1448 /// broadcasting in cases like `%select_scalar_pred` below. 1449 template <typename ConcreteType> 1450 struct Vectorizable : public TraitBase<ConcreteType, Vectorizable> { 1451 static LogicalResult verifyTrait(Operation *op) { 1452 static_assert( 1453 ConcreteType::template hasTrait<Elementwise>(), 1454 "`Vectorizable` trait is only applicable to `Elementwise` ops."); 1455 return success(); 1456 } 1457 }; 1458 1459 /// This trait tags `Elementwise` operatons that can be systematically 1460 /// tensorized. All scalar operands and results are then replaced by tensors 1461 /// with the respective element type. Semantically, this is the operation on 1462 /// multiple elements simultaneously. See also `Vectorizable`. 1463 /// 1464 /// Rationale: 1465 /// Provide the reverse to `Scalarizable` which, when chained together, allows 1466 /// reasoning about the relationship between the tensor and vector case. 1467 /// Additionally, it permits reasoning about promoting scalars to tensors via 1468 /// broadcasting in cases like `%select_scalar_pred` below. 1469 /// 1470 /// Examples: 1471 /// ``` 1472 /// %scalar = "arith.addf"(%a, %b) : (f32, f32) -> f32 1473 /// ``` 1474 /// can be tensorized to 1475 /// ``` 1476 /// %tensor = "arith.addf"(%a, %b) : (tensor<?xf32>, tensor<?xf32>) 1477 /// -> tensor<?xf32> 1478 /// ``` 1479 /// 1480 /// ``` 1481 /// %scalar_pred = "arith.select"(%pred, %true_val, %false_val) 1482 /// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> 1483 /// ``` 1484 /// can be tensorized to 1485 /// ``` 1486 /// %tensor_pred = "arith.select"(%pred, %true_val, %false_val) 1487 /// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) 1488 /// -> tensor<?xf32> 1489 /// ``` 1490 template <typename ConcreteType> 1491 struct Tensorizable : public TraitBase<ConcreteType, Tensorizable> { 1492 static LogicalResult verifyTrait(Operation *op) { 1493 static_assert( 1494 ConcreteType::template hasTrait<Elementwise>(), 1495 "`Tensorizable` trait is only applicable to `Elementwise` ops."); 1496 return success(); 1497 } 1498 }; 1499 1500 /// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable` 1501 /// provide an easy way for scalar operations to conveniently generalize their 1502 /// behavior to vectors/tensors, and systematize conversion between these forms. 1503 bool hasElementwiseMappableTraits(Operation *op); 1504 1505 } // namespace OpTrait 1506 1507 //===----------------------------------------------------------------------===// 1508 // Internal Trait Utilities 1509 //===----------------------------------------------------------------------===// 1510 1511 namespace op_definition_impl { 1512 //===----------------------------------------------------------------------===// 1513 // Trait Existence 1514 1515 /// Returns true if this given Trait ID matches the IDs of any of the provided 1516 /// trait types `Traits`. 1517 template <template <typename T> class... Traits> 1518 inline bool hasTrait(TypeID traitID) { 1519 TypeID traitIDs[] = {TypeID::get<Traits>()...}; 1520 for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i) 1521 if (traitIDs[i] == traitID) 1522 return true; 1523 return false; 1524 } 1525 template <> 1526 inline bool hasTrait<>(TypeID traitID) { 1527 return false; 1528 } 1529 1530 //===----------------------------------------------------------------------===// 1531 // Trait Folding 1532 1533 /// Trait to check if T provides a 'foldTrait' method for single result 1534 /// operations. 1535 template <typename T, typename... Args> 1536 using has_single_result_fold_trait = decltype(T::foldTrait( 1537 std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>())); 1538 template <typename T> 1539 using detect_has_single_result_fold_trait = 1540 llvm::is_detected<has_single_result_fold_trait, T>; 1541 /// Trait to check if T provides a general 'foldTrait' method. 1542 template <typename T, typename... Args> 1543 using has_fold_trait = 1544 decltype(T::foldTrait(std::declval<Operation *>(), 1545 std::declval<ArrayRef<Attribute>>(), 1546 std::declval<SmallVectorImpl<OpFoldResult> &>())); 1547 template <typename T> 1548 using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>; 1549 /// Trait to check if T provides any `foldTrait` method. 1550 template <typename T> 1551 using detect_has_any_fold_trait = 1552 std::disjunction<detect_has_fold_trait<T>, 1553 detect_has_single_result_fold_trait<T>>; 1554 1555 /// Returns the result of folding a trait that implements a `foldTrait` function 1556 /// that is specialized for operations that have a single result. 1557 template <typename Trait> 1558 static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value, 1559 LogicalResult> 1560 foldTrait(Operation *op, ArrayRef<Attribute> operands, 1561 SmallVectorImpl<OpFoldResult> &results) { 1562 assert(op->hasTrait<OpTrait::OneResult>() && 1563 "expected trait on non single-result operation to implement the " 1564 "general `foldTrait` method"); 1565 // If a previous trait has already been folded and replaced this operation, we 1566 // fail to fold this trait. 1567 if (!results.empty()) 1568 return failure(); 1569 1570 if (OpFoldResult result = Trait::foldTrait(op, operands)) { 1571 if (llvm::dyn_cast_if_present<Value>(result) != op->getResult(0)) 1572 results.push_back(result); 1573 return success(); 1574 } 1575 return failure(); 1576 } 1577 /// Returns the result of folding a trait that implements a generalized 1578 /// `foldTrait` function that is supports any operation type. 1579 template <typename Trait> 1580 static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult> 1581 foldTrait(Operation *op, ArrayRef<Attribute> operands, 1582 SmallVectorImpl<OpFoldResult> &results) { 1583 // If a previous trait has already been folded and replaced this operation, we 1584 // fail to fold this trait. 1585 return results.empty() ? Trait::foldTrait(op, operands, results) : failure(); 1586 } 1587 template <typename Trait> 1588 static inline std::enable_if_t<!detect_has_any_fold_trait<Trait>::value, 1589 LogicalResult> 1590 foldTrait(Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) { 1591 return failure(); 1592 } 1593 1594 /// Given a tuple type containing a set of traits, return the result of folding 1595 /// the given operation. 1596 template <typename... Ts> 1597 static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands, 1598 SmallVectorImpl<OpFoldResult> &results) { 1599 return success((succeeded(foldTrait<Ts>(op, operands, results)) || ...)); 1600 } 1601 1602 //===----------------------------------------------------------------------===// 1603 // Trait Verification 1604 1605 /// Trait to check if T provides a `verifyTrait` method. 1606 template <typename T, typename... Args> 1607 using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>())); 1608 template <typename T> 1609 using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>; 1610 1611 /// Trait to check if T provides a `verifyTrait` method. 1612 template <typename T, typename... Args> 1613 using has_verify_region_trait = 1614 decltype(T::verifyRegionTrait(std::declval<Operation *>())); 1615 template <typename T> 1616 using detect_has_verify_region_trait = 1617 llvm::is_detected<has_verify_region_trait, T>; 1618 1619 /// Verify the given trait if it provides a verifier. 1620 template <typename T> 1621 std::enable_if_t<detect_has_verify_trait<T>::value, LogicalResult> 1622 verifyTrait(Operation *op) { 1623 return T::verifyTrait(op); 1624 } 1625 template <typename T> 1626 inline std::enable_if_t<!detect_has_verify_trait<T>::value, LogicalResult> 1627 verifyTrait(Operation *) { 1628 return success(); 1629 } 1630 1631 /// Given a set of traits, return the result of verifying the given operation. 1632 template <typename... Ts> 1633 LogicalResult verifyTraits(Operation *op) { 1634 return success((succeeded(verifyTrait<Ts>(op)) && ...)); 1635 } 1636 1637 /// Verify the given trait if it provides a region verifier. 1638 template <typename T> 1639 std::enable_if_t<detect_has_verify_region_trait<T>::value, LogicalResult> 1640 verifyRegionTrait(Operation *op) { 1641 return T::verifyRegionTrait(op); 1642 } 1643 template <typename T> 1644 inline std::enable_if_t<!detect_has_verify_region_trait<T>::value, 1645 LogicalResult> 1646 verifyRegionTrait(Operation *) { 1647 return success(); 1648 } 1649 1650 /// Given a set of traits, return the result of verifying the regions of the 1651 /// given operation. 1652 template <typename... Ts> 1653 LogicalResult verifyRegionTraits(Operation *op) { 1654 return success((succeeded(verifyRegionTrait<Ts>(op)) && ...)); 1655 } 1656 } // namespace op_definition_impl 1657 1658 //===----------------------------------------------------------------------===// 1659 // Operation Definition classes 1660 //===----------------------------------------------------------------------===// 1661 1662 /// This provides public APIs that all operations should have. The template 1663 /// argument 'ConcreteType' should be the concrete type by CRTP and the others 1664 /// are base classes by the policy pattern. 1665 template <typename ConcreteType, template <typename T> class... Traits> 1666 class Op : public OpState, public Traits<ConcreteType>... { 1667 public: 1668 /// Inherit getOperation from `OpState`. 1669 using OpState::getOperation; 1670 using OpState::verify; 1671 using OpState::verifyRegions; 1672 1673 /// Return if this operation contains the provided trait. 1674 template <template <typename T> class Trait> 1675 static constexpr bool hasTrait() { 1676 return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value; 1677 } 1678 1679 /// Create a deep copy of this operation. 1680 ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); } 1681 1682 /// Create a partial copy of this operation without traversing into attached 1683 /// regions. The new operation will have the same number of regions as the 1684 /// original one, but they will be left empty. 1685 ConcreteType cloneWithoutRegions() { 1686 return cast<ConcreteType>(getOperation()->cloneWithoutRegions()); 1687 } 1688 1689 /// Return true if this "op class" can match against the specified operation. 1690 static bool classof(Operation *op) { 1691 if (auto info = op->getRegisteredInfo()) 1692 return TypeID::get<ConcreteType>() == info->getTypeID(); 1693 #ifndef NDEBUG 1694 if (op->getName().getStringRef() == ConcreteType::getOperationName()) 1695 llvm::report_fatal_error( 1696 "classof on '" + ConcreteType::getOperationName() + 1697 "' failed due to the operation not being registered"); 1698 #endif 1699 return false; 1700 } 1701 /// Provide `classof` support for other OpBase derived classes, such as 1702 /// Interfaces. 1703 template <typename T> 1704 static std::enable_if_t<std::is_base_of<OpState, T>::value, bool> 1705 classof(const T *op) { 1706 return classof(const_cast<T *>(op)->getOperation()); 1707 } 1708 1709 /// Expose the type we are instantiated on to template machinery that may want 1710 /// to introspect traits on this operation. 1711 using ConcreteOpType = ConcreteType; 1712 1713 /// This is a public constructor. Any op can be initialized to null. 1714 explicit Op() : OpState(nullptr) {} 1715 Op(std::nullptr_t) : OpState(nullptr) {} 1716 1717 /// This is a public constructor to enable access via the llvm::cast family of 1718 /// methods. This should not be used directly. 1719 explicit Op(Operation *state) : OpState(state) {} 1720 1721 /// Methods for supporting PointerLikeTypeTraits. 1722 const void *getAsOpaquePointer() const { 1723 return static_cast<const void *>((Operation *)*this); 1724 } 1725 static ConcreteOpType getFromOpaquePointer(const void *pointer) { 1726 return ConcreteOpType( 1727 reinterpret_cast<Operation *>(const_cast<void *>(pointer))); 1728 } 1729 1730 /// Attach the given models as implementations of the corresponding 1731 /// interfaces for the concrete operation. 1732 template <typename... Models> 1733 static void attachInterface(MLIRContext &context) { 1734 std::optional<RegisteredOperationName> info = 1735 RegisteredOperationName::lookup(TypeID::get<ConcreteType>(), &context); 1736 if (!info) 1737 llvm::report_fatal_error( 1738 "Attempting to attach an interface to an unregistered operation " + 1739 ConcreteType::getOperationName() + "."); 1740 (checkInterfaceTarget<Models>(), ...); 1741 info->attachInterface<Models...>(); 1742 } 1743 /// Convert the provided attribute to a property and assigned it to the 1744 /// provided properties. This default implementation forwards to a free 1745 /// function `setPropertiesFromAttribute` that can be looked up with ADL in 1746 /// the namespace where the properties are defined. It can also be overridden 1747 /// in the derived ConcreteOp. 1748 template <typename PropertiesTy> 1749 static LogicalResult 1750 setPropertiesFromAttr(PropertiesTy &prop, Attribute attr, 1751 function_ref<InFlightDiagnostic()> emitError) { 1752 return setPropertiesFromAttribute(prop, attr, emitError); 1753 } 1754 /// Convert the provided properties to an attribute. This default 1755 /// implementation forwards to a free function `getPropertiesAsAttribute` that 1756 /// can be looked up with ADL in the namespace where the properties are 1757 /// defined. It can also be overridden in the derived ConcreteOp. 1758 template <typename PropertiesTy> 1759 static Attribute getPropertiesAsAttr(MLIRContext *ctx, 1760 const PropertiesTy &prop) { 1761 return getPropertiesAsAttribute(ctx, prop); 1762 } 1763 /// Hash the provided properties. This default implementation forwards to a 1764 /// free function `computeHash` that can be looked up with ADL in the 1765 /// namespace where the properties are defined. It can also be overridden in 1766 /// the derived ConcreteOp. 1767 template <typename PropertiesTy> 1768 static llvm::hash_code computePropertiesHash(const PropertiesTy &prop) { 1769 return computeHash(prop); 1770 } 1771 1772 private: 1773 /// Trait to check if T provides a 'fold' method for a single result op. 1774 template <typename T, typename... Args> 1775 using has_single_result_fold_t = 1776 decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>())); 1777 template <typename T> 1778 constexpr static bool has_single_result_fold_v = 1779 llvm::is_detected<has_single_result_fold_t, T>::value; 1780 /// Trait to check if T provides a general 'fold' method. 1781 template <typename T, typename... Args> 1782 using has_fold_t = decltype(std::declval<T>().fold( 1783 std::declval<ArrayRef<Attribute>>(), 1784 std::declval<SmallVectorImpl<OpFoldResult> &>())); 1785 template <typename T> 1786 constexpr static bool has_fold_v = llvm::is_detected<has_fold_t, T>::value; 1787 /// Trait to check if T provides a 'fold' method with a FoldAdaptor for a 1788 /// single result op. 1789 template <typename T, typename... Args> 1790 using has_fold_adaptor_single_result_fold_t = 1791 decltype(std::declval<T>().fold(std::declval<typename T::FoldAdaptor>())); 1792 template <class T> 1793 constexpr static bool has_fold_adaptor_single_result_v = 1794 llvm::is_detected<has_fold_adaptor_single_result_fold_t, T>::value; 1795 /// Trait to check if T provides a general 'fold' method with a FoldAdaptor. 1796 template <typename T, typename... Args> 1797 using has_fold_adaptor_fold_t = decltype(std::declval<T>().fold( 1798 std::declval<typename T::FoldAdaptor>(), 1799 std::declval<SmallVectorImpl<OpFoldResult> &>())); 1800 template <class T> 1801 constexpr static bool has_fold_adaptor_v = 1802 llvm::is_detected<has_fold_adaptor_fold_t, T>::value; 1803 1804 /// Trait to check if T provides a 'print' method. 1805 template <typename T, typename... Args> 1806 using has_print = 1807 decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>())); 1808 template <typename T> 1809 using detect_has_print = llvm::is_detected<has_print, T>; 1810 1811 /// Trait to check if printProperties(OpAsmPrinter, T, ArrayRef<StringRef>) 1812 /// exist 1813 template <typename T, typename... Args> 1814 using has_print_properties = 1815 decltype(printProperties(std::declval<OpAsmPrinter &>(), 1816 std::declval<T>(), 1817 std::declval<ArrayRef<StringRef>>())); 1818 template <typename T> 1819 using detect_has_print_properties = 1820 llvm::is_detected<has_print_properties, T>; 1821 1822 /// Trait to check if parseProperties(OpAsmParser, T) exist 1823 template <typename T, typename... Args> 1824 using has_parse_properties = decltype(parseProperties( 1825 std::declval<OpAsmParser &>(), std::declval<T &>())); 1826 template <typename T> 1827 using detect_has_parse_properties = 1828 llvm::is_detected<has_parse_properties, T>; 1829 1830 /// Trait to check if T provides a 'ConcreteEntity' type alias. 1831 template <typename T> 1832 using has_concrete_entity_t = typename T::ConcreteEntity; 1833 1834 public: 1835 /// Returns true if this operation defines a `Properties` inner type. 1836 static constexpr bool hasProperties() { 1837 return !std::is_same_v< 1838 typename ConcreteType::template InferredProperties<ConcreteType>, 1839 EmptyProperties>; 1840 } 1841 1842 private: 1843 /// A struct-wrapped type alias to T::ConcreteEntity if provided and to 1844 /// ConcreteType otherwise. This is akin to std::conditional but doesn't fail 1845 /// on the missing typedef. Useful for checking if the interface is targeting 1846 /// the right class. 1847 template <typename T, 1848 bool = llvm::is_detected<has_concrete_entity_t, T>::value> 1849 struct InterfaceTargetOrOpT { 1850 using type = typename T::ConcreteEntity; 1851 }; 1852 template <typename T> 1853 struct InterfaceTargetOrOpT<T, false> { 1854 using type = ConcreteType; 1855 }; 1856 1857 /// A hook for static assertion that the external interface model T is 1858 /// targeting the concrete type of this op. The model can also be a fallback 1859 /// model that works for every op. 1860 template <typename T> 1861 static void checkInterfaceTarget() { 1862 static_assert(std::is_same<typename InterfaceTargetOrOpT<T>::type, 1863 ConcreteType>::value, 1864 "attaching an interface to the wrong op kind"); 1865 } 1866 1867 /// Returns an interface map containing the interfaces registered to this 1868 /// operation. 1869 static detail::InterfaceMap getInterfaceMap() { 1870 return detail::InterfaceMap::template get<Traits<ConcreteType>...>(); 1871 } 1872 1873 /// Return the internal implementations of each of the OperationName 1874 /// hooks. 1875 /// Implementation of `FoldHookFn` OperationName hook. 1876 static OperationName::FoldHookFn getFoldHookFn() { 1877 // If the operation is single result and defines a `fold` method. 1878 if constexpr (llvm::is_one_of<OpTrait::OneResult<ConcreteType>, 1879 Traits<ConcreteType>...>::value && 1880 (has_single_result_fold_v<ConcreteType> || 1881 has_fold_adaptor_single_result_v<ConcreteType>)) 1882 return [](Operation *op, ArrayRef<Attribute> operands, 1883 SmallVectorImpl<OpFoldResult> &results) { 1884 return foldSingleResultHook<ConcreteType>(op, operands, results); 1885 }; 1886 // The operation is not single result and defines a `fold` method. 1887 if constexpr (has_fold_v<ConcreteType> || has_fold_adaptor_v<ConcreteType>) 1888 return [](Operation *op, ArrayRef<Attribute> operands, 1889 SmallVectorImpl<OpFoldResult> &results) { 1890 return foldHook<ConcreteType>(op, operands, results); 1891 }; 1892 // The operation does not define a `fold` method. 1893 return [](Operation *op, ArrayRef<Attribute> operands, 1894 SmallVectorImpl<OpFoldResult> &results) { 1895 // In this case, we only need to fold the traits of the operation. 1896 return op_definition_impl::foldTraits<Traits<ConcreteType>...>( 1897 op, operands, results); 1898 }; 1899 } 1900 /// Return the result of folding a single result operation that defines a 1901 /// `fold` method. 1902 template <typename ConcreteOpT> 1903 static LogicalResult 1904 foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands, 1905 SmallVectorImpl<OpFoldResult> &results) { 1906 OpFoldResult result; 1907 if constexpr (has_fold_adaptor_single_result_v<ConcreteOpT>) { 1908 result = cast<ConcreteOpT>(op).fold( 1909 typename ConcreteOpT::FoldAdaptor(operands, cast<ConcreteOpT>(op))); 1910 } else { 1911 result = cast<ConcreteOpT>(op).fold(operands); 1912 } 1913 1914 // If the fold failed or was in-place, try to fold the traits of the 1915 // operation. 1916 if (!result || 1917 llvm::dyn_cast_if_present<Value>(result) == op->getResult(0)) { 1918 if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>( 1919 op, operands, results))) 1920 return success(); 1921 return success(static_cast<bool>(result)); 1922 } 1923 results.push_back(result); 1924 return success(); 1925 } 1926 /// Return the result of folding an operation that defines a `fold` method. 1927 template <typename ConcreteOpT> 1928 static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands, 1929 SmallVectorImpl<OpFoldResult> &results) { 1930 auto result = LogicalResult::failure(); 1931 if constexpr (has_fold_adaptor_v<ConcreteOpT>) { 1932 result = cast<ConcreteOpT>(op).fold( 1933 typename ConcreteOpT::FoldAdaptor(operands, cast<ConcreteOpT>(op)), 1934 results); 1935 } else { 1936 result = cast<ConcreteOpT>(op).fold(operands, results); 1937 } 1938 1939 // If the fold failed or was in-place, try to fold the traits of the 1940 // operation. 1941 if (failed(result) || results.empty()) { 1942 if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>( 1943 op, operands, results))) 1944 return success(); 1945 } 1946 return result; 1947 } 1948 1949 /// Implementation of `GetHasTraitFn` 1950 static OperationName::HasTraitFn getHasTraitFn() { 1951 return 1952 [](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); }; 1953 } 1954 /// Implementation of `PrintAssemblyFn` OperationName hook. 1955 static OperationName::PrintAssemblyFn getPrintAssemblyFn() { 1956 if constexpr (detect_has_print<ConcreteType>::value) 1957 return [](Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { 1958 OpState::printOpName(op, p, defaultDialect); 1959 return cast<ConcreteType>(op).print(p); 1960 }; 1961 return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) { 1962 return OpState::print(op, printer, defaultDialect); 1963 }; 1964 } 1965 1966 public: 1967 template <typename T> 1968 using InferredProperties = typename PropertiesSelector<T>::type; 1969 template <typename T = ConcreteType> 1970 InferredProperties<T> &getProperties() { 1971 if constexpr (!hasProperties()) 1972 return getEmptyProperties(); 1973 return *getOperation() 1974 ->getPropertiesStorageUnsafe() 1975 .template as<InferredProperties<T> *>(); 1976 } 1977 1978 /// This hook populates any unset default attrs when mapped to properties. 1979 template <typename T = ConcreteType> 1980 static void populateDefaultProperties(OperationName opName, 1981 InferredProperties<T> &properties) {} 1982 1983 /// Print the operation properties with names not included within 1984 /// 'elidedProps'. Unless overridden, this method will try to dispatch to a 1985 /// `printProperties` free-function if it exists, and otherwise by converting 1986 /// the properties to an Attribute. 1987 template <typename T> 1988 static void printProperties(MLIRContext *ctx, OpAsmPrinter &p, 1989 const T &properties, 1990 ArrayRef<StringRef> elidedProps = {}) { 1991 if constexpr (detect_has_print_properties<T>::value) 1992 return printProperties(p, properties, elidedProps); 1993 genericPrintProperties( 1994 p, ConcreteType::getPropertiesAsAttr(ctx, properties), elidedProps); 1995 } 1996 1997 /// Parses 'prop-dict' for the operation. Unless overridden, the method will 1998 /// parse the properties using the generic property dictionary using the 1999 /// '<{ ... }>' syntax. The resulting properties are stored within the 2000 /// property structure of 'result', accessible via 'getOrAddProperties'. 2001 template <typename T = ConcreteType> 2002 static ParseResult parseProperties(OpAsmParser &parser, 2003 OperationState &result) { 2004 if constexpr (detect_has_parse_properties<InferredProperties<T>>::value) { 2005 return parseProperties( 2006 parser, result.getOrAddProperties<InferredProperties<T>>()); 2007 } 2008 2009 Attribute propertyDictionary; 2010 if (genericParseProperties(parser, propertyDictionary)) 2011 return failure(); 2012 2013 // The generated 'setPropertiesFromParsedAttr', like 2014 // 'setPropertiesFromAttr', expects a 'DictionaryAttr' that is not null. 2015 // Use an empty dictionary in the case that the whole dictionary is 2016 // optional. 2017 if (!propertyDictionary) 2018 propertyDictionary = DictionaryAttr::get(result.getContext()); 2019 2020 auto emitError = [&]() { 2021 return mlir::emitError(result.location, "invalid properties ") 2022 << propertyDictionary << " for op " << result.name.getStringRef() 2023 << ": "; 2024 }; 2025 2026 // Copy the data from the dictionary attribute into the property struct of 2027 // the operation. This method is generated by ODS by default if there are 2028 // any occurrences of 'prop-dict' in the assembly format and should set 2029 // any properties that aren't parsed elsewhere. 2030 return ConcreteOpType::setPropertiesFromParsedAttr( 2031 result.getOrAddProperties<InferredProperties<T>>(), propertyDictionary, 2032 emitError); 2033 } 2034 2035 private: 2036 /// Implementation of `PopulateDefaultAttrsFn` OperationName hook. 2037 static OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn() { 2038 return ConcreteType::populateDefaultAttrs; 2039 } 2040 /// Implementation of `VerifyInvariantsFn` OperationName hook. 2041 static LogicalResult verifyInvariants(Operation *op) { 2042 static_assert(hasNoDataMembers(), 2043 "Op class shouldn't define new data members"); 2044 return failure( 2045 failed(op_definition_impl::verifyTraits<Traits<ConcreteType>...>(op)) || 2046 failed(cast<ConcreteType>(op).verify())); 2047 } 2048 static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() { 2049 return static_cast<LogicalResult (*)(Operation *)>(&verifyInvariants); 2050 } 2051 /// Implementation of `VerifyRegionInvariantsFn` OperationName hook. 2052 static LogicalResult verifyRegionInvariants(Operation *op) { 2053 static_assert(hasNoDataMembers(), 2054 "Op class shouldn't define new data members"); 2055 return failure( 2056 failed(op_definition_impl::verifyRegionTraits<Traits<ConcreteType>...>( 2057 op)) || 2058 failed(cast<ConcreteType>(op).verifyRegions())); 2059 } 2060 static OperationName::VerifyRegionInvariantsFn getVerifyRegionInvariantsFn() { 2061 return static_cast<LogicalResult (*)(Operation *)>(&verifyRegionInvariants); 2062 } 2063 2064 static constexpr bool hasNoDataMembers() { 2065 // Checking that the derived class does not define any member by comparing 2066 // its size to an ad-hoc EmptyOp. 2067 class EmptyOp : public Op<EmptyOp, Traits...> {}; 2068 return sizeof(ConcreteType) == sizeof(EmptyOp); 2069 } 2070 2071 /// Allow access to internal implementation methods. 2072 friend RegisteredOperationName; 2073 }; 2074 2075 /// This class represents the base of an operation interface. See the definition 2076 /// of `detail::Interface` for requirements on the `Traits` type. 2077 template <typename ConcreteType, typename Traits> 2078 class OpInterface 2079 : public detail::Interface<ConcreteType, Operation *, Traits, 2080 Op<ConcreteType>, OpTrait::TraitBase> { 2081 public: 2082 using Base = OpInterface<ConcreteType, Traits>; 2083 using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits, 2084 Op<ConcreteType>, OpTrait::TraitBase>; 2085 2086 /// Inherit the base class constructor. 2087 using InterfaceBase::InterfaceBase; 2088 2089 protected: 2090 /// Returns the impl interface instance for the given operation. 2091 static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) { 2092 OperationName name = op->getName(); 2093 2094 #ifndef NDEBUG 2095 // Check that the current interface isn't an unresolved promise for the 2096 // given operation. 2097 if (Dialect *dialect = name.getDialect()) { 2098 dialect_extension_detail::handleUseOfUndefinedPromisedInterface( 2099 *dialect, name.getTypeID(), ConcreteType::getInterfaceID(), 2100 llvm::getTypeName<ConcreteType>()); 2101 } 2102 #endif 2103 2104 // Access the raw interface from the operation info. 2105 if (std::optional<RegisteredOperationName> rInfo = 2106 name.getRegisteredInfo()) { 2107 if (auto *opIface = rInfo->getInterface<ConcreteType>()) 2108 return opIface; 2109 // Fallback to the dialect to provide it with a chance to implement this 2110 // interface for this operation. 2111 return rInfo->getDialect().getRegisteredInterfaceForOp<ConcreteType>( 2112 op->getName()); 2113 } 2114 // Fallback to the dialect to provide it with a chance to implement this 2115 // interface for this operation. 2116 if (Dialect *dialect = name.getDialect()) 2117 return dialect->getRegisteredInterfaceForOp<ConcreteType>(name); 2118 return nullptr; 2119 } 2120 2121 /// Allow access to `getInterfaceFor`. 2122 friend InterfaceBase; 2123 }; 2124 2125 } // namespace mlir 2126 2127 namespace llvm { 2128 2129 template <typename T> 2130 struct DenseMapInfo<T, 2131 std::enable_if_t<std::is_base_of<mlir::OpState, T>::value && 2132 !mlir::detail::IsInterface<T>::value>> { 2133 static inline T getEmptyKey() { 2134 auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); 2135 return T::getFromOpaquePointer(pointer); 2136 } 2137 static inline T getTombstoneKey() { 2138 auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); 2139 return T::getFromOpaquePointer(pointer); 2140 } 2141 static unsigned getHashValue(T val) { 2142 return hash_value(val.getAsOpaquePointer()); 2143 } 2144 static bool isEqual(T lhs, T rhs) { return lhs == rhs; } 2145 }; 2146 } // namespace llvm 2147 2148 #endif 2149