xref: /dflybsd-src/contrib/gcc-8.0/libstdc++-v3/include/parallel/partial_sum.h (revision 38fd149817dfbff97799f62fcb70be98c4e32523)
1*38fd1498Szrj // -*- C++ -*-
2*38fd1498Szrj 
3*38fd1498Szrj // Copyright (C) 2007-2018 Free Software Foundation, Inc.
4*38fd1498Szrj //
5*38fd1498Szrj // This file is part of the GNU ISO C++ Library.  This library is free
6*38fd1498Szrj // software; you can redistribute it and/or modify it under the terms
7*38fd1498Szrj // of the GNU General Public License as published by the Free Software
8*38fd1498Szrj // Foundation; either version 3, or (at your option) any later
9*38fd1498Szrj // version.
10*38fd1498Szrj 
11*38fd1498Szrj // This library is distributed in the hope that it will be useful, but
12*38fd1498Szrj // WITHOUT ANY WARRANTY; without even the implied warranty of
13*38fd1498Szrj // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14*38fd1498Szrj // General Public License for more details.
15*38fd1498Szrj 
16*38fd1498Szrj // Under Section 7 of GPL version 3, you are granted additional
17*38fd1498Szrj // permissions described in the GCC Runtime Library Exception, version
18*38fd1498Szrj // 3.1, as published by the Free Software Foundation.
19*38fd1498Szrj 
20*38fd1498Szrj // You should have received a copy of the GNU General Public License and
21*38fd1498Szrj // a copy of the GCC Runtime Library Exception along with this program;
22*38fd1498Szrj // see the files COPYING3 and COPYING.RUNTIME respectively.  If not, see
23*38fd1498Szrj // <http://www.gnu.org/licenses/>.
24*38fd1498Szrj 
25*38fd1498Szrj /** @file parallel/partial_sum.h
26*38fd1498Szrj  *  @brief Parallel implementation of std::partial_sum(), i.e. prefix
27*38fd1498Szrj *  sums.
28*38fd1498Szrj  *  This file is a GNU parallel extension to the Standard C++ Library.
29*38fd1498Szrj  */
30*38fd1498Szrj 
31*38fd1498Szrj // Written by Johannes Singler.
32*38fd1498Szrj 
33*38fd1498Szrj #ifndef _GLIBCXX_PARALLEL_PARTIAL_SUM_H
34*38fd1498Szrj #define _GLIBCXX_PARALLEL_PARTIAL_SUM_H 1
35*38fd1498Szrj 
36*38fd1498Szrj #include <omp.h>
37*38fd1498Szrj #include <new>
38*38fd1498Szrj #include <bits/stl_algobase.h>
39*38fd1498Szrj #include <parallel/parallel.h>
40*38fd1498Szrj #include <parallel/numericfwd.h>
41*38fd1498Szrj 
42*38fd1498Szrj namespace __gnu_parallel
43*38fd1498Szrj {
44*38fd1498Szrj   // Problem: there is no 0-element given.
45*38fd1498Szrj 
46*38fd1498Szrj   /** @brief Base case prefix sum routine.
47*38fd1498Szrj    *  @param __begin Begin iterator of input sequence.
48*38fd1498Szrj    *  @param __end End iterator of input sequence.
49*38fd1498Szrj    *  @param __result Begin iterator of output sequence.
50*38fd1498Szrj    *  @param __bin_op Associative binary function.
51*38fd1498Szrj    *  @param __value Start value. Must be passed since the neutral
52*38fd1498Szrj    *  element is unknown in general.
53*38fd1498Szrj    *  @return End iterator of output sequence. */
54*38fd1498Szrj   template<typename _IIter,
55*38fd1498Szrj 	   typename _OutputIterator,
56*38fd1498Szrj 	   typename _BinaryOperation>
57*38fd1498Szrj     _OutputIterator
__parallel_partial_sum_basecase(_IIter __begin,_IIter __end,_OutputIterator __result,_BinaryOperation __bin_op,typename std::iterator_traits<_IIter>::value_type __value)58*38fd1498Szrj     __parallel_partial_sum_basecase(_IIter __begin, _IIter __end,
59*38fd1498Szrj 				    _OutputIterator __result,
60*38fd1498Szrj 				    _BinaryOperation __bin_op,
61*38fd1498Szrj       typename std::iterator_traits <_IIter>::value_type __value)
62*38fd1498Szrj     {
63*38fd1498Szrj       if (__begin == __end)
64*38fd1498Szrj 	return __result;
65*38fd1498Szrj 
66*38fd1498Szrj       while (__begin != __end)
67*38fd1498Szrj 	{
68*38fd1498Szrj 	  __value = __bin_op(__value, *__begin);
69*38fd1498Szrj 	  *__result = __value;
70*38fd1498Szrj 	  ++__result;
71*38fd1498Szrj 	  ++__begin;
72*38fd1498Szrj 	}
73*38fd1498Szrj       return __result;
74*38fd1498Szrj     }
75*38fd1498Szrj 
76*38fd1498Szrj   /** @brief Parallel partial sum implementation, two-phase approach,
77*38fd1498Szrj       no recursion.
78*38fd1498Szrj       *  @param __begin Begin iterator of input sequence.
79*38fd1498Szrj       *  @param __end End iterator of input sequence.
80*38fd1498Szrj       *  @param __result Begin iterator of output sequence.
81*38fd1498Szrj       *  @param __bin_op Associative binary function.
82*38fd1498Szrj       *  @param __n Length of sequence.
83*38fd1498Szrj       *  @return End iterator of output sequence.
84*38fd1498Szrj       */
85*38fd1498Szrj   template<typename _IIter,
86*38fd1498Szrj 	   typename _OutputIterator,
87*38fd1498Szrj 	   typename _BinaryOperation>
88*38fd1498Szrj     _OutputIterator
__parallel_partial_sum_linear(_IIter __begin,_IIter __end,_OutputIterator __result,_BinaryOperation __bin_op,typename std::iterator_traits<_IIter>::difference_type __n)89*38fd1498Szrj     __parallel_partial_sum_linear(_IIter __begin, _IIter __end,
90*38fd1498Szrj 				  _OutputIterator __result,
91*38fd1498Szrj 				  _BinaryOperation __bin_op,
92*38fd1498Szrj       typename std::iterator_traits<_IIter>::difference_type __n)
93*38fd1498Szrj     {
94*38fd1498Szrj       typedef std::iterator_traits<_IIter> _TraitsType;
95*38fd1498Szrj       typedef typename _TraitsType::value_type _ValueType;
96*38fd1498Szrj       typedef typename _TraitsType::difference_type _DifferenceType;
97*38fd1498Szrj 
98*38fd1498Szrj       if (__begin == __end)
99*38fd1498Szrj 	return __result;
100*38fd1498Szrj 
101*38fd1498Szrj       _ThreadIndex __num_threads =
102*38fd1498Szrj         std::min<_DifferenceType>(__get_max_threads(), __n - 1);
103*38fd1498Szrj 
104*38fd1498Szrj       if (__num_threads < 2)
105*38fd1498Szrj 	{
106*38fd1498Szrj 	  *__result = *__begin;
107*38fd1498Szrj 	  return __parallel_partial_sum_basecase(__begin + 1, __end,
108*38fd1498Szrj 						 __result + 1, __bin_op,
109*38fd1498Szrj 						 *__begin);
110*38fd1498Szrj 	}
111*38fd1498Szrj 
112*38fd1498Szrj       _DifferenceType* __borders;
113*38fd1498Szrj       _ValueType* __sums;
114*38fd1498Szrj 
115*38fd1498Szrj       const _Settings& __s = _Settings::get();
116*38fd1498Szrj 
117*38fd1498Szrj #     pragma omp parallel num_threads(__num_threads)
118*38fd1498Szrj       {
119*38fd1498Szrj #       pragma omp single
120*38fd1498Szrj 	{
121*38fd1498Szrj 	  __num_threads = omp_get_num_threads();
122*38fd1498Szrj 
123*38fd1498Szrj 	  __borders = new _DifferenceType[__num_threads + 2];
124*38fd1498Szrj 
125*38fd1498Szrj 	  if (__s.partial_sum_dilation == 1.0f)
126*38fd1498Szrj 	    __equally_split(__n, __num_threads + 1, __borders);
127*38fd1498Szrj 	  else
128*38fd1498Szrj 	    {
129*38fd1498Szrj 	      _DifferenceType __first_part_length =
130*38fd1498Szrj 		  std::max<_DifferenceType>(1,
131*38fd1498Szrj 		    __n / (1.0f + __s.partial_sum_dilation * __num_threads));
132*38fd1498Szrj 	      _DifferenceType __chunk_length =
133*38fd1498Szrj 		  (__n - __first_part_length) / __num_threads;
134*38fd1498Szrj 	      _DifferenceType __borderstart =
135*38fd1498Szrj 		  __n - __num_threads * __chunk_length;
136*38fd1498Szrj 	      __borders[0] = 0;
137*38fd1498Szrj 	      for (_ThreadIndex __i = 1; __i < (__num_threads + 1); ++__i)
138*38fd1498Szrj 		{
139*38fd1498Szrj 		  __borders[__i] = __borderstart;
140*38fd1498Szrj 		  __borderstart += __chunk_length;
141*38fd1498Szrj 		}
142*38fd1498Szrj 	      __borders[__num_threads + 1] = __n;
143*38fd1498Szrj 	    }
144*38fd1498Szrj 
145*38fd1498Szrj 	  __sums = static_cast<_ValueType*>(::operator new(sizeof(_ValueType)
146*38fd1498Szrj                                                            * __num_threads));
147*38fd1498Szrj 	  _OutputIterator __target_end;
148*38fd1498Szrj 	} //single
149*38fd1498Szrj 
150*38fd1498Szrj         _ThreadIndex __iam = omp_get_thread_num();
151*38fd1498Szrj         if (__iam == 0)
152*38fd1498Szrj           {
153*38fd1498Szrj             *__result = *__begin;
154*38fd1498Szrj             __parallel_partial_sum_basecase(__begin + 1,
155*38fd1498Szrj 					    __begin + __borders[1],
156*38fd1498Szrj 					    __result + 1,
157*38fd1498Szrj 					    __bin_op, *__begin);
158*38fd1498Szrj             ::new(&(__sums[__iam])) _ValueType(*(__result + __borders[1] - 1));
159*38fd1498Szrj           }
160*38fd1498Szrj         else
161*38fd1498Szrj           {
162*38fd1498Szrj             ::new(&(__sums[__iam]))
163*38fd1498Szrj               _ValueType(__gnu_parallel::accumulate(
164*38fd1498Szrj                                          __begin + __borders[__iam] + 1,
165*38fd1498Szrj                                          __begin + __borders[__iam + 1],
166*38fd1498Szrj                                          *(__begin + __borders[__iam]),
167*38fd1498Szrj                                          __bin_op,
168*38fd1498Szrj                                          __gnu_parallel::sequential_tag()));
169*38fd1498Szrj           }
170*38fd1498Szrj 
171*38fd1498Szrj #       pragma omp barrier
172*38fd1498Szrj 
173*38fd1498Szrj #       pragma omp single
174*38fd1498Szrj 	__parallel_partial_sum_basecase(__sums + 1, __sums + __num_threads,
175*38fd1498Szrj 					__sums + 1, __bin_op, __sums[0]);
176*38fd1498Szrj 
177*38fd1498Szrj #       pragma omp barrier
178*38fd1498Szrj 
179*38fd1498Szrj 	// Still same team.
180*38fd1498Szrj         __parallel_partial_sum_basecase(__begin + __borders[__iam + 1],
181*38fd1498Szrj 					__begin + __borders[__iam + 2],
182*38fd1498Szrj 					__result + __borders[__iam + 1],
183*38fd1498Szrj 					__bin_op, __sums[__iam]);
184*38fd1498Szrj       } //parallel
185*38fd1498Szrj 
186*38fd1498Szrj       for (_ThreadIndex __i = 0; __i < __num_threads; ++__i)
187*38fd1498Szrj 	__sums[__i].~_ValueType();
188*38fd1498Szrj       ::operator delete(__sums);
189*38fd1498Szrj 
190*38fd1498Szrj       delete[] __borders;
191*38fd1498Szrj 
192*38fd1498Szrj       return __result + __n;
193*38fd1498Szrj     }
194*38fd1498Szrj 
195*38fd1498Szrj   /** @brief Parallel partial sum front-__end.
196*38fd1498Szrj    *  @param __begin Begin iterator of input sequence.
197*38fd1498Szrj    *  @param __end End iterator of input sequence.
198*38fd1498Szrj    *  @param __result Begin iterator of output sequence.
199*38fd1498Szrj    *  @param __bin_op Associative binary function.
200*38fd1498Szrj    *  @return End iterator of output sequence. */
201*38fd1498Szrj   template<typename _IIter,
202*38fd1498Szrj 	   typename _OutputIterator,
203*38fd1498Szrj 	   typename _BinaryOperation>
204*38fd1498Szrj     _OutputIterator
__parallel_partial_sum(_IIter __begin,_IIter __end,_OutputIterator __result,_BinaryOperation __bin_op)205*38fd1498Szrj     __parallel_partial_sum(_IIter __begin, _IIter __end,
206*38fd1498Szrj 			   _OutputIterator __result, _BinaryOperation __bin_op)
207*38fd1498Szrj     {
208*38fd1498Szrj       _GLIBCXX_CALL(__begin - __end)
209*38fd1498Szrj 
210*38fd1498Szrj       typedef std::iterator_traits<_IIter> _TraitsType;
211*38fd1498Szrj       typedef typename _TraitsType::value_type _ValueType;
212*38fd1498Szrj       typedef typename _TraitsType::difference_type _DifferenceType;
213*38fd1498Szrj 
214*38fd1498Szrj       _DifferenceType __n = __end - __begin;
215*38fd1498Szrj 
216*38fd1498Szrj       switch (_Settings::get().partial_sum_algorithm)
217*38fd1498Szrj 	{
218*38fd1498Szrj 	case LINEAR:
219*38fd1498Szrj 	  // Need an initial offset.
220*38fd1498Szrj 	  return __parallel_partial_sum_linear(__begin, __end, __result,
221*38fd1498Szrj 					       __bin_op, __n);
222*38fd1498Szrj 	default:
223*38fd1498Szrj 	  // Partial_sum algorithm not implemented.
224*38fd1498Szrj 	  _GLIBCXX_PARALLEL_ASSERT(0);
225*38fd1498Szrj 	  return __result + __n;
226*38fd1498Szrj 	}
227*38fd1498Szrj     }
228*38fd1498Szrj }
229*38fd1498Szrj 
230*38fd1498Szrj #endif /* _GLIBCXX_PARALLEL_PARTIAL_SUM_H */
231