14ff8ba72STom Eccles //===-- runtime/matmul-transpose.cpp --------------------------------------===// 24ff8ba72STom Eccles // 34ff8ba72STom Eccles // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 44ff8ba72STom Eccles // See https://llvm.org/LICENSE.txt for license information. 54ff8ba72STom Eccles // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 64ff8ba72STom Eccles // 74ff8ba72STom Eccles //===----------------------------------------------------------------------===// 84ff8ba72STom Eccles 94ff8ba72STom Eccles // Implements a fused matmul-transpose operation 104ff8ba72STom Eccles // 114ff8ba72STom Eccles // There are two main entry points; one establishes a descriptor for the 124ff8ba72STom Eccles // result and allocates it, and the other expects a result descriptor that 134ff8ba72STom Eccles // points to existing storage. 144ff8ba72STom Eccles // 154ff8ba72STom Eccles // This implementation must handle all combinations of numeric types and 164ff8ba72STom Eccles // kinds (100 - 165 cases depending on the target), plus all combinations 174ff8ba72STom Eccles // of logical kinds (16). A single template undergoes many instantiations 184ff8ba72STom Eccles // to cover all of the valid possibilities. 194ff8ba72STom Eccles // 204ff8ba72STom Eccles // The usefulness of this optimization should be reviewed once Matmul is swapped 214ff8ba72STom Eccles // to use the faster BLAS routines. 224ff8ba72STom Eccles 234ff8ba72STom Eccles #include "flang/Runtime/matmul-transpose.h" 244ff8ba72STom Eccles #include "terminator.h" 254ff8ba72STom Eccles #include "tools.h" 2671e0261fSSlava Zakharin #include "flang/Common/optional.h" 274ff8ba72STom Eccles #include "flang/Runtime/c-or-cpp.h" 284ff8ba72STom Eccles #include "flang/Runtime/cpp-type.h" 294ff8ba72STom Eccles #include "flang/Runtime/descriptor.h" 304ff8ba72STom Eccles #include <cstring> 314ff8ba72STom Eccles 324ff8ba72STom Eccles namespace { 334ff8ba72STom Eccles using namespace Fortran::runtime; 344ff8ba72STom Eccles 354ff8ba72STom Eccles // Contiguous numeric TRANSPOSE(matrix)*matrix multiplication 364ff8ba72STom Eccles // TRANSPOSE(matrix(n, rows)) * matrix(n,cols) -> 374ff8ba72STom Eccles // matrix(rows, n) * matrix(n,cols) -> matrix(rows,cols) 384ff8ba72STom Eccles // The transpose is implemented by swapping the indices of accesses into the LHS 394ff8ba72STom Eccles // 404ff8ba72STom Eccles // Straightforward algorithm: 414ff8ba72STom Eccles // DO 1 I = 1, NROWS 424ff8ba72STom Eccles // DO 1 J = 1, NCOLS 434ff8ba72STom Eccles // RES(I,J) = 0 444ff8ba72STom Eccles // DO 1 K = 1, N 454ff8ba72STom Eccles // 1 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J) 464ff8ba72STom Eccles // 474ff8ba72STom Eccles // With loop distribution and transposition to avoid the inner sum 484ff8ba72STom Eccles // reduction and to avoid non-unit strides: 494ff8ba72STom Eccles // DO 1 I = 1, NROWS 504ff8ba72STom Eccles // DO 1 J = 1, NCOLS 514ff8ba72STom Eccles // 1 RES(I,J) = 0 524ff8ba72STom Eccles // DO 2 J = 1, NCOLS 534ff8ba72STom Eccles // DO 2 I = 1, NROWS 544ff8ba72STom Eccles // DO 2 K = 1, N 554ff8ba72STom Eccles // 2 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J) ! loop-invariant last term 564d977174SSlava Zakharin template <TypeCategory RCAT, int RKIND, typename XT, typename YT, 574d977174SSlava Zakharin bool X_HAS_STRIDED_COLUMNS, bool Y_HAS_STRIDED_COLUMNS> 58b4b23ff7SSlava Zakharin inline static RT_API_ATTRS void MatrixTransposedTimesMatrix( 594ff8ba72STom Eccles CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, 604ff8ba72STom Eccles SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, 614d977174SSlava Zakharin SubscriptValue n, std::size_t xColumnByteStride = 0, 624d977174SSlava Zakharin std::size_t yColumnByteStride = 0) { 634ff8ba72STom Eccles using ResultType = CppTypeFor<RCAT, RKIND>; 644ff8ba72STom Eccles 654ff8ba72STom Eccles std::memset(product, 0, rows * cols * sizeof *product); 664ff8ba72STom Eccles for (SubscriptValue j{0}; j < cols; ++j) { 674ff8ba72STom Eccles for (SubscriptValue i{0}; i < rows; ++i) { 684ff8ba72STom Eccles for (SubscriptValue k{0}; k < n; ++k) { 694d977174SSlava Zakharin ResultType x_ki; 704d977174SSlava Zakharin if constexpr (!X_HAS_STRIDED_COLUMNS) { 714d977174SSlava Zakharin x_ki = static_cast<ResultType>(x[i * n + k]); 724d977174SSlava Zakharin } else { 734d977174SSlava Zakharin x_ki = static_cast<ResultType>(reinterpret_cast<const XT *>( 744d977174SSlava Zakharin reinterpret_cast<const char *>(x) + i * xColumnByteStride)[k]); 754d977174SSlava Zakharin } 764d977174SSlava Zakharin ResultType y_kj; 774d977174SSlava Zakharin if constexpr (!Y_HAS_STRIDED_COLUMNS) { 784d977174SSlava Zakharin y_kj = static_cast<ResultType>(y[j * n + k]); 794d977174SSlava Zakharin } else { 804d977174SSlava Zakharin y_kj = static_cast<ResultType>(reinterpret_cast<const YT *>( 814d977174SSlava Zakharin reinterpret_cast<const char *>(y) + j * yColumnByteStride)[k]); 824d977174SSlava Zakharin } 834ff8ba72STom Eccles product[j * rows + i] += x_ki * y_kj; 844ff8ba72STom Eccles } 854ff8ba72STom Eccles } 864ff8ba72STom Eccles } 874ff8ba72STom Eccles } 884ff8ba72STom Eccles 894d977174SSlava Zakharin template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 90b4b23ff7SSlava Zakharin inline static RT_API_ATTRS void MatrixTransposedTimesMatrixHelper( 914d977174SSlava Zakharin CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, 924d977174SSlava Zakharin SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, 9371e0261fSSlava Zakharin SubscriptValue n, Fortran::common::optional<std::size_t> xColumnByteStride, 9471e0261fSSlava Zakharin Fortran::common::optional<std::size_t> yColumnByteStride) { 954d977174SSlava Zakharin if (!xColumnByteStride) { 964d977174SSlava Zakharin if (!yColumnByteStride) { 974d977174SSlava Zakharin MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, false, false>( 984d977174SSlava Zakharin product, rows, cols, x, y, n); 994d977174SSlava Zakharin } else { 1004d977174SSlava Zakharin MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, false, true>( 1014d977174SSlava Zakharin product, rows, cols, x, y, n, 0, *yColumnByteStride); 1024d977174SSlava Zakharin } 1034d977174SSlava Zakharin } else { 1044d977174SSlava Zakharin if (!yColumnByteStride) { 1054d977174SSlava Zakharin MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, true, false>( 1064d977174SSlava Zakharin product, rows, cols, x, y, n, *xColumnByteStride); 1074d977174SSlava Zakharin } else { 1084d977174SSlava Zakharin MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, true, true>( 1094d977174SSlava Zakharin product, rows, cols, x, y, n, *xColumnByteStride, *yColumnByteStride); 1104d977174SSlava Zakharin } 1114d977174SSlava Zakharin } 1124d977174SSlava Zakharin } 1134d977174SSlava Zakharin 1144ff8ba72STom Eccles // Contiguous numeric matrix*vector multiplication 1154ff8ba72STom Eccles // matrix(rows,n) * column vector(n) -> column vector(rows) 1164ff8ba72STom Eccles // Straightforward algorithm: 1174ff8ba72STom Eccles // DO 1 I = 1, NROWS 1184ff8ba72STom Eccles // RES(I) = 0 1194ff8ba72STom Eccles // DO 1 K = 1, N 1204ff8ba72STom Eccles // 1 RES(I) = RES(I) + X(K,I)*Y(K) 1214ff8ba72STom Eccles // With loop distribution and transposition to avoid the inner 1224ff8ba72STom Eccles // sum reduction and to avoid non-unit strides: 1234ff8ba72STom Eccles // DO 1 I = 1, NROWS 1244ff8ba72STom Eccles // 1 RES(I) = 0 1254ff8ba72STom Eccles // DO 2 I = 1, NROWS 1264ff8ba72STom Eccles // DO 2 K = 1, N 1274ff8ba72STom Eccles // 2 RES(I) = RES(I) + X(K,I)*Y(K) 1284d977174SSlava Zakharin template <TypeCategory RCAT, int RKIND, typename XT, typename YT, 1294d977174SSlava Zakharin bool X_HAS_STRIDED_COLUMNS> 130b4b23ff7SSlava Zakharin inline static RT_API_ATTRS void MatrixTransposedTimesVector( 1314ff8ba72STom Eccles CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, 1324d977174SSlava Zakharin SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y, 1334d977174SSlava Zakharin std::size_t xColumnByteStride = 0) { 1344ff8ba72STom Eccles using ResultType = CppTypeFor<RCAT, RKIND>; 1354ff8ba72STom Eccles std::memset(product, 0, rows * sizeof *product); 1364ff8ba72STom Eccles for (SubscriptValue i{0}; i < rows; ++i) { 1374ff8ba72STom Eccles for (SubscriptValue k{0}; k < n; ++k) { 1384d977174SSlava Zakharin ResultType x_ki; 1394d977174SSlava Zakharin if constexpr (!X_HAS_STRIDED_COLUMNS) { 1404d977174SSlava Zakharin x_ki = static_cast<ResultType>(x[i * n + k]); 1414d977174SSlava Zakharin } else { 1424d977174SSlava Zakharin x_ki = static_cast<ResultType>(reinterpret_cast<const XT *>( 1434d977174SSlava Zakharin reinterpret_cast<const char *>(x) + i * xColumnByteStride)[k]); 1444d977174SSlava Zakharin } 1454ff8ba72STom Eccles ResultType y_k = static_cast<ResultType>(y[k]); 1464ff8ba72STom Eccles product[i] += x_ki * y_k; 1474ff8ba72STom Eccles } 1484ff8ba72STom Eccles } 1494ff8ba72STom Eccles } 1504ff8ba72STom Eccles 1514d977174SSlava Zakharin template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 152b4b23ff7SSlava Zakharin inline static RT_API_ATTRS void MatrixTransposedTimesVectorHelper( 1534d977174SSlava Zakharin CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, 1544d977174SSlava Zakharin SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y, 15571e0261fSSlava Zakharin Fortran::common::optional<std::size_t> xColumnByteStride) { 1564d977174SSlava Zakharin if (!xColumnByteStride) { 1574d977174SSlava Zakharin MatrixTransposedTimesVector<RCAT, RKIND, XT, YT, false>( 1584d977174SSlava Zakharin product, rows, n, x, y); 1594d977174SSlava Zakharin } else { 1604d977174SSlava Zakharin MatrixTransposedTimesVector<RCAT, RKIND, XT, YT, true>( 1614d977174SSlava Zakharin product, rows, n, x, y, *xColumnByteStride); 1624d977174SSlava Zakharin } 1634d977174SSlava Zakharin } 1644d977174SSlava Zakharin 1654ff8ba72STom Eccles // Implements an instance of MATMUL for given argument types. 1664ff8ba72STom Eccles template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT, 1674ff8ba72STom Eccles typename YT> 168b4b23ff7SSlava Zakharin inline static RT_API_ATTRS void DoMatmulTranspose( 1694ff8ba72STom Eccles std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result, 1704ff8ba72STom Eccles const Descriptor &x, const Descriptor &y, Terminator &terminator) { 1714ff8ba72STom Eccles int xRank{x.rank()}; 1724ff8ba72STom Eccles int yRank{y.rank()}; 1734ff8ba72STom Eccles int resRank{xRank + yRank - 2}; 1744ff8ba72STom Eccles if (xRank * yRank != 2 * resRank) { 175ea7d6a1bSSlava Zakharin terminator.Crash( 176ea7d6a1bSSlava Zakharin "MATMUL-TRANSPOSE: bad argument ranks (%d * %d)", xRank, yRank); 1774ff8ba72STom Eccles } 1784ff8ba72STom Eccles SubscriptValue extent[2]{x.GetDimension(1).Extent(), 1794ff8ba72STom Eccles resRank == 2 ? y.GetDimension(1).Extent() : 0}; 1804ff8ba72STom Eccles if constexpr (IS_ALLOCATING) { 1814ff8ba72STom Eccles result.Establish( 1824ff8ba72STom Eccles RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable); 1834ff8ba72STom Eccles for (int j{0}; j < resRank; ++j) { 1844ff8ba72STom Eccles result.GetDimension(j).SetBounds(1, extent[j]); 1854ff8ba72STom Eccles } 1864ff8ba72STom Eccles if (int stat{result.Allocate()}) { 1874ff8ba72STom Eccles terminator.Crash( 188ea7d6a1bSSlava Zakharin "MATMUL-TRANSPOSE: could not allocate memory for result; STAT=%d", 189ea7d6a1bSSlava Zakharin stat); 1904ff8ba72STom Eccles } 1914ff8ba72STom Eccles } else { 1924ff8ba72STom Eccles RUNTIME_CHECK(terminator, resRank == result.rank()); 1934ff8ba72STom Eccles RUNTIME_CHECK( 1944ff8ba72STom Eccles terminator, result.ElementBytes() == static_cast<std::size_t>(RKIND)); 1954ff8ba72STom Eccles RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]); 1964ff8ba72STom Eccles RUNTIME_CHECK(terminator, 1974ff8ba72STom Eccles resRank == 1 || result.GetDimension(1).Extent() == extent[1]); 1984ff8ba72STom Eccles } 1994ff8ba72STom Eccles SubscriptValue n{x.GetDimension(0).Extent()}; 2004ff8ba72STom Eccles if (n != y.GetDimension(0).Extent()) { 201ea7d6a1bSSlava Zakharin terminator.Crash( 202ea7d6a1bSSlava Zakharin "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)", 2034ff8ba72STom Eccles static_cast<std::intmax_t>(x.GetDimension(0).Extent()), 2044ff8ba72STom Eccles static_cast<std::intmax_t>(x.GetDimension(1).Extent()), 2054ff8ba72STom Eccles static_cast<std::intmax_t>(y.GetDimension(0).Extent()), 2064ff8ba72STom Eccles static_cast<std::intmax_t>(y.GetDimension(1).Extent())); 2074ff8ba72STom Eccles } 2084ff8ba72STom Eccles using WriteResult = 2094ff8ba72STom Eccles CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT, 2104ff8ba72STom Eccles RKIND>; 2114ff8ba72STom Eccles const SubscriptValue rows{extent[0]}; 2124ff8ba72STom Eccles const SubscriptValue cols{extent[1]}; 2134ff8ba72STom Eccles if constexpr (RCAT != TypeCategory::Logical) { 2144d977174SSlava Zakharin if (x.IsContiguous(1) && y.IsContiguous(1) && 2154ff8ba72STom Eccles (IS_ALLOCATING || result.IsContiguous())) { 2164d977174SSlava Zakharin // Contiguous numeric matrices (maybe with columns 2174d977174SSlava Zakharin // separated by a stride). 21871e0261fSSlava Zakharin Fortran::common::optional<std::size_t> xColumnByteStride; 2194d977174SSlava Zakharin if (!x.IsContiguous()) { 2204d977174SSlava Zakharin // X's columns are strided. 2214d977174SSlava Zakharin SubscriptValue xAt[2]{}; 2224d977174SSlava Zakharin x.GetLowerBounds(xAt); 2234d977174SSlava Zakharin xAt[1]++; 2244d977174SSlava Zakharin xColumnByteStride = x.SubscriptsToByteOffset(xAt); 2254d977174SSlava Zakharin } 22671e0261fSSlava Zakharin Fortran::common::optional<std::size_t> yColumnByteStride; 2274d977174SSlava Zakharin if (!y.IsContiguous()) { 2284d977174SSlava Zakharin // Y's columns are strided. 2294d977174SSlava Zakharin SubscriptValue yAt[2]{}; 2304d977174SSlava Zakharin y.GetLowerBounds(yAt); 2314d977174SSlava Zakharin yAt[1]++; 2324d977174SSlava Zakharin yColumnByteStride = y.SubscriptsToByteOffset(yAt); 2334d977174SSlava Zakharin } 2344ff8ba72STom Eccles if (resRank == 2) { // M*M -> M 2354d977174SSlava Zakharin // TODO: use BLAS-3 GEMM for supported types. 2364d977174SSlava Zakharin MatrixTransposedTimesMatrixHelper<RCAT, RKIND, XT, YT>( 2374ff8ba72STom Eccles result.template OffsetElement<WriteResult>(), rows, cols, 2384d977174SSlava Zakharin x.OffsetElement<XT>(), y.OffsetElement<YT>(), n, xColumnByteStride, 2394d977174SSlava Zakharin yColumnByteStride); 2404ff8ba72STom Eccles return; 2414ff8ba72STom Eccles } 2424ff8ba72STom Eccles if (xRank == 2) { // M*V -> V 2434d977174SSlava Zakharin // TODO: use BLAS-2 GEMM for supported types. 2444d977174SSlava Zakharin MatrixTransposedTimesVectorHelper<RCAT, RKIND, XT, YT>( 2454ff8ba72STom Eccles result.template OffsetElement<WriteResult>(), rows, n, 2464d977174SSlava Zakharin x.OffsetElement<XT>(), y.OffsetElement<YT>(), xColumnByteStride); 2474ff8ba72STom Eccles return; 2484ff8ba72STom Eccles } 2494ff8ba72STom Eccles // else V*M -> V (not allowed because TRANSPOSE() is only defined for rank 2504ff8ba72STom Eccles // 1 matrices 251ea7d6a1bSSlava Zakharin terminator.Crash( 252ea7d6a1bSSlava Zakharin "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)", 2534ff8ba72STom Eccles static_cast<std::intmax_t>(x.GetDimension(0).Extent()), 2544ff8ba72STom Eccles static_cast<std::intmax_t>(n), 2554ff8ba72STom Eccles static_cast<std::intmax_t>(y.GetDimension(0).Extent()), 2564ff8ba72STom Eccles static_cast<std::intmax_t>(y.GetDimension(1).Extent())); 2574ff8ba72STom Eccles return; 2584ff8ba72STom Eccles } 2594ff8ba72STom Eccles } 2604ff8ba72STom Eccles // General algorithms for LOGICAL and noncontiguity 2614ff8ba72STom Eccles SubscriptValue xLB[2], yLB[2], resLB[2]; 2624ff8ba72STom Eccles x.GetLowerBounds(xLB); 2634ff8ba72STom Eccles y.GetLowerBounds(yLB); 2644ff8ba72STom Eccles result.GetLowerBounds(resLB); 2654ff8ba72STom Eccles using ResultType = CppTypeFor<RCAT, RKIND>; 2664ff8ba72STom Eccles if (resRank == 2) { // M*M -> M 2674ff8ba72STom Eccles for (SubscriptValue i{0}; i < rows; ++i) { 2684ff8ba72STom Eccles for (SubscriptValue j{0}; j < cols; ++j) { 2694ff8ba72STom Eccles ResultType res_ij; 2704ff8ba72STom Eccles if constexpr (RCAT == TypeCategory::Logical) { 2714ff8ba72STom Eccles res_ij = false; 2724ff8ba72STom Eccles } else { 2734ff8ba72STom Eccles res_ij = 0; 2744ff8ba72STom Eccles } 2754ff8ba72STom Eccles 2764ff8ba72STom Eccles for (SubscriptValue k{0}; k < n; ++k) { 2774ff8ba72STom Eccles SubscriptValue xAt[2]{k + xLB[0], i + xLB[1]}; 2784ff8ba72STom Eccles SubscriptValue yAt[2]{k + yLB[0], j + yLB[1]}; 2794ff8ba72STom Eccles if constexpr (RCAT == TypeCategory::Logical) { 2804ff8ba72STom Eccles ResultType x_ki = IsLogicalElementTrue(x, xAt); 2814ff8ba72STom Eccles ResultType y_kj = IsLogicalElementTrue(y, yAt); 2824ff8ba72STom Eccles res_ij = res_ij || (x_ki && y_kj); 2834ff8ba72STom Eccles } else { 2844ff8ba72STom Eccles ResultType x_ki = static_cast<ResultType>(*x.Element<XT>(xAt)); 2854ff8ba72STom Eccles ResultType y_kj = static_cast<ResultType>(*y.Element<YT>(yAt)); 2864ff8ba72STom Eccles res_ij += x_ki * y_kj; 2874ff8ba72STom Eccles } 2884ff8ba72STom Eccles } 2894ff8ba72STom Eccles SubscriptValue resAt[2]{i + resLB[0], j + resLB[1]}; 2904ff8ba72STom Eccles *result.template Element<WriteResult>(resAt) = res_ij; 2914ff8ba72STom Eccles } 2924ff8ba72STom Eccles } 2934ff8ba72STom Eccles } else if (xRank == 2) { // M*V -> V 2944ff8ba72STom Eccles for (SubscriptValue i{0}; i < rows; ++i) { 2954ff8ba72STom Eccles ResultType res_i; 2964ff8ba72STom Eccles if constexpr (RCAT == TypeCategory::Logical) { 2974ff8ba72STom Eccles res_i = false; 2984ff8ba72STom Eccles } else { 2994ff8ba72STom Eccles res_i = 0; 3004ff8ba72STom Eccles } 3014ff8ba72STom Eccles 3024ff8ba72STom Eccles for (SubscriptValue k{0}; k < n; ++k) { 3034ff8ba72STom Eccles SubscriptValue xAt[2]{k + xLB[0], i + xLB[1]}; 3044ff8ba72STom Eccles SubscriptValue yAt[1]{k + yLB[0]}; 3054ff8ba72STom Eccles if constexpr (RCAT == TypeCategory::Logical) { 3064ff8ba72STom Eccles ResultType x_ki = IsLogicalElementTrue(x, xAt); 3074ff8ba72STom Eccles ResultType y_k = IsLogicalElementTrue(y, yAt); 3084ff8ba72STom Eccles res_i = res_i || (x_ki && y_k); 3094ff8ba72STom Eccles } else { 3104ff8ba72STom Eccles ResultType x_ki = static_cast<ResultType>(*x.Element<XT>(xAt)); 3114ff8ba72STom Eccles ResultType y_k = static_cast<ResultType>(*y.Element<YT>(yAt)); 3124ff8ba72STom Eccles res_i += x_ki * y_k; 3134ff8ba72STom Eccles } 3144ff8ba72STom Eccles } 3154ff8ba72STom Eccles SubscriptValue resAt[1]{i + resLB[0]}; 3164ff8ba72STom Eccles *result.template Element<WriteResult>(resAt) = res_i; 3174ff8ba72STom Eccles } 3184ff8ba72STom Eccles } else { // V*M -> V 3194ff8ba72STom Eccles // TRANSPOSE(V) not allowed by fortran standard 320ea7d6a1bSSlava Zakharin terminator.Crash( 321ea7d6a1bSSlava Zakharin "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)", 3224ff8ba72STom Eccles static_cast<std::intmax_t>(x.GetDimension(0).Extent()), 3234ff8ba72STom Eccles static_cast<std::intmax_t>(n), 3244ff8ba72STom Eccles static_cast<std::intmax_t>(y.GetDimension(0).Extent()), 3254ff8ba72STom Eccles static_cast<std::intmax_t>(y.GetDimension(1).Extent())); 3264ff8ba72STom Eccles } 3274ff8ba72STom Eccles } 3284ff8ba72STom Eccles 329dd220853SSlava Zakharin template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT, 330dd220853SSlava Zakharin int YKIND> 331dd220853SSlava Zakharin struct MatmulTransposeHelper { 332dd220853SSlava Zakharin using ResultDescriptor = 333dd220853SSlava Zakharin std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>; 334dd220853SSlava Zakharin RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x, 335dd220853SSlava Zakharin const Descriptor &y, const char *sourceFile, int line) const { 336dd220853SSlava Zakharin Terminator terminator{sourceFile, line}; 337dd220853SSlava Zakharin auto xCatKind{x.type().GetCategoryAndKind()}; 338dd220853SSlava Zakharin auto yCatKind{y.type().GetCategoryAndKind()}; 339dd220853SSlava Zakharin RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); 340dd220853SSlava Zakharin RUNTIME_CHECK(terminator, xCatKind->first == XCAT); 341dd220853SSlava Zakharin RUNTIME_CHECK(terminator, yCatKind->first == YCAT); 342dd220853SSlava Zakharin if constexpr (constexpr auto resultType{ 343dd220853SSlava Zakharin GetResultType(XCAT, XKIND, YCAT, YKIND)}) { 344dd220853SSlava Zakharin return DoMatmulTranspose<IS_ALLOCATING, resultType->first, 345dd220853SSlava Zakharin resultType->second, CppTypeFor<XCAT, XKIND>, CppTypeFor<YCAT, YKIND>>( 346dd220853SSlava Zakharin result, x, y, terminator); 347dd220853SSlava Zakharin } 348dd220853SSlava Zakharin terminator.Crash("MATMUL-TRANSPOSE: bad operand types (%d(%d), %d(%d))", 349dd220853SSlava Zakharin static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND); 350dd220853SSlava Zakharin } 351dd220853SSlava Zakharin }; 3524ff8ba72STom Eccles } // namespace 3534ff8ba72STom Eccles 3544ff8ba72STom Eccles namespace Fortran::runtime { 3554ff8ba72STom Eccles extern "C" { 35676facde3SSlava Zakharin RT_EXT_API_GROUP_BEGIN 35776facde3SSlava Zakharin 358dd220853SSlava Zakharin #define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \ 359dd220853SSlava Zakharin void RTDEF(MatmulTranspose##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \ 360dd220853SSlava Zakharin const Descriptor &x, const Descriptor &y, const char *sourceFile, \ 361dd220853SSlava Zakharin int line) { \ 362dd220853SSlava Zakharin MatmulTransposeHelper<true, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \ 363dd220853SSlava Zakharin YKIND>{}(result, x, y, sourceFile, line); \ 364dd220853SSlava Zakharin } 365dd220853SSlava Zakharin 366dd220853SSlava Zakharin #define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \ 367dd220853SSlava Zakharin void RTDEF(MatmulTransposeDirect##XCAT##XKIND##YCAT##YKIND)( \ 368dd220853SSlava Zakharin Descriptor & result, const Descriptor &x, const Descriptor &y, \ 369dd220853SSlava Zakharin const char *sourceFile, int line) { \ 370dd220853SSlava Zakharin MatmulTransposeHelper<false, TypeCategory::XCAT, XKIND, \ 371dd220853SSlava Zakharin TypeCategory::YCAT, YKIND>{}(result, x, y, sourceFile, line); \ 372dd220853SSlava Zakharin } 373dd220853SSlava Zakharin 374*8ce1aed5SSlava Zakharin #define MATMUL_FORCE_ALL_TYPES 0 375*8ce1aed5SSlava Zakharin 376dd220853SSlava Zakharin #include "flang/Runtime/matmul-instances.inc" 377dd220853SSlava Zakharin 37876facde3SSlava Zakharin RT_EXT_API_GROUP_END 3794ff8ba72STom Eccles } // extern "C" 3804ff8ba72STom Eccles } // namespace Fortran::runtime 381