1 //===- OperationSupport.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains out-of-line implementations of the support types that 10 // Operation and related classes build on top of. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/IR/OperationSupport.h" 15 #include "mlir/IR/BuiltinAttributes.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/OpDefinition.h" 18 #include "llvm/ADT/BitVector.h" 19 #include "llvm/Support/SHA1.h" 20 #include <numeric> 21 #include <optional> 22 23 using namespace mlir; 24 25 //===----------------------------------------------------------------------===// 26 // NamedAttrList 27 //===----------------------------------------------------------------------===// 28 29 NamedAttrList::NamedAttrList(ArrayRef<NamedAttribute> attributes) { 30 assign(attributes.begin(), attributes.end()); 31 } 32 33 NamedAttrList::NamedAttrList(DictionaryAttr attributes) 34 : NamedAttrList(attributes ? attributes.getValue() 35 : ArrayRef<NamedAttribute>()) { 36 dictionarySorted.setPointerAndInt(attributes, true); 37 } 38 39 NamedAttrList::NamedAttrList(const_iterator inStart, const_iterator inEnd) { 40 assign(inStart, inEnd); 41 } 42 43 ArrayRef<NamedAttribute> NamedAttrList::getAttrs() const { return attrs; } 44 45 std::optional<NamedAttribute> NamedAttrList::findDuplicate() const { 46 std::optional<NamedAttribute> duplicate = 47 DictionaryAttr::findDuplicate(attrs, isSorted()); 48 // DictionaryAttr::findDuplicate will sort the list, so reset the sorted 49 // state. 50 if (!isSorted()) 51 dictionarySorted.setPointerAndInt(nullptr, true); 52 return duplicate; 53 } 54 55 DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const { 56 if (!isSorted()) { 57 DictionaryAttr::sortInPlace(attrs); 58 dictionarySorted.setPointerAndInt(nullptr, true); 59 } 60 if (!dictionarySorted.getPointer()) 61 dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs)); 62 return llvm::cast<DictionaryAttr>(dictionarySorted.getPointer()); 63 } 64 65 /// Replaces the attributes with new list of attributes. 66 void NamedAttrList::assign(const_iterator inStart, const_iterator inEnd) { 67 DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs); 68 dictionarySorted.setPointerAndInt(nullptr, true); 69 } 70 71 void NamedAttrList::push_back(NamedAttribute newAttribute) { 72 if (isSorted()) 73 dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute); 74 dictionarySorted.setPointer(nullptr); 75 attrs.push_back(newAttribute); 76 } 77 78 /// Return the specified attribute if present, null otherwise. 79 Attribute NamedAttrList::get(StringRef name) const { 80 auto it = findAttr(*this, name); 81 return it.second ? it.first->getValue() : Attribute(); 82 } 83 Attribute NamedAttrList::get(StringAttr name) const { 84 auto it = findAttr(*this, name); 85 return it.second ? it.first->getValue() : Attribute(); 86 } 87 88 /// Return the specified named attribute if present, std::nullopt otherwise. 89 std::optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const { 90 auto it = findAttr(*this, name); 91 return it.second ? *it.first : std::optional<NamedAttribute>(); 92 } 93 std::optional<NamedAttribute> NamedAttrList::getNamed(StringAttr name) const { 94 auto it = findAttr(*this, name); 95 return it.second ? *it.first : std::optional<NamedAttribute>(); 96 } 97 98 /// If the an attribute exists with the specified name, change it to the new 99 /// value. Otherwise, add a new attribute with the specified name/value. 100 Attribute NamedAttrList::set(StringAttr name, Attribute value) { 101 assert(value && "attributes may never be null"); 102 103 // Look for an existing attribute with the given name, and set its value 104 // in-place. Return the previous value of the attribute, if there was one. 105 auto it = findAttr(*this, name); 106 if (it.second) { 107 // Update the existing attribute by swapping out the old value for the new 108 // value. Return the old value. 109 Attribute oldValue = it.first->getValue(); 110 if (it.first->getValue() != value) { 111 it.first->setValue(value); 112 113 // If the attributes have changed, the dictionary is invalidated. 114 dictionarySorted.setPointer(nullptr); 115 } 116 return oldValue; 117 } 118 // Perform a string lookup to insert the new attribute into its sorted 119 // position. 120 if (isSorted()) 121 it = findAttr(*this, name.strref()); 122 attrs.insert(it.first, {name, value}); 123 // Invalidate the dictionary. Return null as there was no previous value. 124 dictionarySorted.setPointer(nullptr); 125 return Attribute(); 126 } 127 128 Attribute NamedAttrList::set(StringRef name, Attribute value) { 129 assert(value && "attributes may never be null"); 130 return set(mlir::StringAttr::get(value.getContext(), name), value); 131 } 132 133 Attribute 134 NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) { 135 // Erasing does not affect the sorted property. 136 Attribute attr = it->getValue(); 137 attrs.erase(it); 138 dictionarySorted.setPointer(nullptr); 139 return attr; 140 } 141 142 Attribute NamedAttrList::erase(StringAttr name) { 143 auto it = findAttr(*this, name); 144 return it.second ? eraseImpl(it.first) : Attribute(); 145 } 146 147 Attribute NamedAttrList::erase(StringRef name) { 148 auto it = findAttr(*this, name); 149 return it.second ? eraseImpl(it.first) : Attribute(); 150 } 151 152 NamedAttrList & 153 NamedAttrList::operator=(const SmallVectorImpl<NamedAttribute> &rhs) { 154 assign(rhs.begin(), rhs.end()); 155 return *this; 156 } 157 158 NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; } 159 160 //===----------------------------------------------------------------------===// 161 // OperationState 162 //===----------------------------------------------------------------------===// 163 164 OperationState::OperationState(Location location, StringRef name) 165 : location(location), name(name, location->getContext()) {} 166 167 OperationState::OperationState(Location location, OperationName name) 168 : location(location), name(name) {} 169 170 OperationState::OperationState(Location location, OperationName name, 171 ValueRange operands, TypeRange types, 172 ArrayRef<NamedAttribute> attributes, 173 BlockRange successors, 174 MutableArrayRef<std::unique_ptr<Region>> regions) 175 : location(location), name(name), 176 operands(operands.begin(), operands.end()), 177 types(types.begin(), types.end()), 178 attributes(attributes.begin(), attributes.end()), 179 successors(successors.begin(), successors.end()) { 180 for (std::unique_ptr<Region> &r : regions) 181 this->regions.push_back(std::move(r)); 182 } 183 OperationState::OperationState(Location location, StringRef name, 184 ValueRange operands, TypeRange types, 185 ArrayRef<NamedAttribute> attributes, 186 BlockRange successors, 187 MutableArrayRef<std::unique_ptr<Region>> regions) 188 : OperationState(location, OperationName(name, location.getContext()), 189 operands, types, attributes, successors, regions) {} 190 191 OperationState::~OperationState() { 192 if (properties) 193 propertiesDeleter(properties); 194 } 195 196 LogicalResult OperationState::setProperties( 197 Operation *op, function_ref<InFlightDiagnostic()> emitError) const { 198 if (LLVM_UNLIKELY(propertiesAttr)) { 199 assert(!properties); 200 return op->setPropertiesFromAttribute(propertiesAttr, emitError); 201 } 202 if (properties) 203 propertiesSetter(op->getPropertiesStorage(), properties); 204 return success(); 205 } 206 207 void OperationState::addOperands(ValueRange newOperands) { 208 operands.append(newOperands.begin(), newOperands.end()); 209 } 210 211 void OperationState::addSuccessors(BlockRange newSuccessors) { 212 successors.append(newSuccessors.begin(), newSuccessors.end()); 213 } 214 215 Region *OperationState::addRegion() { 216 regions.emplace_back(new Region); 217 return regions.back().get(); 218 } 219 220 void OperationState::addRegion(std::unique_ptr<Region> &®ion) { 221 regions.push_back(std::move(region)); 222 } 223 224 void OperationState::addRegions( 225 MutableArrayRef<std::unique_ptr<Region>> regions) { 226 for (std::unique_ptr<Region> ®ion : regions) 227 addRegion(std::move(region)); 228 } 229 230 //===----------------------------------------------------------------------===// 231 // OperandStorage 232 //===----------------------------------------------------------------------===// 233 234 detail::OperandStorage::OperandStorage(Operation *owner, 235 OpOperand *trailingOperands, 236 ValueRange values) 237 : isStorageDynamic(false), operandStorage(trailingOperands) { 238 numOperands = capacity = values.size(); 239 for (unsigned i = 0; i < numOperands; ++i) 240 new (&operandStorage[i]) OpOperand(owner, values[i]); 241 } 242 243 detail::OperandStorage::~OperandStorage() { 244 for (auto &operand : getOperands()) 245 operand.~OpOperand(); 246 247 // If the storage is dynamic, deallocate it. 248 if (isStorageDynamic) 249 free(operandStorage); 250 } 251 252 /// Replace the operands contained in the storage with the ones provided in 253 /// 'values'. 254 void detail::OperandStorage::setOperands(Operation *owner, ValueRange values) { 255 MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size()); 256 for (unsigned i = 0, e = values.size(); i != e; ++i) 257 storageOperands[i].set(values[i]); 258 } 259 260 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 261 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 262 /// than the range pointed to by 'start'+'length'. 263 void detail::OperandStorage::setOperands(Operation *owner, unsigned start, 264 unsigned length, ValueRange operands) { 265 // If the new size is the same, we can update inplace. 266 unsigned newSize = operands.size(); 267 if (newSize == length) { 268 MutableArrayRef<OpOperand> storageOperands = getOperands(); 269 for (unsigned i = 0, e = length; i != e; ++i) 270 storageOperands[start + i].set(operands[i]); 271 return; 272 } 273 // If the new size is greater, remove the extra operands and set the rest 274 // inplace. 275 if (newSize < length) { 276 eraseOperands(start + operands.size(), length - newSize); 277 setOperands(owner, start, newSize, operands); 278 return; 279 } 280 // Otherwise, the new size is greater so we need to grow the storage. 281 auto storageOperands = resize(owner, size() + (newSize - length)); 282 283 // Shift operands to the right to make space for the new operands. 284 unsigned rotateSize = storageOperands.size() - (start + length); 285 auto rbegin = storageOperands.rbegin(); 286 std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize); 287 288 // Update the operands inplace. 289 for (unsigned i = 0, e = operands.size(); i != e; ++i) 290 storageOperands[start + i].set(operands[i]); 291 } 292 293 /// Erase an operand held by the storage. 294 void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) { 295 MutableArrayRef<OpOperand> operands = getOperands(); 296 assert((start + length) <= operands.size()); 297 numOperands -= length; 298 299 // Shift all operands down if the operand to remove is not at the end. 300 if (start != numOperands) { 301 auto *indexIt = std::next(operands.begin(), start); 302 std::rotate(indexIt, std::next(indexIt, length), operands.end()); 303 } 304 for (unsigned i = 0; i != length; ++i) 305 operands[numOperands + i].~OpOperand(); 306 } 307 308 void detail::OperandStorage::eraseOperands(const BitVector &eraseIndices) { 309 MutableArrayRef<OpOperand> operands = getOperands(); 310 assert(eraseIndices.size() == operands.size()); 311 312 // Check that at least one operand is erased. 313 int firstErasedIndice = eraseIndices.find_first(); 314 if (firstErasedIndice == -1) 315 return; 316 317 // Shift all of the removed operands to the end, and destroy them. 318 numOperands = firstErasedIndice; 319 for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i) 320 if (!eraseIndices.test(i)) 321 operands[numOperands++] = std::move(operands[i]); 322 for (OpOperand &operand : operands.drop_front(numOperands)) 323 operand.~OpOperand(); 324 } 325 326 /// Resize the storage to the given size. Returns the array containing the new 327 /// operands. 328 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner, 329 unsigned newSize) { 330 // If the number of operands is less than or equal to the current amount, we 331 // can just update in place. 332 MutableArrayRef<OpOperand> origOperands = getOperands(); 333 if (newSize <= numOperands) { 334 // If the number of new size is less than the current, remove any extra 335 // operands. 336 for (unsigned i = newSize; i != numOperands; ++i) 337 origOperands[i].~OpOperand(); 338 numOperands = newSize; 339 return origOperands.take_front(newSize); 340 } 341 342 // If the new size is within the original inline capacity, grow inplace. 343 if (newSize <= capacity) { 344 OpOperand *opBegin = origOperands.data(); 345 for (unsigned e = newSize; numOperands != e; ++numOperands) 346 new (&opBegin[numOperands]) OpOperand(owner); 347 return MutableArrayRef<OpOperand>(opBegin, newSize); 348 } 349 350 // Otherwise, we need to allocate a new storage. 351 unsigned newCapacity = 352 std::max(unsigned(llvm::NextPowerOf2(capacity + 2)), newSize); 353 OpOperand *newOperandStorage = 354 reinterpret_cast<OpOperand *>(malloc(sizeof(OpOperand) * newCapacity)); 355 356 // Move the current operands to the new storage. 357 MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize); 358 std::uninitialized_move(origOperands.begin(), origOperands.end(), 359 newOperands.begin()); 360 361 // Destroy the original operands. 362 for (auto &operand : origOperands) 363 operand.~OpOperand(); 364 365 // Initialize any new operands. 366 for (unsigned e = newSize; numOperands != e; ++numOperands) 367 new (&newOperands[numOperands]) OpOperand(owner); 368 369 // If the current storage is dynamic, free it. 370 if (isStorageDynamic) 371 free(operandStorage); 372 373 // Update the storage representation to use the new dynamic storage. 374 operandStorage = newOperandStorage; 375 capacity = newCapacity; 376 isStorageDynamic = true; 377 return newOperands; 378 } 379 380 //===----------------------------------------------------------------------===// 381 // Operation Value-Iterators 382 //===----------------------------------------------------------------------===// 383 384 //===----------------------------------------------------------------------===// 385 // OperandRange 386 387 unsigned OperandRange::getBeginOperandIndex() const { 388 assert(!empty() && "range must not be empty"); 389 return base->getOperandNumber(); 390 } 391 392 OperandRangeRange OperandRange::split(DenseI32ArrayAttr segmentSizes) const { 393 return OperandRangeRange(*this, segmentSizes); 394 } 395 396 //===----------------------------------------------------------------------===// 397 // OperandRangeRange 398 399 OperandRangeRange::OperandRangeRange(OperandRange operands, 400 Attribute operandSegments) 401 : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0, 402 llvm::cast<DenseI32ArrayAttr>(operandSegments).size()) { 403 } 404 405 OperandRange OperandRangeRange::join() const { 406 const OwnerT &owner = getBase(); 407 ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(owner.second); 408 return OperandRange(owner.first, 409 std::accumulate(sizeData.begin(), sizeData.end(), 0)); 410 } 411 412 OperandRange OperandRangeRange::dereference(const OwnerT &object, 413 ptrdiff_t index) { 414 ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(object.second); 415 uint32_t startIndex = 416 std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); 417 return OperandRange(object.first + startIndex, *(sizeData.begin() + index)); 418 } 419 420 //===----------------------------------------------------------------------===// 421 // MutableOperandRange 422 423 /// Construct a new mutable range from the given operand, operand start index, 424 /// and range length. 425 MutableOperandRange::MutableOperandRange( 426 Operation *owner, unsigned start, unsigned length, 427 ArrayRef<OperandSegment> operandSegments) 428 : owner(owner), start(start), length(length), 429 operandSegments(operandSegments) { 430 assert((start + length) <= owner->getNumOperands() && "invalid range"); 431 } 432 MutableOperandRange::MutableOperandRange(Operation *owner) 433 : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {} 434 435 /// Construct a new mutable range for the given OpOperand. 436 MutableOperandRange::MutableOperandRange(OpOperand &opOperand) 437 : MutableOperandRange(opOperand.getOwner(), 438 /*start=*/opOperand.getOperandNumber(), 439 /*length=*/1) {} 440 441 /// Slice this range into a sub range, with the additional operand segment. 442 MutableOperandRange 443 MutableOperandRange::slice(unsigned subStart, unsigned subLen, 444 std::optional<OperandSegment> segment) const { 445 assert((subStart + subLen) <= length && "invalid sub-range"); 446 MutableOperandRange subSlice(owner, start + subStart, subLen, 447 operandSegments); 448 if (segment) 449 subSlice.operandSegments.push_back(*segment); 450 return subSlice; 451 } 452 453 /// Append the given values to the range. 454 void MutableOperandRange::append(ValueRange values) { 455 if (values.empty()) 456 return; 457 owner->insertOperands(start + length, values); 458 updateLength(length + values.size()); 459 } 460 461 /// Assign this range to the given values. 462 void MutableOperandRange::assign(ValueRange values) { 463 owner->setOperands(start, length, values); 464 if (length != values.size()) 465 updateLength(/*newLength=*/values.size()); 466 } 467 468 /// Assign the range to the given value. 469 void MutableOperandRange::assign(Value value) { 470 if (length == 1) { 471 owner->setOperand(start, value); 472 } else { 473 owner->setOperands(start, length, value); 474 updateLength(/*newLength=*/1); 475 } 476 } 477 478 /// Erase the operands within the given sub-range. 479 void MutableOperandRange::erase(unsigned subStart, unsigned subLen) { 480 assert((subStart + subLen) <= length && "invalid sub-range"); 481 if (length == 0) 482 return; 483 owner->eraseOperands(start + subStart, subLen); 484 updateLength(length - subLen); 485 } 486 487 /// Clear this range and erase all of the operands. 488 void MutableOperandRange::clear() { 489 if (length != 0) { 490 owner->eraseOperands(start, length); 491 updateLength(/*newLength=*/0); 492 } 493 } 494 495 /// Explicit conversion to an OperandRange. 496 OperandRange MutableOperandRange::getAsOperandRange() const { 497 return owner->getOperands().slice(start, length); 498 } 499 500 /// Allow implicit conversion to an OperandRange. 501 MutableOperandRange::operator OperandRange() const { 502 return getAsOperandRange(); 503 } 504 505 MutableOperandRange::operator MutableArrayRef<OpOperand>() const { 506 return owner->getOpOperands().slice(start, length); 507 } 508 509 MutableOperandRangeRange 510 MutableOperandRange::split(NamedAttribute segmentSizes) const { 511 return MutableOperandRangeRange(*this, segmentSizes); 512 } 513 514 /// Update the length of this range to the one provided. 515 void MutableOperandRange::updateLength(unsigned newLength) { 516 int32_t diff = int32_t(newLength) - int32_t(length); 517 length = newLength; 518 519 // Update any of the provided segment attributes. 520 for (OperandSegment &segment : operandSegments) { 521 auto attr = llvm::cast<DenseI32ArrayAttr>(segment.second.getValue()); 522 SmallVector<int32_t, 8> segments(attr.asArrayRef()); 523 segments[segment.first] += diff; 524 segment.second.setValue( 525 DenseI32ArrayAttr::get(attr.getContext(), segments)); 526 owner->setAttr(segment.second.getName(), segment.second.getValue()); 527 } 528 } 529 530 OpOperand &MutableOperandRange::operator[](unsigned index) const { 531 assert(index < length && "index is out of bounds"); 532 return owner->getOpOperand(start + index); 533 } 534 535 MutableArrayRef<OpOperand>::iterator MutableOperandRange::begin() const { 536 return owner->getOpOperands().slice(start, length).begin(); 537 } 538 539 MutableArrayRef<OpOperand>::iterator MutableOperandRange::end() const { 540 return owner->getOpOperands().slice(start, length).end(); 541 } 542 543 //===----------------------------------------------------------------------===// 544 // MutableOperandRangeRange 545 546 MutableOperandRangeRange::MutableOperandRangeRange( 547 const MutableOperandRange &operands, NamedAttribute operandSegmentAttr) 548 : MutableOperandRangeRange( 549 OwnerT(operands, operandSegmentAttr), 0, 550 llvm::cast<DenseI32ArrayAttr>(operandSegmentAttr.getValue()).size()) { 551 } 552 553 MutableOperandRange MutableOperandRangeRange::join() const { 554 return getBase().first; 555 } 556 557 MutableOperandRangeRange::operator OperandRangeRange() const { 558 return OperandRangeRange(getBase().first, getBase().second.getValue()); 559 } 560 561 MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object, 562 ptrdiff_t index) { 563 ArrayRef<int32_t> sizeData = 564 llvm::cast<DenseI32ArrayAttr>(object.second.getValue()); 565 uint32_t startIndex = 566 std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); 567 return object.first.slice( 568 startIndex, *(sizeData.begin() + index), 569 MutableOperandRange::OperandSegment(index, object.second)); 570 } 571 572 //===----------------------------------------------------------------------===// 573 // ResultRange 574 575 ResultRange::ResultRange(OpResult result) 576 : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()), 577 1) {} 578 579 ResultRange::use_range ResultRange::getUses() const { 580 return {use_begin(), use_end()}; 581 } 582 ResultRange::use_iterator ResultRange::use_begin() const { 583 return use_iterator(*this); 584 } 585 ResultRange::use_iterator ResultRange::use_end() const { 586 return use_iterator(*this, /*end=*/true); 587 } 588 ResultRange::user_range ResultRange::getUsers() { 589 return {user_begin(), user_end()}; 590 } 591 ResultRange::user_iterator ResultRange::user_begin() { 592 return user_iterator(use_begin()); 593 } 594 ResultRange::user_iterator ResultRange::user_end() { 595 return user_iterator(use_end()); 596 } 597 598 ResultRange::UseIterator::UseIterator(ResultRange results, bool end) 599 : it(end ? results.end() : results.begin()), endIt(results.end()) { 600 // Only initialize current use if there are results/can be uses. 601 if (it != endIt) 602 skipOverResultsWithNoUsers(); 603 } 604 605 ResultRange::UseIterator &ResultRange::UseIterator::operator++() { 606 // We increment over uses, if we reach the last use then move to next 607 // result. 608 if (use != (*it).use_end()) 609 ++use; 610 if (use == (*it).use_end()) { 611 ++it; 612 skipOverResultsWithNoUsers(); 613 } 614 return *this; 615 } 616 617 void ResultRange::UseIterator::skipOverResultsWithNoUsers() { 618 while (it != endIt && (*it).use_empty()) 619 ++it; 620 621 // If we are at the last result, then set use to first use of 622 // first result (sentinel value used for end). 623 if (it == endIt) 624 use = {}; 625 else 626 use = (*it).use_begin(); 627 } 628 629 void ResultRange::replaceAllUsesWith(Operation *op) { 630 replaceAllUsesWith(op->getResults()); 631 } 632 633 void ResultRange::replaceUsesWithIf( 634 Operation *op, function_ref<bool(OpOperand &)> shouldReplace) { 635 replaceUsesWithIf(op->getResults(), shouldReplace); 636 } 637 638 //===----------------------------------------------------------------------===// 639 // ValueRange 640 641 ValueRange::ValueRange(ArrayRef<Value> values) 642 : ValueRange(values.data(), values.size()) {} 643 ValueRange::ValueRange(OperandRange values) 644 : ValueRange(values.begin().getBase(), values.size()) {} 645 ValueRange::ValueRange(ResultRange values) 646 : ValueRange(values.getBase(), values.size()) {} 647 648 /// See `llvm::detail::indexed_accessor_range_base` for details. 649 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner, 650 ptrdiff_t index) { 651 if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner)) 652 return {value + index}; 653 if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner)) 654 return {operand + index}; 655 return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index); 656 } 657 /// See `llvm::detail::indexed_accessor_range_base` for details. 658 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) { 659 if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner)) 660 return value[index]; 661 if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner)) 662 return operand[index].get(); 663 return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index); 664 } 665 666 //===----------------------------------------------------------------------===// 667 // Operation Equivalency 668 //===----------------------------------------------------------------------===// 669 670 llvm::hash_code OperationEquivalence::computeHash( 671 Operation *op, function_ref<llvm::hash_code(Value)> hashOperands, 672 function_ref<llvm::hash_code(Value)> hashResults, Flags flags) { 673 // Hash operations based upon their: 674 // - Operation Name 675 // - Attributes 676 // - Result Types 677 llvm::hash_code hash = 678 llvm::hash_combine(op->getName(), op->getRawDictionaryAttrs(), 679 op->getResultTypes(), op->hashProperties()); 680 681 // - Location if required 682 if (!(flags & Flags::IgnoreLocations)) 683 hash = llvm::hash_combine(hash, op->getLoc()); 684 685 // - Operands 686 if (op->hasTrait<mlir::OpTrait::IsCommutative>() && 687 op->getNumOperands() > 0) { 688 size_t operandHash = hashOperands(op->getOperand(0)); 689 for (auto operand : op->getOperands().drop_front()) 690 operandHash += hashOperands(operand); 691 hash = llvm::hash_combine(hash, operandHash); 692 } else { 693 for (Value operand : op->getOperands()) 694 hash = llvm::hash_combine(hash, hashOperands(operand)); 695 } 696 697 // - Results 698 for (Value result : op->getResults()) 699 hash = llvm::hash_combine(hash, hashResults(result)); 700 return hash; 701 } 702 703 /*static*/ bool OperationEquivalence::isRegionEquivalentTo( 704 Region *lhs, Region *rhs, 705 function_ref<LogicalResult(Value, Value)> checkEquivalent, 706 function_ref<void(Value, Value)> markEquivalent, 707 OperationEquivalence::Flags flags, 708 function_ref<LogicalResult(ValueRange, ValueRange)> 709 checkCommutativeEquivalent) { 710 DenseMap<Block *, Block *> blocksMap; 711 auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) { 712 // Check block arguments. 713 if (lBlock.getNumArguments() != rBlock.getNumArguments()) 714 return false; 715 716 // Map the two blocks. 717 auto insertion = blocksMap.insert({&lBlock, &rBlock}); 718 if (insertion.first->getSecond() != &rBlock) 719 return false; 720 721 for (auto argPair : 722 llvm::zip(lBlock.getArguments(), rBlock.getArguments())) { 723 Value curArg = std::get<0>(argPair); 724 Value otherArg = std::get<1>(argPair); 725 if (curArg.getType() != otherArg.getType()) 726 return false; 727 if (!(flags & OperationEquivalence::IgnoreLocations) && 728 curArg.getLoc() != otherArg.getLoc()) 729 return false; 730 // Corresponding bbArgs are equivalent. 731 if (markEquivalent) 732 markEquivalent(curArg, otherArg); 733 } 734 735 auto opsEquivalent = [&](Operation &lOp, Operation &rOp) { 736 // Check for op equality (recursively). 737 if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent, 738 markEquivalent, flags, 739 checkCommutativeEquivalent)) 740 return false; 741 // Check successor mapping. 742 for (auto successorsPair : 743 llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) { 744 Block *curSuccessor = std::get<0>(successorsPair); 745 Block *otherSuccessor = std::get<1>(successorsPair); 746 auto insertion = blocksMap.insert({curSuccessor, otherSuccessor}); 747 if (insertion.first->getSecond() != otherSuccessor) 748 return false; 749 } 750 return true; 751 }; 752 return llvm::all_of_zip(lBlock, rBlock, opsEquivalent); 753 }; 754 return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent); 755 } 756 757 // Value equivalence cache to be used with `isRegionEquivalentTo` and 758 // `isEquivalentTo`. 759 struct ValueEquivalenceCache { 760 DenseMap<Value, Value> equivalentValues; 761 LogicalResult checkEquivalent(Value lhsValue, Value rhsValue) { 762 return success(lhsValue == rhsValue || 763 equivalentValues.lookup(lhsValue) == rhsValue); 764 } 765 LogicalResult checkCommutativeEquivalent(ValueRange lhsRange, 766 ValueRange rhsRange) { 767 // Handle simple case where sizes mismatch. 768 if (lhsRange.size() != rhsRange.size()) 769 return failure(); 770 771 // Handle where operands in order are equivalent. 772 auto lhsIt = lhsRange.begin(); 773 auto rhsIt = rhsRange.begin(); 774 for (; lhsIt != lhsRange.end(); ++lhsIt, ++rhsIt) { 775 if (failed(checkEquivalent(*lhsIt, *rhsIt))) 776 break; 777 } 778 if (lhsIt == lhsRange.end()) 779 return success(); 780 781 // Handle another simple case where operands are just a permutation. 782 // Note: This is not sufficient, this handles simple cases relatively 783 // cheaply. 784 auto sortValues = [](ValueRange values) { 785 SmallVector<Value> sortedValues = llvm::to_vector(values); 786 llvm::sort(sortedValues, [](Value a, Value b) { 787 return a.getAsOpaquePointer() < b.getAsOpaquePointer(); 788 }); 789 return sortedValues; 790 }; 791 auto lhsSorted = sortValues({lhsIt, lhsRange.end()}); 792 auto rhsSorted = sortValues({rhsIt, rhsRange.end()}); 793 return success(lhsSorted == rhsSorted); 794 } 795 void markEquivalent(Value lhsResult, Value rhsResult) { 796 auto insertion = equivalentValues.insert({lhsResult, rhsResult}); 797 // Make sure that the value was not already marked equivalent to some other 798 // value. 799 (void)insertion; 800 assert(insertion.first->second == rhsResult && 801 "inconsistent OperationEquivalence state"); 802 } 803 }; 804 805 /*static*/ bool 806 OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs, 807 OperationEquivalence::Flags flags) { 808 ValueEquivalenceCache cache; 809 return isRegionEquivalentTo( 810 lhs, rhs, 811 [&](Value lhsValue, Value rhsValue) -> LogicalResult { 812 return cache.checkEquivalent(lhsValue, rhsValue); 813 }, 814 [&](Value lhsResult, Value rhsResult) { 815 cache.markEquivalent(lhsResult, rhsResult); 816 }, 817 flags, 818 [&](ValueRange lhs, ValueRange rhs) -> LogicalResult { 819 return cache.checkCommutativeEquivalent(lhs, rhs); 820 }); 821 } 822 823 /*static*/ bool OperationEquivalence::isEquivalentTo( 824 Operation *lhs, Operation *rhs, 825 function_ref<LogicalResult(Value, Value)> checkEquivalent, 826 function_ref<void(Value, Value)> markEquivalent, Flags flags, 827 function_ref<LogicalResult(ValueRange, ValueRange)> 828 checkCommutativeEquivalent) { 829 if (lhs == rhs) 830 return true; 831 832 // 1. Compare the operation properties. 833 if (lhs->getName() != rhs->getName() || 834 lhs->getRawDictionaryAttrs() != rhs->getRawDictionaryAttrs() || 835 lhs->getNumRegions() != rhs->getNumRegions() || 836 lhs->getNumSuccessors() != rhs->getNumSuccessors() || 837 lhs->getNumOperands() != rhs->getNumOperands() || 838 lhs->getNumResults() != rhs->getNumResults() || 839 !lhs->getName().compareOpProperties(lhs->getPropertiesStorage(), 840 rhs->getPropertiesStorage())) 841 return false; 842 if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc()) 843 return false; 844 845 // 2. Compare operands. 846 if (checkCommutativeEquivalent && 847 lhs->hasTrait<mlir::OpTrait::IsCommutative>()) { 848 auto lhsRange = lhs->getOperands(); 849 auto rhsRange = rhs->getOperands(); 850 if (failed(checkCommutativeEquivalent(lhsRange, rhsRange))) 851 return false; 852 } else { 853 // Check pair wise for equivalence. 854 for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) { 855 Value curArg = std::get<0>(operandPair); 856 Value otherArg = std::get<1>(operandPair); 857 if (curArg == otherArg) 858 continue; 859 if (curArg.getType() != otherArg.getType()) 860 return false; 861 if (failed(checkEquivalent(curArg, otherArg))) 862 return false; 863 } 864 } 865 866 // 3. Compare result types and mark results as equivalent. 867 for (auto resultPair : llvm::zip(lhs->getResults(), rhs->getResults())) { 868 Value curArg = std::get<0>(resultPair); 869 Value otherArg = std::get<1>(resultPair); 870 if (curArg.getType() != otherArg.getType()) 871 return false; 872 if (markEquivalent) 873 markEquivalent(curArg, otherArg); 874 } 875 876 // 4. Compare regions. 877 for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions())) 878 if (!isRegionEquivalentTo(&std::get<0>(regionPair), 879 &std::get<1>(regionPair), checkEquivalent, 880 markEquivalent, flags)) 881 return false; 882 883 return true; 884 } 885 886 /*static*/ bool OperationEquivalence::isEquivalentTo(Operation *lhs, 887 Operation *rhs, 888 Flags flags) { 889 ValueEquivalenceCache cache; 890 return OperationEquivalence::isEquivalentTo( 891 lhs, rhs, 892 [&](Value lhsValue, Value rhsValue) -> LogicalResult { 893 return cache.checkEquivalent(lhsValue, rhsValue); 894 }, 895 [&](Value lhsResult, Value rhsResult) { 896 cache.markEquivalent(lhsResult, rhsResult); 897 }, 898 flags, 899 [&](ValueRange lhs, ValueRange rhs) -> LogicalResult { 900 return cache.checkCommutativeEquivalent(lhs, rhs); 901 }); 902 } 903 904 //===----------------------------------------------------------------------===// 905 // OperationFingerPrint 906 //===----------------------------------------------------------------------===// 907 908 template <typename T> 909 static void addDataToHash(llvm::SHA1 &hasher, const T &data) { 910 hasher.update( 911 ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T))); 912 } 913 914 OperationFingerPrint::OperationFingerPrint(Operation *topOp, 915 bool includeNested) { 916 llvm::SHA1 hasher; 917 918 // Helper function that hashes an operation based on its mutable bits: 919 auto addOperationToHash = [&](Operation *op) { 920 // - Operation pointer 921 addDataToHash(hasher, op); 922 // - Parent operation pointer (to take into account the nesting structure) 923 if (op != topOp) 924 addDataToHash(hasher, op->getParentOp()); 925 // - Attributes 926 addDataToHash(hasher, op->getRawDictionaryAttrs()); 927 // - Properties 928 addDataToHash(hasher, op->hashProperties()); 929 // - Blocks in Regions 930 for (Region ®ion : op->getRegions()) { 931 for (Block &block : region) { 932 addDataToHash(hasher, &block); 933 for (BlockArgument arg : block.getArguments()) 934 addDataToHash(hasher, arg); 935 } 936 } 937 // - Location 938 addDataToHash(hasher, op->getLoc().getAsOpaquePointer()); 939 // - Operands 940 for (Value operand : op->getOperands()) 941 addDataToHash(hasher, operand); 942 // - Successors 943 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) 944 addDataToHash(hasher, op->getSuccessor(i)); 945 // - Result types 946 for (Type t : op->getResultTypes()) 947 addDataToHash(hasher, t); 948 }; 949 950 if (includeNested) 951 topOp->walk(addOperationToHash); 952 else 953 addOperationToHash(topOp); 954 955 hash = hasher.result(); 956 } 957