xref: /dpdk/lib/mldev/mldev_utils_neon_bfloat16.c (revision 65282e9f8e118a4ca977d1aee2d7f51f44e9bc1b)
1538f6997SSrikanth Yalavarthi /* SPDX-License-Identifier: BSD-3-Clause
2538f6997SSrikanth Yalavarthi  * Copyright (c) 2023 Marvell.
3538f6997SSrikanth Yalavarthi  */
4538f6997SSrikanth Yalavarthi 
5538f6997SSrikanth Yalavarthi #include <errno.h>
6538f6997SSrikanth Yalavarthi #include <stdint.h>
7538f6997SSrikanth Yalavarthi #include <stdlib.h>
8538f6997SSrikanth Yalavarthi 
9538f6997SSrikanth Yalavarthi #include "mldev_utils.h"
10538f6997SSrikanth Yalavarthi 
11538f6997SSrikanth Yalavarthi #include <arm_neon.h>
12538f6997SSrikanth Yalavarthi 
13538f6997SSrikanth Yalavarthi /* Description:
14538f6997SSrikanth Yalavarthi  * This file implements vector versions of Machine Learning utility functions used to convert data
15538f6997SSrikanth Yalavarthi  * types from bfloat16 to float and vice-versa. Implementation is based on Arm Neon intrinsics.
16538f6997SSrikanth Yalavarthi  */
17538f6997SSrikanth Yalavarthi 
18538f6997SSrikanth Yalavarthi #ifdef __ARM_FEATURE_BF16
19538f6997SSrikanth Yalavarthi 
20538f6997SSrikanth Yalavarthi static inline void
21*65282e9fSSrikanth Yalavarthi __float32_to_bfloat16_neon_f16x4(const float32_t *input, bfloat16_t *output)
22538f6997SSrikanth Yalavarthi {
23538f6997SSrikanth Yalavarthi 	float32x4_t f32x4;
24538f6997SSrikanth Yalavarthi 	bfloat16x4_t bf16x4;
25538f6997SSrikanth Yalavarthi 
26538f6997SSrikanth Yalavarthi 	/* load 4 x float32_t elements */
27538f6997SSrikanth Yalavarthi 	f32x4 = vld1q_f32(input);
28538f6997SSrikanth Yalavarthi 
29538f6997SSrikanth Yalavarthi 	/* convert float32x4_t to bfloat16x4_t */
30538f6997SSrikanth Yalavarthi 	bf16x4 = vcvt_bf16_f32(f32x4);
31538f6997SSrikanth Yalavarthi 
32538f6997SSrikanth Yalavarthi 	/* store bfloat16x4_t */
33538f6997SSrikanth Yalavarthi 	vst1_bf16(output, bf16x4);
34538f6997SSrikanth Yalavarthi }
35538f6997SSrikanth Yalavarthi 
36538f6997SSrikanth Yalavarthi static inline void
37*65282e9fSSrikanth Yalavarthi __float32_to_bfloat16_neon_f16x1(const float32_t *input, bfloat16_t *output)
38538f6997SSrikanth Yalavarthi {
39538f6997SSrikanth Yalavarthi 	float32x4_t f32x4;
40538f6997SSrikanth Yalavarthi 	bfloat16x4_t bf16x4;
41538f6997SSrikanth Yalavarthi 
42538f6997SSrikanth Yalavarthi 	/* load element to 4 lanes */
43538f6997SSrikanth Yalavarthi 	f32x4 = vld1q_dup_f32(input);
44538f6997SSrikanth Yalavarthi 
45538f6997SSrikanth Yalavarthi 	/* convert float32_t to bfloat16_t */
46538f6997SSrikanth Yalavarthi 	bf16x4 = vcvt_bf16_f32(f32x4);
47538f6997SSrikanth Yalavarthi 
48538f6997SSrikanth Yalavarthi 	/* store lane 0 / 1 element */
49538f6997SSrikanth Yalavarthi 	vst1_lane_bf16(output, bf16x4, 0);
50538f6997SSrikanth Yalavarthi }
51538f6997SSrikanth Yalavarthi 
52538f6997SSrikanth Yalavarthi int
53*65282e9fSSrikanth Yalavarthi rte_ml_io_float32_to_bfloat16(const void *input, void *output, uint64_t nb_elements)
54538f6997SSrikanth Yalavarthi {
55*65282e9fSSrikanth Yalavarthi 	const float32_t *input_buffer;
56538f6997SSrikanth Yalavarthi 	bfloat16_t *output_buffer;
57538f6997SSrikanth Yalavarthi 	uint64_t nb_iterations;
58538f6997SSrikanth Yalavarthi 	uint32_t vlen;
59538f6997SSrikanth Yalavarthi 	uint64_t i;
60538f6997SSrikanth Yalavarthi 
61538f6997SSrikanth Yalavarthi 	if ((nb_elements == 0) || (input == NULL) || (output == NULL))
62538f6997SSrikanth Yalavarthi 		return -EINVAL;
63538f6997SSrikanth Yalavarthi 
64*65282e9fSSrikanth Yalavarthi 	input_buffer = (const float32_t *)input;
65538f6997SSrikanth Yalavarthi 	output_buffer = (bfloat16_t *)output;
66538f6997SSrikanth Yalavarthi 	vlen = 2 * sizeof(float32_t) / sizeof(bfloat16_t);
67538f6997SSrikanth Yalavarthi 	nb_iterations = nb_elements / vlen;
68538f6997SSrikanth Yalavarthi 
69538f6997SSrikanth Yalavarthi 	/* convert vlen elements in each iteration */
70538f6997SSrikanth Yalavarthi 	for (i = 0; i < nb_iterations; i++) {
71538f6997SSrikanth Yalavarthi 		__float32_to_bfloat16_neon_f16x4(input_buffer, output_buffer);
72538f6997SSrikanth Yalavarthi 		input_buffer += vlen;
73538f6997SSrikanth Yalavarthi 		output_buffer += vlen;
74538f6997SSrikanth Yalavarthi 	}
75538f6997SSrikanth Yalavarthi 
76538f6997SSrikanth Yalavarthi 	/* convert leftover elements */
77538f6997SSrikanth Yalavarthi 	i = i * vlen;
78538f6997SSrikanth Yalavarthi 	for (; i < nb_elements; i++) {
79538f6997SSrikanth Yalavarthi 		__float32_to_bfloat16_neon_f16x1(input_buffer, output_buffer);
80538f6997SSrikanth Yalavarthi 		input_buffer++;
81538f6997SSrikanth Yalavarthi 		output_buffer++;
82538f6997SSrikanth Yalavarthi 	}
83538f6997SSrikanth Yalavarthi 
84538f6997SSrikanth Yalavarthi 	return 0;
85538f6997SSrikanth Yalavarthi }
86538f6997SSrikanth Yalavarthi 
87538f6997SSrikanth Yalavarthi static inline void
88*65282e9fSSrikanth Yalavarthi __bfloat16_to_float32_neon_f32x4(const bfloat16_t *input, float32_t *output)
89538f6997SSrikanth Yalavarthi {
90538f6997SSrikanth Yalavarthi 	bfloat16x4_t bf16x4;
91538f6997SSrikanth Yalavarthi 	float32x4_t f32x4;
92538f6997SSrikanth Yalavarthi 
93538f6997SSrikanth Yalavarthi 	/* load 4 x bfloat16_t elements */
94538f6997SSrikanth Yalavarthi 	bf16x4 = vld1_bf16(input);
95538f6997SSrikanth Yalavarthi 
96538f6997SSrikanth Yalavarthi 	/* convert bfloat16x4_t to float32x4_t */
97538f6997SSrikanth Yalavarthi 	f32x4 = vcvt_f32_bf16(bf16x4);
98538f6997SSrikanth Yalavarthi 
99538f6997SSrikanth Yalavarthi 	/* store float32x4_t */
100538f6997SSrikanth Yalavarthi 	vst1q_f32(output, f32x4);
101538f6997SSrikanth Yalavarthi }
102538f6997SSrikanth Yalavarthi 
103538f6997SSrikanth Yalavarthi static inline void
104*65282e9fSSrikanth Yalavarthi __bfloat16_to_float32_neon_f32x1(const bfloat16_t *input, float32_t *output)
105538f6997SSrikanth Yalavarthi {
106538f6997SSrikanth Yalavarthi 	bfloat16x4_t bf16x4;
107538f6997SSrikanth Yalavarthi 	float32x4_t f32x4;
108538f6997SSrikanth Yalavarthi 
109538f6997SSrikanth Yalavarthi 	/* load element to 4 lanes */
110538f6997SSrikanth Yalavarthi 	bf16x4 = vld1_dup_bf16(input);
111538f6997SSrikanth Yalavarthi 
112538f6997SSrikanth Yalavarthi 	/* convert bfloat16_t to float32_t */
113538f6997SSrikanth Yalavarthi 	f32x4 = vcvt_f32_bf16(bf16x4);
114538f6997SSrikanth Yalavarthi 
115538f6997SSrikanth Yalavarthi 	/* store lane 0 / 1 element */
116538f6997SSrikanth Yalavarthi 	vst1q_lane_f32(output, f32x4, 0);
117538f6997SSrikanth Yalavarthi }
118538f6997SSrikanth Yalavarthi 
119538f6997SSrikanth Yalavarthi int
120*65282e9fSSrikanth Yalavarthi rte_ml_io_bfloat16_to_float32(const void *input, void *output, uint64_t nb_elements)
121538f6997SSrikanth Yalavarthi {
122*65282e9fSSrikanth Yalavarthi 	const bfloat16_t *input_buffer;
123538f6997SSrikanth Yalavarthi 	float32_t *output_buffer;
124538f6997SSrikanth Yalavarthi 	uint64_t nb_iterations;
125538f6997SSrikanth Yalavarthi 	uint32_t vlen;
126538f6997SSrikanth Yalavarthi 	uint64_t i;
127538f6997SSrikanth Yalavarthi 
128538f6997SSrikanth Yalavarthi 	if ((nb_elements == 0) || (input == NULL) || (output == NULL))
129538f6997SSrikanth Yalavarthi 		return -EINVAL;
130538f6997SSrikanth Yalavarthi 
131*65282e9fSSrikanth Yalavarthi 	input_buffer = (const bfloat16_t *)input;
132538f6997SSrikanth Yalavarthi 	output_buffer = (float32_t *)output;
133538f6997SSrikanth Yalavarthi 	vlen = 2 * sizeof(float32_t) / sizeof(bfloat16_t);
134538f6997SSrikanth Yalavarthi 	nb_iterations = nb_elements / vlen;
135538f6997SSrikanth Yalavarthi 
136538f6997SSrikanth Yalavarthi 	/* convert vlen elements in each iteration */
137538f6997SSrikanth Yalavarthi 	for (i = 0; i < nb_iterations; i++) {
138538f6997SSrikanth Yalavarthi 		__bfloat16_to_float32_neon_f32x4(input_buffer, output_buffer);
139538f6997SSrikanth Yalavarthi 		input_buffer += vlen;
140538f6997SSrikanth Yalavarthi 		output_buffer += vlen;
141538f6997SSrikanth Yalavarthi 	}
142538f6997SSrikanth Yalavarthi 
143538f6997SSrikanth Yalavarthi 	/* convert leftover elements */
144538f6997SSrikanth Yalavarthi 	i = i * vlen;
145538f6997SSrikanth Yalavarthi 	for (; i < nb_elements; i++) {
146538f6997SSrikanth Yalavarthi 		__bfloat16_to_float32_neon_f32x1(input_buffer, output_buffer);
147538f6997SSrikanth Yalavarthi 		input_buffer++;
148538f6997SSrikanth Yalavarthi 		output_buffer++;
149538f6997SSrikanth Yalavarthi 	}
150538f6997SSrikanth Yalavarthi 
151538f6997SSrikanth Yalavarthi 	return 0;
152538f6997SSrikanth Yalavarthi }
153538f6997SSrikanth Yalavarthi 
154538f6997SSrikanth Yalavarthi #endif /* __ARM_FEATURE_BF16 */
155