1 //===-- lib/Evaluate/fold-reduction.h -------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_ 10 #define FORTRAN_EVALUATE_FOLD_REDUCTION_H_ 11 12 #include "fold-implementation.h" 13 14 namespace Fortran::evaluate { 15 16 // DOT_PRODUCT 17 template <typename T> 18 static Expr<T> FoldDotProduct( 19 FoldingContext &context, FunctionRef<T> &&funcRef) { 20 using Element = typename Constant<T>::Element; 21 auto args{funcRef.arguments()}; 22 CHECK(args.size() == 2); 23 Folder<T> folder{context}; 24 Constant<T> *va{folder.Folding(args[0])}; 25 Constant<T> *vb{folder.Folding(args[1])}; 26 if (va && vb) { 27 CHECK(va->Rank() == 1 && vb->Rank() == 1); 28 if (va->size() != vb->size()) { 29 context.messages().Say( 30 "Vector arguments to DOT_PRODUCT have distinct extents %zd and %zd"_err_en_US, 31 va->size(), vb->size()); 32 return MakeInvalidIntrinsic(std::move(funcRef)); 33 } 34 Element sum{}; 35 bool overflow{false}; 36 if constexpr (T::category == TypeCategory::Complex) { 37 std::vector<Element> conjugates; 38 for (const Element &x : va->values()) { 39 conjugates.emplace_back(x.CONJG()); 40 } 41 Constant<T> conjgA{ 42 std::move(conjugates), ConstantSubscripts{va->shape()}}; 43 Expr<T> products{Fold( 44 context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})}; 45 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))}; 46 [[maybe_unused]] Element correction{}; 47 const auto &rounding{context.targetCharacteristics().roundingMode()}; 48 for (const Element &x : cProducts.values()) { 49 if constexpr (useKahanSummation) { 50 auto next{x.Subtract(correction, rounding)}; 51 overflow |= next.flags.test(RealFlag::Overflow); 52 auto added{sum.Add(next.value, rounding)}; 53 overflow |= added.flags.test(RealFlag::Overflow); 54 correction = added.value.Subtract(sum, rounding) 55 .value.Subtract(next.value, rounding) 56 .value; 57 sum = std::move(added.value); 58 } else { 59 auto added{sum.Add(x, rounding)}; 60 overflow |= added.flags.test(RealFlag::Overflow); 61 sum = std::move(added.value); 62 } 63 } 64 } else if constexpr (T::category == TypeCategory::Logical) { 65 Expr<T> conjunctions{Fold(context, 66 Expr<T>{LogicalOperation<T::kind>{LogicalOperator::And, 67 Expr<T>{Constant<T>{*va}}, Expr<T>{Constant<T>{*vb}}}})}; 68 Constant<T> &cConjunctions{DEREF(UnwrapConstantValue<T>(conjunctions))}; 69 for (const Element &x : cConjunctions.values()) { 70 if (x.IsTrue()) { 71 sum = Element{true}; 72 break; 73 } 74 } 75 } else if constexpr (T::category == TypeCategory::Integer) { 76 Expr<T> products{ 77 Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})}; 78 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))}; 79 for (const Element &x : cProducts.values()) { 80 auto next{sum.AddSigned(x)}; 81 overflow |= next.overflow; 82 sum = std::move(next.value); 83 } 84 } else if constexpr (T::category == TypeCategory::Unsigned) { 85 Expr<T> products{ 86 Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})}; 87 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))}; 88 for (const Element &x : cProducts.values()) { 89 sum = sum.AddUnsigned(x).value; 90 } 91 } else { 92 static_assert(T::category == TypeCategory::Real); 93 Expr<T> products{ 94 Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})}; 95 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))}; 96 [[maybe_unused]] Element correction{}; 97 const auto &rounding{context.targetCharacteristics().roundingMode()}; 98 for (const Element &x : cProducts.values()) { 99 if constexpr (useKahanSummation) { 100 auto next{x.Subtract(correction, rounding)}; 101 overflow |= next.flags.test(RealFlag::Overflow); 102 auto added{sum.Add(next.value, rounding)}; 103 overflow |= added.flags.test(RealFlag::Overflow); 104 correction = added.value.Subtract(sum, rounding) 105 .value.Subtract(next.value, rounding) 106 .value; 107 sum = std::move(added.value); 108 } else { 109 auto added{sum.Add(x, rounding)}; 110 overflow |= added.flags.test(RealFlag::Overflow); 111 sum = std::move(added.value); 112 } 113 } 114 } 115 if (overflow && 116 context.languageFeatures().ShouldWarn( 117 common::UsageWarning::FoldingException)) { 118 context.messages().Say(common::UsageWarning::FoldingException, 119 "DOT_PRODUCT of %s data overflowed during computation"_warn_en_US, 120 T::AsFortran()); 121 } 122 return Expr<T>{Constant<T>{std::move(sum)}}; 123 } 124 return Expr<T>{std::move(funcRef)}; 125 } 126 127 // Fold and validate a DIM= argument. Returns false on error. 128 bool CheckReductionDIM(std::optional<int> &dim, FoldingContext &, 129 ActualArguments &, std::optional<int> dimIndex, int rank); 130 131 // Fold and validate a MASK= argument. Return null on error, absent MASK=, or 132 // non-constant MASK=. 133 Constant<LogicalResult> *GetReductionMASK( 134 std::optional<ActualArgument> &maskArg, const ConstantSubscripts &shape, 135 FoldingContext &); 136 137 // Common preprocessing for reduction transformational intrinsic function 138 // folding. If the intrinsic can have DIM= &/or MASK= arguments, extract 139 // and check them. If a MASK= is present, apply it to the array data and 140 // substitute replacement values for elements corresponding to .FALSE. in 141 // the mask. If the result is present, the intrinsic call can be folded. 142 template <typename T> struct ArrayAndMask { 143 Constant<T> array; 144 Constant<LogicalResult> mask; 145 }; 146 template <typename T> 147 static std::optional<ArrayAndMask<T>> ProcessReductionArgs( 148 FoldingContext &context, ActualArguments &arg, std::optional<int> &dim, 149 int arrayIndex, std::optional<int> dimIndex = std::nullopt, 150 std::optional<int> maskIndex = std::nullopt) { 151 if (arg.empty()) { 152 return std::nullopt; 153 } 154 Constant<T> *folded{Folder<T>{context}.Folding(arg[arrayIndex])}; 155 if (!folded || folded->Rank() < 1) { 156 return std::nullopt; 157 } 158 if (!CheckReductionDIM(dim, context, arg, dimIndex, folded->Rank())) { 159 return std::nullopt; 160 } 161 std::size_t n{folded->size()}; 162 std::vector<Scalar<LogicalResult>> maskElement; 163 if (maskIndex && static_cast<std::size_t>(*maskIndex) < arg.size() && 164 arg[*maskIndex]) { 165 if (const Constant<LogicalResult> *origMask{ 166 GetReductionMASK(arg[*maskIndex], folded->shape(), context)}) { 167 if (auto scalarMask{origMask->GetScalarValue()}) { 168 maskElement = 169 std::vector<Scalar<LogicalResult>>(n, scalarMask->IsTrue()); 170 } else { 171 maskElement = origMask->values(); 172 } 173 } else { 174 return std::nullopt; 175 } 176 } else { 177 maskElement = std::vector<Scalar<LogicalResult>>(n, true); 178 } 179 return ArrayAndMask<T>{Constant<T>(*folded), 180 Constant<LogicalResult>{ 181 std::move(maskElement), ConstantSubscripts{folded->shape()}}}; 182 } 183 184 // Generalized reduction to an array of one dimension fewer (w/ DIM=) 185 // or to a scalar (w/o DIM=). The ACCUMULATOR type must define 186 // operator()(Scalar<T> &, const ConstantSubscripts &, bool first) 187 // and Done(Scalar<T> &). 188 template <typename T, typename ACCUMULATOR, typename ARRAY> 189 static Constant<T> DoReduction(const Constant<ARRAY> &array, 190 const Constant<LogicalResult> &mask, std::optional<int> &dim, 191 const Scalar<T> &identity, ACCUMULATOR &accumulator) { 192 ConstantSubscripts at{array.lbounds()}; 193 ConstantSubscripts maskAt{mask.lbounds()}; 194 std::vector<typename Constant<T>::Element> elements; 195 ConstantSubscripts resultShape; // empty -> scalar 196 if (dim) { // DIM= is present, so result is an array 197 resultShape = array.shape(); 198 resultShape.erase(resultShape.begin() + (*dim - 1)); 199 ConstantSubscript dimExtent{array.shape().at(*dim - 1)}; 200 CHECK(dimExtent == mask.shape().at(*dim - 1)); 201 ConstantSubscript &dimAt{at[*dim - 1]}; 202 ConstantSubscript dimLbound{dimAt}; 203 ConstantSubscript &maskDimAt{maskAt[*dim - 1]}; 204 ConstantSubscript maskDimLbound{maskDimAt}; 205 for (auto n{GetSize(resultShape)}; n-- > 0; 206 array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) { 207 elements.push_back(identity); 208 if (dimExtent > 0) { 209 dimAt = dimLbound; 210 maskDimAt = maskDimLbound; 211 bool firstUnmasked{true}; 212 for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt, ++maskDimAt) { 213 if (mask.At(maskAt).IsTrue()) { 214 accumulator(elements.back(), at, firstUnmasked); 215 firstUnmasked = false; 216 } 217 } 218 --dimAt, --maskDimAt; 219 } 220 accumulator.Done(elements.back()); 221 } 222 } else { // no DIM=, result is scalar 223 elements.push_back(identity); 224 bool firstUnmasked{true}; 225 for (auto n{array.size()}; n-- > 0; 226 array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) { 227 if (mask.At(maskAt).IsTrue()) { 228 accumulator(elements.back(), at, firstUnmasked); 229 firstUnmasked = false; 230 } 231 } 232 accumulator.Done(elements.back()); 233 } 234 if constexpr (T::category == TypeCategory::Character) { 235 return {static_cast<ConstantSubscript>(identity.size()), 236 std::move(elements), std::move(resultShape)}; 237 } else { 238 return {std::move(elements), std::move(resultShape)}; 239 } 240 } 241 242 // MAXVAL & MINVAL 243 template <typename T, bool ABS = false> class MaxvalMinvalAccumulator { 244 public: 245 MaxvalMinvalAccumulator( 246 RelationalOperator opr, FoldingContext &context, const Constant<T> &array) 247 : opr_{opr}, context_{context}, array_{array} {}; 248 void operator()(Scalar<T> &element, const ConstantSubscripts &at, 249 [[maybe_unused]] bool firstUnmasked) const { 250 auto aAt{array_.At(at)}; 251 if constexpr (ABS) { 252 aAt = aAt.ABS(); 253 } 254 if constexpr (T::category == TypeCategory::Real) { 255 if (firstUnmasked || element.IsNotANumber()) { 256 // Return NaN if and only if all unmasked elements are NaNs and 257 // at least one unmasked element is visible. 258 element = aAt; 259 return; 260 } 261 } 262 Expr<LogicalResult> test{PackageRelation( 263 opr_, Expr<T>{Constant<T>{aAt}}, Expr<T>{Constant<T>{element}})}; 264 auto folded{GetScalarConstantValue<LogicalResult>( 265 test.Rewrite(context_, std::move(test)))}; 266 CHECK(folded.has_value()); 267 if (folded->IsTrue()) { 268 element = aAt; 269 } 270 } 271 void Done(Scalar<T> &) const {} 272 273 private: 274 RelationalOperator opr_; 275 FoldingContext &context_; 276 const Constant<T> &array_; 277 }; 278 279 template <typename T> 280 static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref, 281 RelationalOperator opr, const Scalar<T> &identity) { 282 static_assert(T::category == TypeCategory::Integer || 283 T::category == TypeCategory::Unsigned || 284 T::category == TypeCategory::Real || 285 T::category == TypeCategory::Character); 286 std::optional<int> dim; 287 if (std::optional<ArrayAndMask<T>> arrayAndMask{ 288 ProcessReductionArgs<T>(context, ref.arguments(), dim, 289 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) { 290 MaxvalMinvalAccumulator<T> accumulator{opr, context, arrayAndMask->array}; 291 return Expr<T>{DoReduction<T>( 292 arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}; 293 } 294 return Expr<T>{std::move(ref)}; 295 } 296 297 // PRODUCT 298 template <typename T> class ProductAccumulator { 299 public: 300 ProductAccumulator(const Constant<T> &array) : array_{array} {} 301 void operator()( 302 Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) { 303 if constexpr (T::category == TypeCategory::Integer) { 304 auto prod{element.MultiplySigned(array_.At(at))}; 305 overflow_ |= prod.SignedMultiplicationOverflowed(); 306 element = prod.lower; 307 } else if constexpr (T::category == TypeCategory::Unsigned) { 308 element = element.MultiplyUnsigned(array_.At(at)).lower; 309 } else { // Real & Complex 310 auto prod{element.Multiply(array_.At(at))}; 311 overflow_ |= prod.flags.test(RealFlag::Overflow); 312 element = prod.value; 313 } 314 } 315 bool overflow() const { return overflow_; } 316 void Done(Scalar<T> &) const {} 317 318 private: 319 const Constant<T> &array_; 320 bool overflow_{false}; 321 }; 322 323 template <typename T> 324 static Expr<T> FoldProduct( 325 FoldingContext &context, FunctionRef<T> &&ref, Scalar<T> identity) { 326 static_assert(T::category == TypeCategory::Integer || 327 T::category == TypeCategory::Unsigned || 328 T::category == TypeCategory::Real || 329 T::category == TypeCategory::Complex); 330 std::optional<int> dim; 331 if (std::optional<ArrayAndMask<T>> arrayAndMask{ 332 ProcessReductionArgs<T>(context, ref.arguments(), dim, 333 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) { 334 ProductAccumulator accumulator{arrayAndMask->array}; 335 auto result{Expr<T>{DoReduction<T>( 336 arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}}; 337 if (accumulator.overflow() && 338 context.languageFeatures().ShouldWarn( 339 common::UsageWarning::FoldingException)) { 340 context.messages().Say(common::UsageWarning::FoldingException, 341 "PRODUCT() of %s data overflowed"_warn_en_US, T::AsFortran()); 342 } 343 return result; 344 } 345 return Expr<T>{std::move(ref)}; 346 } 347 348 // SUM 349 template <typename T> class SumAccumulator { 350 using Element = typename Constant<T>::Element; 351 352 public: 353 SumAccumulator(const Constant<T> &array, Rounding rounding) 354 : array_{array}, rounding_{rounding} {} 355 void operator()( 356 Element &element, const ConstantSubscripts &at, bool /*first*/) { 357 if constexpr (T::category == TypeCategory::Integer) { 358 auto sum{element.AddSigned(array_.At(at))}; 359 overflow_ |= sum.overflow; 360 element = sum.value; 361 } else if constexpr (T::category == TypeCategory::Unsigned) { 362 element = element.AddUnsigned(array_.At(at)).value; 363 } else { // Real & Complex: use Kahan summation 364 auto next{array_.At(at).Subtract(correction_, rounding_)}; 365 overflow_ |= next.flags.test(RealFlag::Overflow); 366 auto sum{element.Add(next.value, rounding_)}; 367 overflow_ |= sum.flags.test(RealFlag::Overflow); 368 // correction = (sum - element) - next; algebraically zero 369 correction_ = sum.value.Subtract(element, rounding_) 370 .value.Subtract(next.value, rounding_) 371 .value; 372 element = sum.value; 373 } 374 } 375 bool overflow() const { return overflow_; } 376 void Done([[maybe_unused]] Element &element) { 377 if constexpr (T::category != TypeCategory::Integer && 378 T::category != TypeCategory::Unsigned) { 379 auto corrected{element.Add(correction_, rounding_)}; 380 overflow_ |= corrected.flags.test(RealFlag::Overflow); 381 correction_ = Scalar<T>{}; 382 element = corrected.value; 383 } 384 } 385 386 private: 387 const Constant<T> &array_; 388 Rounding rounding_; 389 bool overflow_{false}; 390 Element correction_{}; 391 }; 392 393 template <typename T> 394 static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) { 395 static_assert(T::category == TypeCategory::Integer || 396 T::category == TypeCategory::Unsigned || 397 T::category == TypeCategory::Real || 398 T::category == TypeCategory::Complex); 399 using Element = typename Constant<T>::Element; 400 std::optional<int> dim; 401 Element identity{}; 402 if (std::optional<ArrayAndMask<T>> arrayAndMask{ 403 ProcessReductionArgs<T>(context, ref.arguments(), dim, 404 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) { 405 SumAccumulator accumulator{ 406 arrayAndMask->array, context.targetCharacteristics().roundingMode()}; 407 auto result{Expr<T>{DoReduction<T>( 408 arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}}; 409 if (accumulator.overflow() && 410 context.languageFeatures().ShouldWarn( 411 common::UsageWarning::FoldingException)) { 412 context.messages().Say(common::UsageWarning::FoldingException, 413 "SUM() of %s data overflowed"_warn_en_US, T::AsFortran()); 414 } 415 return result; 416 } 417 return Expr<T>{std::move(ref)}; 418 } 419 420 // Utility for IALL, IANY, IPARITY, ALL, ANY, & PARITY 421 template <typename T> class OperationAccumulator { 422 public: 423 OperationAccumulator(const Constant<T> &array, 424 Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const) 425 : array_{array}, operation_{operation} {} 426 void operator()( 427 Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) { 428 element = (element.*operation_)(array_.At(at)); 429 } 430 void Done(Scalar<T> &) const {} 431 432 private: 433 const Constant<T> &array_; 434 Scalar<T> (Scalar<T>::*operation_)(const Scalar<T> &) const; 435 }; 436 437 } // namespace Fortran::evaluate 438 #endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_ 439