xref: /llvm-project/flang/lib/Evaluate/fold-matmul.h (revision fc97d2e68b03bc2979395e84b645e5b3ba35aecd)
1 //===-- lib/Evaluate/fold-matmul.h ----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef FORTRAN_EVALUATE_FOLD_MATMUL_H_
10 #define FORTRAN_EVALUATE_FOLD_MATMUL_H_
11 
12 #include "fold-implementation.h"
13 
14 namespace Fortran::evaluate {
15 
16 template <typename T>
17 static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
18   using Element = typename Constant<T>::Element;
19   auto args{funcRef.arguments()};
20   CHECK(args.size() == 2);
21   Folder<T> folder{context};
22   Constant<T> *ma{folder.Folding(args[0])};
23   Constant<T> *mb{folder.Folding(args[1])};
24   if (!ma || !mb) {
25     return Expr<T>{std::move(funcRef)};
26   }
27   CHECK(ma->Rank() >= 1 && ma->Rank() <= 2 && mb->Rank() >= 1 &&
28       mb->Rank() <= 2 && (ma->Rank() == 2 || mb->Rank() == 2));
29   ConstantSubscript commonExtent{ma->shape().back()};
30   if (mb->shape().front() != commonExtent) {
31     context.messages().Say(
32         "Arguments to MATMUL have distinct extents %zd and %zd on their last and first dimensions"_err_en_US,
33         commonExtent, mb->shape().front());
34     return MakeInvalidIntrinsic(std::move(funcRef));
35   }
36   ConstantSubscript rows{ma->Rank() == 1 ? 1 : ma->shape()[0]};
37   ConstantSubscript columns{mb->Rank() == 1 ? 1 : mb->shape()[1]};
38   std::vector<Element> elements;
39   elements.reserve(rows * columns);
40   bool overflow{false};
41   [[maybe_unused]] const auto &rounding{
42       context.targetCharacteristics().roundingMode()};
43   // result(j,k) = SUM(A(j,:) * B(:,k))
44   for (ConstantSubscript ci{0}; ci < columns; ++ci) {
45     for (ConstantSubscript ri{0}; ri < rows; ++ri) {
46       ConstantSubscripts aAt{ma->lbounds()};
47       if (ma->Rank() == 2) {
48         aAt[0] += ri;
49       }
50       ConstantSubscripts bAt{mb->lbounds()};
51       if (mb->Rank() == 2) {
52         bAt[1] += ci;
53       }
54       Element sum{};
55       [[maybe_unused]] Element correction{};
56       for (ConstantSubscript j{0}; j < commonExtent; ++j) {
57         Element aElt{ma->At(aAt)};
58         Element bElt{mb->At(bAt)};
59         if constexpr (T::category == TypeCategory::Real ||
60             T::category == TypeCategory::Complex) {
61           auto product{aElt.Multiply(bElt)};
62           overflow |= product.flags.test(RealFlag::Overflow);
63           if constexpr (useKahanSummation) {
64             auto next{product.value.Subtract(correction, rounding)};
65             overflow |= next.flags.test(RealFlag::Overflow);
66             auto added{sum.Add(next.value, rounding)};
67             overflow |= added.flags.test(RealFlag::Overflow);
68             correction = added.value.Subtract(sum, rounding)
69                              .value.Subtract(next.value, rounding)
70                              .value;
71             sum = std::move(added.value);
72           } else {
73             auto added{sum.Add(product.value)};
74             overflow |= added.flags.test(RealFlag::Overflow);
75             sum = std::move(added.value);
76           }
77         } else if constexpr (T::category == TypeCategory::Integer) {
78           auto product{aElt.MultiplySigned(bElt)};
79           overflow |= product.SignedMultiplicationOverflowed();
80           auto added{sum.AddSigned(product.lower)};
81           overflow |= added.overflow;
82           sum = std::move(added.value);
83         } else if constexpr (T::category == TypeCategory::Unsigned) {
84           sum = sum.AddUnsigned(aElt.MultiplyUnsigned(bElt).lower).value;
85         } else {
86           static_assert(T::category == TypeCategory::Logical);
87           sum = sum.OR(aElt.AND(bElt));
88         }
89         ++aAt.back();
90         ++bAt.front();
91       }
92       elements.push_back(sum);
93     }
94   }
95   if (overflow &&
96       context.languageFeatures().ShouldWarn(
97           common::UsageWarning::FoldingException)) {
98     context.messages().Say(common::UsageWarning::FoldingException,
99         "MATMUL of %s data overflowed during computation"_warn_en_US,
100         T::AsFortran());
101   }
102   ConstantSubscripts shape;
103   if (ma->Rank() == 2) {
104     shape.push_back(rows);
105   }
106   if (mb->Rank() == 2) {
107     shape.push_back(columns);
108   }
109   return Expr<T>{Constant<T>{std::move(elements), std::move(shape)}};
110 }
111 } // namespace Fortran::evaluate
112 #endif // FORTRAN_EVALUATE_FOLD_MATMUL_H_
113