1 //===- Operation.cpp - Operation support code -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Operation.h" 10 #include "mlir/IR/Attributes.h" 11 #include "mlir/IR/BuiltinAttributes.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/IRMapping.h" 15 #include "mlir/IR/OpImplementation.h" 16 #include "mlir/IR/OperationSupport.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/Interfaces/FoldInterfaces.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include <numeric> 23 #include <optional> 24 25 using namespace mlir; 26 27 //===----------------------------------------------------------------------===// 28 // Operation 29 //===----------------------------------------------------------------------===// 30 31 /// Create a new Operation from operation state. 32 Operation *Operation::create(const OperationState &state) { 33 Operation *op = 34 create(state.location, state.name, state.types, state.operands, 35 state.attributes.getDictionary(state.getContext()), 36 state.properties, state.successors, state.regions); 37 if (LLVM_UNLIKELY(state.propertiesAttr)) { 38 assert(!state.properties); 39 LogicalResult result = 40 op->setPropertiesFromAttribute(state.propertiesAttr, 41 /*diagnostic=*/nullptr); 42 assert(result.succeeded() && "invalid properties in op creation"); 43 (void)result; 44 } 45 return op; 46 } 47 48 /// Create a new Operation with the specific fields. 49 Operation *Operation::create(Location location, OperationName name, 50 TypeRange resultTypes, ValueRange operands, 51 NamedAttrList &&attributes, 52 OpaqueProperties properties, BlockRange successors, 53 RegionRange regions) { 54 unsigned numRegions = regions.size(); 55 Operation *op = 56 create(location, name, resultTypes, operands, std::move(attributes), 57 properties, successors, numRegions); 58 for (unsigned i = 0; i < numRegions; ++i) 59 if (regions[i]) 60 op->getRegion(i).takeBody(*regions[i]); 61 return op; 62 } 63 64 /// Create a new Operation with the specific fields. 65 Operation *Operation::create(Location location, OperationName name, 66 TypeRange resultTypes, ValueRange operands, 67 NamedAttrList &&attributes, 68 OpaqueProperties properties, BlockRange successors, 69 unsigned numRegions) { 70 // Populate default attributes. 71 name.populateDefaultAttrs(attributes); 72 73 return create(location, name, resultTypes, operands, 74 attributes.getDictionary(location.getContext()), properties, 75 successors, numRegions); 76 } 77 78 /// Overload of create that takes an existing DictionaryAttr to avoid 79 /// unnecessarily uniquing a list of attributes. 80 Operation *Operation::create(Location location, OperationName name, 81 TypeRange resultTypes, ValueRange operands, 82 DictionaryAttr attributes, 83 OpaqueProperties properties, BlockRange successors, 84 unsigned numRegions) { 85 assert(llvm::all_of(resultTypes, [](Type t) { return t; }) && 86 "unexpected null result type"); 87 88 // We only need to allocate additional memory for a subset of results. 89 unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); 90 unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); 91 unsigned numSuccessors = successors.size(); 92 unsigned numOperands = operands.size(); 93 unsigned numResults = resultTypes.size(); 94 int opPropertiesAllocSize = llvm::alignTo<8>(name.getOpPropertyByteSize()); 95 96 // If the operation is known to have no operands, don't allocate an operand 97 // storage. 98 bool needsOperandStorage = 99 operands.empty() ? !name.hasTrait<OpTrait::ZeroOperands>() : true; 100 101 // Compute the byte size for the operation and the operand storage. This takes 102 // into account the size of the operation, its trailing objects, and its 103 // prefixed objects. 104 size_t byteSize = 105 totalSizeToAlloc<detail::OperandStorage, detail::OpProperties, 106 BlockOperand, Region, OpOperand>( 107 needsOperandStorage ? 1 : 0, opPropertiesAllocSize, numSuccessors, 108 numRegions, numOperands); 109 size_t prefixByteSize = llvm::alignTo( 110 Operation::prefixAllocSize(numTrailingResults, numInlineResults), 111 alignof(Operation)); 112 char *mallocMem = reinterpret_cast<char *>(malloc(byteSize + prefixByteSize)); 113 void *rawMem = mallocMem + prefixByteSize; 114 115 // Create the new Operation. 116 Operation *op = ::new (rawMem) Operation( 117 location, name, numResults, numSuccessors, numRegions, 118 opPropertiesAllocSize, attributes, properties, needsOperandStorage); 119 120 assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && 121 "unexpected successors in a non-terminator operation"); 122 123 // Initialize the results. 124 auto resultTypeIt = resultTypes.begin(); 125 for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt) 126 new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i); 127 for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) { 128 new (op->getOutOfLineOpResult(i)) 129 detail::OutOfLineOpResult(*resultTypeIt, i); 130 } 131 132 // Initialize the regions. 133 for (unsigned i = 0; i != numRegions; ++i) 134 new (&op->getRegion(i)) Region(op); 135 136 // Initialize the operands. 137 if (needsOperandStorage) { 138 new (&op->getOperandStorage()) detail::OperandStorage( 139 op, op->getTrailingObjects<OpOperand>(), operands); 140 } 141 142 // Initialize the successors. 143 auto blockOperands = op->getBlockOperands(); 144 for (unsigned i = 0; i != numSuccessors; ++i) 145 new (&blockOperands[i]) BlockOperand(op, successors[i]); 146 147 // This must be done after properties are initalized. 148 op->setAttrs(attributes); 149 150 return op; 151 } 152 153 Operation::Operation(Location location, OperationName name, unsigned numResults, 154 unsigned numSuccessors, unsigned numRegions, 155 int fullPropertiesStorageSize, DictionaryAttr attributes, 156 OpaqueProperties properties, bool hasOperandStorage) 157 : location(location), numResults(numResults), numSuccs(numSuccessors), 158 numRegions(numRegions), hasOperandStorage(hasOperandStorage), 159 propertiesStorageSize((fullPropertiesStorageSize + 7) / 8), name(name) { 160 assert(attributes && "unexpected null attribute dictionary"); 161 assert(fullPropertiesStorageSize <= propertiesCapacity && 162 "Properties size overflow"); 163 #ifndef NDEBUG 164 if (!getDialect() && !getContext()->allowsUnregisteredDialects()) 165 llvm::report_fatal_error( 166 name.getStringRef() + 167 " created with unregistered dialect. If this is intended, please call " 168 "allowUnregisteredDialects() on the MLIRContext, or use " 169 "-allow-unregistered-dialect with the MLIR tool used."); 170 #endif 171 if (fullPropertiesStorageSize) 172 name.initOpProperties(getPropertiesStorage(), properties); 173 } 174 175 // Operations are deleted through the destroy() member because they are 176 // allocated via malloc. 177 Operation::~Operation() { 178 assert(block == nullptr && "operation destroyed but still in a block"); 179 #ifndef NDEBUG 180 if (!use_empty()) { 181 { 182 InFlightDiagnostic diag = 183 emitOpError("operation destroyed but still has uses"); 184 for (Operation *user : getUsers()) 185 diag.attachNote(user->getLoc()) << "- use: " << *user << "\n"; 186 } 187 llvm::report_fatal_error("operation destroyed but still has uses"); 188 } 189 #endif 190 // Explicitly run the destructors for the operands. 191 if (hasOperandStorage) 192 getOperandStorage().~OperandStorage(); 193 194 // Explicitly run the destructors for the successors. 195 for (auto &successor : getBlockOperands()) 196 successor.~BlockOperand(); 197 198 // Explicitly destroy the regions. 199 for (auto ®ion : getRegions()) 200 region.~Region(); 201 if (propertiesStorageSize) 202 name.destroyOpProperties(getPropertiesStorage()); 203 } 204 205 /// Destroy this operation or one of its subclasses. 206 void Operation::destroy() { 207 // Operations may have additional prefixed allocation, which needs to be 208 // accounted for here when computing the address to free. 209 char *rawMem = reinterpret_cast<char *>(this) - 210 llvm::alignTo(prefixAllocSize(), alignof(Operation)); 211 this->~Operation(); 212 free(rawMem); 213 } 214 215 /// Return true if this operation is a proper ancestor of the `other` 216 /// operation. 217 bool Operation::isProperAncestor(Operation *other) { 218 while ((other = other->getParentOp())) 219 if (this == other) 220 return true; 221 return false; 222 } 223 224 /// Replace any uses of 'from' with 'to' within this operation. 225 void Operation::replaceUsesOfWith(Value from, Value to) { 226 if (from == to) 227 return; 228 for (auto &operand : getOpOperands()) 229 if (operand.get() == from) 230 operand.set(to); 231 } 232 233 /// Replace the current operands of this operation with the ones provided in 234 /// 'operands'. 235 void Operation::setOperands(ValueRange operands) { 236 if (LLVM_LIKELY(hasOperandStorage)) 237 return getOperandStorage().setOperands(this, operands); 238 assert(operands.empty() && "setting operands without an operand storage"); 239 } 240 241 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 242 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 243 /// than the range pointed to by 'start'+'length'. 244 void Operation::setOperands(unsigned start, unsigned length, 245 ValueRange operands) { 246 assert((start + length) <= getNumOperands() && 247 "invalid operand range specified"); 248 if (LLVM_LIKELY(hasOperandStorage)) 249 return getOperandStorage().setOperands(this, start, length, operands); 250 assert(operands.empty() && "setting operands without an operand storage"); 251 } 252 253 /// Insert the given operands into the operand list at the given 'index'. 254 void Operation::insertOperands(unsigned index, ValueRange operands) { 255 if (LLVM_LIKELY(hasOperandStorage)) 256 return setOperands(index, /*length=*/0, operands); 257 assert(operands.empty() && "inserting operands without an operand storage"); 258 } 259 260 //===----------------------------------------------------------------------===// 261 // Diagnostics 262 //===----------------------------------------------------------------------===// 263 264 /// Emit an error about fatal conditions with this operation, reporting up to 265 /// any diagnostic handlers that may be listening. 266 InFlightDiagnostic Operation::emitError(const Twine &message) { 267 InFlightDiagnostic diag = mlir::emitError(getLoc(), message); 268 if (getContext()->shouldPrintOpOnDiagnostic()) { 269 diag.attachNote(getLoc()) 270 .append("see current operation: ") 271 .appendOp(*this, OpPrintingFlags().printGenericOpForm()); 272 } 273 return diag; 274 } 275 276 /// Emit a warning about this operation, reporting up to any diagnostic 277 /// handlers that may be listening. 278 InFlightDiagnostic Operation::emitWarning(const Twine &message) { 279 InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); 280 if (getContext()->shouldPrintOpOnDiagnostic()) 281 diag.attachNote(getLoc()) << "see current operation: " << *this; 282 return diag; 283 } 284 285 /// Emit a remark about this operation, reporting up to any diagnostic 286 /// handlers that may be listening. 287 InFlightDiagnostic Operation::emitRemark(const Twine &message) { 288 InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); 289 if (getContext()->shouldPrintOpOnDiagnostic()) 290 diag.attachNote(getLoc()) << "see current operation: " << *this; 291 return diag; 292 } 293 294 DictionaryAttr Operation::getAttrDictionary() { 295 if (getPropertiesStorageSize()) { 296 NamedAttrList attrsList = attrs; 297 getName().populateInherentAttrs(this, attrsList); 298 return attrsList.getDictionary(getContext()); 299 } 300 return attrs; 301 } 302 303 void Operation::setAttrs(DictionaryAttr newAttrs) { 304 assert(newAttrs && "expected valid attribute dictionary"); 305 if (getPropertiesStorageSize()) { 306 attrs = DictionaryAttr::get(getContext(), {}); 307 for (const NamedAttribute &attr : newAttrs) 308 setAttr(attr.getName(), attr.getValue()); 309 return; 310 } 311 attrs = newAttrs; 312 } 313 void Operation::setAttrs(ArrayRef<NamedAttribute> newAttrs) { 314 if (getPropertiesStorageSize()) { 315 setAttrs(DictionaryAttr::get(getContext(), {})); 316 for (const NamedAttribute &attr : newAttrs) 317 setAttr(attr.getName(), attr.getValue()); 318 return; 319 } 320 attrs = DictionaryAttr::get(getContext(), newAttrs); 321 } 322 323 std::optional<Attribute> Operation::getInherentAttr(StringRef name) { 324 return getName().getInherentAttr(this, name); 325 } 326 327 void Operation::setInherentAttr(StringAttr name, Attribute value) { 328 getName().setInherentAttr(this, name, value); 329 } 330 331 Attribute Operation::getPropertiesAsAttribute() { 332 std::optional<RegisteredOperationName> info = getRegisteredInfo(); 333 if (LLVM_UNLIKELY(!info)) 334 return *getPropertiesStorage().as<Attribute *>(); 335 return info->getOpPropertiesAsAttribute(this); 336 } 337 LogicalResult 338 Operation::setPropertiesFromAttribute(Attribute attr, 339 InFlightDiagnostic *diagnostic) { 340 std::optional<RegisteredOperationName> info = getRegisteredInfo(); 341 if (LLVM_UNLIKELY(!info)) { 342 *getPropertiesStorage().as<Attribute *>() = attr; 343 return success(); 344 } 345 return info->setOpPropertiesFromAttribute(this, attr, diagnostic); 346 } 347 348 void Operation::copyProperties(OpaqueProperties rhs) { 349 name.copyOpProperties(getPropertiesStorage(), rhs); 350 } 351 352 llvm::hash_code Operation::hashProperties() { 353 return name.hashOpProperties(getPropertiesStorage()); 354 } 355 356 //===----------------------------------------------------------------------===// 357 // Operation Ordering 358 //===----------------------------------------------------------------------===// 359 360 constexpr unsigned Operation::kInvalidOrderIdx; 361 constexpr unsigned Operation::kOrderStride; 362 363 /// Given an operation 'other' that is within the same parent block, return 364 /// whether the current operation is before 'other' in the operation list 365 /// of the parent block. 366 /// Note: This function has an average complexity of O(1), but worst case may 367 /// take O(N) where N is the number of operations within the parent block. 368 bool Operation::isBeforeInBlock(Operation *other) { 369 assert(block && "Operations without parent blocks have no order."); 370 assert(other && other->block == block && 371 "Expected other operation to have the same parent block."); 372 // If the order of the block is already invalid, directly recompute the 373 // parent. 374 if (!block->isOpOrderValid()) { 375 block->recomputeOpOrder(); 376 } else { 377 // Update the order either operation if necessary. 378 updateOrderIfNecessary(); 379 other->updateOrderIfNecessary(); 380 } 381 382 return orderIndex < other->orderIndex; 383 } 384 385 /// Update the order index of this operation of this operation if necessary, 386 /// potentially recomputing the order of the parent block. 387 void Operation::updateOrderIfNecessary() { 388 assert(block && "expected valid parent"); 389 390 // If the order is valid for this operation there is nothing to do. 391 if (hasValidOrder()) 392 return; 393 Operation *blockFront = &block->front(); 394 Operation *blockBack = &block->back(); 395 396 // This method is expected to only be invoked on blocks with more than one 397 // operation. 398 assert(blockFront != blockBack && "expected more than one operation"); 399 400 // If the operation is at the end of the block. 401 if (this == blockBack) { 402 Operation *prevNode = getPrevNode(); 403 if (!prevNode->hasValidOrder()) 404 return block->recomputeOpOrder(); 405 406 // Add the stride to the previous operation. 407 orderIndex = prevNode->orderIndex + kOrderStride; 408 return; 409 } 410 411 // If this is the first operation try to use the next operation to compute the 412 // ordering. 413 if (this == blockFront) { 414 Operation *nextNode = getNextNode(); 415 if (!nextNode->hasValidOrder()) 416 return block->recomputeOpOrder(); 417 // There is no order to give this operation. 418 if (nextNode->orderIndex == 0) 419 return block->recomputeOpOrder(); 420 421 // If we can't use the stride, just take the middle value left. This is safe 422 // because we know there is at least one valid index to assign to. 423 if (nextNode->orderIndex <= kOrderStride) 424 orderIndex = (nextNode->orderIndex / 2); 425 else 426 orderIndex = kOrderStride; 427 return; 428 } 429 430 // Otherwise, this operation is between two others. Place this operation in 431 // the middle of the previous and next if possible. 432 Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); 433 if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) 434 return block->recomputeOpOrder(); 435 unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; 436 437 // Check to see if there is a valid order between the two. 438 if (prevOrder + 1 == nextOrder) 439 return block->recomputeOpOrder(); 440 orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); 441 } 442 443 //===----------------------------------------------------------------------===// 444 // ilist_traits for Operation 445 //===----------------------------------------------------------------------===// 446 447 auto llvm::ilist_detail::SpecificNodeAccess< 448 typename llvm::ilist_detail::compute_node_options< 449 ::mlir::Operation>::type>::getNodePtr(pointer n) -> node_type * { 450 return NodeAccess::getNodePtr<OptionsT>(n); 451 } 452 453 auto llvm::ilist_detail::SpecificNodeAccess< 454 typename llvm::ilist_detail::compute_node_options< 455 ::mlir::Operation>::type>::getNodePtr(const_pointer n) 456 -> const node_type * { 457 return NodeAccess::getNodePtr<OptionsT>(n); 458 } 459 460 auto llvm::ilist_detail::SpecificNodeAccess< 461 typename llvm::ilist_detail::compute_node_options< 462 ::mlir::Operation>::type>::getValuePtr(node_type *n) -> pointer { 463 return NodeAccess::getValuePtr<OptionsT>(n); 464 } 465 466 auto llvm::ilist_detail::SpecificNodeAccess< 467 typename llvm::ilist_detail::compute_node_options< 468 ::mlir::Operation>::type>::getValuePtr(const node_type *n) 469 -> const_pointer { 470 return NodeAccess::getValuePtr<OptionsT>(n); 471 } 472 473 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { 474 op->destroy(); 475 } 476 477 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { 478 size_t offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); 479 iplist<Operation> *anchor(static_cast<iplist<Operation> *>(this)); 480 return reinterpret_cast<Block *>(reinterpret_cast<char *>(anchor) - offset); 481 } 482 483 /// This is a trait method invoked when an operation is added to a block. We 484 /// keep the block pointer up to date. 485 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { 486 assert(!op->getBlock() && "already in an operation block!"); 487 op->block = getContainingBlock(); 488 489 // Invalidate the order on the operation. 490 op->orderIndex = Operation::kInvalidOrderIdx; 491 } 492 493 /// This is a trait method invoked when an operation is removed from a block. 494 /// We keep the block pointer up to date. 495 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { 496 assert(op->block && "not already in an operation block!"); 497 op->block = nullptr; 498 } 499 500 /// This is a trait method invoked when an operation is moved from one block 501 /// to another. We keep the block pointer up to date. 502 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( 503 ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { 504 Block *curParent = getContainingBlock(); 505 506 // Invalidate the ordering of the parent block. 507 curParent->invalidateOpOrder(); 508 509 // If we are transferring operations within the same block, the block 510 // pointer doesn't need to be updated. 511 if (curParent == otherList.getContainingBlock()) 512 return; 513 514 // Update the 'block' member of each operation. 515 for (; first != last; ++first) 516 first->block = curParent; 517 } 518 519 /// Remove this operation (and its descendants) from its Block and delete 520 /// all of them. 521 void Operation::erase() { 522 if (auto *parent = getBlock()) 523 parent->getOperations().erase(this); 524 else 525 destroy(); 526 } 527 528 /// Remove the operation from its parent block, but don't delete it. 529 void Operation::remove() { 530 if (Block *parent = getBlock()) 531 parent->getOperations().remove(this); 532 } 533 534 /// Unlink this operation from its current block and insert it right before 535 /// `existingOp` which may be in the same or another block in the same 536 /// function. 537 void Operation::moveBefore(Operation *existingOp) { 538 moveBefore(existingOp->getBlock(), existingOp->getIterator()); 539 } 540 541 /// Unlink this operation from its current basic block and insert it right 542 /// before `iterator` in the specified basic block. 543 void Operation::moveBefore(Block *block, 544 llvm::iplist<Operation>::iterator iterator) { 545 block->getOperations().splice(iterator, getBlock()->getOperations(), 546 getIterator()); 547 } 548 549 /// Unlink this operation from its current block and insert it right after 550 /// `existingOp` which may be in the same or another block in the same function. 551 void Operation::moveAfter(Operation *existingOp) { 552 moveAfter(existingOp->getBlock(), existingOp->getIterator()); 553 } 554 555 /// Unlink this operation from its current block and insert it right after 556 /// `iterator` in the specified block. 557 void Operation::moveAfter(Block *block, 558 llvm::iplist<Operation>::iterator iterator) { 559 assert(iterator != block->end() && "cannot move after end of block"); 560 moveBefore(block, std::next(iterator)); 561 } 562 563 /// This drops all operand uses from this operation, which is an essential 564 /// step in breaking cyclic dependences between references when they are to 565 /// be deleted. 566 void Operation::dropAllReferences() { 567 for (auto &op : getOpOperands()) 568 op.drop(); 569 570 for (auto ®ion : getRegions()) 571 region.dropAllReferences(); 572 573 for (auto &dest : getBlockOperands()) 574 dest.drop(); 575 } 576 577 /// This drops all uses of any values defined by this operation or its nested 578 /// regions, wherever they are located. 579 void Operation::dropAllDefinedValueUses() { 580 dropAllUses(); 581 582 for (auto ®ion : getRegions()) 583 for (auto &block : region) 584 block.dropAllDefinedValueUses(); 585 } 586 587 void Operation::setSuccessor(Block *block, unsigned index) { 588 assert(index < getNumSuccessors()); 589 getBlockOperands()[index].set(block); 590 } 591 592 /// Attempt to fold this operation using the Op's registered foldHook. 593 LogicalResult Operation::fold(ArrayRef<Attribute> operands, 594 SmallVectorImpl<OpFoldResult> &results) { 595 // If we have a registered operation definition matching this one, use it to 596 // try to constant fold the operation. 597 if (succeeded(name.foldHook(this, operands, results))) 598 return success(); 599 600 // Otherwise, fall back on the dialect hook to handle it. 601 Dialect *dialect = getDialect(); 602 if (!dialect) 603 return failure(); 604 605 auto *interface = dyn_cast<DialectFoldInterface>(dialect); 606 if (!interface) 607 return failure(); 608 609 return interface->fold(this, operands, results); 610 } 611 612 /// Emit an error with the op name prefixed, like "'dim' op " which is 613 /// convenient for verifiers. 614 InFlightDiagnostic Operation::emitOpError(const Twine &message) { 615 return emitError() << "'" << getName() << "' op " << message; 616 } 617 618 //===----------------------------------------------------------------------===// 619 // Operation Cloning 620 //===----------------------------------------------------------------------===// 621 622 Operation::CloneOptions::CloneOptions() 623 : cloneRegionsFlag(false), cloneOperandsFlag(false) {} 624 625 Operation::CloneOptions::CloneOptions(bool cloneRegions, bool cloneOperands) 626 : cloneRegionsFlag(cloneRegions), cloneOperandsFlag(cloneOperands) {} 627 628 Operation::CloneOptions Operation::CloneOptions::all() { 629 return CloneOptions().cloneRegions().cloneOperands(); 630 } 631 632 Operation::CloneOptions &Operation::CloneOptions::cloneRegions(bool enable) { 633 cloneRegionsFlag = enable; 634 return *this; 635 } 636 637 Operation::CloneOptions &Operation::CloneOptions::cloneOperands(bool enable) { 638 cloneOperandsFlag = enable; 639 return *this; 640 } 641 642 /// Create a deep copy of this operation but keep the operation regions empty. 643 /// Operands are remapped using `mapper` (if present), and `mapper` is updated 644 /// to contain the results. The `mapResults` flag specifies whether the results 645 /// of the cloned operation should be added to the map. 646 Operation *Operation::cloneWithoutRegions(IRMapping &mapper) { 647 return clone(mapper, CloneOptions::all().cloneRegions(false)); 648 } 649 650 Operation *Operation::cloneWithoutRegions() { 651 IRMapping mapper; 652 return cloneWithoutRegions(mapper); 653 } 654 655 /// Create a deep copy of this operation, remapping any operands that use 656 /// values outside of the operation using the map that is provided (leaving 657 /// them alone if no entry is present). Replaces references to cloned 658 /// sub-operations to the corresponding operation that is copied, and adds 659 /// those mappings to the map. 660 Operation *Operation::clone(IRMapping &mapper, CloneOptions options) { 661 SmallVector<Value, 8> operands; 662 SmallVector<Block *, 2> successors; 663 664 // Remap the operands. 665 if (options.shouldCloneOperands()) { 666 operands.reserve(getNumOperands()); 667 for (auto opValue : getOperands()) 668 operands.push_back(mapper.lookupOrDefault(opValue)); 669 } 670 671 // Remap the successors. 672 successors.reserve(getNumSuccessors()); 673 for (Block *successor : getSuccessors()) 674 successors.push_back(mapper.lookupOrDefault(successor)); 675 676 // Create the new operation. 677 auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, 678 getPropertiesStorage(), successors, getNumRegions()); 679 mapper.map(this, newOp); 680 681 // Clone the regions. 682 if (options.shouldCloneRegions()) { 683 for (unsigned i = 0; i != numRegions; ++i) 684 getRegion(i).cloneInto(&newOp->getRegion(i), mapper); 685 } 686 687 // Remember the mapping of any results. 688 for (unsigned i = 0, e = getNumResults(); i != e; ++i) 689 mapper.map(getResult(i), newOp->getResult(i)); 690 691 return newOp; 692 } 693 694 Operation *Operation::clone(CloneOptions options) { 695 IRMapping mapper; 696 return clone(mapper, options); 697 } 698 699 //===----------------------------------------------------------------------===// 700 // OpState trait class. 701 //===----------------------------------------------------------------------===// 702 703 // The fallback for the parser is to try for a dialect operation parser. 704 // Otherwise, reject the custom assembly form. 705 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { 706 if (auto parseFn = result.name.getDialect()->getParseOperationHook( 707 result.name.getStringRef())) 708 return (*parseFn)(parser, result); 709 return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); 710 } 711 712 // The fallback for the printer is to try for a dialect operation printer. 713 // Otherwise, it prints the generic form. 714 void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { 715 if (auto printFn = op->getDialect()->getOperationPrinter(op)) { 716 printOpName(op, p, defaultDialect); 717 printFn(op, p); 718 } else { 719 p.printGenericOp(op); 720 } 721 } 722 723 /// Print an operation name, eliding the dialect prefix if necessary and doesn't 724 /// lead to ambiguities. 725 void OpState::printOpName(Operation *op, OpAsmPrinter &p, 726 StringRef defaultDialect) { 727 StringRef name = op->getName().getStringRef(); 728 if (name.startswith((defaultDialect + ".").str()) && name.count('.') == 1) 729 name = name.drop_front(defaultDialect.size() + 1); 730 p.getStream() << name; 731 } 732 733 /// Parse properties as a Attribute. 734 ParseResult OpState::genericParseProperties(OpAsmParser &parser, 735 Attribute &result) { 736 if (parser.parseLess() || parser.parseAttribute(result) || 737 parser.parseGreater()) 738 return failure(); 739 return success(); 740 } 741 742 /// Print the properties as a Attribute. 743 void OpState::genericPrintProperties(OpAsmPrinter &p, Attribute properties) { 744 p << "<" << properties << ">"; 745 } 746 747 /// Emit an error about fatal conditions with this operation, reporting up to 748 /// any diagnostic handlers that may be listening. 749 InFlightDiagnostic OpState::emitError(const Twine &message) { 750 return getOperation()->emitError(message); 751 } 752 753 /// Emit an error with the op name prefixed, like "'dim' op " which is 754 /// convenient for verifiers. 755 InFlightDiagnostic OpState::emitOpError(const Twine &message) { 756 return getOperation()->emitOpError(message); 757 } 758 759 /// Emit a warning about this operation, reporting up to any diagnostic 760 /// handlers that may be listening. 761 InFlightDiagnostic OpState::emitWarning(const Twine &message) { 762 return getOperation()->emitWarning(message); 763 } 764 765 /// Emit a remark about this operation, reporting up to any diagnostic 766 /// handlers that may be listening. 767 InFlightDiagnostic OpState::emitRemark(const Twine &message) { 768 return getOperation()->emitRemark(message); 769 } 770 771 //===----------------------------------------------------------------------===// 772 // Op Trait implementations 773 //===----------------------------------------------------------------------===// 774 775 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { 776 if (op->getNumOperands() == 1) { 777 auto *argumentOp = op->getOperand(0).getDefiningOp(); 778 if (argumentOp && op->getName() == argumentOp->getName()) { 779 // Replace the outer operation output with the inner operation. 780 return op->getOperand(0); 781 } 782 } else if (op->getOperand(0) == op->getOperand(1)) { 783 return op->getOperand(0); 784 } 785 786 return {}; 787 } 788 789 OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { 790 auto *argumentOp = op->getOperand(0).getDefiningOp(); 791 if (argumentOp && op->getName() == argumentOp->getName()) { 792 // Replace the outer involutions output with inner's input. 793 return argumentOp->getOperand(0); 794 } 795 796 return {}; 797 } 798 799 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { 800 if (op->getNumOperands() != 0) 801 return op->emitOpError() << "requires zero operands"; 802 return success(); 803 } 804 805 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { 806 if (op->getNumOperands() != 1) 807 return op->emitOpError() << "requires a single operand"; 808 return success(); 809 } 810 811 LogicalResult OpTrait::impl::verifyNOperands(Operation *op, 812 unsigned numOperands) { 813 if (op->getNumOperands() != numOperands) { 814 return op->emitOpError() << "expected " << numOperands 815 << " operands, but found " << op->getNumOperands(); 816 } 817 return success(); 818 } 819 820 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, 821 unsigned numOperands) { 822 if (op->getNumOperands() < numOperands) 823 return op->emitOpError() 824 << "expected " << numOperands << " or more operands, but found " 825 << op->getNumOperands(); 826 return success(); 827 } 828 829 /// If this is a vector type, or a tensor type, return the scalar element type 830 /// that it is built around, otherwise return the type unmodified. 831 static Type getTensorOrVectorElementType(Type type) { 832 if (auto vec = llvm::dyn_cast<VectorType>(type)) 833 return vec.getElementType(); 834 835 // Look through tensor<vector<...>> to find the underlying element type. 836 if (auto tensor = llvm::dyn_cast<TensorType>(type)) 837 return getTensorOrVectorElementType(tensor.getElementType()); 838 return type; 839 } 840 841 LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { 842 // FIXME: Add back check for no side effects on operation. 843 // Currently adding it would cause the shared library build 844 // to fail since there would be a dependency of IR on SideEffectInterfaces 845 // which is cyclical. 846 return success(); 847 } 848 849 LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { 850 // FIXME: Add back check for no side effects on operation. 851 // Currently adding it would cause the shared library build 852 // to fail since there would be a dependency of IR on SideEffectInterfaces 853 // which is cyclical. 854 return success(); 855 } 856 857 LogicalResult 858 OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { 859 for (auto opType : op->getOperandTypes()) { 860 auto type = getTensorOrVectorElementType(opType); 861 if (!type.isSignlessIntOrIndex()) 862 return op->emitOpError() << "requires an integer or index type"; 863 } 864 return success(); 865 } 866 867 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { 868 for (auto opType : op->getOperandTypes()) { 869 auto type = getTensorOrVectorElementType(opType); 870 if (!llvm::isa<FloatType>(type)) 871 return op->emitOpError("requires a float type"); 872 } 873 return success(); 874 } 875 876 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { 877 // Zero or one operand always have the "same" type. 878 unsigned nOperands = op->getNumOperands(); 879 if (nOperands < 2) 880 return success(); 881 882 auto type = op->getOperand(0).getType(); 883 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) 884 if (opType != type) 885 return op->emitOpError() << "requires all operands to have the same type"; 886 return success(); 887 } 888 889 LogicalResult OpTrait::impl::verifyZeroRegions(Operation *op) { 890 if (op->getNumRegions() != 0) 891 return op->emitOpError() << "requires zero regions"; 892 return success(); 893 } 894 895 LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { 896 if (op->getNumRegions() != 1) 897 return op->emitOpError() << "requires one region"; 898 return success(); 899 } 900 901 LogicalResult OpTrait::impl::verifyNRegions(Operation *op, 902 unsigned numRegions) { 903 if (op->getNumRegions() != numRegions) 904 return op->emitOpError() << "expected " << numRegions << " regions"; 905 return success(); 906 } 907 908 LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, 909 unsigned numRegions) { 910 if (op->getNumRegions() < numRegions) 911 return op->emitOpError() << "expected " << numRegions << " or more regions"; 912 return success(); 913 } 914 915 LogicalResult OpTrait::impl::verifyZeroResults(Operation *op) { 916 if (op->getNumResults() != 0) 917 return op->emitOpError() << "requires zero results"; 918 return success(); 919 } 920 921 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { 922 if (op->getNumResults() != 1) 923 return op->emitOpError() << "requires one result"; 924 return success(); 925 } 926 927 LogicalResult OpTrait::impl::verifyNResults(Operation *op, 928 unsigned numOperands) { 929 if (op->getNumResults() != numOperands) 930 return op->emitOpError() << "expected " << numOperands << " results"; 931 return success(); 932 } 933 934 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, 935 unsigned numOperands) { 936 if (op->getNumResults() < numOperands) 937 return op->emitOpError() 938 << "expected " << numOperands << " or more results"; 939 return success(); 940 } 941 942 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { 943 if (failed(verifyAtLeastNOperands(op, 1))) 944 return failure(); 945 946 if (failed(verifyCompatibleShapes(op->getOperandTypes()))) 947 return op->emitOpError() << "requires the same shape for all operands"; 948 949 return success(); 950 } 951 952 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { 953 if (failed(verifyAtLeastNOperands(op, 1)) || 954 failed(verifyAtLeastNResults(op, 1))) 955 return failure(); 956 957 SmallVector<Type, 8> types(op->getOperandTypes()); 958 types.append(llvm::to_vector<4>(op->getResultTypes())); 959 960 if (failed(verifyCompatibleShapes(types))) 961 return op->emitOpError() 962 << "requires the same shape for all operands and results"; 963 964 return success(); 965 } 966 967 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { 968 if (failed(verifyAtLeastNOperands(op, 1))) 969 return failure(); 970 auto elementType = getElementTypeOrSelf(op->getOperand(0)); 971 972 for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { 973 if (getElementTypeOrSelf(operand) != elementType) 974 return op->emitOpError("requires the same element type for all operands"); 975 } 976 977 return success(); 978 } 979 980 LogicalResult 981 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { 982 if (failed(verifyAtLeastNOperands(op, 1)) || 983 failed(verifyAtLeastNResults(op, 1))) 984 return failure(); 985 986 auto elementType = getElementTypeOrSelf(op->getResult(0)); 987 988 // Verify result element type matches first result's element type. 989 for (auto result : llvm::drop_begin(op->getResults(), 1)) { 990 if (getElementTypeOrSelf(result) != elementType) 991 return op->emitOpError( 992 "requires the same element type for all operands and results"); 993 } 994 995 // Verify operand's element type matches first result's element type. 996 for (auto operand : op->getOperands()) { 997 if (getElementTypeOrSelf(operand) != elementType) 998 return op->emitOpError( 999 "requires the same element type for all operands and results"); 1000 } 1001 1002 return success(); 1003 } 1004 1005 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { 1006 if (failed(verifyAtLeastNOperands(op, 1)) || 1007 failed(verifyAtLeastNResults(op, 1))) 1008 return failure(); 1009 1010 auto type = op->getResult(0).getType(); 1011 auto elementType = getElementTypeOrSelf(type); 1012 Attribute encoding = nullptr; 1013 if (auto rankedType = dyn_cast<RankedTensorType>(type)) 1014 encoding = rankedType.getEncoding(); 1015 for (auto resultType : llvm::drop_begin(op->getResultTypes())) { 1016 if (getElementTypeOrSelf(resultType) != elementType || 1017 failed(verifyCompatibleShape(resultType, type))) 1018 return op->emitOpError() 1019 << "requires the same type for all operands and results"; 1020 if (encoding) 1021 if (auto rankedType = dyn_cast<RankedTensorType>(resultType); 1022 encoding != rankedType.getEncoding()) 1023 return op->emitOpError() 1024 << "requires the same encoding for all operands and results"; 1025 } 1026 for (auto opType : op->getOperandTypes()) { 1027 if (getElementTypeOrSelf(opType) != elementType || 1028 failed(verifyCompatibleShape(opType, type))) 1029 return op->emitOpError() 1030 << "requires the same type for all operands and results"; 1031 if (encoding) 1032 if (auto rankedType = dyn_cast<RankedTensorType>(opType); 1033 encoding != rankedType.getEncoding()) 1034 return op->emitOpError() 1035 << "requires the same encoding for all operands and results"; 1036 } 1037 return success(); 1038 } 1039 1040 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { 1041 Block *block = op->getBlock(); 1042 // Verify that the operation is at the end of the respective parent block. 1043 if (!block || &block->back() != op) 1044 return op->emitOpError("must be the last operation in the parent block"); 1045 return success(); 1046 } 1047 1048 static LogicalResult verifyTerminatorSuccessors(Operation *op) { 1049 auto *parent = op->getParentRegion(); 1050 1051 // Verify that the operands lines up with the BB arguments in the successor. 1052 for (Block *succ : op->getSuccessors()) 1053 if (succ->getParent() != parent) 1054 return op->emitError("reference to block defined in another region"); 1055 return success(); 1056 } 1057 1058 LogicalResult OpTrait::impl::verifyZeroSuccessors(Operation *op) { 1059 if (op->getNumSuccessors() != 0) { 1060 return op->emitOpError("requires 0 successors but found ") 1061 << op->getNumSuccessors(); 1062 } 1063 return success(); 1064 } 1065 1066 LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { 1067 if (op->getNumSuccessors() != 1) { 1068 return op->emitOpError("requires 1 successor but found ") 1069 << op->getNumSuccessors(); 1070 } 1071 return verifyTerminatorSuccessors(op); 1072 } 1073 LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, 1074 unsigned numSuccessors) { 1075 if (op->getNumSuccessors() != numSuccessors) { 1076 return op->emitOpError("requires ") 1077 << numSuccessors << " successors but found " 1078 << op->getNumSuccessors(); 1079 } 1080 return verifyTerminatorSuccessors(op); 1081 } 1082 LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, 1083 unsigned numSuccessors) { 1084 if (op->getNumSuccessors() < numSuccessors) { 1085 return op->emitOpError("requires at least ") 1086 << numSuccessors << " successors but found " 1087 << op->getNumSuccessors(); 1088 } 1089 return verifyTerminatorSuccessors(op); 1090 } 1091 1092 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { 1093 for (auto resultType : op->getResultTypes()) { 1094 auto elementType = getTensorOrVectorElementType(resultType); 1095 bool isBoolType = elementType.isInteger(1); 1096 if (!isBoolType) 1097 return op->emitOpError() << "requires a bool result type"; 1098 } 1099 1100 return success(); 1101 } 1102 1103 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { 1104 for (auto resultType : op->getResultTypes()) 1105 if (!llvm::isa<FloatType>(getTensorOrVectorElementType(resultType))) 1106 return op->emitOpError() << "requires a floating point type"; 1107 1108 return success(); 1109 } 1110 1111 LogicalResult 1112 OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { 1113 for (auto resultType : op->getResultTypes()) 1114 if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) 1115 return op->emitOpError() << "requires an integer or index type"; 1116 return success(); 1117 } 1118 1119 LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op, 1120 StringRef attrName, 1121 StringRef valueGroupName, 1122 size_t expectedCount) { 1123 auto sizeAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrName); 1124 if (!sizeAttr) 1125 return op->emitOpError("requires dense i32 array attribute '") 1126 << attrName << "'"; 1127 1128 ArrayRef<int32_t> sizes = sizeAttr.asArrayRef(); 1129 if (llvm::any_of(sizes, [](int32_t element) { return element < 0; })) 1130 return op->emitOpError("'") 1131 << attrName << "' attribute cannot have negative elements"; 1132 1133 size_t totalCount = 1134 std::accumulate(sizes.begin(), sizes.end(), 0, 1135 [](unsigned all, int32_t one) { return all + one; }); 1136 1137 if (totalCount != expectedCount) 1138 return op->emitOpError() 1139 << valueGroupName << " count (" << expectedCount 1140 << ") does not match with the total size (" << totalCount 1141 << ") specified in attribute '" << attrName << "'"; 1142 return success(); 1143 } 1144 1145 LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, 1146 StringRef attrName) { 1147 return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands()); 1148 } 1149 1150 LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, 1151 StringRef attrName) { 1152 return verifyValueSizeAttr(op, attrName, "result", op->getNumResults()); 1153 } 1154 1155 LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { 1156 for (Region ®ion : op->getRegions()) { 1157 if (region.empty()) 1158 continue; 1159 1160 if (region.getNumArguments() != 0) { 1161 if (op->getNumRegions() > 1) 1162 return op->emitOpError("region #") 1163 << region.getRegionNumber() << " should have no arguments"; 1164 return op->emitOpError("region should have no arguments"); 1165 } 1166 } 1167 return success(); 1168 } 1169 1170 LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { 1171 auto isMappableType = [](Type type) { 1172 return llvm::isa<VectorType, TensorType>(type); 1173 }; 1174 auto resultMappableTypes = llvm::to_vector<1>( 1175 llvm::make_filter_range(op->getResultTypes(), isMappableType)); 1176 auto operandMappableTypes = llvm::to_vector<2>( 1177 llvm::make_filter_range(op->getOperandTypes(), isMappableType)); 1178 1179 // If the op only has scalar operand/result types, then we have nothing to 1180 // check. 1181 if (resultMappableTypes.empty() && operandMappableTypes.empty()) 1182 return success(); 1183 1184 if (!resultMappableTypes.empty() && operandMappableTypes.empty()) 1185 return op->emitOpError("if a result is non-scalar, then at least one " 1186 "operand must be non-scalar"); 1187 1188 assert(!operandMappableTypes.empty()); 1189 1190 if (resultMappableTypes.empty()) 1191 return op->emitOpError("if an operand is non-scalar, then there must be at " 1192 "least one non-scalar result"); 1193 1194 if (resultMappableTypes.size() != op->getNumResults()) 1195 return op->emitOpError( 1196 "if an operand is non-scalar, then all results must be non-scalar"); 1197 1198 SmallVector<Type, 4> types = llvm::to_vector<2>( 1199 llvm::concat<Type>(operandMappableTypes, resultMappableTypes)); 1200 TypeID expectedBaseTy = types.front().getTypeID(); 1201 if (!llvm::all_of(types, 1202 [&](Type t) { return t.getTypeID() == expectedBaseTy; }) || 1203 failed(verifyCompatibleShapes(types))) { 1204 return op->emitOpError() << "all non-scalar operands/results must have the " 1205 "same shape and base type"; 1206 } 1207 1208 return success(); 1209 } 1210 1211 /// Check for any values used by operations regions attached to the 1212 /// specified "IsIsolatedFromAbove" operation defined outside of it. 1213 LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) { 1214 assert(isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && 1215 "Intended to check IsolatedFromAbove ops"); 1216 1217 // List of regions to analyze. Each region is processed independently, with 1218 // respect to the common `limit` region, so we can look at them in any order. 1219 // Therefore, use a simple vector and push/pop back the current region. 1220 SmallVector<Region *, 8> pendingRegions; 1221 for (auto ®ion : isolatedOp->getRegions()) { 1222 pendingRegions.push_back(®ion); 1223 1224 // Traverse all operations in the region. 1225 while (!pendingRegions.empty()) { 1226 for (Operation &op : pendingRegions.pop_back_val()->getOps()) { 1227 for (Value operand : op.getOperands()) { 1228 // Check that any value that is used by an operation is defined in the 1229 // same region as either an operation result. 1230 auto *operandRegion = operand.getParentRegion(); 1231 if (!operandRegion) 1232 return op.emitError("operation's operand is unlinked"); 1233 if (!region.isAncestor(operandRegion)) { 1234 return op.emitOpError("using value defined outside the region") 1235 .attachNote(isolatedOp->getLoc()) 1236 << "required by region isolation constraints"; 1237 } 1238 } 1239 1240 // Schedule any regions in the operation for further checking. Don't 1241 // recurse into other IsolatedFromAbove ops, because they will check 1242 // themselves. 1243 if (op.getNumRegions() && 1244 !op.hasTrait<OpTrait::IsIsolatedFromAbove>()) { 1245 for (Region &subRegion : op.getRegions()) 1246 pendingRegions.push_back(&subRegion); 1247 } 1248 } 1249 } 1250 } 1251 1252 return success(); 1253 } 1254 1255 bool OpTrait::hasElementwiseMappableTraits(Operation *op) { 1256 return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() && 1257 op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>(); 1258 } 1259 1260 //===----------------------------------------------------------------------===// 1261 // CastOpInterface 1262 //===----------------------------------------------------------------------===// 1263 1264 /// Attempt to fold the given cast operation. 1265 LogicalResult 1266 impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands, 1267 SmallVectorImpl<OpFoldResult> &foldResults) { 1268 OperandRange operands = op->getOperands(); 1269 if (operands.empty()) 1270 return failure(); 1271 ResultRange results = op->getResults(); 1272 1273 // Check for the case where the input and output types match 1-1. 1274 if (operands.getTypes() == results.getTypes()) { 1275 foldResults.append(operands.begin(), operands.end()); 1276 return success(); 1277 } 1278 1279 return failure(); 1280 } 1281 1282 /// Attempt to verify the given cast operation. 1283 LogicalResult impl::verifyCastInterfaceOp( 1284 Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) { 1285 auto resultTypes = op->getResultTypes(); 1286 if (resultTypes.empty()) 1287 return op->emitOpError() 1288 << "expected at least one result for cast operation"; 1289 1290 auto operandTypes = op->getOperandTypes(); 1291 if (!areCastCompatible(operandTypes, resultTypes)) { 1292 InFlightDiagnostic diag = op->emitOpError("operand type"); 1293 if (operandTypes.empty()) 1294 diag << "s []"; 1295 else if (llvm::size(operandTypes) == 1) 1296 diag << " " << *operandTypes.begin(); 1297 else 1298 diag << "s " << operandTypes; 1299 return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") 1300 << resultTypes << " are cast incompatible"; 1301 } 1302 1303 return success(); 1304 } 1305 1306 //===----------------------------------------------------------------------===// 1307 // Misc. utils 1308 //===----------------------------------------------------------------------===// 1309 1310 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the 1311 /// region's only block if it does not have a terminator already. If the region 1312 /// is empty, insert a new block first. `buildTerminatorOp` should return the 1313 /// terminator operation to insert. 1314 void impl::ensureRegionTerminator( 1315 Region ®ion, OpBuilder &builder, Location loc, 1316 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1317 OpBuilder::InsertionGuard guard(builder); 1318 if (region.empty()) 1319 builder.createBlock(®ion); 1320 1321 Block &block = region.back(); 1322 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) 1323 return; 1324 1325 builder.setInsertionPointToEnd(&block); 1326 builder.insert(buildTerminatorOp(builder, loc)); 1327 } 1328 1329 /// Create a simple OpBuilder and forward to the OpBuilder version of this 1330 /// function. 1331 void impl::ensureRegionTerminator( 1332 Region ®ion, Builder &builder, Location loc, 1333 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1334 OpBuilder opBuilder(builder.getContext()); 1335 ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); 1336 } 1337