xref: /llvm-project/mlir/include/mlir/IR/TypeRange.h (revision 1297c1125f9c284e0cc0f2bf50d4b7ba519f7309)
1 //===- TypeRange.h ----------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the TypeRange and ValueTypeRange classes.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_TYPERANGE_H
14 #define MLIR_IR_TYPERANGE_H
15 
16 #include "mlir/IR/Types.h"
17 #include "mlir/IR/Value.h"
18 #include "mlir/IR/ValueRange.h"
19 #include "llvm/ADT/PointerUnion.h"
20 #include "llvm/ADT/Sequence.h"
21 
22 namespace mlir {
23 
24 //===----------------------------------------------------------------------===//
25 // TypeRange
26 
27 /// This class provides an abstraction over the various different ranges of
28 /// value types. In many cases, this prevents the need to explicitly materialize
29 /// a SmallVector/std::vector. This class should be used in places that are not
30 /// suitable for a more derived type (e.g. ArrayRef) or a template range
31 /// parameter.
32 class TypeRange : public llvm::detail::indexed_accessor_range_base<
33                       TypeRange,
34                       llvm::PointerUnion<const Value *, const Type *,
35                                          OpOperand *, detail::OpResultImpl *>,
36                       Type, Type, Type> {
37 public:
38   using RangeBaseT::RangeBaseT;
39   TypeRange(ArrayRef<Type> types = std::nullopt);
40   explicit TypeRange(OperandRange values);
41   explicit TypeRange(ResultRange values);
42   explicit TypeRange(ValueRange values);
43   template <typename ValueRangeT>
44   TypeRange(ValueTypeRange<ValueRangeT> values)
45       : TypeRange(ValueRange(ValueRangeT(values.begin().getCurrent(),
46                                          values.end().getCurrent()))) {}
47   template <typename Arg, typename = std::enable_if_t<std::is_constructible<
48                               ArrayRef<Type>, Arg>::value>>
49   TypeRange(Arg &&arg LLVM_LIFETIME_BOUND)
50       : TypeRange(ArrayRef<Type>(std::forward<Arg>(arg))) {}
51   TypeRange(std::initializer_list<Type> types LLVM_LIFETIME_BOUND)
52       : TypeRange(ArrayRef<Type>(types)) {}
53 
54 private:
55   /// The owner of the range is either:
56   /// * A pointer to the first element of an array of values.
57   /// * A pointer to the first element of an array of types.
58   /// * A pointer to the first element of an array of operands.
59   /// * A pointer to the first element of an array of results.
60   using OwnerT = llvm::PointerUnion<const Value *, const Type *, OpOperand *,
61                                     detail::OpResultImpl *>;
62 
63   /// See `llvm::detail::indexed_accessor_range_base` for details.
64   static OwnerT offset_base(OwnerT object, ptrdiff_t index);
65   /// See `llvm::detail::indexed_accessor_range_base` for details.
66   static Type dereference_iterator(OwnerT object, ptrdiff_t index);
67 
68   /// Allow access to `offset_base` and `dereference_iterator`.
69   friend RangeBaseT;
70 };
71 
72 /// Make TypeRange hashable.
73 inline ::llvm::hash_code hash_value(TypeRange arg) {
74   return ::llvm::hash_combine_range(arg.begin(), arg.end());
75 }
76 
77 /// Emit a type range to the given output stream.
78 inline raw_ostream &operator<<(raw_ostream &os, const TypeRange &types) {
79   llvm::interleaveComma(types, os);
80   return os;
81 }
82 
83 //===----------------------------------------------------------------------===//
84 // TypeRangeRange
85 
86 using TypeRangeRangeIterator =
87     llvm::mapped_iterator<llvm::iota_range<unsigned>::iterator,
88                           std::function<TypeRange(unsigned)>>;
89 
90 /// This class provides an abstraction for a range of TypeRange. This is useful
91 /// when accessing the types of a range of ranges, such as when using
92 /// OperandRangeRange.
93 class TypeRangeRange : public llvm::iterator_range<TypeRangeRangeIterator> {
94 public:
95   template <typename RangeT>
96   TypeRangeRange(const RangeT &range)
97       : TypeRangeRange(llvm::seq<unsigned>(0, range.size()), range) {}
98 
99 private:
100   template <typename RangeT>
101   TypeRangeRange(llvm::iota_range<unsigned> sizeRange, const RangeT &range)
102       : llvm::iterator_range<TypeRangeRangeIterator>(
103             {sizeRange.begin(), getRangeFn(range)},
104             {sizeRange.end(), nullptr}) {}
105 
106   template <typename RangeT>
107   static std::function<TypeRange(unsigned)> getRangeFn(const RangeT &range) {
108     return [=](unsigned index) -> TypeRange { return TypeRange(range[index]); };
109   }
110 };
111 
112 //===----------------------------------------------------------------------===//
113 // ValueTypeRange
114 
115 /// This class implements iteration on the types of a given range of values.
116 template <typename ValueIteratorT>
117 class ValueTypeIterator final
118     : public llvm::mapped_iterator_base<ValueTypeIterator<ValueIteratorT>,
119                                         ValueIteratorT, Type> {
120 public:
121   using llvm::mapped_iterator_base<ValueTypeIterator<ValueIteratorT>,
122                                    ValueIteratorT, Type>::mapped_iterator_base;
123 
124   /// Map the element to the iterator result type.
125   Type mapElement(Value value) const { return value.getType(); }
126 };
127 
128 /// This class implements iteration on the types of a given range of values.
129 template <typename ValueRangeT>
130 class ValueTypeRange final
131     : public llvm::iterator_range<
132           ValueTypeIterator<typename ValueRangeT::iterator>> {
133 public:
134   using llvm::iterator_range<
135       ValueTypeIterator<typename ValueRangeT::iterator>>::iterator_range;
136   template <typename Container>
137   ValueTypeRange(Container &&c) : ValueTypeRange(c.begin(), c.end()) {}
138 
139   /// Return the type at the given index.
140   Type operator[](size_t index) const {
141     assert(index < size() && "invalid index into type range");
142     return *(this->begin() + index);
143   }
144 
145   /// Return the size of this range.
146   size_t size() const { return llvm::size(*this); }
147 
148   /// Return first type in the range.
149   Type front() { return (*this)[0]; }
150 
151   /// Compare this range with another.
152   template <typename OtherT>
153   bool operator==(const OtherT &other) const {
154     return llvm::size(*this) == llvm::size(other) &&
155            std::equal(this->begin(), this->end(), other.begin());
156   }
157   template <typename OtherT>
158   bool operator!=(const OtherT &other) const {
159     return !(*this == other);
160   }
161 };
162 
163 template <typename RangeT>
164 inline bool operator==(ArrayRef<Type> lhs, const ValueTypeRange<RangeT> &rhs) {
165   return lhs.size() == static_cast<size_t>(llvm::size(rhs)) &&
166          std::equal(lhs.begin(), lhs.end(), rhs.begin());
167 }
168 
169 //===----------------------------------------------------------------------===//
170 // SubElements
171 //===----------------------------------------------------------------------===//
172 
173 /// Enable TypeRange to be introspected for sub-elements.
174 template <>
175 struct AttrTypeSubElementHandler<TypeRange> {
176   static void walk(TypeRange param, AttrTypeImmediateSubElementWalker &walker) {
177     walker.walkRange(param);
178   }
179   static TypeRange replace(TypeRange param,
180                            AttrSubElementReplacements &attrRepls,
181                            TypeSubElementReplacements &typeRepls) {
182     return typeRepls.take_front(param.size());
183   }
184 };
185 
186 } // namespace mlir
187 
188 namespace llvm {
189 
190 // Provide DenseMapInfo for TypeRange.
191 template <>
192 struct DenseMapInfo<mlir::TypeRange> {
193   static mlir::TypeRange getEmptyKey() {
194     return mlir::TypeRange(getEmptyKeyPointer(), 0);
195   }
196 
197   static mlir::TypeRange getTombstoneKey() {
198     return mlir::TypeRange(getTombstoneKeyPointer(), 0);
199   }
200 
201   static unsigned getHashValue(mlir::TypeRange val) { return hash_value(val); }
202 
203   static bool isEqual(mlir::TypeRange lhs, mlir::TypeRange rhs) {
204     if (isEmptyKey(rhs))
205       return isEmptyKey(lhs);
206     if (isTombstoneKey(rhs))
207       return isTombstoneKey(lhs);
208     return lhs == rhs;
209   }
210 
211 private:
212   static const mlir::Type *getEmptyKeyPointer() {
213     return DenseMapInfo<mlir::Type *>::getEmptyKey();
214   }
215 
216   static const mlir::Type *getTombstoneKeyPointer() {
217     return DenseMapInfo<mlir::Type *>::getTombstoneKey();
218   }
219 
220   static bool isEmptyKey(mlir::TypeRange range) {
221     if (const auto *type =
222             llvm::dyn_cast_if_present<const mlir::Type *>(range.getBase()))
223       return type == getEmptyKeyPointer();
224     return false;
225   }
226 
227   static bool isTombstoneKey(mlir::TypeRange range) {
228     if (const auto *type =
229             llvm::dyn_cast_if_present<const mlir::Type *>(range.getBase()))
230       return type == getTombstoneKeyPointer();
231     return false;
232   }
233 };
234 
235 } // namespace llvm
236 
237 #endif // MLIR_IR_TYPERANGE_H
238