xref: /dpdk/lib/mldev/mldev_utils_scalar_bfloat16.c (revision 65282e9f8e118a4ca977d1aee2d7f51f44e9bc1b)
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright (c) 2023 Marvell.
3  */
4 
5 #include <errno.h>
6 #include <math.h>
7 #include <stdint.h>
8 
9 #include "mldev_utils_scalar.h"
10 
11 /* Description:
12  * This file implements scalar versions of Machine Learning utility functions used to convert data
13  * types from bfloat16 to float32 and vice-versa.
14  */
15 
16 /* Convert a single precision floating point number (float32) into a
17  * brain float number (bfloat16) using round to nearest rounding mode.
18  */
19 static uint16_t
20 __float32_to_bfloat16_scalar_rtn(float x)
21 {
22 	union float32 f32; /* float32 input */
23 	uint32_t f32_s;	   /* float32 sign */
24 	uint32_t f32_e;	   /* float32 exponent */
25 	uint32_t f32_m;	   /* float32 mantissa */
26 	uint16_t b16_s;	   /* float16 sign */
27 	uint16_t b16_e;	   /* float16 exponent */
28 	uint16_t b16_m;	   /* float16 mantissa */
29 	uint32_t tbits;	   /* number of truncated bits */
30 	uint16_t u16;	   /* float16 output */
31 
32 	f32.f = x;
33 	f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S;
34 	f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E;
35 	f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M;
36 
37 	b16_s = f32_s;
38 	b16_e = 0;
39 	b16_m = 0;
40 
41 	switch (f32_e) {
42 	case (0): /* float32: zero or subnormal number */
43 		b16_e = 0;
44 		if (f32_m == 0) /* zero */
45 			b16_m = 0;
46 		else /* subnormal float32 number, normal bfloat16 */
47 			goto bf16_normal;
48 		break;
49 	case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */
50 		b16_e = BF16_MASK_E >> BF16_LSB_E;
51 		if (f32_m == 0) { /* infinity */
52 			b16_m = 0;
53 		} else { /* nan, propagate mantissa and set MSB of mantissa to 1 */
54 			b16_m = f32_m >> (FP32_MSB_M - BF16_MSB_M);
55 			b16_m |= BIT(BF16_MSB_M);
56 		}
57 		break;
58 	default: /* float32: normal number, normal bfloat16 */
59 		goto bf16_normal;
60 	}
61 
62 	goto bf16_pack;
63 
64 bf16_normal:
65 	b16_e = f32_e;
66 	tbits = FP32_MSB_M - BF16_MSB_M;
67 	b16_m = f32_m >> tbits;
68 
69 	/* if non-leading truncated bits are set */
70 	if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) {
71 		b16_m++;
72 
73 		/* if overflow into exponent */
74 		if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1)
75 			b16_e++;
76 	} else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) {
77 		/* if only leading truncated bit is set */
78 		if ((b16_m & 0x1) == 0x1) {
79 			b16_m++;
80 
81 			/* if overflow into exponent */
82 			if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1)
83 				b16_e++;
84 		}
85 	}
86 	b16_m = b16_m & BF16_MASK_M;
87 
88 bf16_pack:
89 	u16 = BF16_PACK(b16_s, b16_e, b16_m);
90 
91 	return u16;
92 }
93 
94 int
95 rte_ml_io_float32_to_bfloat16(const void *input, void *output, uint64_t nb_elements)
96 {
97 	const float *input_buffer;
98 	uint16_t *output_buffer;
99 	uint64_t i;
100 
101 	if ((nb_elements == 0) || (input == NULL) || (output == NULL))
102 		return -EINVAL;
103 
104 	input_buffer = (const float *)input;
105 	output_buffer = (uint16_t *)output;
106 
107 	for (i = 0; i < nb_elements; i++) {
108 		*output_buffer = __float32_to_bfloat16_scalar_rtn(*input_buffer);
109 
110 		input_buffer = input_buffer + 1;
111 		output_buffer = output_buffer + 1;
112 	}
113 
114 	return 0;
115 }
116 
117 /* Convert a brain float number (bfloat16) into a
118  * single precision floating point number (float32).
119  */
120 static float
121 __bfloat16_to_float32_scalar_rtx(uint16_t f16)
122 {
123 	union float32 f32; /* float32 output */
124 	uint16_t b16_s;	   /* float16 sign */
125 	uint16_t b16_e;	   /* float16 exponent */
126 	uint16_t b16_m;	   /* float16 mantissa */
127 	uint32_t f32_s;	   /* float32 sign */
128 	uint32_t f32_e;	   /* float32 exponent */
129 	uint32_t f32_m;	   /* float32 mantissa*/
130 	uint8_t shift;	   /* number of bits to be shifted */
131 
132 	b16_s = (f16 & BF16_MASK_S) >> BF16_LSB_S;
133 	b16_e = (f16 & BF16_MASK_E) >> BF16_LSB_E;
134 	b16_m = (f16 & BF16_MASK_M) >> BF16_LSB_M;
135 
136 	f32_s = b16_s;
137 	switch (b16_e) {
138 	case (BF16_MASK_E >> BF16_LSB_E): /* bfloat16: infinity or nan */
139 		f32_e = FP32_MASK_E >> FP32_LSB_E;
140 		if (b16_m == 0x0) { /* infinity */
141 			f32_m = 0;
142 		} else { /* nan, propagate mantissa, set MSB of mantissa to 1 */
143 			f32_m = b16_m;
144 			shift = FP32_MSB_M - BF16_MSB_M;
145 			f32_m = (f32_m << shift) & FP32_MASK_M;
146 			f32_m |= BIT(FP32_MSB_M);
147 		}
148 		break;
149 	case 0: /* bfloat16: zero or subnormal */
150 		f32_m = b16_m;
151 		if (b16_m == 0) { /* zero signed */
152 			f32_e = 0;
153 		} else { /* subnormal numbers */
154 			goto fp32_normal;
155 		}
156 		break;
157 	default: /* bfloat16: normal number */
158 		goto fp32_normal;
159 	}
160 
161 	goto fp32_pack;
162 
163 fp32_normal:
164 	f32_m = b16_m;
165 	f32_e = FP32_BIAS_E + b16_e - BF16_BIAS_E;
166 
167 	shift = (FP32_MSB_M - BF16_MSB_M);
168 	f32_m = (f32_m << shift) & FP32_MASK_M;
169 
170 fp32_pack:
171 	f32.u = FP32_PACK(f32_s, f32_e, f32_m);
172 
173 	return f32.f;
174 }
175 
176 int
177 rte_ml_io_bfloat16_to_float32(const void *input, void *output, uint64_t nb_elements)
178 {
179 	const uint16_t *input_buffer;
180 	float *output_buffer;
181 	uint64_t i;
182 
183 	if ((nb_elements == 0) || (input == NULL) || (output == NULL))
184 		return -EINVAL;
185 
186 	input_buffer = (const uint16_t *)input;
187 	output_buffer = (float *)output;
188 
189 	for (i = 0; i < nb_elements; i++) {
190 		*output_buffer = __bfloat16_to_float32_scalar_rtx(*input_buffer);
191 
192 		input_buffer = input_buffer + 1;
193 		output_buffer = output_buffer + 1;
194 	}
195 
196 	return 0;
197 }
198