xref: /llvm-project/flang/runtime/matmul.cpp (revision fc97d2e68b03bc2979395e84b645e5b3ba35aecd)
15e1421b2Speter klausler //===-- runtime/matmul.cpp ------------------------------------------------===//
25e1421b2Speter klausler //
35e1421b2Speter klausler // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45e1421b2Speter klausler // See https://llvm.org/LICENSE.txt for license information.
55e1421b2Speter klausler // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65e1421b2Speter klausler //
75e1421b2Speter klausler //===----------------------------------------------------------------------===//
85e1421b2Speter klausler 
95e1421b2Speter klausler // Implements all forms of MATMUL (Fortran 2018 16.9.124)
105e1421b2Speter klausler //
115e1421b2Speter klausler // There are two main entry points; one establishes a descriptor for the
125e1421b2Speter klausler // result and allocates it, and the other expects a result descriptor that
135e1421b2Speter klausler // points to existing storage.
145e1421b2Speter klausler //
155e1421b2Speter klausler // This implementation must handle all combinations of numeric types and
165e1421b2Speter klausler // kinds (100 - 165 cases depending on the target), plus all combinations
175e1421b2Speter klausler // of logical kinds (16).  A single template undergoes many instantiations
185e1421b2Speter klausler // to cover all of the valid possibilities.
195e1421b2Speter klausler //
205e1421b2Speter klausler // Places where BLAS routines could be called are marked as TODO items.
215e1421b2Speter klausler 
22830c0b90SPeter Klausler #include "flang/Runtime/matmul.h"
235e1421b2Speter klausler #include "terminator.h"
245e1421b2Speter klausler #include "tools.h"
2571e0261fSSlava Zakharin #include "flang/Common/optional.h"
26a5a493e1SPeter Klausler #include "flang/Runtime/c-or-cpp.h"
27830c0b90SPeter Klausler #include "flang/Runtime/cpp-type.h"
28830c0b90SPeter Klausler #include "flang/Runtime/descriptor.h"
29a5a493e1SPeter Klausler #include <cstring>
305e1421b2Speter klausler 
31dd220853SSlava Zakharin namespace {
32dd220853SSlava Zakharin using namespace Fortran::runtime;
335e1421b2Speter klausler 
34a5a493e1SPeter Klausler // General accumulator for any type and stride; this is not used for
35a5a493e1SPeter Klausler // contiguous numeric cases.
365e1421b2Speter klausler template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
375e1421b2Speter klausler class Accumulator {
385e1421b2Speter klausler public:
39a5a493e1SPeter Klausler   using Result = AccumulationType<RCAT, RKIND>;
40b4b23ff7SSlava Zakharin   RT_API_ATTRS Accumulator(const Descriptor &x, const Descriptor &y)
41b4b23ff7SSlava Zakharin       : x_{x}, y_{y} {}
42b4b23ff7SSlava Zakharin   RT_API_ATTRS void Accumulate(
43b4b23ff7SSlava Zakharin       const SubscriptValue xAt[], const SubscriptValue yAt[]) {
445e1421b2Speter klausler     if constexpr (RCAT == TypeCategory::Logical) {
455e1421b2Speter klausler       sum_ = sum_ ||
465e1421b2Speter klausler           (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt));
475e1421b2Speter klausler     } else {
485e1421b2Speter klausler       sum_ += static_cast<Result>(*x_.Element<XT>(xAt)) *
495e1421b2Speter klausler           static_cast<Result>(*y_.Element<YT>(yAt));
505e1421b2Speter klausler     }
515e1421b2Speter klausler   }
52b4b23ff7SSlava Zakharin   RT_API_ATTRS Result GetResult() const { return sum_; }
535e1421b2Speter klausler 
545e1421b2Speter klausler private:
555e1421b2Speter klausler   const Descriptor &x_, &y_;
565e1421b2Speter klausler   Result sum_{};
575e1421b2Speter klausler };
585e1421b2Speter klausler 
59a5a493e1SPeter Klausler // Contiguous numeric matrix*matrix multiplication
60a5a493e1SPeter Klausler //   matrix(rows,n) * matrix(n,cols) -> matrix(rows,cols)
61a5a493e1SPeter Klausler // Straightforward algorithm:
62a5a493e1SPeter Klausler //   DO 1 I = 1, NROWS
63a5a493e1SPeter Klausler //    DO 1 J = 1, NCOLS
64a5a493e1SPeter Klausler //     RES(I,J) = 0
65a5a493e1SPeter Klausler //     DO 1 K = 1, N
66a5a493e1SPeter Klausler //   1  RES(I,J) = RES(I,J) + X(I,K)*Y(K,J)
67a5a493e1SPeter Klausler // With loop distribution and transposition to avoid the inner sum
68a5a493e1SPeter Klausler // reduction and to avoid non-unit strides:
69a5a493e1SPeter Klausler //   DO 1 I = 1, NROWS
70a5a493e1SPeter Klausler //    DO 1 J = 1, NCOLS
71a5a493e1SPeter Klausler //   1 RES(I,J) = 0
72a5a493e1SPeter Klausler //   DO 2 K = 1, N
73a5a493e1SPeter Klausler //    DO 2 J = 1, NCOLS
74a5a493e1SPeter Klausler //     DO 2 I = 1, NROWS
75a5a493e1SPeter Klausler //   2  RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term
764d977174SSlava Zakharin template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
774d977174SSlava Zakharin     bool X_HAS_STRIDED_COLUMNS, bool Y_HAS_STRIDED_COLUMNS>
78b4b23ff7SSlava Zakharin inline RT_API_ATTRS void MatrixTimesMatrix(
79b4b23ff7SSlava Zakharin     CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
80b4b23ff7SSlava Zakharin     SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y,
81b4b23ff7SSlava Zakharin     SubscriptValue n, std::size_t xColumnByteStride = 0,
824d977174SSlava Zakharin     std::size_t yColumnByteStride = 0) {
83a5a493e1SPeter Klausler   using ResultType = CppTypeFor<RCAT, RKIND>;
84a5a493e1SPeter Klausler   std::memset(product, 0, rows * cols * sizeof *product);
85a5a493e1SPeter Klausler   const XT *RESTRICT xp0{x};
86a5a493e1SPeter Klausler   for (SubscriptValue k{0}; k < n; ++k) {
87a5a493e1SPeter Klausler     ResultType *RESTRICT p{product};
88a5a493e1SPeter Klausler     for (SubscriptValue j{0}; j < cols; ++j) {
89a5a493e1SPeter Klausler       const XT *RESTRICT xp{xp0};
904d977174SSlava Zakharin       ResultType yv;
914d977174SSlava Zakharin       if constexpr (!Y_HAS_STRIDED_COLUMNS) {
924d977174SSlava Zakharin         yv = static_cast<ResultType>(y[k + j * n]);
934d977174SSlava Zakharin       } else {
944d977174SSlava Zakharin         yv = static_cast<ResultType>(reinterpret_cast<const YT *>(
954d977174SSlava Zakharin             reinterpret_cast<const char *>(y) + j * yColumnByteStride)[k]);
964d977174SSlava Zakharin       }
97a5a493e1SPeter Klausler       for (SubscriptValue i{0}; i < rows; ++i) {
98a5a493e1SPeter Klausler         *p++ += static_cast<ResultType>(*xp++) * yv;
99a5a493e1SPeter Klausler       }
100a5a493e1SPeter Klausler     }
1014d977174SSlava Zakharin     if constexpr (!X_HAS_STRIDED_COLUMNS) {
102a5a493e1SPeter Klausler       xp0 += rows;
1034d977174SSlava Zakharin     } else {
1044d977174SSlava Zakharin       xp0 = reinterpret_cast<const XT *>(
1054d977174SSlava Zakharin           reinterpret_cast<const char *>(xp0) + xColumnByteStride);
1064d977174SSlava Zakharin     }
1074d977174SSlava Zakharin   }
1084d977174SSlava Zakharin }
1094d977174SSlava Zakharin 
1104d977174SSlava Zakharin template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
111b4b23ff7SSlava Zakharin inline RT_API_ATTRS void MatrixTimesMatrixHelper(
112b4b23ff7SSlava Zakharin     CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
113b4b23ff7SSlava Zakharin     SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y,
11471e0261fSSlava Zakharin     SubscriptValue n, Fortran::common::optional<std::size_t> xColumnByteStride,
11571e0261fSSlava Zakharin     Fortran::common::optional<std::size_t> yColumnByteStride) {
1164d977174SSlava Zakharin   if (!xColumnByteStride) {
1174d977174SSlava Zakharin     if (!yColumnByteStride) {
1184d977174SSlava Zakharin       MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, false>(
1194d977174SSlava Zakharin           product, rows, cols, x, y, n);
1204d977174SSlava Zakharin     } else {
1214d977174SSlava Zakharin       MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, true>(
1224d977174SSlava Zakharin           product, rows, cols, x, y, n, 0, *yColumnByteStride);
1234d977174SSlava Zakharin     }
1244d977174SSlava Zakharin   } else {
1254d977174SSlava Zakharin     if (!yColumnByteStride) {
1264d977174SSlava Zakharin       MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, false>(
1274d977174SSlava Zakharin           product, rows, cols, x, y, n, *xColumnByteStride);
1284d977174SSlava Zakharin     } else {
1294d977174SSlava Zakharin       MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, true>(
1304d977174SSlava Zakharin           product, rows, cols, x, y, n, *xColumnByteStride, *yColumnByteStride);
1314d977174SSlava Zakharin     }
132a5a493e1SPeter Klausler   }
133a5a493e1SPeter Klausler }
134a5a493e1SPeter Klausler 
135a5a493e1SPeter Klausler // Contiguous numeric matrix*vector multiplication
136a5a493e1SPeter Klausler //   matrix(rows,n) * column vector(n) -> column vector(rows)
137a5a493e1SPeter Klausler // Straightforward algorithm:
138a5a493e1SPeter Klausler //   DO 1 J = 1, NROWS
139a5a493e1SPeter Klausler //    RES(J) = 0
140a5a493e1SPeter Klausler //    DO 1 K = 1, N
141a5a493e1SPeter Klausler //   1 RES(J) = RES(J) + X(J,K)*Y(K)
142a5a493e1SPeter Klausler // With loop distribution and transposition to avoid the inner
143a5a493e1SPeter Klausler // sum reduction and to avoid non-unit strides:
144a5a493e1SPeter Klausler //   DO 1 J = 1, NROWS
145a5a493e1SPeter Klausler //   1 RES(J) = 0
146a5a493e1SPeter Klausler //   DO 2 K = 1, N
147a5a493e1SPeter Klausler //    DO 2 J = 1, NROWS
148a5a493e1SPeter Klausler //   2 RES(J) = RES(J) + X(J,K)*Y(K)
1494d977174SSlava Zakharin template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
1504d977174SSlava Zakharin     bool X_HAS_STRIDED_COLUMNS>
151b4b23ff7SSlava Zakharin inline RT_API_ATTRS void MatrixTimesVector(
152b4b23ff7SSlava Zakharin     CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
153b4b23ff7SSlava Zakharin     SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y,
154b4b23ff7SSlava Zakharin     std::size_t xColumnByteStride = 0) {
155a5a493e1SPeter Klausler   using ResultType = CppTypeFor<RCAT, RKIND>;
156a5a493e1SPeter Klausler   std::memset(product, 0, rows * sizeof *product);
1574d977174SSlava Zakharin   [[maybe_unused]] const XT *RESTRICT xp0{x};
158a5a493e1SPeter Klausler   for (SubscriptValue k{0}; k < n; ++k) {
159a5a493e1SPeter Klausler     ResultType *RESTRICT p{product};
160a5a493e1SPeter Klausler     auto yv{static_cast<ResultType>(*y++)};
161a5a493e1SPeter Klausler     for (SubscriptValue j{0}; j < rows; ++j) {
162a5a493e1SPeter Klausler       *p++ += static_cast<ResultType>(*x++) * yv;
163a5a493e1SPeter Klausler     }
1644d977174SSlava Zakharin     if constexpr (X_HAS_STRIDED_COLUMNS) {
1654d977174SSlava Zakharin       xp0 = reinterpret_cast<const XT *>(
1664d977174SSlava Zakharin           reinterpret_cast<const char *>(xp0) + xColumnByteStride);
1674d977174SSlava Zakharin       x = xp0;
1684d977174SSlava Zakharin     }
1694d977174SSlava Zakharin   }
1704d977174SSlava Zakharin }
1714d977174SSlava Zakharin 
1724d977174SSlava Zakharin template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
173b4b23ff7SSlava Zakharin inline RT_API_ATTRS void MatrixTimesVectorHelper(
174b4b23ff7SSlava Zakharin     CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
175b4b23ff7SSlava Zakharin     SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y,
17671e0261fSSlava Zakharin     Fortran::common::optional<std::size_t> xColumnByteStride) {
1774d977174SSlava Zakharin   if (!xColumnByteStride) {
1784d977174SSlava Zakharin     MatrixTimesVector<RCAT, RKIND, XT, YT, false>(product, rows, n, x, y);
1794d977174SSlava Zakharin   } else {
1804d977174SSlava Zakharin     MatrixTimesVector<RCAT, RKIND, XT, YT, true>(
1814d977174SSlava Zakharin         product, rows, n, x, y, *xColumnByteStride);
182a5a493e1SPeter Klausler   }
183a5a493e1SPeter Klausler }
184a5a493e1SPeter Klausler 
185a5a493e1SPeter Klausler // Contiguous numeric vector*matrix multiplication
186a5a493e1SPeter Klausler //   row vector(n) * matrix(n,cols) -> row vector(cols)
187a5a493e1SPeter Klausler // Straightforward algorithm:
188a5a493e1SPeter Klausler //   DO 1 J = 1, NCOLS
189a5a493e1SPeter Klausler //    RES(J) = 0
190a5a493e1SPeter Klausler //    DO 1 K = 1, N
191a5a493e1SPeter Klausler //   1 RES(J) = RES(J) + X(K)*Y(K,J)
192a5a493e1SPeter Klausler // With loop distribution and transposition to avoid the inner
193a5a493e1SPeter Klausler // sum reduction and one non-unit stride (the other remains):
194a5a493e1SPeter Klausler //   DO 1 J = 1, NCOLS
195a5a493e1SPeter Klausler //   1 RES(J) = 0
196a5a493e1SPeter Klausler //   DO 2 K = 1, N
197a5a493e1SPeter Klausler //    DO 2 J = 1, NCOLS
198a5a493e1SPeter Klausler //   2 RES(J) = RES(J) + X(K)*Y(K,J)
1994d977174SSlava Zakharin template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
2004d977174SSlava Zakharin     bool Y_HAS_STRIDED_COLUMNS>
201b4b23ff7SSlava Zakharin inline RT_API_ATTRS void VectorTimesMatrix(
202b4b23ff7SSlava Zakharin     CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue n,
203b4b23ff7SSlava Zakharin     SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y,
204b4b23ff7SSlava Zakharin     std::size_t yColumnByteStride = 0) {
205a5a493e1SPeter Klausler   using ResultType = CppTypeFor<RCAT, RKIND>;
206a5a493e1SPeter Klausler   std::memset(product, 0, cols * sizeof *product);
207a5a493e1SPeter Klausler   for (SubscriptValue k{0}; k < n; ++k) {
208a5a493e1SPeter Klausler     ResultType *RESTRICT p{product};
209a5a493e1SPeter Klausler     auto xv{static_cast<ResultType>(*x++)};
210a5a493e1SPeter Klausler     const YT *RESTRICT yp{&y[k]};
211a5a493e1SPeter Klausler     for (SubscriptValue j{0}; j < cols; ++j) {
212a5a493e1SPeter Klausler       *p++ += xv * static_cast<ResultType>(*yp);
2134d977174SSlava Zakharin       if constexpr (!Y_HAS_STRIDED_COLUMNS) {
214a5a493e1SPeter Klausler         yp += n;
2154d977174SSlava Zakharin       } else {
2164d977174SSlava Zakharin         yp = reinterpret_cast<const YT *>(
2174d977174SSlava Zakharin             reinterpret_cast<const char *>(yp) + yColumnByteStride);
218a5a493e1SPeter Klausler       }
219a5a493e1SPeter Klausler     }
220a5a493e1SPeter Klausler   }
2214d977174SSlava Zakharin }
2224d977174SSlava Zakharin 
2234d977174SSlava Zakharin template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
2244d977174SSlava Zakharin     bool SPARSE_COLUMNS = false>
225b4b23ff7SSlava Zakharin inline RT_API_ATTRS void VectorTimesMatrixHelper(
226b4b23ff7SSlava Zakharin     CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue n,
227b4b23ff7SSlava Zakharin     SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y,
22871e0261fSSlava Zakharin     Fortran::common::optional<std::size_t> yColumnByteStride) {
2294d977174SSlava Zakharin   if (!yColumnByteStride) {
2304d977174SSlava Zakharin     VectorTimesMatrix<RCAT, RKIND, XT, YT, false>(product, n, cols, x, y);
2314d977174SSlava Zakharin   } else {
2324d977174SSlava Zakharin     VectorTimesMatrix<RCAT, RKIND, XT, YT, true>(
2334d977174SSlava Zakharin         product, n, cols, x, y, *yColumnByteStride);
2344d977174SSlava Zakharin   }
2354d977174SSlava Zakharin }
236a5a493e1SPeter Klausler 
2375e1421b2Speter klausler // Implements an instance of MATMUL for given argument types.
2385e1421b2Speter klausler template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT,
2395e1421b2Speter klausler     typename YT>
240b4b23ff7SSlava Zakharin static inline RT_API_ATTRS void DoMatmul(
2415e1421b2Speter klausler     std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result,
2425e1421b2Speter klausler     const Descriptor &x, const Descriptor &y, Terminator &terminator) {
2435e1421b2Speter klausler   int xRank{x.rank()};
2445e1421b2Speter klausler   int yRank{y.rank()};
2455e1421b2Speter klausler   int resRank{xRank + yRank - 2};
2465e1421b2Speter klausler   if (xRank * yRank != 2 * resRank) {
2475e1421b2Speter klausler     terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank);
2485e1421b2Speter klausler   }
2495e1421b2Speter klausler   SubscriptValue extent[2]{
2505e1421b2Speter klausler       xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(),
2515e1421b2Speter klausler       resRank == 2 ? y.GetDimension(1).Extent() : 0};
2525e1421b2Speter klausler   if constexpr (IS_ALLOCATING) {
2535e1421b2Speter klausler     result.Establish(
2545e1421b2Speter klausler         RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable);
2555e1421b2Speter klausler     for (int j{0}; j < resRank; ++j) {
2565e1421b2Speter klausler       result.GetDimension(j).SetBounds(1, extent[j]);
2575e1421b2Speter klausler     }
2585e1421b2Speter klausler     if (int stat{result.Allocate()}) {
2595e1421b2Speter klausler       terminator.Crash(
2605e1421b2Speter klausler           "MATMUL: could not allocate memory for result; STAT=%d", stat);
2615e1421b2Speter klausler     }
2625e1421b2Speter klausler   } else {
2635e1421b2Speter klausler     RUNTIME_CHECK(terminator, resRank == result.rank());
264a5a493e1SPeter Klausler     RUNTIME_CHECK(
265a5a493e1SPeter Klausler         terminator, result.ElementBytes() == static_cast<std::size_t>(RKIND));
2665e1421b2Speter klausler     RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]);
2675e1421b2Speter klausler     RUNTIME_CHECK(terminator,
2685e1421b2Speter klausler         resRank == 1 || result.GetDimension(1).Extent() == extent[1]);
2695e1421b2Speter klausler   }
2705e1421b2Speter klausler   SubscriptValue n{x.GetDimension(xRank - 1).Extent()};
2715e1421b2Speter klausler   if (n != y.GetDimension(0).Extent()) {
272e55aa027SPete Steinfeld     // At this point, we know that there's a shape error.  There are three
273e55aa027SPete Steinfeld     // possibilities, x is rank 1, y is rank 1, or both are rank 2.
274e55aa027SPete Steinfeld     if (xRank == 1) {
275e55aa027SPete Steinfeld       terminator.Crash("MATMUL: unacceptable operand shapes (%jd, %jdx%jd)",
276e55aa027SPete Steinfeld           static_cast<std::intmax_t>(n),
277e55aa027SPete Steinfeld           static_cast<std::intmax_t>(y.GetDimension(0).Extent()),
278e55aa027SPete Steinfeld           static_cast<std::intmax_t>(y.GetDimension(1).Extent()));
279e55aa027SPete Steinfeld     } else if (yRank == 1) {
280e55aa027SPete Steinfeld       terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jd)",
281e55aa027SPete Steinfeld           static_cast<std::intmax_t>(x.GetDimension(0).Extent()),
282e55aa027SPete Steinfeld           static_cast<std::intmax_t>(n),
283e55aa027SPete Steinfeld           static_cast<std::intmax_t>(y.GetDimension(0).Extent()));
284e55aa027SPete Steinfeld     } else {
285f5884fd9SPeter Klausler       terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
286f5884fd9SPeter Klausler           static_cast<std::intmax_t>(x.GetDimension(0).Extent()),
2875e1421b2Speter klausler           static_cast<std::intmax_t>(n),
288f5884fd9SPeter Klausler           static_cast<std::intmax_t>(y.GetDimension(0).Extent()),
289f5884fd9SPeter Klausler           static_cast<std::intmax_t>(y.GetDimension(1).Extent()));
2905e1421b2Speter klausler     }
291e55aa027SPete Steinfeld   }
292a5a493e1SPeter Klausler   using WriteResult =
293a5a493e1SPeter Klausler       CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,
294a5a493e1SPeter Klausler           RKIND>;
295a5a493e1SPeter Klausler   if constexpr (RCAT != TypeCategory::Logical) {
2964d977174SSlava Zakharin     if (x.IsContiguous(1) && y.IsContiguous(1) &&
297a5a493e1SPeter Klausler         (IS_ALLOCATING || result.IsContiguous())) {
2984d977174SSlava Zakharin       // Contiguous numeric matrices (maybe with columns
2994d977174SSlava Zakharin       // separated by a stride).
30071e0261fSSlava Zakharin       Fortran::common::optional<std::size_t> xColumnByteStride;
3014d977174SSlava Zakharin       if (!x.IsContiguous()) {
3024d977174SSlava Zakharin         // X's columns are strided.
3034d977174SSlava Zakharin         SubscriptValue xAt[2]{};
3044d977174SSlava Zakharin         x.GetLowerBounds(xAt);
3054d977174SSlava Zakharin         xAt[1]++;
3064d977174SSlava Zakharin         xColumnByteStride = x.SubscriptsToByteOffset(xAt);
3074d977174SSlava Zakharin       }
30871e0261fSSlava Zakharin       Fortran::common::optional<std::size_t> yColumnByteStride;
3094d977174SSlava Zakharin       if (!y.IsContiguous()) {
3104d977174SSlava Zakharin         // Y's columns are strided.
3114d977174SSlava Zakharin         SubscriptValue yAt[2]{};
3124d977174SSlava Zakharin         y.GetLowerBounds(yAt);
3134d977174SSlava Zakharin         yAt[1]++;
3144d977174SSlava Zakharin         yColumnByteStride = y.SubscriptsToByteOffset(yAt);
3154d977174SSlava Zakharin       }
3164d977174SSlava Zakharin       // Note that BLAS GEMM can be used for the strided
3174d977174SSlava Zakharin       // columns by setting proper leading dimension size.
3184d977174SSlava Zakharin       // This implies that the column stride is divisible
3194d977174SSlava Zakharin       // by the element size, which is usually true.
3205e1421b2Speter klausler       if (resRank == 2) { // M*M -> M
321a5a493e1SPeter Klausler         if (std::is_same_v<XT, YT>) {
3225e1421b2Speter klausler           if constexpr (std::is_same_v<XT, float>) {
3235e1421b2Speter klausler             // TODO: call BLAS-3 SGEMM
3244d977174SSlava Zakharin             // TODO: try using CUTLASS for device.
3255e1421b2Speter klausler           } else if constexpr (std::is_same_v<XT, double>) {
3265e1421b2Speter klausler             // TODO: call BLAS-3 DGEMM
327104f3c18SSlava Zakharin           } else if constexpr (std::is_same_v<XT, rtcmplx::complex<float>>) {
3285e1421b2Speter klausler             // TODO: call BLAS-3 CGEMM
329104f3c18SSlava Zakharin           } else if constexpr (std::is_same_v<XT, rtcmplx::complex<double>>) {
3305e1421b2Speter klausler             // TODO: call BLAS-3 ZGEMM
3315e1421b2Speter klausler           }
3325e1421b2Speter klausler         }
3334d977174SSlava Zakharin         MatrixTimesMatrixHelper<RCAT, RKIND, XT, YT>(
334a5a493e1SPeter Klausler             result.template OffsetElement<WriteResult>(), extent[0], extent[1],
3354d977174SSlava Zakharin             x.OffsetElement<XT>(), y.OffsetElement<YT>(), n, xColumnByteStride,
3364d977174SSlava Zakharin             yColumnByteStride);
337a5a493e1SPeter Klausler         return;
338a5a493e1SPeter Klausler       } else if (xRank == 2) { // M*V -> V
339a5a493e1SPeter Klausler         if (std::is_same_v<XT, YT>) {
340a5a493e1SPeter Klausler           if constexpr (std::is_same_v<XT, float>) {
341a5a493e1SPeter Klausler             // TODO: call BLAS-2 SGEMV(x,y)
342a5a493e1SPeter Klausler           } else if constexpr (std::is_same_v<XT, double>) {
343a5a493e1SPeter Klausler             // TODO: call BLAS-2 DGEMV(x,y)
344104f3c18SSlava Zakharin           } else if constexpr (std::is_same_v<XT, rtcmplx::complex<float>>) {
345a5a493e1SPeter Klausler             // TODO: call BLAS-2 CGEMV(x,y)
346104f3c18SSlava Zakharin           } else if constexpr (std::is_same_v<XT, rtcmplx::complex<double>>) {
347a5a493e1SPeter Klausler             // TODO: call BLAS-2 ZGEMV(x,y)
348a5a493e1SPeter Klausler           }
349a5a493e1SPeter Klausler         }
3504d977174SSlava Zakharin         MatrixTimesVectorHelper<RCAT, RKIND, XT, YT>(
351a5a493e1SPeter Klausler             result.template OffsetElement<WriteResult>(), extent[0], n,
3524d977174SSlava Zakharin             x.OffsetElement<XT>(), y.OffsetElement<YT>(), xColumnByteStride);
353a5a493e1SPeter Klausler         return;
354a5a493e1SPeter Klausler       } else { // V*M -> V
355a5a493e1SPeter Klausler         if (std::is_same_v<XT, YT>) {
356a5a493e1SPeter Klausler           if constexpr (std::is_same_v<XT, float>) {
357a5a493e1SPeter Klausler             // TODO: call BLAS-2 SGEMV(y,x)
358a5a493e1SPeter Klausler           } else if constexpr (std::is_same_v<XT, double>) {
359a5a493e1SPeter Klausler             // TODO: call BLAS-2 DGEMV(y,x)
360104f3c18SSlava Zakharin           } else if constexpr (std::is_same_v<XT, rtcmplx::complex<float>>) {
361a5a493e1SPeter Klausler             // TODO: call BLAS-2 CGEMV(y,x)
362104f3c18SSlava Zakharin           } else if constexpr (std::is_same_v<XT, rtcmplx::complex<double>>) {
363a5a493e1SPeter Klausler             // TODO: call BLAS-2 ZGEMV(y,x)
364a5a493e1SPeter Klausler           }
365a5a493e1SPeter Klausler         }
3664d977174SSlava Zakharin         VectorTimesMatrixHelper<RCAT, RKIND, XT, YT>(
367a5a493e1SPeter Klausler             result.template OffsetElement<WriteResult>(), n, extent[0],
3684d977174SSlava Zakharin             x.OffsetElement<XT>(), y.OffsetElement<YT>(), yColumnByteStride);
369a5a493e1SPeter Klausler         return;
370a5a493e1SPeter Klausler       }
371a5a493e1SPeter Klausler     }
372a5a493e1SPeter Klausler   }
373a5a493e1SPeter Klausler   // General algorithms for LOGICAL and noncontiguity
374a5a493e1SPeter Klausler   SubscriptValue xAt[2], yAt[2], resAt[2];
375a5a493e1SPeter Klausler   x.GetLowerBounds(xAt);
376a5a493e1SPeter Klausler   y.GetLowerBounds(yAt);
377a5a493e1SPeter Klausler   result.GetLowerBounds(resAt);
378a5a493e1SPeter Klausler   if (resRank == 2) { // M*M -> M
3795e1421b2Speter klausler     SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]};
3805e1421b2Speter klausler     for (SubscriptValue i{0}; i < extent[0]; ++i) {
3815e1421b2Speter klausler       for (SubscriptValue j{0}; j < extent[1]; ++j) {
3825e1421b2Speter klausler         Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
3835e1421b2Speter klausler         yAt[1] = y1 + j;
3845e1421b2Speter klausler         for (SubscriptValue k{0}; k < n; ++k) {
3855e1421b2Speter klausler           xAt[1] = x1 + k;
3865e1421b2Speter klausler           yAt[0] = y0 + k;
3875e1421b2Speter klausler           accumulator.Accumulate(xAt, yAt);
3885e1421b2Speter klausler         }
3895e1421b2Speter klausler         resAt[1] = res1 + j;
3905e1421b2Speter klausler         *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
3915e1421b2Speter klausler       }
3925e1421b2Speter klausler       ++resAt[0];
3935e1421b2Speter klausler       ++xAt[0];
3945e1421b2Speter klausler     }
395a5a493e1SPeter Klausler   } else if (xRank == 2) { // M*V -> V
3965e1421b2Speter klausler     SubscriptValue x1{xAt[1]}, y0{yAt[0]};
3975e1421b2Speter klausler     for (SubscriptValue j{0}; j < extent[0]; ++j) {
3985e1421b2Speter klausler       Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
3995e1421b2Speter klausler       for (SubscriptValue k{0}; k < n; ++k) {
4005e1421b2Speter klausler         xAt[1] = x1 + k;
4015e1421b2Speter klausler         yAt[0] = y0 + k;
4025e1421b2Speter klausler         accumulator.Accumulate(xAt, yAt);
4035e1421b2Speter klausler       }
4045e1421b2Speter klausler       *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
4055e1421b2Speter klausler       ++resAt[0];
4065e1421b2Speter klausler       ++xAt[0];
4075e1421b2Speter klausler     }
4085e1421b2Speter klausler   } else { // V*M -> V
4095e1421b2Speter klausler     SubscriptValue x0{xAt[0]}, y0{yAt[0]};
4105e1421b2Speter klausler     for (SubscriptValue j{0}; j < extent[0]; ++j) {
4115e1421b2Speter klausler       Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
4125e1421b2Speter klausler       for (SubscriptValue k{0}; k < n; ++k) {
4135e1421b2Speter klausler         xAt[0] = x0 + k;
4145e1421b2Speter klausler         yAt[0] = y0 + k;
4155e1421b2Speter klausler         accumulator.Accumulate(xAt, yAt);
4165e1421b2Speter klausler       }
4175e1421b2Speter klausler       *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
4185e1421b2Speter klausler       ++resAt[0];
4195e1421b2Speter klausler       ++yAt[1];
4205e1421b2Speter klausler     }
4215e1421b2Speter klausler   }
4225e1421b2Speter klausler }
4235e1421b2Speter klausler 
424dd220853SSlava Zakharin template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT,
425dd220853SSlava Zakharin     int YKIND>
426dd220853SSlava Zakharin struct MatmulHelper {
427dd220853SSlava Zakharin   using ResultDescriptor =
428dd220853SSlava Zakharin       std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
429dd220853SSlava Zakharin   RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x,
430dd220853SSlava Zakharin       const Descriptor &y, const char *sourceFile, int line) const {
431dd220853SSlava Zakharin     Terminator terminator{sourceFile, line};
432dd220853SSlava Zakharin     auto xCatKind{x.type().GetCategoryAndKind()};
433dd220853SSlava Zakharin     auto yCatKind{y.type().GetCategoryAndKind()};
434dd220853SSlava Zakharin     RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
435*fc97d2e6SPeter Klausler     RUNTIME_CHECK(terminator,
436*fc97d2e6SPeter Klausler         (xCatKind->first == XCAT && yCatKind->first == YCAT) ||
437*fc97d2e6SPeter Klausler             (XCAT == TypeCategory::Integer && YCAT == TypeCategory::Integer &&
438*fc97d2e6SPeter Klausler                 ((xCatKind->first == TypeCategory::Integer ||
439*fc97d2e6SPeter Klausler                      xCatKind->first == TypeCategory::Unsigned) &&
440*fc97d2e6SPeter Klausler                     (yCatKind->first == TypeCategory::Integer ||
441*fc97d2e6SPeter Klausler                         yCatKind->first == TypeCategory::Unsigned))));
442dd220853SSlava Zakharin     if constexpr (constexpr auto resultType{
443dd220853SSlava Zakharin                       GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
444dd220853SSlava Zakharin       return DoMatmul<IS_ALLOCATING, resultType->first, resultType->second,
445dd220853SSlava Zakharin           CppTypeFor<XCAT, XKIND>, CppTypeFor<YCAT, YKIND>>(
446dd220853SSlava Zakharin           result, x, y, terminator);
447dd220853SSlava Zakharin     }
448dd220853SSlava Zakharin     terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))",
449dd220853SSlava Zakharin         static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
450dd220853SSlava Zakharin   }
451dd220853SSlava Zakharin };
452dd220853SSlava Zakharin } // namespace
453dd220853SSlava Zakharin 
454dd220853SSlava Zakharin namespace Fortran::runtime {
4555e1421b2Speter klausler extern "C" {
45676facde3SSlava Zakharin RT_EXT_API_GROUP_BEGIN
45776facde3SSlava Zakharin 
458dd220853SSlava Zakharin #define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
459dd220853SSlava Zakharin   void RTDEF(Matmul##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
460dd220853SSlava Zakharin       const Descriptor &x, const Descriptor &y, const char *sourceFile, \
461dd220853SSlava Zakharin       int line) { \
462dd220853SSlava Zakharin     MatmulHelper<true, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \
463dd220853SSlava Zakharin         YKIND>{}(result, x, y, sourceFile, line); \
464dd220853SSlava Zakharin   }
465dd220853SSlava Zakharin 
466dd220853SSlava Zakharin #define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
467dd220853SSlava Zakharin   void RTDEF(MatmulDirect##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
468dd220853SSlava Zakharin       const Descriptor &x, const Descriptor &y, const char *sourceFile, \
469dd220853SSlava Zakharin       int line) { \
470dd220853SSlava Zakharin     MatmulHelper<false, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \
471dd220853SSlava Zakharin         YKIND>{}(result, x, y, sourceFile, line); \
472dd220853SSlava Zakharin   }
473dd220853SSlava Zakharin 
4748ce1aed5SSlava Zakharin #define MATMUL_FORCE_ALL_TYPES 0
4758ce1aed5SSlava Zakharin 
476dd220853SSlava Zakharin #include "flang/Runtime/matmul-instances.inc"
477dd220853SSlava Zakharin 
47876facde3SSlava Zakharin RT_EXT_API_GROUP_END
4795e1421b2Speter klausler } // extern "C"
4805e1421b2Speter klausler } // namespace Fortran::runtime
481