1*9637de38SSrikanth Yalavarthi /* SPDX-License-Identifier: BSD-3-Clause 2*9637de38SSrikanth Yalavarthi * Copyright (c) 2022 Marvell. 3*9637de38SSrikanth Yalavarthi */ 4*9637de38SSrikanth Yalavarthi 5*9637de38SSrikanth Yalavarthi #include <errno.h> 6*9637de38SSrikanth Yalavarthi #include <math.h> 7*9637de38SSrikanth Yalavarthi #include <stdint.h> 8*9637de38SSrikanth Yalavarthi 9*9637de38SSrikanth Yalavarthi #include "mldev_utils.h" 10*9637de38SSrikanth Yalavarthi 11*9637de38SSrikanth Yalavarthi /* Description: 12*9637de38SSrikanth Yalavarthi * This file implements scalar versions of Machine Learning utility functions used to convert data 13*9637de38SSrikanth Yalavarthi * types from higher precision to lower precision and vice-versa. 14*9637de38SSrikanth Yalavarthi */ 15*9637de38SSrikanth Yalavarthi 16*9637de38SSrikanth Yalavarthi #ifndef BIT 17*9637de38SSrikanth Yalavarthi #define BIT(nr) (1UL << (nr)) 18*9637de38SSrikanth Yalavarthi #endif 19*9637de38SSrikanth Yalavarthi 20*9637de38SSrikanth Yalavarthi #ifndef BITS_PER_LONG 21*9637de38SSrikanth Yalavarthi #define BITS_PER_LONG (__SIZEOF_LONG__ * 8) 22*9637de38SSrikanth Yalavarthi #endif 23*9637de38SSrikanth Yalavarthi 24*9637de38SSrikanth Yalavarthi #ifndef GENMASK_U32 25*9637de38SSrikanth Yalavarthi #define GENMASK_U32(h, l) (((~0UL) << (l)) & (~0UL >> (BITS_PER_LONG - 1 - (h)))) 26*9637de38SSrikanth Yalavarthi #endif 27*9637de38SSrikanth Yalavarthi 28*9637de38SSrikanth Yalavarthi /* float32: bit index of MSB & LSB of sign, exponent and mantissa */ 29*9637de38SSrikanth Yalavarthi #define FP32_LSB_M 0 30*9637de38SSrikanth Yalavarthi #define FP32_MSB_M 22 31*9637de38SSrikanth Yalavarthi #define FP32_LSB_E 23 32*9637de38SSrikanth Yalavarthi #define FP32_MSB_E 30 33*9637de38SSrikanth Yalavarthi #define FP32_LSB_S 31 34*9637de38SSrikanth Yalavarthi #define FP32_MSB_S 31 35*9637de38SSrikanth Yalavarthi 36*9637de38SSrikanth Yalavarthi /* float32: bitmask for sign, exponent and mantissa */ 37*9637de38SSrikanth Yalavarthi #define FP32_MASK_S GENMASK_U32(FP32_MSB_S, FP32_LSB_S) 38*9637de38SSrikanth Yalavarthi #define FP32_MASK_E GENMASK_U32(FP32_MSB_E, FP32_LSB_E) 39*9637de38SSrikanth Yalavarthi #define FP32_MASK_M GENMASK_U32(FP32_MSB_M, FP32_LSB_M) 40*9637de38SSrikanth Yalavarthi 41*9637de38SSrikanth Yalavarthi /* float16: bit index of MSB & LSB of sign, exponent and mantissa */ 42*9637de38SSrikanth Yalavarthi #define FP16_LSB_M 0 43*9637de38SSrikanth Yalavarthi #define FP16_MSB_M 9 44*9637de38SSrikanth Yalavarthi #define FP16_LSB_E 10 45*9637de38SSrikanth Yalavarthi #define FP16_MSB_E 14 46*9637de38SSrikanth Yalavarthi #define FP16_LSB_S 15 47*9637de38SSrikanth Yalavarthi #define FP16_MSB_S 15 48*9637de38SSrikanth Yalavarthi 49*9637de38SSrikanth Yalavarthi /* float16: bitmask for sign, exponent and mantissa */ 50*9637de38SSrikanth Yalavarthi #define FP16_MASK_S GENMASK_U32(FP16_MSB_S, FP16_LSB_S) 51*9637de38SSrikanth Yalavarthi #define FP16_MASK_E GENMASK_U32(FP16_MSB_E, FP16_LSB_E) 52*9637de38SSrikanth Yalavarthi #define FP16_MASK_M GENMASK_U32(FP16_MSB_M, FP16_LSB_M) 53*9637de38SSrikanth Yalavarthi 54*9637de38SSrikanth Yalavarthi /* bfloat16: bit index of MSB & LSB of sign, exponent and mantissa */ 55*9637de38SSrikanth Yalavarthi #define BF16_LSB_M 0 56*9637de38SSrikanth Yalavarthi #define BF16_MSB_M 6 57*9637de38SSrikanth Yalavarthi #define BF16_LSB_E 7 58*9637de38SSrikanth Yalavarthi #define BF16_MSB_E 14 59*9637de38SSrikanth Yalavarthi #define BF16_LSB_S 15 60*9637de38SSrikanth Yalavarthi #define BF16_MSB_S 15 61*9637de38SSrikanth Yalavarthi 62*9637de38SSrikanth Yalavarthi /* bfloat16: bitmask for sign, exponent and mantissa */ 63*9637de38SSrikanth Yalavarthi #define BF16_MASK_S GENMASK_U32(BF16_MSB_S, BF16_LSB_S) 64*9637de38SSrikanth Yalavarthi #define BF16_MASK_E GENMASK_U32(BF16_MSB_E, BF16_LSB_E) 65*9637de38SSrikanth Yalavarthi #define BF16_MASK_M GENMASK_U32(BF16_MSB_M, BF16_LSB_M) 66*9637de38SSrikanth Yalavarthi 67*9637de38SSrikanth Yalavarthi /* Exponent bias */ 68*9637de38SSrikanth Yalavarthi #define FP32_BIAS_E 127 69*9637de38SSrikanth Yalavarthi #define FP16_BIAS_E 15 70*9637de38SSrikanth Yalavarthi #define BF16_BIAS_E 127 71*9637de38SSrikanth Yalavarthi 72*9637de38SSrikanth Yalavarthi #define FP32_PACK(sign, exponent, mantissa) \ 73*9637de38SSrikanth Yalavarthi (((sign) << FP32_LSB_S) | ((exponent) << FP32_LSB_E) | (mantissa)) 74*9637de38SSrikanth Yalavarthi 75*9637de38SSrikanth Yalavarthi #define FP16_PACK(sign, exponent, mantissa) \ 76*9637de38SSrikanth Yalavarthi (((sign) << FP16_LSB_S) | ((exponent) << FP16_LSB_E) | (mantissa)) 77*9637de38SSrikanth Yalavarthi 78*9637de38SSrikanth Yalavarthi #define BF16_PACK(sign, exponent, mantissa) \ 79*9637de38SSrikanth Yalavarthi (((sign) << BF16_LSB_S) | ((exponent) << BF16_LSB_E) | (mantissa)) 80*9637de38SSrikanth Yalavarthi 81*9637de38SSrikanth Yalavarthi /* Represent float32 as float and uint32_t */ 82*9637de38SSrikanth Yalavarthi union float32 { 83*9637de38SSrikanth Yalavarthi float f; 84*9637de38SSrikanth Yalavarthi uint32_t u; 85*9637de38SSrikanth Yalavarthi }; 86*9637de38SSrikanth Yalavarthi 87*9637de38SSrikanth Yalavarthi __rte_weak int 88*9637de38SSrikanth Yalavarthi rte_ml_io_float32_to_int8(float scale, uint64_t nb_elements, void *input, void *output) 89*9637de38SSrikanth Yalavarthi { 90*9637de38SSrikanth Yalavarthi float *input_buffer; 91*9637de38SSrikanth Yalavarthi int8_t *output_buffer; 92*9637de38SSrikanth Yalavarthi uint64_t i; 93*9637de38SSrikanth Yalavarthi int i32; 94*9637de38SSrikanth Yalavarthi 95*9637de38SSrikanth Yalavarthi if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 96*9637de38SSrikanth Yalavarthi return -EINVAL; 97*9637de38SSrikanth Yalavarthi 98*9637de38SSrikanth Yalavarthi input_buffer = (float *)input; 99*9637de38SSrikanth Yalavarthi output_buffer = (int8_t *)output; 100*9637de38SSrikanth Yalavarthi 101*9637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 102*9637de38SSrikanth Yalavarthi i32 = (int32_t)round((*input_buffer) * scale); 103*9637de38SSrikanth Yalavarthi 104*9637de38SSrikanth Yalavarthi if (i32 < INT8_MIN) 105*9637de38SSrikanth Yalavarthi i32 = INT8_MIN; 106*9637de38SSrikanth Yalavarthi 107*9637de38SSrikanth Yalavarthi if (i32 > INT8_MAX) 108*9637de38SSrikanth Yalavarthi i32 = INT8_MAX; 109*9637de38SSrikanth Yalavarthi 110*9637de38SSrikanth Yalavarthi *output_buffer = (int8_t)i32; 111*9637de38SSrikanth Yalavarthi 112*9637de38SSrikanth Yalavarthi input_buffer++; 113*9637de38SSrikanth Yalavarthi output_buffer++; 114*9637de38SSrikanth Yalavarthi } 115*9637de38SSrikanth Yalavarthi 116*9637de38SSrikanth Yalavarthi return 0; 117*9637de38SSrikanth Yalavarthi } 118*9637de38SSrikanth Yalavarthi 119*9637de38SSrikanth Yalavarthi __rte_weak int 120*9637de38SSrikanth Yalavarthi rte_ml_io_int8_to_float32(float scale, uint64_t nb_elements, void *input, void *output) 121*9637de38SSrikanth Yalavarthi { 122*9637de38SSrikanth Yalavarthi int8_t *input_buffer; 123*9637de38SSrikanth Yalavarthi float *output_buffer; 124*9637de38SSrikanth Yalavarthi uint64_t i; 125*9637de38SSrikanth Yalavarthi 126*9637de38SSrikanth Yalavarthi if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 127*9637de38SSrikanth Yalavarthi return -EINVAL; 128*9637de38SSrikanth Yalavarthi 129*9637de38SSrikanth Yalavarthi input_buffer = (int8_t *)input; 130*9637de38SSrikanth Yalavarthi output_buffer = (float *)output; 131*9637de38SSrikanth Yalavarthi 132*9637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 133*9637de38SSrikanth Yalavarthi *output_buffer = scale * (float)(*input_buffer); 134*9637de38SSrikanth Yalavarthi 135*9637de38SSrikanth Yalavarthi input_buffer++; 136*9637de38SSrikanth Yalavarthi output_buffer++; 137*9637de38SSrikanth Yalavarthi } 138*9637de38SSrikanth Yalavarthi 139*9637de38SSrikanth Yalavarthi return 0; 140*9637de38SSrikanth Yalavarthi } 141*9637de38SSrikanth Yalavarthi 142*9637de38SSrikanth Yalavarthi __rte_weak int 143*9637de38SSrikanth Yalavarthi rte_ml_io_float32_to_uint8(float scale, uint64_t nb_elements, void *input, void *output) 144*9637de38SSrikanth Yalavarthi { 145*9637de38SSrikanth Yalavarthi float *input_buffer; 146*9637de38SSrikanth Yalavarthi uint8_t *output_buffer; 147*9637de38SSrikanth Yalavarthi int32_t i32; 148*9637de38SSrikanth Yalavarthi uint64_t i; 149*9637de38SSrikanth Yalavarthi 150*9637de38SSrikanth Yalavarthi if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 151*9637de38SSrikanth Yalavarthi return -EINVAL; 152*9637de38SSrikanth Yalavarthi 153*9637de38SSrikanth Yalavarthi input_buffer = (float *)input; 154*9637de38SSrikanth Yalavarthi output_buffer = (uint8_t *)output; 155*9637de38SSrikanth Yalavarthi 156*9637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 157*9637de38SSrikanth Yalavarthi i32 = (int32_t)round((*input_buffer) * scale); 158*9637de38SSrikanth Yalavarthi 159*9637de38SSrikanth Yalavarthi if (i32 < 0) 160*9637de38SSrikanth Yalavarthi i32 = 0; 161*9637de38SSrikanth Yalavarthi 162*9637de38SSrikanth Yalavarthi if (i32 > UINT8_MAX) 163*9637de38SSrikanth Yalavarthi i32 = UINT8_MAX; 164*9637de38SSrikanth Yalavarthi 165*9637de38SSrikanth Yalavarthi *output_buffer = (uint8_t)i32; 166*9637de38SSrikanth Yalavarthi 167*9637de38SSrikanth Yalavarthi input_buffer++; 168*9637de38SSrikanth Yalavarthi output_buffer++; 169*9637de38SSrikanth Yalavarthi } 170*9637de38SSrikanth Yalavarthi 171*9637de38SSrikanth Yalavarthi return 0; 172*9637de38SSrikanth Yalavarthi } 173*9637de38SSrikanth Yalavarthi 174*9637de38SSrikanth Yalavarthi __rte_weak int 175*9637de38SSrikanth Yalavarthi rte_ml_io_uint8_to_float32(float scale, uint64_t nb_elements, void *input, void *output) 176*9637de38SSrikanth Yalavarthi { 177*9637de38SSrikanth Yalavarthi uint8_t *input_buffer; 178*9637de38SSrikanth Yalavarthi float *output_buffer; 179*9637de38SSrikanth Yalavarthi uint64_t i; 180*9637de38SSrikanth Yalavarthi 181*9637de38SSrikanth Yalavarthi if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 182*9637de38SSrikanth Yalavarthi return -EINVAL; 183*9637de38SSrikanth Yalavarthi 184*9637de38SSrikanth Yalavarthi input_buffer = (uint8_t *)input; 185*9637de38SSrikanth Yalavarthi output_buffer = (float *)output; 186*9637de38SSrikanth Yalavarthi 187*9637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 188*9637de38SSrikanth Yalavarthi *output_buffer = scale * (float)(*input_buffer); 189*9637de38SSrikanth Yalavarthi 190*9637de38SSrikanth Yalavarthi input_buffer++; 191*9637de38SSrikanth Yalavarthi output_buffer++; 192*9637de38SSrikanth Yalavarthi } 193*9637de38SSrikanth Yalavarthi 194*9637de38SSrikanth Yalavarthi return 0; 195*9637de38SSrikanth Yalavarthi } 196*9637de38SSrikanth Yalavarthi 197*9637de38SSrikanth Yalavarthi __rte_weak int 198*9637de38SSrikanth Yalavarthi rte_ml_io_float32_to_int16(float scale, uint64_t nb_elements, void *input, void *output) 199*9637de38SSrikanth Yalavarthi { 200*9637de38SSrikanth Yalavarthi float *input_buffer; 201*9637de38SSrikanth Yalavarthi int16_t *output_buffer; 202*9637de38SSrikanth Yalavarthi int32_t i32; 203*9637de38SSrikanth Yalavarthi uint64_t i; 204*9637de38SSrikanth Yalavarthi 205*9637de38SSrikanth Yalavarthi if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 206*9637de38SSrikanth Yalavarthi return -EINVAL; 207*9637de38SSrikanth Yalavarthi 208*9637de38SSrikanth Yalavarthi input_buffer = (float *)input; 209*9637de38SSrikanth Yalavarthi output_buffer = (int16_t *)output; 210*9637de38SSrikanth Yalavarthi 211*9637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 212*9637de38SSrikanth Yalavarthi i32 = (int32_t)round((*input_buffer) * scale); 213*9637de38SSrikanth Yalavarthi 214*9637de38SSrikanth Yalavarthi if (i32 < INT16_MIN) 215*9637de38SSrikanth Yalavarthi i32 = INT16_MIN; 216*9637de38SSrikanth Yalavarthi 217*9637de38SSrikanth Yalavarthi if (i32 > INT16_MAX) 218*9637de38SSrikanth Yalavarthi i32 = INT16_MAX; 219*9637de38SSrikanth Yalavarthi 220*9637de38SSrikanth Yalavarthi *output_buffer = (int16_t)i32; 221*9637de38SSrikanth Yalavarthi 222*9637de38SSrikanth Yalavarthi input_buffer++; 223*9637de38SSrikanth Yalavarthi output_buffer++; 224*9637de38SSrikanth Yalavarthi } 225*9637de38SSrikanth Yalavarthi 226*9637de38SSrikanth Yalavarthi return 0; 227*9637de38SSrikanth Yalavarthi } 228*9637de38SSrikanth Yalavarthi 229*9637de38SSrikanth Yalavarthi __rte_weak int 230*9637de38SSrikanth Yalavarthi rte_ml_io_int16_to_float32(float scale, uint64_t nb_elements, void *input, void *output) 231*9637de38SSrikanth Yalavarthi { 232*9637de38SSrikanth Yalavarthi int16_t *input_buffer; 233*9637de38SSrikanth Yalavarthi float *output_buffer; 234*9637de38SSrikanth Yalavarthi uint64_t i; 235*9637de38SSrikanth Yalavarthi 236*9637de38SSrikanth Yalavarthi if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 237*9637de38SSrikanth Yalavarthi return -EINVAL; 238*9637de38SSrikanth Yalavarthi 239*9637de38SSrikanth Yalavarthi input_buffer = (int16_t *)input; 240*9637de38SSrikanth Yalavarthi output_buffer = (float *)output; 241*9637de38SSrikanth Yalavarthi 242*9637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 243*9637de38SSrikanth Yalavarthi *output_buffer = scale * (float)(*input_buffer); 244*9637de38SSrikanth Yalavarthi 245*9637de38SSrikanth Yalavarthi input_buffer++; 246*9637de38SSrikanth Yalavarthi output_buffer++; 247*9637de38SSrikanth Yalavarthi } 248*9637de38SSrikanth Yalavarthi 249*9637de38SSrikanth Yalavarthi return 0; 250*9637de38SSrikanth Yalavarthi } 251*9637de38SSrikanth Yalavarthi 252*9637de38SSrikanth Yalavarthi __rte_weak int 253*9637de38SSrikanth Yalavarthi rte_ml_io_float32_to_uint16(float scale, uint64_t nb_elements, void *input, void *output) 254*9637de38SSrikanth Yalavarthi { 255*9637de38SSrikanth Yalavarthi float *input_buffer; 256*9637de38SSrikanth Yalavarthi uint16_t *output_buffer; 257*9637de38SSrikanth Yalavarthi int32_t i32; 258*9637de38SSrikanth Yalavarthi uint64_t i; 259*9637de38SSrikanth Yalavarthi 260*9637de38SSrikanth Yalavarthi if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 261*9637de38SSrikanth Yalavarthi return -EINVAL; 262*9637de38SSrikanth Yalavarthi 263*9637de38SSrikanth Yalavarthi input_buffer = (float *)input; 264*9637de38SSrikanth Yalavarthi output_buffer = (uint16_t *)output; 265*9637de38SSrikanth Yalavarthi 266*9637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 267*9637de38SSrikanth Yalavarthi i32 = (int32_t)round((*input_buffer) * scale); 268*9637de38SSrikanth Yalavarthi 269*9637de38SSrikanth Yalavarthi if (i32 < 0) 270*9637de38SSrikanth Yalavarthi i32 = 0; 271*9637de38SSrikanth Yalavarthi 272*9637de38SSrikanth Yalavarthi if (i32 > UINT16_MAX) 273*9637de38SSrikanth Yalavarthi i32 = UINT16_MAX; 274*9637de38SSrikanth Yalavarthi 275*9637de38SSrikanth Yalavarthi *output_buffer = (uint16_t)i32; 276*9637de38SSrikanth Yalavarthi 277*9637de38SSrikanth Yalavarthi input_buffer++; 278*9637de38SSrikanth Yalavarthi output_buffer++; 279*9637de38SSrikanth Yalavarthi } 280*9637de38SSrikanth Yalavarthi 281*9637de38SSrikanth Yalavarthi return 0; 282*9637de38SSrikanth Yalavarthi } 283*9637de38SSrikanth Yalavarthi 284*9637de38SSrikanth Yalavarthi __rte_weak int 285*9637de38SSrikanth Yalavarthi rte_ml_io_uint16_to_float32(float scale, uint64_t nb_elements, void *input, void *output) 286*9637de38SSrikanth Yalavarthi { 287*9637de38SSrikanth Yalavarthi uint16_t *input_buffer; 288*9637de38SSrikanth Yalavarthi float *output_buffer; 289*9637de38SSrikanth Yalavarthi uint64_t i; 290*9637de38SSrikanth Yalavarthi 291*9637de38SSrikanth Yalavarthi if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 292*9637de38SSrikanth Yalavarthi return -EINVAL; 293*9637de38SSrikanth Yalavarthi 294*9637de38SSrikanth Yalavarthi input_buffer = (uint16_t *)input; 295*9637de38SSrikanth Yalavarthi output_buffer = (float *)output; 296*9637de38SSrikanth Yalavarthi 297*9637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 298*9637de38SSrikanth Yalavarthi *output_buffer = scale * (float)(*input_buffer); 299*9637de38SSrikanth Yalavarthi 300*9637de38SSrikanth Yalavarthi input_buffer++; 301*9637de38SSrikanth Yalavarthi output_buffer++; 302*9637de38SSrikanth Yalavarthi } 303*9637de38SSrikanth Yalavarthi 304*9637de38SSrikanth Yalavarthi return 0; 305*9637de38SSrikanth Yalavarthi } 306*9637de38SSrikanth Yalavarthi 307*9637de38SSrikanth Yalavarthi /* Convert a single precision floating point number (float32) into a half precision 308*9637de38SSrikanth Yalavarthi * floating point number (float16) using round to nearest rounding mode. 309*9637de38SSrikanth Yalavarthi */ 310*9637de38SSrikanth Yalavarthi static uint16_t 311*9637de38SSrikanth Yalavarthi __float32_to_float16_scalar_rtn(float x) 312*9637de38SSrikanth Yalavarthi { 313*9637de38SSrikanth Yalavarthi union float32 f32; /* float32 input */ 314*9637de38SSrikanth Yalavarthi uint32_t f32_s; /* float32 sign */ 315*9637de38SSrikanth Yalavarthi uint32_t f32_e; /* float32 exponent */ 316*9637de38SSrikanth Yalavarthi uint32_t f32_m; /* float32 mantissa */ 317*9637de38SSrikanth Yalavarthi uint16_t f16_s; /* float16 sign */ 318*9637de38SSrikanth Yalavarthi uint16_t f16_e; /* float16 exponent */ 319*9637de38SSrikanth Yalavarthi uint16_t f16_m; /* float16 mantissa */ 320*9637de38SSrikanth Yalavarthi uint32_t tbits; /* number of truncated bits */ 321*9637de38SSrikanth Yalavarthi uint32_t tmsb; /* MSB position of truncated bits */ 322*9637de38SSrikanth Yalavarthi uint32_t m_32; /* temporary float32 mantissa */ 323*9637de38SSrikanth Yalavarthi uint16_t m_16; /* temporary float16 mantissa */ 324*9637de38SSrikanth Yalavarthi uint16_t u16; /* float16 output */ 325*9637de38SSrikanth Yalavarthi int be_16; /* float16 biased exponent, signed */ 326*9637de38SSrikanth Yalavarthi 327*9637de38SSrikanth Yalavarthi f32.f = x; 328*9637de38SSrikanth Yalavarthi f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S; 329*9637de38SSrikanth Yalavarthi f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E; 330*9637de38SSrikanth Yalavarthi f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M; 331*9637de38SSrikanth Yalavarthi 332*9637de38SSrikanth Yalavarthi f16_s = f32_s; 333*9637de38SSrikanth Yalavarthi f16_e = 0; 334*9637de38SSrikanth Yalavarthi f16_m = 0; 335*9637de38SSrikanth Yalavarthi 336*9637de38SSrikanth Yalavarthi switch (f32_e) { 337*9637de38SSrikanth Yalavarthi case (0): /* float32: zero or subnormal number */ 338*9637de38SSrikanth Yalavarthi f16_e = 0; 339*9637de38SSrikanth Yalavarthi if (f32_m == 0) /* zero */ 340*9637de38SSrikanth Yalavarthi f16_m = 0; 341*9637de38SSrikanth Yalavarthi else /* subnormal number, convert to zero */ 342*9637de38SSrikanth Yalavarthi f16_m = 0; 343*9637de38SSrikanth Yalavarthi break; 344*9637de38SSrikanth Yalavarthi case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */ 345*9637de38SSrikanth Yalavarthi f16_e = FP16_MASK_E >> FP16_LSB_E; 346*9637de38SSrikanth Yalavarthi if (f32_m == 0) { /* infinity */ 347*9637de38SSrikanth Yalavarthi f16_m = 0; 348*9637de38SSrikanth Yalavarthi } else { /* nan, propagate mantissa and set MSB of mantissa to 1 */ 349*9637de38SSrikanth Yalavarthi f16_m = f32_m >> (FP32_MSB_M - FP16_MSB_M); 350*9637de38SSrikanth Yalavarthi f16_m |= BIT(FP16_MSB_M); 351*9637de38SSrikanth Yalavarthi } 352*9637de38SSrikanth Yalavarthi break; 353*9637de38SSrikanth Yalavarthi default: /* float32: normal number */ 354*9637de38SSrikanth Yalavarthi /* compute biased exponent for float16 */ 355*9637de38SSrikanth Yalavarthi be_16 = (int)f32_e - FP32_BIAS_E + FP16_BIAS_E; 356*9637de38SSrikanth Yalavarthi 357*9637de38SSrikanth Yalavarthi /* overflow, be_16 = [31-INF], set to infinity */ 358*9637de38SSrikanth Yalavarthi if (be_16 >= (int)(FP16_MASK_E >> FP16_LSB_E)) { 359*9637de38SSrikanth Yalavarthi f16_e = FP16_MASK_E >> FP16_LSB_E; 360*9637de38SSrikanth Yalavarthi f16_m = 0; 361*9637de38SSrikanth Yalavarthi } else if ((be_16 >= 1) && (be_16 < (int)(FP16_MASK_E >> FP16_LSB_E))) { 362*9637de38SSrikanth Yalavarthi /* normal float16, be_16 = [1:30]*/ 363*9637de38SSrikanth Yalavarthi f16_e = be_16; 364*9637de38SSrikanth Yalavarthi m_16 = f32_m >> (FP32_LSB_E - FP16_LSB_E); 365*9637de38SSrikanth Yalavarthi tmsb = FP32_MSB_M - FP16_MSB_M - 1; 366*9637de38SSrikanth Yalavarthi if ((f32_m & GENMASK_U32(tmsb, 0)) > BIT(tmsb)) { 367*9637de38SSrikanth Yalavarthi /* round: non-zero truncated bits except MSB */ 368*9637de38SSrikanth Yalavarthi m_16++; 369*9637de38SSrikanth Yalavarthi 370*9637de38SSrikanth Yalavarthi /* overflow into exponent */ 371*9637de38SSrikanth Yalavarthi if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1) 372*9637de38SSrikanth Yalavarthi f16_e++; 373*9637de38SSrikanth Yalavarthi } else if ((f32_m & GENMASK_U32(tmsb, 0)) == BIT(tmsb)) { 374*9637de38SSrikanth Yalavarthi /* round: MSB of truncated bits and LSB of m_16 is set */ 375*9637de38SSrikanth Yalavarthi if ((m_16 & 0x1) == 0x1) { 376*9637de38SSrikanth Yalavarthi m_16++; 377*9637de38SSrikanth Yalavarthi 378*9637de38SSrikanth Yalavarthi /* overflow into exponent */ 379*9637de38SSrikanth Yalavarthi if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1) 380*9637de38SSrikanth Yalavarthi f16_e++; 381*9637de38SSrikanth Yalavarthi } 382*9637de38SSrikanth Yalavarthi } 383*9637de38SSrikanth Yalavarthi f16_m = m_16 & FP16_MASK_M; 384*9637de38SSrikanth Yalavarthi } else if ((be_16 >= -(int)(FP16_MSB_M)) && (be_16 < 1)) { 385*9637de38SSrikanth Yalavarthi /* underflow: zero / subnormal, be_16 = [-9:0] */ 386*9637de38SSrikanth Yalavarthi f16_e = 0; 387*9637de38SSrikanth Yalavarthi 388*9637de38SSrikanth Yalavarthi /* add implicit leading zero */ 389*9637de38SSrikanth Yalavarthi m_32 = f32_m | BIT(FP32_LSB_E); 390*9637de38SSrikanth Yalavarthi tbits = FP32_LSB_E - FP16_LSB_E - be_16 + 1; 391*9637de38SSrikanth Yalavarthi m_16 = m_32 >> tbits; 392*9637de38SSrikanth Yalavarthi 393*9637de38SSrikanth Yalavarthi /* if non-leading truncated bits are set */ 394*9637de38SSrikanth Yalavarthi if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) { 395*9637de38SSrikanth Yalavarthi m_16++; 396*9637de38SSrikanth Yalavarthi 397*9637de38SSrikanth Yalavarthi /* overflow into exponent */ 398*9637de38SSrikanth Yalavarthi if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1) 399*9637de38SSrikanth Yalavarthi f16_e++; 400*9637de38SSrikanth Yalavarthi } else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) { 401*9637de38SSrikanth Yalavarthi /* if leading truncated bit is set */ 402*9637de38SSrikanth Yalavarthi if ((m_16 & 0x1) == 0x1) { 403*9637de38SSrikanth Yalavarthi m_16++; 404*9637de38SSrikanth Yalavarthi 405*9637de38SSrikanth Yalavarthi /* overflow into exponent */ 406*9637de38SSrikanth Yalavarthi if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1) 407*9637de38SSrikanth Yalavarthi f16_e++; 408*9637de38SSrikanth Yalavarthi } 409*9637de38SSrikanth Yalavarthi } 410*9637de38SSrikanth Yalavarthi f16_m = m_16 & FP16_MASK_M; 411*9637de38SSrikanth Yalavarthi } else if (be_16 == -(int)(FP16_MSB_M + 1)) { 412*9637de38SSrikanth Yalavarthi /* underflow: zero, be_16 = [-10] */ 413*9637de38SSrikanth Yalavarthi f16_e = 0; 414*9637de38SSrikanth Yalavarthi if (f32_m != 0) 415*9637de38SSrikanth Yalavarthi f16_m = 1; 416*9637de38SSrikanth Yalavarthi else 417*9637de38SSrikanth Yalavarthi f16_m = 0; 418*9637de38SSrikanth Yalavarthi } else { 419*9637de38SSrikanth Yalavarthi /* underflow: zero, be_16 = [-INF:-11] */ 420*9637de38SSrikanth Yalavarthi f16_e = 0; 421*9637de38SSrikanth Yalavarthi f16_m = 0; 422*9637de38SSrikanth Yalavarthi } 423*9637de38SSrikanth Yalavarthi 424*9637de38SSrikanth Yalavarthi break; 425*9637de38SSrikanth Yalavarthi } 426*9637de38SSrikanth Yalavarthi 427*9637de38SSrikanth Yalavarthi u16 = FP16_PACK(f16_s, f16_e, f16_m); 428*9637de38SSrikanth Yalavarthi 429*9637de38SSrikanth Yalavarthi return u16; 430*9637de38SSrikanth Yalavarthi } 431*9637de38SSrikanth Yalavarthi 432*9637de38SSrikanth Yalavarthi __rte_weak int 433*9637de38SSrikanth Yalavarthi rte_ml_io_float32_to_float16(uint64_t nb_elements, void *input, void *output) 434*9637de38SSrikanth Yalavarthi { 435*9637de38SSrikanth Yalavarthi float *input_buffer; 436*9637de38SSrikanth Yalavarthi uint16_t *output_buffer; 437*9637de38SSrikanth Yalavarthi uint64_t i; 438*9637de38SSrikanth Yalavarthi 439*9637de38SSrikanth Yalavarthi if ((nb_elements == 0) || (input == NULL) || (output == NULL)) 440*9637de38SSrikanth Yalavarthi return -EINVAL; 441*9637de38SSrikanth Yalavarthi 442*9637de38SSrikanth Yalavarthi input_buffer = (float *)input; 443*9637de38SSrikanth Yalavarthi output_buffer = (uint16_t *)output; 444*9637de38SSrikanth Yalavarthi 445*9637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 446*9637de38SSrikanth Yalavarthi *output_buffer = __float32_to_float16_scalar_rtn(*input_buffer); 447*9637de38SSrikanth Yalavarthi 448*9637de38SSrikanth Yalavarthi input_buffer = input_buffer + 1; 449*9637de38SSrikanth Yalavarthi output_buffer = output_buffer + 1; 450*9637de38SSrikanth Yalavarthi } 451*9637de38SSrikanth Yalavarthi 452*9637de38SSrikanth Yalavarthi return 0; 453*9637de38SSrikanth Yalavarthi } 454*9637de38SSrikanth Yalavarthi 455*9637de38SSrikanth Yalavarthi /* Convert a half precision floating point number (float16) into a single precision 456*9637de38SSrikanth Yalavarthi * floating point number (float32). 457*9637de38SSrikanth Yalavarthi */ 458*9637de38SSrikanth Yalavarthi static float 459*9637de38SSrikanth Yalavarthi __float16_to_float32_scalar_rtx(uint16_t f16) 460*9637de38SSrikanth Yalavarthi { 461*9637de38SSrikanth Yalavarthi union float32 f32; /* float32 output */ 462*9637de38SSrikanth Yalavarthi uint16_t f16_s; /* float16 sign */ 463*9637de38SSrikanth Yalavarthi uint16_t f16_e; /* float16 exponent */ 464*9637de38SSrikanth Yalavarthi uint16_t f16_m; /* float16 mantissa */ 465*9637de38SSrikanth Yalavarthi uint32_t f32_s; /* float32 sign */ 466*9637de38SSrikanth Yalavarthi uint32_t f32_e; /* float32 exponent */ 467*9637de38SSrikanth Yalavarthi uint32_t f32_m; /* float32 mantissa*/ 468*9637de38SSrikanth Yalavarthi uint8_t shift; /* number of bits to be shifted */ 469*9637de38SSrikanth Yalavarthi uint32_t clz; /* count of leading zeroes */ 470*9637de38SSrikanth Yalavarthi int e_16; /* float16 exponent unbiased */ 471*9637de38SSrikanth Yalavarthi 472*9637de38SSrikanth Yalavarthi f16_s = (f16 & FP16_MASK_S) >> FP16_LSB_S; 473*9637de38SSrikanth Yalavarthi f16_e = (f16 & FP16_MASK_E) >> FP16_LSB_E; 474*9637de38SSrikanth Yalavarthi f16_m = (f16 & FP16_MASK_M) >> FP16_LSB_M; 475*9637de38SSrikanth Yalavarthi 476*9637de38SSrikanth Yalavarthi f32_s = f16_s; 477*9637de38SSrikanth Yalavarthi switch (f16_e) { 478*9637de38SSrikanth Yalavarthi case (FP16_MASK_E >> FP16_LSB_E): /* float16: infinity or nan */ 479*9637de38SSrikanth Yalavarthi f32_e = FP32_MASK_E >> FP32_LSB_E; 480*9637de38SSrikanth Yalavarthi if (f16_m == 0x0) { /* infinity */ 481*9637de38SSrikanth Yalavarthi f32_m = f16_m; 482*9637de38SSrikanth Yalavarthi } else { /* nan, propagate mantissa, set MSB of mantissa to 1 */ 483*9637de38SSrikanth Yalavarthi f32_m = f16_m; 484*9637de38SSrikanth Yalavarthi shift = FP32_MSB_M - FP16_MSB_M; 485*9637de38SSrikanth Yalavarthi f32_m = (f32_m << shift) & FP32_MASK_M; 486*9637de38SSrikanth Yalavarthi f32_m |= BIT(FP32_MSB_M); 487*9637de38SSrikanth Yalavarthi } 488*9637de38SSrikanth Yalavarthi break; 489*9637de38SSrikanth Yalavarthi case 0: /* float16: zero or sub-normal */ 490*9637de38SSrikanth Yalavarthi f32_m = f16_m; 491*9637de38SSrikanth Yalavarthi if (f16_m == 0) { /* zero signed */ 492*9637de38SSrikanth Yalavarthi f32_e = 0; 493*9637de38SSrikanth Yalavarthi } else { /* subnormal numbers */ 494*9637de38SSrikanth Yalavarthi clz = __builtin_clz((uint32_t)f16_m) - sizeof(uint32_t) * 8 + FP16_LSB_E; 495*9637de38SSrikanth Yalavarthi e_16 = (int)f16_e - clz; 496*9637de38SSrikanth Yalavarthi f32_e = FP32_BIAS_E + e_16 - FP16_BIAS_E; 497*9637de38SSrikanth Yalavarthi 498*9637de38SSrikanth Yalavarthi shift = clz + (FP32_MSB_M - FP16_MSB_M) + 1; 499*9637de38SSrikanth Yalavarthi f32_m = (f32_m << shift) & FP32_MASK_M; 500*9637de38SSrikanth Yalavarthi } 501*9637de38SSrikanth Yalavarthi break; 502*9637de38SSrikanth Yalavarthi default: /* normal numbers */ 503*9637de38SSrikanth Yalavarthi f32_m = f16_m; 504*9637de38SSrikanth Yalavarthi e_16 = (int)f16_e; 505*9637de38SSrikanth Yalavarthi f32_e = FP32_BIAS_E + e_16 - FP16_BIAS_E; 506*9637de38SSrikanth Yalavarthi 507*9637de38SSrikanth Yalavarthi shift = (FP32_MSB_M - FP16_MSB_M); 508*9637de38SSrikanth Yalavarthi f32_m = (f32_m << shift) & FP32_MASK_M; 509*9637de38SSrikanth Yalavarthi } 510*9637de38SSrikanth Yalavarthi 511*9637de38SSrikanth Yalavarthi f32.u = FP32_PACK(f32_s, f32_e, f32_m); 512*9637de38SSrikanth Yalavarthi 513*9637de38SSrikanth Yalavarthi return f32.f; 514*9637de38SSrikanth Yalavarthi } 515*9637de38SSrikanth Yalavarthi 516*9637de38SSrikanth Yalavarthi __rte_weak int 517*9637de38SSrikanth Yalavarthi rte_ml_io_float16_to_float32(uint64_t nb_elements, void *input, void *output) 518*9637de38SSrikanth Yalavarthi { 519*9637de38SSrikanth Yalavarthi uint16_t *input_buffer; 520*9637de38SSrikanth Yalavarthi float *output_buffer; 521*9637de38SSrikanth Yalavarthi uint64_t i; 522*9637de38SSrikanth Yalavarthi 523*9637de38SSrikanth Yalavarthi if ((nb_elements == 0) || (input == NULL) || (output == NULL)) 524*9637de38SSrikanth Yalavarthi return -EINVAL; 525*9637de38SSrikanth Yalavarthi 526*9637de38SSrikanth Yalavarthi input_buffer = (uint16_t *)input; 527*9637de38SSrikanth Yalavarthi output_buffer = (float *)output; 528*9637de38SSrikanth Yalavarthi 529*9637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 530*9637de38SSrikanth Yalavarthi *output_buffer = __float16_to_float32_scalar_rtx(*input_buffer); 531*9637de38SSrikanth Yalavarthi 532*9637de38SSrikanth Yalavarthi input_buffer = input_buffer + 1; 533*9637de38SSrikanth Yalavarthi output_buffer = output_buffer + 1; 534*9637de38SSrikanth Yalavarthi } 535*9637de38SSrikanth Yalavarthi 536*9637de38SSrikanth Yalavarthi return 0; 537*9637de38SSrikanth Yalavarthi } 538*9637de38SSrikanth Yalavarthi 539*9637de38SSrikanth Yalavarthi /* Convert a single precision floating point number (float32) into a 540*9637de38SSrikanth Yalavarthi * brain float number (bfloat16) using round to nearest rounding mode. 541*9637de38SSrikanth Yalavarthi */ 542*9637de38SSrikanth Yalavarthi static uint16_t 543*9637de38SSrikanth Yalavarthi __float32_to_bfloat16_scalar_rtn(float x) 544*9637de38SSrikanth Yalavarthi { 545*9637de38SSrikanth Yalavarthi union float32 f32; /* float32 input */ 546*9637de38SSrikanth Yalavarthi uint32_t f32_s; /* float32 sign */ 547*9637de38SSrikanth Yalavarthi uint32_t f32_e; /* float32 exponent */ 548*9637de38SSrikanth Yalavarthi uint32_t f32_m; /* float32 mantissa */ 549*9637de38SSrikanth Yalavarthi uint16_t b16_s; /* float16 sign */ 550*9637de38SSrikanth Yalavarthi uint16_t b16_e; /* float16 exponent */ 551*9637de38SSrikanth Yalavarthi uint16_t b16_m; /* float16 mantissa */ 552*9637de38SSrikanth Yalavarthi uint32_t tbits; /* number of truncated bits */ 553*9637de38SSrikanth Yalavarthi uint16_t u16; /* float16 output */ 554*9637de38SSrikanth Yalavarthi 555*9637de38SSrikanth Yalavarthi f32.f = x; 556*9637de38SSrikanth Yalavarthi f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S; 557*9637de38SSrikanth Yalavarthi f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E; 558*9637de38SSrikanth Yalavarthi f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M; 559*9637de38SSrikanth Yalavarthi 560*9637de38SSrikanth Yalavarthi b16_s = f32_s; 561*9637de38SSrikanth Yalavarthi b16_e = 0; 562*9637de38SSrikanth Yalavarthi b16_m = 0; 563*9637de38SSrikanth Yalavarthi 564*9637de38SSrikanth Yalavarthi switch (f32_e) { 565*9637de38SSrikanth Yalavarthi case (0): /* float32: zero or subnormal number */ 566*9637de38SSrikanth Yalavarthi b16_e = 0; 567*9637de38SSrikanth Yalavarthi if (f32_m == 0) /* zero */ 568*9637de38SSrikanth Yalavarthi b16_m = 0; 569*9637de38SSrikanth Yalavarthi else /* subnormal float32 number, normal bfloat16 */ 570*9637de38SSrikanth Yalavarthi goto bf16_normal; 571*9637de38SSrikanth Yalavarthi break; 572*9637de38SSrikanth Yalavarthi case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */ 573*9637de38SSrikanth Yalavarthi b16_e = BF16_MASK_E >> BF16_LSB_E; 574*9637de38SSrikanth Yalavarthi if (f32_m == 0) { /* infinity */ 575*9637de38SSrikanth Yalavarthi b16_m = 0; 576*9637de38SSrikanth Yalavarthi } else { /* nan, propagate mantissa and set MSB of mantissa to 1 */ 577*9637de38SSrikanth Yalavarthi b16_m = f32_m >> (FP32_MSB_M - BF16_MSB_M); 578*9637de38SSrikanth Yalavarthi b16_m |= BIT(BF16_MSB_M); 579*9637de38SSrikanth Yalavarthi } 580*9637de38SSrikanth Yalavarthi break; 581*9637de38SSrikanth Yalavarthi default: /* float32: normal number, normal bfloat16 */ 582*9637de38SSrikanth Yalavarthi goto bf16_normal; 583*9637de38SSrikanth Yalavarthi } 584*9637de38SSrikanth Yalavarthi 585*9637de38SSrikanth Yalavarthi goto bf16_pack; 586*9637de38SSrikanth Yalavarthi 587*9637de38SSrikanth Yalavarthi bf16_normal: 588*9637de38SSrikanth Yalavarthi b16_e = f32_e; 589*9637de38SSrikanth Yalavarthi tbits = FP32_MSB_M - BF16_MSB_M; 590*9637de38SSrikanth Yalavarthi b16_m = f32_m >> tbits; 591*9637de38SSrikanth Yalavarthi 592*9637de38SSrikanth Yalavarthi /* if non-leading truncated bits are set */ 593*9637de38SSrikanth Yalavarthi if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) { 594*9637de38SSrikanth Yalavarthi b16_m++; 595*9637de38SSrikanth Yalavarthi 596*9637de38SSrikanth Yalavarthi /* if overflow into exponent */ 597*9637de38SSrikanth Yalavarthi if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1) 598*9637de38SSrikanth Yalavarthi b16_e++; 599*9637de38SSrikanth Yalavarthi } else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) { 600*9637de38SSrikanth Yalavarthi /* if only leading truncated bit is set */ 601*9637de38SSrikanth Yalavarthi if ((b16_m & 0x1) == 0x1) { 602*9637de38SSrikanth Yalavarthi b16_m++; 603*9637de38SSrikanth Yalavarthi 604*9637de38SSrikanth Yalavarthi /* if overflow into exponent */ 605*9637de38SSrikanth Yalavarthi if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1) 606*9637de38SSrikanth Yalavarthi b16_e++; 607*9637de38SSrikanth Yalavarthi } 608*9637de38SSrikanth Yalavarthi } 609*9637de38SSrikanth Yalavarthi b16_m = b16_m & BF16_MASK_M; 610*9637de38SSrikanth Yalavarthi 611*9637de38SSrikanth Yalavarthi bf16_pack: 612*9637de38SSrikanth Yalavarthi u16 = BF16_PACK(b16_s, b16_e, b16_m); 613*9637de38SSrikanth Yalavarthi 614*9637de38SSrikanth Yalavarthi return u16; 615*9637de38SSrikanth Yalavarthi } 616*9637de38SSrikanth Yalavarthi 617*9637de38SSrikanth Yalavarthi __rte_weak int 618*9637de38SSrikanth Yalavarthi rte_ml_io_float32_to_bfloat16(uint64_t nb_elements, void *input, void *output) 619*9637de38SSrikanth Yalavarthi { 620*9637de38SSrikanth Yalavarthi float *input_buffer; 621*9637de38SSrikanth Yalavarthi uint16_t *output_buffer; 622*9637de38SSrikanth Yalavarthi uint64_t i; 623*9637de38SSrikanth Yalavarthi 624*9637de38SSrikanth Yalavarthi if ((nb_elements == 0) || (input == NULL) || (output == NULL)) 625*9637de38SSrikanth Yalavarthi return -EINVAL; 626*9637de38SSrikanth Yalavarthi 627*9637de38SSrikanth Yalavarthi input_buffer = (float *)input; 628*9637de38SSrikanth Yalavarthi output_buffer = (uint16_t *)output; 629*9637de38SSrikanth Yalavarthi 630*9637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 631*9637de38SSrikanth Yalavarthi *output_buffer = __float32_to_bfloat16_scalar_rtn(*input_buffer); 632*9637de38SSrikanth Yalavarthi 633*9637de38SSrikanth Yalavarthi input_buffer = input_buffer + 1; 634*9637de38SSrikanth Yalavarthi output_buffer = output_buffer + 1; 635*9637de38SSrikanth Yalavarthi } 636*9637de38SSrikanth Yalavarthi 637*9637de38SSrikanth Yalavarthi return 0; 638*9637de38SSrikanth Yalavarthi } 639*9637de38SSrikanth Yalavarthi 640*9637de38SSrikanth Yalavarthi /* Convert a brain float number (bfloat16) into a 641*9637de38SSrikanth Yalavarthi * single precision floating point number (float32). 642*9637de38SSrikanth Yalavarthi */ 643*9637de38SSrikanth Yalavarthi static float 644*9637de38SSrikanth Yalavarthi __bfloat16_to_float32_scalar_rtx(uint16_t f16) 645*9637de38SSrikanth Yalavarthi { 646*9637de38SSrikanth Yalavarthi union float32 f32; /* float32 output */ 647*9637de38SSrikanth Yalavarthi uint16_t b16_s; /* float16 sign */ 648*9637de38SSrikanth Yalavarthi uint16_t b16_e; /* float16 exponent */ 649*9637de38SSrikanth Yalavarthi uint16_t b16_m; /* float16 mantissa */ 650*9637de38SSrikanth Yalavarthi uint32_t f32_s; /* float32 sign */ 651*9637de38SSrikanth Yalavarthi uint32_t f32_e; /* float32 exponent */ 652*9637de38SSrikanth Yalavarthi uint32_t f32_m; /* float32 mantissa*/ 653*9637de38SSrikanth Yalavarthi uint8_t shift; /* number of bits to be shifted */ 654*9637de38SSrikanth Yalavarthi 655*9637de38SSrikanth Yalavarthi b16_s = (f16 & BF16_MASK_S) >> BF16_LSB_S; 656*9637de38SSrikanth Yalavarthi b16_e = (f16 & BF16_MASK_E) >> BF16_LSB_E; 657*9637de38SSrikanth Yalavarthi b16_m = (f16 & BF16_MASK_M) >> BF16_LSB_M; 658*9637de38SSrikanth Yalavarthi 659*9637de38SSrikanth Yalavarthi f32_s = b16_s; 660*9637de38SSrikanth Yalavarthi switch (b16_e) { 661*9637de38SSrikanth Yalavarthi case (BF16_MASK_E >> BF16_LSB_E): /* bfloat16: infinity or nan */ 662*9637de38SSrikanth Yalavarthi f32_e = FP32_MASK_E >> FP32_LSB_E; 663*9637de38SSrikanth Yalavarthi if (b16_m == 0x0) { /* infinity */ 664*9637de38SSrikanth Yalavarthi f32_m = 0; 665*9637de38SSrikanth Yalavarthi } else { /* nan, propagate mantissa, set MSB of mantissa to 1 */ 666*9637de38SSrikanth Yalavarthi f32_m = b16_m; 667*9637de38SSrikanth Yalavarthi shift = FP32_MSB_M - BF16_MSB_M; 668*9637de38SSrikanth Yalavarthi f32_m = (f32_m << shift) & FP32_MASK_M; 669*9637de38SSrikanth Yalavarthi f32_m |= BIT(FP32_MSB_M); 670*9637de38SSrikanth Yalavarthi } 671*9637de38SSrikanth Yalavarthi break; 672*9637de38SSrikanth Yalavarthi case 0: /* bfloat16: zero or subnormal */ 673*9637de38SSrikanth Yalavarthi f32_m = b16_m; 674*9637de38SSrikanth Yalavarthi if (b16_m == 0) { /* zero signed */ 675*9637de38SSrikanth Yalavarthi f32_e = 0; 676*9637de38SSrikanth Yalavarthi } else { /* subnormal numbers */ 677*9637de38SSrikanth Yalavarthi goto fp32_normal; 678*9637de38SSrikanth Yalavarthi } 679*9637de38SSrikanth Yalavarthi break; 680*9637de38SSrikanth Yalavarthi default: /* bfloat16: normal number */ 681*9637de38SSrikanth Yalavarthi goto fp32_normal; 682*9637de38SSrikanth Yalavarthi } 683*9637de38SSrikanth Yalavarthi 684*9637de38SSrikanth Yalavarthi goto fp32_pack; 685*9637de38SSrikanth Yalavarthi 686*9637de38SSrikanth Yalavarthi fp32_normal: 687*9637de38SSrikanth Yalavarthi f32_m = b16_m; 688*9637de38SSrikanth Yalavarthi f32_e = FP32_BIAS_E + b16_e - BF16_BIAS_E; 689*9637de38SSrikanth Yalavarthi 690*9637de38SSrikanth Yalavarthi shift = (FP32_MSB_M - BF16_MSB_M); 691*9637de38SSrikanth Yalavarthi f32_m = (f32_m << shift) & FP32_MASK_M; 692*9637de38SSrikanth Yalavarthi 693*9637de38SSrikanth Yalavarthi fp32_pack: 694*9637de38SSrikanth Yalavarthi f32.u = FP32_PACK(f32_s, f32_e, f32_m); 695*9637de38SSrikanth Yalavarthi 696*9637de38SSrikanth Yalavarthi return f32.f; 697*9637de38SSrikanth Yalavarthi } 698*9637de38SSrikanth Yalavarthi 699*9637de38SSrikanth Yalavarthi __rte_weak int 700*9637de38SSrikanth Yalavarthi rte_ml_io_bfloat16_to_float32(uint64_t nb_elements, void *input, void *output) 701*9637de38SSrikanth Yalavarthi { 702*9637de38SSrikanth Yalavarthi uint16_t *input_buffer; 703*9637de38SSrikanth Yalavarthi float *output_buffer; 704*9637de38SSrikanth Yalavarthi uint64_t i; 705*9637de38SSrikanth Yalavarthi 706*9637de38SSrikanth Yalavarthi if ((nb_elements == 0) || (input == NULL) || (output == NULL)) 707*9637de38SSrikanth Yalavarthi return -EINVAL; 708*9637de38SSrikanth Yalavarthi 709*9637de38SSrikanth Yalavarthi input_buffer = (uint16_t *)input; 710*9637de38SSrikanth Yalavarthi output_buffer = (float *)output; 711*9637de38SSrikanth Yalavarthi 712*9637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 713*9637de38SSrikanth Yalavarthi *output_buffer = __bfloat16_to_float32_scalar_rtx(*input_buffer); 714*9637de38SSrikanth Yalavarthi 715*9637de38SSrikanth Yalavarthi input_buffer = input_buffer + 1; 716*9637de38SSrikanth Yalavarthi output_buffer = output_buffer + 1; 717*9637de38SSrikanth Yalavarthi } 718*9637de38SSrikanth Yalavarthi 719*9637de38SSrikanth Yalavarthi return 0; 720*9637de38SSrikanth Yalavarthi } 721