xref: /llvm-project/offload/test/offloading/ompx_bare_shfl_down_sync.cpp (revision 506ca19dc9d6a9f0ad47b82e71525743bbe8cf85)
1 // RUN: %libomptarget-compilexx-run-and-check-generic
2 //
3 // REQUIRES: gpu
4 
5 #include <cassert>
6 #include <cmath>
7 #include <cstdint>
8 #include <cstdio>
9 #include <limits>
10 #include <ompx.h>
11 #include <type_traits>
12 
13 #pragma omp begin declare variant match(device = {arch(amdgcn)})
14 unsigned get_warp_size() { return __builtin_amdgcn_wavefrontsize(); }
15 #pragma omp end declare variant
16 
17 #pragma omp begin declare variant match(device = {arch(nvptx64)})
18 unsigned get_warp_size() { return __nvvm_read_ptx_sreg_warpsize(); }
19 #pragma omp end declare variant
20 
21 #pragma omp begin declare variant match(device = {kind(cpu)})
22 unsigned get_warp_size() { return 1; }
23 #pragma omp end declare variant
24 
25 template <typename T, std::enable_if_t<std::is_integral<T>::value, bool> = true>
26 bool equal(T LHS, T RHS) {
27   return LHS == RHS;
28 }
29 
30 template <typename T,
31           std::enable_if_t<std::is_floating_point<T>::value, bool> = true>
32 bool equal(T LHS, T RHS) {
33   return __builtin_fabs(LHS - RHS) < std::numeric_limits<T>::epsilon();
34 }
35 
36 template <typename T> void test() {
37   constexpr const int num_blocks = 1;
38   constexpr const int block_size = 256;
39   constexpr const int N = num_blocks * block_size;
40   int *res = new int[N];
41 
42 #pragma omp target teams ompx_bare num_teams(num_blocks) thread_limit(block_size) \
43         map(from: res[0:N])
44   {
45     int tid = ompx_thread_id_x();
46     T val = ompx::shfl_down_sync(~0U, static_cast<T>(tid), 1);
47     int warp_size = get_warp_size();
48     if ((tid & (warp_size - 1)) != warp_size - 1)
49       res[tid] = equal(val, static_cast<T>(tid + 1));
50     else
51       res[tid] = equal(val, static_cast<T>(tid));
52   }
53 
54   for (int i = 0; i < N; ++i)
55     assert(res[i]);
56 
57   delete[] res;
58 }
59 
60 int main(int argc, char *argv[]) {
61   test<int32_t>();
62   test<int64_t>();
63   test<float>();
64   test<double>();
65   // CHECK: PASS
66   printf("PASS\n");
67 
68   return 0;
69 }
70