xref: /llvm-project/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- BuiltinAttributeInterfaces.h - Builtin Attr Interfaces ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
10 #define MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
11 
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/Attributes.h"
14 #include "mlir/IR/BuiltinTypeInterfaces.h"
15 #include "mlir/IR/Types.h"
16 #include "llvm/Support/raw_ostream.h"
17 #include <complex>
18 #include <optional>
19 
20 namespace mlir {
21 
22 //===----------------------------------------------------------------------===//
23 // ElementsAttr
24 //===----------------------------------------------------------------------===//
25 namespace detail {
26 /// This class provides support for indexing into the element range of an
27 /// ElementsAttr. It is used to opaquely wrap either a contiguous range, via
28 /// `ElementsAttrIndexer::contiguous`, or a non-contiguous range, via
29 /// `ElementsAttrIndexer::nonContiguous`, A contiguous range is an array-like
30 /// range, where all of the elements are layed out sequentially in memory. A
31 /// non-contiguous range implies no contiguity, and elements may even be
32 /// materialized when indexing, such as the case for a mapped_range.
33 struct ElementsAttrIndexer {
34 public:
ElementsAttrIndexerElementsAttrIndexer35   ElementsAttrIndexer()
36       : ElementsAttrIndexer(/*isContiguous=*/true, /*isSplat=*/true) {}
ElementsAttrIndexerElementsAttrIndexer37   ElementsAttrIndexer(ElementsAttrIndexer &&rhs)
38       : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) {
39     if (isContiguous)
40       conState = rhs.conState;
41     else
42       new (&nonConState) NonContiguousState(std::move(rhs.nonConState));
43   }
ElementsAttrIndexerElementsAttrIndexer44   ElementsAttrIndexer(const ElementsAttrIndexer &rhs)
45       : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) {
46     if (isContiguous)
47       conState = rhs.conState;
48     else
49       new (&nonConState) NonContiguousState(rhs.nonConState);
50   }
~ElementsAttrIndexerElementsAttrIndexer51   ~ElementsAttrIndexer() {
52     if (!isContiguous)
53       nonConState.~NonContiguousState();
54   }
55 
56   /// Construct an indexer for a non-contiguous range starting at the given
57   /// iterator. A non-contiguous range implies no contiguity, and elements may
58   /// even be materialized when indexing, such as the case for a mapped_range.
59   template <typename IteratorT>
nonContiguousElementsAttrIndexer60   static ElementsAttrIndexer nonContiguous(bool isSplat, IteratorT &&iterator) {
61     ElementsAttrIndexer indexer(/*isContiguous=*/false, isSplat);
62     new (&indexer.nonConState)
63         NonContiguousState(std::forward<IteratorT>(iterator));
64     return indexer;
65   }
66 
67   // Construct an indexer for a contiguous range starting at the given element
68   // pointer. A contiguous range is an array-like range, where all of the
69   // elements are layed out sequentially in memory.
70   template <typename T>
contiguousElementsAttrIndexer71   static ElementsAttrIndexer contiguous(bool isSplat, const T *firstEltPtr) {
72     ElementsAttrIndexer indexer(/*isContiguous=*/true, isSplat);
73     new (&indexer.conState) ContiguousState(firstEltPtr);
74     return indexer;
75   }
76 
77   /// Access the element at the given index.
78   template <typename T>
atElementsAttrIndexer79   T at(uint64_t index) const {
80     if (isSplat)
81       index = 0;
82     return isContiguous ? conState.at<T>(index) : nonConState.at<T>(index);
83   }
84 
85 private:
ElementsAttrIndexerElementsAttrIndexer86   ElementsAttrIndexer(bool isContiguous, bool isSplat)
87       : isContiguous(isContiguous), isSplat(isSplat), conState(nullptr) {}
88 
89   /// This class contains all of the state necessary to index a contiguous
90   /// range.
91   class ContiguousState {
92   public:
ContiguousStateElementsAttrIndexer93     ContiguousState(const void *firstEltPtr) : firstEltPtr(firstEltPtr) {}
94 
95     /// Access the element at the given index.
96     template <typename T>
atElementsAttrIndexer97     const T &at(uint64_t index) const {
98       return *(reinterpret_cast<const T *>(firstEltPtr) + index);
99     }
100 
101   private:
102     const void *firstEltPtr;
103   };
104 
105   /// This class contains all of the state necessary to index a non-contiguous
106   /// range.
107   class NonContiguousState {
108   private:
109     /// This class is used to represent the abstract base of an opaque iterator.
110     /// This allows for all iterator and element types to be completely
111     /// type-erased.
112     struct OpaqueIteratorBase {
113       virtual ~OpaqueIteratorBase() = default;
114       virtual std::unique_ptr<OpaqueIteratorBase> clone() const = 0;
115     };
116     /// This class is used to represent the abstract base of an opaque iterator
117     /// that iterates over elements of type `T`. This allows for all iterator
118     /// types to be completely type-erased.
119     template <typename T>
120     struct OpaqueIteratorValueBase : public OpaqueIteratorBase {
121       virtual T at(uint64_t index) = 0;
122     };
123     /// This class is used to represent an opaque handle to an iterator of type
124     /// `IteratorT` that iterates over elements of type `T`.
125     template <typename IteratorT, typename T>
126     struct OpaqueIterator : public OpaqueIteratorValueBase<T> {
127       template <typename ItTy, typename FuncTy, typename FuncReturnTy>
isMappedIteratorTestFnElementsAttrIndexer::OpaqueIterator128       static void isMappedIteratorTestFn(
129           llvm::mapped_iterator<ItTy, FuncTy, FuncReturnTy>) {}
130       template <typename U, typename... Args>
131       using is_mapped_iterator =
132           decltype(isMappedIteratorTestFn(std::declval<U>()));
133       template <typename U>
134       using detect_is_mapped_iterator =
135           llvm::is_detected<is_mapped_iterator, U>;
136 
137       /// Access the element within the iterator at the given index.
138       template <typename ItT>
139       static std::enable_if_t<!detect_is_mapped_iterator<ItT>::value, T>
atImplElementsAttrIndexer::OpaqueIterator140       atImpl(ItT &&it, uint64_t index) {
141         return *std::next(it, index);
142       }
143       template <typename ItT>
144       static std::enable_if_t<detect_is_mapped_iterator<ItT>::value, T>
atImplElementsAttrIndexer::OpaqueIterator145       atImpl(ItT &&it, uint64_t index) {
146         // Special case mapped_iterator to avoid copying the function.
147         return it.getFunction()(*std::next(it.getCurrent(), index));
148       }
149 
150     public:
151       template <typename U>
OpaqueIteratorElementsAttrIndexer::OpaqueIterator152       OpaqueIterator(U &&iterator) : iterator(std::forward<U>(iterator)) {}
cloneElementsAttrIndexer::OpaqueIterator153       std::unique_ptr<OpaqueIteratorBase> clone() const final {
154         return std::make_unique<OpaqueIterator<IteratorT, T>>(iterator);
155       }
156 
157       /// Access the element at the given index.
atElementsAttrIndexer::OpaqueIterator158       T at(uint64_t index) final { return atImpl(iterator, index); }
159 
160     private:
161       IteratorT iterator;
162     };
163 
164   public:
165     /// Construct the state with the given iterator type.
166     template <typename IteratorT, typename T = typename llvm::remove_cvref_t<
167                                       decltype(*std::declval<IteratorT>())>>
NonContiguousStateElementsAttrIndexer168     NonContiguousState(IteratorT iterator)
169         : iterator(std::make_unique<OpaqueIterator<IteratorT, T>>(iterator)) {}
NonContiguousStateElementsAttrIndexer170     NonContiguousState(const NonContiguousState &other)
171         : iterator(other.iterator->clone()) {}
172     NonContiguousState(NonContiguousState &&other) = default;
173 
174     /// Access the element at the given index.
175     template <typename T>
atElementsAttrIndexer176     T at(uint64_t index) const {
177       auto *valueIt = static_cast<OpaqueIteratorValueBase<T> *>(iterator.get());
178       return valueIt->at(index);
179     }
180 
181     /// The opaque iterator state.
182     std::unique_ptr<OpaqueIteratorBase> iterator;
183   };
184 
185   /// A boolean indicating if this range is contiguous or not.
186   bool isContiguous;
187   /// A boolean indicating if this range is a splat.
188   bool isSplat;
189   /// The underlying range state.
190   union {
191     ContiguousState conState;
192     NonContiguousState nonConState;
193   };
194 };
195 
196 /// This class implements a generic iterator for ElementsAttr.
197 template <typename T>
198 class ElementsAttrIterator
199     : public llvm::iterator_facade_base<ElementsAttrIterator<T>,
200                                         std::random_access_iterator_tag, T,
201                                         std::ptrdiff_t, T, T> {
202 public:
ElementsAttrIterator(ElementsAttrIndexer indexer,size_t dataIndex)203   ElementsAttrIterator(ElementsAttrIndexer indexer, size_t dataIndex)
204       : indexer(std::move(indexer)), index(dataIndex) {}
205 
206   // Boilerplate iterator methods.
207   ptrdiff_t operator-(const ElementsAttrIterator &rhs) const {
208     return index - rhs.index;
209   }
210   bool operator==(const ElementsAttrIterator &rhs) const {
211     return index == rhs.index;
212   }
213   bool operator<(const ElementsAttrIterator &rhs) const {
214     return index < rhs.index;
215   }
216   ElementsAttrIterator &operator+=(ptrdiff_t offset) {
217     index += offset;
218     return *this;
219   }
220   ElementsAttrIterator &operator-=(ptrdiff_t offset) {
221     index -= offset;
222     return *this;
223   }
224 
225   /// Return the value at the current iterator position.
226   T operator*() const { return indexer.at<T>(index); }
227 
228 private:
229   ElementsAttrIndexer indexer;
230   ptrdiff_t index;
231 };
232 
233 /// This class provides iterator utilities for an ElementsAttr range.
234 template <typename IteratorT>
235 class ElementsAttrRange : public llvm::iterator_range<IteratorT> {
236 public:
237   using reference = typename IteratorT::reference;
238 
ElementsAttrRange(ShapedType shapeType,const llvm::iterator_range<IteratorT> & range)239   ElementsAttrRange(ShapedType shapeType,
240                     const llvm::iterator_range<IteratorT> &range)
241       : llvm::iterator_range<IteratorT>(range), shapeType(shapeType) {}
ElementsAttrRange(ShapedType shapeType,IteratorT beginIt,IteratorT endIt)242   ElementsAttrRange(ShapedType shapeType, IteratorT beginIt, IteratorT endIt)
243       : ElementsAttrRange(shapeType, llvm::make_range(beginIt, endIt)) {}
244 
245   /// Return the value at the given index.
246   reference operator[](ArrayRef<uint64_t> index) const;
247   reference operator[](uint64_t index) const {
248     return *std::next(this->begin(), index);
249   }
250 
251   /// Return the size of this range.
size()252   size_t size() const { return llvm::size(*this); }
253 
254 private:
255   /// The shaped type of the parent ElementsAttr.
256   ShapedType shapeType;
257 };
258 
259 } // namespace detail
260 
261 //===----------------------------------------------------------------------===//
262 // MemRefLayoutAttrInterface
263 //===----------------------------------------------------------------------===//
264 
265 namespace detail {
266 
267 // Verify the affine map 'm' can be used as a layout specification
268 // for memref with 'shape'.
269 LogicalResult
270 verifyAffineMapAsLayout(AffineMap m, ArrayRef<int64_t> shape,
271                         function_ref<InFlightDiagnostic()> emitError);
272 
273 } // namespace detail
274 
275 } // namespace mlir
276 
277 //===----------------------------------------------------------------------===//
278 // Tablegen Interface Declarations
279 //===----------------------------------------------------------------------===//
280 
281 #include "mlir/IR/BuiltinAttributeInterfaces.h.inc"
282 
283 //===----------------------------------------------------------------------===//
284 // ElementsAttr
285 //===----------------------------------------------------------------------===//
286 
287 namespace mlir {
288 namespace detail {
289 /// Return the value at the given index.
290 template <typename IteratorT>
291 auto ElementsAttrRange<IteratorT>::operator[](ArrayRef<uint64_t> index) const
292     -> reference {
293   // Skip to the element corresponding to the flattened index.
294   return (*this)[ElementsAttr::getFlattenedIndex(shapeType, index)];
295 }
296 } // namespace detail
297 
298 /// Return the elements of this attribute as a value of type 'T'.
299 template <typename T>
300 auto ElementsAttr::value_begin() const -> DefaultValueCheckT<T, iterator<T>> {
301   if (std::optional<iterator<T>> iterator = try_value_begin<T>())
302     return std::move(*iterator);
303   llvm::errs()
304       << "ElementsAttr does not provide iteration facilities for type `"
305       << llvm::getTypeName<T>() << "`, see attribute: " << *this << "\n";
306   llvm_unreachable("invalid `T` for ElementsAttr::getValues");
307 }
308 template <typename T>
309 auto ElementsAttr::try_value_begin() const
310     -> DefaultValueCheckT<T, std::optional<iterator<T>>> {
311   FailureOr<detail::ElementsAttrIndexer> indexer =
312       getValuesImpl(TypeID::get<T>());
313   if (failed(indexer))
314     return std::nullopt;
315   return iterator<T>(std::move(*indexer), 0);
316 }
317 } // namespace mlir.
318 
319 #endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
320