xref: /llvm-project/mlir/lib/IR/OperationSupport.cpp (revision 98de5dfe6a8cbb70f21de545acec4710a77294ed)
1 //===- OperationSupport.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains out-of-line implementations of the support types that
10 // Operation and related classes build on top of.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/OperationSupport.h"
15 #include "mlir/IR/BuiltinAttributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "llvm/ADT/BitVector.h"
19 #include "llvm/Support/SHA1.h"
20 #include <numeric>
21 #include <optional>
22 
23 using namespace mlir;
24 
25 //===----------------------------------------------------------------------===//
26 // NamedAttrList
27 //===----------------------------------------------------------------------===//
28 
29 NamedAttrList::NamedAttrList(ArrayRef<NamedAttribute> attributes) {
30   assign(attributes.begin(), attributes.end());
31 }
32 
33 NamedAttrList::NamedAttrList(DictionaryAttr attributes)
34     : NamedAttrList(attributes ? attributes.getValue()
35                                : ArrayRef<NamedAttribute>()) {
36   dictionarySorted.setPointerAndInt(attributes, true);
37 }
38 
39 NamedAttrList::NamedAttrList(const_iterator inStart, const_iterator inEnd) {
40   assign(inStart, inEnd);
41 }
42 
43 ArrayRef<NamedAttribute> NamedAttrList::getAttrs() const { return attrs; }
44 
45 std::optional<NamedAttribute> NamedAttrList::findDuplicate() const {
46   std::optional<NamedAttribute> duplicate =
47       DictionaryAttr::findDuplicate(attrs, isSorted());
48   // DictionaryAttr::findDuplicate will sort the list, so reset the sorted
49   // state.
50   if (!isSorted())
51     dictionarySorted.setPointerAndInt(nullptr, true);
52   return duplicate;
53 }
54 
55 DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const {
56   if (!isSorted()) {
57     DictionaryAttr::sortInPlace(attrs);
58     dictionarySorted.setPointerAndInt(nullptr, true);
59   }
60   if (!dictionarySorted.getPointer())
61     dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs));
62   return llvm::cast<DictionaryAttr>(dictionarySorted.getPointer());
63 }
64 
65 /// Replaces the attributes with new list of attributes.
66 void NamedAttrList::assign(const_iterator inStart, const_iterator inEnd) {
67   DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs);
68   dictionarySorted.setPointerAndInt(nullptr, true);
69 }
70 
71 void NamedAttrList::push_back(NamedAttribute newAttribute) {
72   if (isSorted())
73     dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute);
74   dictionarySorted.setPointer(nullptr);
75   attrs.push_back(newAttribute);
76 }
77 
78 /// Return the specified attribute if present, null otherwise.
79 Attribute NamedAttrList::get(StringRef name) const {
80   auto it = findAttr(*this, name);
81   return it.second ? it.first->getValue() : Attribute();
82 }
83 Attribute NamedAttrList::get(StringAttr name) const {
84   auto it = findAttr(*this, name);
85   return it.second ? it.first->getValue() : Attribute();
86 }
87 
88 /// Return the specified named attribute if present, std::nullopt otherwise.
89 std::optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
90   auto it = findAttr(*this, name);
91   return it.second ? *it.first : std::optional<NamedAttribute>();
92 }
93 std::optional<NamedAttribute> NamedAttrList::getNamed(StringAttr name) const {
94   auto it = findAttr(*this, name);
95   return it.second ? *it.first : std::optional<NamedAttribute>();
96 }
97 
98 /// If the an attribute exists with the specified name, change it to the new
99 /// value.  Otherwise, add a new attribute with the specified name/value.
100 Attribute NamedAttrList::set(StringAttr name, Attribute value) {
101   assert(value && "attributes may never be null");
102 
103   // Look for an existing attribute with the given name, and set its value
104   // in-place. Return the previous value of the attribute, if there was one.
105   auto it = findAttr(*this, name);
106   if (it.second) {
107     // Update the existing attribute by swapping out the old value for the new
108     // value. Return the old value.
109     Attribute oldValue = it.first->getValue();
110     if (it.first->getValue() != value) {
111       it.first->setValue(value);
112 
113       // If the attributes have changed, the dictionary is invalidated.
114       dictionarySorted.setPointer(nullptr);
115     }
116     return oldValue;
117   }
118   // Perform a string lookup to insert the new attribute into its sorted
119   // position.
120   if (isSorted())
121     it = findAttr(*this, name.strref());
122   attrs.insert(it.first, {name, value});
123   // Invalidate the dictionary. Return null as there was no previous value.
124   dictionarySorted.setPointer(nullptr);
125   return Attribute();
126 }
127 
128 Attribute NamedAttrList::set(StringRef name, Attribute value) {
129   assert(value && "attributes may never be null");
130   return set(mlir::StringAttr::get(value.getContext(), name), value);
131 }
132 
133 Attribute
134 NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
135   // Erasing does not affect the sorted property.
136   Attribute attr = it->getValue();
137   attrs.erase(it);
138   dictionarySorted.setPointer(nullptr);
139   return attr;
140 }
141 
142 Attribute NamedAttrList::erase(StringAttr name) {
143   auto it = findAttr(*this, name);
144   return it.second ? eraseImpl(it.first) : Attribute();
145 }
146 
147 Attribute NamedAttrList::erase(StringRef name) {
148   auto it = findAttr(*this, name);
149   return it.second ? eraseImpl(it.first) : Attribute();
150 }
151 
152 NamedAttrList &
153 NamedAttrList::operator=(const SmallVectorImpl<NamedAttribute> &rhs) {
154   assign(rhs.begin(), rhs.end());
155   return *this;
156 }
157 
158 NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; }
159 
160 //===----------------------------------------------------------------------===//
161 // OperationState
162 //===----------------------------------------------------------------------===//
163 
164 OperationState::OperationState(Location location, StringRef name)
165     : location(location), name(name, location->getContext()) {}
166 
167 OperationState::OperationState(Location location, OperationName name)
168     : location(location), name(name) {}
169 
170 OperationState::OperationState(Location location, OperationName name,
171                                ValueRange operands, TypeRange types,
172                                ArrayRef<NamedAttribute> attributes,
173                                BlockRange successors,
174                                MutableArrayRef<std::unique_ptr<Region>> regions)
175     : location(location), name(name),
176       operands(operands.begin(), operands.end()),
177       types(types.begin(), types.end()),
178       attributes(attributes.begin(), attributes.end()),
179       successors(successors.begin(), successors.end()) {
180   for (std::unique_ptr<Region> &r : regions)
181     this->regions.push_back(std::move(r));
182 }
183 OperationState::OperationState(Location location, StringRef name,
184                                ValueRange operands, TypeRange types,
185                                ArrayRef<NamedAttribute> attributes,
186                                BlockRange successors,
187                                MutableArrayRef<std::unique_ptr<Region>> regions)
188     : OperationState(location, OperationName(name, location.getContext()),
189                      operands, types, attributes, successors, regions) {}
190 
191 OperationState::~OperationState() {
192   if (properties)
193     propertiesDeleter(properties);
194 }
195 
196 LogicalResult OperationState::setProperties(
197     Operation *op, function_ref<InFlightDiagnostic()> emitError) const {
198   if (LLVM_UNLIKELY(propertiesAttr)) {
199     assert(!properties);
200     return op->setPropertiesFromAttribute(propertiesAttr, emitError);
201   }
202   if (properties)
203     propertiesSetter(op->getPropertiesStorage(), properties);
204   return success();
205 }
206 
207 void OperationState::addOperands(ValueRange newOperands) {
208   operands.append(newOperands.begin(), newOperands.end());
209 }
210 
211 void OperationState::addSuccessors(BlockRange newSuccessors) {
212   successors.append(newSuccessors.begin(), newSuccessors.end());
213 }
214 
215 Region *OperationState::addRegion() {
216   regions.emplace_back(new Region);
217   return regions.back().get();
218 }
219 
220 void OperationState::addRegion(std::unique_ptr<Region> &&region) {
221   regions.push_back(std::move(region));
222 }
223 
224 void OperationState::addRegions(
225     MutableArrayRef<std::unique_ptr<Region>> regions) {
226   for (std::unique_ptr<Region> &region : regions)
227     addRegion(std::move(region));
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // OperandStorage
232 //===----------------------------------------------------------------------===//
233 
234 detail::OperandStorage::OperandStorage(Operation *owner,
235                                        OpOperand *trailingOperands,
236                                        ValueRange values)
237     : isStorageDynamic(false), operandStorage(trailingOperands) {
238   numOperands = capacity = values.size();
239   for (unsigned i = 0; i < numOperands; ++i)
240     new (&operandStorage[i]) OpOperand(owner, values[i]);
241 }
242 
243 detail::OperandStorage::~OperandStorage() {
244   for (auto &operand : getOperands())
245     operand.~OpOperand();
246 
247   // If the storage is dynamic, deallocate it.
248   if (isStorageDynamic)
249     free(operandStorage);
250 }
251 
252 /// Replace the operands contained in the storage with the ones provided in
253 /// 'values'.
254 void detail::OperandStorage::setOperands(Operation *owner, ValueRange values) {
255   MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size());
256   for (unsigned i = 0, e = values.size(); i != e; ++i)
257     storageOperands[i].set(values[i]);
258 }
259 
260 /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
261 /// with the ones provided in 'operands'. 'operands' may be smaller or larger
262 /// than the range pointed to by 'start'+'length'.
263 void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
264                                          unsigned length, ValueRange operands) {
265   // If the new size is the same, we can update inplace.
266   unsigned newSize = operands.size();
267   if (newSize == length) {
268     MutableArrayRef<OpOperand> storageOperands = getOperands();
269     for (unsigned i = 0, e = length; i != e; ++i)
270       storageOperands[start + i].set(operands[i]);
271     return;
272   }
273   // If the new size is greater, remove the extra operands and set the rest
274   // inplace.
275   if (newSize < length) {
276     eraseOperands(start + operands.size(), length - newSize);
277     setOperands(owner, start, newSize, operands);
278     return;
279   }
280   // Otherwise, the new size is greater so we need to grow the storage.
281   auto storageOperands = resize(owner, size() + (newSize - length));
282 
283   // Shift operands to the right to make space for the new operands.
284   unsigned rotateSize = storageOperands.size() - (start + length);
285   auto rbegin = storageOperands.rbegin();
286   std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);
287 
288   // Update the operands inplace.
289   for (unsigned i = 0, e = operands.size(); i != e; ++i)
290     storageOperands[start + i].set(operands[i]);
291 }
292 
293 /// Erase an operand held by the storage.
294 void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
295   MutableArrayRef<OpOperand> operands = getOperands();
296   assert((start + length) <= operands.size());
297   numOperands -= length;
298 
299   // Shift all operands down if the operand to remove is not at the end.
300   if (start != numOperands) {
301     auto *indexIt = std::next(operands.begin(), start);
302     std::rotate(indexIt, std::next(indexIt, length), operands.end());
303   }
304   for (unsigned i = 0; i != length; ++i)
305     operands[numOperands + i].~OpOperand();
306 }
307 
308 void detail::OperandStorage::eraseOperands(const BitVector &eraseIndices) {
309   MutableArrayRef<OpOperand> operands = getOperands();
310   assert(eraseIndices.size() == operands.size());
311 
312   // Check that at least one operand is erased.
313   int firstErasedIndice = eraseIndices.find_first();
314   if (firstErasedIndice == -1)
315     return;
316 
317   // Shift all of the removed operands to the end, and destroy them.
318   numOperands = firstErasedIndice;
319   for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
320     if (!eraseIndices.test(i))
321       operands[numOperands++] = std::move(operands[i]);
322   for (OpOperand &operand : operands.drop_front(numOperands))
323     operand.~OpOperand();
324 }
325 
326 /// Resize the storage to the given size. Returns the array containing the new
327 /// operands.
328 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
329                                                           unsigned newSize) {
330   // If the number of operands is less than or equal to the current amount, we
331   // can just update in place.
332   MutableArrayRef<OpOperand> origOperands = getOperands();
333   if (newSize <= numOperands) {
334     // If the number of new size is less than the current, remove any extra
335     // operands.
336     for (unsigned i = newSize; i != numOperands; ++i)
337       origOperands[i].~OpOperand();
338     numOperands = newSize;
339     return origOperands.take_front(newSize);
340   }
341 
342   // If the new size is within the original inline capacity, grow inplace.
343   if (newSize <= capacity) {
344     OpOperand *opBegin = origOperands.data();
345     for (unsigned e = newSize; numOperands != e; ++numOperands)
346       new (&opBegin[numOperands]) OpOperand(owner);
347     return MutableArrayRef<OpOperand>(opBegin, newSize);
348   }
349 
350   // Otherwise, we need to allocate a new storage.
351   unsigned newCapacity =
352       std::max(unsigned(llvm::NextPowerOf2(capacity + 2)), newSize);
353   OpOperand *newOperandStorage =
354       reinterpret_cast<OpOperand *>(malloc(sizeof(OpOperand) * newCapacity));
355 
356   // Move the current operands to the new storage.
357   MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize);
358   std::uninitialized_move(origOperands.begin(), origOperands.end(),
359                           newOperands.begin());
360 
361   // Destroy the original operands.
362   for (auto &operand : origOperands)
363     operand.~OpOperand();
364 
365   // Initialize any new operands.
366   for (unsigned e = newSize; numOperands != e; ++numOperands)
367     new (&newOperands[numOperands]) OpOperand(owner);
368 
369   // If the current storage is dynamic, free it.
370   if (isStorageDynamic)
371     free(operandStorage);
372 
373   // Update the storage representation to use the new dynamic storage.
374   operandStorage = newOperandStorage;
375   capacity = newCapacity;
376   isStorageDynamic = true;
377   return newOperands;
378 }
379 
380 //===----------------------------------------------------------------------===//
381 // Operation Value-Iterators
382 //===----------------------------------------------------------------------===//
383 
384 //===----------------------------------------------------------------------===//
385 // OperandRange
386 
387 unsigned OperandRange::getBeginOperandIndex() const {
388   assert(!empty() && "range must not be empty");
389   return base->getOperandNumber();
390 }
391 
392 OperandRangeRange OperandRange::split(DenseI32ArrayAttr segmentSizes) const {
393   return OperandRangeRange(*this, segmentSizes);
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // OperandRangeRange
398 
399 OperandRangeRange::OperandRangeRange(OperandRange operands,
400                                      Attribute operandSegments)
401     : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
402                         llvm::cast<DenseI32ArrayAttr>(operandSegments).size()) {
403 }
404 
405 OperandRange OperandRangeRange::join() const {
406   const OwnerT &owner = getBase();
407   ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(owner.second);
408   return OperandRange(owner.first,
409                       std::accumulate(sizeData.begin(), sizeData.end(), 0));
410 }
411 
412 OperandRange OperandRangeRange::dereference(const OwnerT &object,
413                                             ptrdiff_t index) {
414   ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(object.second);
415   uint32_t startIndex =
416       std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
417   return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
418 }
419 
420 //===----------------------------------------------------------------------===//
421 // MutableOperandRange
422 
423 /// Construct a new mutable range from the given operand, operand start index,
424 /// and range length.
425 MutableOperandRange::MutableOperandRange(
426     Operation *owner, unsigned start, unsigned length,
427     ArrayRef<OperandSegment> operandSegments)
428     : owner(owner), start(start), length(length),
429       operandSegments(operandSegments) {
430   assert((start + length) <= owner->getNumOperands() && "invalid range");
431 }
432 MutableOperandRange::MutableOperandRange(Operation *owner)
433     : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
434 
435 /// Construct a new mutable range for the given OpOperand.
436 MutableOperandRange::MutableOperandRange(OpOperand &opOperand)
437     : MutableOperandRange(opOperand.getOwner(),
438                           /*start=*/opOperand.getOperandNumber(),
439                           /*length=*/1) {}
440 
441 /// Slice this range into a sub range, with the additional operand segment.
442 MutableOperandRange
443 MutableOperandRange::slice(unsigned subStart, unsigned subLen,
444                            std::optional<OperandSegment> segment) const {
445   assert((subStart + subLen) <= length && "invalid sub-range");
446   MutableOperandRange subSlice(owner, start + subStart, subLen,
447                                operandSegments);
448   if (segment)
449     subSlice.operandSegments.push_back(*segment);
450   return subSlice;
451 }
452 
453 /// Append the given values to the range.
454 void MutableOperandRange::append(ValueRange values) {
455   if (values.empty())
456     return;
457   owner->insertOperands(start + length, values);
458   updateLength(length + values.size());
459 }
460 
461 /// Assign this range to the given values.
462 void MutableOperandRange::assign(ValueRange values) {
463   owner->setOperands(start, length, values);
464   if (length != values.size())
465     updateLength(/*newLength=*/values.size());
466 }
467 
468 /// Assign the range to the given value.
469 void MutableOperandRange::assign(Value value) {
470   if (length == 1) {
471     owner->setOperand(start, value);
472   } else {
473     owner->setOperands(start, length, value);
474     updateLength(/*newLength=*/1);
475   }
476 }
477 
478 /// Erase the operands within the given sub-range.
479 void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
480   assert((subStart + subLen) <= length && "invalid sub-range");
481   if (length == 0)
482     return;
483   owner->eraseOperands(start + subStart, subLen);
484   updateLength(length - subLen);
485 }
486 
487 /// Clear this range and erase all of the operands.
488 void MutableOperandRange::clear() {
489   if (length != 0) {
490     owner->eraseOperands(start, length);
491     updateLength(/*newLength=*/0);
492   }
493 }
494 
495 /// Explicit conversion to an OperandRange.
496 OperandRange MutableOperandRange::getAsOperandRange() const {
497   return owner->getOperands().slice(start, length);
498 }
499 
500 /// Allow implicit conversion to an OperandRange.
501 MutableOperandRange::operator OperandRange() const {
502   return getAsOperandRange();
503 }
504 
505 MutableOperandRange::operator MutableArrayRef<OpOperand>() const {
506   return owner->getOpOperands().slice(start, length);
507 }
508 
509 MutableOperandRangeRange
510 MutableOperandRange::split(NamedAttribute segmentSizes) const {
511   return MutableOperandRangeRange(*this, segmentSizes);
512 }
513 
514 /// Update the length of this range to the one provided.
515 void MutableOperandRange::updateLength(unsigned newLength) {
516   int32_t diff = int32_t(newLength) - int32_t(length);
517   length = newLength;
518 
519   // Update any of the provided segment attributes.
520   for (OperandSegment &segment : operandSegments) {
521     auto attr = llvm::cast<DenseI32ArrayAttr>(segment.second.getValue());
522     SmallVector<int32_t, 8> segments(attr.asArrayRef());
523     segments[segment.first] += diff;
524     segment.second.setValue(
525         DenseI32ArrayAttr::get(attr.getContext(), segments));
526     owner->setAttr(segment.second.getName(), segment.second.getValue());
527   }
528 }
529 
530 OpOperand &MutableOperandRange::operator[](unsigned index) const {
531   assert(index < length && "index is out of bounds");
532   return owner->getOpOperand(start + index);
533 }
534 
535 MutableArrayRef<OpOperand>::iterator MutableOperandRange::begin() const {
536   return owner->getOpOperands().slice(start, length).begin();
537 }
538 
539 MutableArrayRef<OpOperand>::iterator MutableOperandRange::end() const {
540   return owner->getOpOperands().slice(start, length).end();
541 }
542 
543 //===----------------------------------------------------------------------===//
544 // MutableOperandRangeRange
545 
546 MutableOperandRangeRange::MutableOperandRangeRange(
547     const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
548     : MutableOperandRangeRange(
549           OwnerT(operands, operandSegmentAttr), 0,
550           llvm::cast<DenseI32ArrayAttr>(operandSegmentAttr.getValue()).size()) {
551 }
552 
553 MutableOperandRange MutableOperandRangeRange::join() const {
554   return getBase().first;
555 }
556 
557 MutableOperandRangeRange::operator OperandRangeRange() const {
558   return OperandRangeRange(getBase().first, getBase().second.getValue());
559 }
560 
561 MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
562                                                           ptrdiff_t index) {
563   ArrayRef<int32_t> sizeData =
564       llvm::cast<DenseI32ArrayAttr>(object.second.getValue());
565   uint32_t startIndex =
566       std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
567   return object.first.slice(
568       startIndex, *(sizeData.begin() + index),
569       MutableOperandRange::OperandSegment(index, object.second));
570 }
571 
572 //===----------------------------------------------------------------------===//
573 // ResultRange
574 
575 ResultRange::ResultRange(OpResult result)
576     : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
577                   1) {}
578 
579 ResultRange::use_range ResultRange::getUses() const {
580   return {use_begin(), use_end()};
581 }
582 ResultRange::use_iterator ResultRange::use_begin() const {
583   return use_iterator(*this);
584 }
585 ResultRange::use_iterator ResultRange::use_end() const {
586   return use_iterator(*this, /*end=*/true);
587 }
588 ResultRange::user_range ResultRange::getUsers() {
589   return {user_begin(), user_end()};
590 }
591 ResultRange::user_iterator ResultRange::user_begin() {
592   return user_iterator(use_begin());
593 }
594 ResultRange::user_iterator ResultRange::user_end() {
595   return user_iterator(use_end());
596 }
597 
598 ResultRange::UseIterator::UseIterator(ResultRange results, bool end)
599     : it(end ? results.end() : results.begin()), endIt(results.end()) {
600   // Only initialize current use if there are results/can be uses.
601   if (it != endIt)
602     skipOverResultsWithNoUsers();
603 }
604 
605 ResultRange::UseIterator &ResultRange::UseIterator::operator++() {
606   // We increment over uses, if we reach the last use then move to next
607   // result.
608   if (use != (*it).use_end())
609     ++use;
610   if (use == (*it).use_end()) {
611     ++it;
612     skipOverResultsWithNoUsers();
613   }
614   return *this;
615 }
616 
617 void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
618   while (it != endIt && (*it).use_empty())
619     ++it;
620 
621   // If we are at the last result, then set use to first use of
622   // first result (sentinel value used for end).
623   if (it == endIt)
624     use = {};
625   else
626     use = (*it).use_begin();
627 }
628 
629 void ResultRange::replaceAllUsesWith(Operation *op) {
630   replaceAllUsesWith(op->getResults());
631 }
632 
633 void ResultRange::replaceUsesWithIf(
634     Operation *op, function_ref<bool(OpOperand &)> shouldReplace) {
635   replaceUsesWithIf(op->getResults(), shouldReplace);
636 }
637 
638 //===----------------------------------------------------------------------===//
639 // ValueRange
640 
641 ValueRange::ValueRange(ArrayRef<Value> values)
642     : ValueRange(values.data(), values.size()) {}
643 ValueRange::ValueRange(OperandRange values)
644     : ValueRange(values.begin().getBase(), values.size()) {}
645 ValueRange::ValueRange(ResultRange values)
646     : ValueRange(values.getBase(), values.size()) {}
647 
648 /// See `llvm::detail::indexed_accessor_range_base` for details.
649 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
650                                            ptrdiff_t index) {
651   if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
652     return {value + index};
653   if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
654     return {operand + index};
655   return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
656 }
657 /// See `llvm::detail::indexed_accessor_range_base` for details.
658 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
659   if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
660     return value[index];
661   if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
662     return operand[index].get();
663   return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
664 }
665 
666 //===----------------------------------------------------------------------===//
667 // Operation Equivalency
668 //===----------------------------------------------------------------------===//
669 
670 llvm::hash_code OperationEquivalence::computeHash(
671     Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
672     function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
673   // Hash operations based upon their:
674   //   - Operation Name
675   //   - Attributes
676   //   - Result Types
677   llvm::hash_code hash =
678       llvm::hash_combine(op->getName(), op->getRawDictionaryAttrs(),
679                          op->getResultTypes(), op->hashProperties());
680 
681   //   - Location if required
682   if (!(flags & Flags::IgnoreLocations))
683     hash = llvm::hash_combine(hash, op->getLoc());
684 
685   //   - Operands
686   if (op->hasTrait<mlir::OpTrait::IsCommutative>() &&
687       op->getNumOperands() > 0) {
688     size_t operandHash = hashOperands(op->getOperand(0));
689     for (auto operand : op->getOperands().drop_front())
690       operandHash += hashOperands(operand);
691     hash = llvm::hash_combine(hash, operandHash);
692   } else {
693     for (Value operand : op->getOperands())
694       hash = llvm::hash_combine(hash, hashOperands(operand));
695   }
696 
697   //   - Results
698   for (Value result : op->getResults())
699     hash = llvm::hash_combine(hash, hashResults(result));
700   return hash;
701 }
702 
703 /*static*/ bool OperationEquivalence::isRegionEquivalentTo(
704     Region *lhs, Region *rhs,
705     function_ref<LogicalResult(Value, Value)> checkEquivalent,
706     function_ref<void(Value, Value)> markEquivalent,
707     OperationEquivalence::Flags flags,
708     function_ref<LogicalResult(ValueRange, ValueRange)>
709         checkCommutativeEquivalent) {
710   DenseMap<Block *, Block *> blocksMap;
711   auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
712     // Check block arguments.
713     if (lBlock.getNumArguments() != rBlock.getNumArguments())
714       return false;
715 
716     // Map the two blocks.
717     auto insertion = blocksMap.insert({&lBlock, &rBlock});
718     if (insertion.first->getSecond() != &rBlock)
719       return false;
720 
721     for (auto argPair :
722          llvm::zip(lBlock.getArguments(), rBlock.getArguments())) {
723       Value curArg = std::get<0>(argPair);
724       Value otherArg = std::get<1>(argPair);
725       if (curArg.getType() != otherArg.getType())
726         return false;
727       if (!(flags & OperationEquivalence::IgnoreLocations) &&
728           curArg.getLoc() != otherArg.getLoc())
729         return false;
730       // Corresponding bbArgs are equivalent.
731       if (markEquivalent)
732         markEquivalent(curArg, otherArg);
733     }
734 
735     auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
736       // Check for op equality (recursively).
737       if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
738                                                 markEquivalent, flags,
739                                                 checkCommutativeEquivalent))
740         return false;
741       // Check successor mapping.
742       for (auto successorsPair :
743            llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) {
744         Block *curSuccessor = std::get<0>(successorsPair);
745         Block *otherSuccessor = std::get<1>(successorsPair);
746         auto insertion = blocksMap.insert({curSuccessor, otherSuccessor});
747         if (insertion.first->getSecond() != otherSuccessor)
748           return false;
749       }
750       return true;
751     };
752     return llvm::all_of_zip(lBlock, rBlock, opsEquivalent);
753   };
754   return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent);
755 }
756 
757 // Value equivalence cache to be used with `isRegionEquivalentTo` and
758 // `isEquivalentTo`.
759 struct ValueEquivalenceCache {
760   DenseMap<Value, Value> equivalentValues;
761   LogicalResult checkEquivalent(Value lhsValue, Value rhsValue) {
762     return success(lhsValue == rhsValue ||
763                    equivalentValues.lookup(lhsValue) == rhsValue);
764   }
765   LogicalResult checkCommutativeEquivalent(ValueRange lhsRange,
766                                            ValueRange rhsRange) {
767     // Handle simple case where sizes mismatch.
768     if (lhsRange.size() != rhsRange.size())
769       return failure();
770 
771     // Handle where operands in order are equivalent.
772     auto lhsIt = lhsRange.begin();
773     auto rhsIt = rhsRange.begin();
774     for (; lhsIt != lhsRange.end(); ++lhsIt, ++rhsIt) {
775       if (failed(checkEquivalent(*lhsIt, *rhsIt)))
776         break;
777     }
778     if (lhsIt == lhsRange.end())
779       return success();
780 
781     // Handle another simple case where operands are just a permutation.
782     // Note: This is not sufficient, this handles simple cases relatively
783     // cheaply.
784     auto sortValues = [](ValueRange values) {
785       SmallVector<Value> sortedValues = llvm::to_vector(values);
786       llvm::sort(sortedValues, [](Value a, Value b) {
787         return a.getAsOpaquePointer() < b.getAsOpaquePointer();
788       });
789       return sortedValues;
790     };
791     auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
792     auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
793     return success(lhsSorted == rhsSorted);
794   }
795   void markEquivalent(Value lhsResult, Value rhsResult) {
796     auto insertion = equivalentValues.insert({lhsResult, rhsResult});
797     // Make sure that the value was not already marked equivalent to some other
798     // value.
799     (void)insertion;
800     assert(insertion.first->second == rhsResult &&
801            "inconsistent OperationEquivalence state");
802   }
803 };
804 
805 /*static*/ bool
806 OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
807                                            OperationEquivalence::Flags flags) {
808   ValueEquivalenceCache cache;
809   return isRegionEquivalentTo(
810       lhs, rhs,
811       [&](Value lhsValue, Value rhsValue) -> LogicalResult {
812         return cache.checkEquivalent(lhsValue, rhsValue);
813       },
814       [&](Value lhsResult, Value rhsResult) {
815         cache.markEquivalent(lhsResult, rhsResult);
816       },
817       flags,
818       [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
819         return cache.checkCommutativeEquivalent(lhs, rhs);
820       });
821 }
822 
823 /*static*/ bool OperationEquivalence::isEquivalentTo(
824     Operation *lhs, Operation *rhs,
825     function_ref<LogicalResult(Value, Value)> checkEquivalent,
826     function_ref<void(Value, Value)> markEquivalent, Flags flags,
827     function_ref<LogicalResult(ValueRange, ValueRange)>
828         checkCommutativeEquivalent) {
829   if (lhs == rhs)
830     return true;
831 
832   // 1. Compare the operation properties.
833   if (lhs->getName() != rhs->getName() ||
834       lhs->getRawDictionaryAttrs() != rhs->getRawDictionaryAttrs() ||
835       lhs->getNumRegions() != rhs->getNumRegions() ||
836       lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
837       lhs->getNumOperands() != rhs->getNumOperands() ||
838       lhs->getNumResults() != rhs->getNumResults() ||
839       !lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
840                                           rhs->getPropertiesStorage()))
841     return false;
842   if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
843     return false;
844 
845   // 2. Compare operands.
846   if (checkCommutativeEquivalent &&
847       lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
848     auto lhsRange = lhs->getOperands();
849     auto rhsRange = rhs->getOperands();
850     if (failed(checkCommutativeEquivalent(lhsRange, rhsRange)))
851       return false;
852   } else {
853     // Check pair wise for equivalence.
854     for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
855       Value curArg = std::get<0>(operandPair);
856       Value otherArg = std::get<1>(operandPair);
857       if (curArg == otherArg)
858         continue;
859       if (curArg.getType() != otherArg.getType())
860         return false;
861       if (failed(checkEquivalent(curArg, otherArg)))
862         return false;
863     }
864   }
865 
866   // 3. Compare result types and mark results as equivalent.
867   for (auto resultPair : llvm::zip(lhs->getResults(), rhs->getResults())) {
868     Value curArg = std::get<0>(resultPair);
869     Value otherArg = std::get<1>(resultPair);
870     if (curArg.getType() != otherArg.getType())
871       return false;
872     if (markEquivalent)
873       markEquivalent(curArg, otherArg);
874   }
875 
876   // 4. Compare regions.
877   for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
878     if (!isRegionEquivalentTo(&std::get<0>(regionPair),
879                               &std::get<1>(regionPair), checkEquivalent,
880                               markEquivalent, flags))
881       return false;
882 
883   return true;
884 }
885 
886 /*static*/ bool OperationEquivalence::isEquivalentTo(Operation *lhs,
887                                                      Operation *rhs,
888                                                      Flags flags) {
889   ValueEquivalenceCache cache;
890   return OperationEquivalence::isEquivalentTo(
891       lhs, rhs,
892       [&](Value lhsValue, Value rhsValue) -> LogicalResult {
893         return cache.checkEquivalent(lhsValue, rhsValue);
894       },
895       [&](Value lhsResult, Value rhsResult) {
896         cache.markEquivalent(lhsResult, rhsResult);
897       },
898       flags,
899       [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
900         return cache.checkCommutativeEquivalent(lhs, rhs);
901       });
902 }
903 
904 //===----------------------------------------------------------------------===//
905 // OperationFingerPrint
906 //===----------------------------------------------------------------------===//
907 
908 template <typename T>
909 static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
910   hasher.update(
911       ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
912 }
913 
914 OperationFingerPrint::OperationFingerPrint(Operation *topOp,
915                                            bool includeNested) {
916   llvm::SHA1 hasher;
917 
918   // Helper function that hashes an operation based on its mutable bits:
919   auto addOperationToHash = [&](Operation *op) {
920     //   - Operation pointer
921     addDataToHash(hasher, op);
922     //   - Parent operation pointer (to take into account the nesting structure)
923     if (op != topOp)
924       addDataToHash(hasher, op->getParentOp());
925     //   - Attributes
926     addDataToHash(hasher, op->getRawDictionaryAttrs());
927     //   - Properties
928     addDataToHash(hasher, op->hashProperties());
929     //   - Blocks in Regions
930     for (Region &region : op->getRegions()) {
931       for (Block &block : region) {
932         addDataToHash(hasher, &block);
933         for (BlockArgument arg : block.getArguments())
934           addDataToHash(hasher, arg);
935       }
936     }
937     //   - Location
938     addDataToHash(hasher, op->getLoc().getAsOpaquePointer());
939     //   - Operands
940     for (Value operand : op->getOperands())
941       addDataToHash(hasher, operand);
942     //   - Successors
943     for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
944       addDataToHash(hasher, op->getSuccessor(i));
945     //   - Result types
946     for (Type t : op->getResultTypes())
947       addDataToHash(hasher, t);
948   };
949 
950   if (includeNested)
951     topOp->walk(addOperationToHash);
952   else
953     addOperationToHash(topOp);
954 
955   hash = hasher.result();
956 }
957