xref: /freebsd-src/contrib/llvm-project/libcxx/include/__random/discrete_distribution.h (revision 4824e7fd18a1223177218d4aec1b3c6c5c4a444e)
1*4824e7fdSDimitry Andric //===----------------------------------------------------------------------===//
2*4824e7fdSDimitry Andric //
3*4824e7fdSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*4824e7fdSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*4824e7fdSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*4824e7fdSDimitry Andric //
7*4824e7fdSDimitry Andric //===----------------------------------------------------------------------===//
8*4824e7fdSDimitry Andric 
9*4824e7fdSDimitry Andric #ifndef _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H
10*4824e7fdSDimitry Andric #define _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H
11*4824e7fdSDimitry Andric 
12*4824e7fdSDimitry Andric #include <__algorithm/upper_bound.h>
13*4824e7fdSDimitry Andric #include <__config>
14*4824e7fdSDimitry Andric #include <__random/uniform_real_distribution.h>
15*4824e7fdSDimitry Andric #include <cstddef>
16*4824e7fdSDimitry Andric #include <iosfwd>
17*4824e7fdSDimitry Andric #include <numeric>
18*4824e7fdSDimitry Andric #include <vector>
19*4824e7fdSDimitry Andric 
20*4824e7fdSDimitry Andric #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
21*4824e7fdSDimitry Andric #pragma GCC system_header
22*4824e7fdSDimitry Andric #endif
23*4824e7fdSDimitry Andric 
24*4824e7fdSDimitry Andric _LIBCPP_PUSH_MACROS
25*4824e7fdSDimitry Andric #include <__undef_macros>
26*4824e7fdSDimitry Andric 
27*4824e7fdSDimitry Andric _LIBCPP_BEGIN_NAMESPACE_STD
28*4824e7fdSDimitry Andric 
29*4824e7fdSDimitry Andric template<class _IntType = int>
30*4824e7fdSDimitry Andric class _LIBCPP_TEMPLATE_VIS discrete_distribution
31*4824e7fdSDimitry Andric {
32*4824e7fdSDimitry Andric public:
33*4824e7fdSDimitry Andric     // types
34*4824e7fdSDimitry Andric     typedef _IntType result_type;
35*4824e7fdSDimitry Andric 
36*4824e7fdSDimitry Andric     class _LIBCPP_TEMPLATE_VIS param_type
37*4824e7fdSDimitry Andric     {
38*4824e7fdSDimitry Andric         vector<double> __p_;
39*4824e7fdSDimitry Andric     public:
40*4824e7fdSDimitry Andric         typedef discrete_distribution distribution_type;
41*4824e7fdSDimitry Andric 
42*4824e7fdSDimitry Andric         _LIBCPP_INLINE_VISIBILITY
43*4824e7fdSDimitry Andric         param_type() {}
44*4824e7fdSDimitry Andric         template<class _InputIterator>
45*4824e7fdSDimitry Andric             _LIBCPP_INLINE_VISIBILITY
46*4824e7fdSDimitry Andric             param_type(_InputIterator __f, _InputIterator __l)
47*4824e7fdSDimitry Andric             : __p_(__f, __l) {__init();}
48*4824e7fdSDimitry Andric #ifndef _LIBCPP_CXX03_LANG
49*4824e7fdSDimitry Andric         _LIBCPP_INLINE_VISIBILITY
50*4824e7fdSDimitry Andric         param_type(initializer_list<double> __wl)
51*4824e7fdSDimitry Andric             : __p_(__wl.begin(), __wl.end()) {__init();}
52*4824e7fdSDimitry Andric #endif // _LIBCPP_CXX03_LANG
53*4824e7fdSDimitry Andric         template<class _UnaryOperation>
54*4824e7fdSDimitry Andric             param_type(size_t __nw, double __xmin, double __xmax,
55*4824e7fdSDimitry Andric                        _UnaryOperation __fw);
56*4824e7fdSDimitry Andric 
57*4824e7fdSDimitry Andric         vector<double> probabilities() const;
58*4824e7fdSDimitry Andric 
59*4824e7fdSDimitry Andric         friend _LIBCPP_INLINE_VISIBILITY
60*4824e7fdSDimitry Andric             bool operator==(const param_type& __x, const param_type& __y)
61*4824e7fdSDimitry Andric             {return __x.__p_ == __y.__p_;}
62*4824e7fdSDimitry Andric         friend _LIBCPP_INLINE_VISIBILITY
63*4824e7fdSDimitry Andric             bool operator!=(const param_type& __x, const param_type& __y)
64*4824e7fdSDimitry Andric             {return !(__x == __y);}
65*4824e7fdSDimitry Andric 
66*4824e7fdSDimitry Andric     private:
67*4824e7fdSDimitry Andric         void __init();
68*4824e7fdSDimitry Andric 
69*4824e7fdSDimitry Andric         friend class discrete_distribution;
70*4824e7fdSDimitry Andric 
71*4824e7fdSDimitry Andric         template <class _CharT, class _Traits, class _IT>
72*4824e7fdSDimitry Andric         friend
73*4824e7fdSDimitry Andric         basic_ostream<_CharT, _Traits>&
74*4824e7fdSDimitry Andric         operator<<(basic_ostream<_CharT, _Traits>& __os,
75*4824e7fdSDimitry Andric                    const discrete_distribution<_IT>& __x);
76*4824e7fdSDimitry Andric 
77*4824e7fdSDimitry Andric         template <class _CharT, class _Traits, class _IT>
78*4824e7fdSDimitry Andric         friend
79*4824e7fdSDimitry Andric         basic_istream<_CharT, _Traits>&
80*4824e7fdSDimitry Andric         operator>>(basic_istream<_CharT, _Traits>& __is,
81*4824e7fdSDimitry Andric                    discrete_distribution<_IT>& __x);
82*4824e7fdSDimitry Andric     };
83*4824e7fdSDimitry Andric 
84*4824e7fdSDimitry Andric private:
85*4824e7fdSDimitry Andric     param_type __p_;
86*4824e7fdSDimitry Andric 
87*4824e7fdSDimitry Andric public:
88*4824e7fdSDimitry Andric     // constructor and reset functions
89*4824e7fdSDimitry Andric     _LIBCPP_INLINE_VISIBILITY
90*4824e7fdSDimitry Andric     discrete_distribution() {}
91*4824e7fdSDimitry Andric     template<class _InputIterator>
92*4824e7fdSDimitry Andric         _LIBCPP_INLINE_VISIBILITY
93*4824e7fdSDimitry Andric         discrete_distribution(_InputIterator __f, _InputIterator __l)
94*4824e7fdSDimitry Andric             : __p_(__f, __l) {}
95*4824e7fdSDimitry Andric #ifndef _LIBCPP_CXX03_LANG
96*4824e7fdSDimitry Andric     _LIBCPP_INLINE_VISIBILITY
97*4824e7fdSDimitry Andric     discrete_distribution(initializer_list<double> __wl)
98*4824e7fdSDimitry Andric         : __p_(__wl) {}
99*4824e7fdSDimitry Andric #endif // _LIBCPP_CXX03_LANG
100*4824e7fdSDimitry Andric     template<class _UnaryOperation>
101*4824e7fdSDimitry Andric         _LIBCPP_INLINE_VISIBILITY
102*4824e7fdSDimitry Andric         discrete_distribution(size_t __nw, double __xmin, double __xmax,
103*4824e7fdSDimitry Andric                               _UnaryOperation __fw)
104*4824e7fdSDimitry Andric         : __p_(__nw, __xmin, __xmax, __fw) {}
105*4824e7fdSDimitry Andric     _LIBCPP_INLINE_VISIBILITY
106*4824e7fdSDimitry Andric     explicit discrete_distribution(const param_type& __p)
107*4824e7fdSDimitry Andric         : __p_(__p) {}
108*4824e7fdSDimitry Andric     _LIBCPP_INLINE_VISIBILITY
109*4824e7fdSDimitry Andric     void reset() {}
110*4824e7fdSDimitry Andric 
111*4824e7fdSDimitry Andric     // generating functions
112*4824e7fdSDimitry Andric     template<class _URNG>
113*4824e7fdSDimitry Andric         _LIBCPP_INLINE_VISIBILITY
114*4824e7fdSDimitry Andric         result_type operator()(_URNG& __g)
115*4824e7fdSDimitry Andric         {return (*this)(__g, __p_);}
116*4824e7fdSDimitry Andric     template<class _URNG> result_type operator()(_URNG& __g, const param_type& __p);
117*4824e7fdSDimitry Andric 
118*4824e7fdSDimitry Andric     // property functions
119*4824e7fdSDimitry Andric     _LIBCPP_INLINE_VISIBILITY
120*4824e7fdSDimitry Andric     vector<double> probabilities() const {return __p_.probabilities();}
121*4824e7fdSDimitry Andric 
122*4824e7fdSDimitry Andric     _LIBCPP_INLINE_VISIBILITY
123*4824e7fdSDimitry Andric     param_type param() const {return __p_;}
124*4824e7fdSDimitry Andric     _LIBCPP_INLINE_VISIBILITY
125*4824e7fdSDimitry Andric     void param(const param_type& __p) {__p_ = __p;}
126*4824e7fdSDimitry Andric 
127*4824e7fdSDimitry Andric     _LIBCPP_INLINE_VISIBILITY
128*4824e7fdSDimitry Andric     result_type min() const {return 0;}
129*4824e7fdSDimitry Andric     _LIBCPP_INLINE_VISIBILITY
130*4824e7fdSDimitry Andric     result_type max() const {return __p_.__p_.size();}
131*4824e7fdSDimitry Andric 
132*4824e7fdSDimitry Andric     friend _LIBCPP_INLINE_VISIBILITY
133*4824e7fdSDimitry Andric         bool operator==(const discrete_distribution& __x,
134*4824e7fdSDimitry Andric                         const discrete_distribution& __y)
135*4824e7fdSDimitry Andric         {return __x.__p_ == __y.__p_;}
136*4824e7fdSDimitry Andric     friend _LIBCPP_INLINE_VISIBILITY
137*4824e7fdSDimitry Andric         bool operator!=(const discrete_distribution& __x,
138*4824e7fdSDimitry Andric                         const discrete_distribution& __y)
139*4824e7fdSDimitry Andric         {return !(__x == __y);}
140*4824e7fdSDimitry Andric 
141*4824e7fdSDimitry Andric     template <class _CharT, class _Traits, class _IT>
142*4824e7fdSDimitry Andric     friend
143*4824e7fdSDimitry Andric     basic_ostream<_CharT, _Traits>&
144*4824e7fdSDimitry Andric     operator<<(basic_ostream<_CharT, _Traits>& __os,
145*4824e7fdSDimitry Andric                const discrete_distribution<_IT>& __x);
146*4824e7fdSDimitry Andric 
147*4824e7fdSDimitry Andric     template <class _CharT, class _Traits, class _IT>
148*4824e7fdSDimitry Andric     friend
149*4824e7fdSDimitry Andric     basic_istream<_CharT, _Traits>&
150*4824e7fdSDimitry Andric     operator>>(basic_istream<_CharT, _Traits>& __is,
151*4824e7fdSDimitry Andric                discrete_distribution<_IT>& __x);
152*4824e7fdSDimitry Andric };
153*4824e7fdSDimitry Andric 
154*4824e7fdSDimitry Andric template<class _IntType>
155*4824e7fdSDimitry Andric template<class _UnaryOperation>
156*4824e7fdSDimitry Andric discrete_distribution<_IntType>::param_type::param_type(size_t __nw,
157*4824e7fdSDimitry Andric                                                         double __xmin,
158*4824e7fdSDimitry Andric                                                         double __xmax,
159*4824e7fdSDimitry Andric                                                         _UnaryOperation __fw)
160*4824e7fdSDimitry Andric {
161*4824e7fdSDimitry Andric     if (__nw > 1)
162*4824e7fdSDimitry Andric     {
163*4824e7fdSDimitry Andric         __p_.reserve(__nw - 1);
164*4824e7fdSDimitry Andric         double __d = (__xmax - __xmin) / __nw;
165*4824e7fdSDimitry Andric         double __d2 = __d / 2;
166*4824e7fdSDimitry Andric         for (size_t __k = 0; __k < __nw; ++__k)
167*4824e7fdSDimitry Andric             __p_.push_back(__fw(__xmin + __k * __d + __d2));
168*4824e7fdSDimitry Andric         __init();
169*4824e7fdSDimitry Andric     }
170*4824e7fdSDimitry Andric }
171*4824e7fdSDimitry Andric 
172*4824e7fdSDimitry Andric template<class _IntType>
173*4824e7fdSDimitry Andric void
174*4824e7fdSDimitry Andric discrete_distribution<_IntType>::param_type::__init()
175*4824e7fdSDimitry Andric {
176*4824e7fdSDimitry Andric     if (!__p_.empty())
177*4824e7fdSDimitry Andric     {
178*4824e7fdSDimitry Andric         if (__p_.size() > 1)
179*4824e7fdSDimitry Andric         {
180*4824e7fdSDimitry Andric             double __s = _VSTD::accumulate(__p_.begin(), __p_.end(), 0.0);
181*4824e7fdSDimitry Andric             for (vector<double>::iterator __i = __p_.begin(), __e = __p_.end(); __i < __e; ++__i)
182*4824e7fdSDimitry Andric                 *__i /= __s;
183*4824e7fdSDimitry Andric             vector<double> __t(__p_.size() - 1);
184*4824e7fdSDimitry Andric             _VSTD::partial_sum(__p_.begin(), __p_.end() - 1, __t.begin());
185*4824e7fdSDimitry Andric             swap(__p_, __t);
186*4824e7fdSDimitry Andric         }
187*4824e7fdSDimitry Andric         else
188*4824e7fdSDimitry Andric         {
189*4824e7fdSDimitry Andric             __p_.clear();
190*4824e7fdSDimitry Andric             __p_.shrink_to_fit();
191*4824e7fdSDimitry Andric         }
192*4824e7fdSDimitry Andric     }
193*4824e7fdSDimitry Andric }
194*4824e7fdSDimitry Andric 
195*4824e7fdSDimitry Andric template<class _IntType>
196*4824e7fdSDimitry Andric vector<double>
197*4824e7fdSDimitry Andric discrete_distribution<_IntType>::param_type::probabilities() const
198*4824e7fdSDimitry Andric {
199*4824e7fdSDimitry Andric     size_t __n = __p_.size();
200*4824e7fdSDimitry Andric     vector<double> __p(__n+1);
201*4824e7fdSDimitry Andric     _VSTD::adjacent_difference(__p_.begin(), __p_.end(), __p.begin());
202*4824e7fdSDimitry Andric     if (__n > 0)
203*4824e7fdSDimitry Andric         __p[__n] = 1 - __p_[__n-1];
204*4824e7fdSDimitry Andric     else
205*4824e7fdSDimitry Andric         __p[0] = 1;
206*4824e7fdSDimitry Andric     return __p;
207*4824e7fdSDimitry Andric }
208*4824e7fdSDimitry Andric 
209*4824e7fdSDimitry Andric template<class _IntType>
210*4824e7fdSDimitry Andric template<class _URNG>
211*4824e7fdSDimitry Andric _IntType
212*4824e7fdSDimitry Andric discrete_distribution<_IntType>::operator()(_URNG& __g, const param_type& __p)
213*4824e7fdSDimitry Andric {
214*4824e7fdSDimitry Andric     uniform_real_distribution<double> __gen;
215*4824e7fdSDimitry Andric     return static_cast<_IntType>(
216*4824e7fdSDimitry Andric            _VSTD::upper_bound(__p.__p_.begin(), __p.__p_.end(), __gen(__g)) -
217*4824e7fdSDimitry Andric                                                               __p.__p_.begin());
218*4824e7fdSDimitry Andric }
219*4824e7fdSDimitry Andric 
220*4824e7fdSDimitry Andric template <class _CharT, class _Traits, class _IT>
221*4824e7fdSDimitry Andric basic_ostream<_CharT, _Traits>&
222*4824e7fdSDimitry Andric operator<<(basic_ostream<_CharT, _Traits>& __os,
223*4824e7fdSDimitry Andric            const discrete_distribution<_IT>& __x)
224*4824e7fdSDimitry Andric {
225*4824e7fdSDimitry Andric     __save_flags<_CharT, _Traits> __lx(__os);
226*4824e7fdSDimitry Andric     typedef basic_ostream<_CharT, _Traits> _OStream;
227*4824e7fdSDimitry Andric     __os.flags(_OStream::dec | _OStream::left | _OStream::fixed |
228*4824e7fdSDimitry Andric                _OStream::scientific);
229*4824e7fdSDimitry Andric     _CharT __sp = __os.widen(' ');
230*4824e7fdSDimitry Andric     __os.fill(__sp);
231*4824e7fdSDimitry Andric     size_t __n = __x.__p_.__p_.size();
232*4824e7fdSDimitry Andric     __os << __n;
233*4824e7fdSDimitry Andric     for (size_t __i = 0; __i < __n; ++__i)
234*4824e7fdSDimitry Andric         __os << __sp << __x.__p_.__p_[__i];
235*4824e7fdSDimitry Andric     return __os;
236*4824e7fdSDimitry Andric }
237*4824e7fdSDimitry Andric 
238*4824e7fdSDimitry Andric template <class _CharT, class _Traits, class _IT>
239*4824e7fdSDimitry Andric basic_istream<_CharT, _Traits>&
240*4824e7fdSDimitry Andric operator>>(basic_istream<_CharT, _Traits>& __is,
241*4824e7fdSDimitry Andric            discrete_distribution<_IT>& __x)
242*4824e7fdSDimitry Andric {
243*4824e7fdSDimitry Andric     __save_flags<_CharT, _Traits> __lx(__is);
244*4824e7fdSDimitry Andric     typedef basic_istream<_CharT, _Traits> _Istream;
245*4824e7fdSDimitry Andric     __is.flags(_Istream::dec | _Istream::skipws);
246*4824e7fdSDimitry Andric     size_t __n;
247*4824e7fdSDimitry Andric     __is >> __n;
248*4824e7fdSDimitry Andric     vector<double> __p(__n);
249*4824e7fdSDimitry Andric     for (size_t __i = 0; __i < __n; ++__i)
250*4824e7fdSDimitry Andric         __is >> __p[__i];
251*4824e7fdSDimitry Andric     if (!__is.fail())
252*4824e7fdSDimitry Andric         swap(__x.__p_.__p_, __p);
253*4824e7fdSDimitry Andric     return __is;
254*4824e7fdSDimitry Andric }
255*4824e7fdSDimitry Andric 
256*4824e7fdSDimitry Andric _LIBCPP_END_NAMESPACE_STD
257*4824e7fdSDimitry Andric 
258*4824e7fdSDimitry Andric _LIBCPP_POP_MACROS
259*4824e7fdSDimitry Andric 
260*4824e7fdSDimitry Andric #endif // _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H
261