1 //===-- runtime/dot-product.cpp -------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "float.h" 10 #include "terminator.h" 11 #include "tools.h" 12 #include "flang/Common/float128.h" 13 #include "flang/Runtime/cpp-type.h" 14 #include "flang/Runtime/descriptor.h" 15 #include "flang/Runtime/reduction.h" 16 #include <cfloat> 17 #include <cinttypes> 18 19 namespace Fortran::runtime { 20 21 // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first 22 // argument; MATMUL does not. 23 24 // General accumulator for any type and stride; this is not used for 25 // contiguous numeric vectors. 26 template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 27 class Accumulator { 28 public: 29 using Result = AccumulationType<RCAT, RKIND>; 30 RT_API_ATTRS Accumulator(const Descriptor &x, const Descriptor &y) 31 : x_{x}, y_{y} {} 32 RT_API_ATTRS void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) { 33 if constexpr (RCAT == TypeCategory::Logical) { 34 sum_ = sum_ || 35 (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt)); 36 } else { 37 const XT &xElement{*x_.Element<XT>(&xAt)}; 38 const YT &yElement{*y_.Element<YT>(&yAt)}; 39 if constexpr (RCAT == TypeCategory::Complex) { 40 sum_ += rtcmplx::conj(static_cast<Result>(xElement)) * 41 static_cast<Result>(yElement); 42 } else { 43 sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement); 44 } 45 } 46 } 47 RT_API_ATTRS Result GetResult() const { return sum_; } 48 49 private: 50 const Descriptor &x_, &y_; 51 Result sum_{}; 52 }; 53 54 template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 55 static inline RT_API_ATTRS CppTypeFor<RCAT, RKIND> DoDotProduct( 56 const Descriptor &x, const Descriptor &y, Terminator &terminator) { 57 using Result = CppTypeFor<RCAT, RKIND>; 58 RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1); 59 SubscriptValue n{x.GetDimension(0).Extent()}; 60 if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) { 61 terminator.Crash( 62 "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd", 63 static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN)); 64 } 65 if constexpr (RCAT != TypeCategory::Logical) { 66 if (x.GetDimension(0).ByteStride() == sizeof(XT) && 67 y.GetDimension(0).ByteStride() == sizeof(YT)) { 68 // Contiguous numeric vectors 69 if constexpr (std::is_same_v<XT, YT>) { 70 // Contiguous homogeneous numeric vectors 71 if constexpr (std::is_same_v<XT, float>) { 72 // TODO: call BLAS-1 SDOT or SDSDOT 73 } else if constexpr (std::is_same_v<XT, double>) { 74 // TODO: call BLAS-1 DDOT 75 } else if constexpr (std::is_same_v<XT, rtcmplx::complex<float>>) { 76 // TODO: call BLAS-1 CDOTC 77 } else if constexpr (std::is_same_v<XT, rtcmplx::complex<double>>) { 78 // TODO: call BLAS-1 ZDOTC 79 } 80 } 81 XT *xp{x.OffsetElement<XT>(0)}; 82 YT *yp{y.OffsetElement<YT>(0)}; 83 using AccumType = AccumulationType<RCAT, RKIND>; 84 AccumType accum{}; 85 if constexpr (RCAT == TypeCategory::Complex) { 86 for (SubscriptValue j{0}; j < n; ++j) { 87 // conj() may instantiate its argument twice, 88 // so xp has to be incremented separately. 89 // This is a workaround for an alleged bug in clang, 90 // that shows up as: 91 // warning: multiple unsequenced modifications to 'xp' 92 accum += rtcmplx::conj(static_cast<AccumType>(*xp)) * 93 static_cast<AccumType>(*yp++); 94 xp++; 95 } 96 } else { 97 for (SubscriptValue j{0}; j < n; ++j) { 98 accum += 99 static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++); 100 } 101 } 102 return static_cast<Result>(accum); 103 } 104 } 105 // Non-contiguous, heterogeneous, & LOGICAL cases 106 SubscriptValue xAt{x.GetDimension(0).LowerBound()}; 107 SubscriptValue yAt{y.GetDimension(0).LowerBound()}; 108 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; 109 for (SubscriptValue j{0}; j < n; ++j) { 110 accumulator.AccumulateIndexed(xAt++, yAt++); 111 } 112 return static_cast<Result>(accumulator.GetResult()); 113 } 114 115 template <TypeCategory RCAT, int RKIND> struct DotProduct { 116 using Result = CppTypeFor<RCAT, RKIND>; 117 template <TypeCategory XCAT, int XKIND> struct DP1 { 118 template <TypeCategory YCAT, int YKIND> struct DP2 { 119 RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y, 120 Terminator &terminator) const { 121 if constexpr (constexpr auto resultType{ 122 GetResultType(XCAT, XKIND, YCAT, YKIND)}) { 123 if constexpr (resultType->first == RCAT && 124 (resultType->second <= RKIND || RCAT == TypeCategory::Logical)) { 125 return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>, 126 CppTypeFor<YCAT, YKIND>>(x, y, terminator); 127 } 128 } 129 terminator.Crash( 130 "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))", 131 static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND, 132 static_cast<int>(YCAT), YKIND); 133 } 134 }; 135 RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y, 136 Terminator &terminator, TypeCategory yCat, int yKind) const { 137 return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator); 138 } 139 }; 140 RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y, 141 const char *source, int line) const { 142 Terminator terminator{source, line}; 143 if (RCAT != TypeCategory::Logical && x.type() == y.type()) { 144 // No conversions needed, operands and result have same known type 145 return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}( 146 x, y, terminator); 147 } else { 148 auto xCatKind{x.type().GetCategoryAndKind()}; 149 auto yCatKind{y.type().GetCategoryAndKind()}; 150 RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); 151 return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second, 152 terminator, x, y, terminator, yCatKind->first, yCatKind->second); 153 } 154 } 155 }; 156 157 extern "C" { 158 RT_EXT_API_GROUP_BEGIN 159 160 CppTypeFor<TypeCategory::Integer, 1> RTDEF(DotProductInteger1)( 161 const Descriptor &x, const Descriptor &y, const char *source, int line) { 162 return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line); 163 } 164 CppTypeFor<TypeCategory::Integer, 2> RTDEF(DotProductInteger2)( 165 const Descriptor &x, const Descriptor &y, const char *source, int line) { 166 return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line); 167 } 168 CppTypeFor<TypeCategory::Integer, 4> RTDEF(DotProductInteger4)( 169 const Descriptor &x, const Descriptor &y, const char *source, int line) { 170 return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line); 171 } 172 CppTypeFor<TypeCategory::Integer, 8> RTDEF(DotProductInteger8)( 173 const Descriptor &x, const Descriptor &y, const char *source, int line) { 174 return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line); 175 } 176 #ifdef __SIZEOF_INT128__ 177 CppTypeFor<TypeCategory::Integer, 16> RTDEF(DotProductInteger16)( 178 const Descriptor &x, const Descriptor &y, const char *source, int line) { 179 return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line); 180 } 181 #endif 182 183 CppTypeFor<TypeCategory::Unsigned, 1> RTDEF(DotProductUnsigned1)( 184 const Descriptor &x, const Descriptor &y, const char *source, int line) { 185 return DotProduct<TypeCategory::Unsigned, 1>{}(x, y, source, line); 186 } 187 CppTypeFor<TypeCategory::Unsigned, 2> RTDEF(DotProductUnsigned2)( 188 const Descriptor &x, const Descriptor &y, const char *source, int line) { 189 return DotProduct<TypeCategory::Unsigned, 2>{}(x, y, source, line); 190 } 191 CppTypeFor<TypeCategory::Unsigned, 4> RTDEF(DotProductUnsigned4)( 192 const Descriptor &x, const Descriptor &y, const char *source, int line) { 193 return DotProduct<TypeCategory::Unsigned, 4>{}(x, y, source, line); 194 } 195 CppTypeFor<TypeCategory::Unsigned, 8> RTDEF(DotProductUnsigned8)( 196 const Descriptor &x, const Descriptor &y, const char *source, int line) { 197 return DotProduct<TypeCategory::Unsigned, 8>{}(x, y, source, line); 198 } 199 #ifdef __SIZEOF_INT128__ 200 CppTypeFor<TypeCategory::Unsigned, 16> RTDEF(DotProductUnsigned16)( 201 const Descriptor &x, const Descriptor &y, const char *source, int line) { 202 return DotProduct<TypeCategory::Unsigned, 16>{}(x, y, source, line); 203 } 204 #endif 205 206 // TODO: REAL/COMPLEX(2 & 3) 207 // Intermediate results and operations are at least 64 bits 208 CppTypeFor<TypeCategory::Real, 4> RTDEF(DotProductReal4)( 209 const Descriptor &x, const Descriptor &y, const char *source, int line) { 210 return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line); 211 } 212 CppTypeFor<TypeCategory::Real, 8> RTDEF(DotProductReal8)( 213 const Descriptor &x, const Descriptor &y, const char *source, int line) { 214 return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line); 215 } 216 #if HAS_FLOAT80 217 CppTypeFor<TypeCategory::Real, 10> RTDEF(DotProductReal10)( 218 const Descriptor &x, const Descriptor &y, const char *source, int line) { 219 return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line); 220 } 221 #endif 222 #if HAS_LDBL128 || HAS_FLOAT128 223 CppTypeFor<TypeCategory::Real, 16> RTDEF(DotProductReal16)( 224 const Descriptor &x, const Descriptor &y, const char *source, int line) { 225 return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line); 226 } 227 #endif 228 229 void RTDEF(CppDotProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result, 230 const Descriptor &x, const Descriptor &y, const char *source, int line) { 231 result = DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line); 232 } 233 void RTDEF(CppDotProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result, 234 const Descriptor &x, const Descriptor &y, const char *source, int line) { 235 result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line); 236 } 237 #if HAS_FLOAT80 238 void RTDEF(CppDotProductComplex10)( 239 CppTypeFor<TypeCategory::Complex, 10> &result, const Descriptor &x, 240 const Descriptor &y, const char *source, int line) { 241 result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line); 242 } 243 #endif 244 #if HAS_LDBL128 || HAS_FLOAT128 245 void RTDEF(CppDotProductComplex16)( 246 CppTypeFor<TypeCategory::Complex, 16> &result, const Descriptor &x, 247 const Descriptor &y, const char *source, int line) { 248 result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line); 249 } 250 #endif 251 252 bool RTDEF(DotProductLogical)( 253 const Descriptor &x, const Descriptor &y, const char *source, int line) { 254 return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line); 255 } 256 257 RT_EXT_API_GROUP_END 258 } // extern "C" 259 } // namespace Fortran::runtime 260