xref: /llvm-project/flang/runtime/dot-product.cpp (revision fc97d2e68b03bc2979395e84b645e5b3ba35aecd)
1 //===-- runtime/dot-product.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "float.h"
10 #include "terminator.h"
11 #include "tools.h"
12 #include "flang/Common/float128.h"
13 #include "flang/Runtime/cpp-type.h"
14 #include "flang/Runtime/descriptor.h"
15 #include "flang/Runtime/reduction.h"
16 #include <cfloat>
17 #include <cinttypes>
18 
19 namespace Fortran::runtime {
20 
21 // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first
22 // argument; MATMUL does not.
23 
24 // General accumulator for any type and stride; this is not used for
25 // contiguous numeric vectors.
26 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
27 class Accumulator {
28 public:
29   using Result = AccumulationType<RCAT, RKIND>;
30   RT_API_ATTRS Accumulator(const Descriptor &x, const Descriptor &y)
31       : x_{x}, y_{y} {}
32   RT_API_ATTRS void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) {
33     if constexpr (RCAT == TypeCategory::Logical) {
34       sum_ = sum_ ||
35           (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
36     } else {
37       const XT &xElement{*x_.Element<XT>(&xAt)};
38       const YT &yElement{*y_.Element<YT>(&yAt)};
39       if constexpr (RCAT == TypeCategory::Complex) {
40         sum_ += rtcmplx::conj(static_cast<Result>(xElement)) *
41             static_cast<Result>(yElement);
42       } else {
43         sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement);
44       }
45     }
46   }
47   RT_API_ATTRS Result GetResult() const { return sum_; }
48 
49 private:
50   const Descriptor &x_, &y_;
51   Result sum_{};
52 };
53 
54 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
55 static inline RT_API_ATTRS CppTypeFor<RCAT, RKIND> DoDotProduct(
56     const Descriptor &x, const Descriptor &y, Terminator &terminator) {
57   using Result = CppTypeFor<RCAT, RKIND>;
58   RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
59   SubscriptValue n{x.GetDimension(0).Extent()};
60   if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
61     terminator.Crash(
62         "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
63         static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
64   }
65   if constexpr (RCAT != TypeCategory::Logical) {
66     if (x.GetDimension(0).ByteStride() == sizeof(XT) &&
67         y.GetDimension(0).ByteStride() == sizeof(YT)) {
68       // Contiguous numeric vectors
69       if constexpr (std::is_same_v<XT, YT>) {
70         // Contiguous homogeneous numeric vectors
71         if constexpr (std::is_same_v<XT, float>) {
72           // TODO: call BLAS-1 SDOT or SDSDOT
73         } else if constexpr (std::is_same_v<XT, double>) {
74           // TODO: call BLAS-1 DDOT
75         } else if constexpr (std::is_same_v<XT, rtcmplx::complex<float>>) {
76           // TODO: call BLAS-1 CDOTC
77         } else if constexpr (std::is_same_v<XT, rtcmplx::complex<double>>) {
78           // TODO: call BLAS-1 ZDOTC
79         }
80       }
81       XT *xp{x.OffsetElement<XT>(0)};
82       YT *yp{y.OffsetElement<YT>(0)};
83       using AccumType = AccumulationType<RCAT, RKIND>;
84       AccumType accum{};
85       if constexpr (RCAT == TypeCategory::Complex) {
86         for (SubscriptValue j{0}; j < n; ++j) {
87           // conj() may instantiate its argument twice,
88           // so xp has to be incremented separately.
89           // This is a workaround for an alleged bug in clang,
90           // that shows up as:
91           //   warning: multiple unsequenced modifications to 'xp'
92           accum += rtcmplx::conj(static_cast<AccumType>(*xp)) *
93               static_cast<AccumType>(*yp++);
94           xp++;
95         }
96       } else {
97         for (SubscriptValue j{0}; j < n; ++j) {
98           accum +=
99               static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++);
100         }
101       }
102       return static_cast<Result>(accum);
103     }
104   }
105   // Non-contiguous, heterogeneous, & LOGICAL cases
106   SubscriptValue xAt{x.GetDimension(0).LowerBound()};
107   SubscriptValue yAt{y.GetDimension(0).LowerBound()};
108   Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
109   for (SubscriptValue j{0}; j < n; ++j) {
110     accumulator.AccumulateIndexed(xAt++, yAt++);
111   }
112   return static_cast<Result>(accumulator.GetResult());
113 }
114 
115 template <TypeCategory RCAT, int RKIND> struct DotProduct {
116   using Result = CppTypeFor<RCAT, RKIND>;
117   template <TypeCategory XCAT, int XKIND> struct DP1 {
118     template <TypeCategory YCAT, int YKIND> struct DP2 {
119       RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y,
120           Terminator &terminator) const {
121         if constexpr (constexpr auto resultType{
122                           GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
123           if constexpr (resultType->first == RCAT &&
124               (resultType->second <= RKIND || RCAT == TypeCategory::Logical)) {
125             return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>,
126                 CppTypeFor<YCAT, YKIND>>(x, y, terminator);
127           }
128         }
129         terminator.Crash(
130             "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))",
131             static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND,
132             static_cast<int>(YCAT), YKIND);
133       }
134     };
135     RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y,
136         Terminator &terminator, TypeCategory yCat, int yKind) const {
137       return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator);
138     }
139   };
140   RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y,
141       const char *source, int line) const {
142     Terminator terminator{source, line};
143     if (RCAT != TypeCategory::Logical && x.type() == y.type()) {
144       // No conversions needed, operands and result have same known type
145       return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}(
146           x, y, terminator);
147     } else {
148       auto xCatKind{x.type().GetCategoryAndKind()};
149       auto yCatKind{y.type().GetCategoryAndKind()};
150       RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
151       return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second,
152           terminator, x, y, terminator, yCatKind->first, yCatKind->second);
153     }
154   }
155 };
156 
157 extern "C" {
158 RT_EXT_API_GROUP_BEGIN
159 
160 CppTypeFor<TypeCategory::Integer, 1> RTDEF(DotProductInteger1)(
161     const Descriptor &x, const Descriptor &y, const char *source, int line) {
162   return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line);
163 }
164 CppTypeFor<TypeCategory::Integer, 2> RTDEF(DotProductInteger2)(
165     const Descriptor &x, const Descriptor &y, const char *source, int line) {
166   return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line);
167 }
168 CppTypeFor<TypeCategory::Integer, 4> RTDEF(DotProductInteger4)(
169     const Descriptor &x, const Descriptor &y, const char *source, int line) {
170   return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line);
171 }
172 CppTypeFor<TypeCategory::Integer, 8> RTDEF(DotProductInteger8)(
173     const Descriptor &x, const Descriptor &y, const char *source, int line) {
174   return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
175 }
176 #ifdef __SIZEOF_INT128__
177 CppTypeFor<TypeCategory::Integer, 16> RTDEF(DotProductInteger16)(
178     const Descriptor &x, const Descriptor &y, const char *source, int line) {
179   return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line);
180 }
181 #endif
182 
183 CppTypeFor<TypeCategory::Unsigned, 1> RTDEF(DotProductUnsigned1)(
184     const Descriptor &x, const Descriptor &y, const char *source, int line) {
185   return DotProduct<TypeCategory::Unsigned, 1>{}(x, y, source, line);
186 }
187 CppTypeFor<TypeCategory::Unsigned, 2> RTDEF(DotProductUnsigned2)(
188     const Descriptor &x, const Descriptor &y, const char *source, int line) {
189   return DotProduct<TypeCategory::Unsigned, 2>{}(x, y, source, line);
190 }
191 CppTypeFor<TypeCategory::Unsigned, 4> RTDEF(DotProductUnsigned4)(
192     const Descriptor &x, const Descriptor &y, const char *source, int line) {
193   return DotProduct<TypeCategory::Unsigned, 4>{}(x, y, source, line);
194 }
195 CppTypeFor<TypeCategory::Unsigned, 8> RTDEF(DotProductUnsigned8)(
196     const Descriptor &x, const Descriptor &y, const char *source, int line) {
197   return DotProduct<TypeCategory::Unsigned, 8>{}(x, y, source, line);
198 }
199 #ifdef __SIZEOF_INT128__
200 CppTypeFor<TypeCategory::Unsigned, 16> RTDEF(DotProductUnsigned16)(
201     const Descriptor &x, const Descriptor &y, const char *source, int line) {
202   return DotProduct<TypeCategory::Unsigned, 16>{}(x, y, source, line);
203 }
204 #endif
205 
206 // TODO: REAL/COMPLEX(2 & 3)
207 // Intermediate results and operations are at least 64 bits
208 CppTypeFor<TypeCategory::Real, 4> RTDEF(DotProductReal4)(
209     const Descriptor &x, const Descriptor &y, const char *source, int line) {
210   return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line);
211 }
212 CppTypeFor<TypeCategory::Real, 8> RTDEF(DotProductReal8)(
213     const Descriptor &x, const Descriptor &y, const char *source, int line) {
214   return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
215 }
216 #if HAS_FLOAT80
217 CppTypeFor<TypeCategory::Real, 10> RTDEF(DotProductReal10)(
218     const Descriptor &x, const Descriptor &y, const char *source, int line) {
219   return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line);
220 }
221 #endif
222 #if HAS_LDBL128 || HAS_FLOAT128
223 CppTypeFor<TypeCategory::Real, 16> RTDEF(DotProductReal16)(
224     const Descriptor &x, const Descriptor &y, const char *source, int line) {
225   return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line);
226 }
227 #endif
228 
229 void RTDEF(CppDotProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
230     const Descriptor &x, const Descriptor &y, const char *source, int line) {
231   result = DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line);
232 }
233 void RTDEF(CppDotProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
234     const Descriptor &x, const Descriptor &y, const char *source, int line) {
235   result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line);
236 }
237 #if HAS_FLOAT80
238 void RTDEF(CppDotProductComplex10)(
239     CppTypeFor<TypeCategory::Complex, 10> &result, const Descriptor &x,
240     const Descriptor &y, const char *source, int line) {
241   result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line);
242 }
243 #endif
244 #if HAS_LDBL128 || HAS_FLOAT128
245 void RTDEF(CppDotProductComplex16)(
246     CppTypeFor<TypeCategory::Complex, 16> &result, const Descriptor &x,
247     const Descriptor &y, const char *source, int line) {
248   result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line);
249 }
250 #endif
251 
252 bool RTDEF(DotProductLogical)(
253     const Descriptor &x, const Descriptor &y, const char *source, int line) {
254   return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line);
255 }
256 
257 RT_EXT_API_GROUP_END
258 } // extern "C"
259 } // namespace Fortran::runtime
260