xref: /llvm-project/libc/utils/MPCWrapper/MPCUtils.cpp (revision 7f37b34d31914120a5bb6bd341e7616773df7613)
1 //===-- Utils which wrap MPC ----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "MPCUtils.h"
10 
11 #include "src/__support/CPP/array.h"
12 #include "src/__support/CPP/stringstream.h"
13 #include "utils/MPFRWrapper/MPCommon.h"
14 
15 #include <stdint.h>
16 
17 #include "mpc.h"
18 
19 template <typename T> using FPBits = LIBC_NAMESPACE::fputil::FPBits<T>;
20 
21 namespace LIBC_NAMESPACE_DECL {
22 namespace testing {
23 namespace mpc {
24 
25 static inline cpp::string str(RoundingMode mode) {
26   switch (mode) {
27   case RoundingMode::Upward:
28     return "MPFR_RNDU";
29   case RoundingMode::Downward:
30     return "MPFR_RNDD";
31   case RoundingMode::TowardZero:
32     return "MPFR_RNDZ";
33   case RoundingMode::Nearest:
34     return "MPFR_RNDN";
35   }
36 }
37 
38 class MPCNumber {
39 private:
40   unsigned int precision;
41   mpc_t value;
42   mpc_rnd_t mpc_rounding;
43 
44 public:
45   explicit MPCNumber(unsigned int p) : precision(p), mpc_rounding(MPC_RNDNN) {
46     mpc_init2(value, precision);
47   }
48 
49   MPCNumber() : precision(256), mpc_rounding(MPC_RNDNN) {
50     mpc_init2(value, 256);
51   }
52 
53   MPCNumber(unsigned int p, mpc_rnd_t rnd) : precision(p), mpc_rounding(rnd) {
54     mpc_init2(value, precision);
55   }
56 
57   template <typename XType,
58             cpp::enable_if_t<cpp::is_same_v<_Complex float, XType>, bool> = 0>
59   MPCNumber(XType x,
60             unsigned int precision = mpfr::ExtraPrecision<float>::VALUE,
61             RoundingMode rnd = RoundingMode::Nearest)
62       : precision(precision),
63         mpc_rounding(MPC_RND(mpfr::get_mpfr_rounding_mode(rnd),
64                              mpfr::get_mpfr_rounding_mode(rnd))) {
65     mpc_init2(value, precision);
66     Complex<float> x_c = cpp::bit_cast<Complex<float>>(x);
67     mpfr_t real, imag;
68     mpfr_init2(real, precision);
69     mpfr_init2(imag, precision);
70     mpfr_set_flt(real, x_c.real, mpfr::get_mpfr_rounding_mode(rnd));
71     mpfr_set_flt(imag, x_c.imag, mpfr::get_mpfr_rounding_mode(rnd));
72     mpc_set_fr_fr(value, real, imag, mpc_rounding);
73     mpfr_clear(real);
74     mpfr_clear(imag);
75   }
76 
77   template <typename XType,
78             cpp::enable_if_t<cpp::is_same_v<_Complex double, XType>, bool> = 0>
79   MPCNumber(XType x,
80             unsigned int precision = mpfr::ExtraPrecision<double>::VALUE,
81             RoundingMode rnd = RoundingMode::Nearest)
82       : precision(precision),
83         mpc_rounding(MPC_RND(mpfr::get_mpfr_rounding_mode(rnd),
84                              mpfr::get_mpfr_rounding_mode(rnd))) {
85     mpc_init2(value, precision);
86     Complex<double> x_c = cpp::bit_cast<Complex<double>>(x);
87     mpc_set_d_d(value, x_c.real, x_c.imag, mpc_rounding);
88   }
89 
90   MPCNumber(const MPCNumber &other)
91       : precision(other.precision), mpc_rounding(other.mpc_rounding) {
92     mpc_init2(value, precision);
93     mpc_set(value, other.value, mpc_rounding);
94   }
95 
96   ~MPCNumber() { mpc_clear(value); }
97 
98   MPCNumber &operator=(const MPCNumber &rhs) {
99     precision = rhs.precision;
100     mpc_rounding = rhs.mpc_rounding;
101     mpc_init2(value, precision);
102     mpc_set(value, rhs.value, mpc_rounding);
103     return *this;
104   }
105 
106   void setValue(mpc_t val) const { mpc_set(val, value, mpc_rounding); }
107 
108   mpc_t &getValue() { return value; }
109 
110   MPCNumber carg() const {
111     mpfr_t res;
112     MPCNumber result(precision, mpc_rounding);
113 
114     mpfr_init2(res, precision);
115 
116     mpc_arg(res, value, MPC_RND_RE(mpc_rounding));
117     mpc_set_fr(result.value, res, mpc_rounding);
118 
119     mpfr_clear(res);
120 
121     return result;
122   }
123 
124   MPCNumber cproj() const {
125     MPCNumber result(precision, mpc_rounding);
126     mpc_proj(result.value, value, mpc_rounding);
127     return result;
128   }
129 };
130 
131 namespace internal {
132 
133 template <typename InputType>
134 cpp::enable_if_t<cpp::is_complex_v<InputType>, MPCNumber>
135 unary_operation(Operation op, InputType input, unsigned int precision,
136                 RoundingMode rounding) {
137   MPCNumber mpcInput(input, precision, rounding);
138   switch (op) {
139   case Operation::Carg:
140     return mpcInput.carg();
141   case Operation::Cproj:
142     return mpcInput.cproj();
143   default:
144     __builtin_unreachable();
145   }
146 }
147 
148 template <typename InputType, typename OutputType>
149 bool compare_unary_operation_single_output_same_type(Operation op,
150                                                      InputType input,
151                                                      OutputType libc_result,
152                                                      double ulp_tolerance,
153                                                      RoundingMode rounding) {
154 
155   unsigned int precision =
156       mpfr::get_precision<make_real_t<InputType>>(ulp_tolerance);
157 
158   MPCNumber mpc_result;
159   mpc_result = unary_operation(op, input, precision, rounding);
160 
161   mpc_t mpc_result_val;
162   mpc_init2(mpc_result_val, precision);
163   mpc_result.setValue(mpc_result_val);
164 
165   mpfr_t real, imag;
166   mpfr_init2(real, precision);
167   mpfr_init2(imag, precision);
168   mpc_real(real, mpc_result_val, mpfr::get_mpfr_rounding_mode(rounding));
169   mpc_imag(imag, mpc_result_val, mpfr::get_mpfr_rounding_mode(rounding));
170 
171   mpfr::MPFRNumber mpfr_real(real, precision, rounding);
172   mpfr::MPFRNumber mpfr_imag(imag, precision, rounding);
173 
174   double ulp_real = mpfr_real.ulp(
175       (cpp::bit_cast<Complex<make_real_t<InputType>>>(libc_result)).real);
176   double ulp_imag = mpfr_imag.ulp(
177       (cpp::bit_cast<Complex<make_real_t<InputType>>>(libc_result)).imag);
178   mpc_clear(mpc_result_val);
179   mpfr_clear(real);
180   mpfr_clear(imag);
181   return (ulp_real <= ulp_tolerance) && (ulp_imag <= ulp_tolerance);
182 }
183 
184 template bool compare_unary_operation_single_output_same_type(
185     Operation, _Complex float, _Complex float, double, RoundingMode);
186 template bool compare_unary_operation_single_output_same_type(
187     Operation, _Complex double, _Complex double, double, RoundingMode);
188 
189 template <typename InputType, typename OutputType>
190 bool compare_unary_operation_single_output_different_type(
191     Operation op, InputType input, OutputType libc_result, double ulp_tolerance,
192     RoundingMode rounding) {
193 
194   unsigned int precision =
195       mpfr::get_precision<make_real_t<InputType>>(ulp_tolerance);
196 
197   MPCNumber mpc_result;
198   mpc_result = unary_operation(op, input, precision, rounding);
199 
200   mpc_t mpc_result_val;
201   mpc_init2(mpc_result_val, precision);
202   mpc_result.setValue(mpc_result_val);
203 
204   mpfr_t real;
205   mpfr_init2(real, precision);
206   mpc_real(real, mpc_result_val, mpfr::get_mpfr_rounding_mode(rounding));
207 
208   mpfr::MPFRNumber mpfr_real(real, precision, rounding);
209 
210   double ulp_real = mpfr_real.ulp(libc_result);
211   mpc_clear(mpc_result_val);
212   mpfr_clear(real);
213   return (ulp_real <= ulp_tolerance);
214 }
215 
216 template bool compare_unary_operation_single_output_different_type(
217     Operation, _Complex float, float, double, RoundingMode);
218 template bool compare_unary_operation_single_output_different_type(
219     Operation, _Complex double, double, double, RoundingMode);
220 
221 template <typename InputType, typename OutputType>
222 void explain_unary_operation_single_output_different_type_error(
223     Operation op, InputType input, OutputType libc_result, double ulp_tolerance,
224     RoundingMode rounding) {
225 
226   unsigned int precision =
227       mpfr::get_precision<make_real_t<InputType>>(ulp_tolerance);
228 
229   MPCNumber mpc_result;
230   mpc_result = unary_operation(op, input, precision, rounding);
231 
232   mpc_t mpc_result_val;
233   mpc_init2(mpc_result_val, precision);
234   mpc_result.setValue(mpc_result_val);
235 
236   mpfr_t real;
237   mpfr_init2(real, precision);
238   mpc_real(real, mpc_result_val, mpfr::get_mpfr_rounding_mode(rounding));
239 
240   mpfr::MPFRNumber mpfr_result(real, precision, rounding);
241   mpfr::MPFRNumber mpfrLibcResult(libc_result, precision, rounding);
242   mpfr::MPFRNumber mpfrInputReal(
243       cpp::bit_cast<Complex<make_real_t<InputType>>>(input).real, precision,
244       rounding);
245   mpfr::MPFRNumber mpfrInputImag(
246       cpp::bit_cast<Complex<make_real_t<InputType>>>(input).imag, precision,
247       rounding);
248 
249   cpp::array<char, 2048> msg_buf;
250   cpp::StringStream msg(msg_buf);
251   msg << "Match value not within tolerance value of MPFR result:\n"
252       << "  Input: " << mpfrInputReal.str() << " + " << mpfrInputImag.str()
253       << "i\n"
254       << "  Rounding mode: " << str(rounding) << '\n'
255       << "    Libc: " << mpfrLibcResult.str() << '\n'
256       << "    MPC: " << mpfr_result.str() << '\n'
257       << '\n'
258       << "  ULP error: " << mpfr_result.ulp_as_mpfr_number(libc_result).str()
259       << '\n';
260   tlog << msg.str();
261   mpc_clear(mpc_result_val);
262   mpfr_clear(real);
263 }
264 
265 template void explain_unary_operation_single_output_different_type_error(
266     Operation, _Complex float, float, double, RoundingMode);
267 template void explain_unary_operation_single_output_different_type_error(
268     Operation, _Complex double, double, double, RoundingMode);
269 
270 template <typename InputType, typename OutputType>
271 void explain_unary_operation_single_output_same_type_error(
272     Operation op, InputType input, OutputType libc_result, double ulp_tolerance,
273     RoundingMode rounding) {
274 
275   unsigned int precision =
276       mpfr::get_precision<make_real_t<InputType>>(ulp_tolerance);
277 
278   MPCNumber mpc_result;
279   mpc_result = unary_operation(op, input, precision, rounding);
280 
281   mpc_t mpc_result_val;
282   mpc_init2(mpc_result_val, precision);
283   mpc_result.setValue(mpc_result_val);
284 
285   mpfr_t real, imag;
286   mpfr_init2(real, precision);
287   mpfr_init2(imag, precision);
288   mpc_real(real, mpc_result_val, mpfr::get_mpfr_rounding_mode(rounding));
289   mpc_imag(imag, mpc_result_val, mpfr::get_mpfr_rounding_mode(rounding));
290 
291   mpfr::MPFRNumber mpfr_real(real, precision, rounding);
292   mpfr::MPFRNumber mpfr_imag(imag, precision, rounding);
293   mpfr::MPFRNumber mpfrLibcResultReal(
294       cpp::bit_cast<Complex<make_real_t<InputType>>>(libc_result).real,
295       precision, rounding);
296   mpfr::MPFRNumber mpfrLibcResultImag(
297       cpp::bit_cast<Complex<make_real_t<InputType>>>(libc_result).imag,
298       precision, rounding);
299   mpfr::MPFRNumber mpfrInputReal(
300       cpp::bit_cast<Complex<make_real_t<InputType>>>(input).real, precision,
301       rounding);
302   mpfr::MPFRNumber mpfrInputImag(
303       cpp::bit_cast<Complex<make_real_t<InputType>>>(input).imag, precision,
304       rounding);
305 
306   cpp::array<char, 2048> msg_buf;
307   cpp::StringStream msg(msg_buf);
308   msg << "Match value not within tolerance value of MPFR result:\n"
309       << "  Input: " << mpfrInputReal.str() << " + " << mpfrInputImag.str()
310       << "i\n"
311       << "  Rounding mode: " << str(rounding) << " , " << str(rounding) << '\n'
312       << "    Libc: " << mpfrLibcResultReal.str() << " + "
313       << mpfrLibcResultImag.str() << "i\n"
314       << "    MPC: " << mpfr_real.str() << " + " << mpfr_imag.str() << "i\n"
315       << '\n'
316       << "  ULP error: "
317       << mpfr_real
318              .ulp_as_mpfr_number(
319                  cpp::bit_cast<Complex<make_real_t<InputType>>>(libc_result)
320                      .real)
321              .str()
322       << " , "
323       << mpfr_imag
324              .ulp_as_mpfr_number(
325                  cpp::bit_cast<Complex<make_real_t<InputType>>>(libc_result)
326                      .imag)
327              .str()
328       << '\n';
329   tlog << msg.str();
330   mpc_clear(mpc_result_val);
331   mpfr_clear(real);
332   mpfr_clear(imag);
333 }
334 
335 template void explain_unary_operation_single_output_same_type_error(
336     Operation, _Complex float, _Complex float, double, RoundingMode);
337 template void explain_unary_operation_single_output_same_type_error(
338     Operation, _Complex double, _Complex double, double, RoundingMode);
339 
340 } // namespace internal
341 
342 } // namespace mpc
343 } // namespace testing
344 } // namespace LIBC_NAMESPACE_DECL
345