xref: /llvm-project/libclc/generic/lib/math/clc_powr.cl (revision 78b5bb702fe97fe85f66d72598d0dfa7c49fe001)
1/*
2 * Copyright (c) 2014 Advanced Micro Devices, Inc.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a copy
5 * of this software and associated documentation files (the "Software"), to deal
6 * in the Software without restriction, including without limitation the rights
7 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 * copies of the Software, and to permit persons to whom the Software is
9 * furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice shall be included in
12 * all copies or substantial portions of the Software.
13 *
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 * THE SOFTWARE.
21 */
22
23#include <clc/clc.h>
24#include <clc/clcmacro.h>
25#include <clc/math/clc_fabs.h>
26#include <clc/math/clc_mad.h>
27#include <clc/math/clc_subnormal_config.h>
28#include <clc/math/math.h>
29#include <clc/math/tables.h>
30
31// compute pow using log and exp
32// x^y = exp(y * log(x))
33//
34// we take care not to lose precision in the intermediate steps
35//
36// When computing log, calculate it in splits,
37//
38// r = f * (p_invead + p_inv_tail)
39// r = rh + rt
40//
41// calculate log polynomial using r, in end addition, do
42// poly = poly + ((rh-r) + rt)
43//
44// lth = -r
45// ltt = ((xexp * log2_t) - poly) + logT
46// lt = lth + ltt
47//
48// lh = (xexp * log2_h) + logH
49// l = lh + lt
50//
51// Calculate final log answer as gh and gt,
52// gh = l & higher-half bits
53// gt = (((ltt - (lt - lth)) + ((lh - l) + lt)) + (l - gh))
54//
55// yh = y & higher-half bits
56// yt = y - yh
57//
58// Before entering computation of exp,
59// vs = ((yt*gt + yt*gh) + yh*gt)
60// v = vs + yh*gh
61// vt = ((yh*gh - v) + vs)
62//
63// In calculation of exp, add vt to r that is used for poly
64// At the end of exp, do
65// ((((expT * poly) + expT) + expH*poly) + expH)
66
67_CLC_DEF _CLC_OVERLOAD float __clc_powr(float x, float y) {
68  int ix = as_int(x);
69  int ax = ix & EXSIGNBIT_SP32;
70  int xpos = ix == ax;
71
72  int iy = as_int(y);
73  int ay = iy & EXSIGNBIT_SP32;
74  int ypos = iy == ay;
75
76  // Extra precise log calculation
77  // First handle case that x is close to 1
78  float r = 1.0f - as_float(ax);
79  int near1 = __clc_fabs(r) < 0x1.0p-4f;
80  float r2 = r * r;
81
82  // Coefficients are just 1/3, 1/4, 1/5 and 1/6
83  float poly = __clc_mad(
84      r,
85      __clc_mad(r,
86                __clc_mad(r, __clc_mad(r, 0x1.24924ap-3f, 0x1.555556p-3f),
87                          0x1.99999ap-3f),
88                0x1.000000p-2f),
89      0x1.555556p-2f);
90
91  poly *= r2 * r;
92
93  float lth_near1 = -r2 * 0.5f;
94  float ltt_near1 = -poly;
95  float lt_near1 = lth_near1 + ltt_near1;
96  float lh_near1 = -r;
97  float l_near1 = lh_near1 + lt_near1;
98
99  // Computations for x not near 1
100  int m = (int)(ax >> EXPSHIFTBITS_SP32) - EXPBIAS_SP32;
101  float mf = (float)m;
102  int ixs = as_int(as_float(ax | 0x3f800000) - 1.0f);
103  float mfs = (float)((ixs >> EXPSHIFTBITS_SP32) - 253);
104  int c = m == -127;
105  int ixn = c ? ixs : ax;
106  float mfn = c ? mfs : mf;
107
108  int indx = (ixn & 0x007f0000) + ((ixn & 0x00008000) << 1);
109
110  // F - Y
111  float f = as_float(0x3f000000 | indx) -
112            as_float(0x3f000000 | (ixn & MANTBITS_SP32));
113
114  indx = indx >> 16;
115  float2 tv = USE_TABLE(log_inv_tbl_ep, indx);
116  float rh = f * tv.s0;
117  float rt = f * tv.s1;
118  r = rh + rt;
119
120  poly = __clc_mad(r, __clc_mad(r, 0x1.0p-2f, 0x1.555556p-2f), 0x1.0p-1f) *
121         (r * r);
122  poly += (rh - r) + rt;
123
124  const float LOG2_HEAD = 0x1.62e000p-1f;  // 0.693115234
125  const float LOG2_TAIL = 0x1.0bfbe8p-15f; // 0.0000319461833
126  tv = USE_TABLE(loge_tbl, indx);
127  float lth = -r;
128  float ltt = __clc_mad(mfn, LOG2_TAIL, -poly) + tv.s1;
129  float lt = lth + ltt;
130  float lh = __clc_mad(mfn, LOG2_HEAD, tv.s0);
131  float l = lh + lt;
132
133  // Select near 1 or not
134  lth = near1 ? lth_near1 : lth;
135  ltt = near1 ? ltt_near1 : ltt;
136  lt = near1 ? lt_near1 : lt;
137  lh = near1 ? lh_near1 : lh;
138  l = near1 ? l_near1 : l;
139
140  float gh = as_float(as_int(l) & 0xfffff000);
141  float gt = ((ltt - (lt - lth)) + ((lh - l) + lt)) + (l - gh);
142
143  float yh = as_float(iy & 0xfffff000);
144
145  float yt = y - yh;
146
147  float ylogx_s = __clc_mad(gt, yh, __clc_mad(gh, yt, yt * gt));
148  float ylogx = __clc_mad(yh, gh, ylogx_s);
149  float ylogx_t = __clc_mad(yh, gh, -ylogx) + ylogx_s;
150
151  // Extra precise exp of ylogx
152  // 64/log2 : 92.332482616893657
153  const float R_64_BY_LOG2 = 0x1.715476p+6f;
154  int n = convert_int(ylogx * R_64_BY_LOG2);
155  float nf = (float)n;
156
157  int j = n & 0x3f;
158  m = n >> 6;
159  int m2 = m << EXPSHIFTBITS_SP32;
160  // log2/64 lead: 0.0108032227
161  const float R_LOG2_BY_64_LD = 0x1.620000p-7f;
162  // log2/64 tail: 0.0000272020388
163  const float R_LOG2_BY_64_TL = 0x1.c85fdep-16f;
164  r = __clc_mad(nf, -R_LOG2_BY_64_TL, __clc_mad(nf, -R_LOG2_BY_64_LD, ylogx)) +
165      ylogx_t;
166
167  // Truncated Taylor series for e^r
168  poly = __clc_mad(__clc_mad(__clc_mad(r, 0x1.555556p-5f, 0x1.555556p-3f), r,
169                             0x1.000000p-1f),
170                   r * r, r);
171
172  tv = USE_TABLE(exp_tbl_ep, j);
173
174  float expylogx =
175      __clc_mad(tv.s0, poly, __clc_mad(tv.s1, poly, tv.s1)) + tv.s0;
176  float sexpylogx = expylogx * as_float(0x1 << (m + 149));
177  float texpylogx = as_float(as_int(expylogx) + m2);
178  expylogx = m < -125 ? sexpylogx : texpylogx;
179
180  // Result is +-Inf if (ylogx + ylogx_t) > 128*log2
181  expylogx = ((ylogx > 0x1.62e430p+6f) |
182              (ylogx == 0x1.62e430p+6f & ylogx_t > -0x1.05c610p-22f))
183                 ? as_float(PINFBITPATT_SP32)
184                 : expylogx;
185
186  // Result is 0 if ylogx < -149*log2
187  expylogx = ylogx < -0x1.9d1da0p+6f ? 0.0f : expylogx;
188
189  // Classify y:
190  //   inty = 0 means not an integer.
191  //   inty = 1 means odd integer.
192  //   inty = 2 means even integer.
193
194  int yexp = (int)(ay >> EXPSHIFTBITS_SP32) - EXPBIAS_SP32 + 1;
195  int mask = (1 << (24 - yexp)) - 1;
196  int yodd = ((iy >> (24 - yexp)) & 0x1) != 0;
197  int inty = yodd ? 1 : 2;
198  inty = (iy & mask) != 0 ? 0 : inty;
199  inty = yexp < 1 ? 0 : inty;
200  inty = yexp > 24 ? 2 : inty;
201
202  float signval = as_float((as_uint(expylogx) ^ SIGNBIT_SP32));
203  expylogx = ((inty == 1) & !xpos) ? signval : expylogx;
204  int ret = as_int(expylogx);
205
206  // Corner case handling
207  ret = ax < 0x3f800000 & iy == NINFBITPATT_SP32 ? PINFBITPATT_SP32 : ret;
208  ret = ax < 0x3f800000 & iy == PINFBITPATT_SP32 ? 0 : ret;
209  ret = ax == 0x3f800000 & ay < PINFBITPATT_SP32 ? 0x3f800000 : ret;
210  ret = ax == 0x3f800000 & ay == PINFBITPATT_SP32 ? QNANBITPATT_SP32 : ret;
211  ret = ax > 0x3f800000 & iy == NINFBITPATT_SP32 ? 0 : ret;
212  ret = ax > 0x3f800000 & iy == PINFBITPATT_SP32 ? PINFBITPATT_SP32 : ret;
213  ret = ((ix < PINFBITPATT_SP32) & (ay == 0)) ? 0x3f800000 : ret;
214  ret = ((ax == PINFBITPATT_SP32) & !ypos) ? 0 : ret;
215  ret = ((ax == PINFBITPATT_SP32) & ypos) ? PINFBITPATT_SP32 : ret;
216  ret = ((ax == PINFBITPATT_SP32) & (iy == PINFBITPATT_SP32)) ? PINFBITPATT_SP32
217                                                              : ret;
218  ret = ((ax == PINFBITPATT_SP32) & (ay == 0)) ? QNANBITPATT_SP32 : ret;
219  ret = ((ax == 0) & !ypos) ? PINFBITPATT_SP32 : ret;
220  ret = ((ax == 0) & ypos) ? 0 : ret;
221  ret = ((ax == 0) & (ay == 0)) ? QNANBITPATT_SP32 : ret;
222  ret = ((ax != 0) & !xpos) ? QNANBITPATT_SP32 : ret;
223  ret = ax > PINFBITPATT_SP32 ? ix : ret;
224  ret = ay > PINFBITPATT_SP32 ? iy : ret;
225
226  return as_float(ret);
227}
228_CLC_BINARY_VECTORIZE(_CLC_DEF _CLC_OVERLOAD, float, __clc_powr, float, float)
229
230#ifdef cl_khr_fp64
231_CLC_DEF _CLC_OVERLOAD double __clc_powr(double x, double y) {
232  const double real_log2_tail = 5.76999904754328540596e-08;
233  const double real_log2_lead = 6.93147122859954833984e-01;
234
235  long ux = as_long(x);
236  long ax = ux & (~SIGNBIT_DP64);
237  int xpos = ax == ux;
238
239  long uy = as_long(y);
240  long ay = uy & (~SIGNBIT_DP64);
241  int ypos = ay == uy;
242
243  // Extended precision log
244  double v, vt;
245  {
246    int exp = (int)(ax >> 52) - 1023;
247    int mask_exp_1023 = exp == -1023;
248    double xexp = (double)exp;
249    long mantissa = ax & 0x000FFFFFFFFFFFFFL;
250
251    long temp_ux = as_long(as_double(0x3ff0000000000000L | mantissa) - 1.0);
252    exp = ((temp_ux & 0x7FF0000000000000L) >> 52) - 2045;
253    double xexp1 = (double)exp;
254    long mantissa1 = temp_ux & 0x000FFFFFFFFFFFFFL;
255
256    xexp = mask_exp_1023 ? xexp1 : xexp;
257    mantissa = mask_exp_1023 ? mantissa1 : mantissa;
258
259    long rax = (mantissa & 0x000ff00000000000) +
260               ((mantissa & 0x0000080000000000) << 1);
261    int index = rax >> 44;
262
263    double F = as_double(rax | 0x3FE0000000000000L);
264    double Y = as_double(mantissa | 0x3FE0000000000000L);
265    double f = F - Y;
266    double2 tv = USE_TABLE(log_f_inv_tbl, index);
267    double log_h = tv.s0;
268    double log_t = tv.s1;
269    double f_inv = (log_h + log_t) * f;
270    double r1 = as_double(as_long(f_inv) & 0xfffffffff8000000L);
271    double r2 = fma(-F, r1, f) * (log_h + log_t);
272    double r = r1 + r2;
273
274    double poly = fma(
275        r, fma(r, fma(r, fma(r, 1.0 / 7.0, 1.0 / 6.0), 1.0 / 5.0), 1.0 / 4.0),
276        1.0 / 3.0);
277    poly = poly * r * r * r;
278
279    double hr1r1 = 0.5 * r1 * r1;
280    double poly0h = r1 + hr1r1;
281    double poly0t = r1 - poly0h + hr1r1;
282    poly = fma(r1, r2, fma(0.5 * r2, r2, poly)) + r2 + poly0t;
283
284    tv = USE_TABLE(powlog_tbl, index);
285    log_h = tv.s0;
286    log_t = tv.s1;
287
288    double resT_t = fma(xexp, real_log2_tail, +log_t) - poly;
289    double resT = resT_t - poly0h;
290    double resH = fma(xexp, real_log2_lead, log_h);
291    double resT_h = poly0h;
292
293    double H = resT + resH;
294    double H_h = as_double(as_long(H) & 0xfffffffff8000000L);
295    double T = (resH - H + resT) + (resT_t - (resT + resT_h)) + (H - H_h);
296    H = H_h;
297
298    double y_head = as_double(uy & 0xfffffffff8000000L);
299    double y_tail = y - y_head;
300
301    double temp = fma(y_tail, H, fma(y_head, T, y_tail * T));
302    v = fma(y_head, H, temp);
303    vt = fma(y_head, H, -v) + temp;
304  }
305
306  // Now calculate exp of (v,vt)
307
308  double expv;
309  {
310    const double max_exp_arg = 709.782712893384;
311    const double min_exp_arg = -745.1332191019411;
312    const double sixtyfour_by_lnof2 = 92.33248261689366;
313    const double lnof2_by_64_head = 0.010830424260348081;
314    const double lnof2_by_64_tail = -4.359010638708991e-10;
315
316    double temp = v * sixtyfour_by_lnof2;
317    int n = (int)temp;
318    double dn = (double)n;
319    int j = n & 0x0000003f;
320    int m = n >> 6;
321
322    double2 tv = USE_TABLE(two_to_jby64_ep_tbl, j);
323    double f1 = tv.s0;
324    double f2 = tv.s1;
325    double f = f1 + f2;
326
327    double r1 = fma(dn, -lnof2_by_64_head, v);
328    double r2 = dn * lnof2_by_64_tail;
329    double r = (r1 + r2) + vt;
330
331    double q = fma(
332        r,
333        fma(r,
334            fma(r,
335                fma(r, 1.38889490863777199667e-03, 8.33336798434219616221e-03),
336                4.16666666662260795726e-02),
337            1.66666666665260878863e-01),
338        5.00000000000000008883e-01);
339    q = fma(r * r, q, r);
340
341    expv = fma(f, q, f2) + f1;
342    expv = ldexp(expv, m);
343
344    expv = v > max_exp_arg ? as_double(0x7FF0000000000000L) : expv;
345    expv = v < min_exp_arg ? 0.0 : expv;
346  }
347
348  // See whether y is an integer.
349  // inty = 0 means not an integer.
350  // inty = 1 means odd integer.
351  // inty = 2 means even integer.
352
353  int inty;
354  {
355    int yexp = (int)(ay >> EXPSHIFTBITS_DP64) - EXPBIAS_DP64 + 1;
356    inty = yexp < 1 ? 0 : 2;
357    inty = yexp > 53 ? 2 : inty;
358    long mask = (1L << (53 - yexp)) - 1L;
359    int inty1 = (((ay & ~mask) >> (53 - yexp)) & 1L) == 1L ? 1 : 2;
360    inty1 = (ay & mask) != 0 ? 0 : inty1;
361    inty = !(yexp < 1) & !(yexp > 53) ? inty1 : inty;
362  }
363
364  expv *= ((inty == 1) & !xpos) ? -1.0 : 1.0;
365
366  long ret = as_long(expv);
367
368  // Now all the edge cases
369  ret = ax < 0x3ff0000000000000L & uy == NINFBITPATT_DP64 ? PINFBITPATT_DP64
370                                                          : ret;
371  ret = ax < 0x3ff0000000000000L & uy == PINFBITPATT_DP64 ? 0L : ret;
372  ret = ax == 0x3ff0000000000000L & ay < PINFBITPATT_DP64 ? 0x3ff0000000000000L
373                                                          : ret;
374  ret = ax == 0x3ff0000000000000L & ay == PINFBITPATT_DP64 ? QNANBITPATT_DP64
375                                                           : ret;
376  ret = ax > 0x3ff0000000000000L & uy == NINFBITPATT_DP64 ? 0L : ret;
377  ret = ax > 0x3ff0000000000000L & uy == PINFBITPATT_DP64 ? PINFBITPATT_DP64
378                                                          : ret;
379  ret = ux < PINFBITPATT_DP64 & ay == 0L ? 0x3ff0000000000000L : ret;
380  ret = ((ax == PINFBITPATT_DP64) & !ypos) ? 0L : ret;
381  ret = ((ax == PINFBITPATT_DP64) & ypos) ? PINFBITPATT_DP64 : ret;
382  ret = ((ax == PINFBITPATT_DP64) & (uy == PINFBITPATT_DP64)) ? PINFBITPATT_DP64
383                                                              : ret;
384  ret = ((ax == PINFBITPATT_DP64) & (ay == 0L)) ? QNANBITPATT_DP64 : ret;
385  ret = ((ax == 0L) & !ypos) ? PINFBITPATT_DP64 : ret;
386  ret = ((ax == 0L) & ypos) ? 0L : ret;
387  ret = ((ax == 0L) & (ay == 0L)) ? QNANBITPATT_DP64 : ret;
388  ret = ((ax != 0L) & !xpos) ? QNANBITPATT_DP64 : ret;
389  ret = ax > PINFBITPATT_DP64 ? ux : ret;
390  ret = ay > PINFBITPATT_DP64 ? uy : ret;
391
392  return as_double(ret);
393}
394_CLC_BINARY_VECTORIZE(_CLC_DEF _CLC_OVERLOAD, double, __clc_powr, double, double)
395#endif
396