1 //===-- runtime/matmul.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 // Implements all forms of MATMUL (Fortran 2018 16.9.124) 10 // 11 // There are two main entry points; one establishes a descriptor for the 12 // result and allocates it, and the other expects a result descriptor that 13 // points to existing storage. 14 // 15 // This implementation must handle all combinations of numeric types and 16 // kinds (100 - 165 cases depending on the target), plus all combinations 17 // of logical kinds (16). A single template undergoes many instantiations 18 // to cover all of the valid possibilities. 19 // 20 // Places where BLAS routines could be called are marked as TODO items. 21 22 #include "flang/Runtime/matmul.h" 23 #include "terminator.h" 24 #include "tools.h" 25 #include "flang/Common/optional.h" 26 #include "flang/Runtime/c-or-cpp.h" 27 #include "flang/Runtime/cpp-type.h" 28 #include "flang/Runtime/descriptor.h" 29 #include <cstring> 30 31 namespace { 32 using namespace Fortran::runtime; 33 34 // Suppress the warnings about calling __host__-only std::complex operators, 35 // defined in C++ STD header files, from __device__ code. 36 RT_DIAG_PUSH 37 RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN 38 39 // General accumulator for any type and stride; this is not used for 40 // contiguous numeric cases. 41 template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 42 class Accumulator { 43 public: 44 using Result = AccumulationType<RCAT, RKIND>; 45 RT_API_ATTRS Accumulator(const Descriptor &x, const Descriptor &y) 46 : x_{x}, y_{y} {} 47 RT_API_ATTRS void Accumulate( 48 const SubscriptValue xAt[], const SubscriptValue yAt[]) { 49 if constexpr (RCAT == TypeCategory::Logical) { 50 sum_ = sum_ || 51 (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt)); 52 } else { 53 sum_ += static_cast<Result>(*x_.Element<XT>(xAt)) * 54 static_cast<Result>(*y_.Element<YT>(yAt)); 55 } 56 } 57 RT_API_ATTRS Result GetResult() const { return sum_; } 58 59 private: 60 const Descriptor &x_, &y_; 61 Result sum_{}; 62 }; 63 64 // Contiguous numeric matrix*matrix multiplication 65 // matrix(rows,n) * matrix(n,cols) -> matrix(rows,cols) 66 // Straightforward algorithm: 67 // DO 1 I = 1, NROWS 68 // DO 1 J = 1, NCOLS 69 // RES(I,J) = 0 70 // DO 1 K = 1, N 71 // 1 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) 72 // With loop distribution and transposition to avoid the inner sum 73 // reduction and to avoid non-unit strides: 74 // DO 1 I = 1, NROWS 75 // DO 1 J = 1, NCOLS 76 // 1 RES(I,J) = 0 77 // DO 2 K = 1, N 78 // DO 2 J = 1, NCOLS 79 // DO 2 I = 1, NROWS 80 // 2 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term 81 template <TypeCategory RCAT, int RKIND, typename XT, typename YT, 82 bool X_HAS_STRIDED_COLUMNS, bool Y_HAS_STRIDED_COLUMNS> 83 inline RT_API_ATTRS void MatrixTimesMatrix( 84 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, 85 SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, 86 SubscriptValue n, std::size_t xColumnByteStride = 0, 87 std::size_t yColumnByteStride = 0) { 88 using ResultType = CppTypeFor<RCAT, RKIND>; 89 std::memset(product, 0, rows * cols * sizeof *product); 90 const XT *RESTRICT xp0{x}; 91 for (SubscriptValue k{0}; k < n; ++k) { 92 ResultType *RESTRICT p{product}; 93 for (SubscriptValue j{0}; j < cols; ++j) { 94 const XT *RESTRICT xp{xp0}; 95 ResultType yv; 96 if constexpr (!Y_HAS_STRIDED_COLUMNS) { 97 yv = static_cast<ResultType>(y[k + j * n]); 98 } else { 99 yv = static_cast<ResultType>(reinterpret_cast<const YT *>( 100 reinterpret_cast<const char *>(y) + j * yColumnByteStride)[k]); 101 } 102 for (SubscriptValue i{0}; i < rows; ++i) { 103 *p++ += static_cast<ResultType>(*xp++) * yv; 104 } 105 } 106 if constexpr (!X_HAS_STRIDED_COLUMNS) { 107 xp0 += rows; 108 } else { 109 xp0 = reinterpret_cast<const XT *>( 110 reinterpret_cast<const char *>(xp0) + xColumnByteStride); 111 } 112 } 113 } 114 115 RT_DIAG_POP 116 117 template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 118 inline RT_API_ATTRS void MatrixTimesMatrixHelper( 119 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, 120 SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, 121 SubscriptValue n, Fortran::common::optional<std::size_t> xColumnByteStride, 122 Fortran::common::optional<std::size_t> yColumnByteStride) { 123 if (!xColumnByteStride) { 124 if (!yColumnByteStride) { 125 MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, false>( 126 product, rows, cols, x, y, n); 127 } else { 128 MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, true>( 129 product, rows, cols, x, y, n, 0, *yColumnByteStride); 130 } 131 } else { 132 if (!yColumnByteStride) { 133 MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, false>( 134 product, rows, cols, x, y, n, *xColumnByteStride); 135 } else { 136 MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, true>( 137 product, rows, cols, x, y, n, *xColumnByteStride, *yColumnByteStride); 138 } 139 } 140 } 141 142 RT_DIAG_PUSH 143 RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN 144 145 // Contiguous numeric matrix*vector multiplication 146 // matrix(rows,n) * column vector(n) -> column vector(rows) 147 // Straightforward algorithm: 148 // DO 1 J = 1, NROWS 149 // RES(J) = 0 150 // DO 1 K = 1, N 151 // 1 RES(J) = RES(J) + X(J,K)*Y(K) 152 // With loop distribution and transposition to avoid the inner 153 // sum reduction and to avoid non-unit strides: 154 // DO 1 J = 1, NROWS 155 // 1 RES(J) = 0 156 // DO 2 K = 1, N 157 // DO 2 J = 1, NROWS 158 // 2 RES(J) = RES(J) + X(J,K)*Y(K) 159 template <TypeCategory RCAT, int RKIND, typename XT, typename YT, 160 bool X_HAS_STRIDED_COLUMNS> 161 inline RT_API_ATTRS void MatrixTimesVector( 162 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, 163 SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y, 164 std::size_t xColumnByteStride = 0) { 165 using ResultType = CppTypeFor<RCAT, RKIND>; 166 std::memset(product, 0, rows * sizeof *product); 167 [[maybe_unused]] const XT *RESTRICT xp0{x}; 168 for (SubscriptValue k{0}; k < n; ++k) { 169 ResultType *RESTRICT p{product}; 170 auto yv{static_cast<ResultType>(*y++)}; 171 for (SubscriptValue j{0}; j < rows; ++j) { 172 *p++ += static_cast<ResultType>(*x++) * yv; 173 } 174 if constexpr (X_HAS_STRIDED_COLUMNS) { 175 xp0 = reinterpret_cast<const XT *>( 176 reinterpret_cast<const char *>(xp0) + xColumnByteStride); 177 x = xp0; 178 } 179 } 180 } 181 182 RT_DIAG_POP 183 184 template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 185 inline RT_API_ATTRS void MatrixTimesVectorHelper( 186 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, 187 SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y, 188 Fortran::common::optional<std::size_t> xColumnByteStride) { 189 if (!xColumnByteStride) { 190 MatrixTimesVector<RCAT, RKIND, XT, YT, false>(product, rows, n, x, y); 191 } else { 192 MatrixTimesVector<RCAT, RKIND, XT, YT, true>( 193 product, rows, n, x, y, *xColumnByteStride); 194 } 195 } 196 197 RT_DIAG_PUSH 198 RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN 199 200 // Contiguous numeric vector*matrix multiplication 201 // row vector(n) * matrix(n,cols) -> row vector(cols) 202 // Straightforward algorithm: 203 // DO 1 J = 1, NCOLS 204 // RES(J) = 0 205 // DO 1 K = 1, N 206 // 1 RES(J) = RES(J) + X(K)*Y(K,J) 207 // With loop distribution and transposition to avoid the inner 208 // sum reduction and one non-unit stride (the other remains): 209 // DO 1 J = 1, NCOLS 210 // 1 RES(J) = 0 211 // DO 2 K = 1, N 212 // DO 2 J = 1, NCOLS 213 // 2 RES(J) = RES(J) + X(K)*Y(K,J) 214 template <TypeCategory RCAT, int RKIND, typename XT, typename YT, 215 bool Y_HAS_STRIDED_COLUMNS> 216 inline RT_API_ATTRS void VectorTimesMatrix( 217 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue n, 218 SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, 219 std::size_t yColumnByteStride = 0) { 220 using ResultType = CppTypeFor<RCAT, RKIND>; 221 std::memset(product, 0, cols * sizeof *product); 222 for (SubscriptValue k{0}; k < n; ++k) { 223 ResultType *RESTRICT p{product}; 224 auto xv{static_cast<ResultType>(*x++)}; 225 const YT *RESTRICT yp{&y[k]}; 226 for (SubscriptValue j{0}; j < cols; ++j) { 227 *p++ += xv * static_cast<ResultType>(*yp); 228 if constexpr (!Y_HAS_STRIDED_COLUMNS) { 229 yp += n; 230 } else { 231 yp = reinterpret_cast<const YT *>( 232 reinterpret_cast<const char *>(yp) + yColumnByteStride); 233 } 234 } 235 } 236 } 237 238 RT_DIAG_POP 239 240 template <TypeCategory RCAT, int RKIND, typename XT, typename YT, 241 bool SPARSE_COLUMNS = false> 242 inline RT_API_ATTRS void VectorTimesMatrixHelper( 243 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue n, 244 SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, 245 Fortran::common::optional<std::size_t> yColumnByteStride) { 246 if (!yColumnByteStride) { 247 VectorTimesMatrix<RCAT, RKIND, XT, YT, false>(product, n, cols, x, y); 248 } else { 249 VectorTimesMatrix<RCAT, RKIND, XT, YT, true>( 250 product, n, cols, x, y, *yColumnByteStride); 251 } 252 } 253 254 RT_DIAG_PUSH 255 RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN 256 257 // Implements an instance of MATMUL for given argument types. 258 template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT, 259 typename YT> 260 static inline RT_API_ATTRS void DoMatmul( 261 std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result, 262 const Descriptor &x, const Descriptor &y, Terminator &terminator) { 263 int xRank{x.rank()}; 264 int yRank{y.rank()}; 265 int resRank{xRank + yRank - 2}; 266 if (xRank * yRank != 2 * resRank) { 267 terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank); 268 } 269 SubscriptValue extent[2]{ 270 xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(), 271 resRank == 2 ? y.GetDimension(1).Extent() : 0}; 272 if constexpr (IS_ALLOCATING) { 273 result.Establish( 274 RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable); 275 for (int j{0}; j < resRank; ++j) { 276 result.GetDimension(j).SetBounds(1, extent[j]); 277 } 278 if (int stat{result.Allocate()}) { 279 terminator.Crash( 280 "MATMUL: could not allocate memory for result; STAT=%d", stat); 281 } 282 } else { 283 RUNTIME_CHECK(terminator, resRank == result.rank()); 284 RUNTIME_CHECK( 285 terminator, result.ElementBytes() == static_cast<std::size_t>(RKIND)); 286 RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]); 287 RUNTIME_CHECK(terminator, 288 resRank == 1 || result.GetDimension(1).Extent() == extent[1]); 289 } 290 SubscriptValue n{x.GetDimension(xRank - 1).Extent()}; 291 if (n != y.GetDimension(0).Extent()) { 292 // At this point, we know that there's a shape error. There are three 293 // possibilities, x is rank 1, y is rank 1, or both are rank 2. 294 if (xRank == 1) { 295 terminator.Crash("MATMUL: unacceptable operand shapes (%jd, %jdx%jd)", 296 static_cast<std::intmax_t>(n), 297 static_cast<std::intmax_t>(y.GetDimension(0).Extent()), 298 static_cast<std::intmax_t>(y.GetDimension(1).Extent())); 299 } else if (yRank == 1) { 300 terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jd)", 301 static_cast<std::intmax_t>(x.GetDimension(0).Extent()), 302 static_cast<std::intmax_t>(n), 303 static_cast<std::intmax_t>(y.GetDimension(0).Extent())); 304 } else { 305 terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)", 306 static_cast<std::intmax_t>(x.GetDimension(0).Extent()), 307 static_cast<std::intmax_t>(n), 308 static_cast<std::intmax_t>(y.GetDimension(0).Extent()), 309 static_cast<std::intmax_t>(y.GetDimension(1).Extent())); 310 } 311 } 312 using WriteResult = 313 CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT, 314 RKIND>; 315 if constexpr (RCAT != TypeCategory::Logical) { 316 if (x.IsContiguous(1) && y.IsContiguous(1) && 317 (IS_ALLOCATING || result.IsContiguous())) { 318 // Contiguous numeric matrices (maybe with columns 319 // separated by a stride). 320 Fortran::common::optional<std::size_t> xColumnByteStride; 321 if (!x.IsContiguous()) { 322 // X's columns are strided. 323 SubscriptValue xAt[2]{}; 324 x.GetLowerBounds(xAt); 325 xAt[1]++; 326 xColumnByteStride = x.SubscriptsToByteOffset(xAt); 327 } 328 Fortran::common::optional<std::size_t> yColumnByteStride; 329 if (!y.IsContiguous()) { 330 // Y's columns are strided. 331 SubscriptValue yAt[2]{}; 332 y.GetLowerBounds(yAt); 333 yAt[1]++; 334 yColumnByteStride = y.SubscriptsToByteOffset(yAt); 335 } 336 // Note that BLAS GEMM can be used for the strided 337 // columns by setting proper leading dimension size. 338 // This implies that the column stride is divisible 339 // by the element size, which is usually true. 340 if (resRank == 2) { // M*M -> M 341 if (std::is_same_v<XT, YT>) { 342 if constexpr (std::is_same_v<XT, float>) { 343 // TODO: call BLAS-3 SGEMM 344 // TODO: try using CUTLASS for device. 345 } else if constexpr (std::is_same_v<XT, double>) { 346 // TODO: call BLAS-3 DGEMM 347 } else if constexpr (std::is_same_v<XT, std::complex<float>>) { 348 // TODO: call BLAS-3 CGEMM 349 } else if constexpr (std::is_same_v<XT, std::complex<double>>) { 350 // TODO: call BLAS-3 ZGEMM 351 } 352 } 353 MatrixTimesMatrixHelper<RCAT, RKIND, XT, YT>( 354 result.template OffsetElement<WriteResult>(), extent[0], extent[1], 355 x.OffsetElement<XT>(), y.OffsetElement<YT>(), n, xColumnByteStride, 356 yColumnByteStride); 357 return; 358 } else if (xRank == 2) { // M*V -> V 359 if (std::is_same_v<XT, YT>) { 360 if constexpr (std::is_same_v<XT, float>) { 361 // TODO: call BLAS-2 SGEMV(x,y) 362 } else if constexpr (std::is_same_v<XT, double>) { 363 // TODO: call BLAS-2 DGEMV(x,y) 364 } else if constexpr (std::is_same_v<XT, std::complex<float>>) { 365 // TODO: call BLAS-2 CGEMV(x,y) 366 } else if constexpr (std::is_same_v<XT, std::complex<double>>) { 367 // TODO: call BLAS-2 ZGEMV(x,y) 368 } 369 } 370 MatrixTimesVectorHelper<RCAT, RKIND, XT, YT>( 371 result.template OffsetElement<WriteResult>(), extent[0], n, 372 x.OffsetElement<XT>(), y.OffsetElement<YT>(), xColumnByteStride); 373 return; 374 } else { // V*M -> V 375 if (std::is_same_v<XT, YT>) { 376 if constexpr (std::is_same_v<XT, float>) { 377 // TODO: call BLAS-2 SGEMV(y,x) 378 } else if constexpr (std::is_same_v<XT, double>) { 379 // TODO: call BLAS-2 DGEMV(y,x) 380 } else if constexpr (std::is_same_v<XT, std::complex<float>>) { 381 // TODO: call BLAS-2 CGEMV(y,x) 382 } else if constexpr (std::is_same_v<XT, std::complex<double>>) { 383 // TODO: call BLAS-2 ZGEMV(y,x) 384 } 385 } 386 VectorTimesMatrixHelper<RCAT, RKIND, XT, YT>( 387 result.template OffsetElement<WriteResult>(), n, extent[0], 388 x.OffsetElement<XT>(), y.OffsetElement<YT>(), yColumnByteStride); 389 return; 390 } 391 } 392 } 393 // General algorithms for LOGICAL and noncontiguity 394 SubscriptValue xAt[2], yAt[2], resAt[2]; 395 x.GetLowerBounds(xAt); 396 y.GetLowerBounds(yAt); 397 result.GetLowerBounds(resAt); 398 if (resRank == 2) { // M*M -> M 399 SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]}; 400 for (SubscriptValue i{0}; i < extent[0]; ++i) { 401 for (SubscriptValue j{0}; j < extent[1]; ++j) { 402 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; 403 yAt[1] = y1 + j; 404 for (SubscriptValue k{0}; k < n; ++k) { 405 xAt[1] = x1 + k; 406 yAt[0] = y0 + k; 407 accumulator.Accumulate(xAt, yAt); 408 } 409 resAt[1] = res1 + j; 410 *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); 411 } 412 ++resAt[0]; 413 ++xAt[0]; 414 } 415 } else if (xRank == 2) { // M*V -> V 416 SubscriptValue x1{xAt[1]}, y0{yAt[0]}; 417 for (SubscriptValue j{0}; j < extent[0]; ++j) { 418 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; 419 for (SubscriptValue k{0}; k < n; ++k) { 420 xAt[1] = x1 + k; 421 yAt[0] = y0 + k; 422 accumulator.Accumulate(xAt, yAt); 423 } 424 *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); 425 ++resAt[0]; 426 ++xAt[0]; 427 } 428 } else { // V*M -> V 429 SubscriptValue x0{xAt[0]}, y0{yAt[0]}; 430 for (SubscriptValue j{0}; j < extent[0]; ++j) { 431 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; 432 for (SubscriptValue k{0}; k < n; ++k) { 433 xAt[0] = x0 + k; 434 yAt[0] = y0 + k; 435 accumulator.Accumulate(xAt, yAt); 436 } 437 *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); 438 ++resAt[0]; 439 ++yAt[1]; 440 } 441 } 442 } 443 444 RT_DIAG_POP 445 446 template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT, 447 int YKIND> 448 struct MatmulHelper { 449 using ResultDescriptor = 450 std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>; 451 RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x, 452 const Descriptor &y, const char *sourceFile, int line) const { 453 Terminator terminator{sourceFile, line}; 454 auto xCatKind{x.type().GetCategoryAndKind()}; 455 auto yCatKind{y.type().GetCategoryAndKind()}; 456 RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); 457 RUNTIME_CHECK(terminator, xCatKind->first == XCAT); 458 RUNTIME_CHECK(terminator, yCatKind->first == YCAT); 459 if constexpr (constexpr auto resultType{ 460 GetResultType(XCAT, XKIND, YCAT, YKIND)}) { 461 return DoMatmul<IS_ALLOCATING, resultType->first, resultType->second, 462 CppTypeFor<XCAT, XKIND>, CppTypeFor<YCAT, YKIND>>( 463 result, x, y, terminator); 464 } 465 terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))", 466 static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND); 467 } 468 }; 469 } // namespace 470 471 namespace Fortran::runtime { 472 extern "C" { 473 RT_EXT_API_GROUP_BEGIN 474 475 #define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \ 476 void RTDEF(Matmul##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \ 477 const Descriptor &x, const Descriptor &y, const char *sourceFile, \ 478 int line) { \ 479 MatmulHelper<true, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \ 480 YKIND>{}(result, x, y, sourceFile, line); \ 481 } 482 483 #define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \ 484 void RTDEF(MatmulDirect##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \ 485 const Descriptor &x, const Descriptor &y, const char *sourceFile, \ 486 int line) { \ 487 MatmulHelper<false, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \ 488 YKIND>{}(result, x, y, sourceFile, line); \ 489 } 490 491 #define MATMUL_FORCE_ALL_TYPES 0 492 493 #include "flang/Runtime/matmul-instances.inc" 494 495 RT_EXT_API_GROUP_END 496 } // extern "C" 497 } // namespace Fortran::runtime 498