xref: /llvm-project/flang/lib/Evaluate/fold-reduction.h (revision fc97d2e68b03bc2979395e84b645e5b3ba35aecd)
147f18af5Speter klausler //===-- lib/Evaluate/fold-reduction.h -------------------------------------===//
247f18af5Speter klausler //
347f18af5Speter klausler // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
447f18af5Speter klausler // See https://llvm.org/LICENSE.txt for license information.
547f18af5Speter klausler // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
647f18af5Speter klausler //
747f18af5Speter klausler //===----------------------------------------------------------------------===//
847f18af5Speter klausler 
947f18af5Speter klausler #ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
1047f18af5Speter klausler #define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
1147f18af5Speter klausler 
1247f18af5Speter klausler #include "fold-implementation.h"
1347f18af5Speter klausler 
1447f18af5Speter klausler namespace Fortran::evaluate {
1547f18af5Speter klausler 
16e723c69bSPeter Klausler // DOT_PRODUCT
17e723c69bSPeter Klausler template <typename T>
18e723c69bSPeter Klausler static Expr<T> FoldDotProduct(
19e723c69bSPeter Klausler     FoldingContext &context, FunctionRef<T> &&funcRef) {
20e723c69bSPeter Klausler   using Element = typename Constant<T>::Element;
21e723c69bSPeter Klausler   auto args{funcRef.arguments()};
22e723c69bSPeter Klausler   CHECK(args.size() == 2);
23e723c69bSPeter Klausler   Folder<T> folder{context};
24e723c69bSPeter Klausler   Constant<T> *va{folder.Folding(args[0])};
25e723c69bSPeter Klausler   Constant<T> *vb{folder.Folding(args[1])};
26e723c69bSPeter Klausler   if (va && vb) {
27e723c69bSPeter Klausler     CHECK(va->Rank() == 1 && vb->Rank() == 1);
28e723c69bSPeter Klausler     if (va->size() != vb->size()) {
29e723c69bSPeter Klausler       context.messages().Say(
30e723c69bSPeter Klausler           "Vector arguments to DOT_PRODUCT have distinct extents %zd and %zd"_err_en_US,
31e723c69bSPeter Klausler           va->size(), vb->size());
32e723c69bSPeter Klausler       return MakeInvalidIntrinsic(std::move(funcRef));
33e723c69bSPeter Klausler     }
34e723c69bSPeter Klausler     Element sum{};
35e723c69bSPeter Klausler     bool overflow{false};
36e723c69bSPeter Klausler     if constexpr (T::category == TypeCategory::Complex) {
37e723c69bSPeter Klausler       std::vector<Element> conjugates;
38e723c69bSPeter Klausler       for (const Element &x : va->values()) {
39e723c69bSPeter Klausler         conjugates.emplace_back(x.CONJG());
40e723c69bSPeter Klausler       }
41e723c69bSPeter Klausler       Constant<T> conjgA{
42e723c69bSPeter Klausler           std::move(conjugates), ConstantSubscripts{va->shape()}};
43e723c69bSPeter Klausler       Expr<T> products{Fold(
44e723c69bSPeter Klausler           context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})};
45e723c69bSPeter Klausler       Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
463502d340SPeter Klausler       [[maybe_unused]] Element correction{};
47e723c69bSPeter Klausler       const auto &rounding{context.targetCharacteristics().roundingMode()};
48e723c69bSPeter Klausler       for (const Element &x : cProducts.values()) {
493502d340SPeter Klausler         if constexpr (useKahanSummation) {
503f594741SPeter Klausler           auto next{x.Subtract(correction, rounding)};
51e723c69bSPeter Klausler           overflow |= next.flags.test(RealFlag::Overflow);
52e723c69bSPeter Klausler           auto added{sum.Add(next.value, rounding)};
53e723c69bSPeter Klausler           overflow |= added.flags.test(RealFlag::Overflow);
54e723c69bSPeter Klausler           correction = added.value.Subtract(sum, rounding)
55e723c69bSPeter Klausler                            .value.Subtract(next.value, rounding)
56e723c69bSPeter Klausler                            .value;
57e723c69bSPeter Klausler           sum = std::move(added.value);
583502d340SPeter Klausler         } else {
593502d340SPeter Klausler           auto added{sum.Add(x, rounding)};
603502d340SPeter Klausler           overflow |= added.flags.test(RealFlag::Overflow);
613502d340SPeter Klausler           sum = std::move(added.value);
623502d340SPeter Klausler         }
63e723c69bSPeter Klausler       }
64e723c69bSPeter Klausler     } else if constexpr (T::category == TypeCategory::Logical) {
65e723c69bSPeter Klausler       Expr<T> conjunctions{Fold(context,
66e723c69bSPeter Klausler           Expr<T>{LogicalOperation<T::kind>{LogicalOperator::And,
67e723c69bSPeter Klausler               Expr<T>{Constant<T>{*va}}, Expr<T>{Constant<T>{*vb}}}})};
68e723c69bSPeter Klausler       Constant<T> &cConjunctions{DEREF(UnwrapConstantValue<T>(conjunctions))};
69e723c69bSPeter Klausler       for (const Element &x : cConjunctions.values()) {
70e723c69bSPeter Klausler         if (x.IsTrue()) {
71e723c69bSPeter Klausler           sum = Element{true};
72e723c69bSPeter Klausler           break;
73e723c69bSPeter Klausler         }
74e723c69bSPeter Klausler       }
75e723c69bSPeter Klausler     } else if constexpr (T::category == TypeCategory::Integer) {
76e723c69bSPeter Klausler       Expr<T> products{
77e723c69bSPeter Klausler           Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
78e723c69bSPeter Klausler       Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
79e723c69bSPeter Klausler       for (const Element &x : cProducts.values()) {
80e723c69bSPeter Klausler         auto next{sum.AddSigned(x)};
81e723c69bSPeter Klausler         overflow |= next.overflow;
82e723c69bSPeter Klausler         sum = std::move(next.value);
83e723c69bSPeter Klausler       }
84*fc97d2e6SPeter Klausler     } else if constexpr (T::category == TypeCategory::Unsigned) {
85*fc97d2e6SPeter Klausler       Expr<T> products{
86*fc97d2e6SPeter Klausler           Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
87*fc97d2e6SPeter Klausler       Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
88*fc97d2e6SPeter Klausler       for (const Element &x : cProducts.values()) {
89*fc97d2e6SPeter Klausler         sum = sum.AddUnsigned(x).value;
90*fc97d2e6SPeter Klausler       }
9139f1860dSPeter Klausler     } else {
9239f1860dSPeter Klausler       static_assert(T::category == TypeCategory::Real);
93e723c69bSPeter Klausler       Expr<T> products{
94e723c69bSPeter Klausler           Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
95e723c69bSPeter Klausler       Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
963502d340SPeter Klausler       [[maybe_unused]] Element correction{};
97e723c69bSPeter Klausler       const auto &rounding{context.targetCharacteristics().roundingMode()};
98e723c69bSPeter Klausler       for (const Element &x : cProducts.values()) {
993502d340SPeter Klausler         if constexpr (useKahanSummation) {
1003f594741SPeter Klausler           auto next{x.Subtract(correction, rounding)};
101e723c69bSPeter Klausler           overflow |= next.flags.test(RealFlag::Overflow);
102e723c69bSPeter Klausler           auto added{sum.Add(next.value, rounding)};
103e723c69bSPeter Klausler           overflow |= added.flags.test(RealFlag::Overflow);
104e723c69bSPeter Klausler           correction = added.value.Subtract(sum, rounding)
105e723c69bSPeter Klausler                            .value.Subtract(next.value, rounding)
106e723c69bSPeter Klausler                            .value;
107e723c69bSPeter Klausler           sum = std::move(added.value);
1083502d340SPeter Klausler         } else {
1093502d340SPeter Klausler           auto added{sum.Add(x, rounding)};
1103502d340SPeter Klausler           overflow |= added.flags.test(RealFlag::Overflow);
1113502d340SPeter Klausler           sum = std::move(added.value);
1123502d340SPeter Klausler         }
113e723c69bSPeter Klausler       }
114e723c69bSPeter Klausler     }
115505f6da1SPeter Klausler     if (overflow &&
116505f6da1SPeter Klausler         context.languageFeatures().ShouldWarn(
117505f6da1SPeter Klausler             common::UsageWarning::FoldingException)) {
1180f973ac7SPeter Klausler       context.messages().Say(common::UsageWarning::FoldingException,
119e723c69bSPeter Klausler           "DOT_PRODUCT of %s data overflowed during computation"_warn_en_US,
120e723c69bSPeter Klausler           T::AsFortran());
121e723c69bSPeter Klausler     }
122e723c69bSPeter Klausler     return Expr<T>{Constant<T>{std::move(sum)}};
123e723c69bSPeter Klausler   }
124e723c69bSPeter Klausler   return Expr<T>{std::move(funcRef)};
125e723c69bSPeter Klausler }
126e723c69bSPeter Klausler 
127e723c69bSPeter Klausler // Fold and validate a DIM= argument.  Returns false on error.
12882568675SPeter Klausler bool CheckReductionDIM(std::optional<int> &dim, FoldingContext &,
12982568675SPeter Klausler     ActualArguments &, std::optional<int> dimIndex, int rank);
13082568675SPeter Klausler 
13182568675SPeter Klausler // Fold and validate a MASK= argument.  Return null on error, absent MASK=, or
13282568675SPeter Klausler // non-constant MASK=.
13382568675SPeter Klausler Constant<LogicalResult> *GetReductionMASK(
13482568675SPeter Klausler     std::optional<ActualArgument> &maskArg, const ConstantSubscripts &shape,
13582568675SPeter Klausler     FoldingContext &);
13626aff847Speter klausler 
137503c085eSpeter klausler // Common preprocessing for reduction transformational intrinsic function
138503c085eSpeter klausler // folding.  If the intrinsic can have DIM= &/or MASK= arguments, extract
139503c085eSpeter klausler // and check them.  If a MASK= is present, apply it to the array data and
14082e1e412SPeter Klausler // substitute replacement values for elements corresponding to .FALSE. in
141503c085eSpeter klausler // the mask.  If the result is present, the intrinsic call can be folded.
14282e1e412SPeter Klausler template <typename T> struct ArrayAndMask {
14382e1e412SPeter Klausler   Constant<T> array;
14482e1e412SPeter Klausler   Constant<LogicalResult> mask;
14582e1e412SPeter Klausler };
14647f18af5Speter klausler template <typename T>
14782e1e412SPeter Klausler static std::optional<ArrayAndMask<T>> ProcessReductionArgs(
14882e1e412SPeter Klausler     FoldingContext &context, ActualArguments &arg, std::optional<int> &dim,
14982568675SPeter Klausler     int arrayIndex, std::optional<int> dimIndex = std::nullopt,
15082568675SPeter Klausler     std::optional<int> maskIndex = std::nullopt) {
15147f18af5Speter klausler   if (arg.empty()) {
152503c085eSpeter klausler     return std::nullopt;
15347f18af5Speter klausler   }
154503c085eSpeter klausler   Constant<T> *folded{Folder<T>{context}.Folding(arg[arrayIndex])};
155503c085eSpeter klausler   if (!folded || folded->Rank() < 1) {
156503c085eSpeter klausler     return std::nullopt;
15747f18af5Speter klausler   }
15882568675SPeter Klausler   if (!CheckReductionDIM(dim, context, arg, dimIndex, folded->Rank())) {
159503c085eSpeter klausler     return std::nullopt;
16047f18af5Speter klausler   }
16182e1e412SPeter Klausler   std::size_t n{folded->size()};
16282e1e412SPeter Klausler   std::vector<Scalar<LogicalResult>> maskElement;
16382568675SPeter Klausler   if (maskIndex && static_cast<std::size_t>(*maskIndex) < arg.size() &&
16482568675SPeter Klausler       arg[*maskIndex]) {
16582e1e412SPeter Klausler     if (const Constant<LogicalResult> *origMask{
16682568675SPeter Klausler             GetReductionMASK(arg[*maskIndex], folded->shape(), context)}) {
16782e1e412SPeter Klausler       if (auto scalarMask{origMask->GetScalarValue()}) {
16882e1e412SPeter Klausler         maskElement =
16982e1e412SPeter Klausler             std::vector<Scalar<LogicalResult>>(n, scalarMask->IsTrue());
17047f18af5Speter klausler       } else {
17182e1e412SPeter Klausler         maskElement = origMask->values();
17247f18af5Speter klausler       }
173503c085eSpeter klausler     } else {
174503c085eSpeter klausler       return std::nullopt;
175503c085eSpeter klausler     }
176503c085eSpeter klausler   } else {
17782e1e412SPeter Klausler     maskElement = std::vector<Scalar<LogicalResult>>(n, true);
178503c085eSpeter klausler   }
17982e1e412SPeter Klausler   return ArrayAndMask<T>{Constant<T>(*folded),
18082e1e412SPeter Klausler       Constant<LogicalResult>{
18182e1e412SPeter Klausler           std::move(maskElement), ConstantSubscripts{folded->shape()}}};
182503c085eSpeter klausler }
183503c085eSpeter klausler 
184503c085eSpeter klausler // Generalized reduction to an array of one dimension fewer (w/ DIM=)
18539f1860dSPeter Klausler // or to a scalar (w/o DIM=).  The ACCUMULATOR type must define
18682e1e412SPeter Klausler // operator()(Scalar<T> &, const ConstantSubscripts &, bool first)
18782e1e412SPeter Klausler // and Done(Scalar<T> &).
18826aff847Speter klausler template <typename T, typename ACCUMULATOR, typename ARRAY>
18926aff847Speter klausler static Constant<T> DoReduction(const Constant<ARRAY> &array,
19082e1e412SPeter Klausler     const Constant<LogicalResult> &mask, std::optional<int> &dim,
19182e1e412SPeter Klausler     const Scalar<T> &identity, ACCUMULATOR &accumulator) {
192503c085eSpeter klausler   ConstantSubscripts at{array.lbounds()};
19382e1e412SPeter Klausler   ConstantSubscripts maskAt{mask.lbounds()};
194503c085eSpeter klausler   std::vector<typename Constant<T>::Element> elements;
195503c085eSpeter klausler   ConstantSubscripts resultShape; // empty -> scalar
196503c085eSpeter klausler   if (dim) { // DIM= is present, so result is an array
197503c085eSpeter klausler     resultShape = array.shape();
198503c085eSpeter klausler     resultShape.erase(resultShape.begin() + (*dim - 1));
199503c085eSpeter klausler     ConstantSubscript dimExtent{array.shape().at(*dim - 1)};
20082e1e412SPeter Klausler     CHECK(dimExtent == mask.shape().at(*dim - 1));
201503c085eSpeter klausler     ConstantSubscript &dimAt{at[*dim - 1]};
202503c085eSpeter klausler     ConstantSubscript dimLbound{dimAt};
20382e1e412SPeter Klausler     ConstantSubscript &maskDimAt{maskAt[*dim - 1]};
20482e1e412SPeter Klausler     ConstantSubscript maskDimLbound{maskDimAt};
205503c085eSpeter klausler     for (auto n{GetSize(resultShape)}; n-- > 0;
206b685597cSPeter Klausler          array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) {
207b685597cSPeter Klausler       elements.push_back(identity);
208b685597cSPeter Klausler       if (dimExtent > 0) {
209503c085eSpeter klausler         dimAt = dimLbound;
21082e1e412SPeter Klausler         maskDimAt = maskDimLbound;
21182e1e412SPeter Klausler         bool firstUnmasked{true};
21282e1e412SPeter Klausler         for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt, ++maskDimAt) {
21382e1e412SPeter Klausler           if (mask.At(maskAt).IsTrue()) {
21482e1e412SPeter Klausler             accumulator(elements.back(), at, firstUnmasked);
21582e1e412SPeter Klausler             firstUnmasked = false;
21682e1e412SPeter Klausler           }
217503c085eSpeter klausler         }
218b685597cSPeter Klausler         --dimAt, --maskDimAt;
219b685597cSPeter Klausler       }
22039f1860dSPeter Klausler       accumulator.Done(elements.back());
221503c085eSpeter klausler     }
222503c085eSpeter klausler   } else { // no DIM=, result is scalar
223503c085eSpeter klausler     elements.push_back(identity);
22482e1e412SPeter Klausler     bool firstUnmasked{true};
225b685597cSPeter Klausler     for (auto n{array.size()}; n-- > 0;
226b685597cSPeter Klausler          array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) {
22782e1e412SPeter Klausler       if (mask.At(maskAt).IsTrue()) {
22882e1e412SPeter Klausler         accumulator(elements.back(), at, firstUnmasked);
22982e1e412SPeter Klausler         firstUnmasked = false;
23082e1e412SPeter Klausler       }
231503c085eSpeter klausler     }
23239f1860dSPeter Klausler     accumulator.Done(elements.back());
233503c085eSpeter klausler   }
234503c085eSpeter klausler   if constexpr (T::category == TypeCategory::Character) {
235503c085eSpeter klausler     return {static_cast<ConstantSubscript>(identity.size()),
236503c085eSpeter klausler         std::move(elements), std::move(resultShape)};
237503c085eSpeter klausler   } else {
238503c085eSpeter klausler     return {std::move(elements), std::move(resultShape)};
239503c085eSpeter klausler   }
240503c085eSpeter klausler }
241503c085eSpeter klausler 
242503c085eSpeter klausler // MAXVAL & MINVAL
24339f1860dSPeter Klausler template <typename T, bool ABS = false> class MaxvalMinvalAccumulator {
24439f1860dSPeter Klausler public:
24539f1860dSPeter Klausler   MaxvalMinvalAccumulator(
24639f1860dSPeter Klausler       RelationalOperator opr, FoldingContext &context, const Constant<T> &array)
24739f1860dSPeter Klausler       : opr_{opr}, context_{context}, array_{array} {};
24882e1e412SPeter Klausler   void operator()(Scalar<T> &element, const ConstantSubscripts &at,
24982e1e412SPeter Klausler       [[maybe_unused]] bool firstUnmasked) const {
25039f1860dSPeter Klausler     auto aAt{array_.At(at)};
25139f1860dSPeter Klausler     if constexpr (ABS) {
25239f1860dSPeter Klausler       aAt = aAt.ABS();
25339f1860dSPeter Klausler     }
25482e1e412SPeter Klausler     if constexpr (T::category == TypeCategory::Real) {
25582e1e412SPeter Klausler       if (firstUnmasked || element.IsNotANumber()) {
25682e1e412SPeter Klausler         // Return NaN if and only if all unmasked elements are NaNs and
25782e1e412SPeter Klausler         // at least one unmasked element is visible.
25882e1e412SPeter Klausler         element = aAt;
25982e1e412SPeter Klausler         return;
26082e1e412SPeter Klausler       }
26182e1e412SPeter Klausler     }
26239f1860dSPeter Klausler     Expr<LogicalResult> test{PackageRelation(
26339f1860dSPeter Klausler         opr_, Expr<T>{Constant<T>{aAt}}, Expr<T>{Constant<T>{element}})};
26439f1860dSPeter Klausler     auto folded{GetScalarConstantValue<LogicalResult>(
26539f1860dSPeter Klausler         test.Rewrite(context_, std::move(test)))};
26639f1860dSPeter Klausler     CHECK(folded.has_value());
26739f1860dSPeter Klausler     if (folded->IsTrue()) {
268b225934aSPeter Klausler       element = aAt;
26939f1860dSPeter Klausler     }
27039f1860dSPeter Klausler   }
27139f1860dSPeter Klausler   void Done(Scalar<T> &) const {}
27239f1860dSPeter Klausler 
27339f1860dSPeter Klausler private:
27439f1860dSPeter Klausler   RelationalOperator opr_;
27539f1860dSPeter Klausler   FoldingContext &context_;
27639f1860dSPeter Klausler   const Constant<T> &array_;
27739f1860dSPeter Klausler };
27839f1860dSPeter Klausler 
279503c085eSpeter klausler template <typename T>
280503c085eSpeter klausler static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
281503c085eSpeter klausler     RelationalOperator opr, const Scalar<T> &identity) {
282503c085eSpeter klausler   static_assert(T::category == TypeCategory::Integer ||
283*fc97d2e6SPeter Klausler       T::category == TypeCategory::Unsigned ||
284503c085eSpeter klausler       T::category == TypeCategory::Real ||
285503c085eSpeter klausler       T::category == TypeCategory::Character);
28682568675SPeter Klausler   std::optional<int> dim;
28782e1e412SPeter Klausler   if (std::optional<ArrayAndMask<T>> arrayAndMask{
28882e1e412SPeter Klausler           ProcessReductionArgs<T>(context, ref.arguments(), dim,
289503c085eSpeter klausler               /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
290*fc97d2e6SPeter Klausler     MaxvalMinvalAccumulator<T> accumulator{opr, context, arrayAndMask->array};
29182e1e412SPeter Klausler     return Expr<T>{DoReduction<T>(
29282e1e412SPeter Klausler         arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)};
293503c085eSpeter klausler   }
294503c085eSpeter klausler   return Expr<T>{std::move(ref)};
295503c085eSpeter klausler }
296503c085eSpeter klausler 
297503c085eSpeter klausler // PRODUCT
29839f1860dSPeter Klausler template <typename T> class ProductAccumulator {
29939f1860dSPeter Klausler public:
30039f1860dSPeter Klausler   ProductAccumulator(const Constant<T> &array) : array_{array} {}
30182e1e412SPeter Klausler   void operator()(
30282e1e412SPeter Klausler       Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
30339f1860dSPeter Klausler     if constexpr (T::category == TypeCategory::Integer) {
30439f1860dSPeter Klausler       auto prod{element.MultiplySigned(array_.At(at))};
30539f1860dSPeter Klausler       overflow_ |= prod.SignedMultiplicationOverflowed();
30639f1860dSPeter Klausler       element = prod.lower;
307*fc97d2e6SPeter Klausler     } else if constexpr (T::category == TypeCategory::Unsigned) {
308*fc97d2e6SPeter Klausler       element = element.MultiplyUnsigned(array_.At(at)).lower;
30939f1860dSPeter Klausler     } else { // Real & Complex
31039f1860dSPeter Klausler       auto prod{element.Multiply(array_.At(at))};
31139f1860dSPeter Klausler       overflow_ |= prod.flags.test(RealFlag::Overflow);
31239f1860dSPeter Klausler       element = prod.value;
31339f1860dSPeter Klausler     }
31439f1860dSPeter Klausler   }
31539f1860dSPeter Klausler   bool overflow() const { return overflow_; }
31639f1860dSPeter Klausler   void Done(Scalar<T> &) const {}
31739f1860dSPeter Klausler 
31839f1860dSPeter Klausler private:
31939f1860dSPeter Klausler   const Constant<T> &array_;
32039f1860dSPeter Klausler   bool overflow_{false};
32139f1860dSPeter Klausler };
32239f1860dSPeter Klausler 
323503c085eSpeter klausler template <typename T>
324503c085eSpeter klausler static Expr<T> FoldProduct(
325503c085eSpeter klausler     FoldingContext &context, FunctionRef<T> &&ref, Scalar<T> identity) {
326503c085eSpeter klausler   static_assert(T::category == TypeCategory::Integer ||
327*fc97d2e6SPeter Klausler       T::category == TypeCategory::Unsigned ||
328503c085eSpeter klausler       T::category == TypeCategory::Real ||
329503c085eSpeter klausler       T::category == TypeCategory::Complex);
33082568675SPeter Klausler   std::optional<int> dim;
33182e1e412SPeter Klausler   if (std::optional<ArrayAndMask<T>> arrayAndMask{
33282e1e412SPeter Klausler           ProcessReductionArgs<T>(context, ref.arguments(), dim,
333503c085eSpeter klausler               /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
33482e1e412SPeter Klausler     ProductAccumulator accumulator{arrayAndMask->array};
33582e1e412SPeter Klausler     auto result{Expr<T>{DoReduction<T>(
33682e1e412SPeter Klausler         arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}};
337505f6da1SPeter Klausler     if (accumulator.overflow() &&
338505f6da1SPeter Klausler         context.languageFeatures().ShouldWarn(
339505f6da1SPeter Klausler             common::UsageWarning::FoldingException)) {
3400f973ac7SPeter Klausler       context.messages().Say(common::UsageWarning::FoldingException,
341a53967cdSPeter Klausler           "PRODUCT() of %s data overflowed"_warn_en_US, T::AsFortran());
342503c085eSpeter klausler     }
34358d74843SPeter Klausler     return result;
344503c085eSpeter klausler   }
345503c085eSpeter klausler   return Expr<T>{std::move(ref)};
346503c085eSpeter klausler }
347503c085eSpeter klausler 
348503c085eSpeter klausler // SUM
34939f1860dSPeter Klausler template <typename T> class SumAccumulator {
35039f1860dSPeter Klausler   using Element = typename Constant<T>::Element;
35139f1860dSPeter Klausler 
35239f1860dSPeter Klausler public:
35339f1860dSPeter Klausler   SumAccumulator(const Constant<T> &array, Rounding rounding)
35439f1860dSPeter Klausler       : array_{array}, rounding_{rounding} {}
35582e1e412SPeter Klausler   void operator()(
35682e1e412SPeter Klausler       Element &element, const ConstantSubscripts &at, bool /*first*/) {
35739f1860dSPeter Klausler     if constexpr (T::category == TypeCategory::Integer) {
35839f1860dSPeter Klausler       auto sum{element.AddSigned(array_.At(at))};
35939f1860dSPeter Klausler       overflow_ |= sum.overflow;
36039f1860dSPeter Klausler       element = sum.value;
361*fc97d2e6SPeter Klausler     } else if constexpr (T::category == TypeCategory::Unsigned) {
362*fc97d2e6SPeter Klausler       element = element.AddUnsigned(array_.At(at)).value;
36339f1860dSPeter Klausler     } else { // Real & Complex: use Kahan summation
3643f594741SPeter Klausler       auto next{array_.At(at).Subtract(correction_, rounding_)};
36539f1860dSPeter Klausler       overflow_ |= next.flags.test(RealFlag::Overflow);
36639f1860dSPeter Klausler       auto sum{element.Add(next.value, rounding_)};
36739f1860dSPeter Klausler       overflow_ |= sum.flags.test(RealFlag::Overflow);
36839f1860dSPeter Klausler       // correction = (sum - element) - next; algebraically zero
36939f1860dSPeter Klausler       correction_ = sum.value.Subtract(element, rounding_)
37039f1860dSPeter Klausler                         .value.Subtract(next.value, rounding_)
37139f1860dSPeter Klausler                         .value;
37239f1860dSPeter Klausler       element = sum.value;
37339f1860dSPeter Klausler     }
37439f1860dSPeter Klausler   }
37539f1860dSPeter Klausler   bool overflow() const { return overflow_; }
37639f1860dSPeter Klausler   void Done([[maybe_unused]] Element &element) {
377*fc97d2e6SPeter Klausler     if constexpr (T::category != TypeCategory::Integer &&
378*fc97d2e6SPeter Klausler         T::category != TypeCategory::Unsigned) {
37939f1860dSPeter Klausler       auto corrected{element.Add(correction_, rounding_)};
38039f1860dSPeter Klausler       overflow_ |= corrected.flags.test(RealFlag::Overflow);
38139f1860dSPeter Klausler       correction_ = Scalar<T>{};
38239f1860dSPeter Klausler       element = corrected.value;
38339f1860dSPeter Klausler     }
38439f1860dSPeter Klausler   }
38539f1860dSPeter Klausler 
38639f1860dSPeter Klausler private:
38739f1860dSPeter Klausler   const Constant<T> &array_;
38839f1860dSPeter Klausler   Rounding rounding_;
38939f1860dSPeter Klausler   bool overflow_{false};
39039f1860dSPeter Klausler   Element correction_{};
39139f1860dSPeter Klausler };
39239f1860dSPeter Klausler 
393503c085eSpeter klausler template <typename T>
394503c085eSpeter klausler static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
395503c085eSpeter klausler   static_assert(T::category == TypeCategory::Integer ||
396*fc97d2e6SPeter Klausler       T::category == TypeCategory::Unsigned ||
397503c085eSpeter klausler       T::category == TypeCategory::Real ||
398503c085eSpeter klausler       T::category == TypeCategory::Complex);
399503c085eSpeter klausler   using Element = typename Constant<T>::Element;
40082568675SPeter Klausler   std::optional<int> dim;
40139f1860dSPeter Klausler   Element identity{};
40282e1e412SPeter Klausler   if (std::optional<ArrayAndMask<T>> arrayAndMask{
40382e1e412SPeter Klausler           ProcessReductionArgs<T>(context, ref.arguments(), dim,
404503c085eSpeter klausler               /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
40539f1860dSPeter Klausler     SumAccumulator accumulator{
40682e1e412SPeter Klausler         arrayAndMask->array, context.targetCharacteristics().roundingMode()};
40782e1e412SPeter Klausler     auto result{Expr<T>{DoReduction<T>(
40882e1e412SPeter Klausler         arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}};
409505f6da1SPeter Klausler     if (accumulator.overflow() &&
410505f6da1SPeter Klausler         context.languageFeatures().ShouldWarn(
411505f6da1SPeter Klausler             common::UsageWarning::FoldingException)) {
4120f973ac7SPeter Klausler       context.messages().Say(common::UsageWarning::FoldingException,
413a53967cdSPeter Klausler           "SUM() of %s data overflowed"_warn_en_US, T::AsFortran());
414503c085eSpeter klausler     }
41558d74843SPeter Klausler     return result;
416503c085eSpeter klausler   }
417503c085eSpeter klausler   return Expr<T>{std::move(ref)};
41847f18af5Speter klausler }
41947f18af5Speter klausler 
42039f1860dSPeter Klausler // Utility for IALL, IANY, IPARITY, ALL, ANY, & PARITY
42139f1860dSPeter Klausler template <typename T> class OperationAccumulator {
42239f1860dSPeter Klausler public:
42339f1860dSPeter Klausler   OperationAccumulator(const Constant<T> &array,
42439f1860dSPeter Klausler       Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const)
42539f1860dSPeter Klausler       : array_{array}, operation_{operation} {}
42682e1e412SPeter Klausler   void operator()(
42782e1e412SPeter Klausler       Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
42839f1860dSPeter Klausler     element = (element.*operation_)(array_.At(at));
42939f1860dSPeter Klausler   }
43039f1860dSPeter Klausler   void Done(Scalar<T> &) const {}
43139f1860dSPeter Klausler 
43239f1860dSPeter Klausler private:
43339f1860dSPeter Klausler   const Constant<T> &array_;
43439f1860dSPeter Klausler   Scalar<T> (Scalar<T>::*operation_)(const Scalar<T> &) const;
43539f1860dSPeter Klausler };
43639f1860dSPeter Klausler 
43747f18af5Speter klausler } // namespace Fortran::evaluate
43847f18af5Speter klausler #endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_
439