xref: /llvm-project/flang/runtime/sum.cpp (revision fc97d2e68b03bc2979395e84b645e5b3ba35aecd)
1 //===-- runtime/sum.cpp ---------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // Implements SUM for all required operand types and shapes.
10 //
11 // Real and complex SUM reductions attempt to reduce floating-point
12 // cancellation on intermediate results by using "Kahan summation"
13 // (basically the same as manual "double-double").
14 
15 #include "reduction-templates.h"
16 #include "flang/Common/float128.h"
17 #include "flang/Runtime/reduction.h"
18 #include <cfloat>
19 #include <cinttypes>
20 #include <complex>
21 
22 namespace Fortran::runtime {
23 
24 template <typename INTERMEDIATE> class IntegerSumAccumulator {
25 public:
26   explicit RT_API_ATTRS IntegerSumAccumulator(const Descriptor &array)
27       : array_{array} {}
28   void RT_API_ATTRS Reinitialize() { sum_ = 0; }
29   template <typename A>
30   RT_API_ATTRS void GetResult(A *p, int /*zeroBasedDim*/ = -1) const {
31     *p = static_cast<A>(sum_);
32   }
33   template <typename A>
34   RT_API_ATTRS bool AccumulateAt(const SubscriptValue at[]) {
35     sum_ += *array_.Element<A>(at);
36     return true;
37   }
38 
39 private:
40   const Descriptor &array_;
41   INTERMEDIATE sum_{0};
42 };
43 
44 template <typename INTERMEDIATE> class RealSumAccumulator {
45 public:
46   explicit RT_API_ATTRS RealSumAccumulator(const Descriptor &array)
47       : array_{array} {}
48   void RT_API_ATTRS Reinitialize() { sum_ = correction_ = 0; }
49   template <typename A> RT_API_ATTRS A Result() const { return sum_; }
50   template <typename A>
51   RT_API_ATTRS void GetResult(A *p, int /*zeroBasedDim*/ = -1) const {
52     *p = Result<A>();
53   }
54   template <typename A> RT_API_ATTRS bool Accumulate(A x) {
55     // Kahan summation
56     auto next{x - correction_};
57     auto oldSum{sum_};
58     sum_ += next;
59     correction_ = (sum_ - oldSum) - next; // algebraically zero
60     return true;
61   }
62   template <typename A>
63   RT_API_ATTRS bool AccumulateAt(const SubscriptValue at[]) {
64     return Accumulate(*array_.Element<A>(at));
65   }
66 
67 private:
68   const Descriptor &array_;
69   INTERMEDIATE sum_{0.0}, correction_{0.0};
70 };
71 
72 template <typename PART> class ComplexSumAccumulator {
73 public:
74   explicit RT_API_ATTRS ComplexSumAccumulator(const Descriptor &array)
75       : array_{array} {}
76   void RT_API_ATTRS Reinitialize() {
77     reals_.Reinitialize();
78     imaginaries_.Reinitialize();
79   }
80   template <typename A>
81   RT_API_ATTRS void GetResult(A *p, int /*zeroBasedDim*/ = -1) const {
82     using ResultPart = typename A::value_type;
83     *p = {reals_.template Result<ResultPart>(),
84         imaginaries_.template Result<ResultPart>()};
85   }
86   template <typename A> RT_API_ATTRS bool Accumulate(const A &z) {
87     reals_.Accumulate(z.real());
88     imaginaries_.Accumulate(z.imag());
89     return true;
90   }
91   template <typename A>
92   RT_API_ATTRS bool AccumulateAt(const SubscriptValue at[]) {
93     return Accumulate(*array_.Element<A>(at));
94   }
95 
96 private:
97   const Descriptor &array_;
98   RealSumAccumulator<PART> reals_{array_}, imaginaries_{array_};
99 };
100 
101 extern "C" {
102 RT_EXT_API_GROUP_BEGIN
103 
104 CppTypeFor<TypeCategory::Integer, 1> RTDEF(SumInteger1)(const Descriptor &x,
105     const char *source, int line, int dim, const Descriptor *mask) {
106   return GetTotalReduction<TypeCategory::Integer, 1>(x, source, line, dim, mask,
107       IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x}, "SUM");
108 }
109 CppTypeFor<TypeCategory::Integer, 2> RTDEF(SumInteger2)(const Descriptor &x,
110     const char *source, int line, int dim, const Descriptor *mask) {
111   return GetTotalReduction<TypeCategory::Integer, 2>(x, source, line, dim, mask,
112       IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x}, "SUM");
113 }
114 CppTypeFor<TypeCategory::Integer, 4> RTDEF(SumInteger4)(const Descriptor &x,
115     const char *source, int line, int dim, const Descriptor *mask) {
116   return GetTotalReduction<TypeCategory::Integer, 4>(x, source, line, dim, mask,
117       IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x}, "SUM");
118 }
119 CppTypeFor<TypeCategory::Integer, 8> RTDEF(SumInteger8)(const Descriptor &x,
120     const char *source, int line, int dim, const Descriptor *mask) {
121   return GetTotalReduction<TypeCategory::Integer, 8>(x, source, line, dim, mask,
122       IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 8>>{x}, "SUM");
123 }
124 #ifdef __SIZEOF_INT128__
125 CppTypeFor<TypeCategory::Integer, 16> RTDEF(SumInteger16)(const Descriptor &x,
126     const char *source, int line, int dim, const Descriptor *mask) {
127   return GetTotalReduction<TypeCategory::Integer, 16>(x, source, line, dim,
128       mask, IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 16>>{x},
129       "SUM");
130 }
131 #endif
132 
133 CppTypeFor<TypeCategory::Unsigned, 1> RTDEF(SumUnsigned1)(const Descriptor &x,
134     const char *source, int line, int dim, const Descriptor *mask) {
135   return GetTotalReduction<TypeCategory::Unsigned, 1>(x, source, line, dim,
136       mask, IntegerSumAccumulator<CppTypeFor<TypeCategory::Unsigned, 4>>{x},
137       "SUM");
138 }
139 CppTypeFor<TypeCategory::Unsigned, 2> RTDEF(SumUnsigned2)(const Descriptor &x,
140     const char *source, int line, int dim, const Descriptor *mask) {
141   return GetTotalReduction<TypeCategory::Unsigned, 2>(x, source, line, dim,
142       mask, IntegerSumAccumulator<CppTypeFor<TypeCategory::Unsigned, 4>>{x},
143       "SUM");
144 }
145 CppTypeFor<TypeCategory::Unsigned, 4> RTDEF(SumUnsigned4)(const Descriptor &x,
146     const char *source, int line, int dim, const Descriptor *mask) {
147   return GetTotalReduction<TypeCategory::Unsigned, 4>(x, source, line, dim,
148       mask, IntegerSumAccumulator<CppTypeFor<TypeCategory::Unsigned, 4>>{x},
149       "SUM");
150 }
151 CppTypeFor<TypeCategory::Unsigned, 8> RTDEF(SumUnsigned8)(const Descriptor &x,
152     const char *source, int line, int dim, const Descriptor *mask) {
153   return GetTotalReduction<TypeCategory::Unsigned, 8>(x, source, line, dim,
154       mask, IntegerSumAccumulator<CppTypeFor<TypeCategory::Unsigned, 8>>{x},
155       "SUM");
156 }
157 #ifdef __SIZEOF_INT128__
158 CppTypeFor<TypeCategory::Unsigned, 16> RTDEF(SumUnsigned16)(const Descriptor &x,
159     const char *source, int line, int dim, const Descriptor *mask) {
160   return GetTotalReduction<TypeCategory::Unsigned, 16>(x, source, line, dim,
161       mask, IntegerSumAccumulator<CppTypeFor<TypeCategory::Unsigned, 16>>{x},
162       "SUM");
163 }
164 #endif
165 
166 // TODO: real/complex(2 & 3)
167 CppTypeFor<TypeCategory::Real, 4> RTDEF(SumReal4)(const Descriptor &x,
168     const char *source, int line, int dim, const Descriptor *mask) {
169   return GetTotalReduction<TypeCategory::Real, 4>(
170       x, source, line, dim, mask, RealSumAccumulator<float>{x}, "SUM");
171 }
172 CppTypeFor<TypeCategory::Real, 8> RTDEF(SumReal8)(const Descriptor &x,
173     const char *source, int line, int dim, const Descriptor *mask) {
174   return GetTotalReduction<TypeCategory::Real, 8>(
175       x, source, line, dim, mask, RealSumAccumulator<double>{x}, "SUM");
176 }
177 #if HAS_FLOAT80
178 CppTypeFor<TypeCategory::Real, 10> RTDEF(SumReal10)(const Descriptor &x,
179     const char *source, int line, int dim, const Descriptor *mask) {
180   return GetTotalReduction<TypeCategory::Real, 10>(x, source, line, dim, mask,
181       RealSumAccumulator<CppTypeFor<TypeCategory::Real, 10>>{x}, "SUM");
182 }
183 #endif
184 #if HAS_LDBL128 || HAS_FLOAT128
185 CppTypeFor<TypeCategory::Real, 16> RTDEF(SumReal16)(const Descriptor &x,
186     const char *source, int line, int dim, const Descriptor *mask) {
187   return GetTotalReduction<TypeCategory::Real, 16>(x, source, line, dim, mask,
188       RealSumAccumulator<CppTypeFor<TypeCategory::Real, 16>>{x}, "SUM");
189 }
190 #endif
191 
192 void RTDEF(CppSumComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
193     const Descriptor &x, const char *source, int line, int dim,
194     const Descriptor *mask) {
195   result = GetTotalReduction<TypeCategory::Complex, 4>(
196       x, source, line, dim, mask, ComplexSumAccumulator<float>{x}, "SUM");
197 }
198 void RTDEF(CppSumComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
199     const Descriptor &x, const char *source, int line, int dim,
200     const Descriptor *mask) {
201   result = GetTotalReduction<TypeCategory::Complex, 8>(
202       x, source, line, dim, mask, ComplexSumAccumulator<double>{x}, "SUM");
203 }
204 #if HAS_FLOAT80
205 void RTDEF(CppSumComplex10)(CppTypeFor<TypeCategory::Complex, 10> &result,
206     const Descriptor &x, const char *source, int line, int dim,
207     const Descriptor *mask) {
208   result =
209       GetTotalReduction<TypeCategory::Complex, 10>(x, source, line, dim, mask,
210           ComplexSumAccumulator<CppTypeFor<TypeCategory::Real, 10>>{x}, "SUM");
211 }
212 #endif
213 #if HAS_LDBL128 || HAS_FLOAT128
214 void RTDEF(CppSumComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
215     const Descriptor &x, const char *source, int line, int dim,
216     const Descriptor *mask) {
217   result =
218       GetTotalReduction<TypeCategory::Complex, 16>(x, source, line, dim, mask,
219           ComplexSumAccumulator<CppTypeFor<TypeCategory::Real, 16>>{x}, "SUM");
220 }
221 #endif
222 
223 void RTDEF(SumDim)(Descriptor &result, const Descriptor &x, int dim,
224     const char *source, int line, const Descriptor *mask) {
225   TypedPartialNumericReduction<IntegerSumAccumulator, RealSumAccumulator,
226       ComplexSumAccumulator, /*MIN_REAL_KIND=*/4>(
227       result, x, dim, source, line, mask, "SUM");
228 }
229 
230 RT_EXT_API_GROUP_END
231 } // extern "C"
232 } // namespace Fortran::runtime
233