1 // -*- C++ -*-
2 // -*-===----------------------------------------------------------------------===//
3 //
4 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 //
6 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
7 // See https://llvm.org/LICENSE.txt for license information.
8 //
9 //===----------------------------------------------------------------------===//
10
11 #ifndef _PSTL_INTERNAL_OMP_PARALLEL_SCAN_H
12 #define _PSTL_INTERNAL_OMP_PARALLEL_SCAN_H
13
14 #include "parallel_invoke.h"
15
16 namespace __pstl
17 {
18 namespace __omp_backend
19 {
20
21 template <typename _Index>
22 _Index
__split(_Index __m)23 __split(_Index __m)
24 {
25 _Index __k = 1;
26 while (2 * __k < __m)
27 __k *= 2;
28 return __k;
29 }
30
31 template <typename _Index, typename _Tp, typename _Rp, typename _Cp>
32 void
__upsweep(_Index __i,_Index __m,_Index __tilesize,_Tp * __r,_Index __lastsize,_Rp __reduce,_Cp __combine)33 __upsweep(_Index __i, _Index __m, _Index __tilesize, _Tp* __r, _Index __lastsize, _Rp __reduce, _Cp __combine)
34 {
35 if (__m == 1)
36 __r[0] = __reduce(__i * __tilesize, __lastsize);
37 else
38 {
39 _Index __k = __split(__m);
40 __omp_backend::__parallel_invoke_body(
41 [=] { __omp_backend::__upsweep(__i, __k, __tilesize, __r, __tilesize, __reduce, __combine); },
42 [=] {
43 __omp_backend::__upsweep(__i + __k, __m - __k, __tilesize, __r + __k, __lastsize, __reduce, __combine);
44 });
45 if (__m == 2 * __k)
46 __r[__m - 1] = __combine(__r[__k - 1], __r[__m - 1]);
47 }
48 }
49
50 template <typename _Index, typename _Tp, typename _Cp, typename _Sp>
51 void
__downsweep(_Index __i,_Index __m,_Index __tilesize,_Tp * __r,_Index __lastsize,_Tp __initial,_Cp __combine,_Sp __scan)52 __downsweep(_Index __i, _Index __m, _Index __tilesize, _Tp* __r, _Index __lastsize, _Tp __initial, _Cp __combine,
53 _Sp __scan)
54 {
55 if (__m == 1)
56 __scan(__i * __tilesize, __lastsize, __initial);
57 else
58 {
59 const _Index __k = __split(__m);
60 __omp_backend::__parallel_invoke_body(
61 [=] { __omp_backend::__downsweep(__i, __k, __tilesize, __r, __tilesize, __initial, __combine, __scan); },
62 // Assumes that __combine never throws.
63 // TODO: Consider adding a requirement for user functors to be constant.
64 [=, &__combine]
65 {
66 __omp_backend::__downsweep(__i + __k, __m - __k, __tilesize, __r + __k, __lastsize,
67 __combine(__initial, __r[__k - 1]), __combine, __scan);
68 });
69 }
70 }
71
72 template <typename _ExecutionPolicy, typename _Index, typename _Tp, typename _Rp, typename _Cp, typename _Sp,
73 typename _Ap>
74 void
__parallel_strict_scan_body(_Index __n,_Tp __initial,_Rp __reduce,_Cp __combine,_Sp __scan,_Ap __apex)75 __parallel_strict_scan_body(_Index __n, _Tp __initial, _Rp __reduce, _Cp __combine, _Sp __scan, _Ap __apex)
76 {
77 _Index __p = omp_get_num_threads();
78 const _Index __slack = 4;
79 _Index __tilesize = (__n - 1) / (__slack * __p) + 1;
80 _Index __m = (__n - 1) / __tilesize;
81 __buffer<_Tp> __buf(__m + 1);
82 _Tp* __r = __buf.get();
83
84 __omp_backend::__upsweep(_Index(0), _Index(__m + 1), __tilesize, __r, __n - __m * __tilesize, __reduce, __combine);
85
86 std::size_t __k = __m + 1;
87 _Tp __t = __r[__k - 1];
88 while ((__k &= __k - 1))
89 {
90 __t = __combine(__r[__k - 1], __t);
91 }
92
93 __apex(__combine(__initial, __t));
94 __omp_backend::__downsweep(_Index(0), _Index(__m + 1), __tilesize, __r, __n - __m * __tilesize, __initial,
95 __combine, __scan);
96 }
97
98 template <class _ExecutionPolicy, typename _Index, typename _Tp, typename _Rp, typename _Cp, typename _Sp, typename _Ap>
99 void
__parallel_strict_scan(__pstl::__internal::__openmp_backend_tag,_ExecutionPolicy &&,_Index __n,_Tp __initial,_Rp __reduce,_Cp __combine,_Sp __scan,_Ap __apex)100 __parallel_strict_scan(__pstl::__internal::__openmp_backend_tag, _ExecutionPolicy&&, _Index __n, _Tp __initial,
101 _Rp __reduce, _Cp __combine, _Sp __scan, _Ap __apex)
102 {
103 if (__n <= __default_chunk_size)
104 {
105 _Tp __sum = __initial;
106 if (__n)
107 {
108 __sum = __combine(__sum, __reduce(_Index(0), __n));
109 }
110 __apex(__sum);
111 if (__n)
112 {
113 __scan(_Index(0), __n, __initial);
114 }
115 return;
116 }
117
118 if (omp_in_parallel())
119 {
120 __pstl::__omp_backend::__parallel_strict_scan_body<_ExecutionPolicy>(__n, __initial, __reduce, __combine,
121 __scan, __apex);
122 }
123 else
124 {
125 _PSTL_PRAGMA(omp parallel)
126 _PSTL_PRAGMA(omp single nowait)
127 {
128 __pstl::__omp_backend::__parallel_strict_scan_body<_ExecutionPolicy>(__n, __initial, __reduce, __combine,
129 __scan, __apex);
130 }
131 }
132 }
133
134 } // namespace __omp_backend
135 } // namespace __pstl
136 #endif // _PSTL_INTERNAL_OMP_PARALLEL_SCAN_H
137