1 //===-- MPCUtils.h ----------------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef LLVM_LIBC_UTILS_MPCWRAPPER_MPCUTILS_H 10 #define LLVM_LIBC_UTILS_MPCWRAPPER_MPCUTILS_H 11 12 #include "src/__support/CPP/type_traits.h" 13 #include "src/__support/complex_type.h" 14 #include "src/__support/macros/config.h" 15 #include "src/__support/macros/properties/complex_types.h" 16 #include "src/__support/macros/properties/types.h" 17 #include "test/UnitTest/RoundingModeUtils.h" 18 #include "test/UnitTest/Test.h" 19 20 #include <stdint.h> 21 22 namespace LIBC_NAMESPACE_DECL { 23 namespace testing { 24 namespace mpc { 25 26 enum class Operation { 27 // Operations which take a single complex floating point number as input 28 // and produce a single floating point number as output which has the same 29 // floating point type as the real/imaginary part of the input. 30 BeginUnaryOperationsSingleOutputDifferentOutputType, 31 Carg, 32 Cabs, 33 EndUnaryOperationsSingleOutputDifferentOutputType, 34 35 // Operations which take a single complex floating point number as input 36 // and produce a single complex floating point number of the same kind 37 // as output. 38 BeginUnaryOperationsSingleOutputSameOutputType, 39 Cproj, 40 Csqrt, 41 Clog, 42 Cexp, 43 Csinh, 44 Ccosh, 45 Ctanh, 46 Casinh, 47 Cacosh, 48 Catanh, 49 Csin, 50 Ccos, 51 Ctan, 52 Casin, 53 Cacos, 54 Catan, 55 EndUnaryOperationsSingleOutputSameOutputType, 56 57 // Operations which take two complex floating point numbers as input 58 // and produce a single complex floating point number of the same kind 59 // as output. 60 BeginBinaryOperationsSingleOutput, 61 Cpow, 62 EndBinaryOperationsSingleOutput, 63 }; 64 65 using LIBC_NAMESPACE::fputil::testing::RoundingMode; 66 67 template <typename T> struct BinaryInput { 68 static_assert(LIBC_NAMESPACE::cpp::is_complex_v<T>, 69 "Template parameter of BinaryInput must be a complex floating " 70 "point type."); 71 72 using Type = T; 73 T x, y; 74 }; 75 76 namespace internal { 77 78 template <typename InputType, typename OutputType> 79 bool compare_unary_operation_single_output_same_type(Operation op, 80 InputType input, 81 OutputType libc_output, 82 double ulp_tolerance, 83 RoundingMode rounding); 84 85 template <typename InputType, typename OutputType> 86 bool compare_unary_operation_single_output_different_type( 87 Operation op, InputType input, OutputType libc_output, double ulp_tolerance, 88 RoundingMode rounding); 89 90 template <typename InputType, typename OutputType> 91 bool compare_binary_operation_one_output(Operation op, 92 const BinaryInput<InputType> &input, 93 OutputType libc_output, 94 double ulp_tolerance, 95 RoundingMode rounding); 96 97 template <typename InputType, typename OutputType> 98 void explain_unary_operation_single_output_same_type_error( 99 Operation op, InputType input, OutputType match_value, double ulp_tolerance, 100 RoundingMode rounding); 101 102 template <typename InputType, typename OutputType> 103 void explain_unary_operation_single_output_different_type_error( 104 Operation op, InputType input, OutputType match_value, double ulp_tolerance, 105 RoundingMode rounding); 106 107 template <typename InputType, typename OutputType> 108 void explain_binary_operation_one_output_error( 109 Operation op, const BinaryInput<InputType> &input, OutputType match_value, 110 double ulp_tolerance, RoundingMode rounding); 111 112 template <Operation op, typename InputType, typename OutputType> 113 class MPCMatcher : public testing::Matcher<OutputType> { 114 private: 115 InputType input; 116 OutputType match_value; 117 double ulp_tolerance; 118 RoundingMode rounding; 119 120 public: 121 MPCMatcher(InputType testInput, double ulp_tolerance, RoundingMode rounding) 122 : input(testInput), ulp_tolerance(ulp_tolerance), rounding(rounding) {} 123 124 bool match(OutputType libcResult) { 125 match_value = libcResult; 126 return match(input, match_value); 127 } 128 129 void explainError() override { // NOLINT 130 explain_error(input, match_value); 131 } 132 133 private: 134 template <typename InType, typename OutType> 135 bool match(InType in, OutType out) { 136 if (cpp::is_same_v<InType, OutType>) { 137 return compare_unary_operation_single_output_same_type( 138 op, in, out, ulp_tolerance, rounding); 139 } else { 140 return compare_unary_operation_single_output_different_type( 141 op, in, out, ulp_tolerance, rounding); 142 } 143 } 144 145 template <typename T, typename U> 146 bool match(const BinaryInput<T> &in, U out) { 147 return compare_binary_operation_one_output(op, in, out, ulp_tolerance, 148 rounding); 149 } 150 151 template <typename InType, typename OutType> 152 void explain_error(InType in, OutType out) { 153 if (cpp::is_same_v<InType, OutType>) { 154 explain_unary_operation_single_output_same_type_error( 155 op, in, out, ulp_tolerance, rounding); 156 } else { 157 explain_unary_operation_single_output_different_type_error( 158 op, in, out, ulp_tolerance, rounding); 159 } 160 } 161 162 template <typename T, typename U> 163 void explain_error(const BinaryInput<T> &in, U out) { 164 explain_binary_operation_one_output_error(op, in, out, ulp_tolerance, 165 rounding); 166 } 167 }; 168 169 } // namespace internal 170 171 // Return true if the input and ouput types for the operation op are valid 172 // types. 173 template <Operation op, typename InputType, typename OutputType> 174 constexpr bool is_valid_operation() { 175 return (Operation::BeginBinaryOperationsSingleOutput < op && 176 op < Operation::EndBinaryOperationsSingleOutput && 177 cpp::is_complex_type_same<InputType, OutputType> && 178 cpp::is_complex_v<InputType>) || 179 (Operation::BeginUnaryOperationsSingleOutputSameOutputType < op && 180 op < Operation::EndUnaryOperationsSingleOutputSameOutputType && 181 cpp::is_complex_type_same<InputType, OutputType> && 182 cpp::is_complex_v<InputType>) || 183 (Operation::BeginUnaryOperationsSingleOutputDifferentOutputType < op && 184 op < Operation::EndUnaryOperationsSingleOutputDifferentOutputType && 185 cpp::is_same_v<make_real_t<InputType>, OutputType> && 186 cpp::is_complex_v<InputType>); 187 } 188 189 template <Operation op, typename InputType, typename OutputType> 190 cpp::enable_if_t<is_valid_operation<op, InputType, OutputType>(), 191 internal::MPCMatcher<op, InputType, OutputType>> 192 get_mpc_matcher(InputType input, [[maybe_unused]] OutputType output, 193 double ulp_tolerance, RoundingMode rounding) { 194 return internal::MPCMatcher<op, InputType, OutputType>(input, ulp_tolerance, 195 rounding); 196 } 197 198 } // namespace mpc 199 } // namespace testing 200 } // namespace LIBC_NAMESPACE_DECL 201 202 #define EXPECT_MPC_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \ 203 EXPECT_THAT(match_value, \ 204 LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>( \ 205 input, match_value, ulp_tolerance, \ 206 LIBC_NAMESPACE::fputil::testing::RoundingMode::Nearest)) 207 208 #define EXPECT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ 209 rounding) \ 210 EXPECT_THAT(match_value, LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>( \ 211 input, match_value, ulp_tolerance, rounding)) 212 213 #define EXPECT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value, \ 214 ulp_tolerance, rounding) \ 215 { \ 216 MPCRND::ForceRoundingMode __r(rounding); \ 217 if (__r.success) { \ 218 EXPECT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ 219 rounding); \ 220 } \ 221 } 222 223 #define EXPECT_MPC_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance) \ 224 { \ 225 namespace MPCRND = LIBC_NAMESPACE::fputil::testing; \ 226 for (int i = 0; i < 4; i++) { \ 227 MPCRND::RoundingMode r_mode = static_cast<MPCRND::RoundingMode>(i); \ 228 EXPECT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value, \ 229 ulp_tolerance, r_mode); \ 230 } \ 231 } 232 233 #define TEST_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ 234 rounding) \ 235 LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>(input, match_value, \ 236 ulp_tolerance, rounding) \ 237 .match(match_value) 238 239 #define ASSERT_MPC_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \ 240 ASSERT_THAT(match_value, \ 241 LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>( \ 242 input, match_value, ulp_tolerance, \ 243 LIBC_NAMESPACE::fputil::testing::RoundingMode::Nearest)) 244 245 #define ASSERT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ 246 rounding) \ 247 ASSERT_THAT(match_value, LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>( \ 248 input, match_value, ulp_tolerance, rounding)) 249 250 #define ASSERT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value, \ 251 ulp_tolerance, rounding) \ 252 { \ 253 MPCRND::ForceRoundingMode __r(rounding); \ 254 if (__r.success) { \ 255 ASSERT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ 256 rounding); \ 257 } \ 258 } 259 260 #define ASSERT_MPC_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance) \ 261 { \ 262 namespace MPCRND = LIBC_NAMESPACE::fputil::testing; \ 263 for (int i = 0; i < 4; i++) { \ 264 MPCRND::RoundingMode r_mode = static_cast<MPCRND::RoundingMode>(i); \ 265 ASSERT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value, \ 266 ulp_tolerance, r_mode); \ 267 } \ 268 } 269 270 #endif // LLVM_LIBC_UTILS_MPCWRAPPER_MPCUTILS_H 271