15e1421b2Speter klausler //===-- runtime/matmul.cpp ------------------------------------------------===// 25e1421b2Speter klausler // 35e1421b2Speter klausler // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 45e1421b2Speter klausler // See https://llvm.org/LICENSE.txt for license information. 55e1421b2Speter klausler // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65e1421b2Speter klausler // 75e1421b2Speter klausler //===----------------------------------------------------------------------===// 85e1421b2Speter klausler 95e1421b2Speter klausler // Implements all forms of MATMUL (Fortran 2018 16.9.124) 105e1421b2Speter klausler // 115e1421b2Speter klausler // There are two main entry points; one establishes a descriptor for the 125e1421b2Speter klausler // result and allocates it, and the other expects a result descriptor that 135e1421b2Speter klausler // points to existing storage. 145e1421b2Speter klausler // 155e1421b2Speter klausler // This implementation must handle all combinations of numeric types and 165e1421b2Speter klausler // kinds (100 - 165 cases depending on the target), plus all combinations 175e1421b2Speter klausler // of logical kinds (16). A single template undergoes many instantiations 185e1421b2Speter klausler // to cover all of the valid possibilities. 195e1421b2Speter klausler // 205e1421b2Speter klausler // Places where BLAS routines could be called are marked as TODO items. 215e1421b2Speter klausler 22830c0b90SPeter Klausler #include "flang/Runtime/matmul.h" 235e1421b2Speter klausler #include "terminator.h" 245e1421b2Speter klausler #include "tools.h" 2571e0261fSSlava Zakharin #include "flang/Common/optional.h" 26a5a493e1SPeter Klausler #include "flang/Runtime/c-or-cpp.h" 27830c0b90SPeter Klausler #include "flang/Runtime/cpp-type.h" 28830c0b90SPeter Klausler #include "flang/Runtime/descriptor.h" 29a5a493e1SPeter Klausler #include <cstring> 305e1421b2Speter klausler 31dd220853SSlava Zakharin namespace { 32dd220853SSlava Zakharin using namespace Fortran::runtime; 335e1421b2Speter klausler 34a5a493e1SPeter Klausler // General accumulator for any type and stride; this is not used for 35a5a493e1SPeter Klausler // contiguous numeric cases. 365e1421b2Speter klausler template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 375e1421b2Speter klausler class Accumulator { 385e1421b2Speter klausler public: 39a5a493e1SPeter Klausler using Result = AccumulationType<RCAT, RKIND>; 40b4b23ff7SSlava Zakharin RT_API_ATTRS Accumulator(const Descriptor &x, const Descriptor &y) 41b4b23ff7SSlava Zakharin : x_{x}, y_{y} {} 42b4b23ff7SSlava Zakharin RT_API_ATTRS void Accumulate( 43b4b23ff7SSlava Zakharin const SubscriptValue xAt[], const SubscriptValue yAt[]) { 445e1421b2Speter klausler if constexpr (RCAT == TypeCategory::Logical) { 455e1421b2Speter klausler sum_ = sum_ || 465e1421b2Speter klausler (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt)); 475e1421b2Speter klausler } else { 485e1421b2Speter klausler sum_ += static_cast<Result>(*x_.Element<XT>(xAt)) * 495e1421b2Speter klausler static_cast<Result>(*y_.Element<YT>(yAt)); 505e1421b2Speter klausler } 515e1421b2Speter klausler } 52b4b23ff7SSlava Zakharin RT_API_ATTRS Result GetResult() const { return sum_; } 535e1421b2Speter klausler 545e1421b2Speter klausler private: 555e1421b2Speter klausler const Descriptor &x_, &y_; 565e1421b2Speter klausler Result sum_{}; 575e1421b2Speter klausler }; 585e1421b2Speter klausler 59a5a493e1SPeter Klausler // Contiguous numeric matrix*matrix multiplication 60a5a493e1SPeter Klausler // matrix(rows,n) * matrix(n,cols) -> matrix(rows,cols) 61a5a493e1SPeter Klausler // Straightforward algorithm: 62a5a493e1SPeter Klausler // DO 1 I = 1, NROWS 63a5a493e1SPeter Klausler // DO 1 J = 1, NCOLS 64a5a493e1SPeter Klausler // RES(I,J) = 0 65a5a493e1SPeter Klausler // DO 1 K = 1, N 66a5a493e1SPeter Klausler // 1 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) 67a5a493e1SPeter Klausler // With loop distribution and transposition to avoid the inner sum 68a5a493e1SPeter Klausler // reduction and to avoid non-unit strides: 69a5a493e1SPeter Klausler // DO 1 I = 1, NROWS 70a5a493e1SPeter Klausler // DO 1 J = 1, NCOLS 71a5a493e1SPeter Klausler // 1 RES(I,J) = 0 72a5a493e1SPeter Klausler // DO 2 K = 1, N 73a5a493e1SPeter Klausler // DO 2 J = 1, NCOLS 74a5a493e1SPeter Klausler // DO 2 I = 1, NROWS 75a5a493e1SPeter Klausler // 2 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term 764d977174SSlava Zakharin template <TypeCategory RCAT, int RKIND, typename XT, typename YT, 774d977174SSlava Zakharin bool X_HAS_STRIDED_COLUMNS, bool Y_HAS_STRIDED_COLUMNS> 78b4b23ff7SSlava Zakharin inline RT_API_ATTRS void MatrixTimesMatrix( 79b4b23ff7SSlava Zakharin CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, 80b4b23ff7SSlava Zakharin SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, 81b4b23ff7SSlava Zakharin SubscriptValue n, std::size_t xColumnByteStride = 0, 824d977174SSlava Zakharin std::size_t yColumnByteStride = 0) { 83a5a493e1SPeter Klausler using ResultType = CppTypeFor<RCAT, RKIND>; 84a5a493e1SPeter Klausler std::memset(product, 0, rows * cols * sizeof *product); 85a5a493e1SPeter Klausler const XT *RESTRICT xp0{x}; 86a5a493e1SPeter Klausler for (SubscriptValue k{0}; k < n; ++k) { 87a5a493e1SPeter Klausler ResultType *RESTRICT p{product}; 88a5a493e1SPeter Klausler for (SubscriptValue j{0}; j < cols; ++j) { 89a5a493e1SPeter Klausler const XT *RESTRICT xp{xp0}; 904d977174SSlava Zakharin ResultType yv; 914d977174SSlava Zakharin if constexpr (!Y_HAS_STRIDED_COLUMNS) { 924d977174SSlava Zakharin yv = static_cast<ResultType>(y[k + j * n]); 934d977174SSlava Zakharin } else { 944d977174SSlava Zakharin yv = static_cast<ResultType>(reinterpret_cast<const YT *>( 954d977174SSlava Zakharin reinterpret_cast<const char *>(y) + j * yColumnByteStride)[k]); 964d977174SSlava Zakharin } 97a5a493e1SPeter Klausler for (SubscriptValue i{0}; i < rows; ++i) { 98a5a493e1SPeter Klausler *p++ += static_cast<ResultType>(*xp++) * yv; 99a5a493e1SPeter Klausler } 100a5a493e1SPeter Klausler } 1014d977174SSlava Zakharin if constexpr (!X_HAS_STRIDED_COLUMNS) { 102a5a493e1SPeter Klausler xp0 += rows; 1034d977174SSlava Zakharin } else { 1044d977174SSlava Zakharin xp0 = reinterpret_cast<const XT *>( 1054d977174SSlava Zakharin reinterpret_cast<const char *>(xp0) + xColumnByteStride); 1064d977174SSlava Zakharin } 1074d977174SSlava Zakharin } 1084d977174SSlava Zakharin } 1094d977174SSlava Zakharin 1104d977174SSlava Zakharin template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 111b4b23ff7SSlava Zakharin inline RT_API_ATTRS void MatrixTimesMatrixHelper( 112b4b23ff7SSlava Zakharin CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, 113b4b23ff7SSlava Zakharin SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, 11471e0261fSSlava Zakharin SubscriptValue n, Fortran::common::optional<std::size_t> xColumnByteStride, 11571e0261fSSlava Zakharin Fortran::common::optional<std::size_t> yColumnByteStride) { 1164d977174SSlava Zakharin if (!xColumnByteStride) { 1174d977174SSlava Zakharin if (!yColumnByteStride) { 1184d977174SSlava Zakharin MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, false>( 1194d977174SSlava Zakharin product, rows, cols, x, y, n); 1204d977174SSlava Zakharin } else { 1214d977174SSlava Zakharin MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, true>( 1224d977174SSlava Zakharin product, rows, cols, x, y, n, 0, *yColumnByteStride); 1234d977174SSlava Zakharin } 1244d977174SSlava Zakharin } else { 1254d977174SSlava Zakharin if (!yColumnByteStride) { 1264d977174SSlava Zakharin MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, false>( 1274d977174SSlava Zakharin product, rows, cols, x, y, n, *xColumnByteStride); 1284d977174SSlava Zakharin } else { 1294d977174SSlava Zakharin MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, true>( 1304d977174SSlava Zakharin product, rows, cols, x, y, n, *xColumnByteStride, *yColumnByteStride); 1314d977174SSlava Zakharin } 132a5a493e1SPeter Klausler } 133a5a493e1SPeter Klausler } 134a5a493e1SPeter Klausler 135a5a493e1SPeter Klausler // Contiguous numeric matrix*vector multiplication 136a5a493e1SPeter Klausler // matrix(rows,n) * column vector(n) -> column vector(rows) 137a5a493e1SPeter Klausler // Straightforward algorithm: 138a5a493e1SPeter Klausler // DO 1 J = 1, NROWS 139a5a493e1SPeter Klausler // RES(J) = 0 140a5a493e1SPeter Klausler // DO 1 K = 1, N 141a5a493e1SPeter Klausler // 1 RES(J) = RES(J) + X(J,K)*Y(K) 142a5a493e1SPeter Klausler // With loop distribution and transposition to avoid the inner 143a5a493e1SPeter Klausler // sum reduction and to avoid non-unit strides: 144a5a493e1SPeter Klausler // DO 1 J = 1, NROWS 145a5a493e1SPeter Klausler // 1 RES(J) = 0 146a5a493e1SPeter Klausler // DO 2 K = 1, N 147a5a493e1SPeter Klausler // DO 2 J = 1, NROWS 148a5a493e1SPeter Klausler // 2 RES(J) = RES(J) + X(J,K)*Y(K) 1494d977174SSlava Zakharin template <TypeCategory RCAT, int RKIND, typename XT, typename YT, 1504d977174SSlava Zakharin bool X_HAS_STRIDED_COLUMNS> 151b4b23ff7SSlava Zakharin inline RT_API_ATTRS void MatrixTimesVector( 152b4b23ff7SSlava Zakharin CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, 153b4b23ff7SSlava Zakharin SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y, 154b4b23ff7SSlava Zakharin std::size_t xColumnByteStride = 0) { 155a5a493e1SPeter Klausler using ResultType = CppTypeFor<RCAT, RKIND>; 156a5a493e1SPeter Klausler std::memset(product, 0, rows * sizeof *product); 1574d977174SSlava Zakharin [[maybe_unused]] const XT *RESTRICT xp0{x}; 158a5a493e1SPeter Klausler for (SubscriptValue k{0}; k < n; ++k) { 159a5a493e1SPeter Klausler ResultType *RESTRICT p{product}; 160a5a493e1SPeter Klausler auto yv{static_cast<ResultType>(*y++)}; 161a5a493e1SPeter Klausler for (SubscriptValue j{0}; j < rows; ++j) { 162a5a493e1SPeter Klausler *p++ += static_cast<ResultType>(*x++) * yv; 163a5a493e1SPeter Klausler } 1644d977174SSlava Zakharin if constexpr (X_HAS_STRIDED_COLUMNS) { 1654d977174SSlava Zakharin xp0 = reinterpret_cast<const XT *>( 1664d977174SSlava Zakharin reinterpret_cast<const char *>(xp0) + xColumnByteStride); 1674d977174SSlava Zakharin x = xp0; 1684d977174SSlava Zakharin } 1694d977174SSlava Zakharin } 1704d977174SSlava Zakharin } 1714d977174SSlava Zakharin 1724d977174SSlava Zakharin template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 173b4b23ff7SSlava Zakharin inline RT_API_ATTRS void MatrixTimesVectorHelper( 174b4b23ff7SSlava Zakharin CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, 175b4b23ff7SSlava Zakharin SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y, 17671e0261fSSlava Zakharin Fortran::common::optional<std::size_t> xColumnByteStride) { 1774d977174SSlava Zakharin if (!xColumnByteStride) { 1784d977174SSlava Zakharin MatrixTimesVector<RCAT, RKIND, XT, YT, false>(product, rows, n, x, y); 1794d977174SSlava Zakharin } else { 1804d977174SSlava Zakharin MatrixTimesVector<RCAT, RKIND, XT, YT, true>( 1814d977174SSlava Zakharin product, rows, n, x, y, *xColumnByteStride); 182a5a493e1SPeter Klausler } 183a5a493e1SPeter Klausler } 184a5a493e1SPeter Klausler 185a5a493e1SPeter Klausler // Contiguous numeric vector*matrix multiplication 186a5a493e1SPeter Klausler // row vector(n) * matrix(n,cols) -> row vector(cols) 187a5a493e1SPeter Klausler // Straightforward algorithm: 188a5a493e1SPeter Klausler // DO 1 J = 1, NCOLS 189a5a493e1SPeter Klausler // RES(J) = 0 190a5a493e1SPeter Klausler // DO 1 K = 1, N 191a5a493e1SPeter Klausler // 1 RES(J) = RES(J) + X(K)*Y(K,J) 192a5a493e1SPeter Klausler // With loop distribution and transposition to avoid the inner 193a5a493e1SPeter Klausler // sum reduction and one non-unit stride (the other remains): 194a5a493e1SPeter Klausler // DO 1 J = 1, NCOLS 195a5a493e1SPeter Klausler // 1 RES(J) = 0 196a5a493e1SPeter Klausler // DO 2 K = 1, N 197a5a493e1SPeter Klausler // DO 2 J = 1, NCOLS 198a5a493e1SPeter Klausler // 2 RES(J) = RES(J) + X(K)*Y(K,J) 1994d977174SSlava Zakharin template <TypeCategory RCAT, int RKIND, typename XT, typename YT, 2004d977174SSlava Zakharin bool Y_HAS_STRIDED_COLUMNS> 201b4b23ff7SSlava Zakharin inline RT_API_ATTRS void VectorTimesMatrix( 202b4b23ff7SSlava Zakharin CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue n, 203b4b23ff7SSlava Zakharin SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, 204b4b23ff7SSlava Zakharin std::size_t yColumnByteStride = 0) { 205a5a493e1SPeter Klausler using ResultType = CppTypeFor<RCAT, RKIND>; 206a5a493e1SPeter Klausler std::memset(product, 0, cols * sizeof *product); 207a5a493e1SPeter Klausler for (SubscriptValue k{0}; k < n; ++k) { 208a5a493e1SPeter Klausler ResultType *RESTRICT p{product}; 209a5a493e1SPeter Klausler auto xv{static_cast<ResultType>(*x++)}; 210a5a493e1SPeter Klausler const YT *RESTRICT yp{&y[k]}; 211a5a493e1SPeter Klausler for (SubscriptValue j{0}; j < cols; ++j) { 212a5a493e1SPeter Klausler *p++ += xv * static_cast<ResultType>(*yp); 2134d977174SSlava Zakharin if constexpr (!Y_HAS_STRIDED_COLUMNS) { 214a5a493e1SPeter Klausler yp += n; 2154d977174SSlava Zakharin } else { 2164d977174SSlava Zakharin yp = reinterpret_cast<const YT *>( 2174d977174SSlava Zakharin reinterpret_cast<const char *>(yp) + yColumnByteStride); 218a5a493e1SPeter Klausler } 219a5a493e1SPeter Klausler } 220a5a493e1SPeter Klausler } 2214d977174SSlava Zakharin } 2224d977174SSlava Zakharin 2234d977174SSlava Zakharin template <TypeCategory RCAT, int RKIND, typename XT, typename YT, 2244d977174SSlava Zakharin bool SPARSE_COLUMNS = false> 225b4b23ff7SSlava Zakharin inline RT_API_ATTRS void VectorTimesMatrixHelper( 226b4b23ff7SSlava Zakharin CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue n, 227b4b23ff7SSlava Zakharin SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, 22871e0261fSSlava Zakharin Fortran::common::optional<std::size_t> yColumnByteStride) { 2294d977174SSlava Zakharin if (!yColumnByteStride) { 2304d977174SSlava Zakharin VectorTimesMatrix<RCAT, RKIND, XT, YT, false>(product, n, cols, x, y); 2314d977174SSlava Zakharin } else { 2324d977174SSlava Zakharin VectorTimesMatrix<RCAT, RKIND, XT, YT, true>( 2334d977174SSlava Zakharin product, n, cols, x, y, *yColumnByteStride); 2344d977174SSlava Zakharin } 2354d977174SSlava Zakharin } 236a5a493e1SPeter Klausler 2375e1421b2Speter klausler // Implements an instance of MATMUL for given argument types. 2385e1421b2Speter klausler template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT, 2395e1421b2Speter klausler typename YT> 240b4b23ff7SSlava Zakharin static inline RT_API_ATTRS void DoMatmul( 2415e1421b2Speter klausler std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result, 2425e1421b2Speter klausler const Descriptor &x, const Descriptor &y, Terminator &terminator) { 2435e1421b2Speter klausler int xRank{x.rank()}; 2445e1421b2Speter klausler int yRank{y.rank()}; 2455e1421b2Speter klausler int resRank{xRank + yRank - 2}; 2465e1421b2Speter klausler if (xRank * yRank != 2 * resRank) { 2475e1421b2Speter klausler terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank); 2485e1421b2Speter klausler } 2495e1421b2Speter klausler SubscriptValue extent[2]{ 2505e1421b2Speter klausler xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(), 2515e1421b2Speter klausler resRank == 2 ? y.GetDimension(1).Extent() : 0}; 2525e1421b2Speter klausler if constexpr (IS_ALLOCATING) { 2535e1421b2Speter klausler result.Establish( 2545e1421b2Speter klausler RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable); 2555e1421b2Speter klausler for (int j{0}; j < resRank; ++j) { 2565e1421b2Speter klausler result.GetDimension(j).SetBounds(1, extent[j]); 2575e1421b2Speter klausler } 2585e1421b2Speter klausler if (int stat{result.Allocate()}) { 2595e1421b2Speter klausler terminator.Crash( 2605e1421b2Speter klausler "MATMUL: could not allocate memory for result; STAT=%d", stat); 2615e1421b2Speter klausler } 2625e1421b2Speter klausler } else { 2635e1421b2Speter klausler RUNTIME_CHECK(terminator, resRank == result.rank()); 264a5a493e1SPeter Klausler RUNTIME_CHECK( 265a5a493e1SPeter Klausler terminator, result.ElementBytes() == static_cast<std::size_t>(RKIND)); 2665e1421b2Speter klausler RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]); 2675e1421b2Speter klausler RUNTIME_CHECK(terminator, 2685e1421b2Speter klausler resRank == 1 || result.GetDimension(1).Extent() == extent[1]); 2695e1421b2Speter klausler } 2705e1421b2Speter klausler SubscriptValue n{x.GetDimension(xRank - 1).Extent()}; 2715e1421b2Speter klausler if (n != y.GetDimension(0).Extent()) { 272e55aa027SPete Steinfeld // At this point, we know that there's a shape error. There are three 273e55aa027SPete Steinfeld // possibilities, x is rank 1, y is rank 1, or both are rank 2. 274e55aa027SPete Steinfeld if (xRank == 1) { 275e55aa027SPete Steinfeld terminator.Crash("MATMUL: unacceptable operand shapes (%jd, %jdx%jd)", 276e55aa027SPete Steinfeld static_cast<std::intmax_t>(n), 277e55aa027SPete Steinfeld static_cast<std::intmax_t>(y.GetDimension(0).Extent()), 278e55aa027SPete Steinfeld static_cast<std::intmax_t>(y.GetDimension(1).Extent())); 279e55aa027SPete Steinfeld } else if (yRank == 1) { 280e55aa027SPete Steinfeld terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jd)", 281e55aa027SPete Steinfeld static_cast<std::intmax_t>(x.GetDimension(0).Extent()), 282e55aa027SPete Steinfeld static_cast<std::intmax_t>(n), 283e55aa027SPete Steinfeld static_cast<std::intmax_t>(y.GetDimension(0).Extent())); 284e55aa027SPete Steinfeld } else { 285f5884fd9SPeter Klausler terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)", 286f5884fd9SPeter Klausler static_cast<std::intmax_t>(x.GetDimension(0).Extent()), 2875e1421b2Speter klausler static_cast<std::intmax_t>(n), 288f5884fd9SPeter Klausler static_cast<std::intmax_t>(y.GetDimension(0).Extent()), 289f5884fd9SPeter Klausler static_cast<std::intmax_t>(y.GetDimension(1).Extent())); 2905e1421b2Speter klausler } 291e55aa027SPete Steinfeld } 292a5a493e1SPeter Klausler using WriteResult = 293a5a493e1SPeter Klausler CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT, 294a5a493e1SPeter Klausler RKIND>; 295a5a493e1SPeter Klausler if constexpr (RCAT != TypeCategory::Logical) { 2964d977174SSlava Zakharin if (x.IsContiguous(1) && y.IsContiguous(1) && 297a5a493e1SPeter Klausler (IS_ALLOCATING || result.IsContiguous())) { 2984d977174SSlava Zakharin // Contiguous numeric matrices (maybe with columns 2994d977174SSlava Zakharin // separated by a stride). 30071e0261fSSlava Zakharin Fortran::common::optional<std::size_t> xColumnByteStride; 3014d977174SSlava Zakharin if (!x.IsContiguous()) { 3024d977174SSlava Zakharin // X's columns are strided. 3034d977174SSlava Zakharin SubscriptValue xAt[2]{}; 3044d977174SSlava Zakharin x.GetLowerBounds(xAt); 3054d977174SSlava Zakharin xAt[1]++; 3064d977174SSlava Zakharin xColumnByteStride = x.SubscriptsToByteOffset(xAt); 3074d977174SSlava Zakharin } 30871e0261fSSlava Zakharin Fortran::common::optional<std::size_t> yColumnByteStride; 3094d977174SSlava Zakharin if (!y.IsContiguous()) { 3104d977174SSlava Zakharin // Y's columns are strided. 3114d977174SSlava Zakharin SubscriptValue yAt[2]{}; 3124d977174SSlava Zakharin y.GetLowerBounds(yAt); 3134d977174SSlava Zakharin yAt[1]++; 3144d977174SSlava Zakharin yColumnByteStride = y.SubscriptsToByteOffset(yAt); 3154d977174SSlava Zakharin } 3164d977174SSlava Zakharin // Note that BLAS GEMM can be used for the strided 3174d977174SSlava Zakharin // columns by setting proper leading dimension size. 3184d977174SSlava Zakharin // This implies that the column stride is divisible 3194d977174SSlava Zakharin // by the element size, which is usually true. 3205e1421b2Speter klausler if (resRank == 2) { // M*M -> M 321a5a493e1SPeter Klausler if (std::is_same_v<XT, YT>) { 3225e1421b2Speter klausler if constexpr (std::is_same_v<XT, float>) { 3235e1421b2Speter klausler // TODO: call BLAS-3 SGEMM 3244d977174SSlava Zakharin // TODO: try using CUTLASS for device. 3255e1421b2Speter klausler } else if constexpr (std::is_same_v<XT, double>) { 3265e1421b2Speter klausler // TODO: call BLAS-3 DGEMM 327104f3c18SSlava Zakharin } else if constexpr (std::is_same_v<XT, rtcmplx::complex<float>>) { 3285e1421b2Speter klausler // TODO: call BLAS-3 CGEMM 329104f3c18SSlava Zakharin } else if constexpr (std::is_same_v<XT, rtcmplx::complex<double>>) { 3305e1421b2Speter klausler // TODO: call BLAS-3 ZGEMM 3315e1421b2Speter klausler } 3325e1421b2Speter klausler } 3334d977174SSlava Zakharin MatrixTimesMatrixHelper<RCAT, RKIND, XT, YT>( 334a5a493e1SPeter Klausler result.template OffsetElement<WriteResult>(), extent[0], extent[1], 3354d977174SSlava Zakharin x.OffsetElement<XT>(), y.OffsetElement<YT>(), n, xColumnByteStride, 3364d977174SSlava Zakharin yColumnByteStride); 337a5a493e1SPeter Klausler return; 338a5a493e1SPeter Klausler } else if (xRank == 2) { // M*V -> V 339a5a493e1SPeter Klausler if (std::is_same_v<XT, YT>) { 340a5a493e1SPeter Klausler if constexpr (std::is_same_v<XT, float>) { 341a5a493e1SPeter Klausler // TODO: call BLAS-2 SGEMV(x,y) 342a5a493e1SPeter Klausler } else if constexpr (std::is_same_v<XT, double>) { 343a5a493e1SPeter Klausler // TODO: call BLAS-2 DGEMV(x,y) 344104f3c18SSlava Zakharin } else if constexpr (std::is_same_v<XT, rtcmplx::complex<float>>) { 345a5a493e1SPeter Klausler // TODO: call BLAS-2 CGEMV(x,y) 346104f3c18SSlava Zakharin } else if constexpr (std::is_same_v<XT, rtcmplx::complex<double>>) { 347a5a493e1SPeter Klausler // TODO: call BLAS-2 ZGEMV(x,y) 348a5a493e1SPeter Klausler } 349a5a493e1SPeter Klausler } 3504d977174SSlava Zakharin MatrixTimesVectorHelper<RCAT, RKIND, XT, YT>( 351a5a493e1SPeter Klausler result.template OffsetElement<WriteResult>(), extent[0], n, 3524d977174SSlava Zakharin x.OffsetElement<XT>(), y.OffsetElement<YT>(), xColumnByteStride); 353a5a493e1SPeter Klausler return; 354a5a493e1SPeter Klausler } else { // V*M -> V 355a5a493e1SPeter Klausler if (std::is_same_v<XT, YT>) { 356a5a493e1SPeter Klausler if constexpr (std::is_same_v<XT, float>) { 357a5a493e1SPeter Klausler // TODO: call BLAS-2 SGEMV(y,x) 358a5a493e1SPeter Klausler } else if constexpr (std::is_same_v<XT, double>) { 359a5a493e1SPeter Klausler // TODO: call BLAS-2 DGEMV(y,x) 360104f3c18SSlava Zakharin } else if constexpr (std::is_same_v<XT, rtcmplx::complex<float>>) { 361a5a493e1SPeter Klausler // TODO: call BLAS-2 CGEMV(y,x) 362104f3c18SSlava Zakharin } else if constexpr (std::is_same_v<XT, rtcmplx::complex<double>>) { 363a5a493e1SPeter Klausler // TODO: call BLAS-2 ZGEMV(y,x) 364a5a493e1SPeter Klausler } 365a5a493e1SPeter Klausler } 3664d977174SSlava Zakharin VectorTimesMatrixHelper<RCAT, RKIND, XT, YT>( 367a5a493e1SPeter Klausler result.template OffsetElement<WriteResult>(), n, extent[0], 3684d977174SSlava Zakharin x.OffsetElement<XT>(), y.OffsetElement<YT>(), yColumnByteStride); 369a5a493e1SPeter Klausler return; 370a5a493e1SPeter Klausler } 371a5a493e1SPeter Klausler } 372a5a493e1SPeter Klausler } 373a5a493e1SPeter Klausler // General algorithms for LOGICAL and noncontiguity 374a5a493e1SPeter Klausler SubscriptValue xAt[2], yAt[2], resAt[2]; 375a5a493e1SPeter Klausler x.GetLowerBounds(xAt); 376a5a493e1SPeter Klausler y.GetLowerBounds(yAt); 377a5a493e1SPeter Klausler result.GetLowerBounds(resAt); 378a5a493e1SPeter Klausler if (resRank == 2) { // M*M -> M 3795e1421b2Speter klausler SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]}; 3805e1421b2Speter klausler for (SubscriptValue i{0}; i < extent[0]; ++i) { 3815e1421b2Speter klausler for (SubscriptValue j{0}; j < extent[1]; ++j) { 3825e1421b2Speter klausler Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; 3835e1421b2Speter klausler yAt[1] = y1 + j; 3845e1421b2Speter klausler for (SubscriptValue k{0}; k < n; ++k) { 3855e1421b2Speter klausler xAt[1] = x1 + k; 3865e1421b2Speter klausler yAt[0] = y0 + k; 3875e1421b2Speter klausler accumulator.Accumulate(xAt, yAt); 3885e1421b2Speter klausler } 3895e1421b2Speter klausler resAt[1] = res1 + j; 3905e1421b2Speter klausler *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); 3915e1421b2Speter klausler } 3925e1421b2Speter klausler ++resAt[0]; 3935e1421b2Speter klausler ++xAt[0]; 3945e1421b2Speter klausler } 395a5a493e1SPeter Klausler } else if (xRank == 2) { // M*V -> V 3965e1421b2Speter klausler SubscriptValue x1{xAt[1]}, y0{yAt[0]}; 3975e1421b2Speter klausler for (SubscriptValue j{0}; j < extent[0]; ++j) { 3985e1421b2Speter klausler Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; 3995e1421b2Speter klausler for (SubscriptValue k{0}; k < n; ++k) { 4005e1421b2Speter klausler xAt[1] = x1 + k; 4015e1421b2Speter klausler yAt[0] = y0 + k; 4025e1421b2Speter klausler accumulator.Accumulate(xAt, yAt); 4035e1421b2Speter klausler } 4045e1421b2Speter klausler *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); 4055e1421b2Speter klausler ++resAt[0]; 4065e1421b2Speter klausler ++xAt[0]; 4075e1421b2Speter klausler } 4085e1421b2Speter klausler } else { // V*M -> V 4095e1421b2Speter klausler SubscriptValue x0{xAt[0]}, y0{yAt[0]}; 4105e1421b2Speter klausler for (SubscriptValue j{0}; j < extent[0]; ++j) { 4115e1421b2Speter klausler Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; 4125e1421b2Speter klausler for (SubscriptValue k{0}; k < n; ++k) { 4135e1421b2Speter klausler xAt[0] = x0 + k; 4145e1421b2Speter klausler yAt[0] = y0 + k; 4155e1421b2Speter klausler accumulator.Accumulate(xAt, yAt); 4165e1421b2Speter klausler } 4175e1421b2Speter klausler *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); 4185e1421b2Speter klausler ++resAt[0]; 4195e1421b2Speter klausler ++yAt[1]; 4205e1421b2Speter klausler } 4215e1421b2Speter klausler } 4225e1421b2Speter klausler } 4235e1421b2Speter klausler 424dd220853SSlava Zakharin template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT, 425dd220853SSlava Zakharin int YKIND> 426dd220853SSlava Zakharin struct MatmulHelper { 427dd220853SSlava Zakharin using ResultDescriptor = 428dd220853SSlava Zakharin std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>; 429dd220853SSlava Zakharin RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x, 430dd220853SSlava Zakharin const Descriptor &y, const char *sourceFile, int line) const { 431dd220853SSlava Zakharin Terminator terminator{sourceFile, line}; 432dd220853SSlava Zakharin auto xCatKind{x.type().GetCategoryAndKind()}; 433dd220853SSlava Zakharin auto yCatKind{y.type().GetCategoryAndKind()}; 434dd220853SSlava Zakharin RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); 435*fc97d2e6SPeter Klausler RUNTIME_CHECK(terminator, 436*fc97d2e6SPeter Klausler (xCatKind->first == XCAT && yCatKind->first == YCAT) || 437*fc97d2e6SPeter Klausler (XCAT == TypeCategory::Integer && YCAT == TypeCategory::Integer && 438*fc97d2e6SPeter Klausler ((xCatKind->first == TypeCategory::Integer || 439*fc97d2e6SPeter Klausler xCatKind->first == TypeCategory::Unsigned) && 440*fc97d2e6SPeter Klausler (yCatKind->first == TypeCategory::Integer || 441*fc97d2e6SPeter Klausler yCatKind->first == TypeCategory::Unsigned)))); 442dd220853SSlava Zakharin if constexpr (constexpr auto resultType{ 443dd220853SSlava Zakharin GetResultType(XCAT, XKIND, YCAT, YKIND)}) { 444dd220853SSlava Zakharin return DoMatmul<IS_ALLOCATING, resultType->first, resultType->second, 445dd220853SSlava Zakharin CppTypeFor<XCAT, XKIND>, CppTypeFor<YCAT, YKIND>>( 446dd220853SSlava Zakharin result, x, y, terminator); 447dd220853SSlava Zakharin } 448dd220853SSlava Zakharin terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))", 449dd220853SSlava Zakharin static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND); 450dd220853SSlava Zakharin } 451dd220853SSlava Zakharin }; 452dd220853SSlava Zakharin } // namespace 453dd220853SSlava Zakharin 454dd220853SSlava Zakharin namespace Fortran::runtime { 4555e1421b2Speter klausler extern "C" { 45676facde3SSlava Zakharin RT_EXT_API_GROUP_BEGIN 45776facde3SSlava Zakharin 458dd220853SSlava Zakharin #define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \ 459dd220853SSlava Zakharin void RTDEF(Matmul##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \ 460dd220853SSlava Zakharin const Descriptor &x, const Descriptor &y, const char *sourceFile, \ 461dd220853SSlava Zakharin int line) { \ 462dd220853SSlava Zakharin MatmulHelper<true, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \ 463dd220853SSlava Zakharin YKIND>{}(result, x, y, sourceFile, line); \ 464dd220853SSlava Zakharin } 465dd220853SSlava Zakharin 466dd220853SSlava Zakharin #define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \ 467dd220853SSlava Zakharin void RTDEF(MatmulDirect##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \ 468dd220853SSlava Zakharin const Descriptor &x, const Descriptor &y, const char *sourceFile, \ 469dd220853SSlava Zakharin int line) { \ 470dd220853SSlava Zakharin MatmulHelper<false, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \ 471dd220853SSlava Zakharin YKIND>{}(result, x, y, sourceFile, line); \ 472dd220853SSlava Zakharin } 473dd220853SSlava Zakharin 4748ce1aed5SSlava Zakharin #define MATMUL_FORCE_ALL_TYPES 0 4758ce1aed5SSlava Zakharin 476dd220853SSlava Zakharin #include "flang/Runtime/matmul-instances.inc" 477dd220853SSlava Zakharin 47876facde3SSlava Zakharin RT_EXT_API_GROUP_END 4795e1421b2Speter klausler } // extern "C" 4805e1421b2Speter klausler } // namespace Fortran::runtime 481