1 //===- Operation.cpp - Operation support code -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Operation.h" 10 #include "mlir/IR/Attributes.h" 11 #include "mlir/IR/BuiltinAttributes.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/IRMapping.h" 15 #include "mlir/IR/Matchers.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "mlir/IR/OperationSupport.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Interfaces/FoldInterfaces.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/Support/ErrorHandling.h" 24 #include <numeric> 25 #include <optional> 26 27 using namespace mlir; 28 29 //===----------------------------------------------------------------------===// 30 // Operation 31 //===----------------------------------------------------------------------===// 32 33 /// Create a new Operation from operation state. 34 Operation *Operation::create(const OperationState &state) { 35 Operation *op = 36 create(state.location, state.name, state.types, state.operands, 37 state.attributes.getDictionary(state.getContext()), 38 state.properties, state.successors, state.regions); 39 if (LLVM_UNLIKELY(state.propertiesAttr)) { 40 assert(!state.properties); 41 LogicalResult result = 42 op->setPropertiesFromAttribute(state.propertiesAttr, 43 /*diagnostic=*/nullptr); 44 assert(result.succeeded() && "invalid properties in op creation"); 45 (void)result; 46 } 47 return op; 48 } 49 50 /// Create a new Operation with the specific fields. 51 Operation *Operation::create(Location location, OperationName name, 52 TypeRange resultTypes, ValueRange operands, 53 NamedAttrList &&attributes, 54 OpaqueProperties properties, BlockRange successors, 55 RegionRange regions) { 56 unsigned numRegions = regions.size(); 57 Operation *op = 58 create(location, name, resultTypes, operands, std::move(attributes), 59 properties, successors, numRegions); 60 for (unsigned i = 0; i < numRegions; ++i) 61 if (regions[i]) 62 op->getRegion(i).takeBody(*regions[i]); 63 return op; 64 } 65 66 /// Create a new Operation with the specific fields. 67 Operation *Operation::create(Location location, OperationName name, 68 TypeRange resultTypes, ValueRange operands, 69 NamedAttrList &&attributes, 70 OpaqueProperties properties, BlockRange successors, 71 unsigned numRegions) { 72 // Populate default attributes. 73 name.populateDefaultAttrs(attributes); 74 75 return create(location, name, resultTypes, operands, 76 attributes.getDictionary(location.getContext()), properties, 77 successors, numRegions); 78 } 79 80 /// Overload of create that takes an existing DictionaryAttr to avoid 81 /// unnecessarily uniquing a list of attributes. 82 Operation *Operation::create(Location location, OperationName name, 83 TypeRange resultTypes, ValueRange operands, 84 DictionaryAttr attributes, 85 OpaqueProperties properties, BlockRange successors, 86 unsigned numRegions) { 87 assert(llvm::all_of(resultTypes, [](Type t) { return t; }) && 88 "unexpected null result type"); 89 90 // We only need to allocate additional memory for a subset of results. 91 unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); 92 unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); 93 unsigned numSuccessors = successors.size(); 94 unsigned numOperands = operands.size(); 95 unsigned numResults = resultTypes.size(); 96 int opPropertiesAllocSize = llvm::alignTo<8>(name.getOpPropertyByteSize()); 97 98 // If the operation is known to have no operands, don't allocate an operand 99 // storage. 100 bool needsOperandStorage = 101 operands.empty() ? !name.hasTrait<OpTrait::ZeroOperands>() : true; 102 103 // Compute the byte size for the operation and the operand storage. This takes 104 // into account the size of the operation, its trailing objects, and its 105 // prefixed objects. 106 size_t byteSize = 107 totalSizeToAlloc<detail::OperandStorage, detail::OpProperties, 108 BlockOperand, Region, OpOperand>( 109 needsOperandStorage ? 1 : 0, opPropertiesAllocSize, numSuccessors, 110 numRegions, numOperands); 111 size_t prefixByteSize = llvm::alignTo( 112 Operation::prefixAllocSize(numTrailingResults, numInlineResults), 113 alignof(Operation)); 114 char *mallocMem = reinterpret_cast<char *>(malloc(byteSize + prefixByteSize)); 115 void *rawMem = mallocMem + prefixByteSize; 116 117 // Create the new Operation. 118 Operation *op = ::new (rawMem) Operation( 119 location, name, numResults, numSuccessors, numRegions, 120 opPropertiesAllocSize, attributes, properties, needsOperandStorage); 121 122 assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && 123 "unexpected successors in a non-terminator operation"); 124 125 // Initialize the results. 126 auto resultTypeIt = resultTypes.begin(); 127 for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt) 128 new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i); 129 for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) { 130 new (op->getOutOfLineOpResult(i)) 131 detail::OutOfLineOpResult(*resultTypeIt, i); 132 } 133 134 // Initialize the regions. 135 for (unsigned i = 0; i != numRegions; ++i) 136 new (&op->getRegion(i)) Region(op); 137 138 // Initialize the operands. 139 if (needsOperandStorage) { 140 new (&op->getOperandStorage()) detail::OperandStorage( 141 op, op->getTrailingObjects<OpOperand>(), operands); 142 } 143 144 // Initialize the successors. 145 auto blockOperands = op->getBlockOperands(); 146 for (unsigned i = 0; i != numSuccessors; ++i) 147 new (&blockOperands[i]) BlockOperand(op, successors[i]); 148 149 // This must be done after properties are initalized. 150 op->setAttrs(attributes); 151 152 return op; 153 } 154 155 Operation::Operation(Location location, OperationName name, unsigned numResults, 156 unsigned numSuccessors, unsigned numRegions, 157 int fullPropertiesStorageSize, DictionaryAttr attributes, 158 OpaqueProperties properties, bool hasOperandStorage) 159 : location(location), numResults(numResults), numSuccs(numSuccessors), 160 numRegions(numRegions), hasOperandStorage(hasOperandStorage), 161 propertiesStorageSize((fullPropertiesStorageSize + 7) / 8), name(name) { 162 assert(attributes && "unexpected null attribute dictionary"); 163 assert(fullPropertiesStorageSize <= propertiesCapacity && 164 "Properties size overflow"); 165 #ifndef NDEBUG 166 if (!getDialect() && !getContext()->allowsUnregisteredDialects()) 167 llvm::report_fatal_error( 168 name.getStringRef() + 169 " created with unregistered dialect. If this is intended, please call " 170 "allowUnregisteredDialects() on the MLIRContext, or use " 171 "-allow-unregistered-dialect with the MLIR tool used."); 172 #endif 173 if (fullPropertiesStorageSize) 174 name.initOpProperties(getPropertiesStorage(), properties); 175 } 176 177 // Operations are deleted through the destroy() member because they are 178 // allocated via malloc. 179 Operation::~Operation() { 180 assert(block == nullptr && "operation destroyed but still in a block"); 181 #ifndef NDEBUG 182 if (!use_empty()) { 183 { 184 InFlightDiagnostic diag = 185 emitOpError("operation destroyed but still has uses"); 186 for (Operation *user : getUsers()) 187 diag.attachNote(user->getLoc()) << "- use: " << *user << "\n"; 188 } 189 llvm::report_fatal_error("operation destroyed but still has uses"); 190 } 191 #endif 192 // Explicitly run the destructors for the operands. 193 if (hasOperandStorage) 194 getOperandStorage().~OperandStorage(); 195 196 // Explicitly run the destructors for the successors. 197 for (auto &successor : getBlockOperands()) 198 successor.~BlockOperand(); 199 200 // Explicitly destroy the regions. 201 for (auto ®ion : getRegions()) 202 region.~Region(); 203 if (propertiesStorageSize) 204 name.destroyOpProperties(getPropertiesStorage()); 205 } 206 207 /// Destroy this operation or one of its subclasses. 208 void Operation::destroy() { 209 // Operations may have additional prefixed allocation, which needs to be 210 // accounted for here when computing the address to free. 211 char *rawMem = reinterpret_cast<char *>(this) - 212 llvm::alignTo(prefixAllocSize(), alignof(Operation)); 213 this->~Operation(); 214 free(rawMem); 215 } 216 217 /// Return true if this operation is a proper ancestor of the `other` 218 /// operation. 219 bool Operation::isProperAncestor(Operation *other) { 220 while ((other = other->getParentOp())) 221 if (this == other) 222 return true; 223 return false; 224 } 225 226 /// Replace any uses of 'from' with 'to' within this operation. 227 void Operation::replaceUsesOfWith(Value from, Value to) { 228 if (from == to) 229 return; 230 for (auto &operand : getOpOperands()) 231 if (operand.get() == from) 232 operand.set(to); 233 } 234 235 /// Replace the current operands of this operation with the ones provided in 236 /// 'operands'. 237 void Operation::setOperands(ValueRange operands) { 238 if (LLVM_LIKELY(hasOperandStorage)) 239 return getOperandStorage().setOperands(this, operands); 240 assert(operands.empty() && "setting operands without an operand storage"); 241 } 242 243 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 244 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 245 /// than the range pointed to by 'start'+'length'. 246 void Operation::setOperands(unsigned start, unsigned length, 247 ValueRange operands) { 248 assert((start + length) <= getNumOperands() && 249 "invalid operand range specified"); 250 if (LLVM_LIKELY(hasOperandStorage)) 251 return getOperandStorage().setOperands(this, start, length, operands); 252 assert(operands.empty() && "setting operands without an operand storage"); 253 } 254 255 /// Insert the given operands into the operand list at the given 'index'. 256 void Operation::insertOperands(unsigned index, ValueRange operands) { 257 if (LLVM_LIKELY(hasOperandStorage)) 258 return setOperands(index, /*length=*/0, operands); 259 assert(operands.empty() && "inserting operands without an operand storage"); 260 } 261 262 //===----------------------------------------------------------------------===// 263 // Diagnostics 264 //===----------------------------------------------------------------------===// 265 266 /// Emit an error about fatal conditions with this operation, reporting up to 267 /// any diagnostic handlers that may be listening. 268 InFlightDiagnostic Operation::emitError(const Twine &message) { 269 InFlightDiagnostic diag = mlir::emitError(getLoc(), message); 270 if (getContext()->shouldPrintOpOnDiagnostic()) { 271 diag.attachNote(getLoc()) 272 .append("see current operation: ") 273 .appendOp(*this, OpPrintingFlags().printGenericOpForm()); 274 } 275 return diag; 276 } 277 278 /// Emit a warning about this operation, reporting up to any diagnostic 279 /// handlers that may be listening. 280 InFlightDiagnostic Operation::emitWarning(const Twine &message) { 281 InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); 282 if (getContext()->shouldPrintOpOnDiagnostic()) 283 diag.attachNote(getLoc()) << "see current operation: " << *this; 284 return diag; 285 } 286 287 /// Emit a remark about this operation, reporting up to any diagnostic 288 /// handlers that may be listening. 289 InFlightDiagnostic Operation::emitRemark(const Twine &message) { 290 InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); 291 if (getContext()->shouldPrintOpOnDiagnostic()) 292 diag.attachNote(getLoc()) << "see current operation: " << *this; 293 return diag; 294 } 295 296 DictionaryAttr Operation::getAttrDictionary() { 297 if (getPropertiesStorageSize()) { 298 NamedAttrList attrsList = attrs; 299 getName().populateInherentAttrs(this, attrsList); 300 return attrsList.getDictionary(getContext()); 301 } 302 return attrs; 303 } 304 305 void Operation::setAttrs(DictionaryAttr newAttrs) { 306 assert(newAttrs && "expected valid attribute dictionary"); 307 if (getPropertiesStorageSize()) { 308 // We're spliting the providing DictionaryAttr by removing the inherentAttr 309 // which will be stored in the properties. 310 SmallVector<NamedAttribute> discardableAttrs; 311 discardableAttrs.reserve(newAttrs.size()); 312 for (NamedAttribute attr : newAttrs) { 313 if (getInherentAttr(attr.getName())) 314 setInherentAttr(attr.getName(), attr.getValue()); 315 else 316 discardableAttrs.push_back(attr); 317 } 318 if (discardableAttrs.size() != newAttrs.size()) 319 newAttrs = DictionaryAttr::get(getContext(), discardableAttrs); 320 } 321 attrs = newAttrs; 322 } 323 void Operation::setAttrs(ArrayRef<NamedAttribute> newAttrs) { 324 if (getPropertiesStorageSize()) { 325 // We're spliting the providing array of attributes by removing the inherentAttr 326 // which will be stored in the properties. 327 SmallVector<NamedAttribute> discardableAttrs; 328 discardableAttrs.reserve(newAttrs.size()); 329 for (NamedAttribute attr : newAttrs) { 330 if (getInherentAttr(attr.getName())) 331 setInherentAttr(attr.getName(), attr.getValue()); 332 else 333 discardableAttrs.push_back(attr); 334 } 335 attrs = DictionaryAttr::get(getContext(), discardableAttrs); 336 return; 337 } 338 attrs = DictionaryAttr::get(getContext(), newAttrs); 339 } 340 341 std::optional<Attribute> Operation::getInherentAttr(StringRef name) { 342 return getName().getInherentAttr(this, name); 343 } 344 345 void Operation::setInherentAttr(StringAttr name, Attribute value) { 346 getName().setInherentAttr(this, name, value); 347 } 348 349 Attribute Operation::getPropertiesAsAttribute() { 350 std::optional<RegisteredOperationName> info = getRegisteredInfo(); 351 if (LLVM_UNLIKELY(!info)) 352 return *getPropertiesStorage().as<Attribute *>(); 353 return info->getOpPropertiesAsAttribute(this); 354 } 355 LogicalResult Operation::setPropertiesFromAttribute( 356 Attribute attr, function_ref<InFlightDiagnostic()> emitError) { 357 std::optional<RegisteredOperationName> info = getRegisteredInfo(); 358 if (LLVM_UNLIKELY(!info)) { 359 *getPropertiesStorage().as<Attribute *>() = attr; 360 return success(); 361 } 362 return info->setOpPropertiesFromAttribute( 363 this->getName(), this->getPropertiesStorage(), attr, emitError); 364 } 365 366 void Operation::copyProperties(OpaqueProperties rhs) { 367 name.copyOpProperties(getPropertiesStorage(), rhs); 368 } 369 370 llvm::hash_code Operation::hashProperties() { 371 return name.hashOpProperties(getPropertiesStorage()); 372 } 373 374 //===----------------------------------------------------------------------===// 375 // Operation Ordering 376 //===----------------------------------------------------------------------===// 377 378 constexpr unsigned Operation::kInvalidOrderIdx; 379 constexpr unsigned Operation::kOrderStride; 380 381 /// Given an operation 'other' that is within the same parent block, return 382 /// whether the current operation is before 'other' in the operation list 383 /// of the parent block. 384 /// Note: This function has an average complexity of O(1), but worst case may 385 /// take O(N) where N is the number of operations within the parent block. 386 bool Operation::isBeforeInBlock(Operation *other) { 387 assert(block && "Operations without parent blocks have no order."); 388 assert(other && other->block == block && 389 "Expected other operation to have the same parent block."); 390 // If the order of the block is already invalid, directly recompute the 391 // parent. 392 if (!block->isOpOrderValid()) { 393 block->recomputeOpOrder(); 394 } else { 395 // Update the order either operation if necessary. 396 updateOrderIfNecessary(); 397 other->updateOrderIfNecessary(); 398 } 399 400 return orderIndex < other->orderIndex; 401 } 402 403 /// Update the order index of this operation of this operation if necessary, 404 /// potentially recomputing the order of the parent block. 405 void Operation::updateOrderIfNecessary() { 406 assert(block && "expected valid parent"); 407 408 // If the order is valid for this operation there is nothing to do. 409 if (hasValidOrder() || llvm::hasSingleElement(*block)) 410 return; 411 Operation *blockFront = &block->front(); 412 Operation *blockBack = &block->back(); 413 414 // This method is expected to only be invoked on blocks with more than one 415 // operation. 416 assert(blockFront != blockBack && "expected more than one operation"); 417 418 // If the operation is at the end of the block. 419 if (this == blockBack) { 420 Operation *prevNode = getPrevNode(); 421 if (!prevNode->hasValidOrder()) 422 return block->recomputeOpOrder(); 423 424 // Add the stride to the previous operation. 425 orderIndex = prevNode->orderIndex + kOrderStride; 426 return; 427 } 428 429 // If this is the first operation try to use the next operation to compute the 430 // ordering. 431 if (this == blockFront) { 432 Operation *nextNode = getNextNode(); 433 if (!nextNode->hasValidOrder()) 434 return block->recomputeOpOrder(); 435 // There is no order to give this operation. 436 if (nextNode->orderIndex == 0) 437 return block->recomputeOpOrder(); 438 439 // If we can't use the stride, just take the middle value left. This is safe 440 // because we know there is at least one valid index to assign to. 441 if (nextNode->orderIndex <= kOrderStride) 442 orderIndex = (nextNode->orderIndex / 2); 443 else 444 orderIndex = kOrderStride; 445 return; 446 } 447 448 // Otherwise, this operation is between two others. Place this operation in 449 // the middle of the previous and next if possible. 450 Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); 451 if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) 452 return block->recomputeOpOrder(); 453 unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; 454 455 // Check to see if there is a valid order between the two. 456 if (prevOrder + 1 == nextOrder) 457 return block->recomputeOpOrder(); 458 orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); 459 } 460 461 //===----------------------------------------------------------------------===// 462 // ilist_traits for Operation 463 //===----------------------------------------------------------------------===// 464 465 auto llvm::ilist_detail::SpecificNodeAccess< 466 typename llvm::ilist_detail::compute_node_options< 467 ::mlir::Operation>::type>::getNodePtr(pointer n) -> node_type * { 468 return NodeAccess::getNodePtr<OptionsT>(n); 469 } 470 471 auto llvm::ilist_detail::SpecificNodeAccess< 472 typename llvm::ilist_detail::compute_node_options< 473 ::mlir::Operation>::type>::getNodePtr(const_pointer n) 474 -> const node_type * { 475 return NodeAccess::getNodePtr<OptionsT>(n); 476 } 477 478 auto llvm::ilist_detail::SpecificNodeAccess< 479 typename llvm::ilist_detail::compute_node_options< 480 ::mlir::Operation>::type>::getValuePtr(node_type *n) -> pointer { 481 return NodeAccess::getValuePtr<OptionsT>(n); 482 } 483 484 auto llvm::ilist_detail::SpecificNodeAccess< 485 typename llvm::ilist_detail::compute_node_options< 486 ::mlir::Operation>::type>::getValuePtr(const node_type *n) 487 -> const_pointer { 488 return NodeAccess::getValuePtr<OptionsT>(n); 489 } 490 491 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { 492 op->destroy(); 493 } 494 495 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { 496 size_t offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); 497 iplist<Operation> *anchor(static_cast<iplist<Operation> *>(this)); 498 return reinterpret_cast<Block *>(reinterpret_cast<char *>(anchor) - offset); 499 } 500 501 /// This is a trait method invoked when an operation is added to a block. We 502 /// keep the block pointer up to date. 503 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { 504 assert(!op->getBlock() && "already in an operation block!"); 505 op->block = getContainingBlock(); 506 507 // Invalidate the order on the operation. 508 op->orderIndex = Operation::kInvalidOrderIdx; 509 } 510 511 /// This is a trait method invoked when an operation is removed from a block. 512 /// We keep the block pointer up to date. 513 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { 514 assert(op->block && "not already in an operation block!"); 515 op->block = nullptr; 516 } 517 518 /// This is a trait method invoked when an operation is moved from one block 519 /// to another. We keep the block pointer up to date. 520 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( 521 ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { 522 Block *curParent = getContainingBlock(); 523 524 // Invalidate the ordering of the parent block. 525 curParent->invalidateOpOrder(); 526 527 // If we are transferring operations within the same block, the block 528 // pointer doesn't need to be updated. 529 if (curParent == otherList.getContainingBlock()) 530 return; 531 532 // Update the 'block' member of each operation. 533 for (; first != last; ++first) 534 first->block = curParent; 535 } 536 537 /// Remove this operation (and its descendants) from its Block and delete 538 /// all of them. 539 void Operation::erase() { 540 if (auto *parent = getBlock()) 541 parent->getOperations().erase(this); 542 else 543 destroy(); 544 } 545 546 /// Remove the operation from its parent block, but don't delete it. 547 void Operation::remove() { 548 if (Block *parent = getBlock()) 549 parent->getOperations().remove(this); 550 } 551 552 /// Unlink this operation from its current block and insert it right before 553 /// `existingOp` which may be in the same or another block in the same 554 /// function. 555 void Operation::moveBefore(Operation *existingOp) { 556 moveBefore(existingOp->getBlock(), existingOp->getIterator()); 557 } 558 559 /// Unlink this operation from its current basic block and insert it right 560 /// before `iterator` in the specified basic block. 561 void Operation::moveBefore(Block *block, 562 llvm::iplist<Operation>::iterator iterator) { 563 block->getOperations().splice(iterator, getBlock()->getOperations(), 564 getIterator()); 565 } 566 567 /// Unlink this operation from its current block and insert it right after 568 /// `existingOp` which may be in the same or another block in the same function. 569 void Operation::moveAfter(Operation *existingOp) { 570 moveAfter(existingOp->getBlock(), existingOp->getIterator()); 571 } 572 573 /// Unlink this operation from its current block and insert it right after 574 /// `iterator` in the specified block. 575 void Operation::moveAfter(Block *block, 576 llvm::iplist<Operation>::iterator iterator) { 577 assert(iterator != block->end() && "cannot move after end of block"); 578 moveBefore(block, std::next(iterator)); 579 } 580 581 /// This drops all operand uses from this operation, which is an essential 582 /// step in breaking cyclic dependences between references when they are to 583 /// be deleted. 584 void Operation::dropAllReferences() { 585 for (auto &op : getOpOperands()) 586 op.drop(); 587 588 for (auto ®ion : getRegions()) 589 region.dropAllReferences(); 590 591 for (auto &dest : getBlockOperands()) 592 dest.drop(); 593 } 594 595 /// This drops all uses of any values defined by this operation or its nested 596 /// regions, wherever they are located. 597 void Operation::dropAllDefinedValueUses() { 598 dropAllUses(); 599 600 for (auto ®ion : getRegions()) 601 for (auto &block : region) 602 block.dropAllDefinedValueUses(); 603 } 604 605 void Operation::setSuccessor(Block *block, unsigned index) { 606 assert(index < getNumSuccessors()); 607 getBlockOperands()[index].set(block); 608 } 609 610 #ifndef NDEBUG 611 /// Assert that the folded results (in case of values) have the same type as 612 /// the results of the given op. 613 static void checkFoldResultTypes(Operation *op, 614 SmallVectorImpl<OpFoldResult> &results) { 615 if (results.empty()) 616 return; 617 618 for (auto [ofr, opResult] : llvm::zip_equal(results, op->getResults())) { 619 if (auto value = dyn_cast<Value>(ofr)) { 620 if (value.getType() != opResult.getType()) { 621 op->emitOpError() << "folder produced a value of incorrect type: " 622 << value.getType() 623 << ", expected: " << opResult.getType(); 624 assert(false && "incorrect fold result type"); 625 } 626 } 627 } 628 } 629 #endif // NDEBUG 630 631 /// Attempt to fold this operation using the Op's registered foldHook. 632 LogicalResult Operation::fold(ArrayRef<Attribute> operands, 633 SmallVectorImpl<OpFoldResult> &results) { 634 // If we have a registered operation definition matching this one, use it to 635 // try to constant fold the operation. 636 if (succeeded(name.foldHook(this, operands, results))) { 637 #ifndef NDEBUG 638 checkFoldResultTypes(this, results); 639 #endif // NDEBUG 640 return success(); 641 } 642 643 // Otherwise, fall back on the dialect hook to handle it. 644 Dialect *dialect = getDialect(); 645 if (!dialect) 646 return failure(); 647 648 auto *interface = dyn_cast<DialectFoldInterface>(dialect); 649 if (!interface) 650 return failure(); 651 652 LogicalResult status = interface->fold(this, operands, results); 653 #ifndef NDEBUG 654 if (succeeded(status)) 655 checkFoldResultTypes(this, results); 656 #endif // NDEBUG 657 return status; 658 } 659 660 LogicalResult Operation::fold(SmallVectorImpl<OpFoldResult> &results) { 661 // Check if any operands are constants. 662 SmallVector<Attribute> constants; 663 constants.assign(getNumOperands(), Attribute()); 664 for (unsigned i = 0, e = getNumOperands(); i != e; ++i) 665 matchPattern(getOperand(i), m_Constant(&constants[i])); 666 return fold(constants, results); 667 } 668 669 /// Emit an error with the op name prefixed, like "'dim' op " which is 670 /// convenient for verifiers. 671 InFlightDiagnostic Operation::emitOpError(const Twine &message) { 672 return emitError() << "'" << getName() << "' op " << message; 673 } 674 675 //===----------------------------------------------------------------------===// 676 // Operation Cloning 677 //===----------------------------------------------------------------------===// 678 679 Operation::CloneOptions::CloneOptions() 680 : cloneRegionsFlag(false), cloneOperandsFlag(false) {} 681 682 Operation::CloneOptions::CloneOptions(bool cloneRegions, bool cloneOperands) 683 : cloneRegionsFlag(cloneRegions), cloneOperandsFlag(cloneOperands) {} 684 685 Operation::CloneOptions Operation::CloneOptions::all() { 686 return CloneOptions().cloneRegions().cloneOperands(); 687 } 688 689 Operation::CloneOptions &Operation::CloneOptions::cloneRegions(bool enable) { 690 cloneRegionsFlag = enable; 691 return *this; 692 } 693 694 Operation::CloneOptions &Operation::CloneOptions::cloneOperands(bool enable) { 695 cloneOperandsFlag = enable; 696 return *this; 697 } 698 699 /// Create a deep copy of this operation but keep the operation regions empty. 700 /// Operands are remapped using `mapper` (if present), and `mapper` is updated 701 /// to contain the results. The `mapResults` flag specifies whether the results 702 /// of the cloned operation should be added to the map. 703 Operation *Operation::cloneWithoutRegions(IRMapping &mapper) { 704 return clone(mapper, CloneOptions::all().cloneRegions(false)); 705 } 706 707 Operation *Operation::cloneWithoutRegions() { 708 IRMapping mapper; 709 return cloneWithoutRegions(mapper); 710 } 711 712 /// Create a deep copy of this operation, remapping any operands that use 713 /// values outside of the operation using the map that is provided (leaving 714 /// them alone if no entry is present). Replaces references to cloned 715 /// sub-operations to the corresponding operation that is copied, and adds 716 /// those mappings to the map. 717 Operation *Operation::clone(IRMapping &mapper, CloneOptions options) { 718 SmallVector<Value, 8> operands; 719 SmallVector<Block *, 2> successors; 720 721 // Remap the operands. 722 if (options.shouldCloneOperands()) { 723 operands.reserve(getNumOperands()); 724 for (auto opValue : getOperands()) 725 operands.push_back(mapper.lookupOrDefault(opValue)); 726 } 727 728 // Remap the successors. 729 successors.reserve(getNumSuccessors()); 730 for (Block *successor : getSuccessors()) 731 successors.push_back(mapper.lookupOrDefault(successor)); 732 733 // Create the new operation. 734 auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, 735 getPropertiesStorage(), successors, getNumRegions()); 736 mapper.map(this, newOp); 737 738 // Clone the regions. 739 if (options.shouldCloneRegions()) { 740 for (unsigned i = 0; i != numRegions; ++i) 741 getRegion(i).cloneInto(&newOp->getRegion(i), mapper); 742 } 743 744 // Remember the mapping of any results. 745 for (unsigned i = 0, e = getNumResults(); i != e; ++i) 746 mapper.map(getResult(i), newOp->getResult(i)); 747 748 return newOp; 749 } 750 751 Operation *Operation::clone(CloneOptions options) { 752 IRMapping mapper; 753 return clone(mapper, options); 754 } 755 756 //===----------------------------------------------------------------------===// 757 // OpState trait class. 758 //===----------------------------------------------------------------------===// 759 760 // The fallback for the parser is to try for a dialect operation parser. 761 // Otherwise, reject the custom assembly form. 762 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { 763 if (auto parseFn = result.name.getDialect()->getParseOperationHook( 764 result.name.getStringRef())) 765 return (*parseFn)(parser, result); 766 return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); 767 } 768 769 // The fallback for the printer is to try for a dialect operation printer. 770 // Otherwise, it prints the generic form. 771 void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { 772 if (auto printFn = op->getDialect()->getOperationPrinter(op)) { 773 printOpName(op, p, defaultDialect); 774 printFn(op, p); 775 } else { 776 p.printGenericOp(op); 777 } 778 } 779 780 /// Print an operation name, eliding the dialect prefix if necessary and doesn't 781 /// lead to ambiguities. 782 void OpState::printOpName(Operation *op, OpAsmPrinter &p, 783 StringRef defaultDialect) { 784 StringRef name = op->getName().getStringRef(); 785 if (name.starts_with((defaultDialect + ".").str()) && name.count('.') == 1) 786 name = name.drop_front(defaultDialect.size() + 1); 787 p.getStream() << name; 788 } 789 790 /// Parse properties as a Attribute. 791 ParseResult OpState::genericParseProperties(OpAsmParser &parser, 792 Attribute &result) { 793 if (succeeded(parser.parseOptionalLess())) { // The less is optional. 794 if (parser.parseAttribute(result) || parser.parseGreater()) 795 return failure(); 796 } 797 return success(); 798 } 799 800 /// Print the properties as a Attribute with names not included within 801 /// 'elidedProps' 802 void OpState::genericPrintProperties(OpAsmPrinter &p, Attribute properties, 803 ArrayRef<StringRef> elidedProps) { 804 if (!properties) 805 return; 806 auto dictAttr = dyn_cast_or_null<::mlir::DictionaryAttr>(properties); 807 if (dictAttr && !elidedProps.empty()) { 808 ArrayRef<NamedAttribute> attrs = dictAttr.getValue(); 809 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedProps.begin(), 810 elidedProps.end()); 811 bool atLeastOneAttr = llvm::any_of(attrs, [&](NamedAttribute attr) { 812 return !elidedAttrsSet.contains(attr.getName().strref()); 813 }); 814 if (atLeastOneAttr) { 815 p << "<"; 816 p.printOptionalAttrDict(dictAttr.getValue(), elidedProps); 817 p << ">"; 818 } 819 } else { 820 p << "<" << properties << ">"; 821 } 822 } 823 824 /// Emit an error about fatal conditions with this operation, reporting up to 825 /// any diagnostic handlers that may be listening. 826 InFlightDiagnostic OpState::emitError(const Twine &message) { 827 return getOperation()->emitError(message); 828 } 829 830 /// Emit an error with the op name prefixed, like "'dim' op " which is 831 /// convenient for verifiers. 832 InFlightDiagnostic OpState::emitOpError(const Twine &message) { 833 return getOperation()->emitOpError(message); 834 } 835 836 /// Emit a warning about this operation, reporting up to any diagnostic 837 /// handlers that may be listening. 838 InFlightDiagnostic OpState::emitWarning(const Twine &message) { 839 return getOperation()->emitWarning(message); 840 } 841 842 /// Emit a remark about this operation, reporting up to any diagnostic 843 /// handlers that may be listening. 844 InFlightDiagnostic OpState::emitRemark(const Twine &message) { 845 return getOperation()->emitRemark(message); 846 } 847 848 //===----------------------------------------------------------------------===// 849 // Op Trait implementations 850 //===----------------------------------------------------------------------===// 851 852 LogicalResult 853 OpTrait::impl::foldCommutative(Operation *op, ArrayRef<Attribute> operands, 854 SmallVectorImpl<OpFoldResult> &results) { 855 // Nothing to fold if there are not at least 2 operands. 856 if (op->getNumOperands() < 2) 857 return failure(); 858 // Move all constant operands to the end. 859 OpOperand *operandsBegin = op->getOpOperands().begin(); 860 auto isNonConstant = [&](OpOperand &o) { 861 return !static_cast<bool>(operands[std::distance(operandsBegin, &o)]); 862 }; 863 auto *firstConstantIt = llvm::find_if_not(op->getOpOperands(), isNonConstant); 864 auto *newConstantIt = std::stable_partition( 865 firstConstantIt, op->getOpOperands().end(), isNonConstant); 866 // Return success if the op was modified. 867 return success(firstConstantIt != newConstantIt); 868 } 869 870 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { 871 if (op->getNumOperands() == 1) { 872 auto *argumentOp = op->getOperand(0).getDefiningOp(); 873 if (argumentOp && op->getName() == argumentOp->getName()) { 874 // Replace the outer operation output with the inner operation. 875 return op->getOperand(0); 876 } 877 } else if (op->getOperand(0) == op->getOperand(1)) { 878 return op->getOperand(0); 879 } 880 881 return {}; 882 } 883 884 OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { 885 auto *argumentOp = op->getOperand(0).getDefiningOp(); 886 if (argumentOp && op->getName() == argumentOp->getName()) { 887 // Replace the outer involutions output with inner's input. 888 return argumentOp->getOperand(0); 889 } 890 891 return {}; 892 } 893 894 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { 895 if (op->getNumOperands() != 0) 896 return op->emitOpError() << "requires zero operands"; 897 return success(); 898 } 899 900 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { 901 if (op->getNumOperands() != 1) 902 return op->emitOpError() << "requires a single operand"; 903 return success(); 904 } 905 906 LogicalResult OpTrait::impl::verifyNOperands(Operation *op, 907 unsigned numOperands) { 908 if (op->getNumOperands() != numOperands) { 909 return op->emitOpError() << "expected " << numOperands 910 << " operands, but found " << op->getNumOperands(); 911 } 912 return success(); 913 } 914 915 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, 916 unsigned numOperands) { 917 if (op->getNumOperands() < numOperands) 918 return op->emitOpError() 919 << "expected " << numOperands << " or more operands, but found " 920 << op->getNumOperands(); 921 return success(); 922 } 923 924 /// If this is a vector type, or a tensor type, return the scalar element type 925 /// that it is built around, otherwise return the type unmodified. 926 static Type getTensorOrVectorElementType(Type type) { 927 if (auto vec = llvm::dyn_cast<VectorType>(type)) 928 return vec.getElementType(); 929 930 // Look through tensor<vector<...>> to find the underlying element type. 931 if (auto tensor = llvm::dyn_cast<TensorType>(type)) 932 return getTensorOrVectorElementType(tensor.getElementType()); 933 return type; 934 } 935 936 LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { 937 // FIXME: Add back check for no side effects on operation. 938 // Currently adding it would cause the shared library build 939 // to fail since there would be a dependency of IR on SideEffectInterfaces 940 // which is cyclical. 941 return success(); 942 } 943 944 LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { 945 // FIXME: Add back check for no side effects on operation. 946 // Currently adding it would cause the shared library build 947 // to fail since there would be a dependency of IR on SideEffectInterfaces 948 // which is cyclical. 949 return success(); 950 } 951 952 LogicalResult 953 OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { 954 for (auto opType : op->getOperandTypes()) { 955 auto type = getTensorOrVectorElementType(opType); 956 if (!type.isSignlessIntOrIndex()) 957 return op->emitOpError() << "requires an integer or index type"; 958 } 959 return success(); 960 } 961 962 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { 963 for (auto opType : op->getOperandTypes()) { 964 auto type = getTensorOrVectorElementType(opType); 965 if (!llvm::isa<FloatType>(type)) 966 return op->emitOpError("requires a float type"); 967 } 968 return success(); 969 } 970 971 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { 972 // Zero or one operand always have the "same" type. 973 unsigned nOperands = op->getNumOperands(); 974 if (nOperands < 2) 975 return success(); 976 977 auto type = op->getOperand(0).getType(); 978 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) 979 if (opType != type) 980 return op->emitOpError() << "requires all operands to have the same type"; 981 return success(); 982 } 983 984 LogicalResult OpTrait::impl::verifyZeroRegions(Operation *op) { 985 if (op->getNumRegions() != 0) 986 return op->emitOpError() << "requires zero regions"; 987 return success(); 988 } 989 990 LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { 991 if (op->getNumRegions() != 1) 992 return op->emitOpError() << "requires one region"; 993 return success(); 994 } 995 996 LogicalResult OpTrait::impl::verifyNRegions(Operation *op, 997 unsigned numRegions) { 998 if (op->getNumRegions() != numRegions) 999 return op->emitOpError() << "expected " << numRegions << " regions"; 1000 return success(); 1001 } 1002 1003 LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, 1004 unsigned numRegions) { 1005 if (op->getNumRegions() < numRegions) 1006 return op->emitOpError() << "expected " << numRegions << " or more regions"; 1007 return success(); 1008 } 1009 1010 LogicalResult OpTrait::impl::verifyZeroResults(Operation *op) { 1011 if (op->getNumResults() != 0) 1012 return op->emitOpError() << "requires zero results"; 1013 return success(); 1014 } 1015 1016 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { 1017 if (op->getNumResults() != 1) 1018 return op->emitOpError() << "requires one result"; 1019 return success(); 1020 } 1021 1022 LogicalResult OpTrait::impl::verifyNResults(Operation *op, 1023 unsigned numOperands) { 1024 if (op->getNumResults() != numOperands) 1025 return op->emitOpError() << "expected " << numOperands << " results"; 1026 return success(); 1027 } 1028 1029 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, 1030 unsigned numOperands) { 1031 if (op->getNumResults() < numOperands) 1032 return op->emitOpError() 1033 << "expected " << numOperands << " or more results"; 1034 return success(); 1035 } 1036 1037 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { 1038 if (failed(verifyAtLeastNOperands(op, 1))) 1039 return failure(); 1040 1041 if (failed(verifyCompatibleShapes(op->getOperandTypes()))) 1042 return op->emitOpError() << "requires the same shape for all operands"; 1043 1044 return success(); 1045 } 1046 1047 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { 1048 if (failed(verifyAtLeastNOperands(op, 1)) || 1049 failed(verifyAtLeastNResults(op, 1))) 1050 return failure(); 1051 1052 SmallVector<Type, 8> types(op->getOperandTypes()); 1053 types.append(llvm::to_vector<4>(op->getResultTypes())); 1054 1055 if (failed(verifyCompatibleShapes(types))) 1056 return op->emitOpError() 1057 << "requires the same shape for all operands and results"; 1058 1059 return success(); 1060 } 1061 1062 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { 1063 if (failed(verifyAtLeastNOperands(op, 1))) 1064 return failure(); 1065 auto elementType = getElementTypeOrSelf(op->getOperand(0)); 1066 1067 for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { 1068 if (getElementTypeOrSelf(operand) != elementType) 1069 return op->emitOpError("requires the same element type for all operands"); 1070 } 1071 1072 return success(); 1073 } 1074 1075 LogicalResult 1076 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { 1077 if (failed(verifyAtLeastNOperands(op, 1)) || 1078 failed(verifyAtLeastNResults(op, 1))) 1079 return failure(); 1080 1081 auto elementType = getElementTypeOrSelf(op->getResult(0)); 1082 1083 // Verify result element type matches first result's element type. 1084 for (auto result : llvm::drop_begin(op->getResults(), 1)) { 1085 if (getElementTypeOrSelf(result) != elementType) 1086 return op->emitOpError( 1087 "requires the same element type for all operands and results"); 1088 } 1089 1090 // Verify operand's element type matches first result's element type. 1091 for (auto operand : op->getOperands()) { 1092 if (getElementTypeOrSelf(operand) != elementType) 1093 return op->emitOpError( 1094 "requires the same element type for all operands and results"); 1095 } 1096 1097 return success(); 1098 } 1099 1100 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { 1101 if (failed(verifyAtLeastNOperands(op, 1)) || 1102 failed(verifyAtLeastNResults(op, 1))) 1103 return failure(); 1104 1105 auto type = op->getResult(0).getType(); 1106 auto elementType = getElementTypeOrSelf(type); 1107 Attribute encoding = nullptr; 1108 if (auto rankedType = dyn_cast<RankedTensorType>(type)) 1109 encoding = rankedType.getEncoding(); 1110 for (auto resultType : llvm::drop_begin(op->getResultTypes())) { 1111 if (getElementTypeOrSelf(resultType) != elementType || 1112 failed(verifyCompatibleShape(resultType, type))) 1113 return op->emitOpError() 1114 << "requires the same type for all operands and results"; 1115 if (encoding) 1116 if (auto rankedType = dyn_cast<RankedTensorType>(resultType); 1117 encoding != rankedType.getEncoding()) 1118 return op->emitOpError() 1119 << "requires the same encoding for all operands and results"; 1120 } 1121 for (auto opType : op->getOperandTypes()) { 1122 if (getElementTypeOrSelf(opType) != elementType || 1123 failed(verifyCompatibleShape(opType, type))) 1124 return op->emitOpError() 1125 << "requires the same type for all operands and results"; 1126 if (encoding) 1127 if (auto rankedType = dyn_cast<RankedTensorType>(opType); 1128 encoding != rankedType.getEncoding()) 1129 return op->emitOpError() 1130 << "requires the same encoding for all operands and results"; 1131 } 1132 return success(); 1133 } 1134 1135 LogicalResult OpTrait::impl::verifySameOperandsAndResultRank(Operation *op) { 1136 if (failed(verifyAtLeastNOperands(op, 1))) 1137 return failure(); 1138 1139 // delegate function that returns true if type is a shaped type with known 1140 // rank 1141 auto hasRank = [](const Type type) { 1142 if (auto shapedType = dyn_cast<ShapedType>(type)) 1143 return shapedType.hasRank(); 1144 1145 return false; 1146 }; 1147 1148 auto rankedOperandTypes = 1149 llvm::make_filter_range(op->getOperandTypes(), hasRank); 1150 auto rankedResultTypes = 1151 llvm::make_filter_range(op->getResultTypes(), hasRank); 1152 1153 // If all operands and results are unranked, then no further verification. 1154 if (rankedOperandTypes.empty() && rankedResultTypes.empty()) 1155 return success(); 1156 1157 // delegate function that returns rank of shaped type with known rank 1158 auto getRank = [](const Type type) { 1159 return cast<ShapedType>(type).getRank(); 1160 }; 1161 1162 auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin()) 1163 : getRank(*rankedResultTypes.begin()); 1164 1165 for (const auto type : rankedOperandTypes) { 1166 if (rank != getRank(type)) { 1167 return op->emitOpError("operands don't have matching ranks"); 1168 } 1169 } 1170 1171 for (const auto type : rankedResultTypes) { 1172 if (rank != getRank(type)) { 1173 return op->emitOpError("result type has different rank than operands"); 1174 } 1175 } 1176 1177 return success(); 1178 } 1179 1180 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { 1181 Block *block = op->getBlock(); 1182 // Verify that the operation is at the end of the respective parent block. 1183 if (!block || &block->back() != op) 1184 return op->emitOpError("must be the last operation in the parent block"); 1185 return success(); 1186 } 1187 1188 static LogicalResult verifyTerminatorSuccessors(Operation *op) { 1189 auto *parent = op->getParentRegion(); 1190 1191 // Verify that the operands lines up with the BB arguments in the successor. 1192 for (Block *succ : op->getSuccessors()) 1193 if (succ->getParent() != parent) 1194 return op->emitError("reference to block defined in another region"); 1195 return success(); 1196 } 1197 1198 LogicalResult OpTrait::impl::verifyZeroSuccessors(Operation *op) { 1199 if (op->getNumSuccessors() != 0) { 1200 return op->emitOpError("requires 0 successors but found ") 1201 << op->getNumSuccessors(); 1202 } 1203 return success(); 1204 } 1205 1206 LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { 1207 if (op->getNumSuccessors() != 1) { 1208 return op->emitOpError("requires 1 successor but found ") 1209 << op->getNumSuccessors(); 1210 } 1211 return verifyTerminatorSuccessors(op); 1212 } 1213 LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, 1214 unsigned numSuccessors) { 1215 if (op->getNumSuccessors() != numSuccessors) { 1216 return op->emitOpError("requires ") 1217 << numSuccessors << " successors but found " 1218 << op->getNumSuccessors(); 1219 } 1220 return verifyTerminatorSuccessors(op); 1221 } 1222 LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, 1223 unsigned numSuccessors) { 1224 if (op->getNumSuccessors() < numSuccessors) { 1225 return op->emitOpError("requires at least ") 1226 << numSuccessors << " successors but found " 1227 << op->getNumSuccessors(); 1228 } 1229 return verifyTerminatorSuccessors(op); 1230 } 1231 1232 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { 1233 for (auto resultType : op->getResultTypes()) { 1234 auto elementType = getTensorOrVectorElementType(resultType); 1235 bool isBoolType = elementType.isInteger(1); 1236 if (!isBoolType) 1237 return op->emitOpError() << "requires a bool result type"; 1238 } 1239 1240 return success(); 1241 } 1242 1243 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { 1244 for (auto resultType : op->getResultTypes()) 1245 if (!llvm::isa<FloatType>(getTensorOrVectorElementType(resultType))) 1246 return op->emitOpError() << "requires a floating point type"; 1247 1248 return success(); 1249 } 1250 1251 LogicalResult 1252 OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { 1253 for (auto resultType : op->getResultTypes()) 1254 if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) 1255 return op->emitOpError() << "requires an integer or index type"; 1256 return success(); 1257 } 1258 1259 LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op, 1260 StringRef attrName, 1261 StringRef valueGroupName, 1262 size_t expectedCount) { 1263 auto sizeAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrName); 1264 if (!sizeAttr) 1265 return op->emitOpError("requires dense i32 array attribute '") 1266 << attrName << "'"; 1267 1268 ArrayRef<int32_t> sizes = sizeAttr.asArrayRef(); 1269 if (llvm::any_of(sizes, [](int32_t element) { return element < 0; })) 1270 return op->emitOpError("'") 1271 << attrName << "' attribute cannot have negative elements"; 1272 1273 size_t totalCount = 1274 std::accumulate(sizes.begin(), sizes.end(), 0, 1275 [](unsigned all, int32_t one) { return all + one; }); 1276 1277 if (totalCount != expectedCount) 1278 return op->emitOpError() 1279 << valueGroupName << " count (" << expectedCount 1280 << ") does not match with the total size (" << totalCount 1281 << ") specified in attribute '" << attrName << "'"; 1282 return success(); 1283 } 1284 1285 LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, 1286 StringRef attrName) { 1287 return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands()); 1288 } 1289 1290 LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, 1291 StringRef attrName) { 1292 return verifyValueSizeAttr(op, attrName, "result", op->getNumResults()); 1293 } 1294 1295 LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { 1296 for (Region ®ion : op->getRegions()) { 1297 if (region.empty()) 1298 continue; 1299 1300 if (region.getNumArguments() != 0) { 1301 if (op->getNumRegions() > 1) 1302 return op->emitOpError("region #") 1303 << region.getRegionNumber() << " should have no arguments"; 1304 return op->emitOpError("region should have no arguments"); 1305 } 1306 } 1307 return success(); 1308 } 1309 1310 LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { 1311 auto isMappableType = llvm::IsaPred<VectorType, TensorType>; 1312 auto resultMappableTypes = 1313 llvm::filter_to_vector<1>(op->getResultTypes(), isMappableType); 1314 auto operandMappableTypes = 1315 llvm::filter_to_vector<2>(op->getOperandTypes(), isMappableType); 1316 1317 // If the op only has scalar operand/result types, then we have nothing to 1318 // check. 1319 if (resultMappableTypes.empty() && operandMappableTypes.empty()) 1320 return success(); 1321 1322 if (!resultMappableTypes.empty() && operandMappableTypes.empty()) 1323 return op->emitOpError("if a result is non-scalar, then at least one " 1324 "operand must be non-scalar"); 1325 1326 assert(!operandMappableTypes.empty()); 1327 1328 if (resultMappableTypes.empty()) 1329 return op->emitOpError("if an operand is non-scalar, then there must be at " 1330 "least one non-scalar result"); 1331 1332 if (resultMappableTypes.size() != op->getNumResults()) 1333 return op->emitOpError( 1334 "if an operand is non-scalar, then all results must be non-scalar"); 1335 1336 SmallVector<Type, 4> types = llvm::to_vector<2>( 1337 llvm::concat<Type>(operandMappableTypes, resultMappableTypes)); 1338 TypeID expectedBaseTy = types.front().getTypeID(); 1339 if (!llvm::all_of(types, 1340 [&](Type t) { return t.getTypeID() == expectedBaseTy; }) || 1341 failed(verifyCompatibleShapes(types))) { 1342 return op->emitOpError() << "all non-scalar operands/results must have the " 1343 "same shape and base type"; 1344 } 1345 1346 return success(); 1347 } 1348 1349 /// Check for any values used by operations regions attached to the 1350 /// specified "IsIsolatedFromAbove" operation defined outside of it. 1351 LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) { 1352 assert(isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && 1353 "Intended to check IsolatedFromAbove ops"); 1354 1355 // List of regions to analyze. Each region is processed independently, with 1356 // respect to the common `limit` region, so we can look at them in any order. 1357 // Therefore, use a simple vector and push/pop back the current region. 1358 SmallVector<Region *, 8> pendingRegions; 1359 for (auto ®ion : isolatedOp->getRegions()) { 1360 pendingRegions.push_back(®ion); 1361 1362 // Traverse all operations in the region. 1363 while (!pendingRegions.empty()) { 1364 for (Operation &op : pendingRegions.pop_back_val()->getOps()) { 1365 for (Value operand : op.getOperands()) { 1366 // Check that any value that is used by an operation is defined in the 1367 // same region as either an operation result. 1368 auto *operandRegion = operand.getParentRegion(); 1369 if (!operandRegion) 1370 return op.emitError("operation's operand is unlinked"); 1371 if (!region.isAncestor(operandRegion)) { 1372 return op.emitOpError("using value defined outside the region") 1373 .attachNote(isolatedOp->getLoc()) 1374 << "required by region isolation constraints"; 1375 } 1376 } 1377 1378 // Schedule any regions in the operation for further checking. Don't 1379 // recurse into other IsolatedFromAbove ops, because they will check 1380 // themselves. 1381 if (op.getNumRegions() && 1382 !op.hasTrait<OpTrait::IsIsolatedFromAbove>()) { 1383 for (Region &subRegion : op.getRegions()) 1384 pendingRegions.push_back(&subRegion); 1385 } 1386 } 1387 } 1388 } 1389 1390 return success(); 1391 } 1392 1393 bool OpTrait::hasElementwiseMappableTraits(Operation *op) { 1394 return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() && 1395 op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>(); 1396 } 1397 1398 //===----------------------------------------------------------------------===// 1399 // Misc. utils 1400 //===----------------------------------------------------------------------===// 1401 1402 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the 1403 /// region's only block if it does not have a terminator already. If the region 1404 /// is empty, insert a new block first. `buildTerminatorOp` should return the 1405 /// terminator operation to insert. 1406 void impl::ensureRegionTerminator( 1407 Region ®ion, OpBuilder &builder, Location loc, 1408 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1409 OpBuilder::InsertionGuard guard(builder); 1410 if (region.empty()) 1411 builder.createBlock(®ion); 1412 1413 Block &block = region.back(); 1414 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) 1415 return; 1416 1417 builder.setInsertionPointToEnd(&block); 1418 builder.insert(buildTerminatorOp(builder, loc)); 1419 } 1420 1421 /// Create a simple OpBuilder and forward to the OpBuilder version of this 1422 /// function. 1423 void impl::ensureRegionTerminator( 1424 Region ®ion, Builder &builder, Location loc, 1425 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1426 OpBuilder opBuilder(builder.getContext()); 1427 ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); 1428 } 1429