xref: /spdk/lib/util/base64_neon.c (revision 784b9d48746955f210926648a0131f84f58de76f)
1 /*   SPDX-License-Identifier: BSD-2-Clause
2  *   Copyright (C) 2020 Intel Corporation.
3  *   Copyright (c) 2005-2007, Nick Galbreath
4  *   Copyright (c) 2013-2017, Alfred Klomp
5  *   Copyright (c) 2015-2017, Wojciech Mula
6  *   Copyright (c) 2016-2017, Matthieu Darbois
7  *   All rights reserved.
8  */
9 
10 #ifndef __aarch64__
11 #error Unsupported hardware
12 #endif
13 
14 #include "spdk/stdinc.h"
15 /*
16  * Encoding
17  * Use a 64-byte lookup to do the encoding.
18  * Reuse existing base64_dec_table and base64_dec_table.
19 
20  * Decoding
21  * The input consists of five valid character sets in the Base64 alphabet,
22  * which we need to map back to the 6-bit values they represent.
23  * There are three ranges, two singles, and then there's the rest.
24  *
25  * LUT1[0-63] = base64_dec_table_neon64[0-63]
26  * LUT2[0-63] = base64_dec_table_neon64[64-127]
27  *   #  From       To        LUT  Characters
28  *   1  [0..42]    [255]      #1  invalid input
29  *   2  [43]       [62]       #1  +
30  *   3  [44..46]   [255]      #1  invalid input
31  *   4  [47]       [63]       #1  /
32  *   5  [48..57]   [52..61]   #1  0..9
33  *   6  [58..63]   [255]      #1  invalid input
34  *   7  [64]       [255]      #2  invalid input
35  *   8  [65..90]   [0..25]    #2  A..Z
36  *   9  [91..96]   [255]      #2 invalid input
37  *  10  [97..122]  [26..51]   #2  a..z
38  *  11  [123..126] [255]      #2 invalid input
39  * (12) Everything else => invalid input
40  */
41 static const uint8_t base64_dec_table_neon64[] = {
42 	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
43 	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
44 	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,  62, 255, 255, 255,  63,
45 	52,  53,  54,  55,  56,  57,  58,  59,  60,  61, 255, 255, 255, 255, 255, 255,
46 	0, 255,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
47 	14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25, 255, 255, 255, 255,
48 	255, 255,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
49 	40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51, 255, 255, 255, 255
50 };
51 
52 /*
53  * LUT1[0-63] = base64_urlsafe_dec_table_neon64[0-63]
54  * LUT2[0-63] = base64_urlsafe_dec_table_neon64[64-127]
55  *   #  From       To        LUT  Characters
56  *   1  [0..44]    [255]      #1  invalid input
57  *   2  [45]       [62]       #1  -
58  *   3  [46..47]   [255]      #1  invalid input
59  *   5  [48..57]   [52..61]   #1  0..9
60  *   6  [58..63]   [255]      #1  invalid input
61  *   7  [64]       [255]      #2  invalid input
62  *   8  [65..90]   [0..25]    #2  A..Z
63  *   9  [91..94]   [255]      #2  invalid input
64  *  10  [95]       [63]       #2  _
65  *  11  [96]       [255]      #2  invalid input
66  *  12  [97..122]  [26..51]   #2  a..z
67  *  13  [123..126] [255]      #2 invalid input
68  * (14) Everything else => invalid input
69  */
70 static const uint8_t base64_urlsafe_dec_table_neon64[] = {
71 	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
72 	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
73 	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,  62, 255, 255,
74 	52,  53,  54,  55,  56,  57,  58,  59,  60,  61, 255, 255, 255, 255, 255, 255,
75 	0, 255,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
76 	14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25, 255, 255, 255, 255,
77 	63, 255,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
78 	40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51, 255, 255, 255, 255
79 };
80 
81 #include <arm_neon.h>
82 #define CMPGT(s,n)      vcgtq_u8((s), vdupq_n_u8(n))
83 
84 static inline uint8x16x4_t
85 load_64byte_table(const uint8_t *p)
86 {
87 	uint8x16x4_t ret;
88 	ret.val[0] = vld1q_u8(p +  0);
89 	ret.val[1] = vld1q_u8(p + 16);
90 	ret.val[2] = vld1q_u8(p + 32);
91 	ret.val[3] = vld1q_u8(p + 48);
92 	return ret;
93 }
94 
95 static void
96 base64_encode_neon64(char **dst, const char *enc_table, const void **src, size_t *src_len)
97 {
98 	const uint8x16x4_t tbl_enc = load_64byte_table(enc_table);
99 
100 	while (*src_len >= 48) {
101 		uint8x16x3_t str;
102 		uint8x16x4_t res;
103 
104 		/* Load 48 bytes and deinterleave */
105 		str = vld3q_u8((uint8_t *)*src);
106 
107 		/* Divide bits of three input bytes over four output bytes and clear top two bits */
108 		res.val[0] = vshrq_n_u8(str.val[0], 2);
109 		res.val[1] = vandq_u8(vorrq_u8(vshrq_n_u8(str.val[1], 4), vshlq_n_u8(str.val[0], 4)),
110 				      vdupq_n_u8(0x3F));
111 		res.val[2] = vandq_u8(vorrq_u8(vshrq_n_u8(str.val[2], 6), vshlq_n_u8(str.val[1], 2)),
112 				      vdupq_n_u8(0x3F));
113 		res.val[3] = vandq_u8(str.val[2], vdupq_n_u8(0x3F));
114 
115 		/*
116 		 * The bits have now been shifted to the right locations;
117 		 * translate their values 0..63 to the Base64 alphabet.
118 		 * Use a 64-byte table lookup:
119 		 */
120 		res.val[0] = vqtbl4q_u8(tbl_enc, res.val[0]);
121 		res.val[1] = vqtbl4q_u8(tbl_enc, res.val[1]);
122 		res.val[2] = vqtbl4q_u8(tbl_enc, res.val[2]);
123 		res.val[3] = vqtbl4q_u8(tbl_enc, res.val[3]);
124 
125 		/* Interleave and store result */
126 		vst4q_u8((uint8_t *)*dst, res);
127 
128 		*src += 48;      /* 3 * 16 bytes of input */
129 		*dst += 64;      /* 4 * 16 bytes of output */
130 		*src_len -= 48;
131 	}
132 }
133 
134 static void
135 base64_decode_neon64(void **dst, const uint8_t *dec_table_neon64, const uint8_t **src,
136 		     size_t *src_len)
137 {
138 	/*
139 	 * First LUT tbl_dec1 will use VTBL instruction (out of range indices are set to 0 in destination).
140 	 * Second LUT tbl_dec2 will use VTBX instruction (out of range indices will be unchanged in destination).
141 	 * Input [64..126] will be mapped to index [1..63] in tb1_dec2. Index 0 means that value comes from tb1_dec1.
142 	 */
143 	const uint8x16x4_t tbl_dec1 = load_64byte_table(dec_table_neon64);
144 	const uint8x16x4_t tbl_dec2 = load_64byte_table(dec_table_neon64 + 64);
145 	const uint8x16_t offset = vdupq_n_u8(63U);
146 
147 	while (*src_len >= 64) {
148 
149 		uint8x16x4_t dec1, dec2;
150 		uint8x16x3_t dec;
151 
152 		/* Load 64 bytes and deinterleave */
153 		uint8x16x4_t str = vld4q_u8((uint8_t *)*src);
154 
155 		/* Get indices for 2nd LUT */
156 		dec2.val[0] = vqsubq_u8(str.val[0], offset);
157 		dec2.val[1] = vqsubq_u8(str.val[1], offset);
158 		dec2.val[2] = vqsubq_u8(str.val[2], offset);
159 		dec2.val[3] = vqsubq_u8(str.val[3], offset);
160 
161 		/* Get values from 1st LUT */
162 		dec1.val[0] = vqtbl4q_u8(tbl_dec1, str.val[0]);
163 		dec1.val[1] = vqtbl4q_u8(tbl_dec1, str.val[1]);
164 		dec1.val[2] = vqtbl4q_u8(tbl_dec1, str.val[2]);
165 		dec1.val[3] = vqtbl4q_u8(tbl_dec1, str.val[3]);
166 
167 		/* Get values from 2nd LUT */
168 		dec2.val[0] = vqtbx4q_u8(dec2.val[0], tbl_dec2, dec2.val[0]);
169 		dec2.val[1] = vqtbx4q_u8(dec2.val[1], tbl_dec2, dec2.val[1]);
170 		dec2.val[2] = vqtbx4q_u8(dec2.val[2], tbl_dec2, dec2.val[2]);
171 		dec2.val[3] = vqtbx4q_u8(dec2.val[3], tbl_dec2, dec2.val[3]);
172 
173 		/* Get final values */
174 		str.val[0] = vorrq_u8(dec1.val[0], dec2.val[0]);
175 		str.val[1] = vorrq_u8(dec1.val[1], dec2.val[1]);
176 		str.val[2] = vorrq_u8(dec1.val[2], dec2.val[2]);
177 		str.val[3] = vorrq_u8(dec1.val[3], dec2.val[3]);
178 
179 		/* Check for invalid input, any value larger than 63 */
180 		uint8x16_t classified = CMPGT(str.val[0], 63);
181 		classified = vorrq_u8(classified, CMPGT(str.val[1], 63));
182 		classified = vorrq_u8(classified, CMPGT(str.val[2], 63));
183 		classified = vorrq_u8(classified, CMPGT(str.val[3], 63));
184 
185 		/* check that all bits are zero */
186 		if (vmaxvq_u8(classified) != 0U) {
187 			break;
188 		}
189 
190 		/* Compress four bytes into three */
191 		dec.val[0] = vorrq_u8(vshlq_n_u8(str.val[0], 2), vshrq_n_u8(str.val[1], 4));
192 		dec.val[1] = vorrq_u8(vshlq_n_u8(str.val[1], 4), vshrq_n_u8(str.val[2], 2));
193 		dec.val[2] = vorrq_u8(vshlq_n_u8(str.val[2], 6), str.val[3]);
194 
195 		/* Interleave and store decoded result */
196 		vst3q_u8((uint8_t *)*dst, dec);
197 
198 		*src += 64;
199 		*dst += 48;
200 		*src_len -= 64;
201 	}
202 }
203