xref: /dpdk/lib/mldev/mldev_utils_neon_bfloat16.c (revision 65282e9f8e118a4ca977d1aee2d7f51f44e9bc1b)
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright (c) 2023 Marvell.
3  */
4 
5 #include <errno.h>
6 #include <stdint.h>
7 #include <stdlib.h>
8 
9 #include "mldev_utils.h"
10 
11 #include <arm_neon.h>
12 
13 /* Description:
14  * This file implements vector versions of Machine Learning utility functions used to convert data
15  * types from bfloat16 to float and vice-versa. Implementation is based on Arm Neon intrinsics.
16  */
17 
18 #ifdef __ARM_FEATURE_BF16
19 
20 static inline void
21 __float32_to_bfloat16_neon_f16x4(const float32_t *input, bfloat16_t *output)
22 {
23 	float32x4_t f32x4;
24 	bfloat16x4_t bf16x4;
25 
26 	/* load 4 x float32_t elements */
27 	f32x4 = vld1q_f32(input);
28 
29 	/* convert float32x4_t to bfloat16x4_t */
30 	bf16x4 = vcvt_bf16_f32(f32x4);
31 
32 	/* store bfloat16x4_t */
33 	vst1_bf16(output, bf16x4);
34 }
35 
36 static inline void
37 __float32_to_bfloat16_neon_f16x1(const float32_t *input, bfloat16_t *output)
38 {
39 	float32x4_t f32x4;
40 	bfloat16x4_t bf16x4;
41 
42 	/* load element to 4 lanes */
43 	f32x4 = vld1q_dup_f32(input);
44 
45 	/* convert float32_t to bfloat16_t */
46 	bf16x4 = vcvt_bf16_f32(f32x4);
47 
48 	/* store lane 0 / 1 element */
49 	vst1_lane_bf16(output, bf16x4, 0);
50 }
51 
52 int
53 rte_ml_io_float32_to_bfloat16(const void *input, void *output, uint64_t nb_elements)
54 {
55 	const float32_t *input_buffer;
56 	bfloat16_t *output_buffer;
57 	uint64_t nb_iterations;
58 	uint32_t vlen;
59 	uint64_t i;
60 
61 	if ((nb_elements == 0) || (input == NULL) || (output == NULL))
62 		return -EINVAL;
63 
64 	input_buffer = (const float32_t *)input;
65 	output_buffer = (bfloat16_t *)output;
66 	vlen = 2 * sizeof(float32_t) / sizeof(bfloat16_t);
67 	nb_iterations = nb_elements / vlen;
68 
69 	/* convert vlen elements in each iteration */
70 	for (i = 0; i < nb_iterations; i++) {
71 		__float32_to_bfloat16_neon_f16x4(input_buffer, output_buffer);
72 		input_buffer += vlen;
73 		output_buffer += vlen;
74 	}
75 
76 	/* convert leftover elements */
77 	i = i * vlen;
78 	for (; i < nb_elements; i++) {
79 		__float32_to_bfloat16_neon_f16x1(input_buffer, output_buffer);
80 		input_buffer++;
81 		output_buffer++;
82 	}
83 
84 	return 0;
85 }
86 
87 static inline void
88 __bfloat16_to_float32_neon_f32x4(const bfloat16_t *input, float32_t *output)
89 {
90 	bfloat16x4_t bf16x4;
91 	float32x4_t f32x4;
92 
93 	/* load 4 x bfloat16_t elements */
94 	bf16x4 = vld1_bf16(input);
95 
96 	/* convert bfloat16x4_t to float32x4_t */
97 	f32x4 = vcvt_f32_bf16(bf16x4);
98 
99 	/* store float32x4_t */
100 	vst1q_f32(output, f32x4);
101 }
102 
103 static inline void
104 __bfloat16_to_float32_neon_f32x1(const bfloat16_t *input, float32_t *output)
105 {
106 	bfloat16x4_t bf16x4;
107 	float32x4_t f32x4;
108 
109 	/* load element to 4 lanes */
110 	bf16x4 = vld1_dup_bf16(input);
111 
112 	/* convert bfloat16_t to float32_t */
113 	f32x4 = vcvt_f32_bf16(bf16x4);
114 
115 	/* store lane 0 / 1 element */
116 	vst1q_lane_f32(output, f32x4, 0);
117 }
118 
119 int
120 rte_ml_io_bfloat16_to_float32(const void *input, void *output, uint64_t nb_elements)
121 {
122 	const bfloat16_t *input_buffer;
123 	float32_t *output_buffer;
124 	uint64_t nb_iterations;
125 	uint32_t vlen;
126 	uint64_t i;
127 
128 	if ((nb_elements == 0) || (input == NULL) || (output == NULL))
129 		return -EINVAL;
130 
131 	input_buffer = (const bfloat16_t *)input;
132 	output_buffer = (float32_t *)output;
133 	vlen = 2 * sizeof(float32_t) / sizeof(bfloat16_t);
134 	nb_iterations = nb_elements / vlen;
135 
136 	/* convert vlen elements in each iteration */
137 	for (i = 0; i < nb_iterations; i++) {
138 		__bfloat16_to_float32_neon_f32x4(input_buffer, output_buffer);
139 		input_buffer += vlen;
140 		output_buffer += vlen;
141 	}
142 
143 	/* convert leftover elements */
144 	i = i * vlen;
145 	for (; i < nb_elements; i++) {
146 		__bfloat16_to_float32_neon_f32x1(input_buffer, output_buffer);
147 		input_buffer++;
148 		output_buffer++;
149 	}
150 
151 	return 0;
152 }
153 
154 #endif /* __ARM_FEATURE_BF16 */
155