xref: /llvm-project/libc/utils/MPFRWrapper/MPFRUtils.h (revision 7f37b34d31914120a5bb6bd341e7616773df7613)
1 //===-- MPFRUtils.h ---------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_UTILS_MPFRWRAPPER_MPFRUTILS_H
10 #define LLVM_LIBC_UTILS_MPFRWRAPPER_MPFRUTILS_H
11 
12 #include "src/__support/CPP/type_traits.h"
13 #include "src/__support/macros/config.h"
14 #include "test/UnitTest/RoundingModeUtils.h"
15 #include "test/UnitTest/Test.h"
16 
17 #include <stdint.h>
18 
19 namespace LIBC_NAMESPACE_DECL {
20 namespace testing {
21 namespace mpfr {
22 
23 enum class Operation : int {
24   // Operations which take a single floating point number as input
25   // and produce a single floating point number as output. The input
26   // and output floating point numbers are of the same kind.
27   BeginUnaryOperationsSingleOutput,
28   Abs,
29   Acos,
30   Acosh,
31   Asin,
32   Asinh,
33   Atan,
34   Atanh,
35   Cbrt,
36   Ceil,
37   Cos,
38   Cosh,
39   Cospi,
40   Erf,
41   Exp,
42   Exp2,
43   Exp2m1,
44   Exp10,
45   Exp10m1,
46   Expm1,
47   Floor,
48   Log,
49   Log2,
50   Log10,
51   Log1p,
52   Mod2PI,
53   ModPIOver2,
54   ModPIOver4,
55   Round,
56   RoundEven,
57   Sin,
58   Sinpi,
59   Sinh,
60   Sqrt,
61   Tan,
62   Tanh,
63   Tanpi,
64   Trunc,
65   EndUnaryOperationsSingleOutput,
66 
67   // Operations which take a single floating point nubmer as input
68   // but produce two outputs. The first ouput is a floating point
69   // number of the same type as the input. The second output is of type
70   // 'int'.
71   BeginUnaryOperationsTwoOutputs,
72   Frexp, // Floating point output, the first output, is the fractional part.
73   EndUnaryOperationsTwoOutputs,
74 
75   // Operations wich take two floating point nubmers of the same type as
76   // input and produce a single floating point number of the same type as
77   // output.
78   BeginBinaryOperationsSingleOutput,
79   Add,
80   Atan2,
81   Div,
82   Fmod,
83   Hypot,
84   Mul,
85   Pow,
86   Sub,
87   EndBinaryOperationsSingleOutput,
88 
89   // Operations which take two floating point numbers of the same type as
90   // input and produce two outputs. The first output is a floating point number
91   // of the same type as the inputs. The second output is of type 'int'.
92   BeginBinaryOperationsTwoOutputs,
93   RemQuo, // The first output(floating point) is the remainder.
94   EndBinaryOperationsTwoOutputs,
95 
96   // Operations which take three floating point nubmers of the same type as
97   // input and produce a single floating point number of the same type as
98   // output.
99   BeginTernaryOperationsSingleOuput,
100   Fma,
101   EndTernaryOperationsSingleOutput,
102 };
103 
104 using LIBC_NAMESPACE::fputil::testing::ForceRoundingMode;
105 using LIBC_NAMESPACE::fputil::testing::RoundingMode;
106 
107 template <typename T> struct BinaryInput {
108   static_assert(
109       LIBC_NAMESPACE::cpp::is_floating_point_v<T>,
110       "Template parameter of BinaryInput must be a floating point type.");
111 
112   using Type = T;
113   T x, y;
114 };
115 
116 template <typename T> struct TernaryInput {
117   static_assert(
118       LIBC_NAMESPACE::cpp::is_floating_point_v<T>,
119       "Template parameter of TernaryInput must be a floating point type.");
120 
121   using Type = T;
122   T x, y, z;
123 };
124 
125 template <typename T> struct BinaryOutput {
126   T f;
127   int i;
128 };
129 
130 namespace internal {
131 
132 template <typename T1, typename T2>
133 struct AreMatchingBinaryInputAndBinaryOutput {
134   static constexpr bool VALUE = false;
135 };
136 
137 template <typename T>
138 struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {
139   static constexpr bool VALUE = cpp::is_floating_point_v<T>;
140 };
141 
142 template <typename T> struct IsBinaryInput {
143   static constexpr bool VALUE = false;
144 };
145 
146 template <typename T> struct IsBinaryInput<BinaryInput<T>> {
147   static constexpr bool VALUE = true;
148 };
149 
150 template <typename T> struct IsTernaryInput {
151   static constexpr bool VALUE = false;
152 };
153 
154 template <typename T> struct IsTernaryInput<TernaryInput<T>> {
155   static constexpr bool VALUE = true;
156 };
157 
158 template <typename T> struct MakeScalarInput : cpp::type_identity<T> {};
159 
160 template <typename T>
161 struct MakeScalarInput<BinaryInput<T>> : cpp::type_identity<T> {};
162 
163 template <typename T>
164 struct MakeScalarInput<TernaryInput<T>> : cpp::type_identity<T> {};
165 
166 template <typename InputType, typename OutputType>
167 bool compare_unary_operation_single_output(Operation op, InputType input,
168                                            OutputType libc_output,
169                                            double ulp_tolerance,
170                                            RoundingMode rounding);
171 template <typename T>
172 bool compare_unary_operation_two_outputs(Operation op, T input,
173                                          const BinaryOutput<T> &libc_output,
174                                          double ulp_tolerance,
175                                          RoundingMode rounding);
176 template <typename T>
177 bool compare_binary_operation_two_outputs(Operation op,
178                                           const BinaryInput<T> &input,
179                                           const BinaryOutput<T> &libc_output,
180                                           double ulp_tolerance,
181                                           RoundingMode rounding);
182 
183 template <typename InputType, typename OutputType>
184 bool compare_binary_operation_one_output(Operation op,
185                                          const BinaryInput<InputType> &input,
186                                          OutputType libc_output,
187                                          double ulp_tolerance,
188                                          RoundingMode rounding);
189 
190 template <typename InputType, typename OutputType>
191 bool compare_ternary_operation_one_output(Operation op,
192                                           const TernaryInput<InputType> &input,
193                                           OutputType libc_output,
194                                           double ulp_tolerance,
195                                           RoundingMode rounding);
196 
197 template <typename InputType, typename OutputType>
198 void explain_unary_operation_single_output_error(Operation op, InputType input,
199                                                  OutputType match_value,
200                                                  double ulp_tolerance,
201                                                  RoundingMode rounding);
202 template <typename T>
203 void explain_unary_operation_two_outputs_error(
204     Operation op, T input, const BinaryOutput<T> &match_value,
205     double ulp_tolerance, RoundingMode rounding);
206 template <typename T>
207 void explain_binary_operation_two_outputs_error(
208     Operation op, const BinaryInput<T> &input,
209     const BinaryOutput<T> &match_value, double ulp_tolerance,
210     RoundingMode rounding);
211 
212 template <typename InputType, typename OutputType>
213 void explain_binary_operation_one_output_error(
214     Operation op, const BinaryInput<InputType> &input, OutputType match_value,
215     double ulp_tolerance, RoundingMode rounding);
216 
217 template <typename InputType, typename OutputType>
218 void explain_ternary_operation_one_output_error(
219     Operation op, const TernaryInput<InputType> &input, OutputType match_value,
220     double ulp_tolerance, RoundingMode rounding);
221 
222 template <Operation op, bool silent, typename InputType, typename OutputType>
223 class MPFRMatcher : public testing::Matcher<OutputType> {
224   InputType input;
225   OutputType match_value;
226   double ulp_tolerance;
227   RoundingMode rounding;
228 
229 public:
230   MPFRMatcher(InputType testInput, double ulp_tolerance, RoundingMode rounding)
231       : input(testInput), ulp_tolerance(ulp_tolerance), rounding(rounding) {}
232 
233   bool match(OutputType libcResult) {
234     match_value = libcResult;
235     return match(input, match_value);
236   }
237 
238   // This method is marked with NOLINT because the name `explainError` does not
239   // conform to the coding style.
240   void explainError() override { // NOLINT
241     explain_error(input, match_value);
242   }
243 
244   // Whether the `explainError` step is skipped or not.
245   bool is_silent() const override { return silent; }
246 
247 private:
248   template <typename InType, typename OutType>
249   bool match(InType in, OutType out) {
250     return compare_unary_operation_single_output(op, in, out, ulp_tolerance,
251                                                  rounding);
252   }
253 
254   template <typename T> bool match(T in, const BinaryOutput<T> &out) {
255     return compare_unary_operation_two_outputs(op, in, out, ulp_tolerance,
256                                                rounding);
257   }
258 
259   template <typename T, typename U>
260   bool match(const BinaryInput<T> &in, U out) {
261     return compare_binary_operation_one_output(op, in, out, ulp_tolerance,
262                                                rounding);
263   }
264 
265   template <typename T>
266   bool match(BinaryInput<T> in, const BinaryOutput<T> &out) {
267     return compare_binary_operation_two_outputs(op, in, out, ulp_tolerance,
268                                                 rounding);
269   }
270 
271   template <typename InType, typename OutType>
272   bool match(const TernaryInput<InType> &in, OutType out) {
273     return compare_ternary_operation_one_output(op, in, out, ulp_tolerance,
274                                                 rounding);
275   }
276 
277   template <typename InType, typename OutType>
278   void explain_error(InType in, OutType out) {
279     explain_unary_operation_single_output_error(op, in, out, ulp_tolerance,
280                                                 rounding);
281   }
282 
283   template <typename T> void explain_error(T in, const BinaryOutput<T> &out) {
284     explain_unary_operation_two_outputs_error(op, in, out, ulp_tolerance,
285                                               rounding);
286   }
287 
288   template <typename T>
289   void explain_error(const BinaryInput<T> &in, const BinaryOutput<T> &out) {
290     explain_binary_operation_two_outputs_error(op, in, out, ulp_tolerance,
291                                                rounding);
292   }
293 
294   template <typename T, typename U>
295   void explain_error(const BinaryInput<T> &in, U out) {
296     explain_binary_operation_one_output_error(op, in, out, ulp_tolerance,
297                                               rounding);
298   }
299 
300   template <typename InType, typename OutType>
301   void explain_error(const TernaryInput<InType> &in, OutType out) {
302     explain_ternary_operation_one_output_error(op, in, out, ulp_tolerance,
303                                                rounding);
304   }
305 };
306 
307 } // namespace internal
308 
309 // Return true if the input and ouput types for the operation op are valid
310 // types.
311 template <Operation op, typename InputType, typename OutputType>
312 constexpr bool is_valid_operation() {
313   constexpr bool IS_NARROWING_OP =
314       (op == Operation::Sqrt && cpp::is_floating_point_v<InputType> &&
315        cpp::is_floating_point_v<OutputType> &&
316        sizeof(OutputType) <= sizeof(InputType)) ||
317       (Operation::BeginBinaryOperationsSingleOutput < op &&
318        op < Operation::EndBinaryOperationsSingleOutput &&
319        internal::IsBinaryInput<InputType>::VALUE &&
320        cpp::is_floating_point_v<
321            typename internal::MakeScalarInput<InputType>::type> &&
322        cpp::is_floating_point_v<OutputType>) ||
323       (op == Operation::Fma && internal::IsTernaryInput<InputType>::VALUE &&
324        cpp::is_floating_point_v<
325            typename internal::MakeScalarInput<InputType>::type> &&
326        cpp::is_floating_point_v<OutputType>);
327   if (IS_NARROWING_OP)
328     return true;
329   return (Operation::BeginUnaryOperationsSingleOutput < op &&
330           op < Operation::EndUnaryOperationsSingleOutput &&
331           cpp::is_same_v<InputType, OutputType> &&
332           cpp::is_floating_point_v<InputType>) ||
333          (Operation::BeginUnaryOperationsTwoOutputs < op &&
334           op < Operation::EndUnaryOperationsTwoOutputs &&
335           cpp::is_floating_point_v<InputType> &&
336           cpp::is_same_v<OutputType, BinaryOutput<InputType>>) ||
337          (Operation::BeginBinaryOperationsSingleOutput < op &&
338           op < Operation::EndBinaryOperationsSingleOutput &&
339           cpp::is_floating_point_v<OutputType> &&
340           cpp::is_same_v<InputType, BinaryInput<OutputType>>) ||
341          (Operation::BeginBinaryOperationsTwoOutputs < op &&
342           op < Operation::EndBinaryOperationsTwoOutputs &&
343           internal::AreMatchingBinaryInputAndBinaryOutput<InputType,
344                                                           OutputType>::VALUE) ||
345          (Operation::BeginTernaryOperationsSingleOuput < op &&
346           op < Operation::EndTernaryOperationsSingleOutput &&
347           cpp::is_floating_point_v<OutputType> &&
348           cpp::is_same_v<InputType, TernaryInput<OutputType>>);
349 }
350 
351 template <Operation op, typename InputType, typename OutputType>
352 __attribute__((no_sanitize("address"))) cpp::enable_if_t<
353     is_valid_operation<op, InputType, OutputType>(),
354     internal::MPFRMatcher<op, /*is_silent*/ false, InputType, OutputType>>
355 get_mpfr_matcher(InputType input, OutputType output_unused,
356                  double ulp_tolerance, RoundingMode rounding) {
357   return internal::MPFRMatcher<op, /*is_silent*/ false, InputType, OutputType>(
358       input, ulp_tolerance, rounding);
359 }
360 
361 template <Operation op, typename InputType, typename OutputType>
362 __attribute__((no_sanitize("address"))) cpp::enable_if_t<
363     is_valid_operation<op, InputType, OutputType>(),
364     internal::MPFRMatcher<op, /*is_silent*/ true, InputType, OutputType>>
365 get_silent_mpfr_matcher(InputType input, OutputType output_unused,
366                         double ulp_tolerance, RoundingMode rounding) {
367   return internal::MPFRMatcher<op, /*is_silent*/ true, InputType, OutputType>(
368       input, ulp_tolerance, rounding);
369 }
370 
371 template <typename T> T round(T x, RoundingMode mode);
372 
373 template <typename T> bool round_to_long(T x, long &result);
374 template <typename T> bool round_to_long(T x, RoundingMode mode, long &result);
375 
376 } // namespace mpfr
377 } // namespace testing
378 } // namespace LIBC_NAMESPACE_DECL
379 
380 // GET_MPFR_DUMMY_ARG is going to be added to the end of GET_MPFR_MACRO as a
381 // simple way to avoid the compiler warning `gnu-zero-variadic-macro-arguments`.
382 #define GET_MPFR_DUMMY_ARG(...) 0
383 
384 #define GET_MPFR_MACRO(__1, __2, __3, __4, __5, __NAME, ...) __NAME
385 
386 #define EXPECT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance)       \
387   EXPECT_THAT(match_value,                                                     \
388               LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<op>(             \
389                   input, match_value, ulp_tolerance,                           \
390                   LIBC_NAMESPACE::testing::mpfr::RoundingMode::Nearest))
391 
392 #define EXPECT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,      \
393                                    rounding)                                   \
394   EXPECT_THAT(match_value,                                                     \
395               LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<op>(             \
396                   input, match_value, ulp_tolerance, rounding))
397 
398 #define EXPECT_MPFR_MATCH(...)                                                 \
399   GET_MPFR_MACRO(__VA_ARGS__, EXPECT_MPFR_MATCH_ROUNDING,                      \
400                  EXPECT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG)                \
401   (__VA_ARGS__)
402 
403 #define TEST_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,        \
404                                  rounding)                                     \
405   LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<op>(input, match_value,      \
406                                                       ulp_tolerance, rounding) \
407       .match(match_value)
408 
409 #define TEST_MPFR_MATCH(...)                                                   \
410   GET_MPFR_MACRO(__VA_ARGS__, TEST_MPFR_MATCH_ROUNDING,                        \
411                  EXPECT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG)                \
412   (__VA_ARGS__)
413 
414 #define EXPECT_MPFR_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance)  \
415   {                                                                            \
416     namespace mpfr = LIBC_NAMESPACE::testing::mpfr;                            \
417     mpfr::ForceRoundingMode __r1(mpfr::RoundingMode::Nearest);                 \
418     if (__r1.success) {                                                        \
419       EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                 \
420                         mpfr::RoundingMode::Nearest);                          \
421     }                                                                          \
422     mpfr::ForceRoundingMode __r2(mpfr::RoundingMode::Upward);                  \
423     if (__r2.success) {                                                        \
424       EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                 \
425                         mpfr::RoundingMode::Upward);                           \
426     }                                                                          \
427     mpfr::ForceRoundingMode __r3(mpfr::RoundingMode::Downward);                \
428     if (__r3.success) {                                                        \
429       EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                 \
430                         mpfr::RoundingMode::Downward);                         \
431     }                                                                          \
432     mpfr::ForceRoundingMode __r4(mpfr::RoundingMode::TowardZero);              \
433     if (__r4.success) {                                                        \
434       EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                 \
435                         mpfr::RoundingMode::TowardZero);                       \
436     }                                                                          \
437   }
438 
439 #define TEST_MPFR_MATCH_ROUNDING_SILENTLY(op, input, match_value,              \
440                                           ulp_tolerance, rounding)             \
441   LIBC_NAMESPACE::testing::mpfr::get_silent_mpfr_matcher<op>(                  \
442       input, match_value, ulp_tolerance, rounding)                             \
443       .match(match_value)
444 
445 #define ASSERT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance)       \
446   ASSERT_THAT(match_value,                                                     \
447               LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<op>(             \
448                   input, match_value, ulp_tolerance,                           \
449                   LIBC_NAMESPACE::testing::mpfr::RoundingMode::Nearest))
450 
451 #define ASSERT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,      \
452                                    rounding)                                   \
453   ASSERT_THAT(match_value,                                                     \
454               LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<op>(             \
455                   input, match_value, ulp_tolerance, rounding))
456 
457 #define ASSERT_MPFR_MATCH(...)                                                 \
458   GET_MPFR_MACRO(__VA_ARGS__, ASSERT_MPFR_MATCH_ROUNDING,                      \
459                  ASSERT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG)                \
460   (__VA_ARGS__)
461 
462 #define ASSERT_MPFR_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance)  \
463   {                                                                            \
464     namespace mpfr = LIBC_NAMESPACE::testing::mpfr;                            \
465     mpfr::ForceRoundingMode __r1(mpfr::RoundingMode::Nearest);                 \
466     if (__r1.success) {                                                        \
467       ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                 \
468                         mpfr::RoundingMode::Nearest);                          \
469     }                                                                          \
470     mpfr::ForceRoundingMode __r2(mpfr::RoundingMode::Upward);                  \
471     if (__r2.success) {                                                        \
472       ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                 \
473                         mpfr::RoundingMode::Upward);                           \
474     }                                                                          \
475     mpfr::ForceRoundingMode __r3(mpfr::RoundingMode::Downward);                \
476     if (__r3.success) {                                                        \
477       ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                 \
478                         mpfr::RoundingMode::Downward);                         \
479     }                                                                          \
480     mpfr::ForceRoundingMode __r4(mpfr::RoundingMode::TowardZero);              \
481     if (__r4.success) {                                                        \
482       ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                 \
483                         mpfr::RoundingMode::TowardZero);                       \
484     }                                                                          \
485   }
486 
487 #endif // LLVM_LIBC_UTILS_MPFRWRAPPER_MPFRUTILS_H
488