xref: /llvm-project/libc/utils/MPCWrapper/MPCUtils.h (revision 7f37b34d31914120a5bb6bd341e7616773df7613)
1 //===-- MPCUtils.h ----------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_UTILS_MPCWRAPPER_MPCUTILS_H
10 #define LLVM_LIBC_UTILS_MPCWRAPPER_MPCUTILS_H
11 
12 #include "src/__support/CPP/type_traits.h"
13 #include "src/__support/complex_type.h"
14 #include "src/__support/macros/config.h"
15 #include "src/__support/macros/properties/complex_types.h"
16 #include "src/__support/macros/properties/types.h"
17 #include "test/UnitTest/RoundingModeUtils.h"
18 #include "test/UnitTest/Test.h"
19 
20 #include <stdint.h>
21 
22 namespace LIBC_NAMESPACE_DECL {
23 namespace testing {
24 namespace mpc {
25 
26 enum class Operation {
27   // Operations which take a single complex floating point number as input
28   // and produce a single floating point number as output which has the same
29   // floating point type as the real/imaginary part of the input.
30   BeginUnaryOperationsSingleOutputDifferentOutputType,
31   Carg,
32   Cabs,
33   EndUnaryOperationsSingleOutputDifferentOutputType,
34 
35   // Operations which take a single complex floating point number as input
36   // and produce a single complex floating point number of the same kind
37   // as output.
38   BeginUnaryOperationsSingleOutputSameOutputType,
39   Cproj,
40   Csqrt,
41   Clog,
42   Cexp,
43   Csinh,
44   Ccosh,
45   Ctanh,
46   Casinh,
47   Cacosh,
48   Catanh,
49   Csin,
50   Ccos,
51   Ctan,
52   Casin,
53   Cacos,
54   Catan,
55   EndUnaryOperationsSingleOutputSameOutputType,
56 
57   // Operations which take two complex floating point numbers as input
58   // and produce a single complex floating point number of the same kind
59   // as output.
60   BeginBinaryOperationsSingleOutput,
61   Cpow,
62   EndBinaryOperationsSingleOutput,
63 };
64 
65 using LIBC_NAMESPACE::fputil::testing::RoundingMode;
66 
67 template <typename T> struct BinaryInput {
68   static_assert(LIBC_NAMESPACE::cpp::is_complex_v<T>,
69                 "Template parameter of BinaryInput must be a complex floating "
70                 "point type.");
71 
72   using Type = T;
73   T x, y;
74 };
75 
76 namespace internal {
77 
78 template <typename InputType, typename OutputType>
79 bool compare_unary_operation_single_output_same_type(Operation op,
80                                                      InputType input,
81                                                      OutputType libc_output,
82                                                      double ulp_tolerance,
83                                                      RoundingMode rounding);
84 
85 template <typename InputType, typename OutputType>
86 bool compare_unary_operation_single_output_different_type(
87     Operation op, InputType input, OutputType libc_output, double ulp_tolerance,
88     RoundingMode rounding);
89 
90 template <typename InputType, typename OutputType>
91 bool compare_binary_operation_one_output(Operation op,
92                                          const BinaryInput<InputType> &input,
93                                          OutputType libc_output,
94                                          double ulp_tolerance,
95                                          RoundingMode rounding);
96 
97 template <typename InputType, typename OutputType>
98 void explain_unary_operation_single_output_same_type_error(
99     Operation op, InputType input, OutputType match_value, double ulp_tolerance,
100     RoundingMode rounding);
101 
102 template <typename InputType, typename OutputType>
103 void explain_unary_operation_single_output_different_type_error(
104     Operation op, InputType input, OutputType match_value, double ulp_tolerance,
105     RoundingMode rounding);
106 
107 template <typename InputType, typename OutputType>
108 void explain_binary_operation_one_output_error(
109     Operation op, const BinaryInput<InputType> &input, OutputType match_value,
110     double ulp_tolerance, RoundingMode rounding);
111 
112 template <Operation op, typename InputType, typename OutputType>
113 class MPCMatcher : public testing::Matcher<OutputType> {
114 private:
115   InputType input;
116   OutputType match_value;
117   double ulp_tolerance;
118   RoundingMode rounding;
119 
120 public:
121   MPCMatcher(InputType testInput, double ulp_tolerance, RoundingMode rounding)
122       : input(testInput), ulp_tolerance(ulp_tolerance), rounding(rounding) {}
123 
124   bool match(OutputType libcResult) {
125     match_value = libcResult;
126     return match(input, match_value);
127   }
128 
129   void explainError() override { // NOLINT
130     explain_error(input, match_value);
131   }
132 
133 private:
134   template <typename InType, typename OutType>
135   bool match(InType in, OutType out) {
136     if (cpp::is_same_v<InType, OutType>) {
137       return compare_unary_operation_single_output_same_type(
138           op, in, out, ulp_tolerance, rounding);
139     } else {
140       return compare_unary_operation_single_output_different_type(
141           op, in, out, ulp_tolerance, rounding);
142     }
143   }
144 
145   template <typename T, typename U>
146   bool match(const BinaryInput<T> &in, U out) {
147     return compare_binary_operation_one_output(op, in, out, ulp_tolerance,
148                                                rounding);
149   }
150 
151   template <typename InType, typename OutType>
152   void explain_error(InType in, OutType out) {
153     if (cpp::is_same_v<InType, OutType>) {
154       explain_unary_operation_single_output_same_type_error(
155           op, in, out, ulp_tolerance, rounding);
156     } else {
157       explain_unary_operation_single_output_different_type_error(
158           op, in, out, ulp_tolerance, rounding);
159     }
160   }
161 
162   template <typename T, typename U>
163   void explain_error(const BinaryInput<T> &in, U out) {
164     explain_binary_operation_one_output_error(op, in, out, ulp_tolerance,
165                                               rounding);
166   }
167 };
168 
169 } // namespace internal
170 
171 // Return true if the input and ouput types for the operation op are valid
172 // types.
173 template <Operation op, typename InputType, typename OutputType>
174 constexpr bool is_valid_operation() {
175   return (Operation::BeginBinaryOperationsSingleOutput < op &&
176           op < Operation::EndBinaryOperationsSingleOutput &&
177           cpp::is_complex_type_same<InputType, OutputType> &&
178           cpp::is_complex_v<InputType>) ||
179          (Operation::BeginUnaryOperationsSingleOutputSameOutputType < op &&
180           op < Operation::EndUnaryOperationsSingleOutputSameOutputType &&
181           cpp::is_complex_type_same<InputType, OutputType> &&
182           cpp::is_complex_v<InputType>) ||
183          (Operation::BeginUnaryOperationsSingleOutputDifferentOutputType < op &&
184           op < Operation::EndUnaryOperationsSingleOutputDifferentOutputType &&
185           cpp::is_same_v<make_real_t<InputType>, OutputType> &&
186           cpp::is_complex_v<InputType>);
187 }
188 
189 template <Operation op, typename InputType, typename OutputType>
190 cpp::enable_if_t<is_valid_operation<op, InputType, OutputType>(),
191                  internal::MPCMatcher<op, InputType, OutputType>>
192 get_mpc_matcher(InputType input, [[maybe_unused]] OutputType output,
193                 double ulp_tolerance, RoundingMode rounding) {
194   return internal::MPCMatcher<op, InputType, OutputType>(input, ulp_tolerance,
195                                                          rounding);
196 }
197 
198 } // namespace mpc
199 } // namespace testing
200 } // namespace LIBC_NAMESPACE_DECL
201 
202 #define EXPECT_MPC_MATCH_DEFAULT(op, input, match_value, ulp_tolerance)        \
203   EXPECT_THAT(match_value,                                                     \
204               LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>(               \
205                   input, match_value, ulp_tolerance,                           \
206                   LIBC_NAMESPACE::fputil::testing::RoundingMode::Nearest))
207 
208 #define EXPECT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,       \
209                                   rounding)                                    \
210   EXPECT_THAT(match_value, LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>(  \
211                                input, match_value, ulp_tolerance, rounding))
212 
213 #define EXPECT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value,           \
214                                              ulp_tolerance, rounding)          \
215   {                                                                            \
216     MPCRND::ForceRoundingMode __r(rounding);                                   \
217     if (__r.success) {                                                         \
218       EXPECT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,         \
219                                 rounding);                                     \
220     }                                                                          \
221   }
222 
223 #define EXPECT_MPC_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance)   \
224   {                                                                            \
225     namespace MPCRND = LIBC_NAMESPACE::fputil::testing;                        \
226     for (int i = 0; i < 4; i++) {                                              \
227       MPCRND::RoundingMode r_mode = static_cast<MPCRND::RoundingMode>(i);      \
228       EXPECT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value,             \
229                                            ulp_tolerance, r_mode);             \
230     }                                                                          \
231   }
232 
233 #define TEST_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,         \
234                                 rounding)                                      \
235   LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>(input, match_value,        \
236                                                     ulp_tolerance, rounding)   \
237       .match(match_value)
238 
239 #define ASSERT_MPC_MATCH_DEFAULT(op, input, match_value, ulp_tolerance)        \
240   ASSERT_THAT(match_value,                                                     \
241               LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>(               \
242                   input, match_value, ulp_tolerance,                           \
243                   LIBC_NAMESPACE::fputil::testing::RoundingMode::Nearest))
244 
245 #define ASSERT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,       \
246                                   rounding)                                    \
247   ASSERT_THAT(match_value, LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>(  \
248                                input, match_value, ulp_tolerance, rounding))
249 
250 #define ASSERT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value,           \
251                                              ulp_tolerance, rounding)          \
252   {                                                                            \
253     MPCRND::ForceRoundingMode __r(rounding);                                   \
254     if (__r.success) {                                                         \
255       ASSERT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,         \
256                                 rounding);                                     \
257     }                                                                          \
258   }
259 
260 #define ASSERT_MPC_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance)   \
261   {                                                                            \
262     namespace MPCRND = LIBC_NAMESPACE::fputil::testing;                        \
263     for (int i = 0; i < 4; i++) {                                              \
264       MPCRND::RoundingMode r_mode = static_cast<MPCRND::RoundingMode>(i);      \
265       ASSERT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value,             \
266                                            ulp_tolerance, r_mode);             \
267     }                                                                          \
268   }
269 
270 #endif // LLVM_LIBC_UTILS_MPCWRAPPER_MPCUTILS_H
271