1 //===- Operation.cpp - Operation support code -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Operation.h" 10 #include "mlir/IR/Attributes.h" 11 #include "mlir/IR/BuiltinAttributes.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/IRMapping.h" 15 #include "mlir/IR/Matchers.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "mlir/IR/OperationSupport.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Interfaces/FoldInterfaces.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include <numeric> 24 #include <optional> 25 26 using namespace mlir; 27 28 //===----------------------------------------------------------------------===// 29 // Operation 30 //===----------------------------------------------------------------------===// 31 32 /// Create a new Operation from operation state. 33 Operation *Operation::create(const OperationState &state) { 34 Operation *op = 35 create(state.location, state.name, state.types, state.operands, 36 state.attributes.getDictionary(state.getContext()), 37 state.properties, state.successors, state.regions); 38 if (LLVM_UNLIKELY(state.propertiesAttr)) { 39 assert(!state.properties); 40 LogicalResult result = 41 op->setPropertiesFromAttribute(state.propertiesAttr, 42 /*diagnostic=*/nullptr); 43 assert(result.succeeded() && "invalid properties in op creation"); 44 (void)result; 45 } 46 return op; 47 } 48 49 /// Create a new Operation with the specific fields. 50 Operation *Operation::create(Location location, OperationName name, 51 TypeRange resultTypes, ValueRange operands, 52 NamedAttrList &&attributes, 53 OpaqueProperties properties, BlockRange successors, 54 RegionRange regions) { 55 unsigned numRegions = regions.size(); 56 Operation *op = 57 create(location, name, resultTypes, operands, std::move(attributes), 58 properties, successors, numRegions); 59 for (unsigned i = 0; i < numRegions; ++i) 60 if (regions[i]) 61 op->getRegion(i).takeBody(*regions[i]); 62 return op; 63 } 64 65 /// Create a new Operation with the specific fields. 66 Operation *Operation::create(Location location, OperationName name, 67 TypeRange resultTypes, ValueRange operands, 68 NamedAttrList &&attributes, 69 OpaqueProperties properties, BlockRange successors, 70 unsigned numRegions) { 71 // Populate default attributes. 72 name.populateDefaultAttrs(attributes); 73 74 return create(location, name, resultTypes, operands, 75 attributes.getDictionary(location.getContext()), properties, 76 successors, numRegions); 77 } 78 79 /// Overload of create that takes an existing DictionaryAttr to avoid 80 /// unnecessarily uniquing a list of attributes. 81 Operation *Operation::create(Location location, OperationName name, 82 TypeRange resultTypes, ValueRange operands, 83 DictionaryAttr attributes, 84 OpaqueProperties properties, BlockRange successors, 85 unsigned numRegions) { 86 assert(llvm::all_of(resultTypes, [](Type t) { return t; }) && 87 "unexpected null result type"); 88 89 // We only need to allocate additional memory for a subset of results. 90 unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); 91 unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); 92 unsigned numSuccessors = successors.size(); 93 unsigned numOperands = operands.size(); 94 unsigned numResults = resultTypes.size(); 95 int opPropertiesAllocSize = llvm::alignTo<8>(name.getOpPropertyByteSize()); 96 97 // If the operation is known to have no operands, don't allocate an operand 98 // storage. 99 bool needsOperandStorage = 100 operands.empty() ? !name.hasTrait<OpTrait::ZeroOperands>() : true; 101 102 // Compute the byte size for the operation and the operand storage. This takes 103 // into account the size of the operation, its trailing objects, and its 104 // prefixed objects. 105 size_t byteSize = 106 totalSizeToAlloc<detail::OperandStorage, detail::OpProperties, 107 BlockOperand, Region, OpOperand>( 108 needsOperandStorage ? 1 : 0, opPropertiesAllocSize, numSuccessors, 109 numRegions, numOperands); 110 size_t prefixByteSize = llvm::alignTo( 111 Operation::prefixAllocSize(numTrailingResults, numInlineResults), 112 alignof(Operation)); 113 char *mallocMem = reinterpret_cast<char *>(malloc(byteSize + prefixByteSize)); 114 void *rawMem = mallocMem + prefixByteSize; 115 116 // Create the new Operation. 117 Operation *op = ::new (rawMem) Operation( 118 location, name, numResults, numSuccessors, numRegions, 119 opPropertiesAllocSize, attributes, properties, needsOperandStorage); 120 121 assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && 122 "unexpected successors in a non-terminator operation"); 123 124 // Initialize the results. 125 auto resultTypeIt = resultTypes.begin(); 126 for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt) 127 new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i); 128 for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) { 129 new (op->getOutOfLineOpResult(i)) 130 detail::OutOfLineOpResult(*resultTypeIt, i); 131 } 132 133 // Initialize the regions. 134 for (unsigned i = 0; i != numRegions; ++i) 135 new (&op->getRegion(i)) Region(op); 136 137 // Initialize the operands. 138 if (needsOperandStorage) { 139 new (&op->getOperandStorage()) detail::OperandStorage( 140 op, op->getTrailingObjects<OpOperand>(), operands); 141 } 142 143 // Initialize the successors. 144 auto blockOperands = op->getBlockOperands(); 145 for (unsigned i = 0; i != numSuccessors; ++i) 146 new (&blockOperands[i]) BlockOperand(op, successors[i]); 147 148 // This must be done after properties are initalized. 149 op->setAttrs(attributes); 150 151 return op; 152 } 153 154 Operation::Operation(Location location, OperationName name, unsigned numResults, 155 unsigned numSuccessors, unsigned numRegions, 156 int fullPropertiesStorageSize, DictionaryAttr attributes, 157 OpaqueProperties properties, bool hasOperandStorage) 158 : location(location), numResults(numResults), numSuccs(numSuccessors), 159 numRegions(numRegions), hasOperandStorage(hasOperandStorage), 160 propertiesStorageSize((fullPropertiesStorageSize + 7) / 8), name(name) { 161 assert(attributes && "unexpected null attribute dictionary"); 162 assert(fullPropertiesStorageSize <= propertiesCapacity && 163 "Properties size overflow"); 164 #ifndef NDEBUG 165 if (!getDialect() && !getContext()->allowsUnregisteredDialects()) 166 llvm::report_fatal_error( 167 name.getStringRef() + 168 " created with unregistered dialect. If this is intended, please call " 169 "allowUnregisteredDialects() on the MLIRContext, or use " 170 "-allow-unregistered-dialect with the MLIR tool used."); 171 #endif 172 if (fullPropertiesStorageSize) 173 name.initOpProperties(getPropertiesStorage(), properties); 174 } 175 176 // Operations are deleted through the destroy() member because they are 177 // allocated via malloc. 178 Operation::~Operation() { 179 assert(block == nullptr && "operation destroyed but still in a block"); 180 #ifndef NDEBUG 181 if (!use_empty()) { 182 { 183 InFlightDiagnostic diag = 184 emitOpError("operation destroyed but still has uses"); 185 for (Operation *user : getUsers()) 186 diag.attachNote(user->getLoc()) << "- use: " << *user << "\n"; 187 } 188 llvm::report_fatal_error("operation destroyed but still has uses"); 189 } 190 #endif 191 // Explicitly run the destructors for the operands. 192 if (hasOperandStorage) 193 getOperandStorage().~OperandStorage(); 194 195 // Explicitly run the destructors for the successors. 196 for (auto &successor : getBlockOperands()) 197 successor.~BlockOperand(); 198 199 // Explicitly destroy the regions. 200 for (auto ®ion : getRegions()) 201 region.~Region(); 202 if (propertiesStorageSize) 203 name.destroyOpProperties(getPropertiesStorage()); 204 } 205 206 /// Destroy this operation or one of its subclasses. 207 void Operation::destroy() { 208 // Operations may have additional prefixed allocation, which needs to be 209 // accounted for here when computing the address to free. 210 char *rawMem = reinterpret_cast<char *>(this) - 211 llvm::alignTo(prefixAllocSize(), alignof(Operation)); 212 this->~Operation(); 213 free(rawMem); 214 } 215 216 /// Return true if this operation is a proper ancestor of the `other` 217 /// operation. 218 bool Operation::isProperAncestor(Operation *other) { 219 while ((other = other->getParentOp())) 220 if (this == other) 221 return true; 222 return false; 223 } 224 225 /// Replace any uses of 'from' with 'to' within this operation. 226 void Operation::replaceUsesOfWith(Value from, Value to) { 227 if (from == to) 228 return; 229 for (auto &operand : getOpOperands()) 230 if (operand.get() == from) 231 operand.set(to); 232 } 233 234 /// Replace the current operands of this operation with the ones provided in 235 /// 'operands'. 236 void Operation::setOperands(ValueRange operands) { 237 if (LLVM_LIKELY(hasOperandStorage)) 238 return getOperandStorage().setOperands(this, operands); 239 assert(operands.empty() && "setting operands without an operand storage"); 240 } 241 242 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 243 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 244 /// than the range pointed to by 'start'+'length'. 245 void Operation::setOperands(unsigned start, unsigned length, 246 ValueRange operands) { 247 assert((start + length) <= getNumOperands() && 248 "invalid operand range specified"); 249 if (LLVM_LIKELY(hasOperandStorage)) 250 return getOperandStorage().setOperands(this, start, length, operands); 251 assert(operands.empty() && "setting operands without an operand storage"); 252 } 253 254 /// Insert the given operands into the operand list at the given 'index'. 255 void Operation::insertOperands(unsigned index, ValueRange operands) { 256 if (LLVM_LIKELY(hasOperandStorage)) 257 return setOperands(index, /*length=*/0, operands); 258 assert(operands.empty() && "inserting operands without an operand storage"); 259 } 260 261 //===----------------------------------------------------------------------===// 262 // Diagnostics 263 //===----------------------------------------------------------------------===// 264 265 /// Emit an error about fatal conditions with this operation, reporting up to 266 /// any diagnostic handlers that may be listening. 267 InFlightDiagnostic Operation::emitError(const Twine &message) { 268 InFlightDiagnostic diag = mlir::emitError(getLoc(), message); 269 if (getContext()->shouldPrintOpOnDiagnostic()) { 270 diag.attachNote(getLoc()) 271 .append("see current operation: ") 272 .appendOp(*this, OpPrintingFlags().printGenericOpForm()); 273 } 274 return diag; 275 } 276 277 /// Emit a warning about this operation, reporting up to any diagnostic 278 /// handlers that may be listening. 279 InFlightDiagnostic Operation::emitWarning(const Twine &message) { 280 InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); 281 if (getContext()->shouldPrintOpOnDiagnostic()) 282 diag.attachNote(getLoc()) << "see current operation: " << *this; 283 return diag; 284 } 285 286 /// Emit a remark about this operation, reporting up to any diagnostic 287 /// handlers that may be listening. 288 InFlightDiagnostic Operation::emitRemark(const Twine &message) { 289 InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); 290 if (getContext()->shouldPrintOpOnDiagnostic()) 291 diag.attachNote(getLoc()) << "see current operation: " << *this; 292 return diag; 293 } 294 295 DictionaryAttr Operation::getAttrDictionary() { 296 if (getPropertiesStorageSize()) { 297 NamedAttrList attrsList = attrs; 298 getName().populateInherentAttrs(this, attrsList); 299 return attrsList.getDictionary(getContext()); 300 } 301 return attrs; 302 } 303 304 void Operation::setAttrs(DictionaryAttr newAttrs) { 305 assert(newAttrs && "expected valid attribute dictionary"); 306 if (getPropertiesStorageSize()) { 307 // We're spliting the providing DictionaryAttr by removing the inherentAttr 308 // which will be stored in the properties. 309 SmallVector<NamedAttribute> discardableAttrs; 310 discardableAttrs.reserve(newAttrs.size()); 311 for (NamedAttribute attr : newAttrs) { 312 if (getInherentAttr(attr.getName())) 313 setInherentAttr(attr.getName(), attr.getValue()); 314 else 315 discardableAttrs.push_back(attr); 316 } 317 if (discardableAttrs.size() != newAttrs.size()) 318 newAttrs = DictionaryAttr::get(getContext(), discardableAttrs); 319 } 320 attrs = newAttrs; 321 } 322 void Operation::setAttrs(ArrayRef<NamedAttribute> newAttrs) { 323 if (getPropertiesStorageSize()) { 324 // We're spliting the providing array of attributes by removing the inherentAttr 325 // which will be stored in the properties. 326 SmallVector<NamedAttribute> discardableAttrs; 327 discardableAttrs.reserve(newAttrs.size()); 328 for (NamedAttribute attr : newAttrs) { 329 if (getInherentAttr(attr.getName())) 330 setInherentAttr(attr.getName(), attr.getValue()); 331 else 332 discardableAttrs.push_back(attr); 333 } 334 attrs = DictionaryAttr::get(getContext(), discardableAttrs); 335 return; 336 } 337 attrs = DictionaryAttr::get(getContext(), newAttrs); 338 } 339 340 std::optional<Attribute> Operation::getInherentAttr(StringRef name) { 341 return getName().getInherentAttr(this, name); 342 } 343 344 void Operation::setInherentAttr(StringAttr name, Attribute value) { 345 getName().setInherentAttr(this, name, value); 346 } 347 348 Attribute Operation::getPropertiesAsAttribute() { 349 std::optional<RegisteredOperationName> info = getRegisteredInfo(); 350 if (LLVM_UNLIKELY(!info)) 351 return *getPropertiesStorage().as<Attribute *>(); 352 return info->getOpPropertiesAsAttribute(this); 353 } 354 LogicalResult Operation::setPropertiesFromAttribute( 355 Attribute attr, function_ref<InFlightDiagnostic()> emitError) { 356 std::optional<RegisteredOperationName> info = getRegisteredInfo(); 357 if (LLVM_UNLIKELY(!info)) { 358 *getPropertiesStorage().as<Attribute *>() = attr; 359 return success(); 360 } 361 return info->setOpPropertiesFromAttribute( 362 this->getName(), this->getPropertiesStorage(), attr, emitError); 363 } 364 365 void Operation::copyProperties(OpaqueProperties rhs) { 366 name.copyOpProperties(getPropertiesStorage(), rhs); 367 } 368 369 llvm::hash_code Operation::hashProperties() { 370 return name.hashOpProperties(getPropertiesStorage()); 371 } 372 373 //===----------------------------------------------------------------------===// 374 // Operation Ordering 375 //===----------------------------------------------------------------------===// 376 377 constexpr unsigned Operation::kInvalidOrderIdx; 378 constexpr unsigned Operation::kOrderStride; 379 380 /// Given an operation 'other' that is within the same parent block, return 381 /// whether the current operation is before 'other' in the operation list 382 /// of the parent block. 383 /// Note: This function has an average complexity of O(1), but worst case may 384 /// take O(N) where N is the number of operations within the parent block. 385 bool Operation::isBeforeInBlock(Operation *other) { 386 assert(block && "Operations without parent blocks have no order."); 387 assert(other && other->block == block && 388 "Expected other operation to have the same parent block."); 389 // If the order of the block is already invalid, directly recompute the 390 // parent. 391 if (!block->isOpOrderValid()) { 392 block->recomputeOpOrder(); 393 } else { 394 // Update the order either operation if necessary. 395 updateOrderIfNecessary(); 396 other->updateOrderIfNecessary(); 397 } 398 399 return orderIndex < other->orderIndex; 400 } 401 402 /// Update the order index of this operation of this operation if necessary, 403 /// potentially recomputing the order of the parent block. 404 void Operation::updateOrderIfNecessary() { 405 assert(block && "expected valid parent"); 406 407 // If the order is valid for this operation there is nothing to do. 408 if (hasValidOrder()) 409 return; 410 Operation *blockFront = &block->front(); 411 Operation *blockBack = &block->back(); 412 413 // This method is expected to only be invoked on blocks with more than one 414 // operation. 415 assert(blockFront != blockBack && "expected more than one operation"); 416 417 // If the operation is at the end of the block. 418 if (this == blockBack) { 419 Operation *prevNode = getPrevNode(); 420 if (!prevNode->hasValidOrder()) 421 return block->recomputeOpOrder(); 422 423 // Add the stride to the previous operation. 424 orderIndex = prevNode->orderIndex + kOrderStride; 425 return; 426 } 427 428 // If this is the first operation try to use the next operation to compute the 429 // ordering. 430 if (this == blockFront) { 431 Operation *nextNode = getNextNode(); 432 if (!nextNode->hasValidOrder()) 433 return block->recomputeOpOrder(); 434 // There is no order to give this operation. 435 if (nextNode->orderIndex == 0) 436 return block->recomputeOpOrder(); 437 438 // If we can't use the stride, just take the middle value left. This is safe 439 // because we know there is at least one valid index to assign to. 440 if (nextNode->orderIndex <= kOrderStride) 441 orderIndex = (nextNode->orderIndex / 2); 442 else 443 orderIndex = kOrderStride; 444 return; 445 } 446 447 // Otherwise, this operation is between two others. Place this operation in 448 // the middle of the previous and next if possible. 449 Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); 450 if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) 451 return block->recomputeOpOrder(); 452 unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; 453 454 // Check to see if there is a valid order between the two. 455 if (prevOrder + 1 == nextOrder) 456 return block->recomputeOpOrder(); 457 orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); 458 } 459 460 //===----------------------------------------------------------------------===// 461 // ilist_traits for Operation 462 //===----------------------------------------------------------------------===// 463 464 auto llvm::ilist_detail::SpecificNodeAccess< 465 typename llvm::ilist_detail::compute_node_options< 466 ::mlir::Operation>::type>::getNodePtr(pointer n) -> node_type * { 467 return NodeAccess::getNodePtr<OptionsT>(n); 468 } 469 470 auto llvm::ilist_detail::SpecificNodeAccess< 471 typename llvm::ilist_detail::compute_node_options< 472 ::mlir::Operation>::type>::getNodePtr(const_pointer n) 473 -> const node_type * { 474 return NodeAccess::getNodePtr<OptionsT>(n); 475 } 476 477 auto llvm::ilist_detail::SpecificNodeAccess< 478 typename llvm::ilist_detail::compute_node_options< 479 ::mlir::Operation>::type>::getValuePtr(node_type *n) -> pointer { 480 return NodeAccess::getValuePtr<OptionsT>(n); 481 } 482 483 auto llvm::ilist_detail::SpecificNodeAccess< 484 typename llvm::ilist_detail::compute_node_options< 485 ::mlir::Operation>::type>::getValuePtr(const node_type *n) 486 -> const_pointer { 487 return NodeAccess::getValuePtr<OptionsT>(n); 488 } 489 490 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { 491 op->destroy(); 492 } 493 494 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { 495 size_t offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); 496 iplist<Operation> *anchor(static_cast<iplist<Operation> *>(this)); 497 return reinterpret_cast<Block *>(reinterpret_cast<char *>(anchor) - offset); 498 } 499 500 /// This is a trait method invoked when an operation is added to a block. We 501 /// keep the block pointer up to date. 502 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { 503 assert(!op->getBlock() && "already in an operation block!"); 504 op->block = getContainingBlock(); 505 506 // Invalidate the order on the operation. 507 op->orderIndex = Operation::kInvalidOrderIdx; 508 } 509 510 /// This is a trait method invoked when an operation is removed from a block. 511 /// We keep the block pointer up to date. 512 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { 513 assert(op->block && "not already in an operation block!"); 514 op->block = nullptr; 515 } 516 517 /// This is a trait method invoked when an operation is moved from one block 518 /// to another. We keep the block pointer up to date. 519 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( 520 ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { 521 Block *curParent = getContainingBlock(); 522 523 // Invalidate the ordering of the parent block. 524 curParent->invalidateOpOrder(); 525 526 // If we are transferring operations within the same block, the block 527 // pointer doesn't need to be updated. 528 if (curParent == otherList.getContainingBlock()) 529 return; 530 531 // Update the 'block' member of each operation. 532 for (; first != last; ++first) 533 first->block = curParent; 534 } 535 536 /// Remove this operation (and its descendants) from its Block and delete 537 /// all of them. 538 void Operation::erase() { 539 if (auto *parent = getBlock()) 540 parent->getOperations().erase(this); 541 else 542 destroy(); 543 } 544 545 /// Remove the operation from its parent block, but don't delete it. 546 void Operation::remove() { 547 if (Block *parent = getBlock()) 548 parent->getOperations().remove(this); 549 } 550 551 /// Unlink this operation from its current block and insert it right before 552 /// `existingOp` which may be in the same or another block in the same 553 /// function. 554 void Operation::moveBefore(Operation *existingOp) { 555 moveBefore(existingOp->getBlock(), existingOp->getIterator()); 556 } 557 558 /// Unlink this operation from its current basic block and insert it right 559 /// before `iterator` in the specified basic block. 560 void Operation::moveBefore(Block *block, 561 llvm::iplist<Operation>::iterator iterator) { 562 block->getOperations().splice(iterator, getBlock()->getOperations(), 563 getIterator()); 564 } 565 566 /// Unlink this operation from its current block and insert it right after 567 /// `existingOp` which may be in the same or another block in the same function. 568 void Operation::moveAfter(Operation *existingOp) { 569 moveAfter(existingOp->getBlock(), existingOp->getIterator()); 570 } 571 572 /// Unlink this operation from its current block and insert it right after 573 /// `iterator` in the specified block. 574 void Operation::moveAfter(Block *block, 575 llvm::iplist<Operation>::iterator iterator) { 576 assert(iterator != block->end() && "cannot move after end of block"); 577 moveBefore(block, std::next(iterator)); 578 } 579 580 /// This drops all operand uses from this operation, which is an essential 581 /// step in breaking cyclic dependences between references when they are to 582 /// be deleted. 583 void Operation::dropAllReferences() { 584 for (auto &op : getOpOperands()) 585 op.drop(); 586 587 for (auto ®ion : getRegions()) 588 region.dropAllReferences(); 589 590 for (auto &dest : getBlockOperands()) 591 dest.drop(); 592 } 593 594 /// This drops all uses of any values defined by this operation or its nested 595 /// regions, wherever they are located. 596 void Operation::dropAllDefinedValueUses() { 597 dropAllUses(); 598 599 for (auto ®ion : getRegions()) 600 for (auto &block : region) 601 block.dropAllDefinedValueUses(); 602 } 603 604 void Operation::setSuccessor(Block *block, unsigned index) { 605 assert(index < getNumSuccessors()); 606 getBlockOperands()[index].set(block); 607 } 608 609 #ifndef NDEBUG 610 /// Assert that the folded results (in case of values) have the same type as 611 /// the results of the given op. 612 static void checkFoldResultTypes(Operation *op, 613 SmallVectorImpl<OpFoldResult> &results) { 614 if (!results.empty()) 615 for (auto [ofr, opResult] : llvm::zip_equal(results, op->getResults())) 616 if (auto value = ofr.dyn_cast<Value>()) 617 assert(value.getType() == opResult.getType() && 618 "folder produced value of incorrect type"); 619 } 620 #endif // NDEBUG 621 622 /// Attempt to fold this operation using the Op's registered foldHook. 623 LogicalResult Operation::fold(ArrayRef<Attribute> operands, 624 SmallVectorImpl<OpFoldResult> &results) { 625 // If we have a registered operation definition matching this one, use it to 626 // try to constant fold the operation. 627 if (succeeded(name.foldHook(this, operands, results))) { 628 #ifndef NDEBUG 629 checkFoldResultTypes(this, results); 630 #endif // NDEBUG 631 return success(); 632 } 633 634 // Otherwise, fall back on the dialect hook to handle it. 635 Dialect *dialect = getDialect(); 636 if (!dialect) 637 return failure(); 638 639 auto *interface = dyn_cast<DialectFoldInterface>(dialect); 640 if (!interface) 641 return failure(); 642 643 LogicalResult status = interface->fold(this, operands, results); 644 #ifndef NDEBUG 645 if (succeeded(status)) 646 checkFoldResultTypes(this, results); 647 #endif // NDEBUG 648 return status; 649 } 650 651 LogicalResult Operation::fold(SmallVectorImpl<OpFoldResult> &results) { 652 // Check if any operands are constants. 653 SmallVector<Attribute> constants; 654 constants.assign(getNumOperands(), Attribute()); 655 for (unsigned i = 0, e = getNumOperands(); i != e; ++i) 656 matchPattern(getOperand(i), m_Constant(&constants[i])); 657 return fold(constants, results); 658 } 659 660 /// Emit an error with the op name prefixed, like "'dim' op " which is 661 /// convenient for verifiers. 662 InFlightDiagnostic Operation::emitOpError(const Twine &message) { 663 return emitError() << "'" << getName() << "' op " << message; 664 } 665 666 //===----------------------------------------------------------------------===// 667 // Operation Cloning 668 //===----------------------------------------------------------------------===// 669 670 Operation::CloneOptions::CloneOptions() 671 : cloneRegionsFlag(false), cloneOperandsFlag(false) {} 672 673 Operation::CloneOptions::CloneOptions(bool cloneRegions, bool cloneOperands) 674 : cloneRegionsFlag(cloneRegions), cloneOperandsFlag(cloneOperands) {} 675 676 Operation::CloneOptions Operation::CloneOptions::all() { 677 return CloneOptions().cloneRegions().cloneOperands(); 678 } 679 680 Operation::CloneOptions &Operation::CloneOptions::cloneRegions(bool enable) { 681 cloneRegionsFlag = enable; 682 return *this; 683 } 684 685 Operation::CloneOptions &Operation::CloneOptions::cloneOperands(bool enable) { 686 cloneOperandsFlag = enable; 687 return *this; 688 } 689 690 /// Create a deep copy of this operation but keep the operation regions empty. 691 /// Operands are remapped using `mapper` (if present), and `mapper` is updated 692 /// to contain the results. The `mapResults` flag specifies whether the results 693 /// of the cloned operation should be added to the map. 694 Operation *Operation::cloneWithoutRegions(IRMapping &mapper) { 695 return clone(mapper, CloneOptions::all().cloneRegions(false)); 696 } 697 698 Operation *Operation::cloneWithoutRegions() { 699 IRMapping mapper; 700 return cloneWithoutRegions(mapper); 701 } 702 703 /// Create a deep copy of this operation, remapping any operands that use 704 /// values outside of the operation using the map that is provided (leaving 705 /// them alone if no entry is present). Replaces references to cloned 706 /// sub-operations to the corresponding operation that is copied, and adds 707 /// those mappings to the map. 708 Operation *Operation::clone(IRMapping &mapper, CloneOptions options) { 709 SmallVector<Value, 8> operands; 710 SmallVector<Block *, 2> successors; 711 712 // Remap the operands. 713 if (options.shouldCloneOperands()) { 714 operands.reserve(getNumOperands()); 715 for (auto opValue : getOperands()) 716 operands.push_back(mapper.lookupOrDefault(opValue)); 717 } 718 719 // Remap the successors. 720 successors.reserve(getNumSuccessors()); 721 for (Block *successor : getSuccessors()) 722 successors.push_back(mapper.lookupOrDefault(successor)); 723 724 // Create the new operation. 725 auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, 726 getPropertiesStorage(), successors, getNumRegions()); 727 mapper.map(this, newOp); 728 729 // Clone the regions. 730 if (options.shouldCloneRegions()) { 731 for (unsigned i = 0; i != numRegions; ++i) 732 getRegion(i).cloneInto(&newOp->getRegion(i), mapper); 733 } 734 735 // Remember the mapping of any results. 736 for (unsigned i = 0, e = getNumResults(); i != e; ++i) 737 mapper.map(getResult(i), newOp->getResult(i)); 738 739 return newOp; 740 } 741 742 Operation *Operation::clone(CloneOptions options) { 743 IRMapping mapper; 744 return clone(mapper, options); 745 } 746 747 //===----------------------------------------------------------------------===// 748 // OpState trait class. 749 //===----------------------------------------------------------------------===// 750 751 // The fallback for the parser is to try for a dialect operation parser. 752 // Otherwise, reject the custom assembly form. 753 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { 754 if (auto parseFn = result.name.getDialect()->getParseOperationHook( 755 result.name.getStringRef())) 756 return (*parseFn)(parser, result); 757 return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); 758 } 759 760 // The fallback for the printer is to try for a dialect operation printer. 761 // Otherwise, it prints the generic form. 762 void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { 763 if (auto printFn = op->getDialect()->getOperationPrinter(op)) { 764 printOpName(op, p, defaultDialect); 765 printFn(op, p); 766 } else { 767 p.printGenericOp(op); 768 } 769 } 770 771 /// Print an operation name, eliding the dialect prefix if necessary and doesn't 772 /// lead to ambiguities. 773 void OpState::printOpName(Operation *op, OpAsmPrinter &p, 774 StringRef defaultDialect) { 775 StringRef name = op->getName().getStringRef(); 776 if (name.starts_with((defaultDialect + ".").str()) && name.count('.') == 1) 777 name = name.drop_front(defaultDialect.size() + 1); 778 p.getStream() << name; 779 } 780 781 /// Parse properties as a Attribute. 782 ParseResult OpState::genericParseProperties(OpAsmParser &parser, 783 Attribute &result) { 784 if (parser.parseLess() || parser.parseAttribute(result) || 785 parser.parseGreater()) 786 return failure(); 787 return success(); 788 } 789 790 /// Print the properties as a Attribute. 791 void OpState::genericPrintProperties(OpAsmPrinter &p, Attribute properties) { 792 p << "<" << properties << ">"; 793 } 794 795 /// Emit an error about fatal conditions with this operation, reporting up to 796 /// any diagnostic handlers that may be listening. 797 InFlightDiagnostic OpState::emitError(const Twine &message) { 798 return getOperation()->emitError(message); 799 } 800 801 /// Emit an error with the op name prefixed, like "'dim' op " which is 802 /// convenient for verifiers. 803 InFlightDiagnostic OpState::emitOpError(const Twine &message) { 804 return getOperation()->emitOpError(message); 805 } 806 807 /// Emit a warning about this operation, reporting up to any diagnostic 808 /// handlers that may be listening. 809 InFlightDiagnostic OpState::emitWarning(const Twine &message) { 810 return getOperation()->emitWarning(message); 811 } 812 813 /// Emit a remark about this operation, reporting up to any diagnostic 814 /// handlers that may be listening. 815 InFlightDiagnostic OpState::emitRemark(const Twine &message) { 816 return getOperation()->emitRemark(message); 817 } 818 819 //===----------------------------------------------------------------------===// 820 // Op Trait implementations 821 //===----------------------------------------------------------------------===// 822 823 LogicalResult 824 OpTrait::impl::foldCommutative(Operation *op, ArrayRef<Attribute> operands, 825 SmallVectorImpl<OpFoldResult> &results) { 826 // Nothing to fold if there are not at least 2 operands. 827 if (op->getNumOperands() < 2) 828 return failure(); 829 // Move all constant operands to the end. 830 OpOperand *operandsBegin = op->getOpOperands().begin(); 831 auto isNonConstant = [&](OpOperand &o) { 832 return !static_cast<bool>(operands[std::distance(operandsBegin, &o)]); 833 }; 834 auto *firstConstantIt = llvm::find_if_not(op->getOpOperands(), isNonConstant); 835 auto *newConstantIt = std::stable_partition( 836 firstConstantIt, op->getOpOperands().end(), isNonConstant); 837 // Return success if the op was modified. 838 return success(firstConstantIt != newConstantIt); 839 } 840 841 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { 842 if (op->getNumOperands() == 1) { 843 auto *argumentOp = op->getOperand(0).getDefiningOp(); 844 if (argumentOp && op->getName() == argumentOp->getName()) { 845 // Replace the outer operation output with the inner operation. 846 return op->getOperand(0); 847 } 848 } else if (op->getOperand(0) == op->getOperand(1)) { 849 return op->getOperand(0); 850 } 851 852 return {}; 853 } 854 855 OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { 856 auto *argumentOp = op->getOperand(0).getDefiningOp(); 857 if (argumentOp && op->getName() == argumentOp->getName()) { 858 // Replace the outer involutions output with inner's input. 859 return argumentOp->getOperand(0); 860 } 861 862 return {}; 863 } 864 865 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { 866 if (op->getNumOperands() != 0) 867 return op->emitOpError() << "requires zero operands"; 868 return success(); 869 } 870 871 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { 872 if (op->getNumOperands() != 1) 873 return op->emitOpError() << "requires a single operand"; 874 return success(); 875 } 876 877 LogicalResult OpTrait::impl::verifyNOperands(Operation *op, 878 unsigned numOperands) { 879 if (op->getNumOperands() != numOperands) { 880 return op->emitOpError() << "expected " << numOperands 881 << " operands, but found " << op->getNumOperands(); 882 } 883 return success(); 884 } 885 886 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, 887 unsigned numOperands) { 888 if (op->getNumOperands() < numOperands) 889 return op->emitOpError() 890 << "expected " << numOperands << " or more operands, but found " 891 << op->getNumOperands(); 892 return success(); 893 } 894 895 /// If this is a vector type, or a tensor type, return the scalar element type 896 /// that it is built around, otherwise return the type unmodified. 897 static Type getTensorOrVectorElementType(Type type) { 898 if (auto vec = llvm::dyn_cast<VectorType>(type)) 899 return vec.getElementType(); 900 901 // Look through tensor<vector<...>> to find the underlying element type. 902 if (auto tensor = llvm::dyn_cast<TensorType>(type)) 903 return getTensorOrVectorElementType(tensor.getElementType()); 904 return type; 905 } 906 907 LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { 908 // FIXME: Add back check for no side effects on operation. 909 // Currently adding it would cause the shared library build 910 // to fail since there would be a dependency of IR on SideEffectInterfaces 911 // which is cyclical. 912 return success(); 913 } 914 915 LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { 916 // FIXME: Add back check for no side effects on operation. 917 // Currently adding it would cause the shared library build 918 // to fail since there would be a dependency of IR on SideEffectInterfaces 919 // which is cyclical. 920 return success(); 921 } 922 923 LogicalResult 924 OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { 925 for (auto opType : op->getOperandTypes()) { 926 auto type = getTensorOrVectorElementType(opType); 927 if (!type.isSignlessIntOrIndex()) 928 return op->emitOpError() << "requires an integer or index type"; 929 } 930 return success(); 931 } 932 933 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { 934 for (auto opType : op->getOperandTypes()) { 935 auto type = getTensorOrVectorElementType(opType); 936 if (!llvm::isa<FloatType>(type)) 937 return op->emitOpError("requires a float type"); 938 } 939 return success(); 940 } 941 942 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { 943 // Zero or one operand always have the "same" type. 944 unsigned nOperands = op->getNumOperands(); 945 if (nOperands < 2) 946 return success(); 947 948 auto type = op->getOperand(0).getType(); 949 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) 950 if (opType != type) 951 return op->emitOpError() << "requires all operands to have the same type"; 952 return success(); 953 } 954 955 LogicalResult OpTrait::impl::verifyZeroRegions(Operation *op) { 956 if (op->getNumRegions() != 0) 957 return op->emitOpError() << "requires zero regions"; 958 return success(); 959 } 960 961 LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { 962 if (op->getNumRegions() != 1) 963 return op->emitOpError() << "requires one region"; 964 return success(); 965 } 966 967 LogicalResult OpTrait::impl::verifyNRegions(Operation *op, 968 unsigned numRegions) { 969 if (op->getNumRegions() != numRegions) 970 return op->emitOpError() << "expected " << numRegions << " regions"; 971 return success(); 972 } 973 974 LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, 975 unsigned numRegions) { 976 if (op->getNumRegions() < numRegions) 977 return op->emitOpError() << "expected " << numRegions << " or more regions"; 978 return success(); 979 } 980 981 LogicalResult OpTrait::impl::verifyZeroResults(Operation *op) { 982 if (op->getNumResults() != 0) 983 return op->emitOpError() << "requires zero results"; 984 return success(); 985 } 986 987 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { 988 if (op->getNumResults() != 1) 989 return op->emitOpError() << "requires one result"; 990 return success(); 991 } 992 993 LogicalResult OpTrait::impl::verifyNResults(Operation *op, 994 unsigned numOperands) { 995 if (op->getNumResults() != numOperands) 996 return op->emitOpError() << "expected " << numOperands << " results"; 997 return success(); 998 } 999 1000 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, 1001 unsigned numOperands) { 1002 if (op->getNumResults() < numOperands) 1003 return op->emitOpError() 1004 << "expected " << numOperands << " or more results"; 1005 return success(); 1006 } 1007 1008 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { 1009 if (failed(verifyAtLeastNOperands(op, 1))) 1010 return failure(); 1011 1012 if (failed(verifyCompatibleShapes(op->getOperandTypes()))) 1013 return op->emitOpError() << "requires the same shape for all operands"; 1014 1015 return success(); 1016 } 1017 1018 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { 1019 if (failed(verifyAtLeastNOperands(op, 1)) || 1020 failed(verifyAtLeastNResults(op, 1))) 1021 return failure(); 1022 1023 SmallVector<Type, 8> types(op->getOperandTypes()); 1024 types.append(llvm::to_vector<4>(op->getResultTypes())); 1025 1026 if (failed(verifyCompatibleShapes(types))) 1027 return op->emitOpError() 1028 << "requires the same shape for all operands and results"; 1029 1030 return success(); 1031 } 1032 1033 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { 1034 if (failed(verifyAtLeastNOperands(op, 1))) 1035 return failure(); 1036 auto elementType = getElementTypeOrSelf(op->getOperand(0)); 1037 1038 for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { 1039 if (getElementTypeOrSelf(operand) != elementType) 1040 return op->emitOpError("requires the same element type for all operands"); 1041 } 1042 1043 return success(); 1044 } 1045 1046 LogicalResult 1047 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { 1048 if (failed(verifyAtLeastNOperands(op, 1)) || 1049 failed(verifyAtLeastNResults(op, 1))) 1050 return failure(); 1051 1052 auto elementType = getElementTypeOrSelf(op->getResult(0)); 1053 1054 // Verify result element type matches first result's element type. 1055 for (auto result : llvm::drop_begin(op->getResults(), 1)) { 1056 if (getElementTypeOrSelf(result) != elementType) 1057 return op->emitOpError( 1058 "requires the same element type for all operands and results"); 1059 } 1060 1061 // Verify operand's element type matches first result's element type. 1062 for (auto operand : op->getOperands()) { 1063 if (getElementTypeOrSelf(operand) != elementType) 1064 return op->emitOpError( 1065 "requires the same element type for all operands and results"); 1066 } 1067 1068 return success(); 1069 } 1070 1071 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { 1072 if (failed(verifyAtLeastNOperands(op, 1)) || 1073 failed(verifyAtLeastNResults(op, 1))) 1074 return failure(); 1075 1076 auto type = op->getResult(0).getType(); 1077 auto elementType = getElementTypeOrSelf(type); 1078 Attribute encoding = nullptr; 1079 if (auto rankedType = dyn_cast<RankedTensorType>(type)) 1080 encoding = rankedType.getEncoding(); 1081 for (auto resultType : llvm::drop_begin(op->getResultTypes())) { 1082 if (getElementTypeOrSelf(resultType) != elementType || 1083 failed(verifyCompatibleShape(resultType, type))) 1084 return op->emitOpError() 1085 << "requires the same type for all operands and results"; 1086 if (encoding) 1087 if (auto rankedType = dyn_cast<RankedTensorType>(resultType); 1088 encoding != rankedType.getEncoding()) 1089 return op->emitOpError() 1090 << "requires the same encoding for all operands and results"; 1091 } 1092 for (auto opType : op->getOperandTypes()) { 1093 if (getElementTypeOrSelf(opType) != elementType || 1094 failed(verifyCompatibleShape(opType, type))) 1095 return op->emitOpError() 1096 << "requires the same type for all operands and results"; 1097 if (encoding) 1098 if (auto rankedType = dyn_cast<RankedTensorType>(opType); 1099 encoding != rankedType.getEncoding()) 1100 return op->emitOpError() 1101 << "requires the same encoding for all operands and results"; 1102 } 1103 return success(); 1104 } 1105 1106 LogicalResult OpTrait::impl::verifySameOperandsAndResultRank(Operation *op) { 1107 if (failed(verifyAtLeastNOperands(op, 1))) 1108 return failure(); 1109 1110 // delegate function that returns true if type is a shaped type with known 1111 // rank 1112 auto hasRank = [](const Type type) { 1113 if (auto shaped_type = dyn_cast<ShapedType>(type)) 1114 return shaped_type.hasRank(); 1115 1116 return false; 1117 }; 1118 1119 auto rankedOperandTypes = 1120 llvm::make_filter_range(op->getOperandTypes(), hasRank); 1121 auto rankedResultTypes = 1122 llvm::make_filter_range(op->getResultTypes(), hasRank); 1123 1124 // If all operands and results are unranked, then no further verification. 1125 if (rankedOperandTypes.empty() && rankedResultTypes.empty()) 1126 return success(); 1127 1128 // delegate function that returns rank of shaped type with known rank 1129 auto getRank = [](const Type type) { 1130 return type.cast<ShapedType>().getRank(); 1131 }; 1132 1133 auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin()) 1134 : getRank(*rankedResultTypes.begin()); 1135 1136 for (const auto type : rankedOperandTypes) { 1137 if (rank != getRank(type)) { 1138 return op->emitOpError("operands don't have matching ranks"); 1139 } 1140 } 1141 1142 for (const auto type : rankedResultTypes) { 1143 if (rank != getRank(type)) { 1144 return op->emitOpError("result type has different rank than operands"); 1145 } 1146 } 1147 1148 return success(); 1149 } 1150 1151 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { 1152 Block *block = op->getBlock(); 1153 // Verify that the operation is at the end of the respective parent block. 1154 if (!block || &block->back() != op) 1155 return op->emitOpError("must be the last operation in the parent block"); 1156 return success(); 1157 } 1158 1159 static LogicalResult verifyTerminatorSuccessors(Operation *op) { 1160 auto *parent = op->getParentRegion(); 1161 1162 // Verify that the operands lines up with the BB arguments in the successor. 1163 for (Block *succ : op->getSuccessors()) 1164 if (succ->getParent() != parent) 1165 return op->emitError("reference to block defined in another region"); 1166 return success(); 1167 } 1168 1169 LogicalResult OpTrait::impl::verifyZeroSuccessors(Operation *op) { 1170 if (op->getNumSuccessors() != 0) { 1171 return op->emitOpError("requires 0 successors but found ") 1172 << op->getNumSuccessors(); 1173 } 1174 return success(); 1175 } 1176 1177 LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { 1178 if (op->getNumSuccessors() != 1) { 1179 return op->emitOpError("requires 1 successor but found ") 1180 << op->getNumSuccessors(); 1181 } 1182 return verifyTerminatorSuccessors(op); 1183 } 1184 LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, 1185 unsigned numSuccessors) { 1186 if (op->getNumSuccessors() != numSuccessors) { 1187 return op->emitOpError("requires ") 1188 << numSuccessors << " successors but found " 1189 << op->getNumSuccessors(); 1190 } 1191 return verifyTerminatorSuccessors(op); 1192 } 1193 LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, 1194 unsigned numSuccessors) { 1195 if (op->getNumSuccessors() < numSuccessors) { 1196 return op->emitOpError("requires at least ") 1197 << numSuccessors << " successors but found " 1198 << op->getNumSuccessors(); 1199 } 1200 return verifyTerminatorSuccessors(op); 1201 } 1202 1203 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { 1204 for (auto resultType : op->getResultTypes()) { 1205 auto elementType = getTensorOrVectorElementType(resultType); 1206 bool isBoolType = elementType.isInteger(1); 1207 if (!isBoolType) 1208 return op->emitOpError() << "requires a bool result type"; 1209 } 1210 1211 return success(); 1212 } 1213 1214 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { 1215 for (auto resultType : op->getResultTypes()) 1216 if (!llvm::isa<FloatType>(getTensorOrVectorElementType(resultType))) 1217 return op->emitOpError() << "requires a floating point type"; 1218 1219 return success(); 1220 } 1221 1222 LogicalResult 1223 OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { 1224 for (auto resultType : op->getResultTypes()) 1225 if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) 1226 return op->emitOpError() << "requires an integer or index type"; 1227 return success(); 1228 } 1229 1230 LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op, 1231 StringRef attrName, 1232 StringRef valueGroupName, 1233 size_t expectedCount) { 1234 auto sizeAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrName); 1235 if (!sizeAttr) 1236 return op->emitOpError("requires dense i32 array attribute '") 1237 << attrName << "'"; 1238 1239 ArrayRef<int32_t> sizes = sizeAttr.asArrayRef(); 1240 if (llvm::any_of(sizes, [](int32_t element) { return element < 0; })) 1241 return op->emitOpError("'") 1242 << attrName << "' attribute cannot have negative elements"; 1243 1244 size_t totalCount = 1245 std::accumulate(sizes.begin(), sizes.end(), 0, 1246 [](unsigned all, int32_t one) { return all + one; }); 1247 1248 if (totalCount != expectedCount) 1249 return op->emitOpError() 1250 << valueGroupName << " count (" << expectedCount 1251 << ") does not match with the total size (" << totalCount 1252 << ") specified in attribute '" << attrName << "'"; 1253 return success(); 1254 } 1255 1256 LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, 1257 StringRef attrName) { 1258 return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands()); 1259 } 1260 1261 LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, 1262 StringRef attrName) { 1263 return verifyValueSizeAttr(op, attrName, "result", op->getNumResults()); 1264 } 1265 1266 LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { 1267 for (Region ®ion : op->getRegions()) { 1268 if (region.empty()) 1269 continue; 1270 1271 if (region.getNumArguments() != 0) { 1272 if (op->getNumRegions() > 1) 1273 return op->emitOpError("region #") 1274 << region.getRegionNumber() << " should have no arguments"; 1275 return op->emitOpError("region should have no arguments"); 1276 } 1277 } 1278 return success(); 1279 } 1280 1281 LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { 1282 auto isMappableType = [](Type type) { 1283 return llvm::isa<VectorType, TensorType>(type); 1284 }; 1285 auto resultMappableTypes = llvm::to_vector<1>( 1286 llvm::make_filter_range(op->getResultTypes(), isMappableType)); 1287 auto operandMappableTypes = llvm::to_vector<2>( 1288 llvm::make_filter_range(op->getOperandTypes(), isMappableType)); 1289 1290 // If the op only has scalar operand/result types, then we have nothing to 1291 // check. 1292 if (resultMappableTypes.empty() && operandMappableTypes.empty()) 1293 return success(); 1294 1295 if (!resultMappableTypes.empty() && operandMappableTypes.empty()) 1296 return op->emitOpError("if a result is non-scalar, then at least one " 1297 "operand must be non-scalar"); 1298 1299 assert(!operandMappableTypes.empty()); 1300 1301 if (resultMappableTypes.empty()) 1302 return op->emitOpError("if an operand is non-scalar, then there must be at " 1303 "least one non-scalar result"); 1304 1305 if (resultMappableTypes.size() != op->getNumResults()) 1306 return op->emitOpError( 1307 "if an operand is non-scalar, then all results must be non-scalar"); 1308 1309 SmallVector<Type, 4> types = llvm::to_vector<2>( 1310 llvm::concat<Type>(operandMappableTypes, resultMappableTypes)); 1311 TypeID expectedBaseTy = types.front().getTypeID(); 1312 if (!llvm::all_of(types, 1313 [&](Type t) { return t.getTypeID() == expectedBaseTy; }) || 1314 failed(verifyCompatibleShapes(types))) { 1315 return op->emitOpError() << "all non-scalar operands/results must have the " 1316 "same shape and base type"; 1317 } 1318 1319 return success(); 1320 } 1321 1322 /// Check for any values used by operations regions attached to the 1323 /// specified "IsIsolatedFromAbove" operation defined outside of it. 1324 LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) { 1325 assert(isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && 1326 "Intended to check IsolatedFromAbove ops"); 1327 1328 // List of regions to analyze. Each region is processed independently, with 1329 // respect to the common `limit` region, so we can look at them in any order. 1330 // Therefore, use a simple vector and push/pop back the current region. 1331 SmallVector<Region *, 8> pendingRegions; 1332 for (auto ®ion : isolatedOp->getRegions()) { 1333 pendingRegions.push_back(®ion); 1334 1335 // Traverse all operations in the region. 1336 while (!pendingRegions.empty()) { 1337 for (Operation &op : pendingRegions.pop_back_val()->getOps()) { 1338 for (Value operand : op.getOperands()) { 1339 // Check that any value that is used by an operation is defined in the 1340 // same region as either an operation result. 1341 auto *operandRegion = operand.getParentRegion(); 1342 if (!operandRegion) 1343 return op.emitError("operation's operand is unlinked"); 1344 if (!region.isAncestor(operandRegion)) { 1345 return op.emitOpError("using value defined outside the region") 1346 .attachNote(isolatedOp->getLoc()) 1347 << "required by region isolation constraints"; 1348 } 1349 } 1350 1351 // Schedule any regions in the operation for further checking. Don't 1352 // recurse into other IsolatedFromAbove ops, because they will check 1353 // themselves. 1354 if (op.getNumRegions() && 1355 !op.hasTrait<OpTrait::IsIsolatedFromAbove>()) { 1356 for (Region &subRegion : op.getRegions()) 1357 pendingRegions.push_back(&subRegion); 1358 } 1359 } 1360 } 1361 } 1362 1363 return success(); 1364 } 1365 1366 bool OpTrait::hasElementwiseMappableTraits(Operation *op) { 1367 return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() && 1368 op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>(); 1369 } 1370 1371 //===----------------------------------------------------------------------===// 1372 // Misc. utils 1373 //===----------------------------------------------------------------------===// 1374 1375 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the 1376 /// region's only block if it does not have a terminator already. If the region 1377 /// is empty, insert a new block first. `buildTerminatorOp` should return the 1378 /// terminator operation to insert. 1379 void impl::ensureRegionTerminator( 1380 Region ®ion, OpBuilder &builder, Location loc, 1381 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1382 OpBuilder::InsertionGuard guard(builder); 1383 if (region.empty()) 1384 builder.createBlock(®ion); 1385 1386 Block &block = region.back(); 1387 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) 1388 return; 1389 1390 builder.setInsertionPointToEnd(&block); 1391 builder.insert(buildTerminatorOp(builder, loc)); 1392 } 1393 1394 /// Create a simple OpBuilder and forward to the OpBuilder version of this 1395 /// function. 1396 void impl::ensureRegionTerminator( 1397 Region ®ion, Builder &builder, Location loc, 1398 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1399 OpBuilder opBuilder(builder.getContext()); 1400 ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); 1401 } 1402