xref: /dpdk/lib/fib/trie_avx512.c (revision daa02b5cddbb8e11b31d41e2bf7bb1ae64dcae2f)
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright(c) 2020 Intel Corporation
3  */
4 
5 #include <rte_vect.h>
6 #include <rte_fib6.h>
7 
8 #include "trie.h"
9 #include "trie_avx512.h"
10 
11 static __rte_always_inline void
12 transpose_x16(uint8_t ips[16][RTE_FIB6_IPV6_ADDR_SIZE],
13 	__m512i *first, __m512i *second, __m512i *third, __m512i *fourth)
14 {
15 	__m512i tmp1, tmp2, tmp3, tmp4;
16 	__m512i tmp5, tmp6, tmp7, tmp8;
17 	const __rte_x86_zmm_t perm_idxes = {
18 		.u32 = { 0, 4, 8, 12, 2, 6, 10, 14,
19 			1, 5, 9, 13, 3, 7, 11, 15
20 		},
21 	};
22 
23 	/* load all ip addresses */
24 	tmp1 = _mm512_loadu_si512(&ips[0][0]);
25 	tmp2 = _mm512_loadu_si512(&ips[4][0]);
26 	tmp3 = _mm512_loadu_si512(&ips[8][0]);
27 	tmp4 = _mm512_loadu_si512(&ips[12][0]);
28 
29 	/* transpose 4 byte chunks of 16 ips */
30 	tmp5 = _mm512_unpacklo_epi32(tmp1, tmp2);
31 	tmp7 = _mm512_unpackhi_epi32(tmp1, tmp2);
32 	tmp6 = _mm512_unpacklo_epi32(tmp3, tmp4);
33 	tmp8 = _mm512_unpackhi_epi32(tmp3, tmp4);
34 
35 	tmp1 = _mm512_unpacklo_epi32(tmp5, tmp6);
36 	tmp3 = _mm512_unpackhi_epi32(tmp5, tmp6);
37 	tmp2 = _mm512_unpacklo_epi32(tmp7, tmp8);
38 	tmp4 = _mm512_unpackhi_epi32(tmp7, tmp8);
39 
40 	/* first 4-byte chunks of ips[] */
41 	*first = _mm512_permutexvar_epi32(perm_idxes.z, tmp1);
42 	/* second 4-byte chunks of ips[] */
43 	*second = _mm512_permutexvar_epi32(perm_idxes.z, tmp3);
44 	/* third 4-byte chunks of ips[] */
45 	*third = _mm512_permutexvar_epi32(perm_idxes.z, tmp2);
46 	/* fourth 4-byte chunks of ips[] */
47 	*fourth = _mm512_permutexvar_epi32(perm_idxes.z, tmp4);
48 }
49 
50 static __rte_always_inline void
51 transpose_x8(uint8_t ips[8][RTE_FIB6_IPV6_ADDR_SIZE],
52 	__m512i *first, __m512i *second)
53 {
54 	__m512i tmp1, tmp2, tmp3, tmp4;
55 	const __rte_x86_zmm_t perm_idxes = {
56 		.u64 = { 0, 2, 4, 6, 1, 3, 5, 7
57 		},
58 	};
59 
60 	tmp1 = _mm512_loadu_si512(&ips[0][0]);
61 	tmp2 = _mm512_loadu_si512(&ips[4][0]);
62 
63 	tmp3 = _mm512_unpacklo_epi64(tmp1, tmp2);
64 	*first = _mm512_permutexvar_epi64(perm_idxes.z, tmp3);
65 	tmp4 = _mm512_unpackhi_epi64(tmp1, tmp2);
66 	*second = _mm512_permutexvar_epi64(perm_idxes.z, tmp4);
67 }
68 
69 static __rte_always_inline void
70 trie_vec_lookup_x16x2(void *p, uint8_t ips[32][RTE_FIB6_IPV6_ADDR_SIZE],
71 	uint64_t *next_hops, int size)
72 {
73 	struct rte_trie_tbl *dp = (struct rte_trie_tbl *)p;
74 	const __m512i zero = _mm512_set1_epi32(0);
75 	const __m512i lsb = _mm512_set1_epi32(1);
76 	const __m512i two_lsb = _mm512_set1_epi32(3);
77 	/* IPv6 four byte chunks */
78 	__m512i first_1, second_1, third_1, fourth_1;
79 	__m512i first_2, second_2, third_2, fourth_2;
80 	__m512i idxes_1, res_1;
81 	__m512i idxes_2, res_2;
82 	__m512i shuf_idxes;
83 	__m512i tmp_1, tmp2_1, bytes_1, byte_chunk_1;
84 	__m512i tmp_2, tmp2_2, bytes_2, byte_chunk_2;
85 	__m512i base_idxes;
86 	/* used to mask gather values if size is 2 (16 bit next hops) */
87 	const __m512i res_msk = _mm512_set1_epi32(UINT16_MAX);
88 	const __rte_x86_zmm_t bswap = {
89 		.u8 = { 2, 1, 0, 255, 6, 5, 4, 255,
90 			10, 9, 8, 255, 14, 13, 12, 255,
91 			2, 1, 0, 255, 6, 5, 4, 255,
92 			10, 9, 8, 255, 14, 13, 12, 255,
93 			2, 1, 0, 255, 6, 5, 4, 255,
94 			10, 9, 8, 255, 14, 13, 12, 255,
95 			2, 1, 0, 255, 6, 5, 4, 255,
96 			10, 9, 8, 255, 14, 13, 12, 255
97 			},
98 	};
99 	const __mmask64 k = 0x1111111111111111;
100 	int i = 3;
101 	__mmask16 msk_ext_1, new_msk_1;
102 	__mmask16 msk_ext_2, new_msk_2;
103 	__mmask16 exp_msk = 0x5555;
104 
105 	transpose_x16(ips, &first_1, &second_1, &third_1, &fourth_1);
106 	transpose_x16(ips + 16, &first_2, &second_2, &third_2, &fourth_2);
107 
108 	/* get_tbl24_idx() for every 4 byte chunk */
109 	idxes_1 = _mm512_shuffle_epi8(first_1, bswap.z);
110 	idxes_2 = _mm512_shuffle_epi8(first_2, bswap.z);
111 
112 	/**
113 	 * lookup in tbl24
114 	 * Put it inside branch to make compiller happy with -O0
115 	 */
116 	if (size == sizeof(uint16_t)) {
117 		res_1 = _mm512_i32gather_epi32(idxes_1,
118 				(const int *)dp->tbl24, 2);
119 		res_2 = _mm512_i32gather_epi32(idxes_2,
120 				(const int *)dp->tbl24, 2);
121 		res_1 = _mm512_and_epi32(res_1, res_msk);
122 		res_2 = _mm512_and_epi32(res_2, res_msk);
123 	} else {
124 		res_1 = _mm512_i32gather_epi32(idxes_1,
125 				(const int *)dp->tbl24, 4);
126 		res_2 = _mm512_i32gather_epi32(idxes_2,
127 				(const int *)dp->tbl24, 4);
128 	}
129 
130 	/* get extended entries indexes */
131 	msk_ext_1 = _mm512_test_epi32_mask(res_1, lsb);
132 	msk_ext_2 = _mm512_test_epi32_mask(res_2, lsb);
133 
134 	tmp_1 = _mm512_srli_epi32(res_1, 1);
135 	tmp_2 = _mm512_srli_epi32(res_2, 1);
136 
137 	/* idxes to retrieve bytes */
138 	shuf_idxes = _mm512_setr_epi32(3, 7, 11, 15,
139 				19, 23, 27, 31,
140 				35, 39, 43, 47,
141 				51, 55, 59, 63);
142 
143 	base_idxes = _mm512_setr_epi32(0, 4, 8, 12,
144 				16, 20, 24, 28,
145 				32, 36, 40, 44,
146 				48, 52, 56, 60);
147 
148 	/* traverse down the trie */
149 	while (msk_ext_1 || msk_ext_2) {
150 		idxes_1 = _mm512_maskz_slli_epi32(msk_ext_1, tmp_1, 8);
151 		idxes_2 = _mm512_maskz_slli_epi32(msk_ext_2, tmp_2, 8);
152 		byte_chunk_1 = (i < 8) ?
153 			((i >= 4) ? second_1 : first_1) :
154 			((i >= 12) ? fourth_1 : third_1);
155 		byte_chunk_2 = (i < 8) ?
156 			((i >= 4) ? second_2 : first_2) :
157 			((i >= 12) ? fourth_2 : third_2);
158 		bytes_1 = _mm512_maskz_shuffle_epi8(k, byte_chunk_1,
159 				shuf_idxes);
160 		bytes_2 = _mm512_maskz_shuffle_epi8(k, byte_chunk_2,
161 				shuf_idxes);
162 		idxes_1 = _mm512_maskz_add_epi32(msk_ext_1, idxes_1, bytes_1);
163 		idxes_2 = _mm512_maskz_add_epi32(msk_ext_2, idxes_2, bytes_2);
164 		if (size == sizeof(uint16_t)) {
165 			tmp_1 = _mm512_mask_i32gather_epi32(zero, msk_ext_1,
166 				idxes_1, (const int *)dp->tbl8, 2);
167 			tmp_2 = _mm512_mask_i32gather_epi32(zero, msk_ext_2,
168 				idxes_2, (const int *)dp->tbl8, 2);
169 			tmp_1 = _mm512_and_epi32(tmp_1, res_msk);
170 			tmp_2 = _mm512_and_epi32(tmp_2, res_msk);
171 		} else {
172 			tmp_1 = _mm512_mask_i32gather_epi32(zero, msk_ext_1,
173 				idxes_1, (const int *)dp->tbl8, 4);
174 			tmp_2 = _mm512_mask_i32gather_epi32(zero, msk_ext_2,
175 				idxes_2, (const int *)dp->tbl8, 4);
176 		}
177 		new_msk_1 = _mm512_test_epi32_mask(tmp_1, lsb);
178 		new_msk_2 = _mm512_test_epi32_mask(tmp_2, lsb);
179 		res_1 = _mm512_mask_blend_epi32(msk_ext_1 ^ new_msk_1, res_1,
180 				tmp_1);
181 		res_2 = _mm512_mask_blend_epi32(msk_ext_2 ^ new_msk_2, res_2,
182 				tmp_2);
183 		tmp_1 = _mm512_srli_epi32(tmp_1, 1);
184 		tmp_2 = _mm512_srli_epi32(tmp_2, 1);
185 		msk_ext_1 = new_msk_1;
186 		msk_ext_2 = new_msk_2;
187 
188 		shuf_idxes = _mm512_maskz_add_epi8(k, shuf_idxes, lsb);
189 		shuf_idxes = _mm512_and_epi32(shuf_idxes, two_lsb);
190 		shuf_idxes = _mm512_maskz_add_epi8(k, shuf_idxes, base_idxes);
191 		i++;
192 	}
193 
194 	/* get rid of 1 LSB, now we have HN in every epi32 */
195 	res_1 = _mm512_srli_epi32(res_1, 1);
196 	res_2 = _mm512_srli_epi32(res_2, 1);
197 	/* extract first half of NH's each in epi64 chunk */
198 	tmp_1 = _mm512_maskz_expand_epi32(exp_msk, res_1);
199 	tmp_2 = _mm512_maskz_expand_epi32(exp_msk, res_2);
200 	/* extract second half of NH's */
201 	__m256i tmp256_1, tmp256_2;
202 	tmp256_1 = _mm512_extracti32x8_epi32(res_1, 1);
203 	tmp256_2 = _mm512_extracti32x8_epi32(res_2, 1);
204 	tmp2_1 = _mm512_maskz_expand_epi32(exp_msk,
205 		_mm512_castsi256_si512(tmp256_1));
206 	tmp2_2 = _mm512_maskz_expand_epi32(exp_msk,
207 		_mm512_castsi256_si512(tmp256_2));
208 	/* return NH's from two sets of registers */
209 	_mm512_storeu_si512(next_hops, tmp_1);
210 	_mm512_storeu_si512(next_hops + 8, tmp2_1);
211 	_mm512_storeu_si512(next_hops + 16, tmp_2);
212 	_mm512_storeu_si512(next_hops + 24, tmp2_2);
213 }
214 
215 static void
216 trie_vec_lookup_x8x2_8b(void *p, uint8_t ips[16][RTE_FIB6_IPV6_ADDR_SIZE],
217 	uint64_t *next_hops)
218 {
219 	struct rte_trie_tbl *dp = (struct rte_trie_tbl *)p;
220 	const __m512i zero = _mm512_set1_epi32(0);
221 	const __m512i lsb = _mm512_set1_epi32(1);
222 	const __m512i three_lsb = _mm512_set1_epi32(7);
223 	/* IPv6 eight byte chunks */
224 	__m512i first_1, second_1;
225 	__m512i first_2, second_2;
226 	__m512i idxes_1, res_1;
227 	__m512i idxes_2, res_2;
228 	__m512i shuf_idxes, base_idxes;
229 	__m512i tmp_1, bytes_1, byte_chunk_1;
230 	__m512i tmp_2, bytes_2, byte_chunk_2;
231 	const __rte_x86_zmm_t bswap = {
232 		.u8 = { 2, 1, 0, 255, 255, 255, 255, 255,
233 			10, 9, 8, 255, 255, 255, 255, 255,
234 			2, 1, 0, 255, 255, 255, 255, 255,
235 			10, 9, 8, 255, 255, 255, 255, 255,
236 			2, 1, 0, 255, 255, 255, 255, 255,
237 			10, 9, 8, 255, 255, 255, 255, 255,
238 			2, 1, 0, 255, 255, 255, 255, 255,
239 			10, 9, 8, 255, 255, 255, 255, 255
240 			},
241 	};
242 	const __mmask64 k = 0x101010101010101;
243 	int i = 3;
244 	__mmask8 msk_ext_1, new_msk_1;
245 	__mmask8 msk_ext_2, new_msk_2;
246 
247 	transpose_x8(ips, &first_1, &second_1);
248 	transpose_x8(ips + 8, &first_2, &second_2);
249 
250 	/* get_tbl24_idx() for every 4 byte chunk */
251 	idxes_1 = _mm512_shuffle_epi8(first_1, bswap.z);
252 	idxes_2 = _mm512_shuffle_epi8(first_2, bswap.z);
253 
254 	/* lookup in tbl24 */
255 	res_1 = _mm512_i64gather_epi64(idxes_1, (const void *)dp->tbl24, 8);
256 	res_2 = _mm512_i64gather_epi64(idxes_2, (const void *)dp->tbl24, 8);
257 	/* get extended entries indexes */
258 	msk_ext_1 = _mm512_test_epi64_mask(res_1, lsb);
259 	msk_ext_2 = _mm512_test_epi64_mask(res_2, lsb);
260 
261 	tmp_1 = _mm512_srli_epi64(res_1, 1);
262 	tmp_2 = _mm512_srli_epi64(res_2, 1);
263 
264 	/* idxes to retrieve bytes */
265 	shuf_idxes = _mm512_setr_epi64(3, 11, 19, 27, 35, 43, 51, 59);
266 
267 	base_idxes = _mm512_setr_epi64(0, 8, 16, 24, 32, 40, 48, 56);
268 
269 	/* traverse down the trie */
270 	while (msk_ext_1 || msk_ext_2) {
271 		idxes_1 = _mm512_maskz_slli_epi64(msk_ext_1, tmp_1, 8);
272 		idxes_2 = _mm512_maskz_slli_epi64(msk_ext_2, tmp_2, 8);
273 		byte_chunk_1 = (i < 8) ? first_1 : second_1;
274 		byte_chunk_2 = (i < 8) ? first_2 : second_2;
275 		bytes_1 = _mm512_maskz_shuffle_epi8(k, byte_chunk_1,
276 				shuf_idxes);
277 		bytes_2 = _mm512_maskz_shuffle_epi8(k, byte_chunk_2,
278 				shuf_idxes);
279 		idxes_1 = _mm512_maskz_add_epi64(msk_ext_1, idxes_1, bytes_1);
280 		idxes_2 = _mm512_maskz_add_epi64(msk_ext_2, idxes_2, bytes_2);
281 		tmp_1 = _mm512_mask_i64gather_epi64(zero, msk_ext_1,
282 				idxes_1, (const void *)dp->tbl8, 8);
283 		tmp_2 = _mm512_mask_i64gather_epi64(zero, msk_ext_2,
284 				idxes_2, (const void *)dp->tbl8, 8);
285 		new_msk_1 = _mm512_test_epi64_mask(tmp_1, lsb);
286 		new_msk_2 = _mm512_test_epi64_mask(tmp_2, lsb);
287 		res_1 = _mm512_mask_blend_epi64(msk_ext_1 ^ new_msk_1, res_1,
288 				tmp_1);
289 		res_2 = _mm512_mask_blend_epi64(msk_ext_2 ^ new_msk_2, res_2,
290 				tmp_2);
291 		tmp_1 = _mm512_srli_epi64(tmp_1, 1);
292 		tmp_2 = _mm512_srli_epi64(tmp_2, 1);
293 		msk_ext_1 = new_msk_1;
294 		msk_ext_2 = new_msk_2;
295 
296 		shuf_idxes = _mm512_maskz_add_epi8(k, shuf_idxes, lsb);
297 		shuf_idxes = _mm512_and_epi64(shuf_idxes, three_lsb);
298 		shuf_idxes = _mm512_maskz_add_epi8(k, shuf_idxes, base_idxes);
299 		i++;
300 	}
301 
302 	res_1 = _mm512_srli_epi64(res_1, 1);
303 	res_2 = _mm512_srli_epi64(res_2, 1);
304 	_mm512_storeu_si512(next_hops, res_1);
305 	_mm512_storeu_si512(next_hops + 8, res_2);
306 }
307 
308 void
309 rte_trie_vec_lookup_bulk_2b(void *p, uint8_t ips[][RTE_FIB6_IPV6_ADDR_SIZE],
310 	uint64_t *next_hops, const unsigned int n)
311 {
312 	uint32_t i;
313 	for (i = 0; i < (n / 32); i++) {
314 		trie_vec_lookup_x16x2(p, (uint8_t (*)[16])&ips[i * 32][0],
315 				next_hops + i * 32, sizeof(uint16_t));
316 	}
317 	rte_trie_lookup_bulk_2b(p, (uint8_t (*)[16])&ips[i * 32][0],
318 			next_hops + i * 32, n - i * 32);
319 }
320 
321 void
322 rte_trie_vec_lookup_bulk_4b(void *p, uint8_t ips[][RTE_FIB6_IPV6_ADDR_SIZE],
323 	uint64_t *next_hops, const unsigned int n)
324 {
325 	uint32_t i;
326 	for (i = 0; i < (n / 32); i++) {
327 		trie_vec_lookup_x16x2(p, (uint8_t (*)[16])&ips[i * 32][0],
328 				next_hops + i * 32, sizeof(uint32_t));
329 	}
330 	rte_trie_lookup_bulk_4b(p, (uint8_t (*)[16])&ips[i * 32][0],
331 			next_hops + i * 32, n - i * 32);
332 }
333 
334 void
335 rte_trie_vec_lookup_bulk_8b(void *p, uint8_t ips[][RTE_FIB6_IPV6_ADDR_SIZE],
336 	uint64_t *next_hops, const unsigned int n)
337 {
338 	uint32_t i;
339 	for (i = 0; i < (n / 16); i++) {
340 		trie_vec_lookup_x8x2_8b(p, (uint8_t (*)[16])&ips[i * 16][0],
341 				next_hops + i * 16);
342 	}
343 	rte_trie_lookup_bulk_8b(p, (uint8_t (*)[16])&ips[i * 16][0],
344 			next_hops + i * 16, n - i * 16);
345 }
346