xref: /llvm-project/libcxx/test/std/containers/views/mdspan/CustomTestLayouts.h (revision 5e19fd172063c8957a35c7fa3596620f79ebba97)
1 // -*- C++ -*-
2 //===----------------------------------------------------------------------===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //                        Kokkos v. 4.0
9 //       Copyright (2022) National Technology & Engineering
10 //               Solutions of Sandia, LLC (NTESS).
11 //
12 // Under the terms of Contract DE-NA0003525 with NTESS,
13 // the U.S. Government retains certain rights in this software.
14 //
15 //===---------------------------------------------------------------------===//
16 
17 #ifndef TEST_STD_CONTAINERS_VIEWS_MDSPAN_CUSTOM_TEST_LAYOUTS_H
18 #define TEST_STD_CONTAINERS_VIEWS_MDSPAN_CUSTOM_TEST_LAYOUTS_H
19 
20 #include <algorithm>
21 #include <array>
22 #include <cassert>
23 #include <cinttypes>
24 #include <concepts>
25 #include <cstddef>
26 #include <limits>
27 #include <mdspan>
28 #include <span> // dynamic_extent
29 #include <type_traits>
30 #include <utility>
31 
32 // Layout that wraps indices to test some idiosyncratic behavior
33 // - basically it is a layout_left where indices are first wrapped i.e. i%Wrap
34 // - only accepts integers as indices
35 // - is_always_strided and is_always_unique are false
36 // - is_strided and is_unique are true if all extents are smaller than Wrap
37 // - not default constructible
38 // - not extents constructible
39 // - not trivially copyable
40 // - does not check dynamic to static extent conversion in converting ctor
41 // - check via side-effects that mdspan::swap calls mappings swap via ADL
42 
43 struct not_extents_constructible_tag {};
44 
45 template <size_t Wrap>
46 class layout_wrapping_integral {
47 public:
48   template <class Extents>
49   class mapping;
50 };
51 
52 template <size_t WrapArg>
53 template <class Extents>
54 class layout_wrapping_integral<WrapArg>::mapping {
55   static constexpr typename Extents::index_type Wrap = static_cast<typename Extents::index_type>(WrapArg);
56 
57 public:
58   using extents_type = Extents;
59   using index_type   = typename extents_type::index_type;
60   using size_type    = typename extents_type::size_type;
61   using rank_type    = typename extents_type::rank_type;
62   using layout_type  = layout_wrapping_integral<Wrap>;
63 
64 private:
65   static constexpr bool required_span_size_is_representable(const extents_type& ext) {
66     if constexpr (extents_type::rank() == 0)
67       return true;
68 
69     index_type prod = ext.extent(0);
70     for (rank_type r = 1; r < extents_type::rank(); r++) {
71       bool overflowed = __builtin_mul_overflow(prod, std::min(ext.extent(r), Wrap), &prod);
72       if (overflowed)
73         return false;
74     }
75     return true;
76   }
77 
78 public:
79   constexpr mapping() noexcept = delete;
80   constexpr mapping(const mapping& other) noexcept : extents_(other.extents()) {}
81   constexpr mapping(extents_type&& ext) noexcept
82     requires(Wrap == 8)
83       : extents_(ext) {}
84   constexpr mapping(const extents_type& ext, not_extents_constructible_tag) noexcept : extents_(ext) {}
85 
86   template <class OtherExtents>
87     requires(std::is_constructible_v<extents_type, OtherExtents> && (Wrap != 8))
88   constexpr explicit(!std::is_convertible_v<OtherExtents, extents_type>)
89       mapping(const mapping<OtherExtents>& other) noexcept {
90     std::array<index_type, extents_type::rank_dynamic()> dyn_extents;
91     rank_type count = 0;
92     for (rank_type r = 0; r < extents_type::rank(); r++) {
93       if (extents_type::static_extent(r) == std::dynamic_extent) {
94         dyn_extents[count++] = other.extents().extent(r);
95       }
96     }
97     extents_ = extents_type(dyn_extents);
98   }
99   template <class OtherExtents>
100     requires(std::is_constructible_v<extents_type, OtherExtents> && (Wrap == 8))
101   constexpr explicit(!std::is_convertible_v<OtherExtents, extents_type>)
102       mapping(mapping<OtherExtents>&& other) noexcept {
103     std::array<index_type, extents_type::rank_dynamic()> dyn_extents;
104     rank_type count = 0;
105     for (rank_type r = 0; r < extents_type::rank(); r++) {
106       if (extents_type::static_extent(r) == std::dynamic_extent) {
107         dyn_extents[count++] = other.extents().extent(r);
108       }
109     }
110     extents_ = extents_type(dyn_extents);
111   }
112 
113   constexpr mapping& operator=(const mapping& other) noexcept {
114     extents_ = other.extents_;
115     return *this;
116   };
117 
118   constexpr const extents_type& extents() const noexcept { return extents_; }
119 
120   constexpr index_type required_span_size() const noexcept {
121     index_type size = 1;
122     for (size_t r = 0; r < extents_type::rank(); r++)
123       size *= extents_.extent(r) < Wrap ? extents_.extent(r) : Wrap;
124     return size;
125   }
126 
127   template <std::integral... Indices>
128     requires((sizeof...(Indices) == extents_type::rank()) && (std::is_convertible_v<Indices, index_type> && ...) &&
129              (std::is_nothrow_constructible_v<index_type, Indices> && ...))
130   constexpr index_type operator()(Indices... idx) const noexcept {
131     std::array<index_type, extents_type::rank()> idx_a{static_cast<index_type>(static_cast<index_type>(idx) % Wrap)...};
132     return [&]<size_t... Pos>(std::index_sequence<Pos...>) {
133       index_type res = 0;
134       ((res = idx_a[extents_type::rank() - 1 - Pos] +
135               (extents_.extent(extents_type::rank() - 1 - Pos) < Wrap ? extents_.extent(extents_type::rank() - 1 - Pos)
136                                                                       : Wrap) *
137                   res),
138        ...);
139       return res;
140     }(std::make_index_sequence<sizeof...(Indices)>());
141   }
142 
143   static constexpr bool is_always_unique() noexcept { return false; }
144   static constexpr bool is_always_exhaustive() noexcept { return true; }
145   static constexpr bool is_always_strided() noexcept { return false; }
146 
147   constexpr bool is_unique() const noexcept {
148     for (rank_type r = 0; r < extents_type::rank(); r++) {
149       if (extents_.extent(r) > Wrap)
150         return false;
151     }
152     return true;
153   }
154   static constexpr bool is_exhaustive() noexcept { return true; }
155   constexpr bool is_strided() const noexcept {
156     for (rank_type r = 0; r < extents_type::rank(); r++) {
157       if (extents_.extent(r) > Wrap)
158         return false;
159     }
160     return true;
161   }
162 
163   constexpr index_type stride(rank_type r) const noexcept
164     requires(extents_type::rank() > 0)
165   {
166     index_type s = 1;
167     for (rank_type i = extents_type::rank() - 1; i > r; i--)
168       s *= extents_.extent(i);
169     return s;
170   }
171 
172   template <class OtherExtents>
173     requires(OtherExtents::rank() == extents_type::rank())
174   friend constexpr bool operator==(const mapping& lhs, const mapping<OtherExtents>& rhs) noexcept {
175     return lhs.extents() == rhs.extents();
176   }
177 
178   friend constexpr void swap(mapping& x, mapping& y) noexcept {
179     swap(x.extents_, y.extents_);
180     if (!std::is_constant_evaluated()) {
181       swap_counter()++;
182     }
183   }
184 
185   static int& swap_counter() {
186     static int value = 0;
187     return value;
188   }
189 
190 private:
191   extents_type extents_{};
192 };
193 
194 template <class Extents>
195 constexpr auto construct_mapping(std::layout_left, Extents exts) {
196   return std::layout_left::mapping<Extents>(exts);
197 }
198 
199 template <class Extents>
200 constexpr auto construct_mapping(std::layout_right, Extents exts) {
201   return std::layout_right::mapping<Extents>(exts);
202 }
203 
204 template <size_t Wraps, class Extents>
205 constexpr auto construct_mapping(layout_wrapping_integral<Wraps>, Extents exts) {
206   return typename layout_wrapping_integral<Wraps>::template mapping<Extents>(exts, not_extents_constructible_tag{});
207 }
208 
209 // This layout does not check convertibility of extents for its conversion ctor
210 // Allows triggering mdspan's ctor static assertion on convertibility of extents
211 // It also allows for negative strides and offsets via runtime arguments
212 class always_convertible_layout {
213 public:
214   template <class Extents>
215   class mapping;
216 };
217 
218 template <class Extents>
219 class always_convertible_layout::mapping {
220 public:
221   using extents_type = Extents;
222   using index_type   = typename extents_type::index_type;
223   using size_type    = typename extents_type::size_type;
224   using rank_type    = typename extents_type::rank_type;
225   using layout_type  = always_convertible_layout;
226 
227 private:
228   static constexpr bool required_span_size_is_representable(const extents_type& ext) {
229     if constexpr (extents_type::rank() == 0)
230       return true;
231 
232     index_type prod = ext.extent(0);
233     for (rank_type r = 1; r < extents_type::rank(); r++) {
234       bool overflowed = __builtin_mul_overflow(prod, ext.extent(r), &prod);
235       if (overflowed)
236         return false;
237     }
238     return true;
239   }
240 
241 public:
242   constexpr mapping() noexcept = delete;
243   constexpr mapping(const mapping& other) noexcept
244       : extents_(other.extents_), offset_(other.offset_), scaling_(other.scaling_) {}
245   constexpr mapping(const extents_type& ext, index_type offset = 0, index_type scaling = 1) noexcept
246       : extents_(ext), offset_(offset), scaling_(scaling) {}
247 
248   template <class OtherExtents>
249   constexpr mapping(const mapping<OtherExtents>& other) noexcept {
250     if constexpr (extents_type::rank() == OtherExtents::rank()) {
251       std::array<index_type, extents_type::rank_dynamic()> dyn_extents;
252       rank_type count = 0;
253       for (rank_type r = 0; r < extents_type::rank(); r++) {
254         if (extents_type::static_extent(r) == std::dynamic_extent) {
255           dyn_extents[count++] = other.extents().extent(r);
256         }
257       }
258       extents_ = extents_type(dyn_extents);
259     } else {
260       extents_ = extents_type();
261     }
262     offset_  = other.offset_;
263     scaling_ = other.scaling_;
264   }
265 
266   constexpr mapping& operator=(const mapping& other) noexcept {
267     extents_ = other.extents_;
268     offset_  = other.offset_;
269     scaling_ = other.scaling_;
270     return *this;
271   };
272 
273   constexpr const extents_type& extents() const noexcept { return extents_; }
274 
275   constexpr index_type required_span_size() const noexcept {
276     index_type size = 1;
277     for (size_t r = 0; r < extents_type::rank(); r++)
278       size *= extents_.extent(r);
279     return std::max(size * scaling_ + offset_, offset_);
280   }
281 
282   template <std::integral... Indices>
283     requires((sizeof...(Indices) == extents_type::rank()) && (std::is_convertible_v<Indices, index_type> && ...) &&
284              (std::is_nothrow_constructible_v<index_type, Indices> && ...))
285   constexpr index_type operator()(Indices... idx) const noexcept {
286     std::array<index_type, extents_type::rank()> idx_a{static_cast<index_type>(static_cast<index_type>(idx))...};
287     return offset_ +
288            scaling_ * ([&]<size_t... Pos>(std::index_sequence<Pos...>) {
289              index_type res = 0;
290              ((res = idx_a[extents_type::rank() - 1 - Pos] + extents_.extent(extents_type::rank() - 1 - Pos) * res),
291               ...);
292              return res;
293            }(std::make_index_sequence<sizeof...(Indices)>()));
294   }
295 
296   static constexpr bool is_always_unique() noexcept { return true; }
297   static constexpr bool is_always_exhaustive() noexcept { return true; }
298   static constexpr bool is_always_strided() noexcept { return true; }
299 
300   static constexpr bool is_unique() noexcept { return true; }
301   static constexpr bool is_exhaustive() noexcept { return true; }
302   static constexpr bool is_strided() noexcept { return true; }
303 
304   constexpr index_type stride(rank_type r) const noexcept
305     requires(extents_type::rank() > 0)
306   {
307     index_type s = 1;
308     for (rank_type i = 0; i < r; i++)
309       s *= extents_.extent(i);
310     return s * scaling_;
311   }
312 
313   template <class OtherExtents>
314     requires(OtherExtents::rank() == extents_type::rank())
315   friend constexpr bool operator==(const mapping& lhs, const mapping<OtherExtents>& rhs) noexcept {
316     return lhs.extents() == rhs.extents() && lhs.offset_ == rhs.offset && lhs.scaling_ == rhs.scaling_;
317   }
318 
319   friend constexpr void swap(mapping& x, mapping& y) noexcept {
320     swap(x.extents_, y.extents_);
321     if (!std::is_constant_evaluated()) {
322       swap_counter()++;
323     }
324   }
325 
326   static int& swap_counter() {
327     static int value = 0;
328     return value;
329   }
330 
331 private:
332   template <class>
333   friend class mapping;
334 
335   extents_type extents_{};
336   index_type offset_{};
337   index_type scaling_{};
338 };
339 #endif // TEST_STD_CONTAINERS_VIEWS_MDSPAN_CUSTOM_TEST_LAYOUTS_H
340