1/* 2 * Copyright (c) 2014 Advanced Micro Devices, Inc. 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a copy 5 * of this software and associated documentation files (the "Software"), to deal 6 * in the Software without restriction, including without limitation the rights 7 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 * copies of the Software, and to permit persons to whom the Software is 9 * furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice shall be included in 12 * all copies or substantial portions of the Software. 13 * 14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 * THE SOFTWARE. 21 */ 22 23// This version is derived from the generic fma software implementation 24// (__clc_sw_fma), but avoids the use of ulong in favor of uint2. The logic has 25// been updated as appropriate. 26 27#include <clc/clc.h> 28#include <clc/clcmacro.h> 29#include <clc/math/math.h> 30 31struct fp { 32 uint2 mantissa; 33 int exponent; 34 uint sign; 35}; 36 37static uint2 u2_set(uint hi, uint lo) { 38 uint2 res; 39 res.lo = lo; 40 res.hi = hi; 41 return res; 42} 43 44static uint2 u2_set_u(uint val) { return u2_set(0, val); } 45 46static uint2 u2_mul(uint a, uint b) { 47 uint2 res; 48 res.hi = mul_hi(a, b); 49 res.lo = a * b; 50 return res; 51} 52 53static uint2 u2_sll(uint2 val, uint shift) { 54 if (shift == 0) 55 return val; 56 if (shift < 32) { 57 val.hi <<= shift; 58 val.hi |= val.lo >> (32 - shift); 59 val.lo <<= shift; 60 } else { 61 val.hi = val.lo << (shift - 32); 62 val.lo = 0; 63 } 64 return val; 65} 66 67static uint2 u2_srl(uint2 val, uint shift) { 68 if (shift == 0) 69 return val; 70 if (shift < 32) { 71 val.lo >>= shift; 72 val.lo |= val.hi << (32 - shift); 73 val.hi >>= shift; 74 } else { 75 val.lo = val.hi >> (shift - 32); 76 val.hi = 0; 77 } 78 return val; 79} 80 81static uint2 u2_or(uint2 a, uint b) { 82 a.lo |= b; 83 return a; 84} 85 86static uint2 u2_and(uint2 a, uint2 b) { 87 a.lo &= b.lo; 88 a.hi &= b.hi; 89 return a; 90} 91 92static uint2 u2_add(uint2 a, uint2 b) { 93 uint carry = (hadd(a.lo, b.lo) >> 31) & 0x1; 94 a.lo += b.lo; 95 a.hi += b.hi + carry; 96 return a; 97} 98 99static uint2 u2_add_u(uint2 a, uint b) { return u2_add(a, u2_set_u(b)); } 100 101static uint2 u2_inv(uint2 a) { 102 a.lo = ~a.lo; 103 a.hi = ~a.hi; 104 return u2_add_u(a, 1); 105} 106 107static uint u2_clz(uint2 a) { 108 uint leading_zeroes = clz(a.hi); 109 if (leading_zeroes == 32) { 110 leading_zeroes += clz(a.lo); 111 } 112 return leading_zeroes; 113} 114 115static bool u2_eq(uint2 a, uint2 b) { return a.lo == b.lo && a.hi == b.hi; } 116 117static bool u2_zero(uint2 a) { return u2_eq(a, u2_set_u(0)); } 118 119static bool u2_gt(uint2 a, uint2 b) { 120 return a.hi > b.hi || (a.hi == b.hi && a.lo > b.lo); 121} 122 123_CLC_DEF _CLC_OVERLOAD float fma(float a, float b, float c) { 124 /* special cases */ 125 if (isnan(a) || isnan(b) || isnan(c) || isinf(a) || isinf(b)) { 126 return mad(a, b, c); 127 } 128 129 /* If only c is inf, and both a,b are regular numbers, the result is c*/ 130 if (isinf(c)) { 131 return c; 132 } 133 134 a = __clc_flush_denormal_if_not_supported(a); 135 b = __clc_flush_denormal_if_not_supported(b); 136 c = __clc_flush_denormal_if_not_supported(c); 137 138 if (a == 0.0f || b == 0.0f) { 139 return c; 140 } 141 142 if (c == 0) { 143 return a * b; 144 } 145 146 struct fp st_a, st_b, st_c; 147 148 st_a.exponent = a == .0f ? 0 : ((as_uint(a) & 0x7f800000) >> 23) - 127; 149 st_b.exponent = b == .0f ? 0 : ((as_uint(b) & 0x7f800000) >> 23) - 127; 150 st_c.exponent = c == .0f ? 0 : ((as_uint(c) & 0x7f800000) >> 23) - 127; 151 152 st_a.mantissa = u2_set_u(a == .0f ? 0 : (as_uint(a) & 0x7fffff) | 0x800000); 153 st_b.mantissa = u2_set_u(b == .0f ? 0 : (as_uint(b) & 0x7fffff) | 0x800000); 154 st_c.mantissa = u2_set_u(c == .0f ? 0 : (as_uint(c) & 0x7fffff) | 0x800000); 155 156 st_a.sign = as_uint(a) & 0x80000000; 157 st_b.sign = as_uint(b) & 0x80000000; 158 st_c.sign = as_uint(c) & 0x80000000; 159 160 // Multiplication. 161 // Move the product to the highest bits to maximize precision 162 // mantissa is 24 bits => product is 48 bits, 2bits non-fraction. 163 // Add one bit for future addition overflow, 164 // add another bit to detect subtraction underflow 165 struct fp st_mul; 166 st_mul.sign = st_a.sign ^ st_b.sign; 167 st_mul.mantissa = u2_sll(u2_mul(st_a.mantissa.lo, st_b.mantissa.lo), 14); 168 st_mul.exponent = 169 !u2_zero(st_mul.mantissa) ? st_a.exponent + st_b.exponent : 0; 170 171 // FIXME: Detecting a == 0 || b == 0 above crashed GCN isel 172 if (st_mul.exponent == 0 && u2_zero(st_mul.mantissa)) 173 return c; 174 175// Mantissa is 23 fractional bits, shift it the same way as product mantissa 176#define C_ADJUST 37ul 177 178 // both exponents are bias adjusted 179 int exp_diff = st_mul.exponent - st_c.exponent; 180 181 st_c.mantissa = u2_sll(st_c.mantissa, C_ADJUST); 182 uint2 cutoff_bits = u2_set_u(0); 183 uint2 cutoff_mask = u2_add(u2_sll(u2_set_u(1), abs(exp_diff)), 184 u2_set(0xffffffff, 0xffffffff)); 185 if (exp_diff > 0) { 186 cutoff_bits = 187 exp_diff >= 64 ? st_c.mantissa : u2_and(st_c.mantissa, cutoff_mask); 188 st_c.mantissa = 189 exp_diff >= 64 ? u2_set_u(0) : u2_srl(st_c.mantissa, exp_diff); 190 } else { 191 cutoff_bits = -exp_diff >= 64 ? st_mul.mantissa 192 : u2_and(st_mul.mantissa, cutoff_mask); 193 st_mul.mantissa = 194 -exp_diff >= 64 ? u2_set_u(0) : u2_srl(st_mul.mantissa, -exp_diff); 195 } 196 197 struct fp st_fma; 198 st_fma.sign = st_mul.sign; 199 st_fma.exponent = max(st_mul.exponent, st_c.exponent); 200 if (st_c.sign == st_mul.sign) { 201 st_fma.mantissa = u2_add(st_mul.mantissa, st_c.mantissa); 202 } else { 203 // cutoff bits borrow one 204 st_fma.mantissa = 205 u2_add(u2_add(st_mul.mantissa, u2_inv(st_c.mantissa)), 206 (!u2_zero(cutoff_bits) && (st_mul.exponent > st_c.exponent) 207 ? u2_set(0xffffffff, 0xffffffff) 208 : u2_set_u(0))); 209 } 210 211 // underflow: st_c.sign != st_mul.sign, and magnitude switches the sign 212 if (u2_gt(st_fma.mantissa, u2_set(0x7fffffff, 0xffffffff))) { 213 st_fma.mantissa = u2_inv(st_fma.mantissa); 214 st_fma.sign = st_mul.sign ^ 0x80000000; 215 } 216 217 // detect overflow/underflow 218 int overflow_bits = 3 - u2_clz(st_fma.mantissa); 219 220 // adjust exponent 221 st_fma.exponent += overflow_bits; 222 223 // handle underflow 224 if (overflow_bits < 0) { 225 st_fma.mantissa = u2_sll(st_fma.mantissa, -overflow_bits); 226 overflow_bits = 0; 227 } 228 229 // rounding 230 uint2 trunc_mask = u2_add(u2_sll(u2_set_u(1), C_ADJUST + overflow_bits), 231 u2_set(0xffffffff, 0xffffffff)); 232 uint2 trunc_bits = 233 u2_or(u2_and(st_fma.mantissa, trunc_mask), !u2_zero(cutoff_bits)); 234 uint2 last_bit = 235 u2_and(st_fma.mantissa, u2_sll(u2_set_u(1), C_ADJUST + overflow_bits)); 236 uint2 grs_bits = u2_sll(u2_set_u(4), C_ADJUST - 3 + overflow_bits); 237 238 // round to nearest even 239 if (u2_gt(trunc_bits, grs_bits) || 240 (u2_eq(trunc_bits, grs_bits) && !u2_zero(last_bit))) { 241 st_fma.mantissa = 242 u2_add(st_fma.mantissa, u2_sll(u2_set_u(1), C_ADJUST + overflow_bits)); 243 } 244 245 // Shift mantissa back to bit 23 246 st_fma.mantissa = u2_srl(st_fma.mantissa, C_ADJUST + overflow_bits); 247 248 // Detect rounding overflow 249 if (u2_gt(st_fma.mantissa, u2_set_u(0xffffff))) { 250 ++st_fma.exponent; 251 st_fma.mantissa = u2_srl(st_fma.mantissa, 1); 252 } 253 254 if (u2_zero(st_fma.mantissa)) { 255 return 0.0f; 256 } 257 258 // Flating point range limit 259 if (st_fma.exponent > 127) { 260 return as_float(as_uint(INFINITY) | st_fma.sign); 261 } 262 263 // Flush denormals 264 if (st_fma.exponent <= -127) { 265 return as_float(st_fma.sign); 266 } 267 268 return as_float(st_fma.sign | ((st_fma.exponent + 127) << 23) | 269 ((uint)st_fma.mantissa.lo & 0x7fffff)); 270} 271_CLC_TERNARY_VECTORIZE(_CLC_DEF _CLC_OVERLOAD, float, fma, float, float, float) 272 273#ifdef cl_khr_fp16 274 275#pragma OPENCL EXTENSION cl_khr_fp16 : enable 276 277_CLC_DEF _CLC_OVERLOAD half fma(half a, half b, half c) { 278 return (half)mad((float)a, (float)b, (float)c); 279} 280_CLC_TERNARY_VECTORIZE(_CLC_DEF _CLC_OVERLOAD, half, fma, half, half, half) 281 282#endif 283