1 //===-- lib/Evaluate/fold-matmul.h ----------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef FORTRAN_EVALUATE_FOLD_MATMUL_H_ 10 #define FORTRAN_EVALUATE_FOLD_MATMUL_H_ 11 12 #include "fold-implementation.h" 13 14 namespace Fortran::evaluate { 15 16 template <typename T> 17 static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) { 18 using Element = typename Constant<T>::Element; 19 auto args{funcRef.arguments()}; 20 CHECK(args.size() == 2); 21 Folder<T> folder{context}; 22 Constant<T> *ma{folder.Folding(args[0])}; 23 Constant<T> *mb{folder.Folding(args[1])}; 24 if (!ma || !mb) { 25 return Expr<T>{std::move(funcRef)}; 26 } 27 CHECK(ma->Rank() >= 1 && ma->Rank() <= 2 && mb->Rank() >= 1 && 28 mb->Rank() <= 2 && (ma->Rank() == 2 || mb->Rank() == 2)); 29 ConstantSubscript commonExtent{ma->shape().back()}; 30 if (mb->shape().front() != commonExtent) { 31 context.messages().Say( 32 "Arguments to MATMUL have distinct extents %zd and %zd on their last and first dimensions"_err_en_US, 33 commonExtent, mb->shape().front()); 34 return MakeInvalidIntrinsic(std::move(funcRef)); 35 } 36 ConstantSubscript rows{ma->Rank() == 1 ? 1 : ma->shape()[0]}; 37 ConstantSubscript columns{mb->Rank() == 1 ? 1 : mb->shape()[1]}; 38 std::vector<Element> elements; 39 elements.reserve(rows * columns); 40 bool overflow{false}; 41 [[maybe_unused]] const auto &rounding{ 42 context.targetCharacteristics().roundingMode()}; 43 // result(j,k) = SUM(A(j,:) * B(:,k)) 44 for (ConstantSubscript ci{0}; ci < columns; ++ci) { 45 for (ConstantSubscript ri{0}; ri < rows; ++ri) { 46 ConstantSubscripts aAt{ma->lbounds()}; 47 if (ma->Rank() == 2) { 48 aAt[0] += ri; 49 } 50 ConstantSubscripts bAt{mb->lbounds()}; 51 if (mb->Rank() == 2) { 52 bAt[1] += ci; 53 } 54 Element sum{}; 55 [[maybe_unused]] Element correction{}; 56 for (ConstantSubscript j{0}; j < commonExtent; ++j) { 57 Element aElt{ma->At(aAt)}; 58 Element bElt{mb->At(bAt)}; 59 if constexpr (T::category == TypeCategory::Real || 60 T::category == TypeCategory::Complex) { 61 auto product{aElt.Multiply(bElt)}; 62 overflow |= product.flags.test(RealFlag::Overflow); 63 if constexpr (useKahanSummation) { 64 auto next{product.value.Subtract(correction, rounding)}; 65 overflow |= next.flags.test(RealFlag::Overflow); 66 auto added{sum.Add(next.value, rounding)}; 67 overflow |= added.flags.test(RealFlag::Overflow); 68 correction = added.value.Subtract(sum, rounding) 69 .value.Subtract(next.value, rounding) 70 .value; 71 sum = std::move(added.value); 72 } else { 73 auto added{sum.Add(product.value)}; 74 overflow |= added.flags.test(RealFlag::Overflow); 75 sum = std::move(added.value); 76 } 77 } else if constexpr (T::category == TypeCategory::Integer) { 78 auto product{aElt.MultiplySigned(bElt)}; 79 overflow |= product.SignedMultiplicationOverflowed(); 80 auto added{sum.AddSigned(product.lower)}; 81 overflow |= added.overflow; 82 sum = std::move(added.value); 83 } else if constexpr (T::category == TypeCategory::Unsigned) { 84 sum = sum.AddUnsigned(aElt.MultiplyUnsigned(bElt).lower).value; 85 } else { 86 static_assert(T::category == TypeCategory::Logical); 87 sum = sum.OR(aElt.AND(bElt)); 88 } 89 ++aAt.back(); 90 ++bAt.front(); 91 } 92 elements.push_back(sum); 93 } 94 } 95 if (overflow && 96 context.languageFeatures().ShouldWarn( 97 common::UsageWarning::FoldingException)) { 98 context.messages().Say(common::UsageWarning::FoldingException, 99 "MATMUL of %s data overflowed during computation"_warn_en_US, 100 T::AsFortran()); 101 } 102 ConstantSubscripts shape; 103 if (ma->Rank() == 2) { 104 shape.push_back(rows); 105 } 106 if (mb->Rank() == 2) { 107 shape.push_back(columns); 108 } 109 return Expr<T>{Constant<T>{std::move(elements), std::move(shape)}}; 110 } 111 } // namespace Fortran::evaluate 112 #endif // FORTRAN_EVALUATE_FOLD_MATMUL_H_ 113