1 /* SPDX-License-Identifier: BSD-3-Clause 2 * Copyright (c) 2023 Marvell. 3 */ 4 5 #include <errno.h> 6 #include <math.h> 7 #include <stdint.h> 8 9 #include "mldev_utils_scalar.h" 10 11 /* Description: 12 * This file implements scalar versions of Machine Learning utility functions used to convert data 13 * types from bfloat16 to float32 and vice-versa. 14 */ 15 16 /* Convert a single precision floating point number (float32) into a 17 * brain float number (bfloat16) using round to nearest rounding mode. 18 */ 19 static uint16_t 20 __float32_to_bfloat16_scalar_rtn(float x) 21 { 22 union float32 f32; /* float32 input */ 23 uint32_t f32_s; /* float32 sign */ 24 uint32_t f32_e; /* float32 exponent */ 25 uint32_t f32_m; /* float32 mantissa */ 26 uint16_t b16_s; /* float16 sign */ 27 uint16_t b16_e; /* float16 exponent */ 28 uint16_t b16_m; /* float16 mantissa */ 29 uint32_t tbits; /* number of truncated bits */ 30 uint16_t u16; /* float16 output */ 31 32 f32.f = x; 33 f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S; 34 f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E; 35 f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M; 36 37 b16_s = f32_s; 38 b16_e = 0; 39 b16_m = 0; 40 41 switch (f32_e) { 42 case (0): /* float32: zero or subnormal number */ 43 b16_e = 0; 44 if (f32_m == 0) /* zero */ 45 b16_m = 0; 46 else /* subnormal float32 number, normal bfloat16 */ 47 goto bf16_normal; 48 break; 49 case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */ 50 b16_e = BF16_MASK_E >> BF16_LSB_E; 51 if (f32_m == 0) { /* infinity */ 52 b16_m = 0; 53 } else { /* nan, propagate mantissa and set MSB of mantissa to 1 */ 54 b16_m = f32_m >> (FP32_MSB_M - BF16_MSB_M); 55 b16_m |= BIT(BF16_MSB_M); 56 } 57 break; 58 default: /* float32: normal number, normal bfloat16 */ 59 goto bf16_normal; 60 } 61 62 goto bf16_pack; 63 64 bf16_normal: 65 b16_e = f32_e; 66 tbits = FP32_MSB_M - BF16_MSB_M; 67 b16_m = f32_m >> tbits; 68 69 /* if non-leading truncated bits are set */ 70 if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) { 71 b16_m++; 72 73 /* if overflow into exponent */ 74 if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1) 75 b16_e++; 76 } else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) { 77 /* if only leading truncated bit is set */ 78 if ((b16_m & 0x1) == 0x1) { 79 b16_m++; 80 81 /* if overflow into exponent */ 82 if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1) 83 b16_e++; 84 } 85 } 86 b16_m = b16_m & BF16_MASK_M; 87 88 bf16_pack: 89 u16 = BF16_PACK(b16_s, b16_e, b16_m); 90 91 return u16; 92 } 93 94 int 95 rte_ml_io_float32_to_bfloat16(const void *input, void *output, uint64_t nb_elements) 96 { 97 const float *input_buffer; 98 uint16_t *output_buffer; 99 uint64_t i; 100 101 if ((nb_elements == 0) || (input == NULL) || (output == NULL)) 102 return -EINVAL; 103 104 input_buffer = (const float *)input; 105 output_buffer = (uint16_t *)output; 106 107 for (i = 0; i < nb_elements; i++) { 108 *output_buffer = __float32_to_bfloat16_scalar_rtn(*input_buffer); 109 110 input_buffer = input_buffer + 1; 111 output_buffer = output_buffer + 1; 112 } 113 114 return 0; 115 } 116 117 /* Convert a brain float number (bfloat16) into a 118 * single precision floating point number (float32). 119 */ 120 static float 121 __bfloat16_to_float32_scalar_rtx(uint16_t f16) 122 { 123 union float32 f32; /* float32 output */ 124 uint16_t b16_s; /* float16 sign */ 125 uint16_t b16_e; /* float16 exponent */ 126 uint16_t b16_m; /* float16 mantissa */ 127 uint32_t f32_s; /* float32 sign */ 128 uint32_t f32_e; /* float32 exponent */ 129 uint32_t f32_m; /* float32 mantissa*/ 130 uint8_t shift; /* number of bits to be shifted */ 131 132 b16_s = (f16 & BF16_MASK_S) >> BF16_LSB_S; 133 b16_e = (f16 & BF16_MASK_E) >> BF16_LSB_E; 134 b16_m = (f16 & BF16_MASK_M) >> BF16_LSB_M; 135 136 f32_s = b16_s; 137 switch (b16_e) { 138 case (BF16_MASK_E >> BF16_LSB_E): /* bfloat16: infinity or nan */ 139 f32_e = FP32_MASK_E >> FP32_LSB_E; 140 if (b16_m == 0x0) { /* infinity */ 141 f32_m = 0; 142 } else { /* nan, propagate mantissa, set MSB of mantissa to 1 */ 143 f32_m = b16_m; 144 shift = FP32_MSB_M - BF16_MSB_M; 145 f32_m = (f32_m << shift) & FP32_MASK_M; 146 f32_m |= BIT(FP32_MSB_M); 147 } 148 break; 149 case 0: /* bfloat16: zero or subnormal */ 150 f32_m = b16_m; 151 if (b16_m == 0) { /* zero signed */ 152 f32_e = 0; 153 } else { /* subnormal numbers */ 154 goto fp32_normal; 155 } 156 break; 157 default: /* bfloat16: normal number */ 158 goto fp32_normal; 159 } 160 161 goto fp32_pack; 162 163 fp32_normal: 164 f32_m = b16_m; 165 f32_e = FP32_BIAS_E + b16_e - BF16_BIAS_E; 166 167 shift = (FP32_MSB_M - BF16_MSB_M); 168 f32_m = (f32_m << shift) & FP32_MASK_M; 169 170 fp32_pack: 171 f32.u = FP32_PACK(f32_s, f32_e, f32_m); 172 173 return f32.f; 174 } 175 176 int 177 rte_ml_io_bfloat16_to_float32(const void *input, void *output, uint64_t nb_elements) 178 { 179 const uint16_t *input_buffer; 180 float *output_buffer; 181 uint64_t i; 182 183 if ((nb_elements == 0) || (input == NULL) || (output == NULL)) 184 return -EINVAL; 185 186 input_buffer = (const uint16_t *)input; 187 output_buffer = (float *)output; 188 189 for (i = 0; i < nb_elements; i++) { 190 *output_buffer = __bfloat16_to_float32_scalar_rtx(*input_buffer); 191 192 input_buffer = input_buffer + 1; 193 output_buffer = output_buffer + 1; 194 } 195 196 return 0; 197 } 198