xref: /llvm-project/flang/runtime/matmul.cpp (revision a5a493e1920572ca78b77fcdaef68c26e96d25e7)
1 //===-- runtime/matmul.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // Implements all forms of MATMUL (Fortran 2018 16.9.124)
10 //
11 // There are two main entry points; one establishes a descriptor for the
12 // result and allocates it, and the other expects a result descriptor that
13 // points to existing storage.
14 //
15 // This implementation must handle all combinations of numeric types and
16 // kinds (100 - 165 cases depending on the target), plus all combinations
17 // of logical kinds (16).  A single template undergoes many instantiations
18 // to cover all of the valid possibilities.
19 //
20 // Places where BLAS routines could be called are marked as TODO items.
21 
22 #include "flang/Runtime/matmul.h"
23 #include "terminator.h"
24 #include "tools.h"
25 #include "flang/Runtime/c-or-cpp.h"
26 #include "flang/Runtime/cpp-type.h"
27 #include "flang/Runtime/descriptor.h"
28 #include <cstring>
29 
30 namespace Fortran::runtime {
31 
32 // General accumulator for any type and stride; this is not used for
33 // contiguous numeric cases.
34 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
35 class Accumulator {
36 public:
37   using Result = AccumulationType<RCAT, RKIND>;
38   Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
39   void Accumulate(const SubscriptValue xAt[], const SubscriptValue yAt[]) {
40     if constexpr (RCAT == TypeCategory::Logical) {
41       sum_ = sum_ ||
42           (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt));
43     } else {
44       sum_ += static_cast<Result>(*x_.Element<XT>(xAt)) *
45           static_cast<Result>(*y_.Element<YT>(yAt));
46     }
47   }
48   Result GetResult() const { return sum_; }
49 
50 private:
51   const Descriptor &x_, &y_;
52   Result sum_{};
53 };
54 
55 // Contiguous numeric matrix*matrix multiplication
56 //   matrix(rows,n) * matrix(n,cols) -> matrix(rows,cols)
57 // Straightforward algorithm:
58 //   DO 1 I = 1, NROWS
59 //    DO 1 J = 1, NCOLS
60 //     RES(I,J) = 0
61 //     DO 1 K = 1, N
62 //   1  RES(I,J) = RES(I,J) + X(I,K)*Y(K,J)
63 // With loop distribution and transposition to avoid the inner sum
64 // reduction and to avoid non-unit strides:
65 //   DO 1 I = 1, NROWS
66 //    DO 1 J = 1, NCOLS
67 //   1 RES(I,J) = 0
68 //   DO 2 K = 1, N
69 //    DO 2 J = 1, NCOLS
70 //     DO 2 I = 1, NROWS
71 //   2  RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term
72 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
73 inline void MatrixTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product,
74     SubscriptValue rows, SubscriptValue cols, const XT *RESTRICT x,
75     const YT *RESTRICT y, SubscriptValue n) {
76   using ResultType = CppTypeFor<RCAT, RKIND>;
77   std::memset(product, 0, rows * cols * sizeof *product);
78   const XT *RESTRICT xp0{x};
79   for (SubscriptValue k{0}; k < n; ++k) {
80     ResultType *RESTRICT p{product};
81     for (SubscriptValue j{0}; j < cols; ++j) {
82       const XT *RESTRICT xp{xp0};
83       auto yv{static_cast<ResultType>(y[k + j * n])};
84       for (SubscriptValue i{0}; i < rows; ++i) {
85         *p++ += static_cast<ResultType>(*xp++) * yv;
86       }
87     }
88     xp0 += rows;
89   }
90 }
91 
92 // Contiguous numeric matrix*vector multiplication
93 //   matrix(rows,n) * column vector(n) -> column vector(rows)
94 // Straightforward algorithm:
95 //   DO 1 J = 1, NROWS
96 //    RES(J) = 0
97 //    DO 1 K = 1, N
98 //   1 RES(J) = RES(J) + X(J,K)*Y(K)
99 // With loop distribution and transposition to avoid the inner
100 // sum reduction and to avoid non-unit strides:
101 //   DO 1 J = 1, NROWS
102 //   1 RES(J) = 0
103 //   DO 2 K = 1, N
104 //    DO 2 J = 1, NROWS
105 //   2 RES(J) = RES(J) + X(J,K)*Y(K)
106 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
107 inline void MatrixTimesVector(CppTypeFor<RCAT, RKIND> *RESTRICT product,
108     SubscriptValue rows, SubscriptValue n, const XT *RESTRICT x,
109     const YT *RESTRICT y) {
110   using ResultType = CppTypeFor<RCAT, RKIND>;
111   std::memset(product, 0, rows * sizeof *product);
112   for (SubscriptValue k{0}; k < n; ++k) {
113     ResultType *RESTRICT p{product};
114     auto yv{static_cast<ResultType>(*y++)};
115     for (SubscriptValue j{0}; j < rows; ++j) {
116       *p++ += static_cast<ResultType>(*x++) * yv;
117     }
118   }
119 }
120 
121 // Contiguous numeric vector*matrix multiplication
122 //   row vector(n) * matrix(n,cols) -> row vector(cols)
123 // Straightforward algorithm:
124 //   DO 1 J = 1, NCOLS
125 //    RES(J) = 0
126 //    DO 1 K = 1, N
127 //   1 RES(J) = RES(J) + X(K)*Y(K,J)
128 // With loop distribution and transposition to avoid the inner
129 // sum reduction and one non-unit stride (the other remains):
130 //   DO 1 J = 1, NCOLS
131 //   1 RES(J) = 0
132 //   DO 2 K = 1, N
133 //    DO 2 J = 1, NCOLS
134 //   2 RES(J) = RES(J) + X(K)*Y(K,J)
135 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
136 inline void VectorTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product,
137     SubscriptValue n, SubscriptValue cols, const XT *RESTRICT x,
138     const YT *RESTRICT y) {
139   using ResultType = CppTypeFor<RCAT, RKIND>;
140   std::memset(product, 0, cols * sizeof *product);
141   for (SubscriptValue k{0}; k < n; ++k) {
142     ResultType *RESTRICT p{product};
143     auto xv{static_cast<ResultType>(*x++)};
144     const YT *RESTRICT yp{&y[k]};
145     for (SubscriptValue j{0}; j < cols; ++j) {
146       *p++ += xv * static_cast<ResultType>(*yp);
147       yp += n;
148     }
149   }
150 }
151 
152 // Implements an instance of MATMUL for given argument types.
153 template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT,
154     typename YT>
155 static inline void DoMatmul(
156     std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result,
157     const Descriptor &x, const Descriptor &y, Terminator &terminator) {
158   int xRank{x.rank()};
159   int yRank{y.rank()};
160   int resRank{xRank + yRank - 2};
161   if (xRank * yRank != 2 * resRank) {
162     terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank);
163   }
164   SubscriptValue extent[2]{
165       xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(),
166       resRank == 2 ? y.GetDimension(1).Extent() : 0};
167   if constexpr (IS_ALLOCATING) {
168     result.Establish(
169         RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable);
170     for (int j{0}; j < resRank; ++j) {
171       result.GetDimension(j).SetBounds(1, extent[j]);
172     }
173     if (int stat{result.Allocate()}) {
174       terminator.Crash(
175           "MATMUL: could not allocate memory for result; STAT=%d", stat);
176     }
177   } else {
178     RUNTIME_CHECK(terminator, resRank == result.rank());
179     RUNTIME_CHECK(
180         terminator, result.ElementBytes() == static_cast<std::size_t>(RKIND));
181     RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]);
182     RUNTIME_CHECK(terminator,
183         resRank == 1 || result.GetDimension(1).Extent() == extent[1]);
184   }
185   SubscriptValue n{x.GetDimension(xRank - 1).Extent()};
186   if (n != y.GetDimension(0).Extent()) {
187     terminator.Crash("MATMUL: arrays do not conform (%jd != %jd)",
188         static_cast<std::intmax_t>(n),
189         static_cast<std::intmax_t>(y.GetDimension(0).Extent()));
190   }
191   using WriteResult =
192       CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,
193           RKIND>;
194   if constexpr (RCAT != TypeCategory::Logical) {
195     if (x.IsContiguous() && y.IsContiguous() &&
196         (IS_ALLOCATING || result.IsContiguous())) {
197       // Contiguous numeric matrices
198       if (resRank == 2) { // M*M -> M
199         if (std::is_same_v<XT, YT>) {
200           if constexpr (std::is_same_v<XT, float>) {
201             // TODO: call BLAS-3 SGEMM
202           } else if constexpr (std::is_same_v<XT, double>) {
203             // TODO: call BLAS-3 DGEMM
204           } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
205             // TODO: call BLAS-3 CGEMM
206           } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
207             // TODO: call BLAS-3 ZGEMM
208           }
209         }
210         MatrixTimesMatrix<RCAT, RKIND, XT, YT>(
211             result.template OffsetElement<WriteResult>(), extent[0], extent[1],
212             x.OffsetElement<XT>(), y.OffsetElement<YT>(), n);
213         return;
214       } else if (xRank == 2) { // M*V -> V
215         if (std::is_same_v<XT, YT>) {
216           if constexpr (std::is_same_v<XT, float>) {
217             // TODO: call BLAS-2 SGEMV(x,y)
218           } else if constexpr (std::is_same_v<XT, double>) {
219             // TODO: call BLAS-2 DGEMV(x,y)
220           } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
221             // TODO: call BLAS-2 CGEMV(x,y)
222           } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
223             // TODO: call BLAS-2 ZGEMV(x,y)
224           }
225         }
226         MatrixTimesVector<RCAT, RKIND, XT, YT>(
227             result.template OffsetElement<WriteResult>(), extent[0], n,
228             x.OffsetElement<XT>(), y.OffsetElement<YT>());
229         return;
230       } else { // V*M -> V
231         if (std::is_same_v<XT, YT>) {
232           if constexpr (std::is_same_v<XT, float>) {
233             // TODO: call BLAS-2 SGEMV(y,x)
234           } else if constexpr (std::is_same_v<XT, double>) {
235             // TODO: call BLAS-2 DGEMV(y,x)
236           } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
237             // TODO: call BLAS-2 CGEMV(y,x)
238           } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
239             // TODO: call BLAS-2 ZGEMV(y,x)
240           }
241         }
242         VectorTimesMatrix<RCAT, RKIND, XT, YT>(
243             result.template OffsetElement<WriteResult>(), n, extent[0],
244             x.OffsetElement<XT>(), y.OffsetElement<YT>());
245         return;
246       }
247     }
248   }
249   // General algorithms for LOGICAL and noncontiguity
250   SubscriptValue xAt[2], yAt[2], resAt[2];
251   x.GetLowerBounds(xAt);
252   y.GetLowerBounds(yAt);
253   result.GetLowerBounds(resAt);
254   if (resRank == 2) { // M*M -> M
255     SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]};
256     for (SubscriptValue i{0}; i < extent[0]; ++i) {
257       for (SubscriptValue j{0}; j < extent[1]; ++j) {
258         Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
259         yAt[1] = y1 + j;
260         for (SubscriptValue k{0}; k < n; ++k) {
261           xAt[1] = x1 + k;
262           yAt[0] = y0 + k;
263           accumulator.Accumulate(xAt, yAt);
264         }
265         resAt[1] = res1 + j;
266         *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
267       }
268       ++resAt[0];
269       ++xAt[0];
270     }
271   } else if (xRank == 2) { // M*V -> V
272     SubscriptValue x1{xAt[1]}, y0{yAt[0]};
273     for (SubscriptValue j{0}; j < extent[0]; ++j) {
274       Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
275       for (SubscriptValue k{0}; k < n; ++k) {
276         xAt[1] = x1 + k;
277         yAt[0] = y0 + k;
278         accumulator.Accumulate(xAt, yAt);
279       }
280       *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
281       ++resAt[0];
282       ++xAt[0];
283     }
284   } else { // V*M -> V
285     SubscriptValue x0{xAt[0]}, y0{yAt[0]};
286     for (SubscriptValue j{0}; j < extent[0]; ++j) {
287       Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
288       for (SubscriptValue k{0}; k < n; ++k) {
289         xAt[0] = x0 + k;
290         yAt[0] = y0 + k;
291         accumulator.Accumulate(xAt, yAt);
292       }
293       *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
294       ++resAt[0];
295       ++yAt[1];
296     }
297   }
298 }
299 
300 // Maps the dynamic type information from the arguments' descriptors
301 // to the right instantiation of DoMatmul() for valid combinations of
302 // types.
303 template <bool IS_ALLOCATING> struct Matmul {
304   using ResultDescriptor =
305       std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
306   template <TypeCategory XCAT, int XKIND> struct MM1 {
307     template <TypeCategory YCAT, int YKIND> struct MM2 {
308       void operator()(ResultDescriptor &result, const Descriptor &x,
309           const Descriptor &y, Terminator &terminator) const {
310         if constexpr (constexpr auto resultType{
311                           GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
312           if constexpr (common::IsNumericTypeCategory(resultType->first) ||
313               resultType->first == TypeCategory::Logical) {
314             return DoMatmul<IS_ALLOCATING, resultType->first,
315                 resultType->second, CppTypeFor<XCAT, XKIND>,
316                 CppTypeFor<YCAT, YKIND>>(result, x, y, terminator);
317           }
318         }
319         terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))",
320             static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
321       }
322     };
323     void operator()(ResultDescriptor &result, const Descriptor &x,
324         const Descriptor &y, Terminator &terminator, TypeCategory yCat,
325         int yKind) const {
326       ApplyType<MM2, void>(yCat, yKind, terminator, result, x, y, terminator);
327     }
328   };
329   void operator()(ResultDescriptor &result, const Descriptor &x,
330       const Descriptor &y, const char *sourceFile, int line) const {
331     Terminator terminator{sourceFile, line};
332     auto xCatKind{x.type().GetCategoryAndKind()};
333     auto yCatKind{y.type().GetCategoryAndKind()};
334     RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
335     ApplyType<MM1, void>(xCatKind->first, xCatKind->second, terminator, result,
336         x, y, terminator, yCatKind->first, yCatKind->second);
337   }
338 };
339 
340 extern "C" {
341 void RTNAME(Matmul)(Descriptor &result, const Descriptor &x,
342     const Descriptor &y, const char *sourceFile, int line) {
343   Matmul<true>{}(result, x, y, sourceFile, line);
344 }
345 void RTNAME(MatmulDirect)(const Descriptor &result, const Descriptor &x,
346     const Descriptor &y, const char *sourceFile, int line) {
347   Matmul<false>{}(result, x, y, sourceFile, line);
348 }
349 } // extern "C"
350 } // namespace Fortran::runtime
351