19637de38SSrikanth Yalavarthi /* SPDX-License-Identifier: BSD-3-Clause 29637de38SSrikanth Yalavarthi * Copyright (c) 2022 Marvell. 39637de38SSrikanth Yalavarthi */ 49637de38SSrikanth Yalavarthi 59637de38SSrikanth Yalavarthi #include <errno.h> 69637de38SSrikanth Yalavarthi #include <math.h> 79637de38SSrikanth Yalavarthi #include <stdint.h> 89637de38SSrikanth Yalavarthi 99637de38SSrikanth Yalavarthi #include "mldev_utils.h" 109637de38SSrikanth Yalavarthi 119637de38SSrikanth Yalavarthi /* Description: 129637de38SSrikanth Yalavarthi * This file implements scalar versions of Machine Learning utility functions used to convert data 139637de38SSrikanth Yalavarthi * types from higher precision to lower precision and vice-versa. 149637de38SSrikanth Yalavarthi */ 159637de38SSrikanth Yalavarthi 169637de38SSrikanth Yalavarthi #ifndef BIT 179637de38SSrikanth Yalavarthi #define BIT(nr) (1UL << (nr)) 189637de38SSrikanth Yalavarthi #endif 199637de38SSrikanth Yalavarthi 209637de38SSrikanth Yalavarthi #ifndef BITS_PER_LONG 219637de38SSrikanth Yalavarthi #define BITS_PER_LONG (__SIZEOF_LONG__ * 8) 229637de38SSrikanth Yalavarthi #endif 239637de38SSrikanth Yalavarthi 249637de38SSrikanth Yalavarthi #ifndef GENMASK_U32 259637de38SSrikanth Yalavarthi #define GENMASK_U32(h, l) (((~0UL) << (l)) & (~0UL >> (BITS_PER_LONG - 1 - (h)))) 269637de38SSrikanth Yalavarthi #endif 279637de38SSrikanth Yalavarthi 289637de38SSrikanth Yalavarthi /* float32: bit index of MSB & LSB of sign, exponent and mantissa */ 299637de38SSrikanth Yalavarthi #define FP32_LSB_M 0 309637de38SSrikanth Yalavarthi #define FP32_MSB_M 22 319637de38SSrikanth Yalavarthi #define FP32_LSB_E 23 329637de38SSrikanth Yalavarthi #define FP32_MSB_E 30 339637de38SSrikanth Yalavarthi #define FP32_LSB_S 31 349637de38SSrikanth Yalavarthi #define FP32_MSB_S 31 359637de38SSrikanth Yalavarthi 369637de38SSrikanth Yalavarthi /* float32: bitmask for sign, exponent and mantissa */ 379637de38SSrikanth Yalavarthi #define FP32_MASK_S GENMASK_U32(FP32_MSB_S, FP32_LSB_S) 389637de38SSrikanth Yalavarthi #define FP32_MASK_E GENMASK_U32(FP32_MSB_E, FP32_LSB_E) 399637de38SSrikanth Yalavarthi #define FP32_MASK_M GENMASK_U32(FP32_MSB_M, FP32_LSB_M) 409637de38SSrikanth Yalavarthi 419637de38SSrikanth Yalavarthi /* float16: bit index of MSB & LSB of sign, exponent and mantissa */ 429637de38SSrikanth Yalavarthi #define FP16_LSB_M 0 439637de38SSrikanth Yalavarthi #define FP16_MSB_M 9 449637de38SSrikanth Yalavarthi #define FP16_LSB_E 10 459637de38SSrikanth Yalavarthi #define FP16_MSB_E 14 469637de38SSrikanth Yalavarthi #define FP16_LSB_S 15 479637de38SSrikanth Yalavarthi #define FP16_MSB_S 15 489637de38SSrikanth Yalavarthi 499637de38SSrikanth Yalavarthi /* float16: bitmask for sign, exponent and mantissa */ 509637de38SSrikanth Yalavarthi #define FP16_MASK_S GENMASK_U32(FP16_MSB_S, FP16_LSB_S) 519637de38SSrikanth Yalavarthi #define FP16_MASK_E GENMASK_U32(FP16_MSB_E, FP16_LSB_E) 529637de38SSrikanth Yalavarthi #define FP16_MASK_M GENMASK_U32(FP16_MSB_M, FP16_LSB_M) 539637de38SSrikanth Yalavarthi 549637de38SSrikanth Yalavarthi /* bfloat16: bit index of MSB & LSB of sign, exponent and mantissa */ 559637de38SSrikanth Yalavarthi #define BF16_LSB_M 0 569637de38SSrikanth Yalavarthi #define BF16_MSB_M 6 579637de38SSrikanth Yalavarthi #define BF16_LSB_E 7 589637de38SSrikanth Yalavarthi #define BF16_MSB_E 14 599637de38SSrikanth Yalavarthi #define BF16_LSB_S 15 609637de38SSrikanth Yalavarthi #define BF16_MSB_S 15 619637de38SSrikanth Yalavarthi 629637de38SSrikanth Yalavarthi /* bfloat16: bitmask for sign, exponent and mantissa */ 639637de38SSrikanth Yalavarthi #define BF16_MASK_S GENMASK_U32(BF16_MSB_S, BF16_LSB_S) 649637de38SSrikanth Yalavarthi #define BF16_MASK_E GENMASK_U32(BF16_MSB_E, BF16_LSB_E) 659637de38SSrikanth Yalavarthi #define BF16_MASK_M GENMASK_U32(BF16_MSB_M, BF16_LSB_M) 669637de38SSrikanth Yalavarthi 679637de38SSrikanth Yalavarthi /* Exponent bias */ 689637de38SSrikanth Yalavarthi #define FP32_BIAS_E 127 699637de38SSrikanth Yalavarthi #define FP16_BIAS_E 15 709637de38SSrikanth Yalavarthi #define BF16_BIAS_E 127 719637de38SSrikanth Yalavarthi 729637de38SSrikanth Yalavarthi #define FP32_PACK(sign, exponent, mantissa) \ 739637de38SSrikanth Yalavarthi (((sign) << FP32_LSB_S) | ((exponent) << FP32_LSB_E) | (mantissa)) 749637de38SSrikanth Yalavarthi 759637de38SSrikanth Yalavarthi #define FP16_PACK(sign, exponent, mantissa) \ 769637de38SSrikanth Yalavarthi (((sign) << FP16_LSB_S) | ((exponent) << FP16_LSB_E) | (mantissa)) 779637de38SSrikanth Yalavarthi 789637de38SSrikanth Yalavarthi #define BF16_PACK(sign, exponent, mantissa) \ 799637de38SSrikanth Yalavarthi (((sign) << BF16_LSB_S) | ((exponent) << BF16_LSB_E) | (mantissa)) 809637de38SSrikanth Yalavarthi 819637de38SSrikanth Yalavarthi /* Represent float32 as float and uint32_t */ 829637de38SSrikanth Yalavarthi union float32 { 839637de38SSrikanth Yalavarthi float f; 849637de38SSrikanth Yalavarthi uint32_t u; 859637de38SSrikanth Yalavarthi }; 869637de38SSrikanth Yalavarthi 87*8c9bfcb1SSrikanth Yalavarthi int 889637de38SSrikanth Yalavarthi rte_ml_io_float32_to_int8(float scale, uint64_t nb_elements, void *input, void *output) 899637de38SSrikanth Yalavarthi { 909637de38SSrikanth Yalavarthi float *input_buffer; 919637de38SSrikanth Yalavarthi int8_t *output_buffer; 929637de38SSrikanth Yalavarthi uint64_t i; 939637de38SSrikanth Yalavarthi int i32; 949637de38SSrikanth Yalavarthi 959637de38SSrikanth Yalavarthi if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 969637de38SSrikanth Yalavarthi return -EINVAL; 979637de38SSrikanth Yalavarthi 989637de38SSrikanth Yalavarthi input_buffer = (float *)input; 999637de38SSrikanth Yalavarthi output_buffer = (int8_t *)output; 1009637de38SSrikanth Yalavarthi 1019637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 1029637de38SSrikanth Yalavarthi i32 = (int32_t)round((*input_buffer) * scale); 1039637de38SSrikanth Yalavarthi 1049637de38SSrikanth Yalavarthi if (i32 < INT8_MIN) 1059637de38SSrikanth Yalavarthi i32 = INT8_MIN; 1069637de38SSrikanth Yalavarthi 1079637de38SSrikanth Yalavarthi if (i32 > INT8_MAX) 1089637de38SSrikanth Yalavarthi i32 = INT8_MAX; 1099637de38SSrikanth Yalavarthi 1109637de38SSrikanth Yalavarthi *output_buffer = (int8_t)i32; 1119637de38SSrikanth Yalavarthi 1129637de38SSrikanth Yalavarthi input_buffer++; 1139637de38SSrikanth Yalavarthi output_buffer++; 1149637de38SSrikanth Yalavarthi } 1159637de38SSrikanth Yalavarthi 1169637de38SSrikanth Yalavarthi return 0; 1179637de38SSrikanth Yalavarthi } 1189637de38SSrikanth Yalavarthi 119*8c9bfcb1SSrikanth Yalavarthi int 1209637de38SSrikanth Yalavarthi rte_ml_io_int8_to_float32(float scale, uint64_t nb_elements, void *input, void *output) 1219637de38SSrikanth Yalavarthi { 1229637de38SSrikanth Yalavarthi int8_t *input_buffer; 1239637de38SSrikanth Yalavarthi float *output_buffer; 1249637de38SSrikanth Yalavarthi uint64_t i; 1259637de38SSrikanth Yalavarthi 1269637de38SSrikanth Yalavarthi if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 1279637de38SSrikanth Yalavarthi return -EINVAL; 1289637de38SSrikanth Yalavarthi 1299637de38SSrikanth Yalavarthi input_buffer = (int8_t *)input; 1309637de38SSrikanth Yalavarthi output_buffer = (float *)output; 1319637de38SSrikanth Yalavarthi 1329637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 1339637de38SSrikanth Yalavarthi *output_buffer = scale * (float)(*input_buffer); 1349637de38SSrikanth Yalavarthi 1359637de38SSrikanth Yalavarthi input_buffer++; 1369637de38SSrikanth Yalavarthi output_buffer++; 1379637de38SSrikanth Yalavarthi } 1389637de38SSrikanth Yalavarthi 1399637de38SSrikanth Yalavarthi return 0; 1409637de38SSrikanth Yalavarthi } 1419637de38SSrikanth Yalavarthi 142*8c9bfcb1SSrikanth Yalavarthi int 1439637de38SSrikanth Yalavarthi rte_ml_io_float32_to_uint8(float scale, uint64_t nb_elements, void *input, void *output) 1449637de38SSrikanth Yalavarthi { 1459637de38SSrikanth Yalavarthi float *input_buffer; 1469637de38SSrikanth Yalavarthi uint8_t *output_buffer; 1479637de38SSrikanth Yalavarthi int32_t i32; 1489637de38SSrikanth Yalavarthi uint64_t i; 1499637de38SSrikanth Yalavarthi 1509637de38SSrikanth Yalavarthi if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 1519637de38SSrikanth Yalavarthi return -EINVAL; 1529637de38SSrikanth Yalavarthi 1539637de38SSrikanth Yalavarthi input_buffer = (float *)input; 1549637de38SSrikanth Yalavarthi output_buffer = (uint8_t *)output; 1559637de38SSrikanth Yalavarthi 1569637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 1579637de38SSrikanth Yalavarthi i32 = (int32_t)round((*input_buffer) * scale); 1589637de38SSrikanth Yalavarthi 1599637de38SSrikanth Yalavarthi if (i32 < 0) 1609637de38SSrikanth Yalavarthi i32 = 0; 1619637de38SSrikanth Yalavarthi 1629637de38SSrikanth Yalavarthi if (i32 > UINT8_MAX) 1639637de38SSrikanth Yalavarthi i32 = UINT8_MAX; 1649637de38SSrikanth Yalavarthi 1659637de38SSrikanth Yalavarthi *output_buffer = (uint8_t)i32; 1669637de38SSrikanth Yalavarthi 1679637de38SSrikanth Yalavarthi input_buffer++; 1689637de38SSrikanth Yalavarthi output_buffer++; 1699637de38SSrikanth Yalavarthi } 1709637de38SSrikanth Yalavarthi 1719637de38SSrikanth Yalavarthi return 0; 1729637de38SSrikanth Yalavarthi } 1739637de38SSrikanth Yalavarthi 174*8c9bfcb1SSrikanth Yalavarthi int 1759637de38SSrikanth Yalavarthi rte_ml_io_uint8_to_float32(float scale, uint64_t nb_elements, void *input, void *output) 1769637de38SSrikanth Yalavarthi { 1779637de38SSrikanth Yalavarthi uint8_t *input_buffer; 1789637de38SSrikanth Yalavarthi float *output_buffer; 1799637de38SSrikanth Yalavarthi uint64_t i; 1809637de38SSrikanth Yalavarthi 1819637de38SSrikanth Yalavarthi if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 1829637de38SSrikanth Yalavarthi return -EINVAL; 1839637de38SSrikanth Yalavarthi 1849637de38SSrikanth Yalavarthi input_buffer = (uint8_t *)input; 1859637de38SSrikanth Yalavarthi output_buffer = (float *)output; 1869637de38SSrikanth Yalavarthi 1879637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 1889637de38SSrikanth Yalavarthi *output_buffer = scale * (float)(*input_buffer); 1899637de38SSrikanth Yalavarthi 1909637de38SSrikanth Yalavarthi input_buffer++; 1919637de38SSrikanth Yalavarthi output_buffer++; 1929637de38SSrikanth Yalavarthi } 1939637de38SSrikanth Yalavarthi 1949637de38SSrikanth Yalavarthi return 0; 1959637de38SSrikanth Yalavarthi } 1969637de38SSrikanth Yalavarthi 197*8c9bfcb1SSrikanth Yalavarthi int 1989637de38SSrikanth Yalavarthi rte_ml_io_float32_to_int16(float scale, uint64_t nb_elements, void *input, void *output) 1999637de38SSrikanth Yalavarthi { 2009637de38SSrikanth Yalavarthi float *input_buffer; 2019637de38SSrikanth Yalavarthi int16_t *output_buffer; 2029637de38SSrikanth Yalavarthi int32_t i32; 2039637de38SSrikanth Yalavarthi uint64_t i; 2049637de38SSrikanth Yalavarthi 2059637de38SSrikanth Yalavarthi if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 2069637de38SSrikanth Yalavarthi return -EINVAL; 2079637de38SSrikanth Yalavarthi 2089637de38SSrikanth Yalavarthi input_buffer = (float *)input; 2099637de38SSrikanth Yalavarthi output_buffer = (int16_t *)output; 2109637de38SSrikanth Yalavarthi 2119637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 2129637de38SSrikanth Yalavarthi i32 = (int32_t)round((*input_buffer) * scale); 2139637de38SSrikanth Yalavarthi 2149637de38SSrikanth Yalavarthi if (i32 < INT16_MIN) 2159637de38SSrikanth Yalavarthi i32 = INT16_MIN; 2169637de38SSrikanth Yalavarthi 2179637de38SSrikanth Yalavarthi if (i32 > INT16_MAX) 2189637de38SSrikanth Yalavarthi i32 = INT16_MAX; 2199637de38SSrikanth Yalavarthi 2209637de38SSrikanth Yalavarthi *output_buffer = (int16_t)i32; 2219637de38SSrikanth Yalavarthi 2229637de38SSrikanth Yalavarthi input_buffer++; 2239637de38SSrikanth Yalavarthi output_buffer++; 2249637de38SSrikanth Yalavarthi } 2259637de38SSrikanth Yalavarthi 2269637de38SSrikanth Yalavarthi return 0; 2279637de38SSrikanth Yalavarthi } 2289637de38SSrikanth Yalavarthi 229*8c9bfcb1SSrikanth Yalavarthi int 2309637de38SSrikanth Yalavarthi rte_ml_io_int16_to_float32(float scale, uint64_t nb_elements, void *input, void *output) 2319637de38SSrikanth Yalavarthi { 2329637de38SSrikanth Yalavarthi int16_t *input_buffer; 2339637de38SSrikanth Yalavarthi float *output_buffer; 2349637de38SSrikanth Yalavarthi uint64_t i; 2359637de38SSrikanth Yalavarthi 2369637de38SSrikanth Yalavarthi if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 2379637de38SSrikanth Yalavarthi return -EINVAL; 2389637de38SSrikanth Yalavarthi 2399637de38SSrikanth Yalavarthi input_buffer = (int16_t *)input; 2409637de38SSrikanth Yalavarthi output_buffer = (float *)output; 2419637de38SSrikanth Yalavarthi 2429637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 2439637de38SSrikanth Yalavarthi *output_buffer = scale * (float)(*input_buffer); 2449637de38SSrikanth Yalavarthi 2459637de38SSrikanth Yalavarthi input_buffer++; 2469637de38SSrikanth Yalavarthi output_buffer++; 2479637de38SSrikanth Yalavarthi } 2489637de38SSrikanth Yalavarthi 2499637de38SSrikanth Yalavarthi return 0; 2509637de38SSrikanth Yalavarthi } 2519637de38SSrikanth Yalavarthi 252*8c9bfcb1SSrikanth Yalavarthi int 2539637de38SSrikanth Yalavarthi rte_ml_io_float32_to_uint16(float scale, uint64_t nb_elements, void *input, void *output) 2549637de38SSrikanth Yalavarthi { 2559637de38SSrikanth Yalavarthi float *input_buffer; 2569637de38SSrikanth Yalavarthi uint16_t *output_buffer; 2579637de38SSrikanth Yalavarthi int32_t i32; 2589637de38SSrikanth Yalavarthi uint64_t i; 2599637de38SSrikanth Yalavarthi 2609637de38SSrikanth Yalavarthi if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 2619637de38SSrikanth Yalavarthi return -EINVAL; 2629637de38SSrikanth Yalavarthi 2639637de38SSrikanth Yalavarthi input_buffer = (float *)input; 2649637de38SSrikanth Yalavarthi output_buffer = (uint16_t *)output; 2659637de38SSrikanth Yalavarthi 2669637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 2679637de38SSrikanth Yalavarthi i32 = (int32_t)round((*input_buffer) * scale); 2689637de38SSrikanth Yalavarthi 2699637de38SSrikanth Yalavarthi if (i32 < 0) 2709637de38SSrikanth Yalavarthi i32 = 0; 2719637de38SSrikanth Yalavarthi 2729637de38SSrikanth Yalavarthi if (i32 > UINT16_MAX) 2739637de38SSrikanth Yalavarthi i32 = UINT16_MAX; 2749637de38SSrikanth Yalavarthi 2759637de38SSrikanth Yalavarthi *output_buffer = (uint16_t)i32; 2769637de38SSrikanth Yalavarthi 2779637de38SSrikanth Yalavarthi input_buffer++; 2789637de38SSrikanth Yalavarthi output_buffer++; 2799637de38SSrikanth Yalavarthi } 2809637de38SSrikanth Yalavarthi 2819637de38SSrikanth Yalavarthi return 0; 2829637de38SSrikanth Yalavarthi } 2839637de38SSrikanth Yalavarthi 284*8c9bfcb1SSrikanth Yalavarthi int 2859637de38SSrikanth Yalavarthi rte_ml_io_uint16_to_float32(float scale, uint64_t nb_elements, void *input, void *output) 2869637de38SSrikanth Yalavarthi { 2879637de38SSrikanth Yalavarthi uint16_t *input_buffer; 2889637de38SSrikanth Yalavarthi float *output_buffer; 2899637de38SSrikanth Yalavarthi uint64_t i; 2909637de38SSrikanth Yalavarthi 2919637de38SSrikanth Yalavarthi if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 2929637de38SSrikanth Yalavarthi return -EINVAL; 2939637de38SSrikanth Yalavarthi 2949637de38SSrikanth Yalavarthi input_buffer = (uint16_t *)input; 2959637de38SSrikanth Yalavarthi output_buffer = (float *)output; 2969637de38SSrikanth Yalavarthi 2979637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 2989637de38SSrikanth Yalavarthi *output_buffer = scale * (float)(*input_buffer); 2999637de38SSrikanth Yalavarthi 3009637de38SSrikanth Yalavarthi input_buffer++; 3019637de38SSrikanth Yalavarthi output_buffer++; 3029637de38SSrikanth Yalavarthi } 3039637de38SSrikanth Yalavarthi 3049637de38SSrikanth Yalavarthi return 0; 3059637de38SSrikanth Yalavarthi } 3069637de38SSrikanth Yalavarthi 3079637de38SSrikanth Yalavarthi /* Convert a single precision floating point number (float32) into a half precision 3089637de38SSrikanth Yalavarthi * floating point number (float16) using round to nearest rounding mode. 3099637de38SSrikanth Yalavarthi */ 3109637de38SSrikanth Yalavarthi static uint16_t 3119637de38SSrikanth Yalavarthi __float32_to_float16_scalar_rtn(float x) 3129637de38SSrikanth Yalavarthi { 3139637de38SSrikanth Yalavarthi union float32 f32; /* float32 input */ 3149637de38SSrikanth Yalavarthi uint32_t f32_s; /* float32 sign */ 3159637de38SSrikanth Yalavarthi uint32_t f32_e; /* float32 exponent */ 3169637de38SSrikanth Yalavarthi uint32_t f32_m; /* float32 mantissa */ 3179637de38SSrikanth Yalavarthi uint16_t f16_s; /* float16 sign */ 3189637de38SSrikanth Yalavarthi uint16_t f16_e; /* float16 exponent */ 3199637de38SSrikanth Yalavarthi uint16_t f16_m; /* float16 mantissa */ 3209637de38SSrikanth Yalavarthi uint32_t tbits; /* number of truncated bits */ 3219637de38SSrikanth Yalavarthi uint32_t tmsb; /* MSB position of truncated bits */ 3229637de38SSrikanth Yalavarthi uint32_t m_32; /* temporary float32 mantissa */ 3239637de38SSrikanth Yalavarthi uint16_t m_16; /* temporary float16 mantissa */ 3249637de38SSrikanth Yalavarthi uint16_t u16; /* float16 output */ 3259637de38SSrikanth Yalavarthi int be_16; /* float16 biased exponent, signed */ 3269637de38SSrikanth Yalavarthi 3279637de38SSrikanth Yalavarthi f32.f = x; 3289637de38SSrikanth Yalavarthi f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S; 3299637de38SSrikanth Yalavarthi f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E; 3309637de38SSrikanth Yalavarthi f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M; 3319637de38SSrikanth Yalavarthi 3329637de38SSrikanth Yalavarthi f16_s = f32_s; 3339637de38SSrikanth Yalavarthi f16_e = 0; 3349637de38SSrikanth Yalavarthi f16_m = 0; 3359637de38SSrikanth Yalavarthi 3369637de38SSrikanth Yalavarthi switch (f32_e) { 3379637de38SSrikanth Yalavarthi case (0): /* float32: zero or subnormal number */ 3389637de38SSrikanth Yalavarthi f16_e = 0; 339f71c5365SSrikanth Yalavarthi f16_m = 0; /* convert to zero */ 3409637de38SSrikanth Yalavarthi break; 3419637de38SSrikanth Yalavarthi case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */ 3429637de38SSrikanth Yalavarthi f16_e = FP16_MASK_E >> FP16_LSB_E; 3439637de38SSrikanth Yalavarthi if (f32_m == 0) { /* infinity */ 3449637de38SSrikanth Yalavarthi f16_m = 0; 3459637de38SSrikanth Yalavarthi } else { /* nan, propagate mantissa and set MSB of mantissa to 1 */ 3469637de38SSrikanth Yalavarthi f16_m = f32_m >> (FP32_MSB_M - FP16_MSB_M); 3479637de38SSrikanth Yalavarthi f16_m |= BIT(FP16_MSB_M); 3489637de38SSrikanth Yalavarthi } 3499637de38SSrikanth Yalavarthi break; 3509637de38SSrikanth Yalavarthi default: /* float32: normal number */ 3519637de38SSrikanth Yalavarthi /* compute biased exponent for float16 */ 3529637de38SSrikanth Yalavarthi be_16 = (int)f32_e - FP32_BIAS_E + FP16_BIAS_E; 3539637de38SSrikanth Yalavarthi 3549637de38SSrikanth Yalavarthi /* overflow, be_16 = [31-INF], set to infinity */ 3559637de38SSrikanth Yalavarthi if (be_16 >= (int)(FP16_MASK_E >> FP16_LSB_E)) { 3569637de38SSrikanth Yalavarthi f16_e = FP16_MASK_E >> FP16_LSB_E; 3579637de38SSrikanth Yalavarthi f16_m = 0; 3589637de38SSrikanth Yalavarthi } else if ((be_16 >= 1) && (be_16 < (int)(FP16_MASK_E >> FP16_LSB_E))) { 3599637de38SSrikanth Yalavarthi /* normal float16, be_16 = [1:30]*/ 3609637de38SSrikanth Yalavarthi f16_e = be_16; 3619637de38SSrikanth Yalavarthi m_16 = f32_m >> (FP32_LSB_E - FP16_LSB_E); 3629637de38SSrikanth Yalavarthi tmsb = FP32_MSB_M - FP16_MSB_M - 1; 3639637de38SSrikanth Yalavarthi if ((f32_m & GENMASK_U32(tmsb, 0)) > BIT(tmsb)) { 3649637de38SSrikanth Yalavarthi /* round: non-zero truncated bits except MSB */ 3659637de38SSrikanth Yalavarthi m_16++; 3669637de38SSrikanth Yalavarthi 3679637de38SSrikanth Yalavarthi /* overflow into exponent */ 3689637de38SSrikanth Yalavarthi if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1) 3699637de38SSrikanth Yalavarthi f16_e++; 3709637de38SSrikanth Yalavarthi } else if ((f32_m & GENMASK_U32(tmsb, 0)) == BIT(tmsb)) { 3719637de38SSrikanth Yalavarthi /* round: MSB of truncated bits and LSB of m_16 is set */ 3729637de38SSrikanth Yalavarthi if ((m_16 & 0x1) == 0x1) { 3739637de38SSrikanth Yalavarthi m_16++; 3749637de38SSrikanth Yalavarthi 3759637de38SSrikanth Yalavarthi /* overflow into exponent */ 3769637de38SSrikanth Yalavarthi if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1) 3779637de38SSrikanth Yalavarthi f16_e++; 3789637de38SSrikanth Yalavarthi } 3799637de38SSrikanth Yalavarthi } 3809637de38SSrikanth Yalavarthi f16_m = m_16 & FP16_MASK_M; 3819637de38SSrikanth Yalavarthi } else if ((be_16 >= -(int)(FP16_MSB_M)) && (be_16 < 1)) { 3829637de38SSrikanth Yalavarthi /* underflow: zero / subnormal, be_16 = [-9:0] */ 3839637de38SSrikanth Yalavarthi f16_e = 0; 3849637de38SSrikanth Yalavarthi 3859637de38SSrikanth Yalavarthi /* add implicit leading zero */ 3869637de38SSrikanth Yalavarthi m_32 = f32_m | BIT(FP32_LSB_E); 3879637de38SSrikanth Yalavarthi tbits = FP32_LSB_E - FP16_LSB_E - be_16 + 1; 3889637de38SSrikanth Yalavarthi m_16 = m_32 >> tbits; 3899637de38SSrikanth Yalavarthi 3909637de38SSrikanth Yalavarthi /* if non-leading truncated bits are set */ 3919637de38SSrikanth Yalavarthi if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) { 3929637de38SSrikanth Yalavarthi m_16++; 3939637de38SSrikanth Yalavarthi 3949637de38SSrikanth Yalavarthi /* overflow into exponent */ 3959637de38SSrikanth Yalavarthi if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1) 3969637de38SSrikanth Yalavarthi f16_e++; 3979637de38SSrikanth Yalavarthi } else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) { 3989637de38SSrikanth Yalavarthi /* if leading truncated bit is set */ 3999637de38SSrikanth Yalavarthi if ((m_16 & 0x1) == 0x1) { 4009637de38SSrikanth Yalavarthi m_16++; 4019637de38SSrikanth Yalavarthi 4029637de38SSrikanth Yalavarthi /* overflow into exponent */ 4039637de38SSrikanth Yalavarthi if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1) 4049637de38SSrikanth Yalavarthi f16_e++; 4059637de38SSrikanth Yalavarthi } 4069637de38SSrikanth Yalavarthi } 4079637de38SSrikanth Yalavarthi f16_m = m_16 & FP16_MASK_M; 4089637de38SSrikanth Yalavarthi } else if (be_16 == -(int)(FP16_MSB_M + 1)) { 4099637de38SSrikanth Yalavarthi /* underflow: zero, be_16 = [-10] */ 4109637de38SSrikanth Yalavarthi f16_e = 0; 4119637de38SSrikanth Yalavarthi if (f32_m != 0) 4129637de38SSrikanth Yalavarthi f16_m = 1; 4139637de38SSrikanth Yalavarthi else 4149637de38SSrikanth Yalavarthi f16_m = 0; 4159637de38SSrikanth Yalavarthi } else { 4169637de38SSrikanth Yalavarthi /* underflow: zero, be_16 = [-INF:-11] */ 4179637de38SSrikanth Yalavarthi f16_e = 0; 4189637de38SSrikanth Yalavarthi f16_m = 0; 4199637de38SSrikanth Yalavarthi } 4209637de38SSrikanth Yalavarthi 4219637de38SSrikanth Yalavarthi break; 4229637de38SSrikanth Yalavarthi } 4239637de38SSrikanth Yalavarthi 4249637de38SSrikanth Yalavarthi u16 = FP16_PACK(f16_s, f16_e, f16_m); 4259637de38SSrikanth Yalavarthi 4269637de38SSrikanth Yalavarthi return u16; 4279637de38SSrikanth Yalavarthi } 4289637de38SSrikanth Yalavarthi 429*8c9bfcb1SSrikanth Yalavarthi int 4309637de38SSrikanth Yalavarthi rte_ml_io_float32_to_float16(uint64_t nb_elements, void *input, void *output) 4319637de38SSrikanth Yalavarthi { 4329637de38SSrikanth Yalavarthi float *input_buffer; 4339637de38SSrikanth Yalavarthi uint16_t *output_buffer; 4349637de38SSrikanth Yalavarthi uint64_t i; 4359637de38SSrikanth Yalavarthi 4369637de38SSrikanth Yalavarthi if ((nb_elements == 0) || (input == NULL) || (output == NULL)) 4379637de38SSrikanth Yalavarthi return -EINVAL; 4389637de38SSrikanth Yalavarthi 4399637de38SSrikanth Yalavarthi input_buffer = (float *)input; 4409637de38SSrikanth Yalavarthi output_buffer = (uint16_t *)output; 4419637de38SSrikanth Yalavarthi 4429637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 4439637de38SSrikanth Yalavarthi *output_buffer = __float32_to_float16_scalar_rtn(*input_buffer); 4449637de38SSrikanth Yalavarthi 4459637de38SSrikanth Yalavarthi input_buffer = input_buffer + 1; 4469637de38SSrikanth Yalavarthi output_buffer = output_buffer + 1; 4479637de38SSrikanth Yalavarthi } 4489637de38SSrikanth Yalavarthi 4499637de38SSrikanth Yalavarthi return 0; 4509637de38SSrikanth Yalavarthi } 4519637de38SSrikanth Yalavarthi 4529637de38SSrikanth Yalavarthi /* Convert a half precision floating point number (float16) into a single precision 4539637de38SSrikanth Yalavarthi * floating point number (float32). 4549637de38SSrikanth Yalavarthi */ 4559637de38SSrikanth Yalavarthi static float 4569637de38SSrikanth Yalavarthi __float16_to_float32_scalar_rtx(uint16_t f16) 4579637de38SSrikanth Yalavarthi { 4589637de38SSrikanth Yalavarthi union float32 f32; /* float32 output */ 4599637de38SSrikanth Yalavarthi uint16_t f16_s; /* float16 sign */ 4609637de38SSrikanth Yalavarthi uint16_t f16_e; /* float16 exponent */ 4619637de38SSrikanth Yalavarthi uint16_t f16_m; /* float16 mantissa */ 4629637de38SSrikanth Yalavarthi uint32_t f32_s; /* float32 sign */ 4639637de38SSrikanth Yalavarthi uint32_t f32_e; /* float32 exponent */ 4649637de38SSrikanth Yalavarthi uint32_t f32_m; /* float32 mantissa*/ 4659637de38SSrikanth Yalavarthi uint8_t shift; /* number of bits to be shifted */ 4669637de38SSrikanth Yalavarthi uint32_t clz; /* count of leading zeroes */ 4679637de38SSrikanth Yalavarthi int e_16; /* float16 exponent unbiased */ 4689637de38SSrikanth Yalavarthi 4699637de38SSrikanth Yalavarthi f16_s = (f16 & FP16_MASK_S) >> FP16_LSB_S; 4709637de38SSrikanth Yalavarthi f16_e = (f16 & FP16_MASK_E) >> FP16_LSB_E; 4719637de38SSrikanth Yalavarthi f16_m = (f16 & FP16_MASK_M) >> FP16_LSB_M; 4729637de38SSrikanth Yalavarthi 4739637de38SSrikanth Yalavarthi f32_s = f16_s; 4749637de38SSrikanth Yalavarthi switch (f16_e) { 4759637de38SSrikanth Yalavarthi case (FP16_MASK_E >> FP16_LSB_E): /* float16: infinity or nan */ 4769637de38SSrikanth Yalavarthi f32_e = FP32_MASK_E >> FP32_LSB_E; 4779637de38SSrikanth Yalavarthi if (f16_m == 0x0) { /* infinity */ 4789637de38SSrikanth Yalavarthi f32_m = f16_m; 4799637de38SSrikanth Yalavarthi } else { /* nan, propagate mantissa, set MSB of mantissa to 1 */ 4809637de38SSrikanth Yalavarthi f32_m = f16_m; 4819637de38SSrikanth Yalavarthi shift = FP32_MSB_M - FP16_MSB_M; 4829637de38SSrikanth Yalavarthi f32_m = (f32_m << shift) & FP32_MASK_M; 4839637de38SSrikanth Yalavarthi f32_m |= BIT(FP32_MSB_M); 4849637de38SSrikanth Yalavarthi } 4859637de38SSrikanth Yalavarthi break; 4869637de38SSrikanth Yalavarthi case 0: /* float16: zero or sub-normal */ 4879637de38SSrikanth Yalavarthi f32_m = f16_m; 4889637de38SSrikanth Yalavarthi if (f16_m == 0) { /* zero signed */ 4899637de38SSrikanth Yalavarthi f32_e = 0; 4909637de38SSrikanth Yalavarthi } else { /* subnormal numbers */ 4919637de38SSrikanth Yalavarthi clz = __builtin_clz((uint32_t)f16_m) - sizeof(uint32_t) * 8 + FP16_LSB_E; 4929637de38SSrikanth Yalavarthi e_16 = (int)f16_e - clz; 4939637de38SSrikanth Yalavarthi f32_e = FP32_BIAS_E + e_16 - FP16_BIAS_E; 4949637de38SSrikanth Yalavarthi 4959637de38SSrikanth Yalavarthi shift = clz + (FP32_MSB_M - FP16_MSB_M) + 1; 4969637de38SSrikanth Yalavarthi f32_m = (f32_m << shift) & FP32_MASK_M; 4979637de38SSrikanth Yalavarthi } 4989637de38SSrikanth Yalavarthi break; 4999637de38SSrikanth Yalavarthi default: /* normal numbers */ 5009637de38SSrikanth Yalavarthi f32_m = f16_m; 5019637de38SSrikanth Yalavarthi e_16 = (int)f16_e; 5029637de38SSrikanth Yalavarthi f32_e = FP32_BIAS_E + e_16 - FP16_BIAS_E; 5039637de38SSrikanth Yalavarthi 5049637de38SSrikanth Yalavarthi shift = (FP32_MSB_M - FP16_MSB_M); 5059637de38SSrikanth Yalavarthi f32_m = (f32_m << shift) & FP32_MASK_M; 5069637de38SSrikanth Yalavarthi } 5079637de38SSrikanth Yalavarthi 5089637de38SSrikanth Yalavarthi f32.u = FP32_PACK(f32_s, f32_e, f32_m); 5099637de38SSrikanth Yalavarthi 5109637de38SSrikanth Yalavarthi return f32.f; 5119637de38SSrikanth Yalavarthi } 5129637de38SSrikanth Yalavarthi 513*8c9bfcb1SSrikanth Yalavarthi int 5149637de38SSrikanth Yalavarthi rte_ml_io_float16_to_float32(uint64_t nb_elements, void *input, void *output) 5159637de38SSrikanth Yalavarthi { 5169637de38SSrikanth Yalavarthi uint16_t *input_buffer; 5179637de38SSrikanth Yalavarthi float *output_buffer; 5189637de38SSrikanth Yalavarthi uint64_t i; 5199637de38SSrikanth Yalavarthi 5209637de38SSrikanth Yalavarthi if ((nb_elements == 0) || (input == NULL) || (output == NULL)) 5219637de38SSrikanth Yalavarthi return -EINVAL; 5229637de38SSrikanth Yalavarthi 5239637de38SSrikanth Yalavarthi input_buffer = (uint16_t *)input; 5249637de38SSrikanth Yalavarthi output_buffer = (float *)output; 5259637de38SSrikanth Yalavarthi 5269637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 5279637de38SSrikanth Yalavarthi *output_buffer = __float16_to_float32_scalar_rtx(*input_buffer); 5289637de38SSrikanth Yalavarthi 5299637de38SSrikanth Yalavarthi input_buffer = input_buffer + 1; 5309637de38SSrikanth Yalavarthi output_buffer = output_buffer + 1; 5319637de38SSrikanth Yalavarthi } 5329637de38SSrikanth Yalavarthi 5339637de38SSrikanth Yalavarthi return 0; 5349637de38SSrikanth Yalavarthi } 5359637de38SSrikanth Yalavarthi 5369637de38SSrikanth Yalavarthi /* Convert a single precision floating point number (float32) into a 5379637de38SSrikanth Yalavarthi * brain float number (bfloat16) using round to nearest rounding mode. 5389637de38SSrikanth Yalavarthi */ 5399637de38SSrikanth Yalavarthi static uint16_t 5409637de38SSrikanth Yalavarthi __float32_to_bfloat16_scalar_rtn(float x) 5419637de38SSrikanth Yalavarthi { 5429637de38SSrikanth Yalavarthi union float32 f32; /* float32 input */ 5439637de38SSrikanth Yalavarthi uint32_t f32_s; /* float32 sign */ 5449637de38SSrikanth Yalavarthi uint32_t f32_e; /* float32 exponent */ 5459637de38SSrikanth Yalavarthi uint32_t f32_m; /* float32 mantissa */ 5469637de38SSrikanth Yalavarthi uint16_t b16_s; /* float16 sign */ 5479637de38SSrikanth Yalavarthi uint16_t b16_e; /* float16 exponent */ 5489637de38SSrikanth Yalavarthi uint16_t b16_m; /* float16 mantissa */ 5499637de38SSrikanth Yalavarthi uint32_t tbits; /* number of truncated bits */ 5509637de38SSrikanth Yalavarthi uint16_t u16; /* float16 output */ 5519637de38SSrikanth Yalavarthi 5529637de38SSrikanth Yalavarthi f32.f = x; 5539637de38SSrikanth Yalavarthi f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S; 5549637de38SSrikanth Yalavarthi f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E; 5559637de38SSrikanth Yalavarthi f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M; 5569637de38SSrikanth Yalavarthi 5579637de38SSrikanth Yalavarthi b16_s = f32_s; 5589637de38SSrikanth Yalavarthi b16_e = 0; 5599637de38SSrikanth Yalavarthi b16_m = 0; 5609637de38SSrikanth Yalavarthi 5619637de38SSrikanth Yalavarthi switch (f32_e) { 5629637de38SSrikanth Yalavarthi case (0): /* float32: zero or subnormal number */ 5639637de38SSrikanth Yalavarthi b16_e = 0; 5649637de38SSrikanth Yalavarthi if (f32_m == 0) /* zero */ 5659637de38SSrikanth Yalavarthi b16_m = 0; 5669637de38SSrikanth Yalavarthi else /* subnormal float32 number, normal bfloat16 */ 5679637de38SSrikanth Yalavarthi goto bf16_normal; 5689637de38SSrikanth Yalavarthi break; 5699637de38SSrikanth Yalavarthi case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */ 5709637de38SSrikanth Yalavarthi b16_e = BF16_MASK_E >> BF16_LSB_E; 5719637de38SSrikanth Yalavarthi if (f32_m == 0) { /* infinity */ 5729637de38SSrikanth Yalavarthi b16_m = 0; 5739637de38SSrikanth Yalavarthi } else { /* nan, propagate mantissa and set MSB of mantissa to 1 */ 5749637de38SSrikanth Yalavarthi b16_m = f32_m >> (FP32_MSB_M - BF16_MSB_M); 5759637de38SSrikanth Yalavarthi b16_m |= BIT(BF16_MSB_M); 5769637de38SSrikanth Yalavarthi } 5779637de38SSrikanth Yalavarthi break; 5789637de38SSrikanth Yalavarthi default: /* float32: normal number, normal bfloat16 */ 5799637de38SSrikanth Yalavarthi goto bf16_normal; 5809637de38SSrikanth Yalavarthi } 5819637de38SSrikanth Yalavarthi 5829637de38SSrikanth Yalavarthi goto bf16_pack; 5839637de38SSrikanth Yalavarthi 5849637de38SSrikanth Yalavarthi bf16_normal: 5859637de38SSrikanth Yalavarthi b16_e = f32_e; 5869637de38SSrikanth Yalavarthi tbits = FP32_MSB_M - BF16_MSB_M; 5879637de38SSrikanth Yalavarthi b16_m = f32_m >> tbits; 5889637de38SSrikanth Yalavarthi 5899637de38SSrikanth Yalavarthi /* if non-leading truncated bits are set */ 5909637de38SSrikanth Yalavarthi if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) { 5919637de38SSrikanth Yalavarthi b16_m++; 5929637de38SSrikanth Yalavarthi 5939637de38SSrikanth Yalavarthi /* if overflow into exponent */ 5949637de38SSrikanth Yalavarthi if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1) 5959637de38SSrikanth Yalavarthi b16_e++; 5969637de38SSrikanth Yalavarthi } else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) { 5979637de38SSrikanth Yalavarthi /* if only leading truncated bit is set */ 5989637de38SSrikanth Yalavarthi if ((b16_m & 0x1) == 0x1) { 5999637de38SSrikanth Yalavarthi b16_m++; 6009637de38SSrikanth Yalavarthi 6019637de38SSrikanth Yalavarthi /* if overflow into exponent */ 6029637de38SSrikanth Yalavarthi if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1) 6039637de38SSrikanth Yalavarthi b16_e++; 6049637de38SSrikanth Yalavarthi } 6059637de38SSrikanth Yalavarthi } 6069637de38SSrikanth Yalavarthi b16_m = b16_m & BF16_MASK_M; 6079637de38SSrikanth Yalavarthi 6089637de38SSrikanth Yalavarthi bf16_pack: 6099637de38SSrikanth Yalavarthi u16 = BF16_PACK(b16_s, b16_e, b16_m); 6109637de38SSrikanth Yalavarthi 6119637de38SSrikanth Yalavarthi return u16; 6129637de38SSrikanth Yalavarthi } 6139637de38SSrikanth Yalavarthi 614*8c9bfcb1SSrikanth Yalavarthi int 6159637de38SSrikanth Yalavarthi rte_ml_io_float32_to_bfloat16(uint64_t nb_elements, void *input, void *output) 6169637de38SSrikanth Yalavarthi { 6179637de38SSrikanth Yalavarthi float *input_buffer; 6189637de38SSrikanth Yalavarthi uint16_t *output_buffer; 6199637de38SSrikanth Yalavarthi uint64_t i; 6209637de38SSrikanth Yalavarthi 6219637de38SSrikanth Yalavarthi if ((nb_elements == 0) || (input == NULL) || (output == NULL)) 6229637de38SSrikanth Yalavarthi return -EINVAL; 6239637de38SSrikanth Yalavarthi 6249637de38SSrikanth Yalavarthi input_buffer = (float *)input; 6259637de38SSrikanth Yalavarthi output_buffer = (uint16_t *)output; 6269637de38SSrikanth Yalavarthi 6279637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 6289637de38SSrikanth Yalavarthi *output_buffer = __float32_to_bfloat16_scalar_rtn(*input_buffer); 6299637de38SSrikanth Yalavarthi 6309637de38SSrikanth Yalavarthi input_buffer = input_buffer + 1; 6319637de38SSrikanth Yalavarthi output_buffer = output_buffer + 1; 6329637de38SSrikanth Yalavarthi } 6339637de38SSrikanth Yalavarthi 6349637de38SSrikanth Yalavarthi return 0; 6359637de38SSrikanth Yalavarthi } 6369637de38SSrikanth Yalavarthi 6379637de38SSrikanth Yalavarthi /* Convert a brain float number (bfloat16) into a 6389637de38SSrikanth Yalavarthi * single precision floating point number (float32). 6399637de38SSrikanth Yalavarthi */ 6409637de38SSrikanth Yalavarthi static float 6419637de38SSrikanth Yalavarthi __bfloat16_to_float32_scalar_rtx(uint16_t f16) 6429637de38SSrikanth Yalavarthi { 6439637de38SSrikanth Yalavarthi union float32 f32; /* float32 output */ 6449637de38SSrikanth Yalavarthi uint16_t b16_s; /* float16 sign */ 6459637de38SSrikanth Yalavarthi uint16_t b16_e; /* float16 exponent */ 6469637de38SSrikanth Yalavarthi uint16_t b16_m; /* float16 mantissa */ 6479637de38SSrikanth Yalavarthi uint32_t f32_s; /* float32 sign */ 6489637de38SSrikanth Yalavarthi uint32_t f32_e; /* float32 exponent */ 6499637de38SSrikanth Yalavarthi uint32_t f32_m; /* float32 mantissa*/ 6509637de38SSrikanth Yalavarthi uint8_t shift; /* number of bits to be shifted */ 6519637de38SSrikanth Yalavarthi 6529637de38SSrikanth Yalavarthi b16_s = (f16 & BF16_MASK_S) >> BF16_LSB_S; 6539637de38SSrikanth Yalavarthi b16_e = (f16 & BF16_MASK_E) >> BF16_LSB_E; 6549637de38SSrikanth Yalavarthi b16_m = (f16 & BF16_MASK_M) >> BF16_LSB_M; 6559637de38SSrikanth Yalavarthi 6569637de38SSrikanth Yalavarthi f32_s = b16_s; 6579637de38SSrikanth Yalavarthi switch (b16_e) { 6589637de38SSrikanth Yalavarthi case (BF16_MASK_E >> BF16_LSB_E): /* bfloat16: infinity or nan */ 6599637de38SSrikanth Yalavarthi f32_e = FP32_MASK_E >> FP32_LSB_E; 6609637de38SSrikanth Yalavarthi if (b16_m == 0x0) { /* infinity */ 6619637de38SSrikanth Yalavarthi f32_m = 0; 6629637de38SSrikanth Yalavarthi } else { /* nan, propagate mantissa, set MSB of mantissa to 1 */ 6639637de38SSrikanth Yalavarthi f32_m = b16_m; 6649637de38SSrikanth Yalavarthi shift = FP32_MSB_M - BF16_MSB_M; 6659637de38SSrikanth Yalavarthi f32_m = (f32_m << shift) & FP32_MASK_M; 6669637de38SSrikanth Yalavarthi f32_m |= BIT(FP32_MSB_M); 6679637de38SSrikanth Yalavarthi } 6689637de38SSrikanth Yalavarthi break; 6699637de38SSrikanth Yalavarthi case 0: /* bfloat16: zero or subnormal */ 6709637de38SSrikanth Yalavarthi f32_m = b16_m; 6719637de38SSrikanth Yalavarthi if (b16_m == 0) { /* zero signed */ 6729637de38SSrikanth Yalavarthi f32_e = 0; 6739637de38SSrikanth Yalavarthi } else { /* subnormal numbers */ 6749637de38SSrikanth Yalavarthi goto fp32_normal; 6759637de38SSrikanth Yalavarthi } 6769637de38SSrikanth Yalavarthi break; 6779637de38SSrikanth Yalavarthi default: /* bfloat16: normal number */ 6789637de38SSrikanth Yalavarthi goto fp32_normal; 6799637de38SSrikanth Yalavarthi } 6809637de38SSrikanth Yalavarthi 6819637de38SSrikanth Yalavarthi goto fp32_pack; 6829637de38SSrikanth Yalavarthi 6839637de38SSrikanth Yalavarthi fp32_normal: 6849637de38SSrikanth Yalavarthi f32_m = b16_m; 6859637de38SSrikanth Yalavarthi f32_e = FP32_BIAS_E + b16_e - BF16_BIAS_E; 6869637de38SSrikanth Yalavarthi 6879637de38SSrikanth Yalavarthi shift = (FP32_MSB_M - BF16_MSB_M); 6889637de38SSrikanth Yalavarthi f32_m = (f32_m << shift) & FP32_MASK_M; 6899637de38SSrikanth Yalavarthi 6909637de38SSrikanth Yalavarthi fp32_pack: 6919637de38SSrikanth Yalavarthi f32.u = FP32_PACK(f32_s, f32_e, f32_m); 6929637de38SSrikanth Yalavarthi 6939637de38SSrikanth Yalavarthi return f32.f; 6949637de38SSrikanth Yalavarthi } 6959637de38SSrikanth Yalavarthi 696*8c9bfcb1SSrikanth Yalavarthi int 6979637de38SSrikanth Yalavarthi rte_ml_io_bfloat16_to_float32(uint64_t nb_elements, void *input, void *output) 6989637de38SSrikanth Yalavarthi { 6999637de38SSrikanth Yalavarthi uint16_t *input_buffer; 7009637de38SSrikanth Yalavarthi float *output_buffer; 7019637de38SSrikanth Yalavarthi uint64_t i; 7029637de38SSrikanth Yalavarthi 7039637de38SSrikanth Yalavarthi if ((nb_elements == 0) || (input == NULL) || (output == NULL)) 7049637de38SSrikanth Yalavarthi return -EINVAL; 7059637de38SSrikanth Yalavarthi 7069637de38SSrikanth Yalavarthi input_buffer = (uint16_t *)input; 7079637de38SSrikanth Yalavarthi output_buffer = (float *)output; 7089637de38SSrikanth Yalavarthi 7099637de38SSrikanth Yalavarthi for (i = 0; i < nb_elements; i++) { 7109637de38SSrikanth Yalavarthi *output_buffer = __bfloat16_to_float32_scalar_rtx(*input_buffer); 7119637de38SSrikanth Yalavarthi 7129637de38SSrikanth Yalavarthi input_buffer = input_buffer + 1; 7139637de38SSrikanth Yalavarthi output_buffer = output_buffer + 1; 7149637de38SSrikanth Yalavarthi } 7159637de38SSrikanth Yalavarthi 7169637de38SSrikanth Yalavarthi return 0; 7179637de38SSrikanth Yalavarthi } 718