1 // -*- C++ -*- 2 //===----------------------------------------------------------------------===// 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 // Kokkos v. 4.0 9 // Copyright (2022) National Technology & Engineering 10 // Solutions of Sandia, LLC (NTESS). 11 // 12 // Under the terms of Contract DE-NA0003525 with NTESS, 13 // the U.S. Government retains certain rights in this software. 14 // 15 //===---------------------------------------------------------------------===// 16 17 #ifndef _LIBCPP___MDSPAN_LAYOUT_STRIDE_H 18 #define _LIBCPP___MDSPAN_LAYOUT_STRIDE_H 19 20 #include <__assert> 21 #include <__concepts/same_as.h> 22 #include <__config> 23 #include <__fwd/mdspan.h> 24 #include <__mdspan/extents.h> 25 #include <__type_traits/common_type.h> 26 #include <__type_traits/is_constructible.h> 27 #include <__type_traits/is_convertible.h> 28 #include <__type_traits/is_integral.h> 29 #include <__type_traits/is_nothrow_constructible.h> 30 #include <__type_traits/is_same.h> 31 #include <__utility/as_const.h> 32 #include <__utility/integer_sequence.h> 33 #include <__utility/swap.h> 34 #include <array> 35 #include <limits> 36 #include <span> 37 38 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) 39 # pragma GCC system_header 40 #endif 41 42 _LIBCPP_PUSH_MACROS 43 #include <__undef_macros> 44 45 _LIBCPP_BEGIN_NAMESPACE_STD 46 47 #if _LIBCPP_STD_VER >= 23 48 49 namespace __mdspan_detail { 50 template <class _Layout, class _Mapping> 51 constexpr bool __is_mapping_of = 52 is_same_v<typename _Layout::template mapping<typename _Mapping::extents_type>, _Mapping>; 53 54 template <class _Mapping> 55 concept __layout_mapping_alike = requires { 56 requires __is_mapping_of<typename _Mapping::layout_type, _Mapping>; 57 requires __is_extents_v<typename _Mapping::extents_type>; 58 { _Mapping::is_always_strided() } -> same_as<bool>; 59 { _Mapping::is_always_exhaustive() } -> same_as<bool>; 60 { _Mapping::is_always_unique() } -> same_as<bool>; 61 bool_constant<_Mapping::is_always_strided()>::value; 62 bool_constant<_Mapping::is_always_exhaustive()>::value; 63 bool_constant<_Mapping::is_always_unique()>::value; 64 }; 65 } // namespace __mdspan_detail 66 67 template <class _Extents> 68 class layout_stride::mapping { 69 public: 70 static_assert(__mdspan_detail::__is_extents<_Extents>::value, 71 "layout_stride::mapping template argument must be a specialization of extents."); 72 73 using extents_type = _Extents; 74 using index_type = typename extents_type::index_type; 75 using size_type = typename extents_type::size_type; 76 using rank_type = typename extents_type::rank_type; 77 using layout_type = layout_stride; 78 79 private: 80 static constexpr rank_type __rank_ = extents_type::rank(); 81 82 // Used for default construction check and mandates 83 _LIBCPP_HIDE_FROM_ABI static constexpr bool __required_span_size_is_representable(const extents_type& __ext) { 84 if constexpr (__rank_ == 0) 85 return true; 86 87 index_type __prod = __ext.extent(0); 88 for (rank_type __r = 1; __r < __rank_; __r++) { 89 bool __overflowed = __builtin_mul_overflow(__prod, __ext.extent(__r), &__prod); 90 if (__overflowed) 91 return false; 92 } 93 return true; 94 } 95 96 template <class _OtherIndexType> 97 _LIBCPP_HIDE_FROM_ABI static constexpr bool 98 __required_span_size_is_representable(const extents_type& __ext, span<_OtherIndexType, __rank_> __strides) { 99 if constexpr (__rank_ == 0) 100 return true; 101 102 index_type __size = 1; 103 for (rank_type __r = 0; __r < __rank_; __r++) { 104 // We can only check correct conversion of _OtherIndexType if it is an integral 105 if constexpr (is_integral_v<_OtherIndexType>) { 106 using _CommonType = common_type_t<index_type, _OtherIndexType>; 107 if (static_cast<_CommonType>(__strides[__r]) > static_cast<_CommonType>(numeric_limits<index_type>::max())) 108 return false; 109 } 110 if (__ext.extent(__r) == static_cast<index_type>(0)) 111 return true; 112 index_type __prod = (__ext.extent(__r) - 1); 113 bool __overflowed_mul = __builtin_mul_overflow(__prod, static_cast<index_type>(__strides[__r]), &__prod); 114 if (__overflowed_mul) 115 return false; 116 bool __overflowed_add = __builtin_add_overflow(__size, __prod, &__size); 117 if (__overflowed_add) 118 return false; 119 } 120 return true; 121 } 122 123 // compute offset of a strided layout mapping 124 template <class _StridedMapping> 125 _LIBCPP_HIDE_FROM_ABI static constexpr index_type __offset(const _StridedMapping& __mapping) { 126 if constexpr (_StridedMapping::extents_type::rank() == 0) { 127 return static_cast<index_type>(__mapping()); 128 } else if (__mapping.required_span_size() == static_cast<typename _StridedMapping::index_type>(0)) { 129 return static_cast<index_type>(0); 130 } else { 131 return [&]<size_t... _Pos>(index_sequence<_Pos...>) { 132 return static_cast<index_type>(__mapping((_Pos ? 0 : 0)...)); 133 }(make_index_sequence<__rank_>()); 134 } 135 } 136 137 // compute the permutation for sorting the stride array 138 // we never actually sort the stride array 139 _LIBCPP_HIDE_FROM_ABI constexpr void __bubble_sort_by_strides(array<rank_type, __rank_>& __permute) const { 140 for (rank_type __i = __rank_ - 1; __i > 0; __i--) { 141 for (rank_type __r = 0; __r < __i; __r++) { 142 if (__strides_[__permute[__r]] > __strides_[__permute[__r + 1]]) { 143 swap(__permute[__r], __permute[__r + 1]); 144 } else { 145 // if two strides are the same then one of the associated extents must be 1 or 0 146 // both could be, but you can't have one larger than 1 come first 147 if ((__strides_[__permute[__r]] == __strides_[__permute[__r + 1]]) && 148 (__extents_.extent(__permute[__r]) > static_cast<index_type>(1))) 149 swap(__permute[__r], __permute[__r + 1]); 150 } 151 } 152 } 153 } 154 155 static_assert(extents_type::rank_dynamic() > 0 || __required_span_size_is_representable(extents_type()), 156 "layout_stride::mapping product of static extents must be representable as index_type."); 157 158 public: 159 // [mdspan.layout.stride.cons], constructors 160 _LIBCPP_HIDE_FROM_ABI constexpr mapping() noexcept : __extents_(extents_type()) { 161 // Note the nominal precondition is covered by above static assert since 162 // if rank_dynamic is != 0 required_span_size is zero for default construction 163 if constexpr (__rank_ > 0) { 164 index_type __stride = 1; 165 for (rank_type __r = __rank_ - 1; __r > static_cast<rank_type>(0); __r--) { 166 __strides_[__r] = __stride; 167 __stride *= __extents_.extent(__r); 168 } 169 __strides_[0] = __stride; 170 } 171 } 172 173 _LIBCPP_HIDE_FROM_ABI constexpr mapping(const mapping&) noexcept = default; 174 175 template <class _OtherIndexType> 176 requires(is_convertible_v<const _OtherIndexType&, index_type> && 177 is_nothrow_constructible_v<index_type, const _OtherIndexType&>) 178 _LIBCPP_HIDE_FROM_ABI constexpr mapping(const extents_type& __ext, span<_OtherIndexType, __rank_> __strides) noexcept 179 : __extents_(__ext), __strides_([&]<size_t... _Pos>(index_sequence<_Pos...>) { 180 return __mdspan_detail::__possibly_empty_array<index_type, __rank_>{ 181 static_cast<index_type>(std::as_const(__strides[_Pos]))...}; 182 }(make_index_sequence<__rank_>())) { 183 _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS( 184 ([&]<size_t... _Pos>(index_sequence<_Pos...>) { 185 // For integrals we can do a pre-conversion check, for other types not 186 if constexpr (is_integral_v<_OtherIndexType>) { 187 return ((__strides[_Pos] > static_cast<_OtherIndexType>(0)) && ... && true); 188 } else { 189 return ((static_cast<index_type>(__strides[_Pos]) > static_cast<index_type>(0)) && ... && true); 190 } 191 }(make_index_sequence<__rank_>())), 192 "layout_stride::mapping ctor: all strides must be greater than 0"); 193 _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS( 194 __required_span_size_is_representable(__ext, __strides), 195 "layout_stride::mapping ctor: required span size is not representable as index_type."); 196 if constexpr (__rank_ > 1) { 197 _LIBCPP_ASSERT_UNCATEGORIZED( 198 ([&]<size_t... _Pos>(index_sequence<_Pos...>) { 199 // basically sort the dimensions based on strides and extents, sorting is represented in permute array 200 array<rank_type, __rank_> __permute{_Pos...}; 201 __bubble_sort_by_strides(__permute); 202 203 // check that this permutations represents a growing set 204 for (rank_type __i = 1; __i < __rank_; __i++) 205 if (static_cast<index_type>(__strides[__permute[__i]]) < 206 static_cast<index_type>(__strides[__permute[__i - 1]]) * __extents_.extent(__permute[__i - 1])) 207 return false; 208 return true; 209 }(make_index_sequence<__rank_>())), 210 "layout_stride::mapping ctor: the provided extents and strides lead to a non-unique mapping"); 211 } 212 } 213 214 template <class _OtherIndexType> 215 requires(is_convertible_v<const _OtherIndexType&, index_type> && 216 is_nothrow_constructible_v<index_type, const _OtherIndexType&>) 217 _LIBCPP_HIDE_FROM_ABI constexpr mapping(const extents_type& __ext, 218 const array<_OtherIndexType, __rank_>& __strides) noexcept 219 : mapping(__ext, span(__strides)) {} 220 221 template <class _StridedLayoutMapping> 222 requires(__mdspan_detail::__layout_mapping_alike<_StridedLayoutMapping> && 223 is_constructible_v<extents_type, typename _StridedLayoutMapping::extents_type> && 224 _StridedLayoutMapping::is_always_unique() && _StridedLayoutMapping::is_always_strided()) 225 _LIBCPP_HIDE_FROM_ABI constexpr explicit( 226 !(is_convertible_v<typename _StridedLayoutMapping::extents_type, extents_type> && 227 (__mdspan_detail::__is_mapping_of<layout_left, _StridedLayoutMapping> || 228 __mdspan_detail::__is_mapping_of<layout_right, _StridedLayoutMapping> || 229 __mdspan_detail::__is_mapping_of<layout_stride, _StridedLayoutMapping>))) 230 mapping(const _StridedLayoutMapping& __other) noexcept 231 : __extents_(__other.extents()), __strides_([&]<size_t... _Pos>(index_sequence<_Pos...>) { 232 // stride() only compiles for rank > 0 233 if constexpr (__rank_ > 0) { 234 return __mdspan_detail::__possibly_empty_array<index_type, __rank_>{ 235 static_cast<index_type>(__other.stride(_Pos))...}; 236 } else { 237 return __mdspan_detail::__possibly_empty_array<index_type, 0>{}; 238 } 239 }(make_index_sequence<__rank_>())) { 240 // stride() only compiles for rank > 0 241 if constexpr (__rank_ > 0) { 242 _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS( 243 ([&]<size_t... _Pos>(index_sequence<_Pos...>) { 244 return ((static_cast<index_type>(__other.stride(_Pos)) > static_cast<index_type>(0)) && ... && true); 245 }(make_index_sequence<__rank_>())), 246 "layout_stride::mapping converting ctor: all strides must be greater than 0"); 247 } 248 _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS( 249 __mdspan_detail::__is_representable_as<index_type>(__other.required_span_size()), 250 "layout_stride::mapping converting ctor: other.required_span_size() must be representable as index_type."); 251 _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS(static_cast<index_type>(0) == __offset(__other), 252 "layout_stride::mapping converting ctor: base offset of mapping must be zero."); 253 } 254 255 _LIBCPP_HIDE_FROM_ABI constexpr mapping& operator=(const mapping&) noexcept = default; 256 257 // [mdspan.layout.stride.obs], observers 258 _LIBCPP_HIDE_FROM_ABI constexpr const extents_type& extents() const noexcept { return __extents_; } 259 260 _LIBCPP_HIDE_FROM_ABI constexpr array<index_type, __rank_> strides() const noexcept { 261 return [&]<size_t... _Pos>(index_sequence<_Pos...>) { 262 return array<index_type, __rank_>{__strides_[_Pos]...}; 263 }(make_index_sequence<__rank_>()); 264 } 265 266 _LIBCPP_HIDE_FROM_ABI constexpr index_type required_span_size() const noexcept { 267 if constexpr (__rank_ == 0) { 268 return static_cast<index_type>(1); 269 } else { 270 return [&]<size_t... _Pos>(index_sequence<_Pos...>) { 271 if ((__extents_.extent(_Pos) * ... * 1) == 0) 272 return static_cast<index_type>(0); 273 else 274 return static_cast<index_type>( 275 static_cast<index_type>(1) + 276 (((__extents_.extent(_Pos) - static_cast<index_type>(1)) * __strides_[_Pos]) + ... + 277 static_cast<index_type>(0))); 278 }(make_index_sequence<__rank_>()); 279 } 280 } 281 282 template <class... _Indices> 283 requires((sizeof...(_Indices) == __rank_) && (is_convertible_v<_Indices, index_type> && ...) && 284 (is_nothrow_constructible_v<index_type, _Indices> && ...)) 285 _LIBCPP_HIDE_FROM_ABI constexpr index_type operator()(_Indices... __idx) const noexcept { 286 // Mappings are generally meant to be used for accessing allocations and are meant to guarantee to never 287 // return a value exceeding required_span_size(), which is used to know how large an allocation one needs 288 // Thus, this is a canonical point in multi-dimensional data structures to make invalid element access checks 289 // However, mdspan does check this on its own, so for now we avoid double checking in hardened mode 290 _LIBCPP_ASSERT_UNCATEGORIZED(__mdspan_detail::__is_multidimensional_index_in(__extents_, __idx...), 291 "layout_stride::mapping: out of bounds indexing"); 292 return [&]<size_t... _Pos>(index_sequence<_Pos...>) { 293 return ((static_cast<index_type>(__idx) * __strides_[_Pos]) + ... + index_type(0)); 294 }(make_index_sequence<sizeof...(_Indices)>()); 295 } 296 297 _LIBCPP_HIDE_FROM_ABI static constexpr bool is_always_unique() noexcept { return true; } 298 _LIBCPP_HIDE_FROM_ABI static constexpr bool is_always_exhaustive() noexcept { return false; } 299 _LIBCPP_HIDE_FROM_ABI static constexpr bool is_always_strided() noexcept { return true; } 300 301 _LIBCPP_HIDE_FROM_ABI static constexpr bool is_unique() noexcept { return true; } 302 // The answer of this function is fairly complex in the case where one or more 303 // extents are zero. 304 // Technically it is meaningless to query is_exhaustive() in that case, but unfortunately 305 // the way the standard defines this function, we can't give a simple true or false then. 306 _LIBCPP_HIDE_FROM_ABI constexpr bool is_exhaustive() const noexcept { 307 if constexpr (__rank_ == 0) 308 return true; 309 else { 310 index_type __span_size = required_span_size(); 311 if (__span_size == static_cast<index_type>(0)) { 312 if constexpr (__rank_ == 1) 313 return __strides_[0] == 1; 314 else { 315 rank_type __r_largest = 0; 316 for (rank_type __r = 1; __r < __rank_; __r++) 317 if (__strides_[__r] > __strides_[__r_largest]) 318 __r_largest = __r; 319 for (rank_type __r = 0; __r < __rank_; __r++) 320 if (__extents_.extent(__r) == 0 && __r != __r_largest) 321 return false; 322 return true; 323 } 324 } else { 325 return required_span_size() == [&]<size_t... _Pos>(index_sequence<_Pos...>) { 326 return (__extents_.extent(_Pos) * ... * static_cast<index_type>(1)); 327 }(make_index_sequence<__rank_>()); 328 } 329 } 330 } 331 _LIBCPP_HIDE_FROM_ABI static constexpr bool is_strided() noexcept { return true; } 332 333 // according to the standard layout_stride does not have a constraint on stride(r) for rank>0 334 // it still has the precondition though 335 _LIBCPP_HIDE_FROM_ABI constexpr index_type stride(rank_type __r) const noexcept { 336 _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS(__r < __rank_, "layout_stride::mapping::stride(): invalid rank index"); 337 return __strides_[__r]; 338 } 339 340 template <class _OtherMapping> 341 requires(__mdspan_detail::__layout_mapping_alike<_OtherMapping> && 342 (_OtherMapping::extents_type::rank() == __rank_) && _OtherMapping::is_always_strided()) 343 _LIBCPP_HIDE_FROM_ABI friend constexpr bool operator==(const mapping& __lhs, const _OtherMapping& __rhs) noexcept { 344 if (__offset(__rhs)) 345 return false; 346 if constexpr (__rank_ == 0) 347 return true; 348 else { 349 return __lhs.extents() == __rhs.extents() && [&]<size_t... _Pos>(index_sequence<_Pos...>) { 350 // avoid warning when comparing signed and unsigner integers and pick the wider of two types 351 using _CommonType = common_type_t<index_type, typename _OtherMapping::index_type>; 352 return ((static_cast<_CommonType>(__lhs.stride(_Pos)) == static_cast<_CommonType>(__rhs.stride(_Pos))) && ... && 353 true); 354 }(make_index_sequence<__rank_>()); 355 } 356 } 357 358 private: 359 _LIBCPP_NO_UNIQUE_ADDRESS extents_type __extents_{}; 360 _LIBCPP_NO_UNIQUE_ADDRESS __mdspan_detail::__possibly_empty_array<index_type, __rank_> __strides_{}; 361 }; 362 363 #endif // _LIBCPP_STD_VER >= 23 364 365 _LIBCPP_END_NAMESPACE_STD 366 367 _LIBCPP_POP_MACROS 368 369 #endif // _LIBCPP___MDSPAN_LAYOUT_STRIDE_H 370