1 //===- TypeRange.h ----------------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the TypeRange and ValueTypeRange classes. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_IR_TYPERANGE_H 14 #define MLIR_IR_TYPERANGE_H 15 16 #include "mlir/IR/Types.h" 17 #include "mlir/IR/Value.h" 18 #include "mlir/IR/ValueRange.h" 19 #include "llvm/ADT/PointerUnion.h" 20 #include "llvm/ADT/Sequence.h" 21 22 namespace mlir { 23 24 //===----------------------------------------------------------------------===// 25 // TypeRange 26 27 /// This class provides an abstraction over the various different ranges of 28 /// value types. In many cases, this prevents the need to explicitly materialize 29 /// a SmallVector/std::vector. This class should be used in places that are not 30 /// suitable for a more derived type (e.g. ArrayRef) or a template range 31 /// parameter. 32 class TypeRange : public llvm::detail::indexed_accessor_range_base< 33 TypeRange, 34 llvm::PointerUnion<const Value *, const Type *, 35 OpOperand *, detail::OpResultImpl *>, 36 Type, Type, Type> { 37 public: 38 using RangeBaseT::RangeBaseT; 39 TypeRange(ArrayRef<Type> types = std::nullopt); 40 explicit TypeRange(OperandRange values); 41 explicit TypeRange(ResultRange values); 42 explicit TypeRange(ValueRange values); 43 template <typename ValueRangeT> 44 TypeRange(ValueTypeRange<ValueRangeT> values) 45 : TypeRange(ValueRange(ValueRangeT(values.begin().getCurrent(), 46 values.end().getCurrent()))) {} 47 template <typename Arg, typename = std::enable_if_t<std::is_constructible< 48 ArrayRef<Type>, Arg>::value>> 49 TypeRange(Arg &&arg LLVM_LIFETIME_BOUND) 50 : TypeRange(ArrayRef<Type>(std::forward<Arg>(arg))) {} 51 TypeRange(std::initializer_list<Type> types LLVM_LIFETIME_BOUND) 52 : TypeRange(ArrayRef<Type>(types)) {} 53 54 private: 55 /// The owner of the range is either: 56 /// * A pointer to the first element of an array of values. 57 /// * A pointer to the first element of an array of types. 58 /// * A pointer to the first element of an array of operands. 59 /// * A pointer to the first element of an array of results. 60 using OwnerT = llvm::PointerUnion<const Value *, const Type *, OpOperand *, 61 detail::OpResultImpl *>; 62 63 /// See `llvm::detail::indexed_accessor_range_base` for details. 64 static OwnerT offset_base(OwnerT object, ptrdiff_t index); 65 /// See `llvm::detail::indexed_accessor_range_base` for details. 66 static Type dereference_iterator(OwnerT object, ptrdiff_t index); 67 68 /// Allow access to `offset_base` and `dereference_iterator`. 69 friend RangeBaseT; 70 }; 71 72 /// Make TypeRange hashable. 73 inline ::llvm::hash_code hash_value(TypeRange arg) { 74 return ::llvm::hash_combine_range(arg.begin(), arg.end()); 75 } 76 77 /// Emit a type range to the given output stream. 78 inline raw_ostream &operator<<(raw_ostream &os, const TypeRange &types) { 79 llvm::interleaveComma(types, os); 80 return os; 81 } 82 83 //===----------------------------------------------------------------------===// 84 // TypeRangeRange 85 86 using TypeRangeRangeIterator = 87 llvm::mapped_iterator<llvm::iota_range<unsigned>::iterator, 88 std::function<TypeRange(unsigned)>>; 89 90 /// This class provides an abstraction for a range of TypeRange. This is useful 91 /// when accessing the types of a range of ranges, such as when using 92 /// OperandRangeRange. 93 class TypeRangeRange : public llvm::iterator_range<TypeRangeRangeIterator> { 94 public: 95 template <typename RangeT> 96 TypeRangeRange(const RangeT &range) 97 : TypeRangeRange(llvm::seq<unsigned>(0, range.size()), range) {} 98 99 private: 100 template <typename RangeT> 101 TypeRangeRange(llvm::iota_range<unsigned> sizeRange, const RangeT &range) 102 : llvm::iterator_range<TypeRangeRangeIterator>( 103 {sizeRange.begin(), getRangeFn(range)}, 104 {sizeRange.end(), nullptr}) {} 105 106 template <typename RangeT> 107 static std::function<TypeRange(unsigned)> getRangeFn(const RangeT &range) { 108 return [=](unsigned index) -> TypeRange { return TypeRange(range[index]); }; 109 } 110 }; 111 112 //===----------------------------------------------------------------------===// 113 // ValueTypeRange 114 115 /// This class implements iteration on the types of a given range of values. 116 template <typename ValueIteratorT> 117 class ValueTypeIterator final 118 : public llvm::mapped_iterator_base<ValueTypeIterator<ValueIteratorT>, 119 ValueIteratorT, Type> { 120 public: 121 using llvm::mapped_iterator_base<ValueTypeIterator<ValueIteratorT>, 122 ValueIteratorT, Type>::mapped_iterator_base; 123 124 /// Map the element to the iterator result type. 125 Type mapElement(Value value) const { return value.getType(); } 126 }; 127 128 /// This class implements iteration on the types of a given range of values. 129 template <typename ValueRangeT> 130 class ValueTypeRange final 131 : public llvm::iterator_range< 132 ValueTypeIterator<typename ValueRangeT::iterator>> { 133 public: 134 using llvm::iterator_range< 135 ValueTypeIterator<typename ValueRangeT::iterator>>::iterator_range; 136 template <typename Container> 137 ValueTypeRange(Container &&c) : ValueTypeRange(c.begin(), c.end()) {} 138 139 /// Return the type at the given index. 140 Type operator[](size_t index) const { 141 assert(index < size() && "invalid index into type range"); 142 return *(this->begin() + index); 143 } 144 145 /// Return the size of this range. 146 size_t size() const { return llvm::size(*this); } 147 148 /// Return first type in the range. 149 Type front() { return (*this)[0]; } 150 151 /// Compare this range with another. 152 template <typename OtherT> 153 bool operator==(const OtherT &other) const { 154 return llvm::size(*this) == llvm::size(other) && 155 std::equal(this->begin(), this->end(), other.begin()); 156 } 157 template <typename OtherT> 158 bool operator!=(const OtherT &other) const { 159 return !(*this == other); 160 } 161 }; 162 163 template <typename RangeT> 164 inline bool operator==(ArrayRef<Type> lhs, const ValueTypeRange<RangeT> &rhs) { 165 return lhs.size() == static_cast<size_t>(llvm::size(rhs)) && 166 std::equal(lhs.begin(), lhs.end(), rhs.begin()); 167 } 168 169 //===----------------------------------------------------------------------===// 170 // SubElements 171 //===----------------------------------------------------------------------===// 172 173 /// Enable TypeRange to be introspected for sub-elements. 174 template <> 175 struct AttrTypeSubElementHandler<TypeRange> { 176 static void walk(TypeRange param, AttrTypeImmediateSubElementWalker &walker) { 177 walker.walkRange(param); 178 } 179 static TypeRange replace(TypeRange param, 180 AttrSubElementReplacements &attrRepls, 181 TypeSubElementReplacements &typeRepls) { 182 return typeRepls.take_front(param.size()); 183 } 184 }; 185 186 } // namespace mlir 187 188 namespace llvm { 189 190 // Provide DenseMapInfo for TypeRange. 191 template <> 192 struct DenseMapInfo<mlir::TypeRange> { 193 static mlir::TypeRange getEmptyKey() { 194 return mlir::TypeRange(getEmptyKeyPointer(), 0); 195 } 196 197 static mlir::TypeRange getTombstoneKey() { 198 return mlir::TypeRange(getTombstoneKeyPointer(), 0); 199 } 200 201 static unsigned getHashValue(mlir::TypeRange val) { return hash_value(val); } 202 203 static bool isEqual(mlir::TypeRange lhs, mlir::TypeRange rhs) { 204 if (isEmptyKey(rhs)) 205 return isEmptyKey(lhs); 206 if (isTombstoneKey(rhs)) 207 return isTombstoneKey(lhs); 208 return lhs == rhs; 209 } 210 211 private: 212 static const mlir::Type *getEmptyKeyPointer() { 213 return DenseMapInfo<mlir::Type *>::getEmptyKey(); 214 } 215 216 static const mlir::Type *getTombstoneKeyPointer() { 217 return DenseMapInfo<mlir::Type *>::getTombstoneKey(); 218 } 219 220 static bool isEmptyKey(mlir::TypeRange range) { 221 if (const auto *type = 222 llvm::dyn_cast_if_present<const mlir::Type *>(range.getBase())) 223 return type == getEmptyKeyPointer(); 224 return false; 225 } 226 227 static bool isTombstoneKey(mlir::TypeRange range) { 228 if (const auto *type = 229 llvm::dyn_cast_if_present<const mlir::Type *>(range.getBase())) 230 return type == getTombstoneKeyPointer(); 231 return false; 232 } 233 }; 234 235 } // namespace llvm 236 237 #endif // MLIR_IR_TYPERANGE_H 238