xref: /dpdk/lib/mldev/mldev_utils_scalar.c (revision 9637de38a2e361cea6258b85880de0eac69a5d0a)
1*9637de38SSrikanth Yalavarthi /* SPDX-License-Identifier: BSD-3-Clause
2*9637de38SSrikanth Yalavarthi  * Copyright (c) 2022 Marvell.
3*9637de38SSrikanth Yalavarthi  */
4*9637de38SSrikanth Yalavarthi 
5*9637de38SSrikanth Yalavarthi #include <errno.h>
6*9637de38SSrikanth Yalavarthi #include <math.h>
7*9637de38SSrikanth Yalavarthi #include <stdint.h>
8*9637de38SSrikanth Yalavarthi 
9*9637de38SSrikanth Yalavarthi #include "mldev_utils.h"
10*9637de38SSrikanth Yalavarthi 
11*9637de38SSrikanth Yalavarthi /* Description:
12*9637de38SSrikanth Yalavarthi  * This file implements scalar versions of Machine Learning utility functions used to convert data
13*9637de38SSrikanth Yalavarthi  * types from higher precision to lower precision and vice-versa.
14*9637de38SSrikanth Yalavarthi  */
15*9637de38SSrikanth Yalavarthi 
16*9637de38SSrikanth Yalavarthi #ifndef BIT
17*9637de38SSrikanth Yalavarthi #define BIT(nr) (1UL << (nr))
18*9637de38SSrikanth Yalavarthi #endif
19*9637de38SSrikanth Yalavarthi 
20*9637de38SSrikanth Yalavarthi #ifndef BITS_PER_LONG
21*9637de38SSrikanth Yalavarthi #define BITS_PER_LONG (__SIZEOF_LONG__ * 8)
22*9637de38SSrikanth Yalavarthi #endif
23*9637de38SSrikanth Yalavarthi 
24*9637de38SSrikanth Yalavarthi #ifndef GENMASK_U32
25*9637de38SSrikanth Yalavarthi #define GENMASK_U32(h, l) (((~0UL) << (l)) & (~0UL >> (BITS_PER_LONG - 1 - (h))))
26*9637de38SSrikanth Yalavarthi #endif
27*9637de38SSrikanth Yalavarthi 
28*9637de38SSrikanth Yalavarthi /* float32: bit index of MSB & LSB of sign, exponent and mantissa */
29*9637de38SSrikanth Yalavarthi #define FP32_LSB_M 0
30*9637de38SSrikanth Yalavarthi #define FP32_MSB_M 22
31*9637de38SSrikanth Yalavarthi #define FP32_LSB_E 23
32*9637de38SSrikanth Yalavarthi #define FP32_MSB_E 30
33*9637de38SSrikanth Yalavarthi #define FP32_LSB_S 31
34*9637de38SSrikanth Yalavarthi #define FP32_MSB_S 31
35*9637de38SSrikanth Yalavarthi 
36*9637de38SSrikanth Yalavarthi /* float32: bitmask for sign, exponent and mantissa */
37*9637de38SSrikanth Yalavarthi #define FP32_MASK_S GENMASK_U32(FP32_MSB_S, FP32_LSB_S)
38*9637de38SSrikanth Yalavarthi #define FP32_MASK_E GENMASK_U32(FP32_MSB_E, FP32_LSB_E)
39*9637de38SSrikanth Yalavarthi #define FP32_MASK_M GENMASK_U32(FP32_MSB_M, FP32_LSB_M)
40*9637de38SSrikanth Yalavarthi 
41*9637de38SSrikanth Yalavarthi /* float16: bit index of MSB & LSB of sign, exponent and mantissa */
42*9637de38SSrikanth Yalavarthi #define FP16_LSB_M 0
43*9637de38SSrikanth Yalavarthi #define FP16_MSB_M 9
44*9637de38SSrikanth Yalavarthi #define FP16_LSB_E 10
45*9637de38SSrikanth Yalavarthi #define FP16_MSB_E 14
46*9637de38SSrikanth Yalavarthi #define FP16_LSB_S 15
47*9637de38SSrikanth Yalavarthi #define FP16_MSB_S 15
48*9637de38SSrikanth Yalavarthi 
49*9637de38SSrikanth Yalavarthi /* float16: bitmask for sign, exponent and mantissa */
50*9637de38SSrikanth Yalavarthi #define FP16_MASK_S GENMASK_U32(FP16_MSB_S, FP16_LSB_S)
51*9637de38SSrikanth Yalavarthi #define FP16_MASK_E GENMASK_U32(FP16_MSB_E, FP16_LSB_E)
52*9637de38SSrikanth Yalavarthi #define FP16_MASK_M GENMASK_U32(FP16_MSB_M, FP16_LSB_M)
53*9637de38SSrikanth Yalavarthi 
54*9637de38SSrikanth Yalavarthi /* bfloat16: bit index of MSB & LSB of sign, exponent and mantissa */
55*9637de38SSrikanth Yalavarthi #define BF16_LSB_M 0
56*9637de38SSrikanth Yalavarthi #define BF16_MSB_M 6
57*9637de38SSrikanth Yalavarthi #define BF16_LSB_E 7
58*9637de38SSrikanth Yalavarthi #define BF16_MSB_E 14
59*9637de38SSrikanth Yalavarthi #define BF16_LSB_S 15
60*9637de38SSrikanth Yalavarthi #define BF16_MSB_S 15
61*9637de38SSrikanth Yalavarthi 
62*9637de38SSrikanth Yalavarthi /* bfloat16: bitmask for sign, exponent and mantissa */
63*9637de38SSrikanth Yalavarthi #define BF16_MASK_S GENMASK_U32(BF16_MSB_S, BF16_LSB_S)
64*9637de38SSrikanth Yalavarthi #define BF16_MASK_E GENMASK_U32(BF16_MSB_E, BF16_LSB_E)
65*9637de38SSrikanth Yalavarthi #define BF16_MASK_M GENMASK_U32(BF16_MSB_M, BF16_LSB_M)
66*9637de38SSrikanth Yalavarthi 
67*9637de38SSrikanth Yalavarthi /* Exponent bias */
68*9637de38SSrikanth Yalavarthi #define FP32_BIAS_E 127
69*9637de38SSrikanth Yalavarthi #define FP16_BIAS_E 15
70*9637de38SSrikanth Yalavarthi #define BF16_BIAS_E 127
71*9637de38SSrikanth Yalavarthi 
72*9637de38SSrikanth Yalavarthi #define FP32_PACK(sign, exponent, mantissa)                                                        \
73*9637de38SSrikanth Yalavarthi 	(((sign) << FP32_LSB_S) | ((exponent) << FP32_LSB_E) | (mantissa))
74*9637de38SSrikanth Yalavarthi 
75*9637de38SSrikanth Yalavarthi #define FP16_PACK(sign, exponent, mantissa)                                                        \
76*9637de38SSrikanth Yalavarthi 	(((sign) << FP16_LSB_S) | ((exponent) << FP16_LSB_E) | (mantissa))
77*9637de38SSrikanth Yalavarthi 
78*9637de38SSrikanth Yalavarthi #define BF16_PACK(sign, exponent, mantissa)                                                        \
79*9637de38SSrikanth Yalavarthi 	(((sign) << BF16_LSB_S) | ((exponent) << BF16_LSB_E) | (mantissa))
80*9637de38SSrikanth Yalavarthi 
81*9637de38SSrikanth Yalavarthi /* Represent float32 as float and uint32_t */
82*9637de38SSrikanth Yalavarthi union float32 {
83*9637de38SSrikanth Yalavarthi 	float f;
84*9637de38SSrikanth Yalavarthi 	uint32_t u;
85*9637de38SSrikanth Yalavarthi };
86*9637de38SSrikanth Yalavarthi 
87*9637de38SSrikanth Yalavarthi __rte_weak int
88*9637de38SSrikanth Yalavarthi rte_ml_io_float32_to_int8(float scale, uint64_t nb_elements, void *input, void *output)
89*9637de38SSrikanth Yalavarthi {
90*9637de38SSrikanth Yalavarthi 	float *input_buffer;
91*9637de38SSrikanth Yalavarthi 	int8_t *output_buffer;
92*9637de38SSrikanth Yalavarthi 	uint64_t i;
93*9637de38SSrikanth Yalavarthi 	int i32;
94*9637de38SSrikanth Yalavarthi 
95*9637de38SSrikanth Yalavarthi 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
96*9637de38SSrikanth Yalavarthi 		return -EINVAL;
97*9637de38SSrikanth Yalavarthi 
98*9637de38SSrikanth Yalavarthi 	input_buffer = (float *)input;
99*9637de38SSrikanth Yalavarthi 	output_buffer = (int8_t *)output;
100*9637de38SSrikanth Yalavarthi 
101*9637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
102*9637de38SSrikanth Yalavarthi 		i32 = (int32_t)round((*input_buffer) * scale);
103*9637de38SSrikanth Yalavarthi 
104*9637de38SSrikanth Yalavarthi 		if (i32 < INT8_MIN)
105*9637de38SSrikanth Yalavarthi 			i32 = INT8_MIN;
106*9637de38SSrikanth Yalavarthi 
107*9637de38SSrikanth Yalavarthi 		if (i32 > INT8_MAX)
108*9637de38SSrikanth Yalavarthi 			i32 = INT8_MAX;
109*9637de38SSrikanth Yalavarthi 
110*9637de38SSrikanth Yalavarthi 		*output_buffer = (int8_t)i32;
111*9637de38SSrikanth Yalavarthi 
112*9637de38SSrikanth Yalavarthi 		input_buffer++;
113*9637de38SSrikanth Yalavarthi 		output_buffer++;
114*9637de38SSrikanth Yalavarthi 	}
115*9637de38SSrikanth Yalavarthi 
116*9637de38SSrikanth Yalavarthi 	return 0;
117*9637de38SSrikanth Yalavarthi }
118*9637de38SSrikanth Yalavarthi 
119*9637de38SSrikanth Yalavarthi __rte_weak int
120*9637de38SSrikanth Yalavarthi rte_ml_io_int8_to_float32(float scale, uint64_t nb_elements, void *input, void *output)
121*9637de38SSrikanth Yalavarthi {
122*9637de38SSrikanth Yalavarthi 	int8_t *input_buffer;
123*9637de38SSrikanth Yalavarthi 	float *output_buffer;
124*9637de38SSrikanth Yalavarthi 	uint64_t i;
125*9637de38SSrikanth Yalavarthi 
126*9637de38SSrikanth Yalavarthi 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
127*9637de38SSrikanth Yalavarthi 		return -EINVAL;
128*9637de38SSrikanth Yalavarthi 
129*9637de38SSrikanth Yalavarthi 	input_buffer = (int8_t *)input;
130*9637de38SSrikanth Yalavarthi 	output_buffer = (float *)output;
131*9637de38SSrikanth Yalavarthi 
132*9637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
133*9637de38SSrikanth Yalavarthi 		*output_buffer = scale * (float)(*input_buffer);
134*9637de38SSrikanth Yalavarthi 
135*9637de38SSrikanth Yalavarthi 		input_buffer++;
136*9637de38SSrikanth Yalavarthi 		output_buffer++;
137*9637de38SSrikanth Yalavarthi 	}
138*9637de38SSrikanth Yalavarthi 
139*9637de38SSrikanth Yalavarthi 	return 0;
140*9637de38SSrikanth Yalavarthi }
141*9637de38SSrikanth Yalavarthi 
142*9637de38SSrikanth Yalavarthi __rte_weak int
143*9637de38SSrikanth Yalavarthi rte_ml_io_float32_to_uint8(float scale, uint64_t nb_elements, void *input, void *output)
144*9637de38SSrikanth Yalavarthi {
145*9637de38SSrikanth Yalavarthi 	float *input_buffer;
146*9637de38SSrikanth Yalavarthi 	uint8_t *output_buffer;
147*9637de38SSrikanth Yalavarthi 	int32_t i32;
148*9637de38SSrikanth Yalavarthi 	uint64_t i;
149*9637de38SSrikanth Yalavarthi 
150*9637de38SSrikanth Yalavarthi 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
151*9637de38SSrikanth Yalavarthi 		return -EINVAL;
152*9637de38SSrikanth Yalavarthi 
153*9637de38SSrikanth Yalavarthi 	input_buffer = (float *)input;
154*9637de38SSrikanth Yalavarthi 	output_buffer = (uint8_t *)output;
155*9637de38SSrikanth Yalavarthi 
156*9637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
157*9637de38SSrikanth Yalavarthi 		i32 = (int32_t)round((*input_buffer) * scale);
158*9637de38SSrikanth Yalavarthi 
159*9637de38SSrikanth Yalavarthi 		if (i32 < 0)
160*9637de38SSrikanth Yalavarthi 			i32 = 0;
161*9637de38SSrikanth Yalavarthi 
162*9637de38SSrikanth Yalavarthi 		if (i32 > UINT8_MAX)
163*9637de38SSrikanth Yalavarthi 			i32 = UINT8_MAX;
164*9637de38SSrikanth Yalavarthi 
165*9637de38SSrikanth Yalavarthi 		*output_buffer = (uint8_t)i32;
166*9637de38SSrikanth Yalavarthi 
167*9637de38SSrikanth Yalavarthi 		input_buffer++;
168*9637de38SSrikanth Yalavarthi 		output_buffer++;
169*9637de38SSrikanth Yalavarthi 	}
170*9637de38SSrikanth Yalavarthi 
171*9637de38SSrikanth Yalavarthi 	return 0;
172*9637de38SSrikanth Yalavarthi }
173*9637de38SSrikanth Yalavarthi 
174*9637de38SSrikanth Yalavarthi __rte_weak int
175*9637de38SSrikanth Yalavarthi rte_ml_io_uint8_to_float32(float scale, uint64_t nb_elements, void *input, void *output)
176*9637de38SSrikanth Yalavarthi {
177*9637de38SSrikanth Yalavarthi 	uint8_t *input_buffer;
178*9637de38SSrikanth Yalavarthi 	float *output_buffer;
179*9637de38SSrikanth Yalavarthi 	uint64_t i;
180*9637de38SSrikanth Yalavarthi 
181*9637de38SSrikanth Yalavarthi 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
182*9637de38SSrikanth Yalavarthi 		return -EINVAL;
183*9637de38SSrikanth Yalavarthi 
184*9637de38SSrikanth Yalavarthi 	input_buffer = (uint8_t *)input;
185*9637de38SSrikanth Yalavarthi 	output_buffer = (float *)output;
186*9637de38SSrikanth Yalavarthi 
187*9637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
188*9637de38SSrikanth Yalavarthi 		*output_buffer = scale * (float)(*input_buffer);
189*9637de38SSrikanth Yalavarthi 
190*9637de38SSrikanth Yalavarthi 		input_buffer++;
191*9637de38SSrikanth Yalavarthi 		output_buffer++;
192*9637de38SSrikanth Yalavarthi 	}
193*9637de38SSrikanth Yalavarthi 
194*9637de38SSrikanth Yalavarthi 	return 0;
195*9637de38SSrikanth Yalavarthi }
196*9637de38SSrikanth Yalavarthi 
197*9637de38SSrikanth Yalavarthi __rte_weak int
198*9637de38SSrikanth Yalavarthi rte_ml_io_float32_to_int16(float scale, uint64_t nb_elements, void *input, void *output)
199*9637de38SSrikanth Yalavarthi {
200*9637de38SSrikanth Yalavarthi 	float *input_buffer;
201*9637de38SSrikanth Yalavarthi 	int16_t *output_buffer;
202*9637de38SSrikanth Yalavarthi 	int32_t i32;
203*9637de38SSrikanth Yalavarthi 	uint64_t i;
204*9637de38SSrikanth Yalavarthi 
205*9637de38SSrikanth Yalavarthi 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
206*9637de38SSrikanth Yalavarthi 		return -EINVAL;
207*9637de38SSrikanth Yalavarthi 
208*9637de38SSrikanth Yalavarthi 	input_buffer = (float *)input;
209*9637de38SSrikanth Yalavarthi 	output_buffer = (int16_t *)output;
210*9637de38SSrikanth Yalavarthi 
211*9637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
212*9637de38SSrikanth Yalavarthi 		i32 = (int32_t)round((*input_buffer) * scale);
213*9637de38SSrikanth Yalavarthi 
214*9637de38SSrikanth Yalavarthi 		if (i32 < INT16_MIN)
215*9637de38SSrikanth Yalavarthi 			i32 = INT16_MIN;
216*9637de38SSrikanth Yalavarthi 
217*9637de38SSrikanth Yalavarthi 		if (i32 > INT16_MAX)
218*9637de38SSrikanth Yalavarthi 			i32 = INT16_MAX;
219*9637de38SSrikanth Yalavarthi 
220*9637de38SSrikanth Yalavarthi 		*output_buffer = (int16_t)i32;
221*9637de38SSrikanth Yalavarthi 
222*9637de38SSrikanth Yalavarthi 		input_buffer++;
223*9637de38SSrikanth Yalavarthi 		output_buffer++;
224*9637de38SSrikanth Yalavarthi 	}
225*9637de38SSrikanth Yalavarthi 
226*9637de38SSrikanth Yalavarthi 	return 0;
227*9637de38SSrikanth Yalavarthi }
228*9637de38SSrikanth Yalavarthi 
229*9637de38SSrikanth Yalavarthi __rte_weak int
230*9637de38SSrikanth Yalavarthi rte_ml_io_int16_to_float32(float scale, uint64_t nb_elements, void *input, void *output)
231*9637de38SSrikanth Yalavarthi {
232*9637de38SSrikanth Yalavarthi 	int16_t *input_buffer;
233*9637de38SSrikanth Yalavarthi 	float *output_buffer;
234*9637de38SSrikanth Yalavarthi 	uint64_t i;
235*9637de38SSrikanth Yalavarthi 
236*9637de38SSrikanth Yalavarthi 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
237*9637de38SSrikanth Yalavarthi 		return -EINVAL;
238*9637de38SSrikanth Yalavarthi 
239*9637de38SSrikanth Yalavarthi 	input_buffer = (int16_t *)input;
240*9637de38SSrikanth Yalavarthi 	output_buffer = (float *)output;
241*9637de38SSrikanth Yalavarthi 
242*9637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
243*9637de38SSrikanth Yalavarthi 		*output_buffer = scale * (float)(*input_buffer);
244*9637de38SSrikanth Yalavarthi 
245*9637de38SSrikanth Yalavarthi 		input_buffer++;
246*9637de38SSrikanth Yalavarthi 		output_buffer++;
247*9637de38SSrikanth Yalavarthi 	}
248*9637de38SSrikanth Yalavarthi 
249*9637de38SSrikanth Yalavarthi 	return 0;
250*9637de38SSrikanth Yalavarthi }
251*9637de38SSrikanth Yalavarthi 
252*9637de38SSrikanth Yalavarthi __rte_weak int
253*9637de38SSrikanth Yalavarthi rte_ml_io_float32_to_uint16(float scale, uint64_t nb_elements, void *input, void *output)
254*9637de38SSrikanth Yalavarthi {
255*9637de38SSrikanth Yalavarthi 	float *input_buffer;
256*9637de38SSrikanth Yalavarthi 	uint16_t *output_buffer;
257*9637de38SSrikanth Yalavarthi 	int32_t i32;
258*9637de38SSrikanth Yalavarthi 	uint64_t i;
259*9637de38SSrikanth Yalavarthi 
260*9637de38SSrikanth Yalavarthi 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
261*9637de38SSrikanth Yalavarthi 		return -EINVAL;
262*9637de38SSrikanth Yalavarthi 
263*9637de38SSrikanth Yalavarthi 	input_buffer = (float *)input;
264*9637de38SSrikanth Yalavarthi 	output_buffer = (uint16_t *)output;
265*9637de38SSrikanth Yalavarthi 
266*9637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
267*9637de38SSrikanth Yalavarthi 		i32 = (int32_t)round((*input_buffer) * scale);
268*9637de38SSrikanth Yalavarthi 
269*9637de38SSrikanth Yalavarthi 		if (i32 < 0)
270*9637de38SSrikanth Yalavarthi 			i32 = 0;
271*9637de38SSrikanth Yalavarthi 
272*9637de38SSrikanth Yalavarthi 		if (i32 > UINT16_MAX)
273*9637de38SSrikanth Yalavarthi 			i32 = UINT16_MAX;
274*9637de38SSrikanth Yalavarthi 
275*9637de38SSrikanth Yalavarthi 		*output_buffer = (uint16_t)i32;
276*9637de38SSrikanth Yalavarthi 
277*9637de38SSrikanth Yalavarthi 		input_buffer++;
278*9637de38SSrikanth Yalavarthi 		output_buffer++;
279*9637de38SSrikanth Yalavarthi 	}
280*9637de38SSrikanth Yalavarthi 
281*9637de38SSrikanth Yalavarthi 	return 0;
282*9637de38SSrikanth Yalavarthi }
283*9637de38SSrikanth Yalavarthi 
284*9637de38SSrikanth Yalavarthi __rte_weak int
285*9637de38SSrikanth Yalavarthi rte_ml_io_uint16_to_float32(float scale, uint64_t nb_elements, void *input, void *output)
286*9637de38SSrikanth Yalavarthi {
287*9637de38SSrikanth Yalavarthi 	uint16_t *input_buffer;
288*9637de38SSrikanth Yalavarthi 	float *output_buffer;
289*9637de38SSrikanth Yalavarthi 	uint64_t i;
290*9637de38SSrikanth Yalavarthi 
291*9637de38SSrikanth Yalavarthi 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
292*9637de38SSrikanth Yalavarthi 		return -EINVAL;
293*9637de38SSrikanth Yalavarthi 
294*9637de38SSrikanth Yalavarthi 	input_buffer = (uint16_t *)input;
295*9637de38SSrikanth Yalavarthi 	output_buffer = (float *)output;
296*9637de38SSrikanth Yalavarthi 
297*9637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
298*9637de38SSrikanth Yalavarthi 		*output_buffer = scale * (float)(*input_buffer);
299*9637de38SSrikanth Yalavarthi 
300*9637de38SSrikanth Yalavarthi 		input_buffer++;
301*9637de38SSrikanth Yalavarthi 		output_buffer++;
302*9637de38SSrikanth Yalavarthi 	}
303*9637de38SSrikanth Yalavarthi 
304*9637de38SSrikanth Yalavarthi 	return 0;
305*9637de38SSrikanth Yalavarthi }
306*9637de38SSrikanth Yalavarthi 
307*9637de38SSrikanth Yalavarthi /* Convert a single precision floating point number (float32) into a half precision
308*9637de38SSrikanth Yalavarthi  * floating point number (float16) using round to nearest rounding mode.
309*9637de38SSrikanth Yalavarthi  */
310*9637de38SSrikanth Yalavarthi static uint16_t
311*9637de38SSrikanth Yalavarthi __float32_to_float16_scalar_rtn(float x)
312*9637de38SSrikanth Yalavarthi {
313*9637de38SSrikanth Yalavarthi 	union float32 f32; /* float32 input */
314*9637de38SSrikanth Yalavarthi 	uint32_t f32_s;	   /* float32 sign */
315*9637de38SSrikanth Yalavarthi 	uint32_t f32_e;	   /* float32 exponent */
316*9637de38SSrikanth Yalavarthi 	uint32_t f32_m;	   /* float32 mantissa */
317*9637de38SSrikanth Yalavarthi 	uint16_t f16_s;	   /* float16 sign */
318*9637de38SSrikanth Yalavarthi 	uint16_t f16_e;	   /* float16 exponent */
319*9637de38SSrikanth Yalavarthi 	uint16_t f16_m;	   /* float16 mantissa */
320*9637de38SSrikanth Yalavarthi 	uint32_t tbits;	   /* number of truncated bits */
321*9637de38SSrikanth Yalavarthi 	uint32_t tmsb;	   /* MSB position of truncated bits */
322*9637de38SSrikanth Yalavarthi 	uint32_t m_32;	   /* temporary float32 mantissa */
323*9637de38SSrikanth Yalavarthi 	uint16_t m_16;	   /* temporary float16 mantissa */
324*9637de38SSrikanth Yalavarthi 	uint16_t u16;	   /* float16 output */
325*9637de38SSrikanth Yalavarthi 	int be_16;	   /* float16 biased exponent, signed */
326*9637de38SSrikanth Yalavarthi 
327*9637de38SSrikanth Yalavarthi 	f32.f = x;
328*9637de38SSrikanth Yalavarthi 	f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S;
329*9637de38SSrikanth Yalavarthi 	f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E;
330*9637de38SSrikanth Yalavarthi 	f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M;
331*9637de38SSrikanth Yalavarthi 
332*9637de38SSrikanth Yalavarthi 	f16_s = f32_s;
333*9637de38SSrikanth Yalavarthi 	f16_e = 0;
334*9637de38SSrikanth Yalavarthi 	f16_m = 0;
335*9637de38SSrikanth Yalavarthi 
336*9637de38SSrikanth Yalavarthi 	switch (f32_e) {
337*9637de38SSrikanth Yalavarthi 	case (0): /* float32: zero or subnormal number */
338*9637de38SSrikanth Yalavarthi 		f16_e = 0;
339*9637de38SSrikanth Yalavarthi 		if (f32_m == 0) /* zero */
340*9637de38SSrikanth Yalavarthi 			f16_m = 0;
341*9637de38SSrikanth Yalavarthi 		else /* subnormal number, convert to zero */
342*9637de38SSrikanth Yalavarthi 			f16_m = 0;
343*9637de38SSrikanth Yalavarthi 		break;
344*9637de38SSrikanth Yalavarthi 	case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */
345*9637de38SSrikanth Yalavarthi 		f16_e = FP16_MASK_E >> FP16_LSB_E;
346*9637de38SSrikanth Yalavarthi 		if (f32_m == 0) { /* infinity */
347*9637de38SSrikanth Yalavarthi 			f16_m = 0;
348*9637de38SSrikanth Yalavarthi 		} else { /* nan, propagate mantissa and set MSB of mantissa to 1 */
349*9637de38SSrikanth Yalavarthi 			f16_m = f32_m >> (FP32_MSB_M - FP16_MSB_M);
350*9637de38SSrikanth Yalavarthi 			f16_m |= BIT(FP16_MSB_M);
351*9637de38SSrikanth Yalavarthi 		}
352*9637de38SSrikanth Yalavarthi 		break;
353*9637de38SSrikanth Yalavarthi 	default: /* float32: normal number */
354*9637de38SSrikanth Yalavarthi 		/* compute biased exponent for float16 */
355*9637de38SSrikanth Yalavarthi 		be_16 = (int)f32_e - FP32_BIAS_E + FP16_BIAS_E;
356*9637de38SSrikanth Yalavarthi 
357*9637de38SSrikanth Yalavarthi 		/* overflow, be_16 = [31-INF], set to infinity */
358*9637de38SSrikanth Yalavarthi 		if (be_16 >= (int)(FP16_MASK_E >> FP16_LSB_E)) {
359*9637de38SSrikanth Yalavarthi 			f16_e = FP16_MASK_E >> FP16_LSB_E;
360*9637de38SSrikanth Yalavarthi 			f16_m = 0;
361*9637de38SSrikanth Yalavarthi 		} else if ((be_16 >= 1) && (be_16 < (int)(FP16_MASK_E >> FP16_LSB_E))) {
362*9637de38SSrikanth Yalavarthi 			/* normal float16, be_16 = [1:30]*/
363*9637de38SSrikanth Yalavarthi 			f16_e = be_16;
364*9637de38SSrikanth Yalavarthi 			m_16 = f32_m >> (FP32_LSB_E - FP16_LSB_E);
365*9637de38SSrikanth Yalavarthi 			tmsb = FP32_MSB_M - FP16_MSB_M - 1;
366*9637de38SSrikanth Yalavarthi 			if ((f32_m & GENMASK_U32(tmsb, 0)) > BIT(tmsb)) {
367*9637de38SSrikanth Yalavarthi 				/* round: non-zero truncated bits except MSB */
368*9637de38SSrikanth Yalavarthi 				m_16++;
369*9637de38SSrikanth Yalavarthi 
370*9637de38SSrikanth Yalavarthi 				/* overflow into exponent */
371*9637de38SSrikanth Yalavarthi 				if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1)
372*9637de38SSrikanth Yalavarthi 					f16_e++;
373*9637de38SSrikanth Yalavarthi 			} else if ((f32_m & GENMASK_U32(tmsb, 0)) == BIT(tmsb)) {
374*9637de38SSrikanth Yalavarthi 				/* round: MSB of truncated bits and LSB of m_16 is set */
375*9637de38SSrikanth Yalavarthi 				if ((m_16 & 0x1) == 0x1) {
376*9637de38SSrikanth Yalavarthi 					m_16++;
377*9637de38SSrikanth Yalavarthi 
378*9637de38SSrikanth Yalavarthi 					/* overflow into exponent */
379*9637de38SSrikanth Yalavarthi 					if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1)
380*9637de38SSrikanth Yalavarthi 						f16_e++;
381*9637de38SSrikanth Yalavarthi 				}
382*9637de38SSrikanth Yalavarthi 			}
383*9637de38SSrikanth Yalavarthi 			f16_m = m_16 & FP16_MASK_M;
384*9637de38SSrikanth Yalavarthi 		} else if ((be_16 >= -(int)(FP16_MSB_M)) && (be_16 < 1)) {
385*9637de38SSrikanth Yalavarthi 			/* underflow: zero / subnormal, be_16 = [-9:0] */
386*9637de38SSrikanth Yalavarthi 			f16_e = 0;
387*9637de38SSrikanth Yalavarthi 
388*9637de38SSrikanth Yalavarthi 			/* add implicit leading zero */
389*9637de38SSrikanth Yalavarthi 			m_32 = f32_m | BIT(FP32_LSB_E);
390*9637de38SSrikanth Yalavarthi 			tbits = FP32_LSB_E - FP16_LSB_E - be_16 + 1;
391*9637de38SSrikanth Yalavarthi 			m_16 = m_32 >> tbits;
392*9637de38SSrikanth Yalavarthi 
393*9637de38SSrikanth Yalavarthi 			/* if non-leading truncated bits are set */
394*9637de38SSrikanth Yalavarthi 			if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) {
395*9637de38SSrikanth Yalavarthi 				m_16++;
396*9637de38SSrikanth Yalavarthi 
397*9637de38SSrikanth Yalavarthi 				/* overflow into exponent */
398*9637de38SSrikanth Yalavarthi 				if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1)
399*9637de38SSrikanth Yalavarthi 					f16_e++;
400*9637de38SSrikanth Yalavarthi 			} else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) {
401*9637de38SSrikanth Yalavarthi 				/* if leading truncated bit is set */
402*9637de38SSrikanth Yalavarthi 				if ((m_16 & 0x1) == 0x1) {
403*9637de38SSrikanth Yalavarthi 					m_16++;
404*9637de38SSrikanth Yalavarthi 
405*9637de38SSrikanth Yalavarthi 					/* overflow into exponent */
406*9637de38SSrikanth Yalavarthi 					if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1)
407*9637de38SSrikanth Yalavarthi 						f16_e++;
408*9637de38SSrikanth Yalavarthi 				}
409*9637de38SSrikanth Yalavarthi 			}
410*9637de38SSrikanth Yalavarthi 			f16_m = m_16 & FP16_MASK_M;
411*9637de38SSrikanth Yalavarthi 		} else if (be_16 == -(int)(FP16_MSB_M + 1)) {
412*9637de38SSrikanth Yalavarthi 			/* underflow: zero, be_16 = [-10] */
413*9637de38SSrikanth Yalavarthi 			f16_e = 0;
414*9637de38SSrikanth Yalavarthi 			if (f32_m != 0)
415*9637de38SSrikanth Yalavarthi 				f16_m = 1;
416*9637de38SSrikanth Yalavarthi 			else
417*9637de38SSrikanth Yalavarthi 				f16_m = 0;
418*9637de38SSrikanth Yalavarthi 		} else {
419*9637de38SSrikanth Yalavarthi 			/* underflow: zero, be_16 = [-INF:-11] */
420*9637de38SSrikanth Yalavarthi 			f16_e = 0;
421*9637de38SSrikanth Yalavarthi 			f16_m = 0;
422*9637de38SSrikanth Yalavarthi 		}
423*9637de38SSrikanth Yalavarthi 
424*9637de38SSrikanth Yalavarthi 		break;
425*9637de38SSrikanth Yalavarthi 	}
426*9637de38SSrikanth Yalavarthi 
427*9637de38SSrikanth Yalavarthi 	u16 = FP16_PACK(f16_s, f16_e, f16_m);
428*9637de38SSrikanth Yalavarthi 
429*9637de38SSrikanth Yalavarthi 	return u16;
430*9637de38SSrikanth Yalavarthi }
431*9637de38SSrikanth Yalavarthi 
432*9637de38SSrikanth Yalavarthi __rte_weak int
433*9637de38SSrikanth Yalavarthi rte_ml_io_float32_to_float16(uint64_t nb_elements, void *input, void *output)
434*9637de38SSrikanth Yalavarthi {
435*9637de38SSrikanth Yalavarthi 	float *input_buffer;
436*9637de38SSrikanth Yalavarthi 	uint16_t *output_buffer;
437*9637de38SSrikanth Yalavarthi 	uint64_t i;
438*9637de38SSrikanth Yalavarthi 
439*9637de38SSrikanth Yalavarthi 	if ((nb_elements == 0) || (input == NULL) || (output == NULL))
440*9637de38SSrikanth Yalavarthi 		return -EINVAL;
441*9637de38SSrikanth Yalavarthi 
442*9637de38SSrikanth Yalavarthi 	input_buffer = (float *)input;
443*9637de38SSrikanth Yalavarthi 	output_buffer = (uint16_t *)output;
444*9637de38SSrikanth Yalavarthi 
445*9637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
446*9637de38SSrikanth Yalavarthi 		*output_buffer = __float32_to_float16_scalar_rtn(*input_buffer);
447*9637de38SSrikanth Yalavarthi 
448*9637de38SSrikanth Yalavarthi 		input_buffer = input_buffer + 1;
449*9637de38SSrikanth Yalavarthi 		output_buffer = output_buffer + 1;
450*9637de38SSrikanth Yalavarthi 	}
451*9637de38SSrikanth Yalavarthi 
452*9637de38SSrikanth Yalavarthi 	return 0;
453*9637de38SSrikanth Yalavarthi }
454*9637de38SSrikanth Yalavarthi 
455*9637de38SSrikanth Yalavarthi /* Convert a half precision floating point number (float16) into a single precision
456*9637de38SSrikanth Yalavarthi  * floating point number (float32).
457*9637de38SSrikanth Yalavarthi  */
458*9637de38SSrikanth Yalavarthi static float
459*9637de38SSrikanth Yalavarthi __float16_to_float32_scalar_rtx(uint16_t f16)
460*9637de38SSrikanth Yalavarthi {
461*9637de38SSrikanth Yalavarthi 	union float32 f32; /* float32 output */
462*9637de38SSrikanth Yalavarthi 	uint16_t f16_s;	   /* float16 sign */
463*9637de38SSrikanth Yalavarthi 	uint16_t f16_e;	   /* float16 exponent */
464*9637de38SSrikanth Yalavarthi 	uint16_t f16_m;	   /* float16 mantissa */
465*9637de38SSrikanth Yalavarthi 	uint32_t f32_s;	   /* float32 sign */
466*9637de38SSrikanth Yalavarthi 	uint32_t f32_e;	   /* float32 exponent */
467*9637de38SSrikanth Yalavarthi 	uint32_t f32_m;	   /* float32 mantissa*/
468*9637de38SSrikanth Yalavarthi 	uint8_t shift;	   /* number of bits to be shifted */
469*9637de38SSrikanth Yalavarthi 	uint32_t clz;	   /* count of leading zeroes */
470*9637de38SSrikanth Yalavarthi 	int e_16;	   /* float16 exponent unbiased */
471*9637de38SSrikanth Yalavarthi 
472*9637de38SSrikanth Yalavarthi 	f16_s = (f16 & FP16_MASK_S) >> FP16_LSB_S;
473*9637de38SSrikanth Yalavarthi 	f16_e = (f16 & FP16_MASK_E) >> FP16_LSB_E;
474*9637de38SSrikanth Yalavarthi 	f16_m = (f16 & FP16_MASK_M) >> FP16_LSB_M;
475*9637de38SSrikanth Yalavarthi 
476*9637de38SSrikanth Yalavarthi 	f32_s = f16_s;
477*9637de38SSrikanth Yalavarthi 	switch (f16_e) {
478*9637de38SSrikanth Yalavarthi 	case (FP16_MASK_E >> FP16_LSB_E): /* float16: infinity or nan */
479*9637de38SSrikanth Yalavarthi 		f32_e = FP32_MASK_E >> FP32_LSB_E;
480*9637de38SSrikanth Yalavarthi 		if (f16_m == 0x0) { /* infinity */
481*9637de38SSrikanth Yalavarthi 			f32_m = f16_m;
482*9637de38SSrikanth Yalavarthi 		} else { /* nan, propagate mantissa, set MSB of mantissa to 1 */
483*9637de38SSrikanth Yalavarthi 			f32_m = f16_m;
484*9637de38SSrikanth Yalavarthi 			shift = FP32_MSB_M - FP16_MSB_M;
485*9637de38SSrikanth Yalavarthi 			f32_m = (f32_m << shift) & FP32_MASK_M;
486*9637de38SSrikanth Yalavarthi 			f32_m |= BIT(FP32_MSB_M);
487*9637de38SSrikanth Yalavarthi 		}
488*9637de38SSrikanth Yalavarthi 		break;
489*9637de38SSrikanth Yalavarthi 	case 0: /* float16: zero or sub-normal */
490*9637de38SSrikanth Yalavarthi 		f32_m = f16_m;
491*9637de38SSrikanth Yalavarthi 		if (f16_m == 0) { /* zero signed */
492*9637de38SSrikanth Yalavarthi 			f32_e = 0;
493*9637de38SSrikanth Yalavarthi 		} else { /* subnormal numbers */
494*9637de38SSrikanth Yalavarthi 			clz = __builtin_clz((uint32_t)f16_m) - sizeof(uint32_t) * 8 + FP16_LSB_E;
495*9637de38SSrikanth Yalavarthi 			e_16 = (int)f16_e - clz;
496*9637de38SSrikanth Yalavarthi 			f32_e = FP32_BIAS_E + e_16 - FP16_BIAS_E;
497*9637de38SSrikanth Yalavarthi 
498*9637de38SSrikanth Yalavarthi 			shift = clz + (FP32_MSB_M - FP16_MSB_M) + 1;
499*9637de38SSrikanth Yalavarthi 			f32_m = (f32_m << shift) & FP32_MASK_M;
500*9637de38SSrikanth Yalavarthi 		}
501*9637de38SSrikanth Yalavarthi 		break;
502*9637de38SSrikanth Yalavarthi 	default: /* normal numbers */
503*9637de38SSrikanth Yalavarthi 		f32_m = f16_m;
504*9637de38SSrikanth Yalavarthi 		e_16 = (int)f16_e;
505*9637de38SSrikanth Yalavarthi 		f32_e = FP32_BIAS_E + e_16 - FP16_BIAS_E;
506*9637de38SSrikanth Yalavarthi 
507*9637de38SSrikanth Yalavarthi 		shift = (FP32_MSB_M - FP16_MSB_M);
508*9637de38SSrikanth Yalavarthi 		f32_m = (f32_m << shift) & FP32_MASK_M;
509*9637de38SSrikanth Yalavarthi 	}
510*9637de38SSrikanth Yalavarthi 
511*9637de38SSrikanth Yalavarthi 	f32.u = FP32_PACK(f32_s, f32_e, f32_m);
512*9637de38SSrikanth Yalavarthi 
513*9637de38SSrikanth Yalavarthi 	return f32.f;
514*9637de38SSrikanth Yalavarthi }
515*9637de38SSrikanth Yalavarthi 
516*9637de38SSrikanth Yalavarthi __rte_weak int
517*9637de38SSrikanth Yalavarthi rte_ml_io_float16_to_float32(uint64_t nb_elements, void *input, void *output)
518*9637de38SSrikanth Yalavarthi {
519*9637de38SSrikanth Yalavarthi 	uint16_t *input_buffer;
520*9637de38SSrikanth Yalavarthi 	float *output_buffer;
521*9637de38SSrikanth Yalavarthi 	uint64_t i;
522*9637de38SSrikanth Yalavarthi 
523*9637de38SSrikanth Yalavarthi 	if ((nb_elements == 0) || (input == NULL) || (output == NULL))
524*9637de38SSrikanth Yalavarthi 		return -EINVAL;
525*9637de38SSrikanth Yalavarthi 
526*9637de38SSrikanth Yalavarthi 	input_buffer = (uint16_t *)input;
527*9637de38SSrikanth Yalavarthi 	output_buffer = (float *)output;
528*9637de38SSrikanth Yalavarthi 
529*9637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
530*9637de38SSrikanth Yalavarthi 		*output_buffer = __float16_to_float32_scalar_rtx(*input_buffer);
531*9637de38SSrikanth Yalavarthi 
532*9637de38SSrikanth Yalavarthi 		input_buffer = input_buffer + 1;
533*9637de38SSrikanth Yalavarthi 		output_buffer = output_buffer + 1;
534*9637de38SSrikanth Yalavarthi 	}
535*9637de38SSrikanth Yalavarthi 
536*9637de38SSrikanth Yalavarthi 	return 0;
537*9637de38SSrikanth Yalavarthi }
538*9637de38SSrikanth Yalavarthi 
539*9637de38SSrikanth Yalavarthi /* Convert a single precision floating point number (float32) into a
540*9637de38SSrikanth Yalavarthi  * brain float number (bfloat16) using round to nearest rounding mode.
541*9637de38SSrikanth Yalavarthi  */
542*9637de38SSrikanth Yalavarthi static uint16_t
543*9637de38SSrikanth Yalavarthi __float32_to_bfloat16_scalar_rtn(float x)
544*9637de38SSrikanth Yalavarthi {
545*9637de38SSrikanth Yalavarthi 	union float32 f32; /* float32 input */
546*9637de38SSrikanth Yalavarthi 	uint32_t f32_s;	   /* float32 sign */
547*9637de38SSrikanth Yalavarthi 	uint32_t f32_e;	   /* float32 exponent */
548*9637de38SSrikanth Yalavarthi 	uint32_t f32_m;	   /* float32 mantissa */
549*9637de38SSrikanth Yalavarthi 	uint16_t b16_s;	   /* float16 sign */
550*9637de38SSrikanth Yalavarthi 	uint16_t b16_e;	   /* float16 exponent */
551*9637de38SSrikanth Yalavarthi 	uint16_t b16_m;	   /* float16 mantissa */
552*9637de38SSrikanth Yalavarthi 	uint32_t tbits;	   /* number of truncated bits */
553*9637de38SSrikanth Yalavarthi 	uint16_t u16;	   /* float16 output */
554*9637de38SSrikanth Yalavarthi 
555*9637de38SSrikanth Yalavarthi 	f32.f = x;
556*9637de38SSrikanth Yalavarthi 	f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S;
557*9637de38SSrikanth Yalavarthi 	f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E;
558*9637de38SSrikanth Yalavarthi 	f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M;
559*9637de38SSrikanth Yalavarthi 
560*9637de38SSrikanth Yalavarthi 	b16_s = f32_s;
561*9637de38SSrikanth Yalavarthi 	b16_e = 0;
562*9637de38SSrikanth Yalavarthi 	b16_m = 0;
563*9637de38SSrikanth Yalavarthi 
564*9637de38SSrikanth Yalavarthi 	switch (f32_e) {
565*9637de38SSrikanth Yalavarthi 	case (0): /* float32: zero or subnormal number */
566*9637de38SSrikanth Yalavarthi 		b16_e = 0;
567*9637de38SSrikanth Yalavarthi 		if (f32_m == 0) /* zero */
568*9637de38SSrikanth Yalavarthi 			b16_m = 0;
569*9637de38SSrikanth Yalavarthi 		else /* subnormal float32 number, normal bfloat16 */
570*9637de38SSrikanth Yalavarthi 			goto bf16_normal;
571*9637de38SSrikanth Yalavarthi 		break;
572*9637de38SSrikanth Yalavarthi 	case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */
573*9637de38SSrikanth Yalavarthi 		b16_e = BF16_MASK_E >> BF16_LSB_E;
574*9637de38SSrikanth Yalavarthi 		if (f32_m == 0) { /* infinity */
575*9637de38SSrikanth Yalavarthi 			b16_m = 0;
576*9637de38SSrikanth Yalavarthi 		} else { /* nan, propagate mantissa and set MSB of mantissa to 1 */
577*9637de38SSrikanth Yalavarthi 			b16_m = f32_m >> (FP32_MSB_M - BF16_MSB_M);
578*9637de38SSrikanth Yalavarthi 			b16_m |= BIT(BF16_MSB_M);
579*9637de38SSrikanth Yalavarthi 		}
580*9637de38SSrikanth Yalavarthi 		break;
581*9637de38SSrikanth Yalavarthi 	default: /* float32: normal number, normal bfloat16 */
582*9637de38SSrikanth Yalavarthi 		goto bf16_normal;
583*9637de38SSrikanth Yalavarthi 	}
584*9637de38SSrikanth Yalavarthi 
585*9637de38SSrikanth Yalavarthi 	goto bf16_pack;
586*9637de38SSrikanth Yalavarthi 
587*9637de38SSrikanth Yalavarthi bf16_normal:
588*9637de38SSrikanth Yalavarthi 	b16_e = f32_e;
589*9637de38SSrikanth Yalavarthi 	tbits = FP32_MSB_M - BF16_MSB_M;
590*9637de38SSrikanth Yalavarthi 	b16_m = f32_m >> tbits;
591*9637de38SSrikanth Yalavarthi 
592*9637de38SSrikanth Yalavarthi 	/* if non-leading truncated bits are set */
593*9637de38SSrikanth Yalavarthi 	if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) {
594*9637de38SSrikanth Yalavarthi 		b16_m++;
595*9637de38SSrikanth Yalavarthi 
596*9637de38SSrikanth Yalavarthi 		/* if overflow into exponent */
597*9637de38SSrikanth Yalavarthi 		if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1)
598*9637de38SSrikanth Yalavarthi 			b16_e++;
599*9637de38SSrikanth Yalavarthi 	} else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) {
600*9637de38SSrikanth Yalavarthi 		/* if only leading truncated bit is set */
601*9637de38SSrikanth Yalavarthi 		if ((b16_m & 0x1) == 0x1) {
602*9637de38SSrikanth Yalavarthi 			b16_m++;
603*9637de38SSrikanth Yalavarthi 
604*9637de38SSrikanth Yalavarthi 			/* if overflow into exponent */
605*9637de38SSrikanth Yalavarthi 			if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1)
606*9637de38SSrikanth Yalavarthi 				b16_e++;
607*9637de38SSrikanth Yalavarthi 		}
608*9637de38SSrikanth Yalavarthi 	}
609*9637de38SSrikanth Yalavarthi 	b16_m = b16_m & BF16_MASK_M;
610*9637de38SSrikanth Yalavarthi 
611*9637de38SSrikanth Yalavarthi bf16_pack:
612*9637de38SSrikanth Yalavarthi 	u16 = BF16_PACK(b16_s, b16_e, b16_m);
613*9637de38SSrikanth Yalavarthi 
614*9637de38SSrikanth Yalavarthi 	return u16;
615*9637de38SSrikanth Yalavarthi }
616*9637de38SSrikanth Yalavarthi 
617*9637de38SSrikanth Yalavarthi __rte_weak int
618*9637de38SSrikanth Yalavarthi rte_ml_io_float32_to_bfloat16(uint64_t nb_elements, void *input, void *output)
619*9637de38SSrikanth Yalavarthi {
620*9637de38SSrikanth Yalavarthi 	float *input_buffer;
621*9637de38SSrikanth Yalavarthi 	uint16_t *output_buffer;
622*9637de38SSrikanth Yalavarthi 	uint64_t i;
623*9637de38SSrikanth Yalavarthi 
624*9637de38SSrikanth Yalavarthi 	if ((nb_elements == 0) || (input == NULL) || (output == NULL))
625*9637de38SSrikanth Yalavarthi 		return -EINVAL;
626*9637de38SSrikanth Yalavarthi 
627*9637de38SSrikanth Yalavarthi 	input_buffer = (float *)input;
628*9637de38SSrikanth Yalavarthi 	output_buffer = (uint16_t *)output;
629*9637de38SSrikanth Yalavarthi 
630*9637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
631*9637de38SSrikanth Yalavarthi 		*output_buffer = __float32_to_bfloat16_scalar_rtn(*input_buffer);
632*9637de38SSrikanth Yalavarthi 
633*9637de38SSrikanth Yalavarthi 		input_buffer = input_buffer + 1;
634*9637de38SSrikanth Yalavarthi 		output_buffer = output_buffer + 1;
635*9637de38SSrikanth Yalavarthi 	}
636*9637de38SSrikanth Yalavarthi 
637*9637de38SSrikanth Yalavarthi 	return 0;
638*9637de38SSrikanth Yalavarthi }
639*9637de38SSrikanth Yalavarthi 
640*9637de38SSrikanth Yalavarthi /* Convert a brain float number (bfloat16) into a
641*9637de38SSrikanth Yalavarthi  * single precision floating point number (float32).
642*9637de38SSrikanth Yalavarthi  */
643*9637de38SSrikanth Yalavarthi static float
644*9637de38SSrikanth Yalavarthi __bfloat16_to_float32_scalar_rtx(uint16_t f16)
645*9637de38SSrikanth Yalavarthi {
646*9637de38SSrikanth Yalavarthi 	union float32 f32; /* float32 output */
647*9637de38SSrikanth Yalavarthi 	uint16_t b16_s;	   /* float16 sign */
648*9637de38SSrikanth Yalavarthi 	uint16_t b16_e;	   /* float16 exponent */
649*9637de38SSrikanth Yalavarthi 	uint16_t b16_m;	   /* float16 mantissa */
650*9637de38SSrikanth Yalavarthi 	uint32_t f32_s;	   /* float32 sign */
651*9637de38SSrikanth Yalavarthi 	uint32_t f32_e;	   /* float32 exponent */
652*9637de38SSrikanth Yalavarthi 	uint32_t f32_m;	   /* float32 mantissa*/
653*9637de38SSrikanth Yalavarthi 	uint8_t shift;	   /* number of bits to be shifted */
654*9637de38SSrikanth Yalavarthi 
655*9637de38SSrikanth Yalavarthi 	b16_s = (f16 & BF16_MASK_S) >> BF16_LSB_S;
656*9637de38SSrikanth Yalavarthi 	b16_e = (f16 & BF16_MASK_E) >> BF16_LSB_E;
657*9637de38SSrikanth Yalavarthi 	b16_m = (f16 & BF16_MASK_M) >> BF16_LSB_M;
658*9637de38SSrikanth Yalavarthi 
659*9637de38SSrikanth Yalavarthi 	f32_s = b16_s;
660*9637de38SSrikanth Yalavarthi 	switch (b16_e) {
661*9637de38SSrikanth Yalavarthi 	case (BF16_MASK_E >> BF16_LSB_E): /* bfloat16: infinity or nan */
662*9637de38SSrikanth Yalavarthi 		f32_e = FP32_MASK_E >> FP32_LSB_E;
663*9637de38SSrikanth Yalavarthi 		if (b16_m == 0x0) { /* infinity */
664*9637de38SSrikanth Yalavarthi 			f32_m = 0;
665*9637de38SSrikanth Yalavarthi 		} else { /* nan, propagate mantissa, set MSB of mantissa to 1 */
666*9637de38SSrikanth Yalavarthi 			f32_m = b16_m;
667*9637de38SSrikanth Yalavarthi 			shift = FP32_MSB_M - BF16_MSB_M;
668*9637de38SSrikanth Yalavarthi 			f32_m = (f32_m << shift) & FP32_MASK_M;
669*9637de38SSrikanth Yalavarthi 			f32_m |= BIT(FP32_MSB_M);
670*9637de38SSrikanth Yalavarthi 		}
671*9637de38SSrikanth Yalavarthi 		break;
672*9637de38SSrikanth Yalavarthi 	case 0: /* bfloat16: zero or subnormal */
673*9637de38SSrikanth Yalavarthi 		f32_m = b16_m;
674*9637de38SSrikanth Yalavarthi 		if (b16_m == 0) { /* zero signed */
675*9637de38SSrikanth Yalavarthi 			f32_e = 0;
676*9637de38SSrikanth Yalavarthi 		} else { /* subnormal numbers */
677*9637de38SSrikanth Yalavarthi 			goto fp32_normal;
678*9637de38SSrikanth Yalavarthi 		}
679*9637de38SSrikanth Yalavarthi 		break;
680*9637de38SSrikanth Yalavarthi 	default: /* bfloat16: normal number */
681*9637de38SSrikanth Yalavarthi 		goto fp32_normal;
682*9637de38SSrikanth Yalavarthi 	}
683*9637de38SSrikanth Yalavarthi 
684*9637de38SSrikanth Yalavarthi 	goto fp32_pack;
685*9637de38SSrikanth Yalavarthi 
686*9637de38SSrikanth Yalavarthi fp32_normal:
687*9637de38SSrikanth Yalavarthi 	f32_m = b16_m;
688*9637de38SSrikanth Yalavarthi 	f32_e = FP32_BIAS_E + b16_e - BF16_BIAS_E;
689*9637de38SSrikanth Yalavarthi 
690*9637de38SSrikanth Yalavarthi 	shift = (FP32_MSB_M - BF16_MSB_M);
691*9637de38SSrikanth Yalavarthi 	f32_m = (f32_m << shift) & FP32_MASK_M;
692*9637de38SSrikanth Yalavarthi 
693*9637de38SSrikanth Yalavarthi fp32_pack:
694*9637de38SSrikanth Yalavarthi 	f32.u = FP32_PACK(f32_s, f32_e, f32_m);
695*9637de38SSrikanth Yalavarthi 
696*9637de38SSrikanth Yalavarthi 	return f32.f;
697*9637de38SSrikanth Yalavarthi }
698*9637de38SSrikanth Yalavarthi 
699*9637de38SSrikanth Yalavarthi __rte_weak int
700*9637de38SSrikanth Yalavarthi rte_ml_io_bfloat16_to_float32(uint64_t nb_elements, void *input, void *output)
701*9637de38SSrikanth Yalavarthi {
702*9637de38SSrikanth Yalavarthi 	uint16_t *input_buffer;
703*9637de38SSrikanth Yalavarthi 	float *output_buffer;
704*9637de38SSrikanth Yalavarthi 	uint64_t i;
705*9637de38SSrikanth Yalavarthi 
706*9637de38SSrikanth Yalavarthi 	if ((nb_elements == 0) || (input == NULL) || (output == NULL))
707*9637de38SSrikanth Yalavarthi 		return -EINVAL;
708*9637de38SSrikanth Yalavarthi 
709*9637de38SSrikanth Yalavarthi 	input_buffer = (uint16_t *)input;
710*9637de38SSrikanth Yalavarthi 	output_buffer = (float *)output;
711*9637de38SSrikanth Yalavarthi 
712*9637de38SSrikanth Yalavarthi 	for (i = 0; i < nb_elements; i++) {
713*9637de38SSrikanth Yalavarthi 		*output_buffer = __bfloat16_to_float32_scalar_rtx(*input_buffer);
714*9637de38SSrikanth Yalavarthi 
715*9637de38SSrikanth Yalavarthi 		input_buffer = input_buffer + 1;
716*9637de38SSrikanth Yalavarthi 		output_buffer = output_buffer + 1;
717*9637de38SSrikanth Yalavarthi 	}
718*9637de38SSrikanth Yalavarthi 
719*9637de38SSrikanth Yalavarthi 	return 0;
720*9637de38SSrikanth Yalavarthi }
721