1538f6997SSrikanth Yalavarthi /* SPDX-License-Identifier: BSD-3-Clause 2538f6997SSrikanth Yalavarthi * Copyright (c) 2023 Marvell. 3538f6997SSrikanth Yalavarthi */ 4538f6997SSrikanth Yalavarthi 5538f6997SSrikanth Yalavarthi #include <errno.h> 6538f6997SSrikanth Yalavarthi #include <stdint.h> 7538f6997SSrikanth Yalavarthi #include <stdlib.h> 8538f6997SSrikanth Yalavarthi 9538f6997SSrikanth Yalavarthi #include "mldev_utils.h" 10538f6997SSrikanth Yalavarthi 11538f6997SSrikanth Yalavarthi #include <arm_neon.h> 12538f6997SSrikanth Yalavarthi 13538f6997SSrikanth Yalavarthi /* Description: 14538f6997SSrikanth Yalavarthi * This file implements vector versions of Machine Learning utility functions used to convert data 15538f6997SSrikanth Yalavarthi * types from bfloat16 to float and vice-versa. Implementation is based on Arm Neon intrinsics. 16538f6997SSrikanth Yalavarthi */ 17538f6997SSrikanth Yalavarthi 18538f6997SSrikanth Yalavarthi #ifdef __ARM_FEATURE_BF16 19538f6997SSrikanth Yalavarthi 20538f6997SSrikanth Yalavarthi static inline void 21*65282e9fSSrikanth Yalavarthi __float32_to_bfloat16_neon_f16x4(const float32_t *input, bfloat16_t *output) 22538f6997SSrikanth Yalavarthi { 23538f6997SSrikanth Yalavarthi float32x4_t f32x4; 24538f6997SSrikanth Yalavarthi bfloat16x4_t bf16x4; 25538f6997SSrikanth Yalavarthi 26538f6997SSrikanth Yalavarthi /* load 4 x float32_t elements */ 27538f6997SSrikanth Yalavarthi f32x4 = vld1q_f32(input); 28538f6997SSrikanth Yalavarthi 29538f6997SSrikanth Yalavarthi /* convert float32x4_t to bfloat16x4_t */ 30538f6997SSrikanth Yalavarthi bf16x4 = vcvt_bf16_f32(f32x4); 31538f6997SSrikanth Yalavarthi 32538f6997SSrikanth Yalavarthi /* store bfloat16x4_t */ 33538f6997SSrikanth Yalavarthi vst1_bf16(output, bf16x4); 34538f6997SSrikanth Yalavarthi } 35538f6997SSrikanth Yalavarthi 36538f6997SSrikanth Yalavarthi static inline void 37*65282e9fSSrikanth Yalavarthi __float32_to_bfloat16_neon_f16x1(const float32_t *input, bfloat16_t *output) 38538f6997SSrikanth Yalavarthi { 39538f6997SSrikanth Yalavarthi float32x4_t f32x4; 40538f6997SSrikanth Yalavarthi bfloat16x4_t bf16x4; 41538f6997SSrikanth Yalavarthi 42538f6997SSrikanth Yalavarthi /* load element to 4 lanes */ 43538f6997SSrikanth Yalavarthi f32x4 = vld1q_dup_f32(input); 44538f6997SSrikanth Yalavarthi 45538f6997SSrikanth Yalavarthi /* convert float32_t to bfloat16_t */ 46538f6997SSrikanth Yalavarthi bf16x4 = vcvt_bf16_f32(f32x4); 47538f6997SSrikanth Yalavarthi 48538f6997SSrikanth Yalavarthi /* store lane 0 / 1 element */ 49538f6997SSrikanth Yalavarthi vst1_lane_bf16(output, bf16x4, 0); 50538f6997SSrikanth Yalavarthi } 51538f6997SSrikanth Yalavarthi 52538f6997SSrikanth Yalavarthi int 53*65282e9fSSrikanth Yalavarthi rte_ml_io_float32_to_bfloat16(const void *input, void *output, uint64_t nb_elements) 54538f6997SSrikanth Yalavarthi { 55*65282e9fSSrikanth Yalavarthi const float32_t *input_buffer; 56538f6997SSrikanth Yalavarthi bfloat16_t *output_buffer; 57538f6997SSrikanth Yalavarthi uint64_t nb_iterations; 58538f6997SSrikanth Yalavarthi uint32_t vlen; 59538f6997SSrikanth Yalavarthi uint64_t i; 60538f6997SSrikanth Yalavarthi 61538f6997SSrikanth Yalavarthi if ((nb_elements == 0) || (input == NULL) || (output == NULL)) 62538f6997SSrikanth Yalavarthi return -EINVAL; 63538f6997SSrikanth Yalavarthi 64*65282e9fSSrikanth Yalavarthi input_buffer = (const float32_t *)input; 65538f6997SSrikanth Yalavarthi output_buffer = (bfloat16_t *)output; 66538f6997SSrikanth Yalavarthi vlen = 2 * sizeof(float32_t) / sizeof(bfloat16_t); 67538f6997SSrikanth Yalavarthi nb_iterations = nb_elements / vlen; 68538f6997SSrikanth Yalavarthi 69538f6997SSrikanth Yalavarthi /* convert vlen elements in each iteration */ 70538f6997SSrikanth Yalavarthi for (i = 0; i < nb_iterations; i++) { 71538f6997SSrikanth Yalavarthi __float32_to_bfloat16_neon_f16x4(input_buffer, output_buffer); 72538f6997SSrikanth Yalavarthi input_buffer += vlen; 73538f6997SSrikanth Yalavarthi output_buffer += vlen; 74538f6997SSrikanth Yalavarthi } 75538f6997SSrikanth Yalavarthi 76538f6997SSrikanth Yalavarthi /* convert leftover elements */ 77538f6997SSrikanth Yalavarthi i = i * vlen; 78538f6997SSrikanth Yalavarthi for (; i < nb_elements; i++) { 79538f6997SSrikanth Yalavarthi __float32_to_bfloat16_neon_f16x1(input_buffer, output_buffer); 80538f6997SSrikanth Yalavarthi input_buffer++; 81538f6997SSrikanth Yalavarthi output_buffer++; 82538f6997SSrikanth Yalavarthi } 83538f6997SSrikanth Yalavarthi 84538f6997SSrikanth Yalavarthi return 0; 85538f6997SSrikanth Yalavarthi } 86538f6997SSrikanth Yalavarthi 87538f6997SSrikanth Yalavarthi static inline void 88*65282e9fSSrikanth Yalavarthi __bfloat16_to_float32_neon_f32x4(const bfloat16_t *input, float32_t *output) 89538f6997SSrikanth Yalavarthi { 90538f6997SSrikanth Yalavarthi bfloat16x4_t bf16x4; 91538f6997SSrikanth Yalavarthi float32x4_t f32x4; 92538f6997SSrikanth Yalavarthi 93538f6997SSrikanth Yalavarthi /* load 4 x bfloat16_t elements */ 94538f6997SSrikanth Yalavarthi bf16x4 = vld1_bf16(input); 95538f6997SSrikanth Yalavarthi 96538f6997SSrikanth Yalavarthi /* convert bfloat16x4_t to float32x4_t */ 97538f6997SSrikanth Yalavarthi f32x4 = vcvt_f32_bf16(bf16x4); 98538f6997SSrikanth Yalavarthi 99538f6997SSrikanth Yalavarthi /* store float32x4_t */ 100538f6997SSrikanth Yalavarthi vst1q_f32(output, f32x4); 101538f6997SSrikanth Yalavarthi } 102538f6997SSrikanth Yalavarthi 103538f6997SSrikanth Yalavarthi static inline void 104*65282e9fSSrikanth Yalavarthi __bfloat16_to_float32_neon_f32x1(const bfloat16_t *input, float32_t *output) 105538f6997SSrikanth Yalavarthi { 106538f6997SSrikanth Yalavarthi bfloat16x4_t bf16x4; 107538f6997SSrikanth Yalavarthi float32x4_t f32x4; 108538f6997SSrikanth Yalavarthi 109538f6997SSrikanth Yalavarthi /* load element to 4 lanes */ 110538f6997SSrikanth Yalavarthi bf16x4 = vld1_dup_bf16(input); 111538f6997SSrikanth Yalavarthi 112538f6997SSrikanth Yalavarthi /* convert bfloat16_t to float32_t */ 113538f6997SSrikanth Yalavarthi f32x4 = vcvt_f32_bf16(bf16x4); 114538f6997SSrikanth Yalavarthi 115538f6997SSrikanth Yalavarthi /* store lane 0 / 1 element */ 116538f6997SSrikanth Yalavarthi vst1q_lane_f32(output, f32x4, 0); 117538f6997SSrikanth Yalavarthi } 118538f6997SSrikanth Yalavarthi 119538f6997SSrikanth Yalavarthi int 120*65282e9fSSrikanth Yalavarthi rte_ml_io_bfloat16_to_float32(const void *input, void *output, uint64_t nb_elements) 121538f6997SSrikanth Yalavarthi { 122*65282e9fSSrikanth Yalavarthi const bfloat16_t *input_buffer; 123538f6997SSrikanth Yalavarthi float32_t *output_buffer; 124538f6997SSrikanth Yalavarthi uint64_t nb_iterations; 125538f6997SSrikanth Yalavarthi uint32_t vlen; 126538f6997SSrikanth Yalavarthi uint64_t i; 127538f6997SSrikanth Yalavarthi 128538f6997SSrikanth Yalavarthi if ((nb_elements == 0) || (input == NULL) || (output == NULL)) 129538f6997SSrikanth Yalavarthi return -EINVAL; 130538f6997SSrikanth Yalavarthi 131*65282e9fSSrikanth Yalavarthi input_buffer = (const bfloat16_t *)input; 132538f6997SSrikanth Yalavarthi output_buffer = (float32_t *)output; 133538f6997SSrikanth Yalavarthi vlen = 2 * sizeof(float32_t) / sizeof(bfloat16_t); 134538f6997SSrikanth Yalavarthi nb_iterations = nb_elements / vlen; 135538f6997SSrikanth Yalavarthi 136538f6997SSrikanth Yalavarthi /* convert vlen elements in each iteration */ 137538f6997SSrikanth Yalavarthi for (i = 0; i < nb_iterations; i++) { 138538f6997SSrikanth Yalavarthi __bfloat16_to_float32_neon_f32x4(input_buffer, output_buffer); 139538f6997SSrikanth Yalavarthi input_buffer += vlen; 140538f6997SSrikanth Yalavarthi output_buffer += vlen; 141538f6997SSrikanth Yalavarthi } 142538f6997SSrikanth Yalavarthi 143538f6997SSrikanth Yalavarthi /* convert leftover elements */ 144538f6997SSrikanth Yalavarthi i = i * vlen; 145538f6997SSrikanth Yalavarthi for (; i < nb_elements; i++) { 146538f6997SSrikanth Yalavarthi __bfloat16_to_float32_neon_f32x1(input_buffer, output_buffer); 147538f6997SSrikanth Yalavarthi input_buffer++; 148538f6997SSrikanth Yalavarthi output_buffer++; 149538f6997SSrikanth Yalavarthi } 150538f6997SSrikanth Yalavarthi 151538f6997SSrikanth Yalavarthi return 0; 152538f6997SSrikanth Yalavarthi } 153538f6997SSrikanth Yalavarthi 154538f6997SSrikanth Yalavarthi #endif /* __ARM_FEATURE_BF16 */ 155