1 //===-- runtime/matmul-transpose.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 // Implements a fused matmul-transpose operation 10 // 11 // There are two main entry points; one establishes a descriptor for the 12 // result and allocates it, and the other expects a result descriptor that 13 // points to existing storage. 14 // 15 // This implementation must handle all combinations of numeric types and 16 // kinds (100 - 165 cases depending on the target), plus all combinations 17 // of logical kinds (16). A single template undergoes many instantiations 18 // to cover all of the valid possibilities. 19 // 20 // The usefulness of this optimization should be reviewed once Matmul is swapped 21 // to use the faster BLAS routines. 22 23 #include "flang/Runtime/matmul-transpose.h" 24 #include "terminator.h" 25 #include "tools.h" 26 #include "flang/Common/optional.h" 27 #include "flang/Runtime/c-or-cpp.h" 28 #include "flang/Runtime/cpp-type.h" 29 #include "flang/Runtime/descriptor.h" 30 #include <cstring> 31 32 namespace { 33 using namespace Fortran::runtime; 34 35 // Contiguous numeric TRANSPOSE(matrix)*matrix multiplication 36 // TRANSPOSE(matrix(n, rows)) * matrix(n,cols) -> 37 // matrix(rows, n) * matrix(n,cols) -> matrix(rows,cols) 38 // The transpose is implemented by swapping the indices of accesses into the LHS 39 // 40 // Straightforward algorithm: 41 // DO 1 I = 1, NROWS 42 // DO 1 J = 1, NCOLS 43 // RES(I,J) = 0 44 // DO 1 K = 1, N 45 // 1 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J) 46 // 47 // With loop distribution and transposition to avoid the inner sum 48 // reduction and to avoid non-unit strides: 49 // DO 1 I = 1, NROWS 50 // DO 1 J = 1, NCOLS 51 // 1 RES(I,J) = 0 52 // DO 2 J = 1, NCOLS 53 // DO 2 I = 1, NROWS 54 // DO 2 K = 1, N 55 // 2 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J) ! loop-invariant last term 56 template <TypeCategory RCAT, int RKIND, typename XT, typename YT, 57 bool X_HAS_STRIDED_COLUMNS, bool Y_HAS_STRIDED_COLUMNS> 58 inline static RT_API_ATTRS void MatrixTransposedTimesMatrix( 59 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, 60 SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, 61 SubscriptValue n, std::size_t xColumnByteStride = 0, 62 std::size_t yColumnByteStride = 0) { 63 using ResultType = CppTypeFor<RCAT, RKIND>; 64 65 std::memset(product, 0, rows * cols * sizeof *product); 66 for (SubscriptValue j{0}; j < cols; ++j) { 67 for (SubscriptValue i{0}; i < rows; ++i) { 68 for (SubscriptValue k{0}; k < n; ++k) { 69 ResultType x_ki; 70 if constexpr (!X_HAS_STRIDED_COLUMNS) { 71 x_ki = static_cast<ResultType>(x[i * n + k]); 72 } else { 73 x_ki = static_cast<ResultType>(reinterpret_cast<const XT *>( 74 reinterpret_cast<const char *>(x) + i * xColumnByteStride)[k]); 75 } 76 ResultType y_kj; 77 if constexpr (!Y_HAS_STRIDED_COLUMNS) { 78 y_kj = static_cast<ResultType>(y[j * n + k]); 79 } else { 80 y_kj = static_cast<ResultType>(reinterpret_cast<const YT *>( 81 reinterpret_cast<const char *>(y) + j * yColumnByteStride)[k]); 82 } 83 product[j * rows + i] += x_ki * y_kj; 84 } 85 } 86 } 87 } 88 89 template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 90 inline static RT_API_ATTRS void MatrixTransposedTimesMatrixHelper( 91 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, 92 SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, 93 SubscriptValue n, Fortran::common::optional<std::size_t> xColumnByteStride, 94 Fortran::common::optional<std::size_t> yColumnByteStride) { 95 if (!xColumnByteStride) { 96 if (!yColumnByteStride) { 97 MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, false, false>( 98 product, rows, cols, x, y, n); 99 } else { 100 MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, false, true>( 101 product, rows, cols, x, y, n, 0, *yColumnByteStride); 102 } 103 } else { 104 if (!yColumnByteStride) { 105 MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, true, false>( 106 product, rows, cols, x, y, n, *xColumnByteStride); 107 } else { 108 MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, true, true>( 109 product, rows, cols, x, y, n, *xColumnByteStride, *yColumnByteStride); 110 } 111 } 112 } 113 114 // Contiguous numeric matrix*vector multiplication 115 // matrix(rows,n) * column vector(n) -> column vector(rows) 116 // Straightforward algorithm: 117 // DO 1 I = 1, NROWS 118 // RES(I) = 0 119 // DO 1 K = 1, N 120 // 1 RES(I) = RES(I) + X(K,I)*Y(K) 121 // With loop distribution and transposition to avoid the inner 122 // sum reduction and to avoid non-unit strides: 123 // DO 1 I = 1, NROWS 124 // 1 RES(I) = 0 125 // DO 2 I = 1, NROWS 126 // DO 2 K = 1, N 127 // 2 RES(I) = RES(I) + X(K,I)*Y(K) 128 template <TypeCategory RCAT, int RKIND, typename XT, typename YT, 129 bool X_HAS_STRIDED_COLUMNS> 130 inline static RT_API_ATTRS void MatrixTransposedTimesVector( 131 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, 132 SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y, 133 std::size_t xColumnByteStride = 0) { 134 using ResultType = CppTypeFor<RCAT, RKIND>; 135 std::memset(product, 0, rows * sizeof *product); 136 for (SubscriptValue i{0}; i < rows; ++i) { 137 for (SubscriptValue k{0}; k < n; ++k) { 138 ResultType x_ki; 139 if constexpr (!X_HAS_STRIDED_COLUMNS) { 140 x_ki = static_cast<ResultType>(x[i * n + k]); 141 } else { 142 x_ki = static_cast<ResultType>(reinterpret_cast<const XT *>( 143 reinterpret_cast<const char *>(x) + i * xColumnByteStride)[k]); 144 } 145 ResultType y_k = static_cast<ResultType>(y[k]); 146 product[i] += x_ki * y_k; 147 } 148 } 149 } 150 151 template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 152 inline static RT_API_ATTRS void MatrixTransposedTimesVectorHelper( 153 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, 154 SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y, 155 Fortran::common::optional<std::size_t> xColumnByteStride) { 156 if (!xColumnByteStride) { 157 MatrixTransposedTimesVector<RCAT, RKIND, XT, YT, false>( 158 product, rows, n, x, y); 159 } else { 160 MatrixTransposedTimesVector<RCAT, RKIND, XT, YT, true>( 161 product, rows, n, x, y, *xColumnByteStride); 162 } 163 } 164 165 // Implements an instance of MATMUL for given argument types. 166 template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT, 167 typename YT> 168 inline static RT_API_ATTRS void DoMatmulTranspose( 169 std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result, 170 const Descriptor &x, const Descriptor &y, Terminator &terminator) { 171 int xRank{x.rank()}; 172 int yRank{y.rank()}; 173 int resRank{xRank + yRank - 2}; 174 if (xRank * yRank != 2 * resRank) { 175 terminator.Crash( 176 "MATMUL-TRANSPOSE: bad argument ranks (%d * %d)", xRank, yRank); 177 } 178 SubscriptValue extent[2]{x.GetDimension(1).Extent(), 179 resRank == 2 ? y.GetDimension(1).Extent() : 0}; 180 if constexpr (IS_ALLOCATING) { 181 result.Establish( 182 RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable); 183 for (int j{0}; j < resRank; ++j) { 184 result.GetDimension(j).SetBounds(1, extent[j]); 185 } 186 if (int stat{result.Allocate()}) { 187 terminator.Crash( 188 "MATMUL-TRANSPOSE: could not allocate memory for result; STAT=%d", 189 stat); 190 } 191 } else { 192 RUNTIME_CHECK(terminator, resRank == result.rank()); 193 RUNTIME_CHECK( 194 terminator, result.ElementBytes() == static_cast<std::size_t>(RKIND)); 195 RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]); 196 RUNTIME_CHECK(terminator, 197 resRank == 1 || result.GetDimension(1).Extent() == extent[1]); 198 } 199 SubscriptValue n{x.GetDimension(0).Extent()}; 200 if (n != y.GetDimension(0).Extent()) { 201 terminator.Crash( 202 "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)", 203 static_cast<std::intmax_t>(x.GetDimension(0).Extent()), 204 static_cast<std::intmax_t>(x.GetDimension(1).Extent()), 205 static_cast<std::intmax_t>(y.GetDimension(0).Extent()), 206 static_cast<std::intmax_t>(y.GetDimension(1).Extent())); 207 } 208 using WriteResult = 209 CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT, 210 RKIND>; 211 const SubscriptValue rows{extent[0]}; 212 const SubscriptValue cols{extent[1]}; 213 if constexpr (RCAT != TypeCategory::Logical) { 214 if (x.IsContiguous(1) && y.IsContiguous(1) && 215 (IS_ALLOCATING || result.IsContiguous())) { 216 // Contiguous numeric matrices (maybe with columns 217 // separated by a stride). 218 Fortran::common::optional<std::size_t> xColumnByteStride; 219 if (!x.IsContiguous()) { 220 // X's columns are strided. 221 SubscriptValue xAt[2]{}; 222 x.GetLowerBounds(xAt); 223 xAt[1]++; 224 xColumnByteStride = x.SubscriptsToByteOffset(xAt); 225 } 226 Fortran::common::optional<std::size_t> yColumnByteStride; 227 if (!y.IsContiguous()) { 228 // Y's columns are strided. 229 SubscriptValue yAt[2]{}; 230 y.GetLowerBounds(yAt); 231 yAt[1]++; 232 yColumnByteStride = y.SubscriptsToByteOffset(yAt); 233 } 234 if (resRank == 2) { // M*M -> M 235 // TODO: use BLAS-3 GEMM for supported types. 236 MatrixTransposedTimesMatrixHelper<RCAT, RKIND, XT, YT>( 237 result.template OffsetElement<WriteResult>(), rows, cols, 238 x.OffsetElement<XT>(), y.OffsetElement<YT>(), n, xColumnByteStride, 239 yColumnByteStride); 240 return; 241 } 242 if (xRank == 2) { // M*V -> V 243 // TODO: use BLAS-2 GEMM for supported types. 244 MatrixTransposedTimesVectorHelper<RCAT, RKIND, XT, YT>( 245 result.template OffsetElement<WriteResult>(), rows, n, 246 x.OffsetElement<XT>(), y.OffsetElement<YT>(), xColumnByteStride); 247 return; 248 } 249 // else V*M -> V (not allowed because TRANSPOSE() is only defined for rank 250 // 1 matrices 251 terminator.Crash( 252 "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)", 253 static_cast<std::intmax_t>(x.GetDimension(0).Extent()), 254 static_cast<std::intmax_t>(n), 255 static_cast<std::intmax_t>(y.GetDimension(0).Extent()), 256 static_cast<std::intmax_t>(y.GetDimension(1).Extent())); 257 return; 258 } 259 } 260 // General algorithms for LOGICAL and noncontiguity 261 SubscriptValue xLB[2], yLB[2], resLB[2]; 262 x.GetLowerBounds(xLB); 263 y.GetLowerBounds(yLB); 264 result.GetLowerBounds(resLB); 265 using ResultType = CppTypeFor<RCAT, RKIND>; 266 if (resRank == 2) { // M*M -> M 267 for (SubscriptValue i{0}; i < rows; ++i) { 268 for (SubscriptValue j{0}; j < cols; ++j) { 269 ResultType res_ij; 270 if constexpr (RCAT == TypeCategory::Logical) { 271 res_ij = false; 272 } else { 273 res_ij = 0; 274 } 275 276 for (SubscriptValue k{0}; k < n; ++k) { 277 SubscriptValue xAt[2]{k + xLB[0], i + xLB[1]}; 278 SubscriptValue yAt[2]{k + yLB[0], j + yLB[1]}; 279 if constexpr (RCAT == TypeCategory::Logical) { 280 ResultType x_ki = IsLogicalElementTrue(x, xAt); 281 ResultType y_kj = IsLogicalElementTrue(y, yAt); 282 res_ij = res_ij || (x_ki && y_kj); 283 } else { 284 ResultType x_ki = static_cast<ResultType>(*x.Element<XT>(xAt)); 285 ResultType y_kj = static_cast<ResultType>(*y.Element<YT>(yAt)); 286 res_ij += x_ki * y_kj; 287 } 288 } 289 SubscriptValue resAt[2]{i + resLB[0], j + resLB[1]}; 290 *result.template Element<WriteResult>(resAt) = res_ij; 291 } 292 } 293 } else if (xRank == 2) { // M*V -> V 294 for (SubscriptValue i{0}; i < rows; ++i) { 295 ResultType res_i; 296 if constexpr (RCAT == TypeCategory::Logical) { 297 res_i = false; 298 } else { 299 res_i = 0; 300 } 301 302 for (SubscriptValue k{0}; k < n; ++k) { 303 SubscriptValue xAt[2]{k + xLB[0], i + xLB[1]}; 304 SubscriptValue yAt[1]{k + yLB[0]}; 305 if constexpr (RCAT == TypeCategory::Logical) { 306 ResultType x_ki = IsLogicalElementTrue(x, xAt); 307 ResultType y_k = IsLogicalElementTrue(y, yAt); 308 res_i = res_i || (x_ki && y_k); 309 } else { 310 ResultType x_ki = static_cast<ResultType>(*x.Element<XT>(xAt)); 311 ResultType y_k = static_cast<ResultType>(*y.Element<YT>(yAt)); 312 res_i += x_ki * y_k; 313 } 314 } 315 SubscriptValue resAt[1]{i + resLB[0]}; 316 *result.template Element<WriteResult>(resAt) = res_i; 317 } 318 } else { // V*M -> V 319 // TRANSPOSE(V) not allowed by fortran standard 320 terminator.Crash( 321 "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)", 322 static_cast<std::intmax_t>(x.GetDimension(0).Extent()), 323 static_cast<std::intmax_t>(n), 324 static_cast<std::intmax_t>(y.GetDimension(0).Extent()), 325 static_cast<std::intmax_t>(y.GetDimension(1).Extent())); 326 } 327 } 328 329 template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT, 330 int YKIND> 331 struct MatmulTransposeHelper { 332 using ResultDescriptor = 333 std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>; 334 RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x, 335 const Descriptor &y, const char *sourceFile, int line) const { 336 Terminator terminator{sourceFile, line}; 337 auto xCatKind{x.type().GetCategoryAndKind()}; 338 auto yCatKind{y.type().GetCategoryAndKind()}; 339 RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); 340 RUNTIME_CHECK(terminator, xCatKind->first == XCAT); 341 RUNTIME_CHECK(terminator, yCatKind->first == YCAT); 342 if constexpr (constexpr auto resultType{ 343 GetResultType(XCAT, XKIND, YCAT, YKIND)}) { 344 return DoMatmulTranspose<IS_ALLOCATING, resultType->first, 345 resultType->second, CppTypeFor<XCAT, XKIND>, CppTypeFor<YCAT, YKIND>>( 346 result, x, y, terminator); 347 } 348 terminator.Crash("MATMUL-TRANSPOSE: bad operand types (%d(%d), %d(%d))", 349 static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND); 350 } 351 }; 352 } // namespace 353 354 namespace Fortran::runtime { 355 extern "C" { 356 RT_EXT_API_GROUP_BEGIN 357 358 #define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \ 359 void RTDEF(MatmulTranspose##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \ 360 const Descriptor &x, const Descriptor &y, const char *sourceFile, \ 361 int line) { \ 362 MatmulTransposeHelper<true, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \ 363 YKIND>{}(result, x, y, sourceFile, line); \ 364 } 365 366 #define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \ 367 void RTDEF(MatmulTransposeDirect##XCAT##XKIND##YCAT##YKIND)( \ 368 Descriptor & result, const Descriptor &x, const Descriptor &y, \ 369 const char *sourceFile, int line) { \ 370 MatmulTransposeHelper<false, TypeCategory::XCAT, XKIND, \ 371 TypeCategory::YCAT, YKIND>{}(result, x, y, sourceFile, line); \ 372 } 373 374 #define MATMUL_FORCE_ALL_TYPES 0 375 376 #include "flang/Runtime/matmul-instances.inc" 377 378 RT_EXT_API_GROUP_END 379 } // extern "C" 380 } // namespace Fortran::runtime 381