Lines Matching +full:matrix +full:- +full:matrix +full:- +full:transpose
1 //===-- runtime/matmul-transpose.cpp --------------------------------------===//
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Implements a fused matmul-transpose operation
16 // kinds (100 - 165 cases depending on the target), plus all combinations
23 #include "flang/Runtime/matmul-transpose.h"
27 #include "flang/Runtime/c-or-cpp.h"
28 #include "flang/Runtime/cpp-type.h"
35 // Contiguous numeric TRANSPOSE(matrix)*matrix multiplication
36 // TRANSPOSE(matrix(n, rows)) * matrix(n,cols) ->
37 // matrix(rows, n) * matrix(n,cols) -> matrix(rows,cols)
38 // The transpose is implemented by swapping the indices of accesses into the LHS
48 // reduction and to avoid non-unit strides:
55 // 2 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J) ! loop-invariant last term
114 // Contiguous numeric matrix*vector multiplication
115 // matrix(rows,n) * column vector(n) -> column vector(rows)
122 // sum reduction and to avoid non-unit strides:
173 int resRank{xRank + yRank - 2};
176 "MATMUL-TRANSPOSE: bad argument ranks (%d * %d)", xRank, yRank);
188 "MATMUL-TRANSPOSE: could not allocate memory for result; STAT=%d",
202 "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
234 if (resRank == 2) { // M*M -> M
235 // TODO: use BLAS-3 GEMM for supported types.
242 if (xRank == 2) { // M*V -> V
243 // TODO: use BLAS-2 GEMM for supported types.
249 // else V*M -> V (not allowed because TRANSPOSE() is only defined for rank
252 "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
266 if (resRank == 2) { // M*M -> M
293 } else if (xRank == 2) { // M*V -> V
318 } else { // V*M -> V
319 // TRANSPOSE(V) not allowed by fortran standard
321 "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
340 RUNTIME_CHECK(terminator, xCatKind->first == XCAT);
341 RUNTIME_CHECK(terminator, yCatKind->first == YCAT);
344 return DoMatmulTranspose<IS_ALLOCATING, resultType->first,
345 resultType->second, CppTypeFor<XCAT, XKIND>, CppTypeFor<YCAT, YKIND>>(
348 terminator.Crash("MATMUL-TRANSPOSE: bad operand types (%d(%d), %d(%d))",
376 #include "flang/Runtime/matmul-instances.inc"