xref: /llvm-project/flang/runtime/dot-product.cpp (revision fc97d2e68b03bc2979395e84b645e5b3ba35aecd)
150e0b298Speter klausler //===-- runtime/dot-product.cpp -------------------------------------------===//
250e0b298Speter klausler //
350e0b298Speter klausler // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
450e0b298Speter klausler // See https://llvm.org/LICENSE.txt for license information.
550e0b298Speter klausler // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
650e0b298Speter klausler //
750e0b298Speter klausler //===----------------------------------------------------------------------===//
850e0b298Speter klausler 
94daa33f6SPeter Klausler #include "float.h"
1050e0b298Speter klausler #include "terminator.h"
1150e0b298Speter klausler #include "tools.h"
12478e0b58SPeter Steinfeld #include "flang/Common/float128.h"
13830c0b90SPeter Klausler #include "flang/Runtime/cpp-type.h"
14830c0b90SPeter Klausler #include "flang/Runtime/descriptor.h"
15830c0b90SPeter Klausler #include "flang/Runtime/reduction.h"
164daa33f6SPeter Klausler #include <cfloat>
1750e0b298Speter klausler #include <cinttypes>
1850e0b298Speter klausler 
1950e0b298Speter klausler namespace Fortran::runtime {
2050e0b298Speter klausler 
21a5a493e1SPeter Klausler // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first
22a5a493e1SPeter Klausler // argument; MATMUL does not.
23a5a493e1SPeter Klausler 
24a5a493e1SPeter Klausler // General accumulator for any type and stride; this is not used for
25a5a493e1SPeter Klausler // contiguous numeric vectors.
26a5a493e1SPeter Klausler template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
275e1421b2Speter klausler class Accumulator {
285e1421b2Speter klausler public:
29a5a493e1SPeter Klausler   using Result = AccumulationType<RCAT, RKIND>;
3076facde3SSlava Zakharin   RT_API_ATTRS Accumulator(const Descriptor &x, const Descriptor &y)
3176facde3SSlava Zakharin       : x_{x}, y_{y} {}
3276facde3SSlava Zakharin   RT_API_ATTRS void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) {
33a5a493e1SPeter Klausler     if constexpr (RCAT == TypeCategory::Logical) {
345e1421b2Speter klausler       sum_ = sum_ ||
355e1421b2Speter klausler           (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
365e1421b2Speter klausler     } else {
37a5a493e1SPeter Klausler       const XT &xElement{*x_.Element<XT>(&xAt)};
38a5a493e1SPeter Klausler       const YT &yElement{*y_.Element<YT>(&yAt)};
39a5a493e1SPeter Klausler       if constexpr (RCAT == TypeCategory::Complex) {
40104f3c18SSlava Zakharin         sum_ += rtcmplx::conj(static_cast<Result>(xElement)) *
41a5a493e1SPeter Klausler             static_cast<Result>(yElement);
42a5a493e1SPeter Klausler       } else {
43a5a493e1SPeter Klausler         sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement);
44a5a493e1SPeter Klausler       }
455e1421b2Speter klausler     }
465e1421b2Speter klausler   }
4776facde3SSlava Zakharin   RT_API_ATTRS Result GetResult() const { return sum_; }
485e1421b2Speter klausler 
495e1421b2Speter klausler private:
505e1421b2Speter klausler   const Descriptor &x_, &y_;
515e1421b2Speter klausler   Result sum_{};
525e1421b2Speter klausler };
535e1421b2Speter klausler 
54a5a493e1SPeter Klausler template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
5576facde3SSlava Zakharin static inline RT_API_ATTRS CppTypeFor<RCAT, RKIND> DoDotProduct(
565e1421b2Speter klausler     const Descriptor &x, const Descriptor &y, Terminator &terminator) {
57a5a493e1SPeter Klausler   using Result = CppTypeFor<RCAT, RKIND>;
5850e0b298Speter klausler   RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
5950e0b298Speter klausler   SubscriptValue n{x.GetDimension(0).Extent()};
6050e0b298Speter klausler   if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
6150e0b298Speter klausler     terminator.Crash(
6250e0b298Speter klausler         "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
6350e0b298Speter klausler         static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
6450e0b298Speter klausler   }
65a5a493e1SPeter Klausler   if constexpr (RCAT != TypeCategory::Logical) {
66a5a493e1SPeter Klausler     if (x.GetDimension(0).ByteStride() == sizeof(XT) &&
67a5a493e1SPeter Klausler         y.GetDimension(0).ByteStride() == sizeof(YT)) {
68a5a493e1SPeter Klausler       // Contiguous numeric vectors
695e1421b2Speter klausler       if constexpr (std::is_same_v<XT, YT>) {
70a5a493e1SPeter Klausler         // Contiguous homogeneous numeric vectors
715e1421b2Speter klausler         if constexpr (std::is_same_v<XT, float>) {
725e1421b2Speter klausler           // TODO: call BLAS-1 SDOT or SDSDOT
735e1421b2Speter klausler         } else if constexpr (std::is_same_v<XT, double>) {
745e1421b2Speter klausler           // TODO: call BLAS-1 DDOT
75104f3c18SSlava Zakharin         } else if constexpr (std::is_same_v<XT, rtcmplx::complex<float>>) {
765e1421b2Speter klausler           // TODO: call BLAS-1 CDOTC
77104f3c18SSlava Zakharin         } else if constexpr (std::is_same_v<XT, rtcmplx::complex<double>>) {
785e1421b2Speter klausler           // TODO: call BLAS-1 ZDOTC
795e1421b2Speter klausler         }
805e1421b2Speter klausler       }
81a5a493e1SPeter Klausler       XT *xp{x.OffsetElement<XT>(0)};
82a5a493e1SPeter Klausler       YT *yp{y.OffsetElement<YT>(0)};
83a5a493e1SPeter Klausler       using AccumType = AccumulationType<RCAT, RKIND>;
84a5a493e1SPeter Klausler       AccumType accum{};
85a5a493e1SPeter Klausler       if constexpr (RCAT == TypeCategory::Complex) {
86a5a493e1SPeter Klausler         for (SubscriptValue j{0}; j < n; ++j) {
87104f3c18SSlava Zakharin           // conj() may instantiate its argument twice,
8876facde3SSlava Zakharin           // so xp has to be incremented separately.
8976facde3SSlava Zakharin           // This is a workaround for an alleged bug in clang,
9076facde3SSlava Zakharin           // that shows up as:
9176facde3SSlava Zakharin           //   warning: multiple unsequenced modifications to 'xp'
92104f3c18SSlava Zakharin           accum += rtcmplx::conj(static_cast<AccumType>(*xp)) *
93a5a493e1SPeter Klausler               static_cast<AccumType>(*yp++);
9476facde3SSlava Zakharin           xp++;
95a5a493e1SPeter Klausler         }
96a5a493e1SPeter Klausler       } else {
97a5a493e1SPeter Klausler         for (SubscriptValue j{0}; j < n; ++j) {
98a5a493e1SPeter Klausler           accum +=
99a5a493e1SPeter Klausler               static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++);
100a5a493e1SPeter Klausler         }
101a5a493e1SPeter Klausler       }
102a5a493e1SPeter Klausler       return static_cast<Result>(accum);
103a5a493e1SPeter Klausler     }
104a5a493e1SPeter Klausler   }
105a5a493e1SPeter Klausler   // Non-contiguous, heterogeneous, & LOGICAL cases
10650e0b298Speter klausler   SubscriptValue xAt{x.GetDimension(0).LowerBound()};
10750e0b298Speter klausler   SubscriptValue yAt{y.GetDimension(0).LowerBound()};
108a5a493e1SPeter Klausler   Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
10950e0b298Speter klausler   for (SubscriptValue j{0}; j < n; ++j) {
110a5a493e1SPeter Klausler     accumulator.AccumulateIndexed(xAt++, yAt++);
11150e0b298Speter klausler   }
112a5a493e1SPeter Klausler   return static_cast<Result>(accumulator.GetResult());
11350e0b298Speter klausler }
11450e0b298Speter klausler 
1155e1421b2Speter klausler template <TypeCategory RCAT, int RKIND> struct DotProduct {
11650e0b298Speter klausler   using Result = CppTypeFor<RCAT, RKIND>;
11750e0b298Speter klausler   template <TypeCategory XCAT, int XKIND> struct DP1 {
11850e0b298Speter klausler     template <TypeCategory YCAT, int YKIND> struct DP2 {
11976facde3SSlava Zakharin       RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y,
12050e0b298Speter klausler           Terminator &terminator) const {
12150e0b298Speter klausler         if constexpr (constexpr auto resultType{
12250e0b298Speter klausler                           GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
12350e0b298Speter klausler           if constexpr (resultType->first == RCAT &&
124f6aac0ddSpeter klausler               (resultType->second <= RKIND || RCAT == TypeCategory::Logical)) {
125a5a493e1SPeter Klausler             return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>,
1265e1421b2Speter klausler                 CppTypeFor<YCAT, YKIND>>(x, y, terminator);
12750e0b298Speter klausler           }
12850e0b298Speter klausler         }
12950e0b298Speter klausler         terminator.Crash(
13050e0b298Speter klausler             "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))",
13150e0b298Speter klausler             static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND,
13250e0b298Speter klausler             static_cast<int>(YCAT), YKIND);
13350e0b298Speter klausler       }
13450e0b298Speter klausler     };
13576facde3SSlava Zakharin     RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y,
13650e0b298Speter klausler         Terminator &terminator, TypeCategory yCat, int yKind) const {
13750e0b298Speter klausler       return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator);
13850e0b298Speter klausler     }
13950e0b298Speter klausler   };
14076facde3SSlava Zakharin   RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y,
14150e0b298Speter klausler       const char *source, int line) const {
14250e0b298Speter klausler     Terminator terminator{source, line};
143a5a493e1SPeter Klausler     if (RCAT != TypeCategory::Logical && x.type() == y.type()) {
144a5a493e1SPeter Klausler       // No conversions needed, operands and result have same known type
145a5a493e1SPeter Klausler       return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}(
146a5a493e1SPeter Klausler           x, y, terminator);
147a5a493e1SPeter Klausler     } else {
14850e0b298Speter klausler       auto xCatKind{x.type().GetCategoryAndKind()};
14950e0b298Speter klausler       auto yCatKind{y.type().GetCategoryAndKind()};
15050e0b298Speter klausler       RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
151a5a493e1SPeter Klausler       return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second,
152a5a493e1SPeter Klausler           terminator, x, y, terminator, yCatKind->first, yCatKind->second);
153a5a493e1SPeter Klausler     }
15450e0b298Speter klausler   }
15550e0b298Speter klausler };
15650e0b298Speter klausler 
15750e0b298Speter klausler extern "C" {
15876facde3SSlava Zakharin RT_EXT_API_GROUP_BEGIN
15976facde3SSlava Zakharin 
16076facde3SSlava Zakharin CppTypeFor<TypeCategory::Integer, 1> RTDEF(DotProductInteger1)(
16150e0b298Speter klausler     const Descriptor &x, const Descriptor &y, const char *source, int line) {
162a5a493e1SPeter Klausler   return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line);
16350e0b298Speter klausler }
16476facde3SSlava Zakharin CppTypeFor<TypeCategory::Integer, 2> RTDEF(DotProductInteger2)(
16550e0b298Speter klausler     const Descriptor &x, const Descriptor &y, const char *source, int line) {
166a5a493e1SPeter Klausler   return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line);
16750e0b298Speter klausler }
16876facde3SSlava Zakharin CppTypeFor<TypeCategory::Integer, 4> RTDEF(DotProductInteger4)(
16950e0b298Speter klausler     const Descriptor &x, const Descriptor &y, const char *source, int line) {
170a5a493e1SPeter Klausler   return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line);
17150e0b298Speter klausler }
17276facde3SSlava Zakharin CppTypeFor<TypeCategory::Integer, 8> RTDEF(DotProductInteger8)(
17350e0b298Speter klausler     const Descriptor &x, const Descriptor &y, const char *source, int line) {
1745e1421b2Speter klausler   return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
17550e0b298Speter klausler }
17650e0b298Speter klausler #ifdef __SIZEOF_INT128__
17776facde3SSlava Zakharin CppTypeFor<TypeCategory::Integer, 16> RTDEF(DotProductInteger16)(
17850e0b298Speter klausler     const Descriptor &x, const Descriptor &y, const char *source, int line) {
1795e1421b2Speter klausler   return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line);
18050e0b298Speter klausler }
18150e0b298Speter klausler #endif
18250e0b298Speter klausler 
183*fc97d2e6SPeter Klausler CppTypeFor<TypeCategory::Unsigned, 1> RTDEF(DotProductUnsigned1)(
184*fc97d2e6SPeter Klausler     const Descriptor &x, const Descriptor &y, const char *source, int line) {
185*fc97d2e6SPeter Klausler   return DotProduct<TypeCategory::Unsigned, 1>{}(x, y, source, line);
186*fc97d2e6SPeter Klausler }
187*fc97d2e6SPeter Klausler CppTypeFor<TypeCategory::Unsigned, 2> RTDEF(DotProductUnsigned2)(
188*fc97d2e6SPeter Klausler     const Descriptor &x, const Descriptor &y, const char *source, int line) {
189*fc97d2e6SPeter Klausler   return DotProduct<TypeCategory::Unsigned, 2>{}(x, y, source, line);
190*fc97d2e6SPeter Klausler }
191*fc97d2e6SPeter Klausler CppTypeFor<TypeCategory::Unsigned, 4> RTDEF(DotProductUnsigned4)(
192*fc97d2e6SPeter Klausler     const Descriptor &x, const Descriptor &y, const char *source, int line) {
193*fc97d2e6SPeter Klausler   return DotProduct<TypeCategory::Unsigned, 4>{}(x, y, source, line);
194*fc97d2e6SPeter Klausler }
195*fc97d2e6SPeter Klausler CppTypeFor<TypeCategory::Unsigned, 8> RTDEF(DotProductUnsigned8)(
196*fc97d2e6SPeter Klausler     const Descriptor &x, const Descriptor &y, const char *source, int line) {
197*fc97d2e6SPeter Klausler   return DotProduct<TypeCategory::Unsigned, 8>{}(x, y, source, line);
198*fc97d2e6SPeter Klausler }
199*fc97d2e6SPeter Klausler #ifdef __SIZEOF_INT128__
200*fc97d2e6SPeter Klausler CppTypeFor<TypeCategory::Unsigned, 16> RTDEF(DotProductUnsigned16)(
201*fc97d2e6SPeter Klausler     const Descriptor &x, const Descriptor &y, const char *source, int line) {
202*fc97d2e6SPeter Klausler   return DotProduct<TypeCategory::Unsigned, 16>{}(x, y, source, line);
203*fc97d2e6SPeter Klausler }
204*fc97d2e6SPeter Klausler #endif
205*fc97d2e6SPeter Klausler 
20650e0b298Speter klausler // TODO: REAL/COMPLEX(2 & 3)
207a5a493e1SPeter Klausler // Intermediate results and operations are at least 64 bits
20876facde3SSlava Zakharin CppTypeFor<TypeCategory::Real, 4> RTDEF(DotProductReal4)(
20950e0b298Speter klausler     const Descriptor &x, const Descriptor &y, const char *source, int line) {
210a5a493e1SPeter Klausler   return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line);
21150e0b298Speter klausler }
21276facde3SSlava Zakharin CppTypeFor<TypeCategory::Real, 8> RTDEF(DotProductReal8)(
21350e0b298Speter klausler     const Descriptor &x, const Descriptor &y, const char *source, int line) {
2145e1421b2Speter klausler   return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
21550e0b298Speter klausler }
216104f3c18SSlava Zakharin #if HAS_FLOAT80
21776facde3SSlava Zakharin CppTypeFor<TypeCategory::Real, 10> RTDEF(DotProductReal10)(
21850e0b298Speter klausler     const Descriptor &x, const Descriptor &y, const char *source, int line) {
2195e1421b2Speter klausler   return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line);
22050e0b298Speter klausler }
2214daa33f6SPeter Klausler #endif
222fc51c7f0SSlava Zakharin #if HAS_LDBL128 || HAS_FLOAT128
22376facde3SSlava Zakharin CppTypeFor<TypeCategory::Real, 16> RTDEF(DotProductReal16)(
22450e0b298Speter klausler     const Descriptor &x, const Descriptor &y, const char *source, int line) {
2255e1421b2Speter klausler   return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line);
22650e0b298Speter klausler }
22750e0b298Speter klausler #endif
22850e0b298Speter klausler 
22976facde3SSlava Zakharin void RTDEF(CppDotProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
23050e0b298Speter klausler     const Descriptor &x, const Descriptor &y, const char *source, int line) {
231f8a9f43eSSlava Zakharin   result = DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line);
23250e0b298Speter klausler }
23376facde3SSlava Zakharin void RTDEF(CppDotProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
23450e0b298Speter klausler     const Descriptor &x, const Descriptor &y, const char *source, int line) {
2355e1421b2Speter klausler   result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line);
23650e0b298Speter klausler }
237104f3c18SSlava Zakharin #if HAS_FLOAT80
23876facde3SSlava Zakharin void RTDEF(CppDotProductComplex10)(
239f8a9f43eSSlava Zakharin     CppTypeFor<TypeCategory::Complex, 10> &result, const Descriptor &x,
240f8a9f43eSSlava Zakharin     const Descriptor &y, const char *source, int line) {
2415e1421b2Speter klausler   result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line);
24250e0b298Speter klausler }
243f8a9f43eSSlava Zakharin #endif
244fc51c7f0SSlava Zakharin #if HAS_LDBL128 || HAS_FLOAT128
24576facde3SSlava Zakharin void RTDEF(CppDotProductComplex16)(
246f8a9f43eSSlava Zakharin     CppTypeFor<TypeCategory::Complex, 16> &result, const Descriptor &x,
247f8a9f43eSSlava Zakharin     const Descriptor &y, const char *source, int line) {
2485e1421b2Speter klausler   result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line);
24950e0b298Speter klausler }
25050e0b298Speter klausler #endif
25150e0b298Speter klausler 
25276facde3SSlava Zakharin bool RTDEF(DotProductLogical)(
25350e0b298Speter klausler     const Descriptor &x, const Descriptor &y, const char *source, int line) {
2545e1421b2Speter klausler   return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line);
25550e0b298Speter klausler }
25676facde3SSlava Zakharin 
25776facde3SSlava Zakharin RT_EXT_API_GROUP_END
25850e0b298Speter klausler } // extern "C"
25950e0b298Speter klausler } // namespace Fortran::runtime
260