150e0b298Speter klausler //===-- runtime/dot-product.cpp -------------------------------------------===// 250e0b298Speter klausler // 350e0b298Speter klausler // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 450e0b298Speter klausler // See https://llvm.org/LICENSE.txt for license information. 550e0b298Speter klausler // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 650e0b298Speter klausler // 750e0b298Speter klausler //===----------------------------------------------------------------------===// 850e0b298Speter klausler 94daa33f6SPeter Klausler #include "float.h" 1050e0b298Speter klausler #include "terminator.h" 1150e0b298Speter klausler #include "tools.h" 12478e0b58SPeter Steinfeld #include "flang/Common/float128.h" 13830c0b90SPeter Klausler #include "flang/Runtime/cpp-type.h" 14830c0b90SPeter Klausler #include "flang/Runtime/descriptor.h" 15830c0b90SPeter Klausler #include "flang/Runtime/reduction.h" 164daa33f6SPeter Klausler #include <cfloat> 1750e0b298Speter klausler #include <cinttypes> 1850e0b298Speter klausler 1950e0b298Speter klausler namespace Fortran::runtime { 2050e0b298Speter klausler 21a5a493e1SPeter Klausler // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first 22a5a493e1SPeter Klausler // argument; MATMUL does not. 23a5a493e1SPeter Klausler 24a5a493e1SPeter Klausler // General accumulator for any type and stride; this is not used for 25a5a493e1SPeter Klausler // contiguous numeric vectors. 26a5a493e1SPeter Klausler template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 275e1421b2Speter klausler class Accumulator { 285e1421b2Speter klausler public: 29a5a493e1SPeter Klausler using Result = AccumulationType<RCAT, RKIND>; 3076facde3SSlava Zakharin RT_API_ATTRS Accumulator(const Descriptor &x, const Descriptor &y) 3176facde3SSlava Zakharin : x_{x}, y_{y} {} 3276facde3SSlava Zakharin RT_API_ATTRS void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) { 33a5a493e1SPeter Klausler if constexpr (RCAT == TypeCategory::Logical) { 345e1421b2Speter klausler sum_ = sum_ || 355e1421b2Speter klausler (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt)); 365e1421b2Speter klausler } else { 37a5a493e1SPeter Klausler const XT &xElement{*x_.Element<XT>(&xAt)}; 38a5a493e1SPeter Klausler const YT &yElement{*y_.Element<YT>(&yAt)}; 39a5a493e1SPeter Klausler if constexpr (RCAT == TypeCategory::Complex) { 40104f3c18SSlava Zakharin sum_ += rtcmplx::conj(static_cast<Result>(xElement)) * 41a5a493e1SPeter Klausler static_cast<Result>(yElement); 42a5a493e1SPeter Klausler } else { 43a5a493e1SPeter Klausler sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement); 44a5a493e1SPeter Klausler } 455e1421b2Speter klausler } 465e1421b2Speter klausler } 4776facde3SSlava Zakharin RT_API_ATTRS Result GetResult() const { return sum_; } 485e1421b2Speter klausler 495e1421b2Speter klausler private: 505e1421b2Speter klausler const Descriptor &x_, &y_; 515e1421b2Speter klausler Result sum_{}; 525e1421b2Speter klausler }; 535e1421b2Speter klausler 54a5a493e1SPeter Klausler template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 5576facde3SSlava Zakharin static inline RT_API_ATTRS CppTypeFor<RCAT, RKIND> DoDotProduct( 565e1421b2Speter klausler const Descriptor &x, const Descriptor &y, Terminator &terminator) { 57a5a493e1SPeter Klausler using Result = CppTypeFor<RCAT, RKIND>; 5850e0b298Speter klausler RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1); 5950e0b298Speter klausler SubscriptValue n{x.GetDimension(0).Extent()}; 6050e0b298Speter klausler if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) { 6150e0b298Speter klausler terminator.Crash( 6250e0b298Speter klausler "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd", 6350e0b298Speter klausler static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN)); 6450e0b298Speter klausler } 65a5a493e1SPeter Klausler if constexpr (RCAT != TypeCategory::Logical) { 66a5a493e1SPeter Klausler if (x.GetDimension(0).ByteStride() == sizeof(XT) && 67a5a493e1SPeter Klausler y.GetDimension(0).ByteStride() == sizeof(YT)) { 68a5a493e1SPeter Klausler // Contiguous numeric vectors 695e1421b2Speter klausler if constexpr (std::is_same_v<XT, YT>) { 70a5a493e1SPeter Klausler // Contiguous homogeneous numeric vectors 715e1421b2Speter klausler if constexpr (std::is_same_v<XT, float>) { 725e1421b2Speter klausler // TODO: call BLAS-1 SDOT or SDSDOT 735e1421b2Speter klausler } else if constexpr (std::is_same_v<XT, double>) { 745e1421b2Speter klausler // TODO: call BLAS-1 DDOT 75104f3c18SSlava Zakharin } else if constexpr (std::is_same_v<XT, rtcmplx::complex<float>>) { 765e1421b2Speter klausler // TODO: call BLAS-1 CDOTC 77104f3c18SSlava Zakharin } else if constexpr (std::is_same_v<XT, rtcmplx::complex<double>>) { 785e1421b2Speter klausler // TODO: call BLAS-1 ZDOTC 795e1421b2Speter klausler } 805e1421b2Speter klausler } 81a5a493e1SPeter Klausler XT *xp{x.OffsetElement<XT>(0)}; 82a5a493e1SPeter Klausler YT *yp{y.OffsetElement<YT>(0)}; 83a5a493e1SPeter Klausler using AccumType = AccumulationType<RCAT, RKIND>; 84a5a493e1SPeter Klausler AccumType accum{}; 85a5a493e1SPeter Klausler if constexpr (RCAT == TypeCategory::Complex) { 86a5a493e1SPeter Klausler for (SubscriptValue j{0}; j < n; ++j) { 87104f3c18SSlava Zakharin // conj() may instantiate its argument twice, 8876facde3SSlava Zakharin // so xp has to be incremented separately. 8976facde3SSlava Zakharin // This is a workaround for an alleged bug in clang, 9076facde3SSlava Zakharin // that shows up as: 9176facde3SSlava Zakharin // warning: multiple unsequenced modifications to 'xp' 92104f3c18SSlava Zakharin accum += rtcmplx::conj(static_cast<AccumType>(*xp)) * 93a5a493e1SPeter Klausler static_cast<AccumType>(*yp++); 9476facde3SSlava Zakharin xp++; 95a5a493e1SPeter Klausler } 96a5a493e1SPeter Klausler } else { 97a5a493e1SPeter Klausler for (SubscriptValue j{0}; j < n; ++j) { 98a5a493e1SPeter Klausler accum += 99a5a493e1SPeter Klausler static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++); 100a5a493e1SPeter Klausler } 101a5a493e1SPeter Klausler } 102a5a493e1SPeter Klausler return static_cast<Result>(accum); 103a5a493e1SPeter Klausler } 104a5a493e1SPeter Klausler } 105a5a493e1SPeter Klausler // Non-contiguous, heterogeneous, & LOGICAL cases 10650e0b298Speter klausler SubscriptValue xAt{x.GetDimension(0).LowerBound()}; 10750e0b298Speter klausler SubscriptValue yAt{y.GetDimension(0).LowerBound()}; 108a5a493e1SPeter Klausler Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; 10950e0b298Speter klausler for (SubscriptValue j{0}; j < n; ++j) { 110a5a493e1SPeter Klausler accumulator.AccumulateIndexed(xAt++, yAt++); 11150e0b298Speter klausler } 112a5a493e1SPeter Klausler return static_cast<Result>(accumulator.GetResult()); 11350e0b298Speter klausler } 11450e0b298Speter klausler 1155e1421b2Speter klausler template <TypeCategory RCAT, int RKIND> struct DotProduct { 11650e0b298Speter klausler using Result = CppTypeFor<RCAT, RKIND>; 11750e0b298Speter klausler template <TypeCategory XCAT, int XKIND> struct DP1 { 11850e0b298Speter klausler template <TypeCategory YCAT, int YKIND> struct DP2 { 11976facde3SSlava Zakharin RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y, 12050e0b298Speter klausler Terminator &terminator) const { 12150e0b298Speter klausler if constexpr (constexpr auto resultType{ 12250e0b298Speter klausler GetResultType(XCAT, XKIND, YCAT, YKIND)}) { 12350e0b298Speter klausler if constexpr (resultType->first == RCAT && 124f6aac0ddSpeter klausler (resultType->second <= RKIND || RCAT == TypeCategory::Logical)) { 125a5a493e1SPeter Klausler return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>, 1265e1421b2Speter klausler CppTypeFor<YCAT, YKIND>>(x, y, terminator); 12750e0b298Speter klausler } 12850e0b298Speter klausler } 12950e0b298Speter klausler terminator.Crash( 13050e0b298Speter klausler "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))", 13150e0b298Speter klausler static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND, 13250e0b298Speter klausler static_cast<int>(YCAT), YKIND); 13350e0b298Speter klausler } 13450e0b298Speter klausler }; 13576facde3SSlava Zakharin RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y, 13650e0b298Speter klausler Terminator &terminator, TypeCategory yCat, int yKind) const { 13750e0b298Speter klausler return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator); 13850e0b298Speter klausler } 13950e0b298Speter klausler }; 14076facde3SSlava Zakharin RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y, 14150e0b298Speter klausler const char *source, int line) const { 14250e0b298Speter klausler Terminator terminator{source, line}; 143a5a493e1SPeter Klausler if (RCAT != TypeCategory::Logical && x.type() == y.type()) { 144a5a493e1SPeter Klausler // No conversions needed, operands and result have same known type 145a5a493e1SPeter Klausler return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}( 146a5a493e1SPeter Klausler x, y, terminator); 147a5a493e1SPeter Klausler } else { 14850e0b298Speter klausler auto xCatKind{x.type().GetCategoryAndKind()}; 14950e0b298Speter klausler auto yCatKind{y.type().GetCategoryAndKind()}; 15050e0b298Speter klausler RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); 151a5a493e1SPeter Klausler return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second, 152a5a493e1SPeter Klausler terminator, x, y, terminator, yCatKind->first, yCatKind->second); 153a5a493e1SPeter Klausler } 15450e0b298Speter klausler } 15550e0b298Speter klausler }; 15650e0b298Speter klausler 15750e0b298Speter klausler extern "C" { 15876facde3SSlava Zakharin RT_EXT_API_GROUP_BEGIN 15976facde3SSlava Zakharin 16076facde3SSlava Zakharin CppTypeFor<TypeCategory::Integer, 1> RTDEF(DotProductInteger1)( 16150e0b298Speter klausler const Descriptor &x, const Descriptor &y, const char *source, int line) { 162a5a493e1SPeter Klausler return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line); 16350e0b298Speter klausler } 16476facde3SSlava Zakharin CppTypeFor<TypeCategory::Integer, 2> RTDEF(DotProductInteger2)( 16550e0b298Speter klausler const Descriptor &x, const Descriptor &y, const char *source, int line) { 166a5a493e1SPeter Klausler return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line); 16750e0b298Speter klausler } 16876facde3SSlava Zakharin CppTypeFor<TypeCategory::Integer, 4> RTDEF(DotProductInteger4)( 16950e0b298Speter klausler const Descriptor &x, const Descriptor &y, const char *source, int line) { 170a5a493e1SPeter Klausler return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line); 17150e0b298Speter klausler } 17276facde3SSlava Zakharin CppTypeFor<TypeCategory::Integer, 8> RTDEF(DotProductInteger8)( 17350e0b298Speter klausler const Descriptor &x, const Descriptor &y, const char *source, int line) { 1745e1421b2Speter klausler return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line); 17550e0b298Speter klausler } 17650e0b298Speter klausler #ifdef __SIZEOF_INT128__ 17776facde3SSlava Zakharin CppTypeFor<TypeCategory::Integer, 16> RTDEF(DotProductInteger16)( 17850e0b298Speter klausler const Descriptor &x, const Descriptor &y, const char *source, int line) { 1795e1421b2Speter klausler return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line); 18050e0b298Speter klausler } 18150e0b298Speter klausler #endif 18250e0b298Speter klausler 183*fc97d2e6SPeter Klausler CppTypeFor<TypeCategory::Unsigned, 1> RTDEF(DotProductUnsigned1)( 184*fc97d2e6SPeter Klausler const Descriptor &x, const Descriptor &y, const char *source, int line) { 185*fc97d2e6SPeter Klausler return DotProduct<TypeCategory::Unsigned, 1>{}(x, y, source, line); 186*fc97d2e6SPeter Klausler } 187*fc97d2e6SPeter Klausler CppTypeFor<TypeCategory::Unsigned, 2> RTDEF(DotProductUnsigned2)( 188*fc97d2e6SPeter Klausler const Descriptor &x, const Descriptor &y, const char *source, int line) { 189*fc97d2e6SPeter Klausler return DotProduct<TypeCategory::Unsigned, 2>{}(x, y, source, line); 190*fc97d2e6SPeter Klausler } 191*fc97d2e6SPeter Klausler CppTypeFor<TypeCategory::Unsigned, 4> RTDEF(DotProductUnsigned4)( 192*fc97d2e6SPeter Klausler const Descriptor &x, const Descriptor &y, const char *source, int line) { 193*fc97d2e6SPeter Klausler return DotProduct<TypeCategory::Unsigned, 4>{}(x, y, source, line); 194*fc97d2e6SPeter Klausler } 195*fc97d2e6SPeter Klausler CppTypeFor<TypeCategory::Unsigned, 8> RTDEF(DotProductUnsigned8)( 196*fc97d2e6SPeter Klausler const Descriptor &x, const Descriptor &y, const char *source, int line) { 197*fc97d2e6SPeter Klausler return DotProduct<TypeCategory::Unsigned, 8>{}(x, y, source, line); 198*fc97d2e6SPeter Klausler } 199*fc97d2e6SPeter Klausler #ifdef __SIZEOF_INT128__ 200*fc97d2e6SPeter Klausler CppTypeFor<TypeCategory::Unsigned, 16> RTDEF(DotProductUnsigned16)( 201*fc97d2e6SPeter Klausler const Descriptor &x, const Descriptor &y, const char *source, int line) { 202*fc97d2e6SPeter Klausler return DotProduct<TypeCategory::Unsigned, 16>{}(x, y, source, line); 203*fc97d2e6SPeter Klausler } 204*fc97d2e6SPeter Klausler #endif 205*fc97d2e6SPeter Klausler 20650e0b298Speter klausler // TODO: REAL/COMPLEX(2 & 3) 207a5a493e1SPeter Klausler // Intermediate results and operations are at least 64 bits 20876facde3SSlava Zakharin CppTypeFor<TypeCategory::Real, 4> RTDEF(DotProductReal4)( 20950e0b298Speter klausler const Descriptor &x, const Descriptor &y, const char *source, int line) { 210a5a493e1SPeter Klausler return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line); 21150e0b298Speter klausler } 21276facde3SSlava Zakharin CppTypeFor<TypeCategory::Real, 8> RTDEF(DotProductReal8)( 21350e0b298Speter klausler const Descriptor &x, const Descriptor &y, const char *source, int line) { 2145e1421b2Speter klausler return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line); 21550e0b298Speter klausler } 216104f3c18SSlava Zakharin #if HAS_FLOAT80 21776facde3SSlava Zakharin CppTypeFor<TypeCategory::Real, 10> RTDEF(DotProductReal10)( 21850e0b298Speter klausler const Descriptor &x, const Descriptor &y, const char *source, int line) { 2195e1421b2Speter klausler return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line); 22050e0b298Speter klausler } 2214daa33f6SPeter Klausler #endif 222fc51c7f0SSlava Zakharin #if HAS_LDBL128 || HAS_FLOAT128 22376facde3SSlava Zakharin CppTypeFor<TypeCategory::Real, 16> RTDEF(DotProductReal16)( 22450e0b298Speter klausler const Descriptor &x, const Descriptor &y, const char *source, int line) { 2255e1421b2Speter klausler return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line); 22650e0b298Speter klausler } 22750e0b298Speter klausler #endif 22850e0b298Speter klausler 22976facde3SSlava Zakharin void RTDEF(CppDotProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result, 23050e0b298Speter klausler const Descriptor &x, const Descriptor &y, const char *source, int line) { 231f8a9f43eSSlava Zakharin result = DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line); 23250e0b298Speter klausler } 23376facde3SSlava Zakharin void RTDEF(CppDotProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result, 23450e0b298Speter klausler const Descriptor &x, const Descriptor &y, const char *source, int line) { 2355e1421b2Speter klausler result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line); 23650e0b298Speter klausler } 237104f3c18SSlava Zakharin #if HAS_FLOAT80 23876facde3SSlava Zakharin void RTDEF(CppDotProductComplex10)( 239f8a9f43eSSlava Zakharin CppTypeFor<TypeCategory::Complex, 10> &result, const Descriptor &x, 240f8a9f43eSSlava Zakharin const Descriptor &y, const char *source, int line) { 2415e1421b2Speter klausler result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line); 24250e0b298Speter klausler } 243f8a9f43eSSlava Zakharin #endif 244fc51c7f0SSlava Zakharin #if HAS_LDBL128 || HAS_FLOAT128 24576facde3SSlava Zakharin void RTDEF(CppDotProductComplex16)( 246f8a9f43eSSlava Zakharin CppTypeFor<TypeCategory::Complex, 16> &result, const Descriptor &x, 247f8a9f43eSSlava Zakharin const Descriptor &y, const char *source, int line) { 2485e1421b2Speter klausler result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line); 24950e0b298Speter klausler } 25050e0b298Speter klausler #endif 25150e0b298Speter klausler 25276facde3SSlava Zakharin bool RTDEF(DotProductLogical)( 25350e0b298Speter klausler const Descriptor &x, const Descriptor &y, const char *source, int line) { 2545e1421b2Speter klausler return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line); 25550e0b298Speter klausler } 25676facde3SSlava Zakharin 25776facde3SSlava Zakharin RT_EXT_API_GROUP_END 25850e0b298Speter klausler } // extern "C" 25950e0b298Speter klausler } // namespace Fortran::runtime 260