1 //===-- runtime/matmul.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 // Implements all forms of MATMUL (Fortran 2018 16.9.124) 10 // 11 // There are two main entry points; one establishes a descriptor for the 12 // result and allocates it, and the other expects a result descriptor that 13 // points to existing storage. 14 // 15 // This implementation must handle all combinations of numeric types and 16 // kinds (100 - 165 cases depending on the target), plus all combinations 17 // of logical kinds (16). A single template undergoes many instantiations 18 // to cover all of the valid possibilities. 19 // 20 // Places where BLAS routines could be called are marked as TODO items. 21 22 #include "flang/Runtime/matmul.h" 23 #include "terminator.h" 24 #include "tools.h" 25 #include "flang/Runtime/c-or-cpp.h" 26 #include "flang/Runtime/cpp-type.h" 27 #include "flang/Runtime/descriptor.h" 28 #include <cstring> 29 30 namespace Fortran::runtime { 31 32 // General accumulator for any type and stride; this is not used for 33 // contiguous numeric cases. 34 template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 35 class Accumulator { 36 public: 37 using Result = AccumulationType<RCAT, RKIND>; 38 Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {} 39 void Accumulate(const SubscriptValue xAt[], const SubscriptValue yAt[]) { 40 if constexpr (RCAT == TypeCategory::Logical) { 41 sum_ = sum_ || 42 (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt)); 43 } else { 44 sum_ += static_cast<Result>(*x_.Element<XT>(xAt)) * 45 static_cast<Result>(*y_.Element<YT>(yAt)); 46 } 47 } 48 Result GetResult() const { return sum_; } 49 50 private: 51 const Descriptor &x_, &y_; 52 Result sum_{}; 53 }; 54 55 // Contiguous numeric matrix*matrix multiplication 56 // matrix(rows,n) * matrix(n,cols) -> matrix(rows,cols) 57 // Straightforward algorithm: 58 // DO 1 I = 1, NROWS 59 // DO 1 J = 1, NCOLS 60 // RES(I,J) = 0 61 // DO 1 K = 1, N 62 // 1 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) 63 // With loop distribution and transposition to avoid the inner sum 64 // reduction and to avoid non-unit strides: 65 // DO 1 I = 1, NROWS 66 // DO 1 J = 1, NCOLS 67 // 1 RES(I,J) = 0 68 // DO 2 K = 1, N 69 // DO 2 J = 1, NCOLS 70 // DO 2 I = 1, NROWS 71 // 2 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term 72 template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 73 inline void MatrixTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product, 74 SubscriptValue rows, SubscriptValue cols, const XT *RESTRICT x, 75 const YT *RESTRICT y, SubscriptValue n) { 76 using ResultType = CppTypeFor<RCAT, RKIND>; 77 std::memset(product, 0, rows * cols * sizeof *product); 78 const XT *RESTRICT xp0{x}; 79 for (SubscriptValue k{0}; k < n; ++k) { 80 ResultType *RESTRICT p{product}; 81 for (SubscriptValue j{0}; j < cols; ++j) { 82 const XT *RESTRICT xp{xp0}; 83 auto yv{static_cast<ResultType>(y[k + j * n])}; 84 for (SubscriptValue i{0}; i < rows; ++i) { 85 *p++ += static_cast<ResultType>(*xp++) * yv; 86 } 87 } 88 xp0 += rows; 89 } 90 } 91 92 // Contiguous numeric matrix*vector multiplication 93 // matrix(rows,n) * column vector(n) -> column vector(rows) 94 // Straightforward algorithm: 95 // DO 1 J = 1, NROWS 96 // RES(J) = 0 97 // DO 1 K = 1, N 98 // 1 RES(J) = RES(J) + X(J,K)*Y(K) 99 // With loop distribution and transposition to avoid the inner 100 // sum reduction and to avoid non-unit strides: 101 // DO 1 J = 1, NROWS 102 // 1 RES(J) = 0 103 // DO 2 K = 1, N 104 // DO 2 J = 1, NROWS 105 // 2 RES(J) = RES(J) + X(J,K)*Y(K) 106 template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 107 inline void MatrixTimesVector(CppTypeFor<RCAT, RKIND> *RESTRICT product, 108 SubscriptValue rows, SubscriptValue n, const XT *RESTRICT x, 109 const YT *RESTRICT y) { 110 using ResultType = CppTypeFor<RCAT, RKIND>; 111 std::memset(product, 0, rows * sizeof *product); 112 for (SubscriptValue k{0}; k < n; ++k) { 113 ResultType *RESTRICT p{product}; 114 auto yv{static_cast<ResultType>(*y++)}; 115 for (SubscriptValue j{0}; j < rows; ++j) { 116 *p++ += static_cast<ResultType>(*x++) * yv; 117 } 118 } 119 } 120 121 // Contiguous numeric vector*matrix multiplication 122 // row vector(n) * matrix(n,cols) -> row vector(cols) 123 // Straightforward algorithm: 124 // DO 1 J = 1, NCOLS 125 // RES(J) = 0 126 // DO 1 K = 1, N 127 // 1 RES(J) = RES(J) + X(K)*Y(K,J) 128 // With loop distribution and transposition to avoid the inner 129 // sum reduction and one non-unit stride (the other remains): 130 // DO 1 J = 1, NCOLS 131 // 1 RES(J) = 0 132 // DO 2 K = 1, N 133 // DO 2 J = 1, NCOLS 134 // 2 RES(J) = RES(J) + X(K)*Y(K,J) 135 template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 136 inline void VectorTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product, 137 SubscriptValue n, SubscriptValue cols, const XT *RESTRICT x, 138 const YT *RESTRICT y) { 139 using ResultType = CppTypeFor<RCAT, RKIND>; 140 std::memset(product, 0, cols * sizeof *product); 141 for (SubscriptValue k{0}; k < n; ++k) { 142 ResultType *RESTRICT p{product}; 143 auto xv{static_cast<ResultType>(*x++)}; 144 const YT *RESTRICT yp{&y[k]}; 145 for (SubscriptValue j{0}; j < cols; ++j) { 146 *p++ += xv * static_cast<ResultType>(*yp); 147 yp += n; 148 } 149 } 150 } 151 152 // Implements an instance of MATMUL for given argument types. 153 template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT, 154 typename YT> 155 static inline void DoMatmul( 156 std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result, 157 const Descriptor &x, const Descriptor &y, Terminator &terminator) { 158 int xRank{x.rank()}; 159 int yRank{y.rank()}; 160 int resRank{xRank + yRank - 2}; 161 if (xRank * yRank != 2 * resRank) { 162 terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank); 163 } 164 SubscriptValue extent[2]{ 165 xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(), 166 resRank == 2 ? y.GetDimension(1).Extent() : 0}; 167 if constexpr (IS_ALLOCATING) { 168 result.Establish( 169 RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable); 170 for (int j{0}; j < resRank; ++j) { 171 result.GetDimension(j).SetBounds(1, extent[j]); 172 } 173 if (int stat{result.Allocate()}) { 174 terminator.Crash( 175 "MATMUL: could not allocate memory for result; STAT=%d", stat); 176 } 177 } else { 178 RUNTIME_CHECK(terminator, resRank == result.rank()); 179 RUNTIME_CHECK( 180 terminator, result.ElementBytes() == static_cast<std::size_t>(RKIND)); 181 RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]); 182 RUNTIME_CHECK(terminator, 183 resRank == 1 || result.GetDimension(1).Extent() == extent[1]); 184 } 185 SubscriptValue n{x.GetDimension(xRank - 1).Extent()}; 186 if (n != y.GetDimension(0).Extent()) { 187 terminator.Crash("MATMUL: arrays do not conform (%jd != %jd)", 188 static_cast<std::intmax_t>(n), 189 static_cast<std::intmax_t>(y.GetDimension(0).Extent())); 190 } 191 using WriteResult = 192 CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT, 193 RKIND>; 194 if constexpr (RCAT != TypeCategory::Logical) { 195 if (x.IsContiguous() && y.IsContiguous() && 196 (IS_ALLOCATING || result.IsContiguous())) { 197 // Contiguous numeric matrices 198 if (resRank == 2) { // M*M -> M 199 if (std::is_same_v<XT, YT>) { 200 if constexpr (std::is_same_v<XT, float>) { 201 // TODO: call BLAS-3 SGEMM 202 } else if constexpr (std::is_same_v<XT, double>) { 203 // TODO: call BLAS-3 DGEMM 204 } else if constexpr (std::is_same_v<XT, std::complex<float>>) { 205 // TODO: call BLAS-3 CGEMM 206 } else if constexpr (std::is_same_v<XT, std::complex<double>>) { 207 // TODO: call BLAS-3 ZGEMM 208 } 209 } 210 MatrixTimesMatrix<RCAT, RKIND, XT, YT>( 211 result.template OffsetElement<WriteResult>(), extent[0], extent[1], 212 x.OffsetElement<XT>(), y.OffsetElement<YT>(), n); 213 return; 214 } else if (xRank == 2) { // M*V -> V 215 if (std::is_same_v<XT, YT>) { 216 if constexpr (std::is_same_v<XT, float>) { 217 // TODO: call BLAS-2 SGEMV(x,y) 218 } else if constexpr (std::is_same_v<XT, double>) { 219 // TODO: call BLAS-2 DGEMV(x,y) 220 } else if constexpr (std::is_same_v<XT, std::complex<float>>) { 221 // TODO: call BLAS-2 CGEMV(x,y) 222 } else if constexpr (std::is_same_v<XT, std::complex<double>>) { 223 // TODO: call BLAS-2 ZGEMV(x,y) 224 } 225 } 226 MatrixTimesVector<RCAT, RKIND, XT, YT>( 227 result.template OffsetElement<WriteResult>(), extent[0], n, 228 x.OffsetElement<XT>(), y.OffsetElement<YT>()); 229 return; 230 } else { // V*M -> V 231 if (std::is_same_v<XT, YT>) { 232 if constexpr (std::is_same_v<XT, float>) { 233 // TODO: call BLAS-2 SGEMV(y,x) 234 } else if constexpr (std::is_same_v<XT, double>) { 235 // TODO: call BLAS-2 DGEMV(y,x) 236 } else if constexpr (std::is_same_v<XT, std::complex<float>>) { 237 // TODO: call BLAS-2 CGEMV(y,x) 238 } else if constexpr (std::is_same_v<XT, std::complex<double>>) { 239 // TODO: call BLAS-2 ZGEMV(y,x) 240 } 241 } 242 VectorTimesMatrix<RCAT, RKIND, XT, YT>( 243 result.template OffsetElement<WriteResult>(), n, extent[0], 244 x.OffsetElement<XT>(), y.OffsetElement<YT>()); 245 return; 246 } 247 } 248 } 249 // General algorithms for LOGICAL and noncontiguity 250 SubscriptValue xAt[2], yAt[2], resAt[2]; 251 x.GetLowerBounds(xAt); 252 y.GetLowerBounds(yAt); 253 result.GetLowerBounds(resAt); 254 if (resRank == 2) { // M*M -> M 255 SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]}; 256 for (SubscriptValue i{0}; i < extent[0]; ++i) { 257 for (SubscriptValue j{0}; j < extent[1]; ++j) { 258 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; 259 yAt[1] = y1 + j; 260 for (SubscriptValue k{0}; k < n; ++k) { 261 xAt[1] = x1 + k; 262 yAt[0] = y0 + k; 263 accumulator.Accumulate(xAt, yAt); 264 } 265 resAt[1] = res1 + j; 266 *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); 267 } 268 ++resAt[0]; 269 ++xAt[0]; 270 } 271 } else if (xRank == 2) { // M*V -> V 272 SubscriptValue x1{xAt[1]}, y0{yAt[0]}; 273 for (SubscriptValue j{0}; j < extent[0]; ++j) { 274 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; 275 for (SubscriptValue k{0}; k < n; ++k) { 276 xAt[1] = x1 + k; 277 yAt[0] = y0 + k; 278 accumulator.Accumulate(xAt, yAt); 279 } 280 *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); 281 ++resAt[0]; 282 ++xAt[0]; 283 } 284 } else { // V*M -> V 285 SubscriptValue x0{xAt[0]}, y0{yAt[0]}; 286 for (SubscriptValue j{0}; j < extent[0]; ++j) { 287 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; 288 for (SubscriptValue k{0}; k < n; ++k) { 289 xAt[0] = x0 + k; 290 yAt[0] = y0 + k; 291 accumulator.Accumulate(xAt, yAt); 292 } 293 *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); 294 ++resAt[0]; 295 ++yAt[1]; 296 } 297 } 298 } 299 300 // Maps the dynamic type information from the arguments' descriptors 301 // to the right instantiation of DoMatmul() for valid combinations of 302 // types. 303 template <bool IS_ALLOCATING> struct Matmul { 304 using ResultDescriptor = 305 std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>; 306 template <TypeCategory XCAT, int XKIND> struct MM1 { 307 template <TypeCategory YCAT, int YKIND> struct MM2 { 308 void operator()(ResultDescriptor &result, const Descriptor &x, 309 const Descriptor &y, Terminator &terminator) const { 310 if constexpr (constexpr auto resultType{ 311 GetResultType(XCAT, XKIND, YCAT, YKIND)}) { 312 if constexpr (common::IsNumericTypeCategory(resultType->first) || 313 resultType->first == TypeCategory::Logical) { 314 return DoMatmul<IS_ALLOCATING, resultType->first, 315 resultType->second, CppTypeFor<XCAT, XKIND>, 316 CppTypeFor<YCAT, YKIND>>(result, x, y, terminator); 317 } 318 } 319 terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))", 320 static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND); 321 } 322 }; 323 void operator()(ResultDescriptor &result, const Descriptor &x, 324 const Descriptor &y, Terminator &terminator, TypeCategory yCat, 325 int yKind) const { 326 ApplyType<MM2, void>(yCat, yKind, terminator, result, x, y, terminator); 327 } 328 }; 329 void operator()(ResultDescriptor &result, const Descriptor &x, 330 const Descriptor &y, const char *sourceFile, int line) const { 331 Terminator terminator{sourceFile, line}; 332 auto xCatKind{x.type().GetCategoryAndKind()}; 333 auto yCatKind{y.type().GetCategoryAndKind()}; 334 RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); 335 ApplyType<MM1, void>(xCatKind->first, xCatKind->second, terminator, result, 336 x, y, terminator, yCatKind->first, yCatKind->second); 337 } 338 }; 339 340 extern "C" { 341 void RTNAME(Matmul)(Descriptor &result, const Descriptor &x, 342 const Descriptor &y, const char *sourceFile, int line) { 343 Matmul<true>{}(result, x, y, sourceFile, line); 344 } 345 void RTNAME(MatmulDirect)(const Descriptor &result, const Descriptor &x, 346 const Descriptor &y, const char *sourceFile, int line) { 347 Matmul<false>{}(result, x, y, sourceFile, line); 348 } 349 } // extern "C" 350 } // namespace Fortran::runtime 351