xref: /spdk/lib/util/base64_neon.c (revision 6f338d4bf3a8a91b7abe377a605a321ea2b05bf7)
1 /*   SPDX-License-Identifier: BSD-2-Clause
2  *   Copyright (c) 2005-2007, Nick Galbreath
3  *   Copyright (c) 2013-2017, Alfred Klomp
4  *   Copyright (c) 2015-2017, Wojciech Mula
5  *   Copyright (c) 2016-2017, Matthieu Darbois
6  *   All rights reserved.
7  */
8 
9 #ifndef __aarch64__
10 #error Unsupported hardware
11 #endif
12 
13 #include "spdk/stdinc.h"
14 /*
15  * Encoding
16  * Use a 64-byte lookup to do the encoding.
17  * Reuse existing base64_dec_table and base64_dec_table.
18 
19  * Decoding
20  * The input consists of five valid character sets in the Base64 alphabet,
21  * which we need to map back to the 6-bit values they represent.
22  * There are three ranges, two singles, and then there's the rest.
23  *
24  * LUT1[0-63] = base64_dec_table_neon64[0-63]
25  * LUT2[0-63] = base64_dec_table_neon64[64-127]
26  *   #  From       To        LUT  Characters
27  *   1  [0..42]    [255]      #1  invalid input
28  *   2  [43]       [62]       #1  +
29  *   3  [44..46]   [255]      #1  invalid input
30  *   4  [47]       [63]       #1  /
31  *   5  [48..57]   [52..61]   #1  0..9
32  *   6  [58..63]   [255]      #1  invalid input
33  *   7  [64]       [255]      #2  invalid input
34  *   8  [65..90]   [0..25]    #2  A..Z
35  *   9  [91..96]   [255]      #2 invalid input
36  *  10  [97..122]  [26..51]   #2  a..z
37  *  11  [123..126] [255]      #2 invalid input
38  * (12) Everything else => invalid input
39  */
40 static const uint8_t base64_dec_table_neon64[] = {
41 	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
42 	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
43 	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,  62, 255, 255, 255,  63,
44 	52,  53,  54,  55,  56,  57,  58,  59,  60,  61, 255, 255, 255, 255, 255, 255,
45 	0, 255,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
46 	14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25, 255, 255, 255, 255,
47 	255, 255,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
48 	40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51, 255, 255, 255, 255
49 };
50 
51 /*
52  * LUT1[0-63] = base64_urlsafe_dec_table_neon64[0-63]
53  * LUT2[0-63] = base64_urlsafe_dec_table_neon64[64-127]
54  *   #  From       To        LUT  Characters
55  *   1  [0..44]    [255]      #1  invalid input
56  *   2  [45]       [62]       #1  -
57  *   3  [46..47]   [255]      #1  invalid input
58  *   5  [48..57]   [52..61]   #1  0..9
59  *   6  [58..63]   [255]      #1  invalid input
60  *   7  [64]       [255]      #2  invalid input
61  *   8  [65..90]   [0..25]    #2  A..Z
62  *   9  [91..94]   [255]      #2  invalid input
63  *  10  [95]       [63]       #2  _
64  *  11  [96]       [255]      #2  invalid input
65  *  12  [97..122]  [26..51]   #2  a..z
66  *  13  [123..126] [255]      #2 invalid input
67  * (14) Everything else => invalid input
68  */
69 static const uint8_t base64_urlsafe_dec_table_neon64[] = {
70 	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
71 	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
72 	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,  62, 255, 255,
73 	52,  53,  54,  55,  56,  57,  58,  59,  60,  61, 255, 255, 255, 255, 255, 255,
74 	0, 255,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
75 	14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25, 255, 255, 255, 255,
76 	63, 255,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
77 	40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51, 255, 255, 255, 255
78 };
79 
80 #include <arm_neon.h>
81 #define CMPGT(s,n)      vcgtq_u8((s), vdupq_n_u8(n))
82 
83 static inline uint8x16x4_t
84 load_64byte_table(const uint8_t *p)
85 {
86 	uint8x16x4_t ret;
87 	ret.val[0] = vld1q_u8(p +  0);
88 	ret.val[1] = vld1q_u8(p + 16);
89 	ret.val[2] = vld1q_u8(p + 32);
90 	ret.val[3] = vld1q_u8(p + 48);
91 	return ret;
92 }
93 
94 static void
95 base64_encode_neon64(char **dst, const char *enc_table, const void **src, size_t *src_len)
96 {
97 	const uint8x16x4_t tbl_enc = load_64byte_table(enc_table);
98 
99 	while (*src_len >= 48) {
100 		uint8x16x3_t str;
101 		uint8x16x4_t res;
102 
103 		/* Load 48 bytes and deinterleave */
104 		str = vld3q_u8((uint8_t *)*src);
105 
106 		/* Divide bits of three input bytes over four output bytes and clear top two bits */
107 		res.val[0] = vshrq_n_u8(str.val[0], 2);
108 		res.val[1] = vandq_u8(vorrq_u8(vshrq_n_u8(str.val[1], 4), vshlq_n_u8(str.val[0], 4)),
109 				      vdupq_n_u8(0x3F));
110 		res.val[2] = vandq_u8(vorrq_u8(vshrq_n_u8(str.val[2], 6), vshlq_n_u8(str.val[1], 2)),
111 				      vdupq_n_u8(0x3F));
112 		res.val[3] = vandq_u8(str.val[2], vdupq_n_u8(0x3F));
113 
114 		/*
115 		 * The bits have now been shifted to the right locations;
116 		 * translate their values 0..63 to the Base64 alphabet.
117 		 * Use a 64-byte table lookup:
118 		 */
119 		res.val[0] = vqtbl4q_u8(tbl_enc, res.val[0]);
120 		res.val[1] = vqtbl4q_u8(tbl_enc, res.val[1]);
121 		res.val[2] = vqtbl4q_u8(tbl_enc, res.val[2]);
122 		res.val[3] = vqtbl4q_u8(tbl_enc, res.val[3]);
123 
124 		/* Interleave and store result */
125 		vst4q_u8((uint8_t *)*dst, res);
126 
127 		*src += 48;      /* 3 * 16 bytes of input */
128 		*dst += 64;      /* 4 * 16 bytes of output */
129 		*src_len -= 48;
130 	}
131 }
132 
133 static void
134 base64_decode_neon64(void **dst, const uint8_t *dec_table_neon64, const uint8_t **src,
135 		     size_t *src_len)
136 {
137 	/*
138 	 * First LUT tbl_dec1 will use VTBL instruction (out of range indices are set to 0 in destination).
139 	 * Second LUT tbl_dec2 will use VTBX instruction (out of range indices will be unchanged in destination).
140 	 * Input [64..126] will be mapped to index [1..63] in tb1_dec2. Index 0 means that value comes from tb1_dec1.
141 	 */
142 	const uint8x16x4_t tbl_dec1 = load_64byte_table(dec_table_neon64);
143 	const uint8x16x4_t tbl_dec2 = load_64byte_table(dec_table_neon64 + 64);
144 	const uint8x16_t offset = vdupq_n_u8(63U);
145 
146 	while (*src_len >= 64) {
147 
148 		uint8x16x4_t dec1, dec2;
149 		uint8x16x3_t dec;
150 
151 		/* Load 64 bytes and deinterleave */
152 		uint8x16x4_t str = vld4q_u8((uint8_t *)*src);
153 
154 		/* Get indices for 2nd LUT */
155 		dec2.val[0] = vqsubq_u8(str.val[0], offset);
156 		dec2.val[1] = vqsubq_u8(str.val[1], offset);
157 		dec2.val[2] = vqsubq_u8(str.val[2], offset);
158 		dec2.val[3] = vqsubq_u8(str.val[3], offset);
159 
160 		/* Get values from 1st LUT */
161 		dec1.val[0] = vqtbl4q_u8(tbl_dec1, str.val[0]);
162 		dec1.val[1] = vqtbl4q_u8(tbl_dec1, str.val[1]);
163 		dec1.val[2] = vqtbl4q_u8(tbl_dec1, str.val[2]);
164 		dec1.val[3] = vqtbl4q_u8(tbl_dec1, str.val[3]);
165 
166 		/* Get values from 2nd LUT */
167 		dec2.val[0] = vqtbx4q_u8(dec2.val[0], tbl_dec2, dec2.val[0]);
168 		dec2.val[1] = vqtbx4q_u8(dec2.val[1], tbl_dec2, dec2.val[1]);
169 		dec2.val[2] = vqtbx4q_u8(dec2.val[2], tbl_dec2, dec2.val[2]);
170 		dec2.val[3] = vqtbx4q_u8(dec2.val[3], tbl_dec2, dec2.val[3]);
171 
172 		/* Get final values */
173 		str.val[0] = vorrq_u8(dec1.val[0], dec2.val[0]);
174 		str.val[1] = vorrq_u8(dec1.val[1], dec2.val[1]);
175 		str.val[2] = vorrq_u8(dec1.val[2], dec2.val[2]);
176 		str.val[3] = vorrq_u8(dec1.val[3], dec2.val[3]);
177 
178 		/* Check for invalid input, any value larger than 63 */
179 		uint8x16_t classified = CMPGT(str.val[0], 63);
180 		classified = vorrq_u8(classified, CMPGT(str.val[1], 63));
181 		classified = vorrq_u8(classified, CMPGT(str.val[2], 63));
182 		classified = vorrq_u8(classified, CMPGT(str.val[3], 63));
183 
184 		/* check that all bits are zero */
185 		if (vmaxvq_u8(classified) != 0U) {
186 			break;
187 		}
188 
189 		/* Compress four bytes into three */
190 		dec.val[0] = vorrq_u8(vshlq_n_u8(str.val[0], 2), vshrq_n_u8(str.val[1], 4));
191 		dec.val[1] = vorrq_u8(vshlq_n_u8(str.val[1], 4), vshrq_n_u8(str.val[2], 2));
192 		dec.val[2] = vorrq_u8(vshlq_n_u8(str.val[2], 6), str.val[3]);
193 
194 		/* Interleave and store decoded result */
195 		vst3q_u8((uint8_t *)*dst, dec);
196 
197 		*src += 64;
198 		*dst += 48;
199 		*src_len -= 64;
200 	}
201 }
202