xref: /llvm-project/flang/lib/Evaluate/fold-reduction.h (revision fc97d2e68b03bc2979395e84b645e5b3ba35aecd)
1 //===-- lib/Evaluate/fold-reduction.h -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
10 #define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
11 
12 #include "fold-implementation.h"
13 
14 namespace Fortran::evaluate {
15 
16 // DOT_PRODUCT
17 template <typename T>
18 static Expr<T> FoldDotProduct(
19     FoldingContext &context, FunctionRef<T> &&funcRef) {
20   using Element = typename Constant<T>::Element;
21   auto args{funcRef.arguments()};
22   CHECK(args.size() == 2);
23   Folder<T> folder{context};
24   Constant<T> *va{folder.Folding(args[0])};
25   Constant<T> *vb{folder.Folding(args[1])};
26   if (va && vb) {
27     CHECK(va->Rank() == 1 && vb->Rank() == 1);
28     if (va->size() != vb->size()) {
29       context.messages().Say(
30           "Vector arguments to DOT_PRODUCT have distinct extents %zd and %zd"_err_en_US,
31           va->size(), vb->size());
32       return MakeInvalidIntrinsic(std::move(funcRef));
33     }
34     Element sum{};
35     bool overflow{false};
36     if constexpr (T::category == TypeCategory::Complex) {
37       std::vector<Element> conjugates;
38       for (const Element &x : va->values()) {
39         conjugates.emplace_back(x.CONJG());
40       }
41       Constant<T> conjgA{
42           std::move(conjugates), ConstantSubscripts{va->shape()}};
43       Expr<T> products{Fold(
44           context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})};
45       Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
46       [[maybe_unused]] Element correction{};
47       const auto &rounding{context.targetCharacteristics().roundingMode()};
48       for (const Element &x : cProducts.values()) {
49         if constexpr (useKahanSummation) {
50           auto next{x.Subtract(correction, rounding)};
51           overflow |= next.flags.test(RealFlag::Overflow);
52           auto added{sum.Add(next.value, rounding)};
53           overflow |= added.flags.test(RealFlag::Overflow);
54           correction = added.value.Subtract(sum, rounding)
55                            .value.Subtract(next.value, rounding)
56                            .value;
57           sum = std::move(added.value);
58         } else {
59           auto added{sum.Add(x, rounding)};
60           overflow |= added.flags.test(RealFlag::Overflow);
61           sum = std::move(added.value);
62         }
63       }
64     } else if constexpr (T::category == TypeCategory::Logical) {
65       Expr<T> conjunctions{Fold(context,
66           Expr<T>{LogicalOperation<T::kind>{LogicalOperator::And,
67               Expr<T>{Constant<T>{*va}}, Expr<T>{Constant<T>{*vb}}}})};
68       Constant<T> &cConjunctions{DEREF(UnwrapConstantValue<T>(conjunctions))};
69       for (const Element &x : cConjunctions.values()) {
70         if (x.IsTrue()) {
71           sum = Element{true};
72           break;
73         }
74       }
75     } else if constexpr (T::category == TypeCategory::Integer) {
76       Expr<T> products{
77           Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
78       Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
79       for (const Element &x : cProducts.values()) {
80         auto next{sum.AddSigned(x)};
81         overflow |= next.overflow;
82         sum = std::move(next.value);
83       }
84     } else if constexpr (T::category == TypeCategory::Unsigned) {
85       Expr<T> products{
86           Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
87       Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
88       for (const Element &x : cProducts.values()) {
89         sum = sum.AddUnsigned(x).value;
90       }
91     } else {
92       static_assert(T::category == TypeCategory::Real);
93       Expr<T> products{
94           Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
95       Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
96       [[maybe_unused]] Element correction{};
97       const auto &rounding{context.targetCharacteristics().roundingMode()};
98       for (const Element &x : cProducts.values()) {
99         if constexpr (useKahanSummation) {
100           auto next{x.Subtract(correction, rounding)};
101           overflow |= next.flags.test(RealFlag::Overflow);
102           auto added{sum.Add(next.value, rounding)};
103           overflow |= added.flags.test(RealFlag::Overflow);
104           correction = added.value.Subtract(sum, rounding)
105                            .value.Subtract(next.value, rounding)
106                            .value;
107           sum = std::move(added.value);
108         } else {
109           auto added{sum.Add(x, rounding)};
110           overflow |= added.flags.test(RealFlag::Overflow);
111           sum = std::move(added.value);
112         }
113       }
114     }
115     if (overflow &&
116         context.languageFeatures().ShouldWarn(
117             common::UsageWarning::FoldingException)) {
118       context.messages().Say(common::UsageWarning::FoldingException,
119           "DOT_PRODUCT of %s data overflowed during computation"_warn_en_US,
120           T::AsFortran());
121     }
122     return Expr<T>{Constant<T>{std::move(sum)}};
123   }
124   return Expr<T>{std::move(funcRef)};
125 }
126 
127 // Fold and validate a DIM= argument.  Returns false on error.
128 bool CheckReductionDIM(std::optional<int> &dim, FoldingContext &,
129     ActualArguments &, std::optional<int> dimIndex, int rank);
130 
131 // Fold and validate a MASK= argument.  Return null on error, absent MASK=, or
132 // non-constant MASK=.
133 Constant<LogicalResult> *GetReductionMASK(
134     std::optional<ActualArgument> &maskArg, const ConstantSubscripts &shape,
135     FoldingContext &);
136 
137 // Common preprocessing for reduction transformational intrinsic function
138 // folding.  If the intrinsic can have DIM= &/or MASK= arguments, extract
139 // and check them.  If a MASK= is present, apply it to the array data and
140 // substitute replacement values for elements corresponding to .FALSE. in
141 // the mask.  If the result is present, the intrinsic call can be folded.
142 template <typename T> struct ArrayAndMask {
143   Constant<T> array;
144   Constant<LogicalResult> mask;
145 };
146 template <typename T>
147 static std::optional<ArrayAndMask<T>> ProcessReductionArgs(
148     FoldingContext &context, ActualArguments &arg, std::optional<int> &dim,
149     int arrayIndex, std::optional<int> dimIndex = std::nullopt,
150     std::optional<int> maskIndex = std::nullopt) {
151   if (arg.empty()) {
152     return std::nullopt;
153   }
154   Constant<T> *folded{Folder<T>{context}.Folding(arg[arrayIndex])};
155   if (!folded || folded->Rank() < 1) {
156     return std::nullopt;
157   }
158   if (!CheckReductionDIM(dim, context, arg, dimIndex, folded->Rank())) {
159     return std::nullopt;
160   }
161   std::size_t n{folded->size()};
162   std::vector<Scalar<LogicalResult>> maskElement;
163   if (maskIndex && static_cast<std::size_t>(*maskIndex) < arg.size() &&
164       arg[*maskIndex]) {
165     if (const Constant<LogicalResult> *origMask{
166             GetReductionMASK(arg[*maskIndex], folded->shape(), context)}) {
167       if (auto scalarMask{origMask->GetScalarValue()}) {
168         maskElement =
169             std::vector<Scalar<LogicalResult>>(n, scalarMask->IsTrue());
170       } else {
171         maskElement = origMask->values();
172       }
173     } else {
174       return std::nullopt;
175     }
176   } else {
177     maskElement = std::vector<Scalar<LogicalResult>>(n, true);
178   }
179   return ArrayAndMask<T>{Constant<T>(*folded),
180       Constant<LogicalResult>{
181           std::move(maskElement), ConstantSubscripts{folded->shape()}}};
182 }
183 
184 // Generalized reduction to an array of one dimension fewer (w/ DIM=)
185 // or to a scalar (w/o DIM=).  The ACCUMULATOR type must define
186 // operator()(Scalar<T> &, const ConstantSubscripts &, bool first)
187 // and Done(Scalar<T> &).
188 template <typename T, typename ACCUMULATOR, typename ARRAY>
189 static Constant<T> DoReduction(const Constant<ARRAY> &array,
190     const Constant<LogicalResult> &mask, std::optional<int> &dim,
191     const Scalar<T> &identity, ACCUMULATOR &accumulator) {
192   ConstantSubscripts at{array.lbounds()};
193   ConstantSubscripts maskAt{mask.lbounds()};
194   std::vector<typename Constant<T>::Element> elements;
195   ConstantSubscripts resultShape; // empty -> scalar
196   if (dim) { // DIM= is present, so result is an array
197     resultShape = array.shape();
198     resultShape.erase(resultShape.begin() + (*dim - 1));
199     ConstantSubscript dimExtent{array.shape().at(*dim - 1)};
200     CHECK(dimExtent == mask.shape().at(*dim - 1));
201     ConstantSubscript &dimAt{at[*dim - 1]};
202     ConstantSubscript dimLbound{dimAt};
203     ConstantSubscript &maskDimAt{maskAt[*dim - 1]};
204     ConstantSubscript maskDimLbound{maskDimAt};
205     for (auto n{GetSize(resultShape)}; n-- > 0;
206          array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) {
207       elements.push_back(identity);
208       if (dimExtent > 0) {
209         dimAt = dimLbound;
210         maskDimAt = maskDimLbound;
211         bool firstUnmasked{true};
212         for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt, ++maskDimAt) {
213           if (mask.At(maskAt).IsTrue()) {
214             accumulator(elements.back(), at, firstUnmasked);
215             firstUnmasked = false;
216           }
217         }
218         --dimAt, --maskDimAt;
219       }
220       accumulator.Done(elements.back());
221     }
222   } else { // no DIM=, result is scalar
223     elements.push_back(identity);
224     bool firstUnmasked{true};
225     for (auto n{array.size()}; n-- > 0;
226          array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) {
227       if (mask.At(maskAt).IsTrue()) {
228         accumulator(elements.back(), at, firstUnmasked);
229         firstUnmasked = false;
230       }
231     }
232     accumulator.Done(elements.back());
233   }
234   if constexpr (T::category == TypeCategory::Character) {
235     return {static_cast<ConstantSubscript>(identity.size()),
236         std::move(elements), std::move(resultShape)};
237   } else {
238     return {std::move(elements), std::move(resultShape)};
239   }
240 }
241 
242 // MAXVAL & MINVAL
243 template <typename T, bool ABS = false> class MaxvalMinvalAccumulator {
244 public:
245   MaxvalMinvalAccumulator(
246       RelationalOperator opr, FoldingContext &context, const Constant<T> &array)
247       : opr_{opr}, context_{context}, array_{array} {};
248   void operator()(Scalar<T> &element, const ConstantSubscripts &at,
249       [[maybe_unused]] bool firstUnmasked) const {
250     auto aAt{array_.At(at)};
251     if constexpr (ABS) {
252       aAt = aAt.ABS();
253     }
254     if constexpr (T::category == TypeCategory::Real) {
255       if (firstUnmasked || element.IsNotANumber()) {
256         // Return NaN if and only if all unmasked elements are NaNs and
257         // at least one unmasked element is visible.
258         element = aAt;
259         return;
260       }
261     }
262     Expr<LogicalResult> test{PackageRelation(
263         opr_, Expr<T>{Constant<T>{aAt}}, Expr<T>{Constant<T>{element}})};
264     auto folded{GetScalarConstantValue<LogicalResult>(
265         test.Rewrite(context_, std::move(test)))};
266     CHECK(folded.has_value());
267     if (folded->IsTrue()) {
268       element = aAt;
269     }
270   }
271   void Done(Scalar<T> &) const {}
272 
273 private:
274   RelationalOperator opr_;
275   FoldingContext &context_;
276   const Constant<T> &array_;
277 };
278 
279 template <typename T>
280 static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
281     RelationalOperator opr, const Scalar<T> &identity) {
282   static_assert(T::category == TypeCategory::Integer ||
283       T::category == TypeCategory::Unsigned ||
284       T::category == TypeCategory::Real ||
285       T::category == TypeCategory::Character);
286   std::optional<int> dim;
287   if (std::optional<ArrayAndMask<T>> arrayAndMask{
288           ProcessReductionArgs<T>(context, ref.arguments(), dim,
289               /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
290     MaxvalMinvalAccumulator<T> accumulator{opr, context, arrayAndMask->array};
291     return Expr<T>{DoReduction<T>(
292         arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)};
293   }
294   return Expr<T>{std::move(ref)};
295 }
296 
297 // PRODUCT
298 template <typename T> class ProductAccumulator {
299 public:
300   ProductAccumulator(const Constant<T> &array) : array_{array} {}
301   void operator()(
302       Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
303     if constexpr (T::category == TypeCategory::Integer) {
304       auto prod{element.MultiplySigned(array_.At(at))};
305       overflow_ |= prod.SignedMultiplicationOverflowed();
306       element = prod.lower;
307     } else if constexpr (T::category == TypeCategory::Unsigned) {
308       element = element.MultiplyUnsigned(array_.At(at)).lower;
309     } else { // Real & Complex
310       auto prod{element.Multiply(array_.At(at))};
311       overflow_ |= prod.flags.test(RealFlag::Overflow);
312       element = prod.value;
313     }
314   }
315   bool overflow() const { return overflow_; }
316   void Done(Scalar<T> &) const {}
317 
318 private:
319   const Constant<T> &array_;
320   bool overflow_{false};
321 };
322 
323 template <typename T>
324 static Expr<T> FoldProduct(
325     FoldingContext &context, FunctionRef<T> &&ref, Scalar<T> identity) {
326   static_assert(T::category == TypeCategory::Integer ||
327       T::category == TypeCategory::Unsigned ||
328       T::category == TypeCategory::Real ||
329       T::category == TypeCategory::Complex);
330   std::optional<int> dim;
331   if (std::optional<ArrayAndMask<T>> arrayAndMask{
332           ProcessReductionArgs<T>(context, ref.arguments(), dim,
333               /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
334     ProductAccumulator accumulator{arrayAndMask->array};
335     auto result{Expr<T>{DoReduction<T>(
336         arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}};
337     if (accumulator.overflow() &&
338         context.languageFeatures().ShouldWarn(
339             common::UsageWarning::FoldingException)) {
340       context.messages().Say(common::UsageWarning::FoldingException,
341           "PRODUCT() of %s data overflowed"_warn_en_US, T::AsFortran());
342     }
343     return result;
344   }
345   return Expr<T>{std::move(ref)};
346 }
347 
348 // SUM
349 template <typename T> class SumAccumulator {
350   using Element = typename Constant<T>::Element;
351 
352 public:
353   SumAccumulator(const Constant<T> &array, Rounding rounding)
354       : array_{array}, rounding_{rounding} {}
355   void operator()(
356       Element &element, const ConstantSubscripts &at, bool /*first*/) {
357     if constexpr (T::category == TypeCategory::Integer) {
358       auto sum{element.AddSigned(array_.At(at))};
359       overflow_ |= sum.overflow;
360       element = sum.value;
361     } else if constexpr (T::category == TypeCategory::Unsigned) {
362       element = element.AddUnsigned(array_.At(at)).value;
363     } else { // Real & Complex: use Kahan summation
364       auto next{array_.At(at).Subtract(correction_, rounding_)};
365       overflow_ |= next.flags.test(RealFlag::Overflow);
366       auto sum{element.Add(next.value, rounding_)};
367       overflow_ |= sum.flags.test(RealFlag::Overflow);
368       // correction = (sum - element) - next; algebraically zero
369       correction_ = sum.value.Subtract(element, rounding_)
370                         .value.Subtract(next.value, rounding_)
371                         .value;
372       element = sum.value;
373     }
374   }
375   bool overflow() const { return overflow_; }
376   void Done([[maybe_unused]] Element &element) {
377     if constexpr (T::category != TypeCategory::Integer &&
378         T::category != TypeCategory::Unsigned) {
379       auto corrected{element.Add(correction_, rounding_)};
380       overflow_ |= corrected.flags.test(RealFlag::Overflow);
381       correction_ = Scalar<T>{};
382       element = corrected.value;
383     }
384   }
385 
386 private:
387   const Constant<T> &array_;
388   Rounding rounding_;
389   bool overflow_{false};
390   Element correction_{};
391 };
392 
393 template <typename T>
394 static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
395   static_assert(T::category == TypeCategory::Integer ||
396       T::category == TypeCategory::Unsigned ||
397       T::category == TypeCategory::Real ||
398       T::category == TypeCategory::Complex);
399   using Element = typename Constant<T>::Element;
400   std::optional<int> dim;
401   Element identity{};
402   if (std::optional<ArrayAndMask<T>> arrayAndMask{
403           ProcessReductionArgs<T>(context, ref.arguments(), dim,
404               /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
405     SumAccumulator accumulator{
406         arrayAndMask->array, context.targetCharacteristics().roundingMode()};
407     auto result{Expr<T>{DoReduction<T>(
408         arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}};
409     if (accumulator.overflow() &&
410         context.languageFeatures().ShouldWarn(
411             common::UsageWarning::FoldingException)) {
412       context.messages().Say(common::UsageWarning::FoldingException,
413           "SUM() of %s data overflowed"_warn_en_US, T::AsFortran());
414     }
415     return result;
416   }
417   return Expr<T>{std::move(ref)};
418 }
419 
420 // Utility for IALL, IANY, IPARITY, ALL, ANY, & PARITY
421 template <typename T> class OperationAccumulator {
422 public:
423   OperationAccumulator(const Constant<T> &array,
424       Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const)
425       : array_{array}, operation_{operation} {}
426   void operator()(
427       Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
428     element = (element.*operation_)(array_.At(at));
429   }
430   void Done(Scalar<T> &) const {}
431 
432 private:
433   const Constant<T> &array_;
434   Scalar<T> (Scalar<T>::*operation_)(const Scalar<T> &) const;
435 };
436 
437 } // namespace Fortran::evaluate
438 #endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_
439