1 //===- BuiltinAttributeInterfaces.h - Builtin Attr Interfaces ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_H 10 #define MLIR_IR_BUILTINATTRIBUTEINTERFACES_H 11 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/Attributes.h" 14 #include "mlir/IR/BuiltinTypeInterfaces.h" 15 #include "mlir/IR/Types.h" 16 #include "llvm/Support/raw_ostream.h" 17 #include <complex> 18 #include <optional> 19 20 namespace mlir { 21 22 //===----------------------------------------------------------------------===// 23 // ElementsAttr 24 //===----------------------------------------------------------------------===// 25 namespace detail { 26 /// This class provides support for indexing into the element range of an 27 /// ElementsAttr. It is used to opaquely wrap either a contiguous range, via 28 /// `ElementsAttrIndexer::contiguous`, or a non-contiguous range, via 29 /// `ElementsAttrIndexer::nonContiguous`, A contiguous range is an array-like 30 /// range, where all of the elements are layed out sequentially in memory. A 31 /// non-contiguous range implies no contiguity, and elements may even be 32 /// materialized when indexing, such as the case for a mapped_range. 33 struct ElementsAttrIndexer { 34 public: ElementsAttrIndexerElementsAttrIndexer35 ElementsAttrIndexer() 36 : ElementsAttrIndexer(/*isContiguous=*/true, /*isSplat=*/true) {} ElementsAttrIndexerElementsAttrIndexer37 ElementsAttrIndexer(ElementsAttrIndexer &&rhs) 38 : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) { 39 if (isContiguous) 40 conState = rhs.conState; 41 else 42 new (&nonConState) NonContiguousState(std::move(rhs.nonConState)); 43 } ElementsAttrIndexerElementsAttrIndexer44 ElementsAttrIndexer(const ElementsAttrIndexer &rhs) 45 : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) { 46 if (isContiguous) 47 conState = rhs.conState; 48 else 49 new (&nonConState) NonContiguousState(rhs.nonConState); 50 } ~ElementsAttrIndexerElementsAttrIndexer51 ~ElementsAttrIndexer() { 52 if (!isContiguous) 53 nonConState.~NonContiguousState(); 54 } 55 56 /// Construct an indexer for a non-contiguous range starting at the given 57 /// iterator. A non-contiguous range implies no contiguity, and elements may 58 /// even be materialized when indexing, such as the case for a mapped_range. 59 template <typename IteratorT> nonContiguousElementsAttrIndexer60 static ElementsAttrIndexer nonContiguous(bool isSplat, IteratorT &&iterator) { 61 ElementsAttrIndexer indexer(/*isContiguous=*/false, isSplat); 62 new (&indexer.nonConState) 63 NonContiguousState(std::forward<IteratorT>(iterator)); 64 return indexer; 65 } 66 67 // Construct an indexer for a contiguous range starting at the given element 68 // pointer. A contiguous range is an array-like range, where all of the 69 // elements are layed out sequentially in memory. 70 template <typename T> contiguousElementsAttrIndexer71 static ElementsAttrIndexer contiguous(bool isSplat, const T *firstEltPtr) { 72 ElementsAttrIndexer indexer(/*isContiguous=*/true, isSplat); 73 new (&indexer.conState) ContiguousState(firstEltPtr); 74 return indexer; 75 } 76 77 /// Access the element at the given index. 78 template <typename T> atElementsAttrIndexer79 T at(uint64_t index) const { 80 if (isSplat) 81 index = 0; 82 return isContiguous ? conState.at<T>(index) : nonConState.at<T>(index); 83 } 84 85 private: ElementsAttrIndexerElementsAttrIndexer86 ElementsAttrIndexer(bool isContiguous, bool isSplat) 87 : isContiguous(isContiguous), isSplat(isSplat), conState(nullptr) {} 88 89 /// This class contains all of the state necessary to index a contiguous 90 /// range. 91 class ContiguousState { 92 public: ContiguousStateElementsAttrIndexer93 ContiguousState(const void *firstEltPtr) : firstEltPtr(firstEltPtr) {} 94 95 /// Access the element at the given index. 96 template <typename T> atElementsAttrIndexer97 const T &at(uint64_t index) const { 98 return *(reinterpret_cast<const T *>(firstEltPtr) + index); 99 } 100 101 private: 102 const void *firstEltPtr; 103 }; 104 105 /// This class contains all of the state necessary to index a non-contiguous 106 /// range. 107 class NonContiguousState { 108 private: 109 /// This class is used to represent the abstract base of an opaque iterator. 110 /// This allows for all iterator and element types to be completely 111 /// type-erased. 112 struct OpaqueIteratorBase { 113 virtual ~OpaqueIteratorBase() = default; 114 virtual std::unique_ptr<OpaqueIteratorBase> clone() const = 0; 115 }; 116 /// This class is used to represent the abstract base of an opaque iterator 117 /// that iterates over elements of type `T`. This allows for all iterator 118 /// types to be completely type-erased. 119 template <typename T> 120 struct OpaqueIteratorValueBase : public OpaqueIteratorBase { 121 virtual T at(uint64_t index) = 0; 122 }; 123 /// This class is used to represent an opaque handle to an iterator of type 124 /// `IteratorT` that iterates over elements of type `T`. 125 template <typename IteratorT, typename T> 126 struct OpaqueIterator : public OpaqueIteratorValueBase<T> { 127 template <typename ItTy, typename FuncTy, typename FuncReturnTy> isMappedIteratorTestFnElementsAttrIndexer::OpaqueIterator128 static void isMappedIteratorTestFn( 129 llvm::mapped_iterator<ItTy, FuncTy, FuncReturnTy>) {} 130 template <typename U, typename... Args> 131 using is_mapped_iterator = 132 decltype(isMappedIteratorTestFn(std::declval<U>())); 133 template <typename U> 134 using detect_is_mapped_iterator = 135 llvm::is_detected<is_mapped_iterator, U>; 136 137 /// Access the element within the iterator at the given index. 138 template <typename ItT> 139 static std::enable_if_t<!detect_is_mapped_iterator<ItT>::value, T> atImplElementsAttrIndexer::OpaqueIterator140 atImpl(ItT &&it, uint64_t index) { 141 return *std::next(it, index); 142 } 143 template <typename ItT> 144 static std::enable_if_t<detect_is_mapped_iterator<ItT>::value, T> atImplElementsAttrIndexer::OpaqueIterator145 atImpl(ItT &&it, uint64_t index) { 146 // Special case mapped_iterator to avoid copying the function. 147 return it.getFunction()(*std::next(it.getCurrent(), index)); 148 } 149 150 public: 151 template <typename U> OpaqueIteratorElementsAttrIndexer::OpaqueIterator152 OpaqueIterator(U &&iterator) : iterator(std::forward<U>(iterator)) {} cloneElementsAttrIndexer::OpaqueIterator153 std::unique_ptr<OpaqueIteratorBase> clone() const final { 154 return std::make_unique<OpaqueIterator<IteratorT, T>>(iterator); 155 } 156 157 /// Access the element at the given index. atElementsAttrIndexer::OpaqueIterator158 T at(uint64_t index) final { return atImpl(iterator, index); } 159 160 private: 161 IteratorT iterator; 162 }; 163 164 public: 165 /// Construct the state with the given iterator type. 166 template <typename IteratorT, typename T = typename llvm::remove_cvref_t< 167 decltype(*std::declval<IteratorT>())>> NonContiguousStateElementsAttrIndexer168 NonContiguousState(IteratorT iterator) 169 : iterator(std::make_unique<OpaqueIterator<IteratorT, T>>(iterator)) {} NonContiguousStateElementsAttrIndexer170 NonContiguousState(const NonContiguousState &other) 171 : iterator(other.iterator->clone()) {} 172 NonContiguousState(NonContiguousState &&other) = default; 173 174 /// Access the element at the given index. 175 template <typename T> atElementsAttrIndexer176 T at(uint64_t index) const { 177 auto *valueIt = static_cast<OpaqueIteratorValueBase<T> *>(iterator.get()); 178 return valueIt->at(index); 179 } 180 181 /// The opaque iterator state. 182 std::unique_ptr<OpaqueIteratorBase> iterator; 183 }; 184 185 /// A boolean indicating if this range is contiguous or not. 186 bool isContiguous; 187 /// A boolean indicating if this range is a splat. 188 bool isSplat; 189 /// The underlying range state. 190 union { 191 ContiguousState conState; 192 NonContiguousState nonConState; 193 }; 194 }; 195 196 /// This class implements a generic iterator for ElementsAttr. 197 template <typename T> 198 class ElementsAttrIterator 199 : public llvm::iterator_facade_base<ElementsAttrIterator<T>, 200 std::random_access_iterator_tag, T, 201 std::ptrdiff_t, T, T> { 202 public: ElementsAttrIterator(ElementsAttrIndexer indexer,size_t dataIndex)203 ElementsAttrIterator(ElementsAttrIndexer indexer, size_t dataIndex) 204 : indexer(std::move(indexer)), index(dataIndex) {} 205 206 // Boilerplate iterator methods. 207 ptrdiff_t operator-(const ElementsAttrIterator &rhs) const { 208 return index - rhs.index; 209 } 210 bool operator==(const ElementsAttrIterator &rhs) const { 211 return index == rhs.index; 212 } 213 bool operator<(const ElementsAttrIterator &rhs) const { 214 return index < rhs.index; 215 } 216 ElementsAttrIterator &operator+=(ptrdiff_t offset) { 217 index += offset; 218 return *this; 219 } 220 ElementsAttrIterator &operator-=(ptrdiff_t offset) { 221 index -= offset; 222 return *this; 223 } 224 225 /// Return the value at the current iterator position. 226 T operator*() const { return indexer.at<T>(index); } 227 228 private: 229 ElementsAttrIndexer indexer; 230 ptrdiff_t index; 231 }; 232 233 /// This class provides iterator utilities for an ElementsAttr range. 234 template <typename IteratorT> 235 class ElementsAttrRange : public llvm::iterator_range<IteratorT> { 236 public: 237 using reference = typename IteratorT::reference; 238 ElementsAttrRange(ShapedType shapeType,const llvm::iterator_range<IteratorT> & range)239 ElementsAttrRange(ShapedType shapeType, 240 const llvm::iterator_range<IteratorT> &range) 241 : llvm::iterator_range<IteratorT>(range), shapeType(shapeType) {} ElementsAttrRange(ShapedType shapeType,IteratorT beginIt,IteratorT endIt)242 ElementsAttrRange(ShapedType shapeType, IteratorT beginIt, IteratorT endIt) 243 : ElementsAttrRange(shapeType, llvm::make_range(beginIt, endIt)) {} 244 245 /// Return the value at the given index. 246 reference operator[](ArrayRef<uint64_t> index) const; 247 reference operator[](uint64_t index) const { 248 return *std::next(this->begin(), index); 249 } 250 251 /// Return the size of this range. size()252 size_t size() const { return llvm::size(*this); } 253 254 private: 255 /// The shaped type of the parent ElementsAttr. 256 ShapedType shapeType; 257 }; 258 259 } // namespace detail 260 261 //===----------------------------------------------------------------------===// 262 // MemRefLayoutAttrInterface 263 //===----------------------------------------------------------------------===// 264 265 namespace detail { 266 267 // Verify the affine map 'm' can be used as a layout specification 268 // for memref with 'shape'. 269 LogicalResult 270 verifyAffineMapAsLayout(AffineMap m, ArrayRef<int64_t> shape, 271 function_ref<InFlightDiagnostic()> emitError); 272 273 } // namespace detail 274 275 } // namespace mlir 276 277 //===----------------------------------------------------------------------===// 278 // Tablegen Interface Declarations 279 //===----------------------------------------------------------------------===// 280 281 #include "mlir/IR/BuiltinAttributeInterfaces.h.inc" 282 283 //===----------------------------------------------------------------------===// 284 // ElementsAttr 285 //===----------------------------------------------------------------------===// 286 287 namespace mlir { 288 namespace detail { 289 /// Return the value at the given index. 290 template <typename IteratorT> 291 auto ElementsAttrRange<IteratorT>::operator[](ArrayRef<uint64_t> index) const 292 -> reference { 293 // Skip to the element corresponding to the flattened index. 294 return (*this)[ElementsAttr::getFlattenedIndex(shapeType, index)]; 295 } 296 } // namespace detail 297 298 /// Return the elements of this attribute as a value of type 'T'. 299 template <typename T> 300 auto ElementsAttr::value_begin() const -> DefaultValueCheckT<T, iterator<T>> { 301 if (std::optional<iterator<T>> iterator = try_value_begin<T>()) 302 return std::move(*iterator); 303 llvm::errs() 304 << "ElementsAttr does not provide iteration facilities for type `" 305 << llvm::getTypeName<T>() << "`, see attribute: " << *this << "\n"; 306 llvm_unreachable("invalid `T` for ElementsAttr::getValues"); 307 } 308 template <typename T> 309 auto ElementsAttr::try_value_begin() const 310 -> DefaultValueCheckT<T, std::optional<iterator<T>>> { 311 FailureOr<detail::ElementsAttrIndexer> indexer = 312 getValuesImpl(TypeID::get<T>()); 313 if (failed(indexer)) 314 return std::nullopt; 315 return iterator<T>(std::move(*indexer), 0); 316 } 317 } // namespace mlir. 318 319 #endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_H 320