xref: /dpdk/lib/mldev/mldev_utils_scalar.c (revision 6c4329c8990513bc1fff6afd09d5c22916b9f0c3)
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright (c) 2022 Marvell.
3  */
4 
5 #include <errno.h>
6 #include <math.h>
7 #include <stdint.h>
8 
9 #include "mldev_utils.h"
10 
11 /* Description:
12  * This file implements scalar versions of Machine Learning utility functions used to convert data
13  * types from higher precision to lower precision and vice-versa.
14  */
15 
16 #ifndef BIT
17 #define BIT(nr) (1UL << (nr))
18 #endif
19 
20 #ifndef BITS_PER_LONG
21 #define BITS_PER_LONG (__SIZEOF_LONG__ * 8)
22 #endif
23 
24 #ifndef GENMASK_U32
25 #define GENMASK_U32(h, l) (((~0UL) << (l)) & (~0UL >> (BITS_PER_LONG - 1 - (h))))
26 #endif
27 
28 /* float32: bit index of MSB & LSB of sign, exponent and mantissa */
29 #define FP32_LSB_M 0
30 #define FP32_MSB_M 22
31 #define FP32_LSB_E 23
32 #define FP32_MSB_E 30
33 #define FP32_LSB_S 31
34 #define FP32_MSB_S 31
35 
36 /* float32: bitmask for sign, exponent and mantissa */
37 #define FP32_MASK_S GENMASK_U32(FP32_MSB_S, FP32_LSB_S)
38 #define FP32_MASK_E GENMASK_U32(FP32_MSB_E, FP32_LSB_E)
39 #define FP32_MASK_M GENMASK_U32(FP32_MSB_M, FP32_LSB_M)
40 
41 /* float16: bit index of MSB & LSB of sign, exponent and mantissa */
42 #define FP16_LSB_M 0
43 #define FP16_MSB_M 9
44 #define FP16_LSB_E 10
45 #define FP16_MSB_E 14
46 #define FP16_LSB_S 15
47 #define FP16_MSB_S 15
48 
49 /* float16: bitmask for sign, exponent and mantissa */
50 #define FP16_MASK_S GENMASK_U32(FP16_MSB_S, FP16_LSB_S)
51 #define FP16_MASK_E GENMASK_U32(FP16_MSB_E, FP16_LSB_E)
52 #define FP16_MASK_M GENMASK_U32(FP16_MSB_M, FP16_LSB_M)
53 
54 /* bfloat16: bit index of MSB & LSB of sign, exponent and mantissa */
55 #define BF16_LSB_M 0
56 #define BF16_MSB_M 6
57 #define BF16_LSB_E 7
58 #define BF16_MSB_E 14
59 #define BF16_LSB_S 15
60 #define BF16_MSB_S 15
61 
62 /* bfloat16: bitmask for sign, exponent and mantissa */
63 #define BF16_MASK_S GENMASK_U32(BF16_MSB_S, BF16_LSB_S)
64 #define BF16_MASK_E GENMASK_U32(BF16_MSB_E, BF16_LSB_E)
65 #define BF16_MASK_M GENMASK_U32(BF16_MSB_M, BF16_LSB_M)
66 
67 /* Exponent bias */
68 #define FP32_BIAS_E 127
69 #define FP16_BIAS_E 15
70 #define BF16_BIAS_E 127
71 
72 #define FP32_PACK(sign, exponent, mantissa)                                                        \
73 	(((sign) << FP32_LSB_S) | ((exponent) << FP32_LSB_E) | (mantissa))
74 
75 #define FP16_PACK(sign, exponent, mantissa)                                                        \
76 	(((sign) << FP16_LSB_S) | ((exponent) << FP16_LSB_E) | (mantissa))
77 
78 #define BF16_PACK(sign, exponent, mantissa)                                                        \
79 	(((sign) << BF16_LSB_S) | ((exponent) << BF16_LSB_E) | (mantissa))
80 
81 /* Represent float32 as float and uint32_t */
82 union float32 {
83 	float f;
84 	uint32_t u;
85 };
86 
87 __rte_weak int
88 rte_ml_io_float32_to_int8(float scale, uint64_t nb_elements, void *input, void *output)
89 {
90 	float *input_buffer;
91 	int8_t *output_buffer;
92 	uint64_t i;
93 	int i32;
94 
95 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
96 		return -EINVAL;
97 
98 	input_buffer = (float *)input;
99 	output_buffer = (int8_t *)output;
100 
101 	for (i = 0; i < nb_elements; i++) {
102 		i32 = (int32_t)round((*input_buffer) * scale);
103 
104 		if (i32 < INT8_MIN)
105 			i32 = INT8_MIN;
106 
107 		if (i32 > INT8_MAX)
108 			i32 = INT8_MAX;
109 
110 		*output_buffer = (int8_t)i32;
111 
112 		input_buffer++;
113 		output_buffer++;
114 	}
115 
116 	return 0;
117 }
118 
119 __rte_weak int
120 rte_ml_io_int8_to_float32(float scale, uint64_t nb_elements, void *input, void *output)
121 {
122 	int8_t *input_buffer;
123 	float *output_buffer;
124 	uint64_t i;
125 
126 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
127 		return -EINVAL;
128 
129 	input_buffer = (int8_t *)input;
130 	output_buffer = (float *)output;
131 
132 	for (i = 0; i < nb_elements; i++) {
133 		*output_buffer = scale * (float)(*input_buffer);
134 
135 		input_buffer++;
136 		output_buffer++;
137 	}
138 
139 	return 0;
140 }
141 
142 __rte_weak int
143 rte_ml_io_float32_to_uint8(float scale, uint64_t nb_elements, void *input, void *output)
144 {
145 	float *input_buffer;
146 	uint8_t *output_buffer;
147 	int32_t i32;
148 	uint64_t i;
149 
150 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
151 		return -EINVAL;
152 
153 	input_buffer = (float *)input;
154 	output_buffer = (uint8_t *)output;
155 
156 	for (i = 0; i < nb_elements; i++) {
157 		i32 = (int32_t)round((*input_buffer) * scale);
158 
159 		if (i32 < 0)
160 			i32 = 0;
161 
162 		if (i32 > UINT8_MAX)
163 			i32 = UINT8_MAX;
164 
165 		*output_buffer = (uint8_t)i32;
166 
167 		input_buffer++;
168 		output_buffer++;
169 	}
170 
171 	return 0;
172 }
173 
174 __rte_weak int
175 rte_ml_io_uint8_to_float32(float scale, uint64_t nb_elements, void *input, void *output)
176 {
177 	uint8_t *input_buffer;
178 	float *output_buffer;
179 	uint64_t i;
180 
181 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
182 		return -EINVAL;
183 
184 	input_buffer = (uint8_t *)input;
185 	output_buffer = (float *)output;
186 
187 	for (i = 0; i < nb_elements; i++) {
188 		*output_buffer = scale * (float)(*input_buffer);
189 
190 		input_buffer++;
191 		output_buffer++;
192 	}
193 
194 	return 0;
195 }
196 
197 __rte_weak int
198 rte_ml_io_float32_to_int16(float scale, uint64_t nb_elements, void *input, void *output)
199 {
200 	float *input_buffer;
201 	int16_t *output_buffer;
202 	int32_t i32;
203 	uint64_t i;
204 
205 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
206 		return -EINVAL;
207 
208 	input_buffer = (float *)input;
209 	output_buffer = (int16_t *)output;
210 
211 	for (i = 0; i < nb_elements; i++) {
212 		i32 = (int32_t)round((*input_buffer) * scale);
213 
214 		if (i32 < INT16_MIN)
215 			i32 = INT16_MIN;
216 
217 		if (i32 > INT16_MAX)
218 			i32 = INT16_MAX;
219 
220 		*output_buffer = (int16_t)i32;
221 
222 		input_buffer++;
223 		output_buffer++;
224 	}
225 
226 	return 0;
227 }
228 
229 __rte_weak int
230 rte_ml_io_int16_to_float32(float scale, uint64_t nb_elements, void *input, void *output)
231 {
232 	int16_t *input_buffer;
233 	float *output_buffer;
234 	uint64_t i;
235 
236 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
237 		return -EINVAL;
238 
239 	input_buffer = (int16_t *)input;
240 	output_buffer = (float *)output;
241 
242 	for (i = 0; i < nb_elements; i++) {
243 		*output_buffer = scale * (float)(*input_buffer);
244 
245 		input_buffer++;
246 		output_buffer++;
247 	}
248 
249 	return 0;
250 }
251 
252 __rte_weak int
253 rte_ml_io_float32_to_uint16(float scale, uint64_t nb_elements, void *input, void *output)
254 {
255 	float *input_buffer;
256 	uint16_t *output_buffer;
257 	int32_t i32;
258 	uint64_t i;
259 
260 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
261 		return -EINVAL;
262 
263 	input_buffer = (float *)input;
264 	output_buffer = (uint16_t *)output;
265 
266 	for (i = 0; i < nb_elements; i++) {
267 		i32 = (int32_t)round((*input_buffer) * scale);
268 
269 		if (i32 < 0)
270 			i32 = 0;
271 
272 		if (i32 > UINT16_MAX)
273 			i32 = UINT16_MAX;
274 
275 		*output_buffer = (uint16_t)i32;
276 
277 		input_buffer++;
278 		output_buffer++;
279 	}
280 
281 	return 0;
282 }
283 
284 __rte_weak int
285 rte_ml_io_uint16_to_float32(float scale, uint64_t nb_elements, void *input, void *output)
286 {
287 	uint16_t *input_buffer;
288 	float *output_buffer;
289 	uint64_t i;
290 
291 	if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
292 		return -EINVAL;
293 
294 	input_buffer = (uint16_t *)input;
295 	output_buffer = (float *)output;
296 
297 	for (i = 0; i < nb_elements; i++) {
298 		*output_buffer = scale * (float)(*input_buffer);
299 
300 		input_buffer++;
301 		output_buffer++;
302 	}
303 
304 	return 0;
305 }
306 
307 /* Convert a single precision floating point number (float32) into a half precision
308  * floating point number (float16) using round to nearest rounding mode.
309  */
310 static uint16_t
311 __float32_to_float16_scalar_rtn(float x)
312 {
313 	union float32 f32; /* float32 input */
314 	uint32_t f32_s;	   /* float32 sign */
315 	uint32_t f32_e;	   /* float32 exponent */
316 	uint32_t f32_m;	   /* float32 mantissa */
317 	uint16_t f16_s;	   /* float16 sign */
318 	uint16_t f16_e;	   /* float16 exponent */
319 	uint16_t f16_m;	   /* float16 mantissa */
320 	uint32_t tbits;	   /* number of truncated bits */
321 	uint32_t tmsb;	   /* MSB position of truncated bits */
322 	uint32_t m_32;	   /* temporary float32 mantissa */
323 	uint16_t m_16;	   /* temporary float16 mantissa */
324 	uint16_t u16;	   /* float16 output */
325 	int be_16;	   /* float16 biased exponent, signed */
326 
327 	f32.f = x;
328 	f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S;
329 	f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E;
330 	f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M;
331 
332 	f16_s = f32_s;
333 	f16_e = 0;
334 	f16_m = 0;
335 
336 	switch (f32_e) {
337 	case (0): /* float32: zero or subnormal number */
338 		f16_e = 0;
339 		if (f32_m == 0) /* zero */
340 			f16_m = 0;
341 		else /* subnormal number, convert to zero */
342 			f16_m = 0;
343 		break;
344 	case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */
345 		f16_e = FP16_MASK_E >> FP16_LSB_E;
346 		if (f32_m == 0) { /* infinity */
347 			f16_m = 0;
348 		} else { /* nan, propagate mantissa and set MSB of mantissa to 1 */
349 			f16_m = f32_m >> (FP32_MSB_M - FP16_MSB_M);
350 			f16_m |= BIT(FP16_MSB_M);
351 		}
352 		break;
353 	default: /* float32: normal number */
354 		/* compute biased exponent for float16 */
355 		be_16 = (int)f32_e - FP32_BIAS_E + FP16_BIAS_E;
356 
357 		/* overflow, be_16 = [31-INF], set to infinity */
358 		if (be_16 >= (int)(FP16_MASK_E >> FP16_LSB_E)) {
359 			f16_e = FP16_MASK_E >> FP16_LSB_E;
360 			f16_m = 0;
361 		} else if ((be_16 >= 1) && (be_16 < (int)(FP16_MASK_E >> FP16_LSB_E))) {
362 			/* normal float16, be_16 = [1:30]*/
363 			f16_e = be_16;
364 			m_16 = f32_m >> (FP32_LSB_E - FP16_LSB_E);
365 			tmsb = FP32_MSB_M - FP16_MSB_M - 1;
366 			if ((f32_m & GENMASK_U32(tmsb, 0)) > BIT(tmsb)) {
367 				/* round: non-zero truncated bits except MSB */
368 				m_16++;
369 
370 				/* overflow into exponent */
371 				if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1)
372 					f16_e++;
373 			} else if ((f32_m & GENMASK_U32(tmsb, 0)) == BIT(tmsb)) {
374 				/* round: MSB of truncated bits and LSB of m_16 is set */
375 				if ((m_16 & 0x1) == 0x1) {
376 					m_16++;
377 
378 					/* overflow into exponent */
379 					if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1)
380 						f16_e++;
381 				}
382 			}
383 			f16_m = m_16 & FP16_MASK_M;
384 		} else if ((be_16 >= -(int)(FP16_MSB_M)) && (be_16 < 1)) {
385 			/* underflow: zero / subnormal, be_16 = [-9:0] */
386 			f16_e = 0;
387 
388 			/* add implicit leading zero */
389 			m_32 = f32_m | BIT(FP32_LSB_E);
390 			tbits = FP32_LSB_E - FP16_LSB_E - be_16 + 1;
391 			m_16 = m_32 >> tbits;
392 
393 			/* if non-leading truncated bits are set */
394 			if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) {
395 				m_16++;
396 
397 				/* overflow into exponent */
398 				if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1)
399 					f16_e++;
400 			} else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) {
401 				/* if leading truncated bit is set */
402 				if ((m_16 & 0x1) == 0x1) {
403 					m_16++;
404 
405 					/* overflow into exponent */
406 					if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1)
407 						f16_e++;
408 				}
409 			}
410 			f16_m = m_16 & FP16_MASK_M;
411 		} else if (be_16 == -(int)(FP16_MSB_M + 1)) {
412 			/* underflow: zero, be_16 = [-10] */
413 			f16_e = 0;
414 			if (f32_m != 0)
415 				f16_m = 1;
416 			else
417 				f16_m = 0;
418 		} else {
419 			/* underflow: zero, be_16 = [-INF:-11] */
420 			f16_e = 0;
421 			f16_m = 0;
422 		}
423 
424 		break;
425 	}
426 
427 	u16 = FP16_PACK(f16_s, f16_e, f16_m);
428 
429 	return u16;
430 }
431 
432 __rte_weak int
433 rte_ml_io_float32_to_float16(uint64_t nb_elements, void *input, void *output)
434 {
435 	float *input_buffer;
436 	uint16_t *output_buffer;
437 	uint64_t i;
438 
439 	if ((nb_elements == 0) || (input == NULL) || (output == NULL))
440 		return -EINVAL;
441 
442 	input_buffer = (float *)input;
443 	output_buffer = (uint16_t *)output;
444 
445 	for (i = 0; i < nb_elements; i++) {
446 		*output_buffer = __float32_to_float16_scalar_rtn(*input_buffer);
447 
448 		input_buffer = input_buffer + 1;
449 		output_buffer = output_buffer + 1;
450 	}
451 
452 	return 0;
453 }
454 
455 /* Convert a half precision floating point number (float16) into a single precision
456  * floating point number (float32).
457  */
458 static float
459 __float16_to_float32_scalar_rtx(uint16_t f16)
460 {
461 	union float32 f32; /* float32 output */
462 	uint16_t f16_s;	   /* float16 sign */
463 	uint16_t f16_e;	   /* float16 exponent */
464 	uint16_t f16_m;	   /* float16 mantissa */
465 	uint32_t f32_s;	   /* float32 sign */
466 	uint32_t f32_e;	   /* float32 exponent */
467 	uint32_t f32_m;	   /* float32 mantissa*/
468 	uint8_t shift;	   /* number of bits to be shifted */
469 	uint32_t clz;	   /* count of leading zeroes */
470 	int e_16;	   /* float16 exponent unbiased */
471 
472 	f16_s = (f16 & FP16_MASK_S) >> FP16_LSB_S;
473 	f16_e = (f16 & FP16_MASK_E) >> FP16_LSB_E;
474 	f16_m = (f16 & FP16_MASK_M) >> FP16_LSB_M;
475 
476 	f32_s = f16_s;
477 	switch (f16_e) {
478 	case (FP16_MASK_E >> FP16_LSB_E): /* float16: infinity or nan */
479 		f32_e = FP32_MASK_E >> FP32_LSB_E;
480 		if (f16_m == 0x0) { /* infinity */
481 			f32_m = f16_m;
482 		} else { /* nan, propagate mantissa, set MSB of mantissa to 1 */
483 			f32_m = f16_m;
484 			shift = FP32_MSB_M - FP16_MSB_M;
485 			f32_m = (f32_m << shift) & FP32_MASK_M;
486 			f32_m |= BIT(FP32_MSB_M);
487 		}
488 		break;
489 	case 0: /* float16: zero or sub-normal */
490 		f32_m = f16_m;
491 		if (f16_m == 0) { /* zero signed */
492 			f32_e = 0;
493 		} else { /* subnormal numbers */
494 			clz = __builtin_clz((uint32_t)f16_m) - sizeof(uint32_t) * 8 + FP16_LSB_E;
495 			e_16 = (int)f16_e - clz;
496 			f32_e = FP32_BIAS_E + e_16 - FP16_BIAS_E;
497 
498 			shift = clz + (FP32_MSB_M - FP16_MSB_M) + 1;
499 			f32_m = (f32_m << shift) & FP32_MASK_M;
500 		}
501 		break;
502 	default: /* normal numbers */
503 		f32_m = f16_m;
504 		e_16 = (int)f16_e;
505 		f32_e = FP32_BIAS_E + e_16 - FP16_BIAS_E;
506 
507 		shift = (FP32_MSB_M - FP16_MSB_M);
508 		f32_m = (f32_m << shift) & FP32_MASK_M;
509 	}
510 
511 	f32.u = FP32_PACK(f32_s, f32_e, f32_m);
512 
513 	return f32.f;
514 }
515 
516 __rte_weak int
517 rte_ml_io_float16_to_float32(uint64_t nb_elements, void *input, void *output)
518 {
519 	uint16_t *input_buffer;
520 	float *output_buffer;
521 	uint64_t i;
522 
523 	if ((nb_elements == 0) || (input == NULL) || (output == NULL))
524 		return -EINVAL;
525 
526 	input_buffer = (uint16_t *)input;
527 	output_buffer = (float *)output;
528 
529 	for (i = 0; i < nb_elements; i++) {
530 		*output_buffer = __float16_to_float32_scalar_rtx(*input_buffer);
531 
532 		input_buffer = input_buffer + 1;
533 		output_buffer = output_buffer + 1;
534 	}
535 
536 	return 0;
537 }
538 
539 /* Convert a single precision floating point number (float32) into a
540  * brain float number (bfloat16) using round to nearest rounding mode.
541  */
542 static uint16_t
543 __float32_to_bfloat16_scalar_rtn(float x)
544 {
545 	union float32 f32; /* float32 input */
546 	uint32_t f32_s;	   /* float32 sign */
547 	uint32_t f32_e;	   /* float32 exponent */
548 	uint32_t f32_m;	   /* float32 mantissa */
549 	uint16_t b16_s;	   /* float16 sign */
550 	uint16_t b16_e;	   /* float16 exponent */
551 	uint16_t b16_m;	   /* float16 mantissa */
552 	uint32_t tbits;	   /* number of truncated bits */
553 	uint16_t u16;	   /* float16 output */
554 
555 	f32.f = x;
556 	f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S;
557 	f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E;
558 	f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M;
559 
560 	b16_s = f32_s;
561 	b16_e = 0;
562 	b16_m = 0;
563 
564 	switch (f32_e) {
565 	case (0): /* float32: zero or subnormal number */
566 		b16_e = 0;
567 		if (f32_m == 0) /* zero */
568 			b16_m = 0;
569 		else /* subnormal float32 number, normal bfloat16 */
570 			goto bf16_normal;
571 		break;
572 	case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */
573 		b16_e = BF16_MASK_E >> BF16_LSB_E;
574 		if (f32_m == 0) { /* infinity */
575 			b16_m = 0;
576 		} else { /* nan, propagate mantissa and set MSB of mantissa to 1 */
577 			b16_m = f32_m >> (FP32_MSB_M - BF16_MSB_M);
578 			b16_m |= BIT(BF16_MSB_M);
579 		}
580 		break;
581 	default: /* float32: normal number, normal bfloat16 */
582 		goto bf16_normal;
583 	}
584 
585 	goto bf16_pack;
586 
587 bf16_normal:
588 	b16_e = f32_e;
589 	tbits = FP32_MSB_M - BF16_MSB_M;
590 	b16_m = f32_m >> tbits;
591 
592 	/* if non-leading truncated bits are set */
593 	if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) {
594 		b16_m++;
595 
596 		/* if overflow into exponent */
597 		if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1)
598 			b16_e++;
599 	} else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) {
600 		/* if only leading truncated bit is set */
601 		if ((b16_m & 0x1) == 0x1) {
602 			b16_m++;
603 
604 			/* if overflow into exponent */
605 			if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1)
606 				b16_e++;
607 		}
608 	}
609 	b16_m = b16_m & BF16_MASK_M;
610 
611 bf16_pack:
612 	u16 = BF16_PACK(b16_s, b16_e, b16_m);
613 
614 	return u16;
615 }
616 
617 __rte_weak int
618 rte_ml_io_float32_to_bfloat16(uint64_t nb_elements, void *input, void *output)
619 {
620 	float *input_buffer;
621 	uint16_t *output_buffer;
622 	uint64_t i;
623 
624 	if ((nb_elements == 0) || (input == NULL) || (output == NULL))
625 		return -EINVAL;
626 
627 	input_buffer = (float *)input;
628 	output_buffer = (uint16_t *)output;
629 
630 	for (i = 0; i < nb_elements; i++) {
631 		*output_buffer = __float32_to_bfloat16_scalar_rtn(*input_buffer);
632 
633 		input_buffer = input_buffer + 1;
634 		output_buffer = output_buffer + 1;
635 	}
636 
637 	return 0;
638 }
639 
640 /* Convert a brain float number (bfloat16) into a
641  * single precision floating point number (float32).
642  */
643 static float
644 __bfloat16_to_float32_scalar_rtx(uint16_t f16)
645 {
646 	union float32 f32; /* float32 output */
647 	uint16_t b16_s;	   /* float16 sign */
648 	uint16_t b16_e;	   /* float16 exponent */
649 	uint16_t b16_m;	   /* float16 mantissa */
650 	uint32_t f32_s;	   /* float32 sign */
651 	uint32_t f32_e;	   /* float32 exponent */
652 	uint32_t f32_m;	   /* float32 mantissa*/
653 	uint8_t shift;	   /* number of bits to be shifted */
654 
655 	b16_s = (f16 & BF16_MASK_S) >> BF16_LSB_S;
656 	b16_e = (f16 & BF16_MASK_E) >> BF16_LSB_E;
657 	b16_m = (f16 & BF16_MASK_M) >> BF16_LSB_M;
658 
659 	f32_s = b16_s;
660 	switch (b16_e) {
661 	case (BF16_MASK_E >> BF16_LSB_E): /* bfloat16: infinity or nan */
662 		f32_e = FP32_MASK_E >> FP32_LSB_E;
663 		if (b16_m == 0x0) { /* infinity */
664 			f32_m = 0;
665 		} else { /* nan, propagate mantissa, set MSB of mantissa to 1 */
666 			f32_m = b16_m;
667 			shift = FP32_MSB_M - BF16_MSB_M;
668 			f32_m = (f32_m << shift) & FP32_MASK_M;
669 			f32_m |= BIT(FP32_MSB_M);
670 		}
671 		break;
672 	case 0: /* bfloat16: zero or subnormal */
673 		f32_m = b16_m;
674 		if (b16_m == 0) { /* zero signed */
675 			f32_e = 0;
676 		} else { /* subnormal numbers */
677 			goto fp32_normal;
678 		}
679 		break;
680 	default: /* bfloat16: normal number */
681 		goto fp32_normal;
682 	}
683 
684 	goto fp32_pack;
685 
686 fp32_normal:
687 	f32_m = b16_m;
688 	f32_e = FP32_BIAS_E + b16_e - BF16_BIAS_E;
689 
690 	shift = (FP32_MSB_M - BF16_MSB_M);
691 	f32_m = (f32_m << shift) & FP32_MASK_M;
692 
693 fp32_pack:
694 	f32.u = FP32_PACK(f32_s, f32_e, f32_m);
695 
696 	return f32.f;
697 }
698 
699 __rte_weak int
700 rte_ml_io_bfloat16_to_float32(uint64_t nb_elements, void *input, void *output)
701 {
702 	uint16_t *input_buffer;
703 	float *output_buffer;
704 	uint64_t i;
705 
706 	if ((nb_elements == 0) || (input == NULL) || (output == NULL))
707 		return -EINVAL;
708 
709 	input_buffer = (uint16_t *)input;
710 	output_buffer = (float *)output;
711 
712 	for (i = 0; i < nb_elements; i++) {
713 		*output_buffer = __bfloat16_to_float32_scalar_rtx(*input_buffer);
714 
715 		input_buffer = input_buffer + 1;
716 		output_buffer = output_buffer + 1;
717 	}
718 
719 	return 0;
720 }
721