xref: /llvm-project/flang/runtime/matmul-transpose.cpp (revision 4ff8ba72b58328ebf6e8eb8c10a428eece73c89f)
1 //===-- runtime/matmul-transpose.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // Implements a fused matmul-transpose operation
10 //
11 // There are two main entry points; one establishes a descriptor for the
12 // result and allocates it, and the other expects a result descriptor that
13 // points to existing storage.
14 //
15 // This implementation must handle all combinations of numeric types and
16 // kinds (100 - 165 cases depending on the target), plus all combinations
17 // of logical kinds (16).  A single template undergoes many instantiations
18 // to cover all of the valid possibilities.
19 //
20 // The usefulness of this optimization should be reviewed once Matmul is swapped
21 // to use the faster BLAS routines.
22 
23 #include "flang/Runtime/matmul-transpose.h"
24 #include "terminator.h"
25 #include "tools.h"
26 #include "flang/Runtime/c-or-cpp.h"
27 #include "flang/Runtime/cpp-type.h"
28 #include "flang/Runtime/descriptor.h"
29 #include <cstring>
30 
31 namespace {
32 using namespace Fortran::runtime;
33 
34 // Contiguous numeric TRANSPOSE(matrix)*matrix multiplication
35 //   TRANSPOSE(matrix(n, rows)) * matrix(n,cols) ->
36 //             matrix(rows, n)  * matrix(n,cols) -> matrix(rows,cols)
37 // The transpose is implemented by swapping the indices of accesses into the LHS
38 //
39 // Straightforward algorithm:
40 //   DO 1 I = 1, NROWS
41 //    DO 1 J = 1, NCOLS
42 //     RES(I,J) = 0
43 //     DO 1 K = 1, N
44 //   1  RES(I,J) = RES(I,J) + X(K,I)*Y(K,J)
45 //
46 // With loop distribution and transposition to avoid the inner sum
47 // reduction and to avoid non-unit strides:
48 //   DO 1 I = 1, NROWS
49 //    DO 1 J = 1, NCOLS
50 //   1 RES(I,J) = 0
51 //   DO 2 J = 1, NCOLS
52 //    DO 2 I = 1, NROWS
53 //     DO 2 K = 1, N
54 //   2  RES(I,J) = RES(I,J) + X(K,I)*Y(K,J) ! loop-invariant last term
55 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
56 inline static void MatrixTransposedTimesMatrix(
57     CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
58     SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y,
59     SubscriptValue n) {
60   using ResultType = CppTypeFor<RCAT, RKIND>;
61 
62   std::memset(product, 0, rows * cols * sizeof *product);
63   for (SubscriptValue j{0}; j < cols; ++j) {
64     for (SubscriptValue i{0}; i < rows; ++i) {
65       for (SubscriptValue k{0}; k < n; ++k) {
66         ResultType x_ki = static_cast<ResultType>(x[i * n + k]);
67         ResultType y_kj = static_cast<ResultType>(y[j * n + k]);
68         product[j * rows + i] += x_ki * y_kj;
69       }
70     }
71   }
72 }
73 
74 // Contiguous numeric matrix*vector multiplication
75 //   matrix(rows,n) * column vector(n) -> column vector(rows)
76 // Straightforward algorithm:
77 //   DO 1 I = 1, NROWS
78 //    RES(I) = 0
79 //    DO 1 K = 1, N
80 //   1 RES(I) = RES(I) + X(K,I)*Y(K)
81 // With loop distribution and transposition to avoid the inner
82 // sum reduction and to avoid non-unit strides:
83 //   DO 1 I = 1, NROWS
84 //   1 RES(I) = 0
85 //   DO 2 I = 1, NROWS
86 //    DO 2 K = 1, N
87 //   2 RES(I) = RES(I) + X(K,I)*Y(K)
88 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
89 inline static void MatrixTransposedTimesVector(
90     CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
91     SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y) {
92   using ResultType = CppTypeFor<RCAT, RKIND>;
93   std::memset(product, 0, rows * sizeof *product);
94   for (SubscriptValue i{0}; i < rows; ++i) {
95     for (SubscriptValue k{0}; k < n; ++k) {
96       ResultType x_ki = static_cast<ResultType>(x[i * n + k]);
97       ResultType y_k = static_cast<ResultType>(y[k]);
98       product[i] += x_ki * y_k;
99     }
100   }
101 }
102 
103 // Implements an instance of MATMUL for given argument types.
104 template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT,
105     typename YT>
106 inline static void DoMatmulTranspose(
107     std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result,
108     const Descriptor &x, const Descriptor &y, Terminator &terminator) {
109   int xRank{x.rank()};
110   int yRank{y.rank()};
111   int resRank{xRank + yRank - 2};
112   if (xRank * yRank != 2 * resRank) {
113     terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank);
114   }
115   SubscriptValue extent[2]{x.GetDimension(1).Extent(),
116       resRank == 2 ? y.GetDimension(1).Extent() : 0};
117   if constexpr (IS_ALLOCATING) {
118     result.Establish(
119         RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable);
120     for (int j{0}; j < resRank; ++j) {
121       result.GetDimension(j).SetBounds(1, extent[j]);
122     }
123     if (int stat{result.Allocate()}) {
124       terminator.Crash(
125           "MATMUL: could not allocate memory for result; STAT=%d", stat);
126     }
127   } else {
128     RUNTIME_CHECK(terminator, resRank == result.rank());
129     RUNTIME_CHECK(
130         terminator, result.ElementBytes() == static_cast<std::size_t>(RKIND));
131     RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]);
132     RUNTIME_CHECK(terminator,
133         resRank == 1 || result.GetDimension(1).Extent() == extent[1]);
134   }
135   SubscriptValue n{x.GetDimension(0).Extent()};
136   if (n != y.GetDimension(0).Extent()) {
137     terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
138         static_cast<std::intmax_t>(x.GetDimension(0).Extent()),
139         static_cast<std::intmax_t>(x.GetDimension(1).Extent()),
140         static_cast<std::intmax_t>(y.GetDimension(0).Extent()),
141         static_cast<std::intmax_t>(y.GetDimension(1).Extent()));
142   }
143   using WriteResult =
144       CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,
145           RKIND>;
146   const SubscriptValue rows{extent[0]};
147   const SubscriptValue cols{extent[1]};
148   if constexpr (RCAT != TypeCategory::Logical) {
149     if (x.IsContiguous() && y.IsContiguous() &&
150         (IS_ALLOCATING || result.IsContiguous())) {
151       // Contiguous numeric matrices
152       if (resRank == 2) { // M*M -> M
153         MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT>(
154             result.template OffsetElement<WriteResult>(), rows, cols,
155             x.OffsetElement<XT>(), y.OffsetElement<YT>(), n);
156         return;
157       }
158       if (xRank == 2) { // M*V -> V
159         MatrixTransposedTimesVector<RCAT, RKIND, XT, YT>(
160             result.template OffsetElement<WriteResult>(), rows, n,
161             x.OffsetElement<XT>(), y.OffsetElement<YT>());
162         return;
163       }
164       // else V*M -> V (not allowed because TRANSPOSE() is only defined for rank
165       // 1 matrices
166       terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
167           static_cast<std::intmax_t>(x.GetDimension(0).Extent()),
168           static_cast<std::intmax_t>(n),
169           static_cast<std::intmax_t>(y.GetDimension(0).Extent()),
170           static_cast<std::intmax_t>(y.GetDimension(1).Extent()));
171       return;
172     }
173   }
174   // General algorithms for LOGICAL and noncontiguity
175   SubscriptValue xLB[2], yLB[2], resLB[2];
176   x.GetLowerBounds(xLB);
177   y.GetLowerBounds(yLB);
178   result.GetLowerBounds(resLB);
179   using ResultType = CppTypeFor<RCAT, RKIND>;
180   if (resRank == 2) { // M*M -> M
181     for (SubscriptValue i{0}; i < rows; ++i) {
182       for (SubscriptValue j{0}; j < cols; ++j) {
183         ResultType res_ij;
184         if constexpr (RCAT == TypeCategory::Logical) {
185           res_ij = false;
186         } else {
187           res_ij = 0;
188         }
189 
190         for (SubscriptValue k{0}; k < n; ++k) {
191           SubscriptValue xAt[2]{k + xLB[0], i + xLB[1]};
192           SubscriptValue yAt[2]{k + yLB[0], j + yLB[1]};
193           if constexpr (RCAT == TypeCategory::Logical) {
194             ResultType x_ki = IsLogicalElementTrue(x, xAt);
195             ResultType y_kj = IsLogicalElementTrue(y, yAt);
196             res_ij = res_ij || (x_ki && y_kj);
197           } else {
198             ResultType x_ki = static_cast<ResultType>(*x.Element<XT>(xAt));
199             ResultType y_kj = static_cast<ResultType>(*y.Element<YT>(yAt));
200             res_ij += x_ki * y_kj;
201           }
202         }
203         SubscriptValue resAt[2]{i + resLB[0], j + resLB[1]};
204         *result.template Element<WriteResult>(resAt) = res_ij;
205       }
206     }
207   } else if (xRank == 2) { // M*V -> V
208     for (SubscriptValue i{0}; i < rows; ++i) {
209       ResultType res_i;
210       if constexpr (RCAT == TypeCategory::Logical) {
211         res_i = false;
212       } else {
213         res_i = 0;
214       }
215 
216       for (SubscriptValue k{0}; k < n; ++k) {
217         SubscriptValue xAt[2]{k + xLB[0], i + xLB[1]};
218         SubscriptValue yAt[1]{k + yLB[0]};
219         if constexpr (RCAT == TypeCategory::Logical) {
220           ResultType x_ki = IsLogicalElementTrue(x, xAt);
221           ResultType y_k = IsLogicalElementTrue(y, yAt);
222           res_i = res_i || (x_ki && y_k);
223         } else {
224           ResultType x_ki = static_cast<ResultType>(*x.Element<XT>(xAt));
225           ResultType y_k = static_cast<ResultType>(*y.Element<YT>(yAt));
226           res_i += x_ki * y_k;
227         }
228       }
229       SubscriptValue resAt[1]{i + resLB[0]};
230       *result.template Element<WriteResult>(resAt) = res_i;
231     }
232   } else { // V*M -> V
233     // TRANSPOSE(V) not allowed by fortran standard
234     terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
235         static_cast<std::intmax_t>(x.GetDimension(0).Extent()),
236         static_cast<std::intmax_t>(n),
237         static_cast<std::intmax_t>(y.GetDimension(0).Extent()),
238         static_cast<std::intmax_t>(y.GetDimension(1).Extent()));
239   }
240 }
241 
242 // Maps the dynamic type information from the arguments' descriptors
243 // to the right instantiation of DoMatmul() for valid combinations of
244 // types.
245 template <bool IS_ALLOCATING> struct MatmulTranspose {
246   using ResultDescriptor =
247       std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
248   template <TypeCategory XCAT, int XKIND> struct MM1 {
249     template <TypeCategory YCAT, int YKIND> struct MM2 {
250       void operator()(ResultDescriptor &result, const Descriptor &x,
251           const Descriptor &y, Terminator &terminator) const {
252         if constexpr (constexpr auto resultType{
253                           GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
254           if constexpr (Fortran::common::IsNumericTypeCategory(
255                             resultType->first) ||
256               resultType->first == TypeCategory::Logical) {
257             return DoMatmulTranspose<IS_ALLOCATING, resultType->first,
258                 resultType->second, CppTypeFor<XCAT, XKIND>,
259                 CppTypeFor<YCAT, YKIND>>(result, x, y, terminator);
260           }
261         }
262         terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))",
263             static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
264       }
265     };
266     void operator()(ResultDescriptor &result, const Descriptor &x,
267         const Descriptor &y, Terminator &terminator, TypeCategory yCat,
268         int yKind) const {
269       ApplyType<MM2, void>(yCat, yKind, terminator, result, x, y, terminator);
270     }
271   };
272   void operator()(ResultDescriptor &result, const Descriptor &x,
273       const Descriptor &y, const char *sourceFile, int line) const {
274     Terminator terminator{sourceFile, line};
275     auto xCatKind{x.type().GetCategoryAndKind()};
276     auto yCatKind{y.type().GetCategoryAndKind()};
277     RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
278     ApplyType<MM1, void>(xCatKind->first, xCatKind->second, terminator, result,
279         x, y, terminator, yCatKind->first, yCatKind->second);
280   }
281 };
282 } // namespace
283 
284 namespace Fortran::runtime {
285 extern "C" {
286 void RTNAME(MatmulTranspose)(Descriptor &result, const Descriptor &x,
287     const Descriptor &y, const char *sourceFile, int line) {
288   MatmulTranspose<true>{}(result, x, y, sourceFile, line);
289 }
290 void RTNAME(MatmulTransposeDirect)(const Descriptor &result,
291     const Descriptor &x, const Descriptor &y, const char *sourceFile,
292     int line) {
293   MatmulTranspose<false>{}(result, x, y, sourceFile, line);
294 }
295 } // extern "C"
296 } // namespace Fortran::runtime
297