xref: /llvm-project/libc/src/__support/fixed_point/sqrt.h (revision 5ff3ff33ff930e4ec49da7910612d8a41eb068cb)
1ded4ea97Slntue //===-- Calculate square root of fixed point numbers. -----*- C++ -*-=========//
2ded4ea97Slntue //
3ded4ea97Slntue // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ded4ea97Slntue // See https://llvm.org/LICENSE.txt for license information.
5ded4ea97Slntue // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ded4ea97Slntue //
7ded4ea97Slntue //===----------------------------------------------------------------------===//
8ded4ea97Slntue 
9ded4ea97Slntue #ifndef LLVM_LIBC_SRC___SUPPORT_FIXEDPOINT_SQRT_H
10ded4ea97Slntue #define LLVM_LIBC_SRC___SUPPORT_FIXEDPOINT_SQRT_H
11ded4ea97Slntue 
1273aab2f6Slntue #include "include/llvm-libc-macros/stdfix-macros.h"
13ded4ea97Slntue #include "src/__support/CPP/bit.h"
14ad33fe12Slntue #include "src/__support/CPP/limits.h" // CHAR_BIT
15ded4ea97Slntue #include "src/__support/CPP/type_traits.h"
16ded4ea97Slntue #include "src/__support/macros/attributes.h"   // LIBC_INLINE
17*5ff3ff33SPetr Hosek #include "src/__support/macros/config.h"
18ded4ea97Slntue #include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
19ded4ea97Slntue 
20ded4ea97Slntue #include "fx_rep.h"
21ded4ea97Slntue 
22ded4ea97Slntue #ifdef LIBC_COMPILER_HAS_FIXED_POINT
23ded4ea97Slntue 
24*5ff3ff33SPetr Hosek namespace LIBC_NAMESPACE_DECL {
25*5ff3ff33SPetr Hosek namespace fixed_point {
26ded4ea97Slntue 
27ded4ea97Slntue namespace internal {
28ded4ea97Slntue 
29ded4ea97Slntue template <typename T> struct SqrtConfig;
30ded4ea97Slntue 
31ded4ea97Slntue template <> struct SqrtConfig<unsigned short fract> {
32ded4ea97Slntue   using Type = unsigned short fract;
33ded4ea97Slntue   static constexpr int EXTRA_STEPS = 0;
34ded4ea97Slntue 
35ded4ea97Slntue   // Linear approximation for the initial values, with errors bounded by:
36ded4ea97Slntue   //   max(1.5 * 2^-11, eps)
37ded4ea97Slntue   // Generated with Sollya:
38ded4ea97Slntue   // > for i from 4 to 15 do {
39ded4ea97Slntue   //     P = fpminimax(sqrt(x), 1, [|8, 8|], [i * 2^-4, (i + 1)*2^-4],
40ded4ea97Slntue   //                   fixed, absolute);
41ded4ea97Slntue   //     print("{", coeff(P, 1), "uhr,", coeff(P, 0), "uhr},");
42ded4ea97Slntue   //   };
43ad33fe12Slntue   static constexpr Type FIRST_APPROX[12][2] = {
44ded4ea97Slntue       {0x1.e8p-1uhr, 0x1.0cp-2uhr}, {0x1.bap-1uhr, 0x1.28p-2uhr},
45ded4ea97Slntue       {0x1.94p-1uhr, 0x1.44p-2uhr}, {0x1.74p-1uhr, 0x1.6p-2uhr},
46ded4ea97Slntue       {0x1.6p-1uhr, 0x1.74p-2uhr},  {0x1.4ep-1uhr, 0x1.88p-2uhr},
47ded4ea97Slntue       {0x1.3ep-1uhr, 0x1.9cp-2uhr}, {0x1.32p-1uhr, 0x1.acp-2uhr},
48ded4ea97Slntue       {0x1.22p-1uhr, 0x1.c4p-2uhr}, {0x1.18p-1uhr, 0x1.d4p-2uhr},
49ded4ea97Slntue       {0x1.08p-1uhr, 0x1.fp-2uhr},  {0x1.04p-1uhr, 0x1.f8p-2uhr},
50ded4ea97Slntue   };
51ad33fe12Slntue };
52ad33fe12Slntue 
53ad33fe12Slntue template <> struct SqrtConfig<unsigned fract> {
54ad33fe12Slntue   using Type = unsigned fract;
55ad33fe12Slntue   static constexpr int EXTRA_STEPS = 1;
56ad33fe12Slntue 
57ad33fe12Slntue   // Linear approximation for the initial values, with errors bounded by:
58ad33fe12Slntue   //   max(1.5 * 2^-11, eps)
59ad33fe12Slntue   // Generated with Sollya:
607dd4ce48Slntue   // > for i from 4 to 14 do {
61ad33fe12Slntue   //     P = fpminimax(sqrt(x), 1, [|16, 16|], [i * 2^-4, (i + 1)*2^-4],
62ad33fe12Slntue   //                   fixed, absolute);
63ad33fe12Slntue   //     print("{", coeff(P, 1), "ur,", coeff(P, 0), "ur},");
64ad33fe12Slntue   //   };
657dd4ce48Slntue   // For the last interval [15/16, 1), we choose the linear function Q such that
667dd4ce48Slntue   //   Q(1) = 1 and Q(15/16) = P(15/16),
677dd4ce48Slntue   // where P is the polynomial generated by Sollya above for [14/16, 15/16].
687dd4ce48Slntue   // This is to prevent overflow in the last interval [15/16, 1).
69ad33fe12Slntue   static constexpr Type FIRST_APPROX[12][2] = {
70ad33fe12Slntue       {0x1.e378p-1ur, 0x1.0ebp-2ur},  {0x1.b512p-1ur, 0x1.2b94p-2ur},
71ad33fe12Slntue       {0x1.91fp-1ur, 0x1.45dcp-2ur},  {0x1.7622p-1ur, 0x1.5e24p-2ur},
72ad33fe12Slntue       {0x1.5f5ap-1ur, 0x1.74e4p-2ur}, {0x1.4c58p-1ur, 0x1.8a4p-2ur},
73ad33fe12Slntue       {0x1.3c1ep-1ur, 0x1.9e84p-2ur}, {0x1.2e0cp-1ur, 0x1.b1d8p-2ur},
74ad33fe12Slntue       {0x1.21aap-1ur, 0x1.c468p-2ur}, {0x1.16bap-1ur, 0x1.d62cp-2ur},
757dd4ce48Slntue       {0x1.0cfp-1ur, 0x1.e74cp-2ur},  {0x1.039p-1ur, 0x1.f8ep-2ur},
76ad33fe12Slntue   };
77ad33fe12Slntue };
78ad33fe12Slntue 
79ad33fe12Slntue template <> struct SqrtConfig<unsigned long fract> {
80ad33fe12Slntue   using Type = unsigned long fract;
81ad33fe12Slntue   static constexpr int EXTRA_STEPS = 2;
82ad33fe12Slntue 
83ad33fe12Slntue   // Linear approximation for the initial values, with errors bounded by:
84ad33fe12Slntue   //   max(1.5 * 2^-11, eps)
85ad33fe12Slntue   // Generated with Sollya:
867dd4ce48Slntue   // > for i from 4 to 14 do {
87ad33fe12Slntue   //     P = fpminimax(sqrt(x), 1, [|32, 32|], [i * 2^-4, (i + 1)*2^-4],
88ad33fe12Slntue   //                   fixed, absolute);
89ad33fe12Slntue   //     print("{", coeff(P, 1), "ulr,", coeff(P, 0), "ulr},");
90ad33fe12Slntue   //   };
917dd4ce48Slntue   // For the last interval [15/16, 1), we choose the linear function Q such that
927dd4ce48Slntue   //   Q(1) = 1 and Q(15/16) = P(15/16),
937dd4ce48Slntue   // where P is the polynomial generated by Sollya above for [14/16, 15/16].
947dd4ce48Slntue   // This is to prevent overflow in the last interval [15/16, 1).
95ad33fe12Slntue   static constexpr Type FIRST_APPROX[12][2] = {
96ad33fe12Slntue       {0x1.e3779b98p-1ulr, 0x1.0eaff788p-2ulr},
97ad33fe12Slntue       {0x1.b5167872p-1ulr, 0x1.2b908ad4p-2ulr},
98ad33fe12Slntue       {0x1.91f195cap-1ulr, 0x1.45da800cp-2ulr},
99ad33fe12Slntue       {0x1.761ebcb4p-1ulr, 0x1.5e27004cp-2ulr},
100ad33fe12Slntue       {0x1.5f619986p-1ulr, 0x1.74db933cp-2ulr},
101ad33fe12Slntue       {0x1.4c583adep-1ulr, 0x1.8a3fbfccp-2ulr},
102ad33fe12Slntue       {0x1.3c1a591cp-1ulr, 0x1.9e88373cp-2ulr},
103ad33fe12Slntue       {0x1.2e08545ap-1ulr, 0x1.b1dd2534p-2ulr},
104ad33fe12Slntue       {0x1.21b05c0ap-1ulr, 0x1.c45e023p-2ulr},
105ad33fe12Slntue       {0x1.16becd02p-1ulr, 0x1.d624031p-2ulr},
106ad33fe12Slntue       {0x1.0cf49fep-1ulr, 0x1.e743b844p-2ulr},
1077dd4ce48Slntue       {0x1.038cdfcp-1ulr, 0x1.f8e6408p-2ulr},
108ad33fe12Slntue   };
109ad33fe12Slntue };
110ad33fe12Slntue 
111ad33fe12Slntue template <>
112ad33fe12Slntue struct SqrtConfig<unsigned short accum> : SqrtConfig<unsigned fract> {};
113ad33fe12Slntue 
114ad33fe12Slntue template <>
115ad33fe12Slntue struct SqrtConfig<unsigned accum> : SqrtConfig<unsigned long fract> {};
116ad33fe12Slntue 
117ad33fe12Slntue // Integer square root
118ad33fe12Slntue template <> struct SqrtConfig<unsigned short> {
119ad33fe12Slntue   using OutType = unsigned short accum;
120ad33fe12Slntue   using FracType = unsigned fract;
121ad33fe12Slntue   // For fast-but-less-accurate version
122ad33fe12Slntue   using FastFracType = unsigned short fract;
123ad33fe12Slntue   using HalfType = unsigned char;
124ad33fe12Slntue };
125ad33fe12Slntue 
126ad33fe12Slntue template <> struct SqrtConfig<unsigned int> {
127ad33fe12Slntue   using OutType = unsigned accum;
128ad33fe12Slntue   using FracType = unsigned long fract;
129ad33fe12Slntue   // For fast-but-less-accurate version
130ad33fe12Slntue   using FastFracType = unsigned fract;
131ad33fe12Slntue   using HalfType = unsigned short;
132ad33fe12Slntue };
133ad33fe12Slntue 
134ad33fe12Slntue // TODO: unsigned long accum type is 64-bit, and will need 64-bit fract type.
135ad33fe12Slntue // Probably we will use DyadicFloat<64> for intermediate computations instead.
136ded4ea97Slntue 
137ded4ea97Slntue } // namespace internal
138ded4ea97Slntue 
139ad33fe12Slntue // Core computation for sqrt with normalized inputs (0.25 <= x < 1).
140ad33fe12Slntue template <typename Config>
141ad33fe12Slntue LIBC_INLINE constexpr typename Config::Type
142ad33fe12Slntue sqrt_core(typename Config::Type x_frac) {
143ad33fe12Slntue   using FracType = typename Config::Type;
144ad33fe12Slntue   using FXRep = FXRep<FracType>;
145ad33fe12Slntue   using StorageType = typename FXRep::StorageType;
146ad33fe12Slntue   // Exact case:
147ad33fe12Slntue   if (x_frac == FXRep::ONE_FOURTH())
148ad33fe12Slntue     return FXRep::ONE_HALF();
149ad33fe12Slntue 
150ad33fe12Slntue   // Use use Newton method to approximate sqrt(a):
151ad33fe12Slntue   //   x_{n + 1} = 1/2 (x_n + a / x_n)
152ad33fe12Slntue   // For the initial values, we choose x_0
153ad33fe12Slntue 
154ad33fe12Slntue   // Use the leading 4 bits to do look up for sqrt(x).
155ad33fe12Slntue   // After normalization, 0.25 <= x_frac < 1, so the leading 4 bits of x_frac
156ad33fe12Slntue   // are between 0b0100 and 0b1111.  Hence the lookup table only needs 12
157ad33fe12Slntue   // entries, and we can get the index by subtracting the leading 4 bits of
158ad33fe12Slntue   // x_frac by 4 = 0b0100.
159ad33fe12Slntue   StorageType x_bit = cpp::bit_cast<StorageType>(x_frac);
160ad33fe12Slntue   int index = (static_cast<int>(x_bit >> (FXRep::TOTAL_LEN - 4))) - 4;
161ad33fe12Slntue   FracType a = Config::FIRST_APPROX[index][0];
162ad33fe12Slntue   FracType b = Config::FIRST_APPROX[index][1];
163ad33fe12Slntue 
164ad33fe12Slntue   // Initial approximation step.
165ad33fe12Slntue   // Estimated error bounds: | r - sqrt(x_frac) | < max(1.5 * 2^-11, eps).
166ad33fe12Slntue   FracType r = a * x_frac + b;
167ad33fe12Slntue 
168ad33fe12Slntue   // Further Newton-method iterations for square-root:
169ad33fe12Slntue   //   x_{n + 1} = 0.5 * (x_n + a / x_n)
170ad33fe12Slntue   // We distribute and do the multiplication by 0.5 first to avoid overflow.
171ad33fe12Slntue   // TODO: Investigate the performance and accuracy of using division-free
172ad33fe12Slntue   // iterations from:
173ad33fe12Slntue   //   Blanchard, J. D. and Chamberland, M., "Newton's Method Without Division",
174ad33fe12Slntue   //   The American Mathematical Monthly (2023).
175ad33fe12Slntue   //   https://chamberland.math.grinnell.edu/papers/newton.pdf
176ad33fe12Slntue   for (int i = 0; i < Config::EXTRA_STEPS; ++i)
177ad33fe12Slntue     r = (r >> 1) + (x_frac >> 1) / r;
178ad33fe12Slntue 
179ad33fe12Slntue   return r;
180ad33fe12Slntue }
181ad33fe12Slntue 
182ded4ea97Slntue template <typename T>
183ded4ea97Slntue LIBC_INLINE constexpr cpp::enable_if_t<cpp::is_fixed_point_v<T>, T> sqrt(T x) {
184ded4ea97Slntue   using BitType = typename FXRep<T>::StorageType;
185ded4ea97Slntue   BitType x_bit = cpp::bit_cast<BitType>(x);
186ded4ea97Slntue 
187ded4ea97Slntue   if (LIBC_UNLIKELY(x_bit == 0))
188ded4ea97Slntue     return FXRep<T>::ZERO();
189ded4ea97Slntue 
190ded4ea97Slntue   int leading_zeros = cpp::countl_zero(x_bit);
191ded4ea97Slntue   constexpr int STORAGE_LENGTH = sizeof(BitType) * CHAR_BIT;
192ded4ea97Slntue   constexpr int EXP_ADJUSTMENT = STORAGE_LENGTH - FXRep<T>::FRACTION_LEN - 1;
193ded4ea97Slntue   // x_exp is the real exponent of the leading bit of x.
194ded4ea97Slntue   int x_exp = EXP_ADJUSTMENT - leading_zeros;
195ded4ea97Slntue   int shift = EXP_ADJUSTMENT - 1 - (x_exp & (~1));
196ded4ea97Slntue   // Normalize.
197ded4ea97Slntue   x_bit <<= shift;
198ded4ea97Slntue   using FracType = typename internal::SqrtConfig<T>::Type;
199ded4ea97Slntue   FracType x_frac = cpp::bit_cast<FracType>(x_bit);
200ded4ea97Slntue 
201ad33fe12Slntue   // Compute sqrt(x_frac) using Newton-method.
202ad33fe12Slntue   FracType r = sqrt_core<internal::SqrtConfig<T>>(x_frac);
203ded4ea97Slntue 
204ded4ea97Slntue   // Re-scaling
205ded4ea97Slntue   r >>= EXP_ADJUSTMENT - (x_exp >> 1);
206ded4ea97Slntue 
207ded4ea97Slntue   // Return result.
208ded4ea97Slntue   return cpp::bit_cast<T>(r);
209ded4ea97Slntue }
210ded4ea97Slntue 
211ad33fe12Slntue // Integer square root - Accurate version:
212ad33fe12Slntue // Absolute errors < 2^(-fraction length).
213ad33fe12Slntue template <typename T>
214ad33fe12Slntue LIBC_INLINE constexpr typename internal::SqrtConfig<T>::OutType isqrt(T x) {
215ad33fe12Slntue   using OutType = typename internal::SqrtConfig<T>::OutType;
216ad33fe12Slntue   using FracType = typename internal::SqrtConfig<T>::FracType;
217ad33fe12Slntue 
218ad33fe12Slntue   if (x == 0)
219ad33fe12Slntue     return FXRep<OutType>::ZERO();
220ad33fe12Slntue 
221ad33fe12Slntue   // Normalize the leading bits to the first two bits.
222ad33fe12Slntue   // Shift and then Bit cast x to x_frac gives us:
223ad33fe12Slntue   //   x = 2^(FRACTION_LEN + 1 - shift) * x_frac;
224ad33fe12Slntue   int leading_zeros = cpp::countl_zero(x);
225ad33fe12Slntue   int shift = ((leading_zeros >> 1) << 1);
226ad33fe12Slntue   x <<= shift;
227ad33fe12Slntue   // Convert to frac type and compute square root.
228ad33fe12Slntue   FracType x_frac = cpp::bit_cast<FracType>(x);
229ad33fe12Slntue   FracType r = sqrt_core<internal::SqrtConfig<FracType>>(x_frac);
230ad33fe12Slntue   // To rescale back to the OutType (Accum)
231ad33fe12Slntue   r >>= (shift >> 1);
232ad33fe12Slntue 
233ad33fe12Slntue   return cpp::bit_cast<OutType>(r);
234ad33fe12Slntue }
235ad33fe12Slntue 
236ad33fe12Slntue // Integer square root - Fast but less accurate version:
237ad33fe12Slntue // Relative errors < 2^(-fraction length).
238ad33fe12Slntue template <typename T>
239ad33fe12Slntue LIBC_INLINE constexpr typename internal::SqrtConfig<T>::OutType
240ad33fe12Slntue isqrt_fast(T x) {
241ad33fe12Slntue   using OutType = typename internal::SqrtConfig<T>::OutType;
242ad33fe12Slntue   using FracType = typename internal::SqrtConfig<T>::FastFracType;
243ad33fe12Slntue   using StorageType = typename FXRep<FracType>::StorageType;
244ad33fe12Slntue 
245ad33fe12Slntue   if (x == 0)
246ad33fe12Slntue     return FXRep<OutType>::ZERO();
247ad33fe12Slntue 
248ad33fe12Slntue   // Normalize the leading bits to the first two bits.
249ad33fe12Slntue   // Shift and then Bit cast x to x_frac gives us:
250ad33fe12Slntue   //   x = 2^(FRACTION_LEN + 1 - shift) * x_frac;
251ad33fe12Slntue   int leading_zeros = cpp::countl_zero(x);
252ad33fe12Slntue   int shift = (leading_zeros & (~1));
253ad33fe12Slntue   x <<= shift;
254ad33fe12Slntue   // Convert to frac type and compute square root.
255ad33fe12Slntue   FracType x_frac = cpp::bit_cast<FracType>(
256ad33fe12Slntue       static_cast<StorageType>(x >> FXRep<FracType>::FRACTION_LEN));
257ad33fe12Slntue   OutType r =
258ad33fe12Slntue       static_cast<OutType>(sqrt_core<internal::SqrtConfig<FracType>>(x_frac));
259ad33fe12Slntue   // To rescale back to the OutType (Accum)
260ad33fe12Slntue   r <<= (FXRep<OutType>::INTEGRAL_LEN - (shift >> 1));
261ad33fe12Slntue   return cpp::bit_cast<OutType>(r);
262ad33fe12Slntue }
263ad33fe12Slntue 
264*5ff3ff33SPetr Hosek } // namespace fixed_point
265*5ff3ff33SPetr Hosek } // namespace LIBC_NAMESPACE_DECL
266ded4ea97Slntue 
267ded4ea97Slntue #endif // LIBC_COMPILER_HAS_FIXED_POINT
268ded4ea97Slntue 
269ded4ea97Slntue #endif // LLVM_LIBC_SRC___SUPPORT_FIXEDPOINT_SQRT_H
270