1 /* SPDX-License-Identifier: BSD-3-Clause 2 * Copyright (c) 2022 Marvell. 3 */ 4 5 #include <errno.h> 6 #include <math.h> 7 #include <stdint.h> 8 9 #include "mldev_utils.h" 10 11 /* Description: 12 * This file implements scalar versions of Machine Learning utility functions used to convert data 13 * types from higher precision to lower precision and vice-versa. 14 */ 15 16 #ifndef BIT 17 #define BIT(nr) (1UL << (nr)) 18 #endif 19 20 #ifndef BITS_PER_LONG 21 #define BITS_PER_LONG (__SIZEOF_LONG__ * 8) 22 #endif 23 24 #ifndef GENMASK_U32 25 #define GENMASK_U32(h, l) (((~0UL) << (l)) & (~0UL >> (BITS_PER_LONG - 1 - (h)))) 26 #endif 27 28 /* float32: bit index of MSB & LSB of sign, exponent and mantissa */ 29 #define FP32_LSB_M 0 30 #define FP32_MSB_M 22 31 #define FP32_LSB_E 23 32 #define FP32_MSB_E 30 33 #define FP32_LSB_S 31 34 #define FP32_MSB_S 31 35 36 /* float32: bitmask for sign, exponent and mantissa */ 37 #define FP32_MASK_S GENMASK_U32(FP32_MSB_S, FP32_LSB_S) 38 #define FP32_MASK_E GENMASK_U32(FP32_MSB_E, FP32_LSB_E) 39 #define FP32_MASK_M GENMASK_U32(FP32_MSB_M, FP32_LSB_M) 40 41 /* float16: bit index of MSB & LSB of sign, exponent and mantissa */ 42 #define FP16_LSB_M 0 43 #define FP16_MSB_M 9 44 #define FP16_LSB_E 10 45 #define FP16_MSB_E 14 46 #define FP16_LSB_S 15 47 #define FP16_MSB_S 15 48 49 /* float16: bitmask for sign, exponent and mantissa */ 50 #define FP16_MASK_S GENMASK_U32(FP16_MSB_S, FP16_LSB_S) 51 #define FP16_MASK_E GENMASK_U32(FP16_MSB_E, FP16_LSB_E) 52 #define FP16_MASK_M GENMASK_U32(FP16_MSB_M, FP16_LSB_M) 53 54 /* bfloat16: bit index of MSB & LSB of sign, exponent and mantissa */ 55 #define BF16_LSB_M 0 56 #define BF16_MSB_M 6 57 #define BF16_LSB_E 7 58 #define BF16_MSB_E 14 59 #define BF16_LSB_S 15 60 #define BF16_MSB_S 15 61 62 /* bfloat16: bitmask for sign, exponent and mantissa */ 63 #define BF16_MASK_S GENMASK_U32(BF16_MSB_S, BF16_LSB_S) 64 #define BF16_MASK_E GENMASK_U32(BF16_MSB_E, BF16_LSB_E) 65 #define BF16_MASK_M GENMASK_U32(BF16_MSB_M, BF16_LSB_M) 66 67 /* Exponent bias */ 68 #define FP32_BIAS_E 127 69 #define FP16_BIAS_E 15 70 #define BF16_BIAS_E 127 71 72 #define FP32_PACK(sign, exponent, mantissa) \ 73 (((sign) << FP32_LSB_S) | ((exponent) << FP32_LSB_E) | (mantissa)) 74 75 #define FP16_PACK(sign, exponent, mantissa) \ 76 (((sign) << FP16_LSB_S) | ((exponent) << FP16_LSB_E) | (mantissa)) 77 78 #define BF16_PACK(sign, exponent, mantissa) \ 79 (((sign) << BF16_LSB_S) | ((exponent) << BF16_LSB_E) | (mantissa)) 80 81 /* Represent float32 as float and uint32_t */ 82 union float32 { 83 float f; 84 uint32_t u; 85 }; 86 87 __rte_weak int 88 rte_ml_io_float32_to_int8(float scale, uint64_t nb_elements, void *input, void *output) 89 { 90 float *input_buffer; 91 int8_t *output_buffer; 92 uint64_t i; 93 int i32; 94 95 if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 96 return -EINVAL; 97 98 input_buffer = (float *)input; 99 output_buffer = (int8_t *)output; 100 101 for (i = 0; i < nb_elements; i++) { 102 i32 = (int32_t)round((*input_buffer) * scale); 103 104 if (i32 < INT8_MIN) 105 i32 = INT8_MIN; 106 107 if (i32 > INT8_MAX) 108 i32 = INT8_MAX; 109 110 *output_buffer = (int8_t)i32; 111 112 input_buffer++; 113 output_buffer++; 114 } 115 116 return 0; 117 } 118 119 __rte_weak int 120 rte_ml_io_int8_to_float32(float scale, uint64_t nb_elements, void *input, void *output) 121 { 122 int8_t *input_buffer; 123 float *output_buffer; 124 uint64_t i; 125 126 if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 127 return -EINVAL; 128 129 input_buffer = (int8_t *)input; 130 output_buffer = (float *)output; 131 132 for (i = 0; i < nb_elements; i++) { 133 *output_buffer = scale * (float)(*input_buffer); 134 135 input_buffer++; 136 output_buffer++; 137 } 138 139 return 0; 140 } 141 142 __rte_weak int 143 rte_ml_io_float32_to_uint8(float scale, uint64_t nb_elements, void *input, void *output) 144 { 145 float *input_buffer; 146 uint8_t *output_buffer; 147 int32_t i32; 148 uint64_t i; 149 150 if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 151 return -EINVAL; 152 153 input_buffer = (float *)input; 154 output_buffer = (uint8_t *)output; 155 156 for (i = 0; i < nb_elements; i++) { 157 i32 = (int32_t)round((*input_buffer) * scale); 158 159 if (i32 < 0) 160 i32 = 0; 161 162 if (i32 > UINT8_MAX) 163 i32 = UINT8_MAX; 164 165 *output_buffer = (uint8_t)i32; 166 167 input_buffer++; 168 output_buffer++; 169 } 170 171 return 0; 172 } 173 174 __rte_weak int 175 rte_ml_io_uint8_to_float32(float scale, uint64_t nb_elements, void *input, void *output) 176 { 177 uint8_t *input_buffer; 178 float *output_buffer; 179 uint64_t i; 180 181 if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 182 return -EINVAL; 183 184 input_buffer = (uint8_t *)input; 185 output_buffer = (float *)output; 186 187 for (i = 0; i < nb_elements; i++) { 188 *output_buffer = scale * (float)(*input_buffer); 189 190 input_buffer++; 191 output_buffer++; 192 } 193 194 return 0; 195 } 196 197 __rte_weak int 198 rte_ml_io_float32_to_int16(float scale, uint64_t nb_elements, void *input, void *output) 199 { 200 float *input_buffer; 201 int16_t *output_buffer; 202 int32_t i32; 203 uint64_t i; 204 205 if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 206 return -EINVAL; 207 208 input_buffer = (float *)input; 209 output_buffer = (int16_t *)output; 210 211 for (i = 0; i < nb_elements; i++) { 212 i32 = (int32_t)round((*input_buffer) * scale); 213 214 if (i32 < INT16_MIN) 215 i32 = INT16_MIN; 216 217 if (i32 > INT16_MAX) 218 i32 = INT16_MAX; 219 220 *output_buffer = (int16_t)i32; 221 222 input_buffer++; 223 output_buffer++; 224 } 225 226 return 0; 227 } 228 229 __rte_weak int 230 rte_ml_io_int16_to_float32(float scale, uint64_t nb_elements, void *input, void *output) 231 { 232 int16_t *input_buffer; 233 float *output_buffer; 234 uint64_t i; 235 236 if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 237 return -EINVAL; 238 239 input_buffer = (int16_t *)input; 240 output_buffer = (float *)output; 241 242 for (i = 0; i < nb_elements; i++) { 243 *output_buffer = scale * (float)(*input_buffer); 244 245 input_buffer++; 246 output_buffer++; 247 } 248 249 return 0; 250 } 251 252 __rte_weak int 253 rte_ml_io_float32_to_uint16(float scale, uint64_t nb_elements, void *input, void *output) 254 { 255 float *input_buffer; 256 uint16_t *output_buffer; 257 int32_t i32; 258 uint64_t i; 259 260 if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 261 return -EINVAL; 262 263 input_buffer = (float *)input; 264 output_buffer = (uint16_t *)output; 265 266 for (i = 0; i < nb_elements; i++) { 267 i32 = (int32_t)round((*input_buffer) * scale); 268 269 if (i32 < 0) 270 i32 = 0; 271 272 if (i32 > UINT16_MAX) 273 i32 = UINT16_MAX; 274 275 *output_buffer = (uint16_t)i32; 276 277 input_buffer++; 278 output_buffer++; 279 } 280 281 return 0; 282 } 283 284 __rte_weak int 285 rte_ml_io_uint16_to_float32(float scale, uint64_t nb_elements, void *input, void *output) 286 { 287 uint16_t *input_buffer; 288 float *output_buffer; 289 uint64_t i; 290 291 if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) 292 return -EINVAL; 293 294 input_buffer = (uint16_t *)input; 295 output_buffer = (float *)output; 296 297 for (i = 0; i < nb_elements; i++) { 298 *output_buffer = scale * (float)(*input_buffer); 299 300 input_buffer++; 301 output_buffer++; 302 } 303 304 return 0; 305 } 306 307 /* Convert a single precision floating point number (float32) into a half precision 308 * floating point number (float16) using round to nearest rounding mode. 309 */ 310 static uint16_t 311 __float32_to_float16_scalar_rtn(float x) 312 { 313 union float32 f32; /* float32 input */ 314 uint32_t f32_s; /* float32 sign */ 315 uint32_t f32_e; /* float32 exponent */ 316 uint32_t f32_m; /* float32 mantissa */ 317 uint16_t f16_s; /* float16 sign */ 318 uint16_t f16_e; /* float16 exponent */ 319 uint16_t f16_m; /* float16 mantissa */ 320 uint32_t tbits; /* number of truncated bits */ 321 uint32_t tmsb; /* MSB position of truncated bits */ 322 uint32_t m_32; /* temporary float32 mantissa */ 323 uint16_t m_16; /* temporary float16 mantissa */ 324 uint16_t u16; /* float16 output */ 325 int be_16; /* float16 biased exponent, signed */ 326 327 f32.f = x; 328 f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S; 329 f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E; 330 f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M; 331 332 f16_s = f32_s; 333 f16_e = 0; 334 f16_m = 0; 335 336 switch (f32_e) { 337 case (0): /* float32: zero or subnormal number */ 338 f16_e = 0; 339 if (f32_m == 0) /* zero */ 340 f16_m = 0; 341 else /* subnormal number, convert to zero */ 342 f16_m = 0; 343 break; 344 case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */ 345 f16_e = FP16_MASK_E >> FP16_LSB_E; 346 if (f32_m == 0) { /* infinity */ 347 f16_m = 0; 348 } else { /* nan, propagate mantissa and set MSB of mantissa to 1 */ 349 f16_m = f32_m >> (FP32_MSB_M - FP16_MSB_M); 350 f16_m |= BIT(FP16_MSB_M); 351 } 352 break; 353 default: /* float32: normal number */ 354 /* compute biased exponent for float16 */ 355 be_16 = (int)f32_e - FP32_BIAS_E + FP16_BIAS_E; 356 357 /* overflow, be_16 = [31-INF], set to infinity */ 358 if (be_16 >= (int)(FP16_MASK_E >> FP16_LSB_E)) { 359 f16_e = FP16_MASK_E >> FP16_LSB_E; 360 f16_m = 0; 361 } else if ((be_16 >= 1) && (be_16 < (int)(FP16_MASK_E >> FP16_LSB_E))) { 362 /* normal float16, be_16 = [1:30]*/ 363 f16_e = be_16; 364 m_16 = f32_m >> (FP32_LSB_E - FP16_LSB_E); 365 tmsb = FP32_MSB_M - FP16_MSB_M - 1; 366 if ((f32_m & GENMASK_U32(tmsb, 0)) > BIT(tmsb)) { 367 /* round: non-zero truncated bits except MSB */ 368 m_16++; 369 370 /* overflow into exponent */ 371 if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1) 372 f16_e++; 373 } else if ((f32_m & GENMASK_U32(tmsb, 0)) == BIT(tmsb)) { 374 /* round: MSB of truncated bits and LSB of m_16 is set */ 375 if ((m_16 & 0x1) == 0x1) { 376 m_16++; 377 378 /* overflow into exponent */ 379 if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1) 380 f16_e++; 381 } 382 } 383 f16_m = m_16 & FP16_MASK_M; 384 } else if ((be_16 >= -(int)(FP16_MSB_M)) && (be_16 < 1)) { 385 /* underflow: zero / subnormal, be_16 = [-9:0] */ 386 f16_e = 0; 387 388 /* add implicit leading zero */ 389 m_32 = f32_m | BIT(FP32_LSB_E); 390 tbits = FP32_LSB_E - FP16_LSB_E - be_16 + 1; 391 m_16 = m_32 >> tbits; 392 393 /* if non-leading truncated bits are set */ 394 if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) { 395 m_16++; 396 397 /* overflow into exponent */ 398 if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1) 399 f16_e++; 400 } else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) { 401 /* if leading truncated bit is set */ 402 if ((m_16 & 0x1) == 0x1) { 403 m_16++; 404 405 /* overflow into exponent */ 406 if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1) 407 f16_e++; 408 } 409 } 410 f16_m = m_16 & FP16_MASK_M; 411 } else if (be_16 == -(int)(FP16_MSB_M + 1)) { 412 /* underflow: zero, be_16 = [-10] */ 413 f16_e = 0; 414 if (f32_m != 0) 415 f16_m = 1; 416 else 417 f16_m = 0; 418 } else { 419 /* underflow: zero, be_16 = [-INF:-11] */ 420 f16_e = 0; 421 f16_m = 0; 422 } 423 424 break; 425 } 426 427 u16 = FP16_PACK(f16_s, f16_e, f16_m); 428 429 return u16; 430 } 431 432 __rte_weak int 433 rte_ml_io_float32_to_float16(uint64_t nb_elements, void *input, void *output) 434 { 435 float *input_buffer; 436 uint16_t *output_buffer; 437 uint64_t i; 438 439 if ((nb_elements == 0) || (input == NULL) || (output == NULL)) 440 return -EINVAL; 441 442 input_buffer = (float *)input; 443 output_buffer = (uint16_t *)output; 444 445 for (i = 0; i < nb_elements; i++) { 446 *output_buffer = __float32_to_float16_scalar_rtn(*input_buffer); 447 448 input_buffer = input_buffer + 1; 449 output_buffer = output_buffer + 1; 450 } 451 452 return 0; 453 } 454 455 /* Convert a half precision floating point number (float16) into a single precision 456 * floating point number (float32). 457 */ 458 static float 459 __float16_to_float32_scalar_rtx(uint16_t f16) 460 { 461 union float32 f32; /* float32 output */ 462 uint16_t f16_s; /* float16 sign */ 463 uint16_t f16_e; /* float16 exponent */ 464 uint16_t f16_m; /* float16 mantissa */ 465 uint32_t f32_s; /* float32 sign */ 466 uint32_t f32_e; /* float32 exponent */ 467 uint32_t f32_m; /* float32 mantissa*/ 468 uint8_t shift; /* number of bits to be shifted */ 469 uint32_t clz; /* count of leading zeroes */ 470 int e_16; /* float16 exponent unbiased */ 471 472 f16_s = (f16 & FP16_MASK_S) >> FP16_LSB_S; 473 f16_e = (f16 & FP16_MASK_E) >> FP16_LSB_E; 474 f16_m = (f16 & FP16_MASK_M) >> FP16_LSB_M; 475 476 f32_s = f16_s; 477 switch (f16_e) { 478 case (FP16_MASK_E >> FP16_LSB_E): /* float16: infinity or nan */ 479 f32_e = FP32_MASK_E >> FP32_LSB_E; 480 if (f16_m == 0x0) { /* infinity */ 481 f32_m = f16_m; 482 } else { /* nan, propagate mantissa, set MSB of mantissa to 1 */ 483 f32_m = f16_m; 484 shift = FP32_MSB_M - FP16_MSB_M; 485 f32_m = (f32_m << shift) & FP32_MASK_M; 486 f32_m |= BIT(FP32_MSB_M); 487 } 488 break; 489 case 0: /* float16: zero or sub-normal */ 490 f32_m = f16_m; 491 if (f16_m == 0) { /* zero signed */ 492 f32_e = 0; 493 } else { /* subnormal numbers */ 494 clz = __builtin_clz((uint32_t)f16_m) - sizeof(uint32_t) * 8 + FP16_LSB_E; 495 e_16 = (int)f16_e - clz; 496 f32_e = FP32_BIAS_E + e_16 - FP16_BIAS_E; 497 498 shift = clz + (FP32_MSB_M - FP16_MSB_M) + 1; 499 f32_m = (f32_m << shift) & FP32_MASK_M; 500 } 501 break; 502 default: /* normal numbers */ 503 f32_m = f16_m; 504 e_16 = (int)f16_e; 505 f32_e = FP32_BIAS_E + e_16 - FP16_BIAS_E; 506 507 shift = (FP32_MSB_M - FP16_MSB_M); 508 f32_m = (f32_m << shift) & FP32_MASK_M; 509 } 510 511 f32.u = FP32_PACK(f32_s, f32_e, f32_m); 512 513 return f32.f; 514 } 515 516 __rte_weak int 517 rte_ml_io_float16_to_float32(uint64_t nb_elements, void *input, void *output) 518 { 519 uint16_t *input_buffer; 520 float *output_buffer; 521 uint64_t i; 522 523 if ((nb_elements == 0) || (input == NULL) || (output == NULL)) 524 return -EINVAL; 525 526 input_buffer = (uint16_t *)input; 527 output_buffer = (float *)output; 528 529 for (i = 0; i < nb_elements; i++) { 530 *output_buffer = __float16_to_float32_scalar_rtx(*input_buffer); 531 532 input_buffer = input_buffer + 1; 533 output_buffer = output_buffer + 1; 534 } 535 536 return 0; 537 } 538 539 /* Convert a single precision floating point number (float32) into a 540 * brain float number (bfloat16) using round to nearest rounding mode. 541 */ 542 static uint16_t 543 __float32_to_bfloat16_scalar_rtn(float x) 544 { 545 union float32 f32; /* float32 input */ 546 uint32_t f32_s; /* float32 sign */ 547 uint32_t f32_e; /* float32 exponent */ 548 uint32_t f32_m; /* float32 mantissa */ 549 uint16_t b16_s; /* float16 sign */ 550 uint16_t b16_e; /* float16 exponent */ 551 uint16_t b16_m; /* float16 mantissa */ 552 uint32_t tbits; /* number of truncated bits */ 553 uint16_t u16; /* float16 output */ 554 555 f32.f = x; 556 f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S; 557 f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E; 558 f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M; 559 560 b16_s = f32_s; 561 b16_e = 0; 562 b16_m = 0; 563 564 switch (f32_e) { 565 case (0): /* float32: zero or subnormal number */ 566 b16_e = 0; 567 if (f32_m == 0) /* zero */ 568 b16_m = 0; 569 else /* subnormal float32 number, normal bfloat16 */ 570 goto bf16_normal; 571 break; 572 case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */ 573 b16_e = BF16_MASK_E >> BF16_LSB_E; 574 if (f32_m == 0) { /* infinity */ 575 b16_m = 0; 576 } else { /* nan, propagate mantissa and set MSB of mantissa to 1 */ 577 b16_m = f32_m >> (FP32_MSB_M - BF16_MSB_M); 578 b16_m |= BIT(BF16_MSB_M); 579 } 580 break; 581 default: /* float32: normal number, normal bfloat16 */ 582 goto bf16_normal; 583 } 584 585 goto bf16_pack; 586 587 bf16_normal: 588 b16_e = f32_e; 589 tbits = FP32_MSB_M - BF16_MSB_M; 590 b16_m = f32_m >> tbits; 591 592 /* if non-leading truncated bits are set */ 593 if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) { 594 b16_m++; 595 596 /* if overflow into exponent */ 597 if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1) 598 b16_e++; 599 } else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) { 600 /* if only leading truncated bit is set */ 601 if ((b16_m & 0x1) == 0x1) { 602 b16_m++; 603 604 /* if overflow into exponent */ 605 if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1) 606 b16_e++; 607 } 608 } 609 b16_m = b16_m & BF16_MASK_M; 610 611 bf16_pack: 612 u16 = BF16_PACK(b16_s, b16_e, b16_m); 613 614 return u16; 615 } 616 617 __rte_weak int 618 rte_ml_io_float32_to_bfloat16(uint64_t nb_elements, void *input, void *output) 619 { 620 float *input_buffer; 621 uint16_t *output_buffer; 622 uint64_t i; 623 624 if ((nb_elements == 0) || (input == NULL) || (output == NULL)) 625 return -EINVAL; 626 627 input_buffer = (float *)input; 628 output_buffer = (uint16_t *)output; 629 630 for (i = 0; i < nb_elements; i++) { 631 *output_buffer = __float32_to_bfloat16_scalar_rtn(*input_buffer); 632 633 input_buffer = input_buffer + 1; 634 output_buffer = output_buffer + 1; 635 } 636 637 return 0; 638 } 639 640 /* Convert a brain float number (bfloat16) into a 641 * single precision floating point number (float32). 642 */ 643 static float 644 __bfloat16_to_float32_scalar_rtx(uint16_t f16) 645 { 646 union float32 f32; /* float32 output */ 647 uint16_t b16_s; /* float16 sign */ 648 uint16_t b16_e; /* float16 exponent */ 649 uint16_t b16_m; /* float16 mantissa */ 650 uint32_t f32_s; /* float32 sign */ 651 uint32_t f32_e; /* float32 exponent */ 652 uint32_t f32_m; /* float32 mantissa*/ 653 uint8_t shift; /* number of bits to be shifted */ 654 655 b16_s = (f16 & BF16_MASK_S) >> BF16_LSB_S; 656 b16_e = (f16 & BF16_MASK_E) >> BF16_LSB_E; 657 b16_m = (f16 & BF16_MASK_M) >> BF16_LSB_M; 658 659 f32_s = b16_s; 660 switch (b16_e) { 661 case (BF16_MASK_E >> BF16_LSB_E): /* bfloat16: infinity or nan */ 662 f32_e = FP32_MASK_E >> FP32_LSB_E; 663 if (b16_m == 0x0) { /* infinity */ 664 f32_m = 0; 665 } else { /* nan, propagate mantissa, set MSB of mantissa to 1 */ 666 f32_m = b16_m; 667 shift = FP32_MSB_M - BF16_MSB_M; 668 f32_m = (f32_m << shift) & FP32_MASK_M; 669 f32_m |= BIT(FP32_MSB_M); 670 } 671 break; 672 case 0: /* bfloat16: zero or subnormal */ 673 f32_m = b16_m; 674 if (b16_m == 0) { /* zero signed */ 675 f32_e = 0; 676 } else { /* subnormal numbers */ 677 goto fp32_normal; 678 } 679 break; 680 default: /* bfloat16: normal number */ 681 goto fp32_normal; 682 } 683 684 goto fp32_pack; 685 686 fp32_normal: 687 f32_m = b16_m; 688 f32_e = FP32_BIAS_E + b16_e - BF16_BIAS_E; 689 690 shift = (FP32_MSB_M - BF16_MSB_M); 691 f32_m = (f32_m << shift) & FP32_MASK_M; 692 693 fp32_pack: 694 f32.u = FP32_PACK(f32_s, f32_e, f32_m); 695 696 return f32.f; 697 } 698 699 __rte_weak int 700 rte_ml_io_bfloat16_to_float32(uint64_t nb_elements, void *input, void *output) 701 { 702 uint16_t *input_buffer; 703 float *output_buffer; 704 uint64_t i; 705 706 if ((nb_elements == 0) || (input == NULL) || (output == NULL)) 707 return -EINVAL; 708 709 input_buffer = (uint16_t *)input; 710 output_buffer = (float *)output; 711 712 for (i = 0; i < nb_elements; i++) { 713 *output_buffer = __bfloat16_to_float32_scalar_rtx(*input_buffer); 714 715 input_buffer = input_buffer + 1; 716 output_buffer = output_buffer + 1; 717 } 718 719 return 0; 720 } 721