1 //===-- runtime/matmul.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 // Implements all forms of MATMUL (Fortran 2018 16.9.124) 10 // 11 // There are two main entry points; one establishes a descriptor for the 12 // result and allocates it, and the other expects a result descriptor that 13 // points to existing storage. 14 // 15 // This implementation must handle all combinations of numeric types and 16 // kinds (100 - 165 cases depending on the target), plus all combinations 17 // of logical kinds (16). A single template undergoes many instantiations 18 // to cover all of the valid possibilities. 19 // 20 // Places where BLAS routines could be called are marked as TODO items. 21 22 #include "flang/Runtime/matmul.h" 23 #include "terminator.h" 24 #include "tools.h" 25 #include "flang/Common/optional.h" 26 #include "flang/Runtime/c-or-cpp.h" 27 #include "flang/Runtime/cpp-type.h" 28 #include "flang/Runtime/descriptor.h" 29 #include <cstring> 30 31 namespace { 32 using namespace Fortran::runtime; 33 34 // General accumulator for any type and stride; this is not used for 35 // contiguous numeric cases. 36 template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 37 class Accumulator { 38 public: 39 using Result = AccumulationType<RCAT, RKIND>; 40 RT_API_ATTRS Accumulator(const Descriptor &x, const Descriptor &y) 41 : x_{x}, y_{y} {} 42 RT_API_ATTRS void Accumulate( 43 const SubscriptValue xAt[], const SubscriptValue yAt[]) { 44 if constexpr (RCAT == TypeCategory::Logical) { 45 sum_ = sum_ || 46 (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt)); 47 } else { 48 sum_ += static_cast<Result>(*x_.Element<XT>(xAt)) * 49 static_cast<Result>(*y_.Element<YT>(yAt)); 50 } 51 } 52 RT_API_ATTRS Result GetResult() const { return sum_; } 53 54 private: 55 const Descriptor &x_, &y_; 56 Result sum_{}; 57 }; 58 59 // Contiguous numeric matrix*matrix multiplication 60 // matrix(rows,n) * matrix(n,cols) -> matrix(rows,cols) 61 // Straightforward algorithm: 62 // DO 1 I = 1, NROWS 63 // DO 1 J = 1, NCOLS 64 // RES(I,J) = 0 65 // DO 1 K = 1, N 66 // 1 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) 67 // With loop distribution and transposition to avoid the inner sum 68 // reduction and to avoid non-unit strides: 69 // DO 1 I = 1, NROWS 70 // DO 1 J = 1, NCOLS 71 // 1 RES(I,J) = 0 72 // DO 2 K = 1, N 73 // DO 2 J = 1, NCOLS 74 // DO 2 I = 1, NROWS 75 // 2 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term 76 template <TypeCategory RCAT, int RKIND, typename XT, typename YT, 77 bool X_HAS_STRIDED_COLUMNS, bool Y_HAS_STRIDED_COLUMNS> 78 inline RT_API_ATTRS void MatrixTimesMatrix( 79 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, 80 SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, 81 SubscriptValue n, std::size_t xColumnByteStride = 0, 82 std::size_t yColumnByteStride = 0) { 83 using ResultType = CppTypeFor<RCAT, RKIND>; 84 std::memset(product, 0, rows * cols * sizeof *product); 85 const XT *RESTRICT xp0{x}; 86 for (SubscriptValue k{0}; k < n; ++k) { 87 ResultType *RESTRICT p{product}; 88 for (SubscriptValue j{0}; j < cols; ++j) { 89 const XT *RESTRICT xp{xp0}; 90 ResultType yv; 91 if constexpr (!Y_HAS_STRIDED_COLUMNS) { 92 yv = static_cast<ResultType>(y[k + j * n]); 93 } else { 94 yv = static_cast<ResultType>(reinterpret_cast<const YT *>( 95 reinterpret_cast<const char *>(y) + j * yColumnByteStride)[k]); 96 } 97 for (SubscriptValue i{0}; i < rows; ++i) { 98 *p++ += static_cast<ResultType>(*xp++) * yv; 99 } 100 } 101 if constexpr (!X_HAS_STRIDED_COLUMNS) { 102 xp0 += rows; 103 } else { 104 xp0 = reinterpret_cast<const XT *>( 105 reinterpret_cast<const char *>(xp0) + xColumnByteStride); 106 } 107 } 108 } 109 110 template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 111 inline RT_API_ATTRS void MatrixTimesMatrixHelper( 112 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, 113 SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, 114 SubscriptValue n, Fortran::common::optional<std::size_t> xColumnByteStride, 115 Fortran::common::optional<std::size_t> yColumnByteStride) { 116 if (!xColumnByteStride) { 117 if (!yColumnByteStride) { 118 MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, false>( 119 product, rows, cols, x, y, n); 120 } else { 121 MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, true>( 122 product, rows, cols, x, y, n, 0, *yColumnByteStride); 123 } 124 } else { 125 if (!yColumnByteStride) { 126 MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, false>( 127 product, rows, cols, x, y, n, *xColumnByteStride); 128 } else { 129 MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, true>( 130 product, rows, cols, x, y, n, *xColumnByteStride, *yColumnByteStride); 131 } 132 } 133 } 134 135 // Contiguous numeric matrix*vector multiplication 136 // matrix(rows,n) * column vector(n) -> column vector(rows) 137 // Straightforward algorithm: 138 // DO 1 J = 1, NROWS 139 // RES(J) = 0 140 // DO 1 K = 1, N 141 // 1 RES(J) = RES(J) + X(J,K)*Y(K) 142 // With loop distribution and transposition to avoid the inner 143 // sum reduction and to avoid non-unit strides: 144 // DO 1 J = 1, NROWS 145 // 1 RES(J) = 0 146 // DO 2 K = 1, N 147 // DO 2 J = 1, NROWS 148 // 2 RES(J) = RES(J) + X(J,K)*Y(K) 149 template <TypeCategory RCAT, int RKIND, typename XT, typename YT, 150 bool X_HAS_STRIDED_COLUMNS> 151 inline RT_API_ATTRS void MatrixTimesVector( 152 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, 153 SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y, 154 std::size_t xColumnByteStride = 0) { 155 using ResultType = CppTypeFor<RCAT, RKIND>; 156 std::memset(product, 0, rows * sizeof *product); 157 [[maybe_unused]] const XT *RESTRICT xp0{x}; 158 for (SubscriptValue k{0}; k < n; ++k) { 159 ResultType *RESTRICT p{product}; 160 auto yv{static_cast<ResultType>(*y++)}; 161 for (SubscriptValue j{0}; j < rows; ++j) { 162 *p++ += static_cast<ResultType>(*x++) * yv; 163 } 164 if constexpr (X_HAS_STRIDED_COLUMNS) { 165 xp0 = reinterpret_cast<const XT *>( 166 reinterpret_cast<const char *>(xp0) + xColumnByteStride); 167 x = xp0; 168 } 169 } 170 } 171 172 template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 173 inline RT_API_ATTRS void MatrixTimesVectorHelper( 174 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, 175 SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y, 176 Fortran::common::optional<std::size_t> xColumnByteStride) { 177 if (!xColumnByteStride) { 178 MatrixTimesVector<RCAT, RKIND, XT, YT, false>(product, rows, n, x, y); 179 } else { 180 MatrixTimesVector<RCAT, RKIND, XT, YT, true>( 181 product, rows, n, x, y, *xColumnByteStride); 182 } 183 } 184 185 // Contiguous numeric vector*matrix multiplication 186 // row vector(n) * matrix(n,cols) -> row vector(cols) 187 // Straightforward algorithm: 188 // DO 1 J = 1, NCOLS 189 // RES(J) = 0 190 // DO 1 K = 1, N 191 // 1 RES(J) = RES(J) + X(K)*Y(K,J) 192 // With loop distribution and transposition to avoid the inner 193 // sum reduction and one non-unit stride (the other remains): 194 // DO 1 J = 1, NCOLS 195 // 1 RES(J) = 0 196 // DO 2 K = 1, N 197 // DO 2 J = 1, NCOLS 198 // 2 RES(J) = RES(J) + X(K)*Y(K,J) 199 template <TypeCategory RCAT, int RKIND, typename XT, typename YT, 200 bool Y_HAS_STRIDED_COLUMNS> 201 inline RT_API_ATTRS void VectorTimesMatrix( 202 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue n, 203 SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, 204 std::size_t yColumnByteStride = 0) { 205 using ResultType = CppTypeFor<RCAT, RKIND>; 206 std::memset(product, 0, cols * sizeof *product); 207 for (SubscriptValue k{0}; k < n; ++k) { 208 ResultType *RESTRICT p{product}; 209 auto xv{static_cast<ResultType>(*x++)}; 210 const YT *RESTRICT yp{&y[k]}; 211 for (SubscriptValue j{0}; j < cols; ++j) { 212 *p++ += xv * static_cast<ResultType>(*yp); 213 if constexpr (!Y_HAS_STRIDED_COLUMNS) { 214 yp += n; 215 } else { 216 yp = reinterpret_cast<const YT *>( 217 reinterpret_cast<const char *>(yp) + yColumnByteStride); 218 } 219 } 220 } 221 } 222 223 template <TypeCategory RCAT, int RKIND, typename XT, typename YT, 224 bool SPARSE_COLUMNS = false> 225 inline RT_API_ATTRS void VectorTimesMatrixHelper( 226 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue n, 227 SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, 228 Fortran::common::optional<std::size_t> yColumnByteStride) { 229 if (!yColumnByteStride) { 230 VectorTimesMatrix<RCAT, RKIND, XT, YT, false>(product, n, cols, x, y); 231 } else { 232 VectorTimesMatrix<RCAT, RKIND, XT, YT, true>( 233 product, n, cols, x, y, *yColumnByteStride); 234 } 235 } 236 237 // Implements an instance of MATMUL for given argument types. 238 template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT, 239 typename YT> 240 static inline RT_API_ATTRS void DoMatmul( 241 std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result, 242 const Descriptor &x, const Descriptor &y, Terminator &terminator) { 243 int xRank{x.rank()}; 244 int yRank{y.rank()}; 245 int resRank{xRank + yRank - 2}; 246 if (xRank * yRank != 2 * resRank) { 247 terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank); 248 } 249 SubscriptValue extent[2]{ 250 xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(), 251 resRank == 2 ? y.GetDimension(1).Extent() : 0}; 252 if constexpr (IS_ALLOCATING) { 253 result.Establish( 254 RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable); 255 for (int j{0}; j < resRank; ++j) { 256 result.GetDimension(j).SetBounds(1, extent[j]); 257 } 258 if (int stat{result.Allocate()}) { 259 terminator.Crash( 260 "MATMUL: could not allocate memory for result; STAT=%d", stat); 261 } 262 } else { 263 RUNTIME_CHECK(terminator, resRank == result.rank()); 264 RUNTIME_CHECK( 265 terminator, result.ElementBytes() == static_cast<std::size_t>(RKIND)); 266 RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]); 267 RUNTIME_CHECK(terminator, 268 resRank == 1 || result.GetDimension(1).Extent() == extent[1]); 269 } 270 SubscriptValue n{x.GetDimension(xRank - 1).Extent()}; 271 if (n != y.GetDimension(0).Extent()) { 272 // At this point, we know that there's a shape error. There are three 273 // possibilities, x is rank 1, y is rank 1, or both are rank 2. 274 if (xRank == 1) { 275 terminator.Crash("MATMUL: unacceptable operand shapes (%jd, %jdx%jd)", 276 static_cast<std::intmax_t>(n), 277 static_cast<std::intmax_t>(y.GetDimension(0).Extent()), 278 static_cast<std::intmax_t>(y.GetDimension(1).Extent())); 279 } else if (yRank == 1) { 280 terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jd)", 281 static_cast<std::intmax_t>(x.GetDimension(0).Extent()), 282 static_cast<std::intmax_t>(n), 283 static_cast<std::intmax_t>(y.GetDimension(0).Extent())); 284 } else { 285 terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)", 286 static_cast<std::intmax_t>(x.GetDimension(0).Extent()), 287 static_cast<std::intmax_t>(n), 288 static_cast<std::intmax_t>(y.GetDimension(0).Extent()), 289 static_cast<std::intmax_t>(y.GetDimension(1).Extent())); 290 } 291 } 292 using WriteResult = 293 CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT, 294 RKIND>; 295 if constexpr (RCAT != TypeCategory::Logical) { 296 if (x.IsContiguous(1) && y.IsContiguous(1) && 297 (IS_ALLOCATING || result.IsContiguous())) { 298 // Contiguous numeric matrices (maybe with columns 299 // separated by a stride). 300 Fortran::common::optional<std::size_t> xColumnByteStride; 301 if (!x.IsContiguous()) { 302 // X's columns are strided. 303 SubscriptValue xAt[2]{}; 304 x.GetLowerBounds(xAt); 305 xAt[1]++; 306 xColumnByteStride = x.SubscriptsToByteOffset(xAt); 307 } 308 Fortran::common::optional<std::size_t> yColumnByteStride; 309 if (!y.IsContiguous()) { 310 // Y's columns are strided. 311 SubscriptValue yAt[2]{}; 312 y.GetLowerBounds(yAt); 313 yAt[1]++; 314 yColumnByteStride = y.SubscriptsToByteOffset(yAt); 315 } 316 // Note that BLAS GEMM can be used for the strided 317 // columns by setting proper leading dimension size. 318 // This implies that the column stride is divisible 319 // by the element size, which is usually true. 320 if (resRank == 2) { // M*M -> M 321 if (std::is_same_v<XT, YT>) { 322 if constexpr (std::is_same_v<XT, float>) { 323 // TODO: call BLAS-3 SGEMM 324 // TODO: try using CUTLASS for device. 325 } else if constexpr (std::is_same_v<XT, double>) { 326 // TODO: call BLAS-3 DGEMM 327 } else if constexpr (std::is_same_v<XT, rtcmplx::complex<float>>) { 328 // TODO: call BLAS-3 CGEMM 329 } else if constexpr (std::is_same_v<XT, rtcmplx::complex<double>>) { 330 // TODO: call BLAS-3 ZGEMM 331 } 332 } 333 MatrixTimesMatrixHelper<RCAT, RKIND, XT, YT>( 334 result.template OffsetElement<WriteResult>(), extent[0], extent[1], 335 x.OffsetElement<XT>(), y.OffsetElement<YT>(), n, xColumnByteStride, 336 yColumnByteStride); 337 return; 338 } else if (xRank == 2) { // M*V -> V 339 if (std::is_same_v<XT, YT>) { 340 if constexpr (std::is_same_v<XT, float>) { 341 // TODO: call BLAS-2 SGEMV(x,y) 342 } else if constexpr (std::is_same_v<XT, double>) { 343 // TODO: call BLAS-2 DGEMV(x,y) 344 } else if constexpr (std::is_same_v<XT, rtcmplx::complex<float>>) { 345 // TODO: call BLAS-2 CGEMV(x,y) 346 } else if constexpr (std::is_same_v<XT, rtcmplx::complex<double>>) { 347 // TODO: call BLAS-2 ZGEMV(x,y) 348 } 349 } 350 MatrixTimesVectorHelper<RCAT, RKIND, XT, YT>( 351 result.template OffsetElement<WriteResult>(), extent[0], n, 352 x.OffsetElement<XT>(), y.OffsetElement<YT>(), xColumnByteStride); 353 return; 354 } else { // V*M -> V 355 if (std::is_same_v<XT, YT>) { 356 if constexpr (std::is_same_v<XT, float>) { 357 // TODO: call BLAS-2 SGEMV(y,x) 358 } else if constexpr (std::is_same_v<XT, double>) { 359 // TODO: call BLAS-2 DGEMV(y,x) 360 } else if constexpr (std::is_same_v<XT, rtcmplx::complex<float>>) { 361 // TODO: call BLAS-2 CGEMV(y,x) 362 } else if constexpr (std::is_same_v<XT, rtcmplx::complex<double>>) { 363 // TODO: call BLAS-2 ZGEMV(y,x) 364 } 365 } 366 VectorTimesMatrixHelper<RCAT, RKIND, XT, YT>( 367 result.template OffsetElement<WriteResult>(), n, extent[0], 368 x.OffsetElement<XT>(), y.OffsetElement<YT>(), yColumnByteStride); 369 return; 370 } 371 } 372 } 373 // General algorithms for LOGICAL and noncontiguity 374 SubscriptValue xAt[2], yAt[2], resAt[2]; 375 x.GetLowerBounds(xAt); 376 y.GetLowerBounds(yAt); 377 result.GetLowerBounds(resAt); 378 if (resRank == 2) { // M*M -> M 379 SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]}; 380 for (SubscriptValue i{0}; i < extent[0]; ++i) { 381 for (SubscriptValue j{0}; j < extent[1]; ++j) { 382 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; 383 yAt[1] = y1 + j; 384 for (SubscriptValue k{0}; k < n; ++k) { 385 xAt[1] = x1 + k; 386 yAt[0] = y0 + k; 387 accumulator.Accumulate(xAt, yAt); 388 } 389 resAt[1] = res1 + j; 390 *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); 391 } 392 ++resAt[0]; 393 ++xAt[0]; 394 } 395 } else if (xRank == 2) { // M*V -> V 396 SubscriptValue x1{xAt[1]}, y0{yAt[0]}; 397 for (SubscriptValue j{0}; j < extent[0]; ++j) { 398 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; 399 for (SubscriptValue k{0}; k < n; ++k) { 400 xAt[1] = x1 + k; 401 yAt[0] = y0 + k; 402 accumulator.Accumulate(xAt, yAt); 403 } 404 *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); 405 ++resAt[0]; 406 ++xAt[0]; 407 } 408 } else { // V*M -> V 409 SubscriptValue x0{xAt[0]}, y0{yAt[0]}; 410 for (SubscriptValue j{0}; j < extent[0]; ++j) { 411 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; 412 for (SubscriptValue k{0}; k < n; ++k) { 413 xAt[0] = x0 + k; 414 yAt[0] = y0 + k; 415 accumulator.Accumulate(xAt, yAt); 416 } 417 *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); 418 ++resAt[0]; 419 ++yAt[1]; 420 } 421 } 422 } 423 424 template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT, 425 int YKIND> 426 struct MatmulHelper { 427 using ResultDescriptor = 428 std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>; 429 RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x, 430 const Descriptor &y, const char *sourceFile, int line) const { 431 Terminator terminator{sourceFile, line}; 432 auto xCatKind{x.type().GetCategoryAndKind()}; 433 auto yCatKind{y.type().GetCategoryAndKind()}; 434 RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); 435 RUNTIME_CHECK(terminator, 436 (xCatKind->first == XCAT && yCatKind->first == YCAT) || 437 (XCAT == TypeCategory::Integer && YCAT == TypeCategory::Integer && 438 ((xCatKind->first == TypeCategory::Integer || 439 xCatKind->first == TypeCategory::Unsigned) && 440 (yCatKind->first == TypeCategory::Integer || 441 yCatKind->first == TypeCategory::Unsigned)))); 442 if constexpr (constexpr auto resultType{ 443 GetResultType(XCAT, XKIND, YCAT, YKIND)}) { 444 return DoMatmul<IS_ALLOCATING, resultType->first, resultType->second, 445 CppTypeFor<XCAT, XKIND>, CppTypeFor<YCAT, YKIND>>( 446 result, x, y, terminator); 447 } 448 terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))", 449 static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND); 450 } 451 }; 452 } // namespace 453 454 namespace Fortran::runtime { 455 extern "C" { 456 RT_EXT_API_GROUP_BEGIN 457 458 #define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \ 459 void RTDEF(Matmul##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \ 460 const Descriptor &x, const Descriptor &y, const char *sourceFile, \ 461 int line) { \ 462 MatmulHelper<true, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \ 463 YKIND>{}(result, x, y, sourceFile, line); \ 464 } 465 466 #define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \ 467 void RTDEF(MatmulDirect##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \ 468 const Descriptor &x, const Descriptor &y, const char *sourceFile, \ 469 int line) { \ 470 MatmulHelper<false, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \ 471 YKIND>{}(result, x, y, sourceFile, line); \ 472 } 473 474 #define MATMUL_FORCE_ALL_TYPES 0 475 476 #include "flang/Runtime/matmul-instances.inc" 477 478 RT_EXT_API_GROUP_END 479 } // extern "C" 480 } // namespace Fortran::runtime 481