1 // -*- C++ -*- 2 //===----------------------------------------------------------------------===// 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 // Kokkos v. 4.0 9 // Copyright (2022) National Technology & Engineering 10 // Solutions of Sandia, LLC (NTESS). 11 // 12 // Under the terms of Contract DE-NA0003525 with NTESS, 13 // the U.S. Government retains certain rights in this software. 14 // 15 //===---------------------------------------------------------------------===// 16 17 #ifndef TEST_STD_CONTAINERS_VIEWS_MDSPAN_CUSTOM_TEST_LAYOUTS_H 18 #define TEST_STD_CONTAINERS_VIEWS_MDSPAN_CUSTOM_TEST_LAYOUTS_H 19 20 #include <algorithm> 21 #include <array> 22 #include <cassert> 23 #include <cinttypes> 24 #include <concepts> 25 #include <cstddef> 26 #include <limits> 27 #include <mdspan> 28 #include <span> // dynamic_extent 29 #include <type_traits> 30 #include <utility> 31 32 // Layout that wraps indices to test some idiosyncratic behavior 33 // - basically it is a layout_left where indices are first wrapped i.e. i%Wrap 34 // - only accepts integers as indices 35 // - is_always_strided and is_always_unique are false 36 // - is_strided and is_unique are true if all extents are smaller than Wrap 37 // - not default constructible 38 // - not extents constructible 39 // - not trivially copyable 40 // - does not check dynamic to static extent conversion in converting ctor 41 // - check via side-effects that mdspan::swap calls mappings swap via ADL 42 43 struct not_extents_constructible_tag {}; 44 45 template <size_t Wrap> 46 class layout_wrapping_integral { 47 public: 48 template <class Extents> 49 class mapping; 50 }; 51 52 template <size_t WrapArg> 53 template <class Extents> 54 class layout_wrapping_integral<WrapArg>::mapping { 55 static constexpr typename Extents::index_type Wrap = static_cast<typename Extents::index_type>(WrapArg); 56 57 public: 58 using extents_type = Extents; 59 using index_type = typename extents_type::index_type; 60 using size_type = typename extents_type::size_type; 61 using rank_type = typename extents_type::rank_type; 62 using layout_type = layout_wrapping_integral<Wrap>; 63 64 private: 65 static constexpr bool required_span_size_is_representable(const extents_type& ext) { 66 if constexpr (extents_type::rank() == 0) 67 return true; 68 69 index_type prod = ext.extent(0); 70 for (rank_type r = 1; r < extents_type::rank(); r++) { 71 bool overflowed = __builtin_mul_overflow(prod, std::min(ext.extent(r), Wrap), &prod); 72 if (overflowed) 73 return false; 74 } 75 return true; 76 } 77 78 public: 79 constexpr mapping() noexcept = delete; 80 constexpr mapping(const mapping& other) noexcept : extents_(other.extents()) {} 81 constexpr mapping(extents_type&& ext) noexcept 82 requires(Wrap == 8) 83 : extents_(ext) {} 84 constexpr mapping(const extents_type& ext, not_extents_constructible_tag) noexcept : extents_(ext) {} 85 86 template <class OtherExtents> 87 requires(std::is_constructible_v<extents_type, OtherExtents> && (Wrap != 8)) 88 constexpr explicit(!std::is_convertible_v<OtherExtents, extents_type>) 89 mapping(const mapping<OtherExtents>& other) noexcept { 90 std::array<index_type, extents_type::rank_dynamic()> dyn_extents; 91 rank_type count = 0; 92 for (rank_type r = 0; r < extents_type::rank(); r++) { 93 if (extents_type::static_extent(r) == std::dynamic_extent) { 94 dyn_extents[count++] = other.extents().extent(r); 95 } 96 } 97 extents_ = extents_type(dyn_extents); 98 } 99 template <class OtherExtents> 100 requires(std::is_constructible_v<extents_type, OtherExtents> && (Wrap == 8)) 101 constexpr explicit(!std::is_convertible_v<OtherExtents, extents_type>) 102 mapping(mapping<OtherExtents>&& other) noexcept { 103 std::array<index_type, extents_type::rank_dynamic()> dyn_extents; 104 rank_type count = 0; 105 for (rank_type r = 0; r < extents_type::rank(); r++) { 106 if (extents_type::static_extent(r) == std::dynamic_extent) { 107 dyn_extents[count++] = other.extents().extent(r); 108 } 109 } 110 extents_ = extents_type(dyn_extents); 111 } 112 113 constexpr mapping& operator=(const mapping& other) noexcept { 114 extents_ = other.extents_; 115 return *this; 116 }; 117 118 constexpr const extents_type& extents() const noexcept { return extents_; } 119 120 constexpr index_type required_span_size() const noexcept { 121 index_type size = 1; 122 for (size_t r = 0; r < extents_type::rank(); r++) 123 size *= extents_.extent(r) < Wrap ? extents_.extent(r) : Wrap; 124 return size; 125 } 126 127 template <std::integral... Indices> 128 requires((sizeof...(Indices) == extents_type::rank()) && (std::is_convertible_v<Indices, index_type> && ...) && 129 (std::is_nothrow_constructible_v<index_type, Indices> && ...)) 130 constexpr index_type operator()(Indices... idx) const noexcept { 131 std::array<index_type, extents_type::rank()> idx_a{static_cast<index_type>(static_cast<index_type>(idx) % Wrap)...}; 132 return [&]<size_t... Pos>(std::index_sequence<Pos...>) { 133 index_type res = 0; 134 ((res = idx_a[extents_type::rank() - 1 - Pos] + 135 (extents_.extent(extents_type::rank() - 1 - Pos) < Wrap ? extents_.extent(extents_type::rank() - 1 - Pos) 136 : Wrap) * 137 res), 138 ...); 139 return res; 140 }(std::make_index_sequence<sizeof...(Indices)>()); 141 } 142 143 static constexpr bool is_always_unique() noexcept { return false; } 144 static constexpr bool is_always_exhaustive() noexcept { return true; } 145 static constexpr bool is_always_strided() noexcept { return false; } 146 147 constexpr bool is_unique() const noexcept { 148 for (rank_type r = 0; r < extents_type::rank(); r++) { 149 if (extents_.extent(r) > Wrap) 150 return false; 151 } 152 return true; 153 } 154 static constexpr bool is_exhaustive() noexcept { return true; } 155 constexpr bool is_strided() const noexcept { 156 for (rank_type r = 0; r < extents_type::rank(); r++) { 157 if (extents_.extent(r) > Wrap) 158 return false; 159 } 160 return true; 161 } 162 163 constexpr index_type stride(rank_type r) const noexcept 164 requires(extents_type::rank() > 0) 165 { 166 index_type s = 1; 167 for (rank_type i = extents_type::rank() - 1; i > r; i--) 168 s *= extents_.extent(i); 169 return s; 170 } 171 172 template <class OtherExtents> 173 requires(OtherExtents::rank() == extents_type::rank()) 174 friend constexpr bool operator==(const mapping& lhs, const mapping<OtherExtents>& rhs) noexcept { 175 return lhs.extents() == rhs.extents(); 176 } 177 178 friend constexpr void swap(mapping& x, mapping& y) noexcept { 179 swap(x.extents_, y.extents_); 180 if (!std::is_constant_evaluated()) { 181 swap_counter()++; 182 } 183 } 184 185 static int& swap_counter() { 186 static int value = 0; 187 return value; 188 } 189 190 private: 191 extents_type extents_{}; 192 }; 193 194 template <class Extents> 195 constexpr auto construct_mapping(std::layout_left, Extents exts) { 196 return std::layout_left::mapping<Extents>(exts); 197 } 198 199 template <class Extents> 200 constexpr auto construct_mapping(std::layout_right, Extents exts) { 201 return std::layout_right::mapping<Extents>(exts); 202 } 203 204 template <size_t Wraps, class Extents> 205 constexpr auto construct_mapping(layout_wrapping_integral<Wraps>, Extents exts) { 206 return typename layout_wrapping_integral<Wraps>::template mapping<Extents>(exts, not_extents_constructible_tag{}); 207 } 208 209 // This layout does not check convertibility of extents for its conversion ctor 210 // Allows triggering mdspan's ctor static assertion on convertibility of extents 211 // It also allows for negative strides and offsets via runtime arguments 212 class always_convertible_layout { 213 public: 214 template <class Extents> 215 class mapping; 216 }; 217 218 template <class Extents> 219 class always_convertible_layout::mapping { 220 public: 221 using extents_type = Extents; 222 using index_type = typename extents_type::index_type; 223 using size_type = typename extents_type::size_type; 224 using rank_type = typename extents_type::rank_type; 225 using layout_type = always_convertible_layout; 226 227 private: 228 static constexpr bool required_span_size_is_representable(const extents_type& ext) { 229 if constexpr (extents_type::rank() == 0) 230 return true; 231 232 index_type prod = ext.extent(0); 233 for (rank_type r = 1; r < extents_type::rank(); r++) { 234 bool overflowed = __builtin_mul_overflow(prod, ext.extent(r), &prod); 235 if (overflowed) 236 return false; 237 } 238 return true; 239 } 240 241 public: 242 constexpr mapping() noexcept = delete; 243 constexpr mapping(const mapping& other) noexcept 244 : extents_(other.extents_), offset_(other.offset_), scaling_(other.scaling_) {} 245 constexpr mapping(const extents_type& ext, index_type offset = 0, index_type scaling = 1) noexcept 246 : extents_(ext), offset_(offset), scaling_(scaling) {} 247 248 template <class OtherExtents> 249 constexpr mapping(const mapping<OtherExtents>& other) noexcept { 250 if constexpr (extents_type::rank() == OtherExtents::rank()) { 251 std::array<index_type, extents_type::rank_dynamic()> dyn_extents; 252 rank_type count = 0; 253 for (rank_type r = 0; r < extents_type::rank(); r++) { 254 if (extents_type::static_extent(r) == std::dynamic_extent) { 255 dyn_extents[count++] = other.extents().extent(r); 256 } 257 } 258 extents_ = extents_type(dyn_extents); 259 } else { 260 extents_ = extents_type(); 261 } 262 offset_ = other.offset_; 263 scaling_ = other.scaling_; 264 } 265 266 constexpr mapping& operator=(const mapping& other) noexcept { 267 extents_ = other.extents_; 268 offset_ = other.offset_; 269 scaling_ = other.scaling_; 270 return *this; 271 }; 272 273 constexpr const extents_type& extents() const noexcept { return extents_; } 274 275 constexpr index_type required_span_size() const noexcept { 276 index_type size = 1; 277 for (size_t r = 0; r < extents_type::rank(); r++) 278 size *= extents_.extent(r); 279 return std::max(size * scaling_ + offset_, offset_); 280 } 281 282 template <std::integral... Indices> 283 requires((sizeof...(Indices) == extents_type::rank()) && (std::is_convertible_v<Indices, index_type> && ...) && 284 (std::is_nothrow_constructible_v<index_type, Indices> && ...)) 285 constexpr index_type operator()(Indices... idx) const noexcept { 286 std::array<index_type, extents_type::rank()> idx_a{static_cast<index_type>(static_cast<index_type>(idx))...}; 287 return offset_ + 288 scaling_ * ([&]<size_t... Pos>(std::index_sequence<Pos...>) { 289 index_type res = 0; 290 ((res = idx_a[extents_type::rank() - 1 - Pos] + extents_.extent(extents_type::rank() - 1 - Pos) * res), 291 ...); 292 return res; 293 }(std::make_index_sequence<sizeof...(Indices)>())); 294 } 295 296 static constexpr bool is_always_unique() noexcept { return true; } 297 static constexpr bool is_always_exhaustive() noexcept { return true; } 298 static constexpr bool is_always_strided() noexcept { return true; } 299 300 static constexpr bool is_unique() noexcept { return true; } 301 static constexpr bool is_exhaustive() noexcept { return true; } 302 static constexpr bool is_strided() noexcept { return true; } 303 304 constexpr index_type stride(rank_type r) const noexcept 305 requires(extents_type::rank() > 0) 306 { 307 index_type s = 1; 308 for (rank_type i = 0; i < r; i++) 309 s *= extents_.extent(i); 310 return s * scaling_; 311 } 312 313 template <class OtherExtents> 314 requires(OtherExtents::rank() == extents_type::rank()) 315 friend constexpr bool operator==(const mapping& lhs, const mapping<OtherExtents>& rhs) noexcept { 316 return lhs.extents() == rhs.extents() && lhs.offset_ == rhs.offset && lhs.scaling_ == rhs.scaling_; 317 } 318 319 friend constexpr void swap(mapping& x, mapping& y) noexcept { 320 swap(x.extents_, y.extents_); 321 if (!std::is_constant_evaluated()) { 322 swap_counter()++; 323 } 324 } 325 326 static int& swap_counter() { 327 static int value = 0; 328 return value; 329 } 330 331 private: 332 template <class> 333 friend class mapping; 334 335 extents_type extents_{}; 336 index_type offset_{}; 337 index_type scaling_{}; 338 }; 339 #endif // TEST_STD_CONTAINERS_VIEWS_MDSPAN_CUSTOM_TEST_LAYOUTS_H 340