xref: /llvm-project/pstl/include/pstl/internal/omp/parallel_scan.h (revision 843c12d6a0cdfd64c5a92e24eb58ba9ee17ca1ee)
1 // -*- C++ -*-
2 // -*-===----------------------------------------------------------------------===//
3 //
4 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 //
6 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
7 // See https://llvm.org/LICENSE.txt for license information.
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #ifndef _PSTL_INTERNAL_OMP_PARALLEL_SCAN_H
12 #define _PSTL_INTERNAL_OMP_PARALLEL_SCAN_H
13 
14 #include "parallel_invoke.h"
15 
16 namespace __pstl
17 {
18 namespace __omp_backend
19 {
20 
21 template <typename _Index>
22 _Index
__split(_Index __m)23 __split(_Index __m)
24 {
25     _Index __k = 1;
26     while (2 * __k < __m)
27         __k *= 2;
28     return __k;
29 }
30 
31 template <typename _Index, typename _Tp, typename _Rp, typename _Cp>
32 void
__upsweep(_Index __i,_Index __m,_Index __tilesize,_Tp * __r,_Index __lastsize,_Rp __reduce,_Cp __combine)33 __upsweep(_Index __i, _Index __m, _Index __tilesize, _Tp* __r, _Index __lastsize, _Rp __reduce, _Cp __combine)
34 {
35     if (__m == 1)
36         __r[0] = __reduce(__i * __tilesize, __lastsize);
37     else
38     {
39         _Index __k = __split(__m);
40         __omp_backend::__parallel_invoke_body(
41             [=] { __omp_backend::__upsweep(__i, __k, __tilesize, __r, __tilesize, __reduce, __combine); },
42             [=] {
43                 __omp_backend::__upsweep(__i + __k, __m - __k, __tilesize, __r + __k, __lastsize, __reduce, __combine);
44             });
45         if (__m == 2 * __k)
46             __r[__m - 1] = __combine(__r[__k - 1], __r[__m - 1]);
47     }
48 }
49 
50 template <typename _Index, typename _Tp, typename _Cp, typename _Sp>
51 void
__downsweep(_Index __i,_Index __m,_Index __tilesize,_Tp * __r,_Index __lastsize,_Tp __initial,_Cp __combine,_Sp __scan)52 __downsweep(_Index __i, _Index __m, _Index __tilesize, _Tp* __r, _Index __lastsize, _Tp __initial, _Cp __combine,
53             _Sp __scan)
54 {
55     if (__m == 1)
56         __scan(__i * __tilesize, __lastsize, __initial);
57     else
58     {
59         const _Index __k = __split(__m);
60         __omp_backend::__parallel_invoke_body(
61             [=] { __omp_backend::__downsweep(__i, __k, __tilesize, __r, __tilesize, __initial, __combine, __scan); },
62             // Assumes that __combine never throws.
63             // TODO: Consider adding a requirement for user functors to be constant.
64             [=, &__combine]
65             {
66                 __omp_backend::__downsweep(__i + __k, __m - __k, __tilesize, __r + __k, __lastsize,
67                                            __combine(__initial, __r[__k - 1]), __combine, __scan);
68             });
69     }
70 }
71 
72 template <typename _ExecutionPolicy, typename _Index, typename _Tp, typename _Rp, typename _Cp, typename _Sp,
73           typename _Ap>
74 void
__parallel_strict_scan_body(_Index __n,_Tp __initial,_Rp __reduce,_Cp __combine,_Sp __scan,_Ap __apex)75 __parallel_strict_scan_body(_Index __n, _Tp __initial, _Rp __reduce, _Cp __combine, _Sp __scan, _Ap __apex)
76 {
77     _Index __p = omp_get_num_threads();
78     const _Index __slack = 4;
79     _Index __tilesize = (__n - 1) / (__slack * __p) + 1;
80     _Index __m = (__n - 1) / __tilesize;
81     __buffer<_Tp> __buf(__m + 1);
82     _Tp* __r = __buf.get();
83 
84     __omp_backend::__upsweep(_Index(0), _Index(__m + 1), __tilesize, __r, __n - __m * __tilesize, __reduce, __combine);
85 
86     std::size_t __k = __m + 1;
87     _Tp __t = __r[__k - 1];
88     while ((__k &= __k - 1))
89     {
90         __t = __combine(__r[__k - 1], __t);
91     }
92 
93     __apex(__combine(__initial, __t));
94     __omp_backend::__downsweep(_Index(0), _Index(__m + 1), __tilesize, __r, __n - __m * __tilesize, __initial,
95                                __combine, __scan);
96 }
97 
98 template <class _ExecutionPolicy, typename _Index, typename _Tp, typename _Rp, typename _Cp, typename _Sp, typename _Ap>
99 void
__parallel_strict_scan(__pstl::__internal::__openmp_backend_tag,_ExecutionPolicy &&,_Index __n,_Tp __initial,_Rp __reduce,_Cp __combine,_Sp __scan,_Ap __apex)100 __parallel_strict_scan(__pstl::__internal::__openmp_backend_tag, _ExecutionPolicy&&, _Index __n, _Tp __initial,
101                        _Rp __reduce, _Cp __combine, _Sp __scan, _Ap __apex)
102 {
103     if (__n <= __default_chunk_size)
104     {
105         _Tp __sum = __initial;
106         if (__n)
107         {
108             __sum = __combine(__sum, __reduce(_Index(0), __n));
109         }
110         __apex(__sum);
111         if (__n)
112         {
113             __scan(_Index(0), __n, __initial);
114         }
115         return;
116     }
117 
118     if (omp_in_parallel())
119     {
120         __pstl::__omp_backend::__parallel_strict_scan_body<_ExecutionPolicy>(__n, __initial, __reduce, __combine,
121                                                                              __scan, __apex);
122     }
123     else
124     {
125         _PSTL_PRAGMA(omp parallel)
126         _PSTL_PRAGMA(omp single nowait)
127         {
128             __pstl::__omp_backend::__parallel_strict_scan_body<_ExecutionPolicy>(__n, __initial, __reduce, __combine,
129                                                                                  __scan, __apex);
130         }
131     }
132 }
133 
134 } // namespace __omp_backend
135 } // namespace __pstl
136 #endif // _PSTL_INTERNAL_OMP_PARALLEL_SCAN_H
137