1 //===- Operation.cpp - Operation support code -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Operation.h" 10 #include "mlir/IR/Attributes.h" 11 #include "mlir/IR/BuiltinAttributes.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/IRMapping.h" 15 #include "mlir/IR/Matchers.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "mlir/IR/OperationSupport.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Interfaces/FoldInterfaces.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include <numeric> 24 #include <optional> 25 26 using namespace mlir; 27 28 //===----------------------------------------------------------------------===// 29 // Operation 30 //===----------------------------------------------------------------------===// 31 32 /// Create a new Operation from operation state. 33 Operation *Operation::create(const OperationState &state) { 34 Operation *op = 35 create(state.location, state.name, state.types, state.operands, 36 state.attributes.getDictionary(state.getContext()), 37 state.properties, state.successors, state.regions); 38 if (LLVM_UNLIKELY(state.propertiesAttr)) { 39 assert(!state.properties); 40 LogicalResult result = 41 op->setPropertiesFromAttribute(state.propertiesAttr, 42 /*diagnostic=*/nullptr); 43 assert(result.succeeded() && "invalid properties in op creation"); 44 (void)result; 45 } 46 return op; 47 } 48 49 /// Create a new Operation with the specific fields. 50 Operation *Operation::create(Location location, OperationName name, 51 TypeRange resultTypes, ValueRange operands, 52 NamedAttrList &&attributes, 53 OpaqueProperties properties, BlockRange successors, 54 RegionRange regions) { 55 unsigned numRegions = regions.size(); 56 Operation *op = 57 create(location, name, resultTypes, operands, std::move(attributes), 58 properties, successors, numRegions); 59 for (unsigned i = 0; i < numRegions; ++i) 60 if (regions[i]) 61 op->getRegion(i).takeBody(*regions[i]); 62 return op; 63 } 64 65 /// Create a new Operation with the specific fields. 66 Operation *Operation::create(Location location, OperationName name, 67 TypeRange resultTypes, ValueRange operands, 68 NamedAttrList &&attributes, 69 OpaqueProperties properties, BlockRange successors, 70 unsigned numRegions) { 71 // Populate default attributes. 72 name.populateDefaultAttrs(attributes); 73 74 return create(location, name, resultTypes, operands, 75 attributes.getDictionary(location.getContext()), properties, 76 successors, numRegions); 77 } 78 79 /// Overload of create that takes an existing DictionaryAttr to avoid 80 /// unnecessarily uniquing a list of attributes. 81 Operation *Operation::create(Location location, OperationName name, 82 TypeRange resultTypes, ValueRange operands, 83 DictionaryAttr attributes, 84 OpaqueProperties properties, BlockRange successors, 85 unsigned numRegions) { 86 assert(llvm::all_of(resultTypes, [](Type t) { return t; }) && 87 "unexpected null result type"); 88 89 // We only need to allocate additional memory for a subset of results. 90 unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); 91 unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); 92 unsigned numSuccessors = successors.size(); 93 unsigned numOperands = operands.size(); 94 unsigned numResults = resultTypes.size(); 95 int opPropertiesAllocSize = llvm::alignTo<8>(name.getOpPropertyByteSize()); 96 97 // If the operation is known to have no operands, don't allocate an operand 98 // storage. 99 bool needsOperandStorage = 100 operands.empty() ? !name.hasTrait<OpTrait::ZeroOperands>() : true; 101 102 // Compute the byte size for the operation and the operand storage. This takes 103 // into account the size of the operation, its trailing objects, and its 104 // prefixed objects. 105 size_t byteSize = 106 totalSizeToAlloc<detail::OperandStorage, detail::OpProperties, 107 BlockOperand, Region, OpOperand>( 108 needsOperandStorage ? 1 : 0, opPropertiesAllocSize, numSuccessors, 109 numRegions, numOperands); 110 size_t prefixByteSize = llvm::alignTo( 111 Operation::prefixAllocSize(numTrailingResults, numInlineResults), 112 alignof(Operation)); 113 char *mallocMem = reinterpret_cast<char *>(malloc(byteSize + prefixByteSize)); 114 void *rawMem = mallocMem + prefixByteSize; 115 116 // Create the new Operation. 117 Operation *op = ::new (rawMem) Operation( 118 location, name, numResults, numSuccessors, numRegions, 119 opPropertiesAllocSize, attributes, properties, needsOperandStorage); 120 121 assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && 122 "unexpected successors in a non-terminator operation"); 123 124 // Initialize the results. 125 auto resultTypeIt = resultTypes.begin(); 126 for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt) 127 new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i); 128 for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) { 129 new (op->getOutOfLineOpResult(i)) 130 detail::OutOfLineOpResult(*resultTypeIt, i); 131 } 132 133 // Initialize the regions. 134 for (unsigned i = 0; i != numRegions; ++i) 135 new (&op->getRegion(i)) Region(op); 136 137 // Initialize the operands. 138 if (needsOperandStorage) { 139 new (&op->getOperandStorage()) detail::OperandStorage( 140 op, op->getTrailingObjects<OpOperand>(), operands); 141 } 142 143 // Initialize the successors. 144 auto blockOperands = op->getBlockOperands(); 145 for (unsigned i = 0; i != numSuccessors; ++i) 146 new (&blockOperands[i]) BlockOperand(op, successors[i]); 147 148 // This must be done after properties are initalized. 149 op->setAttrs(attributes); 150 151 return op; 152 } 153 154 Operation::Operation(Location location, OperationName name, unsigned numResults, 155 unsigned numSuccessors, unsigned numRegions, 156 int fullPropertiesStorageSize, DictionaryAttr attributes, 157 OpaqueProperties properties, bool hasOperandStorage) 158 : location(location), numResults(numResults), numSuccs(numSuccessors), 159 numRegions(numRegions), hasOperandStorage(hasOperandStorage), 160 propertiesStorageSize((fullPropertiesStorageSize + 7) / 8), name(name) { 161 assert(attributes && "unexpected null attribute dictionary"); 162 assert(fullPropertiesStorageSize <= propertiesCapacity && 163 "Properties size overflow"); 164 #ifndef NDEBUG 165 if (!getDialect() && !getContext()->allowsUnregisteredDialects()) 166 llvm::report_fatal_error( 167 name.getStringRef() + 168 " created with unregistered dialect. If this is intended, please call " 169 "allowUnregisteredDialects() on the MLIRContext, or use " 170 "-allow-unregistered-dialect with the MLIR tool used."); 171 #endif 172 if (fullPropertiesStorageSize) 173 name.initOpProperties(getPropertiesStorage(), properties); 174 } 175 176 // Operations are deleted through the destroy() member because they are 177 // allocated via malloc. 178 Operation::~Operation() { 179 assert(block == nullptr && "operation destroyed but still in a block"); 180 #ifndef NDEBUG 181 if (!use_empty()) { 182 { 183 InFlightDiagnostic diag = 184 emitOpError("operation destroyed but still has uses"); 185 for (Operation *user : getUsers()) 186 diag.attachNote(user->getLoc()) << "- use: " << *user << "\n"; 187 } 188 llvm::report_fatal_error("operation destroyed but still has uses"); 189 } 190 #endif 191 // Explicitly run the destructors for the operands. 192 if (hasOperandStorage) 193 getOperandStorage().~OperandStorage(); 194 195 // Explicitly run the destructors for the successors. 196 for (auto &successor : getBlockOperands()) 197 successor.~BlockOperand(); 198 199 // Explicitly destroy the regions. 200 for (auto ®ion : getRegions()) 201 region.~Region(); 202 if (propertiesStorageSize) 203 name.destroyOpProperties(getPropertiesStorage()); 204 } 205 206 /// Destroy this operation or one of its subclasses. 207 void Operation::destroy() { 208 // Operations may have additional prefixed allocation, which needs to be 209 // accounted for here when computing the address to free. 210 char *rawMem = reinterpret_cast<char *>(this) - 211 llvm::alignTo(prefixAllocSize(), alignof(Operation)); 212 this->~Operation(); 213 free(rawMem); 214 } 215 216 /// Return true if this operation is a proper ancestor of the `other` 217 /// operation. 218 bool Operation::isProperAncestor(Operation *other) { 219 while ((other = other->getParentOp())) 220 if (this == other) 221 return true; 222 return false; 223 } 224 225 /// Replace any uses of 'from' with 'to' within this operation. 226 void Operation::replaceUsesOfWith(Value from, Value to) { 227 if (from == to) 228 return; 229 for (auto &operand : getOpOperands()) 230 if (operand.get() == from) 231 operand.set(to); 232 } 233 234 /// Replace the current operands of this operation with the ones provided in 235 /// 'operands'. 236 void Operation::setOperands(ValueRange operands) { 237 if (LLVM_LIKELY(hasOperandStorage)) 238 return getOperandStorage().setOperands(this, operands); 239 assert(operands.empty() && "setting operands without an operand storage"); 240 } 241 242 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 243 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 244 /// than the range pointed to by 'start'+'length'. 245 void Operation::setOperands(unsigned start, unsigned length, 246 ValueRange operands) { 247 assert((start + length) <= getNumOperands() && 248 "invalid operand range specified"); 249 if (LLVM_LIKELY(hasOperandStorage)) 250 return getOperandStorage().setOperands(this, start, length, operands); 251 assert(operands.empty() && "setting operands without an operand storage"); 252 } 253 254 /// Insert the given operands into the operand list at the given 'index'. 255 void Operation::insertOperands(unsigned index, ValueRange operands) { 256 if (LLVM_LIKELY(hasOperandStorage)) 257 return setOperands(index, /*length=*/0, operands); 258 assert(operands.empty() && "inserting operands without an operand storage"); 259 } 260 261 //===----------------------------------------------------------------------===// 262 // Diagnostics 263 //===----------------------------------------------------------------------===// 264 265 /// Emit an error about fatal conditions with this operation, reporting up to 266 /// any diagnostic handlers that may be listening. 267 InFlightDiagnostic Operation::emitError(const Twine &message) { 268 InFlightDiagnostic diag = mlir::emitError(getLoc(), message); 269 if (getContext()->shouldPrintOpOnDiagnostic()) { 270 diag.attachNote(getLoc()) 271 .append("see current operation: ") 272 .appendOp(*this, OpPrintingFlags().printGenericOpForm()); 273 } 274 return diag; 275 } 276 277 /// Emit a warning about this operation, reporting up to any diagnostic 278 /// handlers that may be listening. 279 InFlightDiagnostic Operation::emitWarning(const Twine &message) { 280 InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); 281 if (getContext()->shouldPrintOpOnDiagnostic()) 282 diag.attachNote(getLoc()) << "see current operation: " << *this; 283 return diag; 284 } 285 286 /// Emit a remark about this operation, reporting up to any diagnostic 287 /// handlers that may be listening. 288 InFlightDiagnostic Operation::emitRemark(const Twine &message) { 289 InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); 290 if (getContext()->shouldPrintOpOnDiagnostic()) 291 diag.attachNote(getLoc()) << "see current operation: " << *this; 292 return diag; 293 } 294 295 DictionaryAttr Operation::getAttrDictionary() { 296 if (getPropertiesStorageSize()) { 297 NamedAttrList attrsList = attrs; 298 getName().populateInherentAttrs(this, attrsList); 299 return attrsList.getDictionary(getContext()); 300 } 301 return attrs; 302 } 303 304 void Operation::setAttrs(DictionaryAttr newAttrs) { 305 assert(newAttrs && "expected valid attribute dictionary"); 306 if (getPropertiesStorageSize()) { 307 // We're spliting the providing DictionaryAttr by removing the inherentAttr 308 // which will be stored in the properties. 309 SmallVector<NamedAttribute> discardableAttrs; 310 discardableAttrs.reserve(newAttrs.size()); 311 for (NamedAttribute attr : newAttrs) { 312 if (std::optional<Attribute> inherentAttr = 313 getInherentAttr(attr.getName())) 314 setInherentAttr(attr.getName(), attr.getValue()); 315 else 316 discardableAttrs.push_back(attr); 317 } 318 if (discardableAttrs.size() != newAttrs.size()) 319 newAttrs = DictionaryAttr::get(getContext(), discardableAttrs); 320 } 321 attrs = newAttrs; 322 } 323 void Operation::setAttrs(ArrayRef<NamedAttribute> newAttrs) { 324 if (getPropertiesStorageSize()) { 325 // We're spliting the providing array of attributes by removing the inherentAttr 326 // which will be stored in the properties. 327 SmallVector<NamedAttribute> discardableAttrs; 328 discardableAttrs.reserve(newAttrs.size()); 329 for (NamedAttribute attr : newAttrs) { 330 if (std::optional<Attribute> inherentAttr = 331 getInherentAttr(attr.getName())) 332 setInherentAttr(attr.getName(), attr.getValue()); 333 else 334 discardableAttrs.push_back(attr); 335 } 336 attrs = DictionaryAttr::get(getContext(), discardableAttrs); 337 return; 338 } 339 attrs = DictionaryAttr::get(getContext(), newAttrs); 340 } 341 342 std::optional<Attribute> Operation::getInherentAttr(StringRef name) { 343 return getName().getInherentAttr(this, name); 344 } 345 346 void Operation::setInherentAttr(StringAttr name, Attribute value) { 347 getName().setInherentAttr(this, name, value); 348 } 349 350 Attribute Operation::getPropertiesAsAttribute() { 351 std::optional<RegisteredOperationName> info = getRegisteredInfo(); 352 if (LLVM_UNLIKELY(!info)) 353 return *getPropertiesStorage().as<Attribute *>(); 354 return info->getOpPropertiesAsAttribute(this); 355 } 356 LogicalResult 357 Operation::setPropertiesFromAttribute(Attribute attr, 358 InFlightDiagnostic *diagnostic) { 359 std::optional<RegisteredOperationName> info = getRegisteredInfo(); 360 if (LLVM_UNLIKELY(!info)) { 361 *getPropertiesStorage().as<Attribute *>() = attr; 362 return success(); 363 } 364 return info->setOpPropertiesFromAttribute( 365 this->getName(), this->getPropertiesStorage(), attr, diagnostic); 366 } 367 368 void Operation::copyProperties(OpaqueProperties rhs) { 369 name.copyOpProperties(getPropertiesStorage(), rhs); 370 } 371 372 llvm::hash_code Operation::hashProperties() { 373 return name.hashOpProperties(getPropertiesStorage()); 374 } 375 376 //===----------------------------------------------------------------------===// 377 // Operation Ordering 378 //===----------------------------------------------------------------------===// 379 380 constexpr unsigned Operation::kInvalidOrderIdx; 381 constexpr unsigned Operation::kOrderStride; 382 383 /// Given an operation 'other' that is within the same parent block, return 384 /// whether the current operation is before 'other' in the operation list 385 /// of the parent block. 386 /// Note: This function has an average complexity of O(1), but worst case may 387 /// take O(N) where N is the number of operations within the parent block. 388 bool Operation::isBeforeInBlock(Operation *other) { 389 assert(block && "Operations without parent blocks have no order."); 390 assert(other && other->block == block && 391 "Expected other operation to have the same parent block."); 392 // If the order of the block is already invalid, directly recompute the 393 // parent. 394 if (!block->isOpOrderValid()) { 395 block->recomputeOpOrder(); 396 } else { 397 // Update the order either operation if necessary. 398 updateOrderIfNecessary(); 399 other->updateOrderIfNecessary(); 400 } 401 402 return orderIndex < other->orderIndex; 403 } 404 405 /// Update the order index of this operation of this operation if necessary, 406 /// potentially recomputing the order of the parent block. 407 void Operation::updateOrderIfNecessary() { 408 assert(block && "expected valid parent"); 409 410 // If the order is valid for this operation there is nothing to do. 411 if (hasValidOrder()) 412 return; 413 Operation *blockFront = &block->front(); 414 Operation *blockBack = &block->back(); 415 416 // This method is expected to only be invoked on blocks with more than one 417 // operation. 418 assert(blockFront != blockBack && "expected more than one operation"); 419 420 // If the operation is at the end of the block. 421 if (this == blockBack) { 422 Operation *prevNode = getPrevNode(); 423 if (!prevNode->hasValidOrder()) 424 return block->recomputeOpOrder(); 425 426 // Add the stride to the previous operation. 427 orderIndex = prevNode->orderIndex + kOrderStride; 428 return; 429 } 430 431 // If this is the first operation try to use the next operation to compute the 432 // ordering. 433 if (this == blockFront) { 434 Operation *nextNode = getNextNode(); 435 if (!nextNode->hasValidOrder()) 436 return block->recomputeOpOrder(); 437 // There is no order to give this operation. 438 if (nextNode->orderIndex == 0) 439 return block->recomputeOpOrder(); 440 441 // If we can't use the stride, just take the middle value left. This is safe 442 // because we know there is at least one valid index to assign to. 443 if (nextNode->orderIndex <= kOrderStride) 444 orderIndex = (nextNode->orderIndex / 2); 445 else 446 orderIndex = kOrderStride; 447 return; 448 } 449 450 // Otherwise, this operation is between two others. Place this operation in 451 // the middle of the previous and next if possible. 452 Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); 453 if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) 454 return block->recomputeOpOrder(); 455 unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; 456 457 // Check to see if there is a valid order between the two. 458 if (prevOrder + 1 == nextOrder) 459 return block->recomputeOpOrder(); 460 orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); 461 } 462 463 //===----------------------------------------------------------------------===// 464 // ilist_traits for Operation 465 //===----------------------------------------------------------------------===// 466 467 auto llvm::ilist_detail::SpecificNodeAccess< 468 typename llvm::ilist_detail::compute_node_options< 469 ::mlir::Operation>::type>::getNodePtr(pointer n) -> node_type * { 470 return NodeAccess::getNodePtr<OptionsT>(n); 471 } 472 473 auto llvm::ilist_detail::SpecificNodeAccess< 474 typename llvm::ilist_detail::compute_node_options< 475 ::mlir::Operation>::type>::getNodePtr(const_pointer n) 476 -> const node_type * { 477 return NodeAccess::getNodePtr<OptionsT>(n); 478 } 479 480 auto llvm::ilist_detail::SpecificNodeAccess< 481 typename llvm::ilist_detail::compute_node_options< 482 ::mlir::Operation>::type>::getValuePtr(node_type *n) -> pointer { 483 return NodeAccess::getValuePtr<OptionsT>(n); 484 } 485 486 auto llvm::ilist_detail::SpecificNodeAccess< 487 typename llvm::ilist_detail::compute_node_options< 488 ::mlir::Operation>::type>::getValuePtr(const node_type *n) 489 -> const_pointer { 490 return NodeAccess::getValuePtr<OptionsT>(n); 491 } 492 493 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { 494 op->destroy(); 495 } 496 497 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { 498 size_t offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); 499 iplist<Operation> *anchor(static_cast<iplist<Operation> *>(this)); 500 return reinterpret_cast<Block *>(reinterpret_cast<char *>(anchor) - offset); 501 } 502 503 /// This is a trait method invoked when an operation is added to a block. We 504 /// keep the block pointer up to date. 505 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { 506 assert(!op->getBlock() && "already in an operation block!"); 507 op->block = getContainingBlock(); 508 509 // Invalidate the order on the operation. 510 op->orderIndex = Operation::kInvalidOrderIdx; 511 } 512 513 /// This is a trait method invoked when an operation is removed from a block. 514 /// We keep the block pointer up to date. 515 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { 516 assert(op->block && "not already in an operation block!"); 517 op->block = nullptr; 518 } 519 520 /// This is a trait method invoked when an operation is moved from one block 521 /// to another. We keep the block pointer up to date. 522 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( 523 ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { 524 Block *curParent = getContainingBlock(); 525 526 // Invalidate the ordering of the parent block. 527 curParent->invalidateOpOrder(); 528 529 // If we are transferring operations within the same block, the block 530 // pointer doesn't need to be updated. 531 if (curParent == otherList.getContainingBlock()) 532 return; 533 534 // Update the 'block' member of each operation. 535 for (; first != last; ++first) 536 first->block = curParent; 537 } 538 539 /// Remove this operation (and its descendants) from its Block and delete 540 /// all of them. 541 void Operation::erase() { 542 if (auto *parent = getBlock()) 543 parent->getOperations().erase(this); 544 else 545 destroy(); 546 } 547 548 /// Remove the operation from its parent block, but don't delete it. 549 void Operation::remove() { 550 if (Block *parent = getBlock()) 551 parent->getOperations().remove(this); 552 } 553 554 /// Unlink this operation from its current block and insert it right before 555 /// `existingOp` which may be in the same or another block in the same 556 /// function. 557 void Operation::moveBefore(Operation *existingOp) { 558 moveBefore(existingOp->getBlock(), existingOp->getIterator()); 559 } 560 561 /// Unlink this operation from its current basic block and insert it right 562 /// before `iterator` in the specified basic block. 563 void Operation::moveBefore(Block *block, 564 llvm::iplist<Operation>::iterator iterator) { 565 block->getOperations().splice(iterator, getBlock()->getOperations(), 566 getIterator()); 567 } 568 569 /// Unlink this operation from its current block and insert it right after 570 /// `existingOp` which may be in the same or another block in the same function. 571 void Operation::moveAfter(Operation *existingOp) { 572 moveAfter(existingOp->getBlock(), existingOp->getIterator()); 573 } 574 575 /// Unlink this operation from its current block and insert it right after 576 /// `iterator` in the specified block. 577 void Operation::moveAfter(Block *block, 578 llvm::iplist<Operation>::iterator iterator) { 579 assert(iterator != block->end() && "cannot move after end of block"); 580 moveBefore(block, std::next(iterator)); 581 } 582 583 /// This drops all operand uses from this operation, which is an essential 584 /// step in breaking cyclic dependences between references when they are to 585 /// be deleted. 586 void Operation::dropAllReferences() { 587 for (auto &op : getOpOperands()) 588 op.drop(); 589 590 for (auto ®ion : getRegions()) 591 region.dropAllReferences(); 592 593 for (auto &dest : getBlockOperands()) 594 dest.drop(); 595 } 596 597 /// This drops all uses of any values defined by this operation or its nested 598 /// regions, wherever they are located. 599 void Operation::dropAllDefinedValueUses() { 600 dropAllUses(); 601 602 for (auto ®ion : getRegions()) 603 for (auto &block : region) 604 block.dropAllDefinedValueUses(); 605 } 606 607 void Operation::setSuccessor(Block *block, unsigned index) { 608 assert(index < getNumSuccessors()); 609 getBlockOperands()[index].set(block); 610 } 611 612 /// Attempt to fold this operation using the Op's registered foldHook. 613 LogicalResult Operation::fold(ArrayRef<Attribute> operands, 614 SmallVectorImpl<OpFoldResult> &results) { 615 // If we have a registered operation definition matching this one, use it to 616 // try to constant fold the operation. 617 if (succeeded(name.foldHook(this, operands, results))) 618 return success(); 619 620 // Otherwise, fall back on the dialect hook to handle it. 621 Dialect *dialect = getDialect(); 622 if (!dialect) 623 return failure(); 624 625 auto *interface = dyn_cast<DialectFoldInterface>(dialect); 626 if (!interface) 627 return failure(); 628 629 return interface->fold(this, operands, results); 630 } 631 632 LogicalResult Operation::fold(SmallVectorImpl<OpFoldResult> &results) { 633 // Check if any operands are constants. 634 SmallVector<Attribute> constants; 635 constants.assign(getNumOperands(), Attribute()); 636 for (unsigned i = 0, e = getNumOperands(); i != e; ++i) 637 matchPattern(getOperand(i), m_Constant(&constants[i])); 638 return fold(constants, results); 639 } 640 641 /// Emit an error with the op name prefixed, like "'dim' op " which is 642 /// convenient for verifiers. 643 InFlightDiagnostic Operation::emitOpError(const Twine &message) { 644 return emitError() << "'" << getName() << "' op " << message; 645 } 646 647 //===----------------------------------------------------------------------===// 648 // Operation Cloning 649 //===----------------------------------------------------------------------===// 650 651 Operation::CloneOptions::CloneOptions() 652 : cloneRegionsFlag(false), cloneOperandsFlag(false) {} 653 654 Operation::CloneOptions::CloneOptions(bool cloneRegions, bool cloneOperands) 655 : cloneRegionsFlag(cloneRegions), cloneOperandsFlag(cloneOperands) {} 656 657 Operation::CloneOptions Operation::CloneOptions::all() { 658 return CloneOptions().cloneRegions().cloneOperands(); 659 } 660 661 Operation::CloneOptions &Operation::CloneOptions::cloneRegions(bool enable) { 662 cloneRegionsFlag = enable; 663 return *this; 664 } 665 666 Operation::CloneOptions &Operation::CloneOptions::cloneOperands(bool enable) { 667 cloneOperandsFlag = enable; 668 return *this; 669 } 670 671 /// Create a deep copy of this operation but keep the operation regions empty. 672 /// Operands are remapped using `mapper` (if present), and `mapper` is updated 673 /// to contain the results. The `mapResults` flag specifies whether the results 674 /// of the cloned operation should be added to the map. 675 Operation *Operation::cloneWithoutRegions(IRMapping &mapper) { 676 return clone(mapper, CloneOptions::all().cloneRegions(false)); 677 } 678 679 Operation *Operation::cloneWithoutRegions() { 680 IRMapping mapper; 681 return cloneWithoutRegions(mapper); 682 } 683 684 /// Create a deep copy of this operation, remapping any operands that use 685 /// values outside of the operation using the map that is provided (leaving 686 /// them alone if no entry is present). Replaces references to cloned 687 /// sub-operations to the corresponding operation that is copied, and adds 688 /// those mappings to the map. 689 Operation *Operation::clone(IRMapping &mapper, CloneOptions options) { 690 SmallVector<Value, 8> operands; 691 SmallVector<Block *, 2> successors; 692 693 // Remap the operands. 694 if (options.shouldCloneOperands()) { 695 operands.reserve(getNumOperands()); 696 for (auto opValue : getOperands()) 697 operands.push_back(mapper.lookupOrDefault(opValue)); 698 } 699 700 // Remap the successors. 701 successors.reserve(getNumSuccessors()); 702 for (Block *successor : getSuccessors()) 703 successors.push_back(mapper.lookupOrDefault(successor)); 704 705 // Create the new operation. 706 auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, 707 getPropertiesStorage(), successors, getNumRegions()); 708 mapper.map(this, newOp); 709 710 // Clone the regions. 711 if (options.shouldCloneRegions()) { 712 for (unsigned i = 0; i != numRegions; ++i) 713 getRegion(i).cloneInto(&newOp->getRegion(i), mapper); 714 } 715 716 // Remember the mapping of any results. 717 for (unsigned i = 0, e = getNumResults(); i != e; ++i) 718 mapper.map(getResult(i), newOp->getResult(i)); 719 720 return newOp; 721 } 722 723 Operation *Operation::clone(CloneOptions options) { 724 IRMapping mapper; 725 return clone(mapper, options); 726 } 727 728 //===----------------------------------------------------------------------===// 729 // OpState trait class. 730 //===----------------------------------------------------------------------===// 731 732 // The fallback for the parser is to try for a dialect operation parser. 733 // Otherwise, reject the custom assembly form. 734 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { 735 if (auto parseFn = result.name.getDialect()->getParseOperationHook( 736 result.name.getStringRef())) 737 return (*parseFn)(parser, result); 738 return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); 739 } 740 741 // The fallback for the printer is to try for a dialect operation printer. 742 // Otherwise, it prints the generic form. 743 void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { 744 if (auto printFn = op->getDialect()->getOperationPrinter(op)) { 745 printOpName(op, p, defaultDialect); 746 printFn(op, p); 747 } else { 748 p.printGenericOp(op); 749 } 750 } 751 752 /// Print an operation name, eliding the dialect prefix if necessary and doesn't 753 /// lead to ambiguities. 754 void OpState::printOpName(Operation *op, OpAsmPrinter &p, 755 StringRef defaultDialect) { 756 StringRef name = op->getName().getStringRef(); 757 if (name.startswith((defaultDialect + ".").str()) && name.count('.') == 1) 758 name = name.drop_front(defaultDialect.size() + 1); 759 p.getStream() << name; 760 } 761 762 /// Parse properties as a Attribute. 763 ParseResult OpState::genericParseProperties(OpAsmParser &parser, 764 Attribute &result) { 765 if (parser.parseLess() || parser.parseAttribute(result) || 766 parser.parseGreater()) 767 return failure(); 768 return success(); 769 } 770 771 /// Print the properties as a Attribute. 772 void OpState::genericPrintProperties(OpAsmPrinter &p, Attribute properties) { 773 p << "<" << properties << ">"; 774 } 775 776 /// Emit an error about fatal conditions with this operation, reporting up to 777 /// any diagnostic handlers that may be listening. 778 InFlightDiagnostic OpState::emitError(const Twine &message) { 779 return getOperation()->emitError(message); 780 } 781 782 /// Emit an error with the op name prefixed, like "'dim' op " which is 783 /// convenient for verifiers. 784 InFlightDiagnostic OpState::emitOpError(const Twine &message) { 785 return getOperation()->emitOpError(message); 786 } 787 788 /// Emit a warning about this operation, reporting up to any diagnostic 789 /// handlers that may be listening. 790 InFlightDiagnostic OpState::emitWarning(const Twine &message) { 791 return getOperation()->emitWarning(message); 792 } 793 794 /// Emit a remark about this operation, reporting up to any diagnostic 795 /// handlers that may be listening. 796 InFlightDiagnostic OpState::emitRemark(const Twine &message) { 797 return getOperation()->emitRemark(message); 798 } 799 800 //===----------------------------------------------------------------------===// 801 // Op Trait implementations 802 //===----------------------------------------------------------------------===// 803 804 LogicalResult 805 OpTrait::impl::foldCommutative(Operation *op, ArrayRef<Attribute> operands, 806 SmallVectorImpl<OpFoldResult> &results) { 807 // Nothing to fold if there are not at least 2 operands. 808 if (op->getNumOperands() < 2) 809 return failure(); 810 // Move all constant operands to the end. 811 OpOperand *operandsBegin = op->getOpOperands().begin(); 812 auto isNonConstant = [&](OpOperand &o) { 813 return !static_cast<bool>(operands[std::distance(operandsBegin, &o)]); 814 }; 815 auto *firstConstantIt = llvm::find_if_not(op->getOpOperands(), isNonConstant); 816 auto *newConstantIt = std::stable_partition( 817 firstConstantIt, op->getOpOperands().end(), isNonConstant); 818 // Return success if the op was modified. 819 return success(firstConstantIt != newConstantIt); 820 } 821 822 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { 823 if (op->getNumOperands() == 1) { 824 auto *argumentOp = op->getOperand(0).getDefiningOp(); 825 if (argumentOp && op->getName() == argumentOp->getName()) { 826 // Replace the outer operation output with the inner operation. 827 return op->getOperand(0); 828 } 829 } else if (op->getOperand(0) == op->getOperand(1)) { 830 return op->getOperand(0); 831 } 832 833 return {}; 834 } 835 836 OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { 837 auto *argumentOp = op->getOperand(0).getDefiningOp(); 838 if (argumentOp && op->getName() == argumentOp->getName()) { 839 // Replace the outer involutions output with inner's input. 840 return argumentOp->getOperand(0); 841 } 842 843 return {}; 844 } 845 846 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { 847 if (op->getNumOperands() != 0) 848 return op->emitOpError() << "requires zero operands"; 849 return success(); 850 } 851 852 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { 853 if (op->getNumOperands() != 1) 854 return op->emitOpError() << "requires a single operand"; 855 return success(); 856 } 857 858 LogicalResult OpTrait::impl::verifyNOperands(Operation *op, 859 unsigned numOperands) { 860 if (op->getNumOperands() != numOperands) { 861 return op->emitOpError() << "expected " << numOperands 862 << " operands, but found " << op->getNumOperands(); 863 } 864 return success(); 865 } 866 867 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, 868 unsigned numOperands) { 869 if (op->getNumOperands() < numOperands) 870 return op->emitOpError() 871 << "expected " << numOperands << " or more operands, but found " 872 << op->getNumOperands(); 873 return success(); 874 } 875 876 /// If this is a vector type, or a tensor type, return the scalar element type 877 /// that it is built around, otherwise return the type unmodified. 878 static Type getTensorOrVectorElementType(Type type) { 879 if (auto vec = llvm::dyn_cast<VectorType>(type)) 880 return vec.getElementType(); 881 882 // Look through tensor<vector<...>> to find the underlying element type. 883 if (auto tensor = llvm::dyn_cast<TensorType>(type)) 884 return getTensorOrVectorElementType(tensor.getElementType()); 885 return type; 886 } 887 888 LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { 889 // FIXME: Add back check for no side effects on operation. 890 // Currently adding it would cause the shared library build 891 // to fail since there would be a dependency of IR on SideEffectInterfaces 892 // which is cyclical. 893 return success(); 894 } 895 896 LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { 897 // FIXME: Add back check for no side effects on operation. 898 // Currently adding it would cause the shared library build 899 // to fail since there would be a dependency of IR on SideEffectInterfaces 900 // which is cyclical. 901 return success(); 902 } 903 904 LogicalResult 905 OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { 906 for (auto opType : op->getOperandTypes()) { 907 auto type = getTensorOrVectorElementType(opType); 908 if (!type.isSignlessIntOrIndex()) 909 return op->emitOpError() << "requires an integer or index type"; 910 } 911 return success(); 912 } 913 914 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { 915 for (auto opType : op->getOperandTypes()) { 916 auto type = getTensorOrVectorElementType(opType); 917 if (!llvm::isa<FloatType>(type)) 918 return op->emitOpError("requires a float type"); 919 } 920 return success(); 921 } 922 923 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { 924 // Zero or one operand always have the "same" type. 925 unsigned nOperands = op->getNumOperands(); 926 if (nOperands < 2) 927 return success(); 928 929 auto type = op->getOperand(0).getType(); 930 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) 931 if (opType != type) 932 return op->emitOpError() << "requires all operands to have the same type"; 933 return success(); 934 } 935 936 LogicalResult OpTrait::impl::verifyZeroRegions(Operation *op) { 937 if (op->getNumRegions() != 0) 938 return op->emitOpError() << "requires zero regions"; 939 return success(); 940 } 941 942 LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { 943 if (op->getNumRegions() != 1) 944 return op->emitOpError() << "requires one region"; 945 return success(); 946 } 947 948 LogicalResult OpTrait::impl::verifyNRegions(Operation *op, 949 unsigned numRegions) { 950 if (op->getNumRegions() != numRegions) 951 return op->emitOpError() << "expected " << numRegions << " regions"; 952 return success(); 953 } 954 955 LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, 956 unsigned numRegions) { 957 if (op->getNumRegions() < numRegions) 958 return op->emitOpError() << "expected " << numRegions << " or more regions"; 959 return success(); 960 } 961 962 LogicalResult OpTrait::impl::verifyZeroResults(Operation *op) { 963 if (op->getNumResults() != 0) 964 return op->emitOpError() << "requires zero results"; 965 return success(); 966 } 967 968 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { 969 if (op->getNumResults() != 1) 970 return op->emitOpError() << "requires one result"; 971 return success(); 972 } 973 974 LogicalResult OpTrait::impl::verifyNResults(Operation *op, 975 unsigned numOperands) { 976 if (op->getNumResults() != numOperands) 977 return op->emitOpError() << "expected " << numOperands << " results"; 978 return success(); 979 } 980 981 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, 982 unsigned numOperands) { 983 if (op->getNumResults() < numOperands) 984 return op->emitOpError() 985 << "expected " << numOperands << " or more results"; 986 return success(); 987 } 988 989 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { 990 if (failed(verifyAtLeastNOperands(op, 1))) 991 return failure(); 992 993 if (failed(verifyCompatibleShapes(op->getOperandTypes()))) 994 return op->emitOpError() << "requires the same shape for all operands"; 995 996 return success(); 997 } 998 999 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { 1000 if (failed(verifyAtLeastNOperands(op, 1)) || 1001 failed(verifyAtLeastNResults(op, 1))) 1002 return failure(); 1003 1004 SmallVector<Type, 8> types(op->getOperandTypes()); 1005 types.append(llvm::to_vector<4>(op->getResultTypes())); 1006 1007 if (failed(verifyCompatibleShapes(types))) 1008 return op->emitOpError() 1009 << "requires the same shape for all operands and results"; 1010 1011 return success(); 1012 } 1013 1014 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { 1015 if (failed(verifyAtLeastNOperands(op, 1))) 1016 return failure(); 1017 auto elementType = getElementTypeOrSelf(op->getOperand(0)); 1018 1019 for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { 1020 if (getElementTypeOrSelf(operand) != elementType) 1021 return op->emitOpError("requires the same element type for all operands"); 1022 } 1023 1024 return success(); 1025 } 1026 1027 LogicalResult 1028 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { 1029 if (failed(verifyAtLeastNOperands(op, 1)) || 1030 failed(verifyAtLeastNResults(op, 1))) 1031 return failure(); 1032 1033 auto elementType = getElementTypeOrSelf(op->getResult(0)); 1034 1035 // Verify result element type matches first result's element type. 1036 for (auto result : llvm::drop_begin(op->getResults(), 1)) { 1037 if (getElementTypeOrSelf(result) != elementType) 1038 return op->emitOpError( 1039 "requires the same element type for all operands and results"); 1040 } 1041 1042 // Verify operand's element type matches first result's element type. 1043 for (auto operand : op->getOperands()) { 1044 if (getElementTypeOrSelf(operand) != elementType) 1045 return op->emitOpError( 1046 "requires the same element type for all operands and results"); 1047 } 1048 1049 return success(); 1050 } 1051 1052 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { 1053 if (failed(verifyAtLeastNOperands(op, 1)) || 1054 failed(verifyAtLeastNResults(op, 1))) 1055 return failure(); 1056 1057 auto type = op->getResult(0).getType(); 1058 auto elementType = getElementTypeOrSelf(type); 1059 Attribute encoding = nullptr; 1060 if (auto rankedType = dyn_cast<RankedTensorType>(type)) 1061 encoding = rankedType.getEncoding(); 1062 for (auto resultType : llvm::drop_begin(op->getResultTypes())) { 1063 if (getElementTypeOrSelf(resultType) != elementType || 1064 failed(verifyCompatibleShape(resultType, type))) 1065 return op->emitOpError() 1066 << "requires the same type for all operands and results"; 1067 if (encoding) 1068 if (auto rankedType = dyn_cast<RankedTensorType>(resultType); 1069 encoding != rankedType.getEncoding()) 1070 return op->emitOpError() 1071 << "requires the same encoding for all operands and results"; 1072 } 1073 for (auto opType : op->getOperandTypes()) { 1074 if (getElementTypeOrSelf(opType) != elementType || 1075 failed(verifyCompatibleShape(opType, type))) 1076 return op->emitOpError() 1077 << "requires the same type for all operands and results"; 1078 if (encoding) 1079 if (auto rankedType = dyn_cast<RankedTensorType>(opType); 1080 encoding != rankedType.getEncoding()) 1081 return op->emitOpError() 1082 << "requires the same encoding for all operands and results"; 1083 } 1084 return success(); 1085 } 1086 1087 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { 1088 Block *block = op->getBlock(); 1089 // Verify that the operation is at the end of the respective parent block. 1090 if (!block || &block->back() != op) 1091 return op->emitOpError("must be the last operation in the parent block"); 1092 return success(); 1093 } 1094 1095 static LogicalResult verifyTerminatorSuccessors(Operation *op) { 1096 auto *parent = op->getParentRegion(); 1097 1098 // Verify that the operands lines up with the BB arguments in the successor. 1099 for (Block *succ : op->getSuccessors()) 1100 if (succ->getParent() != parent) 1101 return op->emitError("reference to block defined in another region"); 1102 return success(); 1103 } 1104 1105 LogicalResult OpTrait::impl::verifyZeroSuccessors(Operation *op) { 1106 if (op->getNumSuccessors() != 0) { 1107 return op->emitOpError("requires 0 successors but found ") 1108 << op->getNumSuccessors(); 1109 } 1110 return success(); 1111 } 1112 1113 LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { 1114 if (op->getNumSuccessors() != 1) { 1115 return op->emitOpError("requires 1 successor but found ") 1116 << op->getNumSuccessors(); 1117 } 1118 return verifyTerminatorSuccessors(op); 1119 } 1120 LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, 1121 unsigned numSuccessors) { 1122 if (op->getNumSuccessors() != numSuccessors) { 1123 return op->emitOpError("requires ") 1124 << numSuccessors << " successors but found " 1125 << op->getNumSuccessors(); 1126 } 1127 return verifyTerminatorSuccessors(op); 1128 } 1129 LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, 1130 unsigned numSuccessors) { 1131 if (op->getNumSuccessors() < numSuccessors) { 1132 return op->emitOpError("requires at least ") 1133 << numSuccessors << " successors but found " 1134 << op->getNumSuccessors(); 1135 } 1136 return verifyTerminatorSuccessors(op); 1137 } 1138 1139 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { 1140 for (auto resultType : op->getResultTypes()) { 1141 auto elementType = getTensorOrVectorElementType(resultType); 1142 bool isBoolType = elementType.isInteger(1); 1143 if (!isBoolType) 1144 return op->emitOpError() << "requires a bool result type"; 1145 } 1146 1147 return success(); 1148 } 1149 1150 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { 1151 for (auto resultType : op->getResultTypes()) 1152 if (!llvm::isa<FloatType>(getTensorOrVectorElementType(resultType))) 1153 return op->emitOpError() << "requires a floating point type"; 1154 1155 return success(); 1156 } 1157 1158 LogicalResult 1159 OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { 1160 for (auto resultType : op->getResultTypes()) 1161 if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) 1162 return op->emitOpError() << "requires an integer or index type"; 1163 return success(); 1164 } 1165 1166 LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op, 1167 StringRef attrName, 1168 StringRef valueGroupName, 1169 size_t expectedCount) { 1170 auto sizeAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrName); 1171 if (!sizeAttr) 1172 return op->emitOpError("requires dense i32 array attribute '") 1173 << attrName << "'"; 1174 1175 ArrayRef<int32_t> sizes = sizeAttr.asArrayRef(); 1176 if (llvm::any_of(sizes, [](int32_t element) { return element < 0; })) 1177 return op->emitOpError("'") 1178 << attrName << "' attribute cannot have negative elements"; 1179 1180 size_t totalCount = 1181 std::accumulate(sizes.begin(), sizes.end(), 0, 1182 [](unsigned all, int32_t one) { return all + one; }); 1183 1184 if (totalCount != expectedCount) 1185 return op->emitOpError() 1186 << valueGroupName << " count (" << expectedCount 1187 << ") does not match with the total size (" << totalCount 1188 << ") specified in attribute '" << attrName << "'"; 1189 return success(); 1190 } 1191 1192 LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, 1193 StringRef attrName) { 1194 return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands()); 1195 } 1196 1197 LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, 1198 StringRef attrName) { 1199 return verifyValueSizeAttr(op, attrName, "result", op->getNumResults()); 1200 } 1201 1202 LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { 1203 for (Region ®ion : op->getRegions()) { 1204 if (region.empty()) 1205 continue; 1206 1207 if (region.getNumArguments() != 0) { 1208 if (op->getNumRegions() > 1) 1209 return op->emitOpError("region #") 1210 << region.getRegionNumber() << " should have no arguments"; 1211 return op->emitOpError("region should have no arguments"); 1212 } 1213 } 1214 return success(); 1215 } 1216 1217 LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { 1218 auto isMappableType = [](Type type) { 1219 return llvm::isa<VectorType, TensorType>(type); 1220 }; 1221 auto resultMappableTypes = llvm::to_vector<1>( 1222 llvm::make_filter_range(op->getResultTypes(), isMappableType)); 1223 auto operandMappableTypes = llvm::to_vector<2>( 1224 llvm::make_filter_range(op->getOperandTypes(), isMappableType)); 1225 1226 // If the op only has scalar operand/result types, then we have nothing to 1227 // check. 1228 if (resultMappableTypes.empty() && operandMappableTypes.empty()) 1229 return success(); 1230 1231 if (!resultMappableTypes.empty() && operandMappableTypes.empty()) 1232 return op->emitOpError("if a result is non-scalar, then at least one " 1233 "operand must be non-scalar"); 1234 1235 assert(!operandMappableTypes.empty()); 1236 1237 if (resultMappableTypes.empty()) 1238 return op->emitOpError("if an operand is non-scalar, then there must be at " 1239 "least one non-scalar result"); 1240 1241 if (resultMappableTypes.size() != op->getNumResults()) 1242 return op->emitOpError( 1243 "if an operand is non-scalar, then all results must be non-scalar"); 1244 1245 SmallVector<Type, 4> types = llvm::to_vector<2>( 1246 llvm::concat<Type>(operandMappableTypes, resultMappableTypes)); 1247 TypeID expectedBaseTy = types.front().getTypeID(); 1248 if (!llvm::all_of(types, 1249 [&](Type t) { return t.getTypeID() == expectedBaseTy; }) || 1250 failed(verifyCompatibleShapes(types))) { 1251 return op->emitOpError() << "all non-scalar operands/results must have the " 1252 "same shape and base type"; 1253 } 1254 1255 return success(); 1256 } 1257 1258 /// Check for any values used by operations regions attached to the 1259 /// specified "IsIsolatedFromAbove" operation defined outside of it. 1260 LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) { 1261 assert(isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && 1262 "Intended to check IsolatedFromAbove ops"); 1263 1264 // List of regions to analyze. Each region is processed independently, with 1265 // respect to the common `limit` region, so we can look at them in any order. 1266 // Therefore, use a simple vector and push/pop back the current region. 1267 SmallVector<Region *, 8> pendingRegions; 1268 for (auto ®ion : isolatedOp->getRegions()) { 1269 pendingRegions.push_back(®ion); 1270 1271 // Traverse all operations in the region. 1272 while (!pendingRegions.empty()) { 1273 for (Operation &op : pendingRegions.pop_back_val()->getOps()) { 1274 for (Value operand : op.getOperands()) { 1275 // Check that any value that is used by an operation is defined in the 1276 // same region as either an operation result. 1277 auto *operandRegion = operand.getParentRegion(); 1278 if (!operandRegion) 1279 return op.emitError("operation's operand is unlinked"); 1280 if (!region.isAncestor(operandRegion)) { 1281 return op.emitOpError("using value defined outside the region") 1282 .attachNote(isolatedOp->getLoc()) 1283 << "required by region isolation constraints"; 1284 } 1285 } 1286 1287 // Schedule any regions in the operation for further checking. Don't 1288 // recurse into other IsolatedFromAbove ops, because they will check 1289 // themselves. 1290 if (op.getNumRegions() && 1291 !op.hasTrait<OpTrait::IsIsolatedFromAbove>()) { 1292 for (Region &subRegion : op.getRegions()) 1293 pendingRegions.push_back(&subRegion); 1294 } 1295 } 1296 } 1297 } 1298 1299 return success(); 1300 } 1301 1302 bool OpTrait::hasElementwiseMappableTraits(Operation *op) { 1303 return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() && 1304 op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>(); 1305 } 1306 1307 //===----------------------------------------------------------------------===// 1308 // Misc. utils 1309 //===----------------------------------------------------------------------===// 1310 1311 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the 1312 /// region's only block if it does not have a terminator already. If the region 1313 /// is empty, insert a new block first. `buildTerminatorOp` should return the 1314 /// terminator operation to insert. 1315 void impl::ensureRegionTerminator( 1316 Region ®ion, OpBuilder &builder, Location loc, 1317 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1318 OpBuilder::InsertionGuard guard(builder); 1319 if (region.empty()) 1320 builder.createBlock(®ion); 1321 1322 Block &block = region.back(); 1323 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) 1324 return; 1325 1326 builder.setInsertionPointToEnd(&block); 1327 builder.insert(buildTerminatorOp(builder, loc)); 1328 } 1329 1330 /// Create a simple OpBuilder and forward to the OpBuilder version of this 1331 /// function. 1332 void impl::ensureRegionTerminator( 1333 Region ®ion, Builder &builder, Location loc, 1334 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1335 OpBuilder opBuilder(builder.getContext()); 1336 ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); 1337 } 1338