xref: /dpdk/lib/mldev/mldev_utils_scalar.c (revision 8c9bfcb1553d756eb5392a56ac7813b3865c3ec7)
19637de38SSrikanth Yalavarthi /* SPDX-License-Identifier: BSD-3-Clause
29637de38SSrikanth Yalavarthi  * Copyright (c) 2022 Marvell.
39637de38SSrikanth Yalavarthi  */
49637de38SSrikanth Yalavarthi 
59637de38SSrikanth Yalavarthi #include <errno.h>
69637de38SSrikanth Yalavarthi #include <math.h>
79637de38SSrikanth Yalavarthi #include <stdint.h>
89637de38SSrikanth Yalavarthi 
99637de38SSrikanth Yalavarthi #include "mldev_utils.h"
109637de38SSrikanth Yalavarthi 
119637de38SSrikanth Yalavarthi /* Description:
129637de38SSrikanth Yalavarthi  * This file implements scalar versions of Machine Learning utility functions used to convert data
139637de38SSrikanth Yalavarthi  * types from higher precision to lower precision and vice-versa.
149637de38SSrikanth Yalavarthi  */
159637de38SSrikanth Yalavarthi 
169637de38SSrikanth Yalavarthi #ifndef BIT
179637de38SSrikanth Yalavarthi #define BIT(nr) (1UL << (nr))
189637de38SSrikanth Yalavarthi #endif
199637de38SSrikanth Yalavarthi 
209637de38SSrikanth Yalavarthi #ifndef BITS_PER_LONG
219637de38SSrikanth Yalavarthi #define BITS_PER_LONG (__SIZEOF_LONG__ * 8)
229637de38SSrikanth Yalavarthi #endif
239637de38SSrikanth Yalavarthi 
249637de38SSrikanth Yalavarthi #ifndef GENMASK_U32
259637de38SSrikanth Yalavarthi #define GENMASK_U32(h, l) (((~0UL) << (l)) & (~0UL >> (BITS_PER_LONG - 1 - (h))))
269637de38SSrikanth Yalavarthi #endif
279637de38SSrikanth Yalavarthi 
289637de38SSrikanth Yalavarthi /* float32: bit index of MSB & LSB of sign, exponent and mantissa */
299637de38SSrikanth Yalavarthi #define FP32_LSB_M 0
309637de38SSrikanth Yalavarthi #define FP32_MSB_M 22
319637de38SSrikanth Yalavarthi #define FP32_LSB_E 23
329637de38SSrikanth Yalavarthi #define FP32_MSB_E 30
339637de38SSrikanth Yalavarthi #define FP32_LSB_S 31
349637de38SSrikanth Yalavarthi #define FP32_MSB_S 31
359637de38SSrikanth Yalavarthi 
369637de38SSrikanth Yalavarthi /* float32: bitmask for sign, exponent and mantissa */
379637de38SSrikanth Yalavarthi #define FP32_MASK_S GENMASK_U32(FP32_MSB_S, FP32_LSB_S)
389637de38SSrikanth Yalavarthi #define FP32_MASK_E GENMASK_U32(FP32_MSB_E, FP32_LSB_E)
399637de38SSrikanth Yalavarthi #define FP32_MASK_M GENMASK_U32(FP32_MSB_M, FP32_LSB_M)
409637de38SSrikanth Yalavarthi 
419637de38SSrikanth Yalavarthi /* float16: bit index of MSB & LSB of sign, exponent and mantissa */
429637de38SSrikanth Yalavarthi #define FP16_LSB_M 0
439637de38SSrikanth Yalavarthi #define FP16_MSB_M 9
449637de38SSrikanth Yalavarthi #define FP16_LSB_E 10
459637de38SSrikanth Yalavarthi #define FP16_MSB_E 14
469637de38SSrikanth Yalavarthi #define FP16_LSB_S 15
479637de38SSrikanth Yalavarthi #define FP16_MSB_S 15
489637de38SSrikanth Yalavarthi 
499637de38SSrikanth Yalavarthi /* float16: bitmask for sign, exponent and mantissa */
509637de38SSrikanth Yalavarthi #define FP16_MASK_S GENMASK_U32(FP16_MSB_S, FP16_LSB_S)
519637de38SSrikanth Yalavarthi #define FP16_MASK_E GENMASK_U32(FP16_MSB_E, FP16_LSB_E)
529637de38SSrikanth Yalavarthi #define FP16_MASK_M GENMASK_U32(FP16_MSB_M, FP16_LSB_M)
539637de38SSrikanth Yalavarthi 
549637de38SSrikanth Yalavarthi /* bfloat16: bit index of MSB & LSB of sign, exponent and mantissa */
559637de38SSrikanth Yalavarthi #define BF16_LSB_M 0
569637de38SSrikanth Yalavarthi #define BF16_MSB_M 6
579637de38SSrikanth Yalavarthi #define BF16_LSB_E 7
589637de38SSrikanth Yalavarthi #define BF16_MSB_E 14
599637de38SSrikanth Yalavarthi #define BF16_LSB_S 15
609637de38SSrikanth Yalavarthi #define BF16_MSB_S 15
619637de38SSrikanth Yalavarthi 
629637de38SSrikanth Yalavarthi /* bfloat16: bitmask for sign, exponent and mantissa */
639637de38SSrikanth Yalavarthi #define BF16_MASK_S GENMASK_U32(BF16_MSB_S, BF16_LSB_S)
649637de38SSrikanth Yalavarthi #define BF16_MASK_E GENMASK_U32(BF16_MSB_E, BF16_LSB_E)
659637de38SSrikanth Yalavarthi #define BF16_MASK_M GENMASK_U32(BF16_MSB_M, BF16_LSB_M)
669637de38SSrikanth Yalavarthi 
679637de38SSrikanth Yalavarthi /* Exponent bias */
689637de38SSrikanth Yalavarthi #define FP32_BIAS_E 127
699637de38SSrikanth Yalavarthi #define FP16_BIAS_E 15
709637de38SSrikanth Yalavarthi #define BF16_BIAS_E 127
719637de38SSrikanth Yalavarthi 
729637de38SSrikanth Yalavarthi #define FP32_PACK(sign, exponent, mantissa)                                                        \
739637de38SSrikanth Yalavarthi 	(((sign) << FP32_LSB_S) | ((exponent) << FP32_LSB_E) | (mantissa))
749637de38SSrikanth Yalavarthi 
759637de38SSrikanth Yalavarthi #define FP16_PACK(sign, exponent, mantissa)                                                        \
769637de38SSrikanth Yalavarthi 	(((sign) << FP16_LSB_S) | ((exponent) << FP16_LSB_E) | (mantissa))
779637de38SSrikanth Yalavarthi 
789637de38SSrikanth Yalavarthi #define BF16_PACK(sign, exponent, mantissa)                                                        \
799637de38SSrikanth Yalavarthi 	(((sign) << BF16_LSB_S) | ((exponent) << BF16_LSB_E) | (mantissa))
809637de38SSrikanth Yalavarthi 
819637de38SSrikanth Yalavarthi /* Represent float32 as float and uint32_t */
829637de38SSrikanth Yalavarthi union float32 {
839637de38SSrikanth Yalavarthi 	float f;
849637de38SSrikanth Yalavarthi 	uint32_t u;
859637de38SSrikanth Yalavarthi };
869637de38SSrikanth Yalavarthi 
87*8c9bfcb1SSrikanth Yalavarthi int
889637de38SSrikanth Yalavarthi rte_ml_io_float32_to_int8(float scale, uint64_t nb_elements, void *input, void *output)
899637de38SSrikanth Yalavarthi {
909637de38SSrikanth Yalavarthi 	float *input_buffer;
919637de38SSrikanth Yalavarthi 	int8_t *output_buffer;
929637de38SSrikanth Yalavarthi 	uint64_t i;
939637de38SSrikanth Yalavarthi 	int i32;
949637de38SSrikanth Yalavarthi 
959637de38SSrikanth Yalavarthi 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
969637de38SSrikanth Yalavarthi 		return -EINVAL;
979637de38SSrikanth Yalavarthi 
989637de38SSrikanth Yalavarthi 	input_buffer = (float *)input;
999637de38SSrikanth Yalavarthi 	output_buffer = (int8_t *)output;
1009637de38SSrikanth Yalavarthi 
1019637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
1029637de38SSrikanth Yalavarthi 		i32 = (int32_t)round((*input_buffer) * scale);
1039637de38SSrikanth Yalavarthi 
1049637de38SSrikanth Yalavarthi 		if (i32 < INT8_MIN)
1059637de38SSrikanth Yalavarthi 			i32 = INT8_MIN;
1069637de38SSrikanth Yalavarthi 
1079637de38SSrikanth Yalavarthi 		if (i32 > INT8_MAX)
1089637de38SSrikanth Yalavarthi 			i32 = INT8_MAX;
1099637de38SSrikanth Yalavarthi 
1109637de38SSrikanth Yalavarthi 		*output_buffer = (int8_t)i32;
1119637de38SSrikanth Yalavarthi 
1129637de38SSrikanth Yalavarthi 		input_buffer++;
1139637de38SSrikanth Yalavarthi 		output_buffer++;
1149637de38SSrikanth Yalavarthi 	}
1159637de38SSrikanth Yalavarthi 
1169637de38SSrikanth Yalavarthi 	return 0;
1179637de38SSrikanth Yalavarthi }
1189637de38SSrikanth Yalavarthi 
119*8c9bfcb1SSrikanth Yalavarthi int
1209637de38SSrikanth Yalavarthi rte_ml_io_int8_to_float32(float scale, uint64_t nb_elements, void *input, void *output)
1219637de38SSrikanth Yalavarthi {
1229637de38SSrikanth Yalavarthi 	int8_t *input_buffer;
1239637de38SSrikanth Yalavarthi 	float *output_buffer;
1249637de38SSrikanth Yalavarthi 	uint64_t i;
1259637de38SSrikanth Yalavarthi 
1269637de38SSrikanth Yalavarthi 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
1279637de38SSrikanth Yalavarthi 		return -EINVAL;
1289637de38SSrikanth Yalavarthi 
1299637de38SSrikanth Yalavarthi 	input_buffer = (int8_t *)input;
1309637de38SSrikanth Yalavarthi 	output_buffer = (float *)output;
1319637de38SSrikanth Yalavarthi 
1329637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
1339637de38SSrikanth Yalavarthi 		*output_buffer = scale * (float)(*input_buffer);
1349637de38SSrikanth Yalavarthi 
1359637de38SSrikanth Yalavarthi 		input_buffer++;
1369637de38SSrikanth Yalavarthi 		output_buffer++;
1379637de38SSrikanth Yalavarthi 	}
1389637de38SSrikanth Yalavarthi 
1399637de38SSrikanth Yalavarthi 	return 0;
1409637de38SSrikanth Yalavarthi }
1419637de38SSrikanth Yalavarthi 
142*8c9bfcb1SSrikanth Yalavarthi int
1439637de38SSrikanth Yalavarthi rte_ml_io_float32_to_uint8(float scale, uint64_t nb_elements, void *input, void *output)
1449637de38SSrikanth Yalavarthi {
1459637de38SSrikanth Yalavarthi 	float *input_buffer;
1469637de38SSrikanth Yalavarthi 	uint8_t *output_buffer;
1479637de38SSrikanth Yalavarthi 	int32_t i32;
1489637de38SSrikanth Yalavarthi 	uint64_t i;
1499637de38SSrikanth Yalavarthi 
1509637de38SSrikanth Yalavarthi 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
1519637de38SSrikanth Yalavarthi 		return -EINVAL;
1529637de38SSrikanth Yalavarthi 
1539637de38SSrikanth Yalavarthi 	input_buffer = (float *)input;
1549637de38SSrikanth Yalavarthi 	output_buffer = (uint8_t *)output;
1559637de38SSrikanth Yalavarthi 
1569637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
1579637de38SSrikanth Yalavarthi 		i32 = (int32_t)round((*input_buffer) * scale);
1589637de38SSrikanth Yalavarthi 
1599637de38SSrikanth Yalavarthi 		if (i32 < 0)
1609637de38SSrikanth Yalavarthi 			i32 = 0;
1619637de38SSrikanth Yalavarthi 
1629637de38SSrikanth Yalavarthi 		if (i32 > UINT8_MAX)
1639637de38SSrikanth Yalavarthi 			i32 = UINT8_MAX;
1649637de38SSrikanth Yalavarthi 
1659637de38SSrikanth Yalavarthi 		*output_buffer = (uint8_t)i32;
1669637de38SSrikanth Yalavarthi 
1679637de38SSrikanth Yalavarthi 		input_buffer++;
1689637de38SSrikanth Yalavarthi 		output_buffer++;
1699637de38SSrikanth Yalavarthi 	}
1709637de38SSrikanth Yalavarthi 
1719637de38SSrikanth Yalavarthi 	return 0;
1729637de38SSrikanth Yalavarthi }
1739637de38SSrikanth Yalavarthi 
174*8c9bfcb1SSrikanth Yalavarthi int
1759637de38SSrikanth Yalavarthi rte_ml_io_uint8_to_float32(float scale, uint64_t nb_elements, void *input, void *output)
1769637de38SSrikanth Yalavarthi {
1779637de38SSrikanth Yalavarthi 	uint8_t *input_buffer;
1789637de38SSrikanth Yalavarthi 	float *output_buffer;
1799637de38SSrikanth Yalavarthi 	uint64_t i;
1809637de38SSrikanth Yalavarthi 
1819637de38SSrikanth Yalavarthi 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
1829637de38SSrikanth Yalavarthi 		return -EINVAL;
1839637de38SSrikanth Yalavarthi 
1849637de38SSrikanth Yalavarthi 	input_buffer = (uint8_t *)input;
1859637de38SSrikanth Yalavarthi 	output_buffer = (float *)output;
1869637de38SSrikanth Yalavarthi 
1879637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
1889637de38SSrikanth Yalavarthi 		*output_buffer = scale * (float)(*input_buffer);
1899637de38SSrikanth Yalavarthi 
1909637de38SSrikanth Yalavarthi 		input_buffer++;
1919637de38SSrikanth Yalavarthi 		output_buffer++;
1929637de38SSrikanth Yalavarthi 	}
1939637de38SSrikanth Yalavarthi 
1949637de38SSrikanth Yalavarthi 	return 0;
1959637de38SSrikanth Yalavarthi }
1969637de38SSrikanth Yalavarthi 
197*8c9bfcb1SSrikanth Yalavarthi int
1989637de38SSrikanth Yalavarthi rte_ml_io_float32_to_int16(float scale, uint64_t nb_elements, void *input, void *output)
1999637de38SSrikanth Yalavarthi {
2009637de38SSrikanth Yalavarthi 	float *input_buffer;
2019637de38SSrikanth Yalavarthi 	int16_t *output_buffer;
2029637de38SSrikanth Yalavarthi 	int32_t i32;
2039637de38SSrikanth Yalavarthi 	uint64_t i;
2049637de38SSrikanth Yalavarthi 
2059637de38SSrikanth Yalavarthi 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
2069637de38SSrikanth Yalavarthi 		return -EINVAL;
2079637de38SSrikanth Yalavarthi 
2089637de38SSrikanth Yalavarthi 	input_buffer = (float *)input;
2099637de38SSrikanth Yalavarthi 	output_buffer = (int16_t *)output;
2109637de38SSrikanth Yalavarthi 
2119637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
2129637de38SSrikanth Yalavarthi 		i32 = (int32_t)round((*input_buffer) * scale);
2139637de38SSrikanth Yalavarthi 
2149637de38SSrikanth Yalavarthi 		if (i32 < INT16_MIN)
2159637de38SSrikanth Yalavarthi 			i32 = INT16_MIN;
2169637de38SSrikanth Yalavarthi 
2179637de38SSrikanth Yalavarthi 		if (i32 > INT16_MAX)
2189637de38SSrikanth Yalavarthi 			i32 = INT16_MAX;
2199637de38SSrikanth Yalavarthi 
2209637de38SSrikanth Yalavarthi 		*output_buffer = (int16_t)i32;
2219637de38SSrikanth Yalavarthi 
2229637de38SSrikanth Yalavarthi 		input_buffer++;
2239637de38SSrikanth Yalavarthi 		output_buffer++;
2249637de38SSrikanth Yalavarthi 	}
2259637de38SSrikanth Yalavarthi 
2269637de38SSrikanth Yalavarthi 	return 0;
2279637de38SSrikanth Yalavarthi }
2289637de38SSrikanth Yalavarthi 
229*8c9bfcb1SSrikanth Yalavarthi int
2309637de38SSrikanth Yalavarthi rte_ml_io_int16_to_float32(float scale, uint64_t nb_elements, void *input, void *output)
2319637de38SSrikanth Yalavarthi {
2329637de38SSrikanth Yalavarthi 	int16_t *input_buffer;
2339637de38SSrikanth Yalavarthi 	float *output_buffer;
2349637de38SSrikanth Yalavarthi 	uint64_t i;
2359637de38SSrikanth Yalavarthi 
2369637de38SSrikanth Yalavarthi 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
2379637de38SSrikanth Yalavarthi 		return -EINVAL;
2389637de38SSrikanth Yalavarthi 
2399637de38SSrikanth Yalavarthi 	input_buffer = (int16_t *)input;
2409637de38SSrikanth Yalavarthi 	output_buffer = (float *)output;
2419637de38SSrikanth Yalavarthi 
2429637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
2439637de38SSrikanth Yalavarthi 		*output_buffer = scale * (float)(*input_buffer);
2449637de38SSrikanth Yalavarthi 
2459637de38SSrikanth Yalavarthi 		input_buffer++;
2469637de38SSrikanth Yalavarthi 		output_buffer++;
2479637de38SSrikanth Yalavarthi 	}
2489637de38SSrikanth Yalavarthi 
2499637de38SSrikanth Yalavarthi 	return 0;
2509637de38SSrikanth Yalavarthi }
2519637de38SSrikanth Yalavarthi 
252*8c9bfcb1SSrikanth Yalavarthi int
2539637de38SSrikanth Yalavarthi rte_ml_io_float32_to_uint16(float scale, uint64_t nb_elements, void *input, void *output)
2549637de38SSrikanth Yalavarthi {
2559637de38SSrikanth Yalavarthi 	float *input_buffer;
2569637de38SSrikanth Yalavarthi 	uint16_t *output_buffer;
2579637de38SSrikanth Yalavarthi 	int32_t i32;
2589637de38SSrikanth Yalavarthi 	uint64_t i;
2599637de38SSrikanth Yalavarthi 
2609637de38SSrikanth Yalavarthi 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
2619637de38SSrikanth Yalavarthi 		return -EINVAL;
2629637de38SSrikanth Yalavarthi 
2639637de38SSrikanth Yalavarthi 	input_buffer = (float *)input;
2649637de38SSrikanth Yalavarthi 	output_buffer = (uint16_t *)output;
2659637de38SSrikanth Yalavarthi 
2669637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
2679637de38SSrikanth Yalavarthi 		i32 = (int32_t)round((*input_buffer) * scale);
2689637de38SSrikanth Yalavarthi 
2699637de38SSrikanth Yalavarthi 		if (i32 < 0)
2709637de38SSrikanth Yalavarthi 			i32 = 0;
2719637de38SSrikanth Yalavarthi 
2729637de38SSrikanth Yalavarthi 		if (i32 > UINT16_MAX)
2739637de38SSrikanth Yalavarthi 			i32 = UINT16_MAX;
2749637de38SSrikanth Yalavarthi 
2759637de38SSrikanth Yalavarthi 		*output_buffer = (uint16_t)i32;
2769637de38SSrikanth Yalavarthi 
2779637de38SSrikanth Yalavarthi 		input_buffer++;
2789637de38SSrikanth Yalavarthi 		output_buffer++;
2799637de38SSrikanth Yalavarthi 	}
2809637de38SSrikanth Yalavarthi 
2819637de38SSrikanth Yalavarthi 	return 0;
2829637de38SSrikanth Yalavarthi }
2839637de38SSrikanth Yalavarthi 
284*8c9bfcb1SSrikanth Yalavarthi int
2859637de38SSrikanth Yalavarthi rte_ml_io_uint16_to_float32(float scale, uint64_t nb_elements, void *input, void *output)
2869637de38SSrikanth Yalavarthi {
2879637de38SSrikanth Yalavarthi 	uint16_t *input_buffer;
2889637de38SSrikanth Yalavarthi 	float *output_buffer;
2899637de38SSrikanth Yalavarthi 	uint64_t i;
2909637de38SSrikanth Yalavarthi 
2919637de38SSrikanth Yalavarthi 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
2929637de38SSrikanth Yalavarthi 		return -EINVAL;
2939637de38SSrikanth Yalavarthi 
2949637de38SSrikanth Yalavarthi 	input_buffer = (uint16_t *)input;
2959637de38SSrikanth Yalavarthi 	output_buffer = (float *)output;
2969637de38SSrikanth Yalavarthi 
2979637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
2989637de38SSrikanth Yalavarthi 		*output_buffer = scale * (float)(*input_buffer);
2999637de38SSrikanth Yalavarthi 
3009637de38SSrikanth Yalavarthi 		input_buffer++;
3019637de38SSrikanth Yalavarthi 		output_buffer++;
3029637de38SSrikanth Yalavarthi 	}
3039637de38SSrikanth Yalavarthi 
3049637de38SSrikanth Yalavarthi 	return 0;
3059637de38SSrikanth Yalavarthi }
3069637de38SSrikanth Yalavarthi 
3079637de38SSrikanth Yalavarthi /* Convert a single precision floating point number (float32) into a half precision
3089637de38SSrikanth Yalavarthi  * floating point number (float16) using round to nearest rounding mode.
3099637de38SSrikanth Yalavarthi  */
3109637de38SSrikanth Yalavarthi static uint16_t
3119637de38SSrikanth Yalavarthi __float32_to_float16_scalar_rtn(float x)
3129637de38SSrikanth Yalavarthi {
3139637de38SSrikanth Yalavarthi 	union float32 f32; /* float32 input */
3149637de38SSrikanth Yalavarthi 	uint32_t f32_s;	   /* float32 sign */
3159637de38SSrikanth Yalavarthi 	uint32_t f32_e;	   /* float32 exponent */
3169637de38SSrikanth Yalavarthi 	uint32_t f32_m;	   /* float32 mantissa */
3179637de38SSrikanth Yalavarthi 	uint16_t f16_s;	   /* float16 sign */
3189637de38SSrikanth Yalavarthi 	uint16_t f16_e;	   /* float16 exponent */
3199637de38SSrikanth Yalavarthi 	uint16_t f16_m;	   /* float16 mantissa */
3209637de38SSrikanth Yalavarthi 	uint32_t tbits;	   /* number of truncated bits */
3219637de38SSrikanth Yalavarthi 	uint32_t tmsb;	   /* MSB position of truncated bits */
3229637de38SSrikanth Yalavarthi 	uint32_t m_32;	   /* temporary float32 mantissa */
3239637de38SSrikanth Yalavarthi 	uint16_t m_16;	   /* temporary float16 mantissa */
3249637de38SSrikanth Yalavarthi 	uint16_t u16;	   /* float16 output */
3259637de38SSrikanth Yalavarthi 	int be_16;	   /* float16 biased exponent, signed */
3269637de38SSrikanth Yalavarthi 
3279637de38SSrikanth Yalavarthi 	f32.f = x;
3289637de38SSrikanth Yalavarthi 	f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S;
3299637de38SSrikanth Yalavarthi 	f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E;
3309637de38SSrikanth Yalavarthi 	f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M;
3319637de38SSrikanth Yalavarthi 
3329637de38SSrikanth Yalavarthi 	f16_s = f32_s;
3339637de38SSrikanth Yalavarthi 	f16_e = 0;
3349637de38SSrikanth Yalavarthi 	f16_m = 0;
3359637de38SSrikanth Yalavarthi 
3369637de38SSrikanth Yalavarthi 	switch (f32_e) {
3379637de38SSrikanth Yalavarthi 	case (0): /* float32: zero or subnormal number */
3389637de38SSrikanth Yalavarthi 		f16_e = 0;
339f71c5365SSrikanth Yalavarthi 		f16_m = 0; /* convert to zero */
3409637de38SSrikanth Yalavarthi 		break;
3419637de38SSrikanth Yalavarthi 	case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */
3429637de38SSrikanth Yalavarthi 		f16_e = FP16_MASK_E >> FP16_LSB_E;
3439637de38SSrikanth Yalavarthi 		if (f32_m == 0) { /* infinity */
3449637de38SSrikanth Yalavarthi 			f16_m = 0;
3459637de38SSrikanth Yalavarthi 		} else { /* nan, propagate mantissa and set MSB of mantissa to 1 */
3469637de38SSrikanth Yalavarthi 			f16_m = f32_m >> (FP32_MSB_M - FP16_MSB_M);
3479637de38SSrikanth Yalavarthi 			f16_m |= BIT(FP16_MSB_M);
3489637de38SSrikanth Yalavarthi 		}
3499637de38SSrikanth Yalavarthi 		break;
3509637de38SSrikanth Yalavarthi 	default: /* float32: normal number */
3519637de38SSrikanth Yalavarthi 		/* compute biased exponent for float16 */
3529637de38SSrikanth Yalavarthi 		be_16 = (int)f32_e - FP32_BIAS_E + FP16_BIAS_E;
3539637de38SSrikanth Yalavarthi 
3549637de38SSrikanth Yalavarthi 		/* overflow, be_16 = [31-INF], set to infinity */
3559637de38SSrikanth Yalavarthi 		if (be_16 >= (int)(FP16_MASK_E >> FP16_LSB_E)) {
3569637de38SSrikanth Yalavarthi 			f16_e = FP16_MASK_E >> FP16_LSB_E;
3579637de38SSrikanth Yalavarthi 			f16_m = 0;
3589637de38SSrikanth Yalavarthi 		} else if ((be_16 >= 1) && (be_16 < (int)(FP16_MASK_E >> FP16_LSB_E))) {
3599637de38SSrikanth Yalavarthi 			/* normal float16, be_16 = [1:30]*/
3609637de38SSrikanth Yalavarthi 			f16_e = be_16;
3619637de38SSrikanth Yalavarthi 			m_16 = f32_m >> (FP32_LSB_E - FP16_LSB_E);
3629637de38SSrikanth Yalavarthi 			tmsb = FP32_MSB_M - FP16_MSB_M - 1;
3639637de38SSrikanth Yalavarthi 			if ((f32_m & GENMASK_U32(tmsb, 0)) > BIT(tmsb)) {
3649637de38SSrikanth Yalavarthi 				/* round: non-zero truncated bits except MSB */
3659637de38SSrikanth Yalavarthi 				m_16++;
3669637de38SSrikanth Yalavarthi 
3679637de38SSrikanth Yalavarthi 				/* overflow into exponent */
3689637de38SSrikanth Yalavarthi 				if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1)
3699637de38SSrikanth Yalavarthi 					f16_e++;
3709637de38SSrikanth Yalavarthi 			} else if ((f32_m & GENMASK_U32(tmsb, 0)) == BIT(tmsb)) {
3719637de38SSrikanth Yalavarthi 				/* round: MSB of truncated bits and LSB of m_16 is set */
3729637de38SSrikanth Yalavarthi 				if ((m_16 & 0x1) == 0x1) {
3739637de38SSrikanth Yalavarthi 					m_16++;
3749637de38SSrikanth Yalavarthi 
3759637de38SSrikanth Yalavarthi 					/* overflow into exponent */
3769637de38SSrikanth Yalavarthi 					if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1)
3779637de38SSrikanth Yalavarthi 						f16_e++;
3789637de38SSrikanth Yalavarthi 				}
3799637de38SSrikanth Yalavarthi 			}
3809637de38SSrikanth Yalavarthi 			f16_m = m_16 & FP16_MASK_M;
3819637de38SSrikanth Yalavarthi 		} else if ((be_16 >= -(int)(FP16_MSB_M)) && (be_16 < 1)) {
3829637de38SSrikanth Yalavarthi 			/* underflow: zero / subnormal, be_16 = [-9:0] */
3839637de38SSrikanth Yalavarthi 			f16_e = 0;
3849637de38SSrikanth Yalavarthi 
3859637de38SSrikanth Yalavarthi 			/* add implicit leading zero */
3869637de38SSrikanth Yalavarthi 			m_32 = f32_m | BIT(FP32_LSB_E);
3879637de38SSrikanth Yalavarthi 			tbits = FP32_LSB_E - FP16_LSB_E - be_16 + 1;
3889637de38SSrikanth Yalavarthi 			m_16 = m_32 >> tbits;
3899637de38SSrikanth Yalavarthi 
3909637de38SSrikanth Yalavarthi 			/* if non-leading truncated bits are set */
3919637de38SSrikanth Yalavarthi 			if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) {
3929637de38SSrikanth Yalavarthi 				m_16++;
3939637de38SSrikanth Yalavarthi 
3949637de38SSrikanth Yalavarthi 				/* overflow into exponent */
3959637de38SSrikanth Yalavarthi 				if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1)
3969637de38SSrikanth Yalavarthi 					f16_e++;
3979637de38SSrikanth Yalavarthi 			} else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) {
3989637de38SSrikanth Yalavarthi 				/* if leading truncated bit is set */
3999637de38SSrikanth Yalavarthi 				if ((m_16 & 0x1) == 0x1) {
4009637de38SSrikanth Yalavarthi 					m_16++;
4019637de38SSrikanth Yalavarthi 
4029637de38SSrikanth Yalavarthi 					/* overflow into exponent */
4039637de38SSrikanth Yalavarthi 					if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1)
4049637de38SSrikanth Yalavarthi 						f16_e++;
4059637de38SSrikanth Yalavarthi 				}
4069637de38SSrikanth Yalavarthi 			}
4079637de38SSrikanth Yalavarthi 			f16_m = m_16 & FP16_MASK_M;
4089637de38SSrikanth Yalavarthi 		} else if (be_16 == -(int)(FP16_MSB_M + 1)) {
4099637de38SSrikanth Yalavarthi 			/* underflow: zero, be_16 = [-10] */
4109637de38SSrikanth Yalavarthi 			f16_e = 0;
4119637de38SSrikanth Yalavarthi 			if (f32_m != 0)
4129637de38SSrikanth Yalavarthi 				f16_m = 1;
4139637de38SSrikanth Yalavarthi 			else
4149637de38SSrikanth Yalavarthi 				f16_m = 0;
4159637de38SSrikanth Yalavarthi 		} else {
4169637de38SSrikanth Yalavarthi 			/* underflow: zero, be_16 = [-INF:-11] */
4179637de38SSrikanth Yalavarthi 			f16_e = 0;
4189637de38SSrikanth Yalavarthi 			f16_m = 0;
4199637de38SSrikanth Yalavarthi 		}
4209637de38SSrikanth Yalavarthi 
4219637de38SSrikanth Yalavarthi 		break;
4229637de38SSrikanth Yalavarthi 	}
4239637de38SSrikanth Yalavarthi 
4249637de38SSrikanth Yalavarthi 	u16 = FP16_PACK(f16_s, f16_e, f16_m);
4259637de38SSrikanth Yalavarthi 
4269637de38SSrikanth Yalavarthi 	return u16;
4279637de38SSrikanth Yalavarthi }
4289637de38SSrikanth Yalavarthi 
429*8c9bfcb1SSrikanth Yalavarthi int
4309637de38SSrikanth Yalavarthi rte_ml_io_float32_to_float16(uint64_t nb_elements, void *input, void *output)
4319637de38SSrikanth Yalavarthi {
4329637de38SSrikanth Yalavarthi 	float *input_buffer;
4339637de38SSrikanth Yalavarthi 	uint16_t *output_buffer;
4349637de38SSrikanth Yalavarthi 	uint64_t i;
4359637de38SSrikanth Yalavarthi 
4369637de38SSrikanth Yalavarthi 	if ((nb_elements == 0) || (input == NULL) || (output == NULL))
4379637de38SSrikanth Yalavarthi 		return -EINVAL;
4389637de38SSrikanth Yalavarthi 
4399637de38SSrikanth Yalavarthi 	input_buffer = (float *)input;
4409637de38SSrikanth Yalavarthi 	output_buffer = (uint16_t *)output;
4419637de38SSrikanth Yalavarthi 
4429637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
4439637de38SSrikanth Yalavarthi 		*output_buffer = __float32_to_float16_scalar_rtn(*input_buffer);
4449637de38SSrikanth Yalavarthi 
4459637de38SSrikanth Yalavarthi 		input_buffer = input_buffer + 1;
4469637de38SSrikanth Yalavarthi 		output_buffer = output_buffer + 1;
4479637de38SSrikanth Yalavarthi 	}
4489637de38SSrikanth Yalavarthi 
4499637de38SSrikanth Yalavarthi 	return 0;
4509637de38SSrikanth Yalavarthi }
4519637de38SSrikanth Yalavarthi 
4529637de38SSrikanth Yalavarthi /* Convert a half precision floating point number (float16) into a single precision
4539637de38SSrikanth Yalavarthi  * floating point number (float32).
4549637de38SSrikanth Yalavarthi  */
4559637de38SSrikanth Yalavarthi static float
4569637de38SSrikanth Yalavarthi __float16_to_float32_scalar_rtx(uint16_t f16)
4579637de38SSrikanth Yalavarthi {
4589637de38SSrikanth Yalavarthi 	union float32 f32; /* float32 output */
4599637de38SSrikanth Yalavarthi 	uint16_t f16_s;	   /* float16 sign */
4609637de38SSrikanth Yalavarthi 	uint16_t f16_e;	   /* float16 exponent */
4619637de38SSrikanth Yalavarthi 	uint16_t f16_m;	   /* float16 mantissa */
4629637de38SSrikanth Yalavarthi 	uint32_t f32_s;	   /* float32 sign */
4639637de38SSrikanth Yalavarthi 	uint32_t f32_e;	   /* float32 exponent */
4649637de38SSrikanth Yalavarthi 	uint32_t f32_m;	   /* float32 mantissa*/
4659637de38SSrikanth Yalavarthi 	uint8_t shift;	   /* number of bits to be shifted */
4669637de38SSrikanth Yalavarthi 	uint32_t clz;	   /* count of leading zeroes */
4679637de38SSrikanth Yalavarthi 	int e_16;	   /* float16 exponent unbiased */
4689637de38SSrikanth Yalavarthi 
4699637de38SSrikanth Yalavarthi 	f16_s = (f16 & FP16_MASK_S) >> FP16_LSB_S;
4709637de38SSrikanth Yalavarthi 	f16_e = (f16 & FP16_MASK_E) >> FP16_LSB_E;
4719637de38SSrikanth Yalavarthi 	f16_m = (f16 & FP16_MASK_M) >> FP16_LSB_M;
4729637de38SSrikanth Yalavarthi 
4739637de38SSrikanth Yalavarthi 	f32_s = f16_s;
4749637de38SSrikanth Yalavarthi 	switch (f16_e) {
4759637de38SSrikanth Yalavarthi 	case (FP16_MASK_E >> FP16_LSB_E): /* float16: infinity or nan */
4769637de38SSrikanth Yalavarthi 		f32_e = FP32_MASK_E >> FP32_LSB_E;
4779637de38SSrikanth Yalavarthi 		if (f16_m == 0x0) { /* infinity */
4789637de38SSrikanth Yalavarthi 			f32_m = f16_m;
4799637de38SSrikanth Yalavarthi 		} else { /* nan, propagate mantissa, set MSB of mantissa to 1 */
4809637de38SSrikanth Yalavarthi 			f32_m = f16_m;
4819637de38SSrikanth Yalavarthi 			shift = FP32_MSB_M - FP16_MSB_M;
4829637de38SSrikanth Yalavarthi 			f32_m = (f32_m << shift) & FP32_MASK_M;
4839637de38SSrikanth Yalavarthi 			f32_m |= BIT(FP32_MSB_M);
4849637de38SSrikanth Yalavarthi 		}
4859637de38SSrikanth Yalavarthi 		break;
4869637de38SSrikanth Yalavarthi 	case 0: /* float16: zero or sub-normal */
4879637de38SSrikanth Yalavarthi 		f32_m = f16_m;
4889637de38SSrikanth Yalavarthi 		if (f16_m == 0) { /* zero signed */
4899637de38SSrikanth Yalavarthi 			f32_e = 0;
4909637de38SSrikanth Yalavarthi 		} else { /* subnormal numbers */
4919637de38SSrikanth Yalavarthi 			clz = __builtin_clz((uint32_t)f16_m) - sizeof(uint32_t) * 8 + FP16_LSB_E;
4929637de38SSrikanth Yalavarthi 			e_16 = (int)f16_e - clz;
4939637de38SSrikanth Yalavarthi 			f32_e = FP32_BIAS_E + e_16 - FP16_BIAS_E;
4949637de38SSrikanth Yalavarthi 
4959637de38SSrikanth Yalavarthi 			shift = clz + (FP32_MSB_M - FP16_MSB_M) + 1;
4969637de38SSrikanth Yalavarthi 			f32_m = (f32_m << shift) & FP32_MASK_M;
4979637de38SSrikanth Yalavarthi 		}
4989637de38SSrikanth Yalavarthi 		break;
4999637de38SSrikanth Yalavarthi 	default: /* normal numbers */
5009637de38SSrikanth Yalavarthi 		f32_m = f16_m;
5019637de38SSrikanth Yalavarthi 		e_16 = (int)f16_e;
5029637de38SSrikanth Yalavarthi 		f32_e = FP32_BIAS_E + e_16 - FP16_BIAS_E;
5039637de38SSrikanth Yalavarthi 
5049637de38SSrikanth Yalavarthi 		shift = (FP32_MSB_M - FP16_MSB_M);
5059637de38SSrikanth Yalavarthi 		f32_m = (f32_m << shift) & FP32_MASK_M;
5069637de38SSrikanth Yalavarthi 	}
5079637de38SSrikanth Yalavarthi 
5089637de38SSrikanth Yalavarthi 	f32.u = FP32_PACK(f32_s, f32_e, f32_m);
5099637de38SSrikanth Yalavarthi 
5109637de38SSrikanth Yalavarthi 	return f32.f;
5119637de38SSrikanth Yalavarthi }
5129637de38SSrikanth Yalavarthi 
513*8c9bfcb1SSrikanth Yalavarthi int
5149637de38SSrikanth Yalavarthi rte_ml_io_float16_to_float32(uint64_t nb_elements, void *input, void *output)
5159637de38SSrikanth Yalavarthi {
5169637de38SSrikanth Yalavarthi 	uint16_t *input_buffer;
5179637de38SSrikanth Yalavarthi 	float *output_buffer;
5189637de38SSrikanth Yalavarthi 	uint64_t i;
5199637de38SSrikanth Yalavarthi 
5209637de38SSrikanth Yalavarthi 	if ((nb_elements == 0) || (input == NULL) || (output == NULL))
5219637de38SSrikanth Yalavarthi 		return -EINVAL;
5229637de38SSrikanth Yalavarthi 
5239637de38SSrikanth Yalavarthi 	input_buffer = (uint16_t *)input;
5249637de38SSrikanth Yalavarthi 	output_buffer = (float *)output;
5259637de38SSrikanth Yalavarthi 
5269637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
5279637de38SSrikanth Yalavarthi 		*output_buffer = __float16_to_float32_scalar_rtx(*input_buffer);
5289637de38SSrikanth Yalavarthi 
5299637de38SSrikanth Yalavarthi 		input_buffer = input_buffer + 1;
5309637de38SSrikanth Yalavarthi 		output_buffer = output_buffer + 1;
5319637de38SSrikanth Yalavarthi 	}
5329637de38SSrikanth Yalavarthi 
5339637de38SSrikanth Yalavarthi 	return 0;
5349637de38SSrikanth Yalavarthi }
5359637de38SSrikanth Yalavarthi 
5369637de38SSrikanth Yalavarthi /* Convert a single precision floating point number (float32) into a
5379637de38SSrikanth Yalavarthi  * brain float number (bfloat16) using round to nearest rounding mode.
5389637de38SSrikanth Yalavarthi  */
5399637de38SSrikanth Yalavarthi static uint16_t
5409637de38SSrikanth Yalavarthi __float32_to_bfloat16_scalar_rtn(float x)
5419637de38SSrikanth Yalavarthi {
5429637de38SSrikanth Yalavarthi 	union float32 f32; /* float32 input */
5439637de38SSrikanth Yalavarthi 	uint32_t f32_s;	   /* float32 sign */
5449637de38SSrikanth Yalavarthi 	uint32_t f32_e;	   /* float32 exponent */
5459637de38SSrikanth Yalavarthi 	uint32_t f32_m;	   /* float32 mantissa */
5469637de38SSrikanth Yalavarthi 	uint16_t b16_s;	   /* float16 sign */
5479637de38SSrikanth Yalavarthi 	uint16_t b16_e;	   /* float16 exponent */
5489637de38SSrikanth Yalavarthi 	uint16_t b16_m;	   /* float16 mantissa */
5499637de38SSrikanth Yalavarthi 	uint32_t tbits;	   /* number of truncated bits */
5509637de38SSrikanth Yalavarthi 	uint16_t u16;	   /* float16 output */
5519637de38SSrikanth Yalavarthi 
5529637de38SSrikanth Yalavarthi 	f32.f = x;
5539637de38SSrikanth Yalavarthi 	f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S;
5549637de38SSrikanth Yalavarthi 	f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E;
5559637de38SSrikanth Yalavarthi 	f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M;
5569637de38SSrikanth Yalavarthi 
5579637de38SSrikanth Yalavarthi 	b16_s = f32_s;
5589637de38SSrikanth Yalavarthi 	b16_e = 0;
5599637de38SSrikanth Yalavarthi 	b16_m = 0;
5609637de38SSrikanth Yalavarthi 
5619637de38SSrikanth Yalavarthi 	switch (f32_e) {
5629637de38SSrikanth Yalavarthi 	case (0): /* float32: zero or subnormal number */
5639637de38SSrikanth Yalavarthi 		b16_e = 0;
5649637de38SSrikanth Yalavarthi 		if (f32_m == 0) /* zero */
5659637de38SSrikanth Yalavarthi 			b16_m = 0;
5669637de38SSrikanth Yalavarthi 		else /* subnormal float32 number, normal bfloat16 */
5679637de38SSrikanth Yalavarthi 			goto bf16_normal;
5689637de38SSrikanth Yalavarthi 		break;
5699637de38SSrikanth Yalavarthi 	case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */
5709637de38SSrikanth Yalavarthi 		b16_e = BF16_MASK_E >> BF16_LSB_E;
5719637de38SSrikanth Yalavarthi 		if (f32_m == 0) { /* infinity */
5729637de38SSrikanth Yalavarthi 			b16_m = 0;
5739637de38SSrikanth Yalavarthi 		} else { /* nan, propagate mantissa and set MSB of mantissa to 1 */
5749637de38SSrikanth Yalavarthi 			b16_m = f32_m >> (FP32_MSB_M - BF16_MSB_M);
5759637de38SSrikanth Yalavarthi 			b16_m |= BIT(BF16_MSB_M);
5769637de38SSrikanth Yalavarthi 		}
5779637de38SSrikanth Yalavarthi 		break;
5789637de38SSrikanth Yalavarthi 	default: /* float32: normal number, normal bfloat16 */
5799637de38SSrikanth Yalavarthi 		goto bf16_normal;
5809637de38SSrikanth Yalavarthi 	}
5819637de38SSrikanth Yalavarthi 
5829637de38SSrikanth Yalavarthi 	goto bf16_pack;
5839637de38SSrikanth Yalavarthi 
5849637de38SSrikanth Yalavarthi bf16_normal:
5859637de38SSrikanth Yalavarthi 	b16_e = f32_e;
5869637de38SSrikanth Yalavarthi 	tbits = FP32_MSB_M - BF16_MSB_M;
5879637de38SSrikanth Yalavarthi 	b16_m = f32_m >> tbits;
5889637de38SSrikanth Yalavarthi 
5899637de38SSrikanth Yalavarthi 	/* if non-leading truncated bits are set */
5909637de38SSrikanth Yalavarthi 	if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) {
5919637de38SSrikanth Yalavarthi 		b16_m++;
5929637de38SSrikanth Yalavarthi 
5939637de38SSrikanth Yalavarthi 		/* if overflow into exponent */
5949637de38SSrikanth Yalavarthi 		if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1)
5959637de38SSrikanth Yalavarthi 			b16_e++;
5969637de38SSrikanth Yalavarthi 	} else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) {
5979637de38SSrikanth Yalavarthi 		/* if only leading truncated bit is set */
5989637de38SSrikanth Yalavarthi 		if ((b16_m & 0x1) == 0x1) {
5999637de38SSrikanth Yalavarthi 			b16_m++;
6009637de38SSrikanth Yalavarthi 
6019637de38SSrikanth Yalavarthi 			/* if overflow into exponent */
6029637de38SSrikanth Yalavarthi 			if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1)
6039637de38SSrikanth Yalavarthi 				b16_e++;
6049637de38SSrikanth Yalavarthi 		}
6059637de38SSrikanth Yalavarthi 	}
6069637de38SSrikanth Yalavarthi 	b16_m = b16_m & BF16_MASK_M;
6079637de38SSrikanth Yalavarthi 
6089637de38SSrikanth Yalavarthi bf16_pack:
6099637de38SSrikanth Yalavarthi 	u16 = BF16_PACK(b16_s, b16_e, b16_m);
6109637de38SSrikanth Yalavarthi 
6119637de38SSrikanth Yalavarthi 	return u16;
6129637de38SSrikanth Yalavarthi }
6139637de38SSrikanth Yalavarthi 
614*8c9bfcb1SSrikanth Yalavarthi int
6159637de38SSrikanth Yalavarthi rte_ml_io_float32_to_bfloat16(uint64_t nb_elements, void *input, void *output)
6169637de38SSrikanth Yalavarthi {
6179637de38SSrikanth Yalavarthi 	float *input_buffer;
6189637de38SSrikanth Yalavarthi 	uint16_t *output_buffer;
6199637de38SSrikanth Yalavarthi 	uint64_t i;
6209637de38SSrikanth Yalavarthi 
6219637de38SSrikanth Yalavarthi 	if ((nb_elements == 0) || (input == NULL) || (output == NULL))
6229637de38SSrikanth Yalavarthi 		return -EINVAL;
6239637de38SSrikanth Yalavarthi 
6249637de38SSrikanth Yalavarthi 	input_buffer = (float *)input;
6259637de38SSrikanth Yalavarthi 	output_buffer = (uint16_t *)output;
6269637de38SSrikanth Yalavarthi 
6279637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
6289637de38SSrikanth Yalavarthi 		*output_buffer = __float32_to_bfloat16_scalar_rtn(*input_buffer);
6299637de38SSrikanth Yalavarthi 
6309637de38SSrikanth Yalavarthi 		input_buffer = input_buffer + 1;
6319637de38SSrikanth Yalavarthi 		output_buffer = output_buffer + 1;
6329637de38SSrikanth Yalavarthi 	}
6339637de38SSrikanth Yalavarthi 
6349637de38SSrikanth Yalavarthi 	return 0;
6359637de38SSrikanth Yalavarthi }
6369637de38SSrikanth Yalavarthi 
6379637de38SSrikanth Yalavarthi /* Convert a brain float number (bfloat16) into a
6389637de38SSrikanth Yalavarthi  * single precision floating point number (float32).
6399637de38SSrikanth Yalavarthi  */
6409637de38SSrikanth Yalavarthi static float
6419637de38SSrikanth Yalavarthi __bfloat16_to_float32_scalar_rtx(uint16_t f16)
6429637de38SSrikanth Yalavarthi {
6439637de38SSrikanth Yalavarthi 	union float32 f32; /* float32 output */
6449637de38SSrikanth Yalavarthi 	uint16_t b16_s;	   /* float16 sign */
6459637de38SSrikanth Yalavarthi 	uint16_t b16_e;	   /* float16 exponent */
6469637de38SSrikanth Yalavarthi 	uint16_t b16_m;	   /* float16 mantissa */
6479637de38SSrikanth Yalavarthi 	uint32_t f32_s;	   /* float32 sign */
6489637de38SSrikanth Yalavarthi 	uint32_t f32_e;	   /* float32 exponent */
6499637de38SSrikanth Yalavarthi 	uint32_t f32_m;	   /* float32 mantissa*/
6509637de38SSrikanth Yalavarthi 	uint8_t shift;	   /* number of bits to be shifted */
6519637de38SSrikanth Yalavarthi 
6529637de38SSrikanth Yalavarthi 	b16_s = (f16 & BF16_MASK_S) >> BF16_LSB_S;
6539637de38SSrikanth Yalavarthi 	b16_e = (f16 & BF16_MASK_E) >> BF16_LSB_E;
6549637de38SSrikanth Yalavarthi 	b16_m = (f16 & BF16_MASK_M) >> BF16_LSB_M;
6559637de38SSrikanth Yalavarthi 
6569637de38SSrikanth Yalavarthi 	f32_s = b16_s;
6579637de38SSrikanth Yalavarthi 	switch (b16_e) {
6589637de38SSrikanth Yalavarthi 	case (BF16_MASK_E >> BF16_LSB_E): /* bfloat16: infinity or nan */
6599637de38SSrikanth Yalavarthi 		f32_e = FP32_MASK_E >> FP32_LSB_E;
6609637de38SSrikanth Yalavarthi 		if (b16_m == 0x0) { /* infinity */
6619637de38SSrikanth Yalavarthi 			f32_m = 0;
6629637de38SSrikanth Yalavarthi 		} else { /* nan, propagate mantissa, set MSB of mantissa to 1 */
6639637de38SSrikanth Yalavarthi 			f32_m = b16_m;
6649637de38SSrikanth Yalavarthi 			shift = FP32_MSB_M - BF16_MSB_M;
6659637de38SSrikanth Yalavarthi 			f32_m = (f32_m << shift) & FP32_MASK_M;
6669637de38SSrikanth Yalavarthi 			f32_m |= BIT(FP32_MSB_M);
6679637de38SSrikanth Yalavarthi 		}
6689637de38SSrikanth Yalavarthi 		break;
6699637de38SSrikanth Yalavarthi 	case 0: /* bfloat16: zero or subnormal */
6709637de38SSrikanth Yalavarthi 		f32_m = b16_m;
6719637de38SSrikanth Yalavarthi 		if (b16_m == 0) { /* zero signed */
6729637de38SSrikanth Yalavarthi 			f32_e = 0;
6739637de38SSrikanth Yalavarthi 		} else { /* subnormal numbers */
6749637de38SSrikanth Yalavarthi 			goto fp32_normal;
6759637de38SSrikanth Yalavarthi 		}
6769637de38SSrikanth Yalavarthi 		break;
6779637de38SSrikanth Yalavarthi 	default: /* bfloat16: normal number */
6789637de38SSrikanth Yalavarthi 		goto fp32_normal;
6799637de38SSrikanth Yalavarthi 	}
6809637de38SSrikanth Yalavarthi 
6819637de38SSrikanth Yalavarthi 	goto fp32_pack;
6829637de38SSrikanth Yalavarthi 
6839637de38SSrikanth Yalavarthi fp32_normal:
6849637de38SSrikanth Yalavarthi 	f32_m = b16_m;
6859637de38SSrikanth Yalavarthi 	f32_e = FP32_BIAS_E + b16_e - BF16_BIAS_E;
6869637de38SSrikanth Yalavarthi 
6879637de38SSrikanth Yalavarthi 	shift = (FP32_MSB_M - BF16_MSB_M);
6889637de38SSrikanth Yalavarthi 	f32_m = (f32_m << shift) & FP32_MASK_M;
6899637de38SSrikanth Yalavarthi 
6909637de38SSrikanth Yalavarthi fp32_pack:
6919637de38SSrikanth Yalavarthi 	f32.u = FP32_PACK(f32_s, f32_e, f32_m);
6929637de38SSrikanth Yalavarthi 
6939637de38SSrikanth Yalavarthi 	return f32.f;
6949637de38SSrikanth Yalavarthi }
6959637de38SSrikanth Yalavarthi 
696*8c9bfcb1SSrikanth Yalavarthi int
6979637de38SSrikanth Yalavarthi rte_ml_io_bfloat16_to_float32(uint64_t nb_elements, void *input, void *output)
6989637de38SSrikanth Yalavarthi {
6999637de38SSrikanth Yalavarthi 	uint16_t *input_buffer;
7009637de38SSrikanth Yalavarthi 	float *output_buffer;
7019637de38SSrikanth Yalavarthi 	uint64_t i;
7029637de38SSrikanth Yalavarthi 
7039637de38SSrikanth Yalavarthi 	if ((nb_elements == 0) || (input == NULL) || (output == NULL))
7049637de38SSrikanth Yalavarthi 		return -EINVAL;
7059637de38SSrikanth Yalavarthi 
7069637de38SSrikanth Yalavarthi 	input_buffer = (uint16_t *)input;
7079637de38SSrikanth Yalavarthi 	output_buffer = (float *)output;
7089637de38SSrikanth Yalavarthi 
7099637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
7109637de38SSrikanth Yalavarthi 		*output_buffer = __bfloat16_to_float32_scalar_rtx(*input_buffer);
7119637de38SSrikanth Yalavarthi 
7129637de38SSrikanth Yalavarthi 		input_buffer = input_buffer + 1;
7139637de38SSrikanth Yalavarthi 		output_buffer = output_buffer + 1;
7149637de38SSrikanth Yalavarthi 	}
7159637de38SSrikanth Yalavarthi 
7169637de38SSrikanth Yalavarthi 	return 0;
7179637de38SSrikanth Yalavarthi }
718