147f18af5Speter klausler //===-- lib/Evaluate/fold-reduction.h -------------------------------------===// 247f18af5Speter klausler // 347f18af5Speter klausler // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 447f18af5Speter klausler // See https://llvm.org/LICENSE.txt for license information. 547f18af5Speter klausler // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 647f18af5Speter klausler // 747f18af5Speter klausler //===----------------------------------------------------------------------===// 847f18af5Speter klausler 947f18af5Speter klausler #ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_ 1047f18af5Speter klausler #define FORTRAN_EVALUATE_FOLD_REDUCTION_H_ 1147f18af5Speter klausler 1247f18af5Speter klausler #include "fold-implementation.h" 1347f18af5Speter klausler 1447f18af5Speter klausler namespace Fortran::evaluate { 1547f18af5Speter klausler 16e723c69bSPeter Klausler // DOT_PRODUCT 17e723c69bSPeter Klausler template <typename T> 18e723c69bSPeter Klausler static Expr<T> FoldDotProduct( 19e723c69bSPeter Klausler FoldingContext &context, FunctionRef<T> &&funcRef) { 20e723c69bSPeter Klausler using Element = typename Constant<T>::Element; 21e723c69bSPeter Klausler auto args{funcRef.arguments()}; 22e723c69bSPeter Klausler CHECK(args.size() == 2); 23e723c69bSPeter Klausler Folder<T> folder{context}; 24e723c69bSPeter Klausler Constant<T> *va{folder.Folding(args[0])}; 25e723c69bSPeter Klausler Constant<T> *vb{folder.Folding(args[1])}; 26e723c69bSPeter Klausler if (va && vb) { 27e723c69bSPeter Klausler CHECK(va->Rank() == 1 && vb->Rank() == 1); 28e723c69bSPeter Klausler if (va->size() != vb->size()) { 29e723c69bSPeter Klausler context.messages().Say( 30e723c69bSPeter Klausler "Vector arguments to DOT_PRODUCT have distinct extents %zd and %zd"_err_en_US, 31e723c69bSPeter Klausler va->size(), vb->size()); 32e723c69bSPeter Klausler return MakeInvalidIntrinsic(std::move(funcRef)); 33e723c69bSPeter Klausler } 34e723c69bSPeter Klausler Element sum{}; 35e723c69bSPeter Klausler bool overflow{false}; 36e723c69bSPeter Klausler if constexpr (T::category == TypeCategory::Complex) { 37e723c69bSPeter Klausler std::vector<Element> conjugates; 38e723c69bSPeter Klausler for (const Element &x : va->values()) { 39e723c69bSPeter Klausler conjugates.emplace_back(x.CONJG()); 40e723c69bSPeter Klausler } 41e723c69bSPeter Klausler Constant<T> conjgA{ 42e723c69bSPeter Klausler std::move(conjugates), ConstantSubscripts{va->shape()}}; 43e723c69bSPeter Klausler Expr<T> products{Fold( 44e723c69bSPeter Klausler context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})}; 45e723c69bSPeter Klausler Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))}; 463502d340SPeter Klausler [[maybe_unused]] Element correction{}; 47e723c69bSPeter Klausler const auto &rounding{context.targetCharacteristics().roundingMode()}; 48e723c69bSPeter Klausler for (const Element &x : cProducts.values()) { 493502d340SPeter Klausler if constexpr (useKahanSummation) { 503f594741SPeter Klausler auto next{x.Subtract(correction, rounding)}; 51e723c69bSPeter Klausler overflow |= next.flags.test(RealFlag::Overflow); 52e723c69bSPeter Klausler auto added{sum.Add(next.value, rounding)}; 53e723c69bSPeter Klausler overflow |= added.flags.test(RealFlag::Overflow); 54e723c69bSPeter Klausler correction = added.value.Subtract(sum, rounding) 55e723c69bSPeter Klausler .value.Subtract(next.value, rounding) 56e723c69bSPeter Klausler .value; 57e723c69bSPeter Klausler sum = std::move(added.value); 583502d340SPeter Klausler } else { 593502d340SPeter Klausler auto added{sum.Add(x, rounding)}; 603502d340SPeter Klausler overflow |= added.flags.test(RealFlag::Overflow); 613502d340SPeter Klausler sum = std::move(added.value); 623502d340SPeter Klausler } 63e723c69bSPeter Klausler } 64e723c69bSPeter Klausler } else if constexpr (T::category == TypeCategory::Logical) { 65e723c69bSPeter Klausler Expr<T> conjunctions{Fold(context, 66e723c69bSPeter Klausler Expr<T>{LogicalOperation<T::kind>{LogicalOperator::And, 67e723c69bSPeter Klausler Expr<T>{Constant<T>{*va}}, Expr<T>{Constant<T>{*vb}}}})}; 68e723c69bSPeter Klausler Constant<T> &cConjunctions{DEREF(UnwrapConstantValue<T>(conjunctions))}; 69e723c69bSPeter Klausler for (const Element &x : cConjunctions.values()) { 70e723c69bSPeter Klausler if (x.IsTrue()) { 71e723c69bSPeter Klausler sum = Element{true}; 72e723c69bSPeter Klausler break; 73e723c69bSPeter Klausler } 74e723c69bSPeter Klausler } 75e723c69bSPeter Klausler } else if constexpr (T::category == TypeCategory::Integer) { 76e723c69bSPeter Klausler Expr<T> products{ 77e723c69bSPeter Klausler Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})}; 78e723c69bSPeter Klausler Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))}; 79e723c69bSPeter Klausler for (const Element &x : cProducts.values()) { 80e723c69bSPeter Klausler auto next{sum.AddSigned(x)}; 81e723c69bSPeter Klausler overflow |= next.overflow; 82e723c69bSPeter Klausler sum = std::move(next.value); 83e723c69bSPeter Klausler } 84*fc97d2e6SPeter Klausler } else if constexpr (T::category == TypeCategory::Unsigned) { 85*fc97d2e6SPeter Klausler Expr<T> products{ 86*fc97d2e6SPeter Klausler Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})}; 87*fc97d2e6SPeter Klausler Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))}; 88*fc97d2e6SPeter Klausler for (const Element &x : cProducts.values()) { 89*fc97d2e6SPeter Klausler sum = sum.AddUnsigned(x).value; 90*fc97d2e6SPeter Klausler } 9139f1860dSPeter Klausler } else { 9239f1860dSPeter Klausler static_assert(T::category == TypeCategory::Real); 93e723c69bSPeter Klausler Expr<T> products{ 94e723c69bSPeter Klausler Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})}; 95e723c69bSPeter Klausler Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))}; 963502d340SPeter Klausler [[maybe_unused]] Element correction{}; 97e723c69bSPeter Klausler const auto &rounding{context.targetCharacteristics().roundingMode()}; 98e723c69bSPeter Klausler for (const Element &x : cProducts.values()) { 993502d340SPeter Klausler if constexpr (useKahanSummation) { 1003f594741SPeter Klausler auto next{x.Subtract(correction, rounding)}; 101e723c69bSPeter Klausler overflow |= next.flags.test(RealFlag::Overflow); 102e723c69bSPeter Klausler auto added{sum.Add(next.value, rounding)}; 103e723c69bSPeter Klausler overflow |= added.flags.test(RealFlag::Overflow); 104e723c69bSPeter Klausler correction = added.value.Subtract(sum, rounding) 105e723c69bSPeter Klausler .value.Subtract(next.value, rounding) 106e723c69bSPeter Klausler .value; 107e723c69bSPeter Klausler sum = std::move(added.value); 1083502d340SPeter Klausler } else { 1093502d340SPeter Klausler auto added{sum.Add(x, rounding)}; 1103502d340SPeter Klausler overflow |= added.flags.test(RealFlag::Overflow); 1113502d340SPeter Klausler sum = std::move(added.value); 1123502d340SPeter Klausler } 113e723c69bSPeter Klausler } 114e723c69bSPeter Klausler } 115505f6da1SPeter Klausler if (overflow && 116505f6da1SPeter Klausler context.languageFeatures().ShouldWarn( 117505f6da1SPeter Klausler common::UsageWarning::FoldingException)) { 1180f973ac7SPeter Klausler context.messages().Say(common::UsageWarning::FoldingException, 119e723c69bSPeter Klausler "DOT_PRODUCT of %s data overflowed during computation"_warn_en_US, 120e723c69bSPeter Klausler T::AsFortran()); 121e723c69bSPeter Klausler } 122e723c69bSPeter Klausler return Expr<T>{Constant<T>{std::move(sum)}}; 123e723c69bSPeter Klausler } 124e723c69bSPeter Klausler return Expr<T>{std::move(funcRef)}; 125e723c69bSPeter Klausler } 126e723c69bSPeter Klausler 127e723c69bSPeter Klausler // Fold and validate a DIM= argument. Returns false on error. 12882568675SPeter Klausler bool CheckReductionDIM(std::optional<int> &dim, FoldingContext &, 12982568675SPeter Klausler ActualArguments &, std::optional<int> dimIndex, int rank); 13082568675SPeter Klausler 13182568675SPeter Klausler // Fold and validate a MASK= argument. Return null on error, absent MASK=, or 13282568675SPeter Klausler // non-constant MASK=. 13382568675SPeter Klausler Constant<LogicalResult> *GetReductionMASK( 13482568675SPeter Klausler std::optional<ActualArgument> &maskArg, const ConstantSubscripts &shape, 13582568675SPeter Klausler FoldingContext &); 13626aff847Speter klausler 137503c085eSpeter klausler // Common preprocessing for reduction transformational intrinsic function 138503c085eSpeter klausler // folding. If the intrinsic can have DIM= &/or MASK= arguments, extract 139503c085eSpeter klausler // and check them. If a MASK= is present, apply it to the array data and 14082e1e412SPeter Klausler // substitute replacement values for elements corresponding to .FALSE. in 141503c085eSpeter klausler // the mask. If the result is present, the intrinsic call can be folded. 14282e1e412SPeter Klausler template <typename T> struct ArrayAndMask { 14382e1e412SPeter Klausler Constant<T> array; 14482e1e412SPeter Klausler Constant<LogicalResult> mask; 14582e1e412SPeter Klausler }; 14647f18af5Speter klausler template <typename T> 14782e1e412SPeter Klausler static std::optional<ArrayAndMask<T>> ProcessReductionArgs( 14882e1e412SPeter Klausler FoldingContext &context, ActualArguments &arg, std::optional<int> &dim, 14982568675SPeter Klausler int arrayIndex, std::optional<int> dimIndex = std::nullopt, 15082568675SPeter Klausler std::optional<int> maskIndex = std::nullopt) { 15147f18af5Speter klausler if (arg.empty()) { 152503c085eSpeter klausler return std::nullopt; 15347f18af5Speter klausler } 154503c085eSpeter klausler Constant<T> *folded{Folder<T>{context}.Folding(arg[arrayIndex])}; 155503c085eSpeter klausler if (!folded || folded->Rank() < 1) { 156503c085eSpeter klausler return std::nullopt; 15747f18af5Speter klausler } 15882568675SPeter Klausler if (!CheckReductionDIM(dim, context, arg, dimIndex, folded->Rank())) { 159503c085eSpeter klausler return std::nullopt; 16047f18af5Speter klausler } 16182e1e412SPeter Klausler std::size_t n{folded->size()}; 16282e1e412SPeter Klausler std::vector<Scalar<LogicalResult>> maskElement; 16382568675SPeter Klausler if (maskIndex && static_cast<std::size_t>(*maskIndex) < arg.size() && 16482568675SPeter Klausler arg[*maskIndex]) { 16582e1e412SPeter Klausler if (const Constant<LogicalResult> *origMask{ 16682568675SPeter Klausler GetReductionMASK(arg[*maskIndex], folded->shape(), context)}) { 16782e1e412SPeter Klausler if (auto scalarMask{origMask->GetScalarValue()}) { 16882e1e412SPeter Klausler maskElement = 16982e1e412SPeter Klausler std::vector<Scalar<LogicalResult>>(n, scalarMask->IsTrue()); 17047f18af5Speter klausler } else { 17182e1e412SPeter Klausler maskElement = origMask->values(); 17247f18af5Speter klausler } 173503c085eSpeter klausler } else { 174503c085eSpeter klausler return std::nullopt; 175503c085eSpeter klausler } 176503c085eSpeter klausler } else { 17782e1e412SPeter Klausler maskElement = std::vector<Scalar<LogicalResult>>(n, true); 178503c085eSpeter klausler } 17982e1e412SPeter Klausler return ArrayAndMask<T>{Constant<T>(*folded), 18082e1e412SPeter Klausler Constant<LogicalResult>{ 18182e1e412SPeter Klausler std::move(maskElement), ConstantSubscripts{folded->shape()}}}; 182503c085eSpeter klausler } 183503c085eSpeter klausler 184503c085eSpeter klausler // Generalized reduction to an array of one dimension fewer (w/ DIM=) 18539f1860dSPeter Klausler // or to a scalar (w/o DIM=). The ACCUMULATOR type must define 18682e1e412SPeter Klausler // operator()(Scalar<T> &, const ConstantSubscripts &, bool first) 18782e1e412SPeter Klausler // and Done(Scalar<T> &). 18826aff847Speter klausler template <typename T, typename ACCUMULATOR, typename ARRAY> 18926aff847Speter klausler static Constant<T> DoReduction(const Constant<ARRAY> &array, 19082e1e412SPeter Klausler const Constant<LogicalResult> &mask, std::optional<int> &dim, 19182e1e412SPeter Klausler const Scalar<T> &identity, ACCUMULATOR &accumulator) { 192503c085eSpeter klausler ConstantSubscripts at{array.lbounds()}; 19382e1e412SPeter Klausler ConstantSubscripts maskAt{mask.lbounds()}; 194503c085eSpeter klausler std::vector<typename Constant<T>::Element> elements; 195503c085eSpeter klausler ConstantSubscripts resultShape; // empty -> scalar 196503c085eSpeter klausler if (dim) { // DIM= is present, so result is an array 197503c085eSpeter klausler resultShape = array.shape(); 198503c085eSpeter klausler resultShape.erase(resultShape.begin() + (*dim - 1)); 199503c085eSpeter klausler ConstantSubscript dimExtent{array.shape().at(*dim - 1)}; 20082e1e412SPeter Klausler CHECK(dimExtent == mask.shape().at(*dim - 1)); 201503c085eSpeter klausler ConstantSubscript &dimAt{at[*dim - 1]}; 202503c085eSpeter klausler ConstantSubscript dimLbound{dimAt}; 20382e1e412SPeter Klausler ConstantSubscript &maskDimAt{maskAt[*dim - 1]}; 20482e1e412SPeter Klausler ConstantSubscript maskDimLbound{maskDimAt}; 205503c085eSpeter klausler for (auto n{GetSize(resultShape)}; n-- > 0; 206b685597cSPeter Klausler array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) { 207b685597cSPeter Klausler elements.push_back(identity); 208b685597cSPeter Klausler if (dimExtent > 0) { 209503c085eSpeter klausler dimAt = dimLbound; 21082e1e412SPeter Klausler maskDimAt = maskDimLbound; 21182e1e412SPeter Klausler bool firstUnmasked{true}; 21282e1e412SPeter Klausler for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt, ++maskDimAt) { 21382e1e412SPeter Klausler if (mask.At(maskAt).IsTrue()) { 21482e1e412SPeter Klausler accumulator(elements.back(), at, firstUnmasked); 21582e1e412SPeter Klausler firstUnmasked = false; 21682e1e412SPeter Klausler } 217503c085eSpeter klausler } 218b685597cSPeter Klausler --dimAt, --maskDimAt; 219b685597cSPeter Klausler } 22039f1860dSPeter Klausler accumulator.Done(elements.back()); 221503c085eSpeter klausler } 222503c085eSpeter klausler } else { // no DIM=, result is scalar 223503c085eSpeter klausler elements.push_back(identity); 22482e1e412SPeter Klausler bool firstUnmasked{true}; 225b685597cSPeter Klausler for (auto n{array.size()}; n-- > 0; 226b685597cSPeter Klausler array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) { 22782e1e412SPeter Klausler if (mask.At(maskAt).IsTrue()) { 22882e1e412SPeter Klausler accumulator(elements.back(), at, firstUnmasked); 22982e1e412SPeter Klausler firstUnmasked = false; 23082e1e412SPeter Klausler } 231503c085eSpeter klausler } 23239f1860dSPeter Klausler accumulator.Done(elements.back()); 233503c085eSpeter klausler } 234503c085eSpeter klausler if constexpr (T::category == TypeCategory::Character) { 235503c085eSpeter klausler return {static_cast<ConstantSubscript>(identity.size()), 236503c085eSpeter klausler std::move(elements), std::move(resultShape)}; 237503c085eSpeter klausler } else { 238503c085eSpeter klausler return {std::move(elements), std::move(resultShape)}; 239503c085eSpeter klausler } 240503c085eSpeter klausler } 241503c085eSpeter klausler 242503c085eSpeter klausler // MAXVAL & MINVAL 24339f1860dSPeter Klausler template <typename T, bool ABS = false> class MaxvalMinvalAccumulator { 24439f1860dSPeter Klausler public: 24539f1860dSPeter Klausler MaxvalMinvalAccumulator( 24639f1860dSPeter Klausler RelationalOperator opr, FoldingContext &context, const Constant<T> &array) 24739f1860dSPeter Klausler : opr_{opr}, context_{context}, array_{array} {}; 24882e1e412SPeter Klausler void operator()(Scalar<T> &element, const ConstantSubscripts &at, 24982e1e412SPeter Klausler [[maybe_unused]] bool firstUnmasked) const { 25039f1860dSPeter Klausler auto aAt{array_.At(at)}; 25139f1860dSPeter Klausler if constexpr (ABS) { 25239f1860dSPeter Klausler aAt = aAt.ABS(); 25339f1860dSPeter Klausler } 25482e1e412SPeter Klausler if constexpr (T::category == TypeCategory::Real) { 25582e1e412SPeter Klausler if (firstUnmasked || element.IsNotANumber()) { 25682e1e412SPeter Klausler // Return NaN if and only if all unmasked elements are NaNs and 25782e1e412SPeter Klausler // at least one unmasked element is visible. 25882e1e412SPeter Klausler element = aAt; 25982e1e412SPeter Klausler return; 26082e1e412SPeter Klausler } 26182e1e412SPeter Klausler } 26239f1860dSPeter Klausler Expr<LogicalResult> test{PackageRelation( 26339f1860dSPeter Klausler opr_, Expr<T>{Constant<T>{aAt}}, Expr<T>{Constant<T>{element}})}; 26439f1860dSPeter Klausler auto folded{GetScalarConstantValue<LogicalResult>( 26539f1860dSPeter Klausler test.Rewrite(context_, std::move(test)))}; 26639f1860dSPeter Klausler CHECK(folded.has_value()); 26739f1860dSPeter Klausler if (folded->IsTrue()) { 268b225934aSPeter Klausler element = aAt; 26939f1860dSPeter Klausler } 27039f1860dSPeter Klausler } 27139f1860dSPeter Klausler void Done(Scalar<T> &) const {} 27239f1860dSPeter Klausler 27339f1860dSPeter Klausler private: 27439f1860dSPeter Klausler RelationalOperator opr_; 27539f1860dSPeter Klausler FoldingContext &context_; 27639f1860dSPeter Klausler const Constant<T> &array_; 27739f1860dSPeter Klausler }; 27839f1860dSPeter Klausler 279503c085eSpeter klausler template <typename T> 280503c085eSpeter klausler static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref, 281503c085eSpeter klausler RelationalOperator opr, const Scalar<T> &identity) { 282503c085eSpeter klausler static_assert(T::category == TypeCategory::Integer || 283*fc97d2e6SPeter Klausler T::category == TypeCategory::Unsigned || 284503c085eSpeter klausler T::category == TypeCategory::Real || 285503c085eSpeter klausler T::category == TypeCategory::Character); 28682568675SPeter Klausler std::optional<int> dim; 28782e1e412SPeter Klausler if (std::optional<ArrayAndMask<T>> arrayAndMask{ 28882e1e412SPeter Klausler ProcessReductionArgs<T>(context, ref.arguments(), dim, 289503c085eSpeter klausler /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) { 290*fc97d2e6SPeter Klausler MaxvalMinvalAccumulator<T> accumulator{opr, context, arrayAndMask->array}; 29182e1e412SPeter Klausler return Expr<T>{DoReduction<T>( 29282e1e412SPeter Klausler arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}; 293503c085eSpeter klausler } 294503c085eSpeter klausler return Expr<T>{std::move(ref)}; 295503c085eSpeter klausler } 296503c085eSpeter klausler 297503c085eSpeter klausler // PRODUCT 29839f1860dSPeter Klausler template <typename T> class ProductAccumulator { 29939f1860dSPeter Klausler public: 30039f1860dSPeter Klausler ProductAccumulator(const Constant<T> &array) : array_{array} {} 30182e1e412SPeter Klausler void operator()( 30282e1e412SPeter Klausler Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) { 30339f1860dSPeter Klausler if constexpr (T::category == TypeCategory::Integer) { 30439f1860dSPeter Klausler auto prod{element.MultiplySigned(array_.At(at))}; 30539f1860dSPeter Klausler overflow_ |= prod.SignedMultiplicationOverflowed(); 30639f1860dSPeter Klausler element = prod.lower; 307*fc97d2e6SPeter Klausler } else if constexpr (T::category == TypeCategory::Unsigned) { 308*fc97d2e6SPeter Klausler element = element.MultiplyUnsigned(array_.At(at)).lower; 30939f1860dSPeter Klausler } else { // Real & Complex 31039f1860dSPeter Klausler auto prod{element.Multiply(array_.At(at))}; 31139f1860dSPeter Klausler overflow_ |= prod.flags.test(RealFlag::Overflow); 31239f1860dSPeter Klausler element = prod.value; 31339f1860dSPeter Klausler } 31439f1860dSPeter Klausler } 31539f1860dSPeter Klausler bool overflow() const { return overflow_; } 31639f1860dSPeter Klausler void Done(Scalar<T> &) const {} 31739f1860dSPeter Klausler 31839f1860dSPeter Klausler private: 31939f1860dSPeter Klausler const Constant<T> &array_; 32039f1860dSPeter Klausler bool overflow_{false}; 32139f1860dSPeter Klausler }; 32239f1860dSPeter Klausler 323503c085eSpeter klausler template <typename T> 324503c085eSpeter klausler static Expr<T> FoldProduct( 325503c085eSpeter klausler FoldingContext &context, FunctionRef<T> &&ref, Scalar<T> identity) { 326503c085eSpeter klausler static_assert(T::category == TypeCategory::Integer || 327*fc97d2e6SPeter Klausler T::category == TypeCategory::Unsigned || 328503c085eSpeter klausler T::category == TypeCategory::Real || 329503c085eSpeter klausler T::category == TypeCategory::Complex); 33082568675SPeter Klausler std::optional<int> dim; 33182e1e412SPeter Klausler if (std::optional<ArrayAndMask<T>> arrayAndMask{ 33282e1e412SPeter Klausler ProcessReductionArgs<T>(context, ref.arguments(), dim, 333503c085eSpeter klausler /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) { 33482e1e412SPeter Klausler ProductAccumulator accumulator{arrayAndMask->array}; 33582e1e412SPeter Klausler auto result{Expr<T>{DoReduction<T>( 33682e1e412SPeter Klausler arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}}; 337505f6da1SPeter Klausler if (accumulator.overflow() && 338505f6da1SPeter Klausler context.languageFeatures().ShouldWarn( 339505f6da1SPeter Klausler common::UsageWarning::FoldingException)) { 3400f973ac7SPeter Klausler context.messages().Say(common::UsageWarning::FoldingException, 341a53967cdSPeter Klausler "PRODUCT() of %s data overflowed"_warn_en_US, T::AsFortran()); 342503c085eSpeter klausler } 34358d74843SPeter Klausler return result; 344503c085eSpeter klausler } 345503c085eSpeter klausler return Expr<T>{std::move(ref)}; 346503c085eSpeter klausler } 347503c085eSpeter klausler 348503c085eSpeter klausler // SUM 34939f1860dSPeter Klausler template <typename T> class SumAccumulator { 35039f1860dSPeter Klausler using Element = typename Constant<T>::Element; 35139f1860dSPeter Klausler 35239f1860dSPeter Klausler public: 35339f1860dSPeter Klausler SumAccumulator(const Constant<T> &array, Rounding rounding) 35439f1860dSPeter Klausler : array_{array}, rounding_{rounding} {} 35582e1e412SPeter Klausler void operator()( 35682e1e412SPeter Klausler Element &element, const ConstantSubscripts &at, bool /*first*/) { 35739f1860dSPeter Klausler if constexpr (T::category == TypeCategory::Integer) { 35839f1860dSPeter Klausler auto sum{element.AddSigned(array_.At(at))}; 35939f1860dSPeter Klausler overflow_ |= sum.overflow; 36039f1860dSPeter Klausler element = sum.value; 361*fc97d2e6SPeter Klausler } else if constexpr (T::category == TypeCategory::Unsigned) { 362*fc97d2e6SPeter Klausler element = element.AddUnsigned(array_.At(at)).value; 36339f1860dSPeter Klausler } else { // Real & Complex: use Kahan summation 3643f594741SPeter Klausler auto next{array_.At(at).Subtract(correction_, rounding_)}; 36539f1860dSPeter Klausler overflow_ |= next.flags.test(RealFlag::Overflow); 36639f1860dSPeter Klausler auto sum{element.Add(next.value, rounding_)}; 36739f1860dSPeter Klausler overflow_ |= sum.flags.test(RealFlag::Overflow); 36839f1860dSPeter Klausler // correction = (sum - element) - next; algebraically zero 36939f1860dSPeter Klausler correction_ = sum.value.Subtract(element, rounding_) 37039f1860dSPeter Klausler .value.Subtract(next.value, rounding_) 37139f1860dSPeter Klausler .value; 37239f1860dSPeter Klausler element = sum.value; 37339f1860dSPeter Klausler } 37439f1860dSPeter Klausler } 37539f1860dSPeter Klausler bool overflow() const { return overflow_; } 37639f1860dSPeter Klausler void Done([[maybe_unused]] Element &element) { 377*fc97d2e6SPeter Klausler if constexpr (T::category != TypeCategory::Integer && 378*fc97d2e6SPeter Klausler T::category != TypeCategory::Unsigned) { 37939f1860dSPeter Klausler auto corrected{element.Add(correction_, rounding_)}; 38039f1860dSPeter Klausler overflow_ |= corrected.flags.test(RealFlag::Overflow); 38139f1860dSPeter Klausler correction_ = Scalar<T>{}; 38239f1860dSPeter Klausler element = corrected.value; 38339f1860dSPeter Klausler } 38439f1860dSPeter Klausler } 38539f1860dSPeter Klausler 38639f1860dSPeter Klausler private: 38739f1860dSPeter Klausler const Constant<T> &array_; 38839f1860dSPeter Klausler Rounding rounding_; 38939f1860dSPeter Klausler bool overflow_{false}; 39039f1860dSPeter Klausler Element correction_{}; 39139f1860dSPeter Klausler }; 39239f1860dSPeter Klausler 393503c085eSpeter klausler template <typename T> 394503c085eSpeter klausler static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) { 395503c085eSpeter klausler static_assert(T::category == TypeCategory::Integer || 396*fc97d2e6SPeter Klausler T::category == TypeCategory::Unsigned || 397503c085eSpeter klausler T::category == TypeCategory::Real || 398503c085eSpeter klausler T::category == TypeCategory::Complex); 399503c085eSpeter klausler using Element = typename Constant<T>::Element; 40082568675SPeter Klausler std::optional<int> dim; 40139f1860dSPeter Klausler Element identity{}; 40282e1e412SPeter Klausler if (std::optional<ArrayAndMask<T>> arrayAndMask{ 40382e1e412SPeter Klausler ProcessReductionArgs<T>(context, ref.arguments(), dim, 404503c085eSpeter klausler /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) { 40539f1860dSPeter Klausler SumAccumulator accumulator{ 40682e1e412SPeter Klausler arrayAndMask->array, context.targetCharacteristics().roundingMode()}; 40782e1e412SPeter Klausler auto result{Expr<T>{DoReduction<T>( 40882e1e412SPeter Klausler arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}}; 409505f6da1SPeter Klausler if (accumulator.overflow() && 410505f6da1SPeter Klausler context.languageFeatures().ShouldWarn( 411505f6da1SPeter Klausler common::UsageWarning::FoldingException)) { 4120f973ac7SPeter Klausler context.messages().Say(common::UsageWarning::FoldingException, 413a53967cdSPeter Klausler "SUM() of %s data overflowed"_warn_en_US, T::AsFortran()); 414503c085eSpeter klausler } 41558d74843SPeter Klausler return result; 416503c085eSpeter klausler } 417503c085eSpeter klausler return Expr<T>{std::move(ref)}; 41847f18af5Speter klausler } 41947f18af5Speter klausler 42039f1860dSPeter Klausler // Utility for IALL, IANY, IPARITY, ALL, ANY, & PARITY 42139f1860dSPeter Klausler template <typename T> class OperationAccumulator { 42239f1860dSPeter Klausler public: 42339f1860dSPeter Klausler OperationAccumulator(const Constant<T> &array, 42439f1860dSPeter Klausler Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const) 42539f1860dSPeter Klausler : array_{array}, operation_{operation} {} 42682e1e412SPeter Klausler void operator()( 42782e1e412SPeter Klausler Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) { 42839f1860dSPeter Klausler element = (element.*operation_)(array_.At(at)); 42939f1860dSPeter Klausler } 43039f1860dSPeter Klausler void Done(Scalar<T> &) const {} 43139f1860dSPeter Klausler 43239f1860dSPeter Klausler private: 43339f1860dSPeter Klausler const Constant<T> &array_; 43439f1860dSPeter Klausler Scalar<T> (Scalar<T>::*operation_)(const Scalar<T> &) const; 43539f1860dSPeter Klausler }; 43639f1860dSPeter Klausler 43747f18af5Speter klausler } // namespace Fortran::evaluate 43847f18af5Speter klausler #endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_ 439