xref: /spdk/lib/util/base64_neon.c (revision 7506a7aa53d239f533af3bc768f0d2af55e735fe)
1 /*-
2  *   BSD LICENSE
3  *
4  *   Copyright (c) 2005-2007, Nick Galbreath
5  *   Copyright (c) 2013-2017, Alfred Klomp
6  *   Copyright (c) 2015-2017, Wojciech Mula
7  *   Copyright (c) 2016-2017, Matthieu Darbois
8  *   All rights reserved.
9  *
10  *   Redistribution and use in source and binary forms, with or without
11  *   modification, are permitted provided that the following conditions are
12  *   met:
13  *
14  *     * Redistributions of source code must retain the above copyright notice,
15  *       this list of conditions and the following disclaimer.
16  *     * Redistributions in binary form must reproduce the above copyright
17  *       notice, this list of conditions and the following disclaimer in the
18  *       documentation and/or other materials provided with the distribution.
19  *
20  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
21  *   IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
22  *   TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
23  *   PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24  *   HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
26  *   TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
27  *   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
28  *   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
29  *   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30  *   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31  */
32 
33 #ifndef __aarch64__
34 #error Unsupported hardware
35 #endif
36 
37 #include "spdk/stdinc.h"
38 /*
39  * Encoding
40  * Use a 64-byte lookup to do the encoding.
41  * Reuse existing base64_dec_table and base64_dec_table.
42 
43  * Decoding
44  * The input consists of five valid character sets in the Base64 alphabet,
45  * which we need to map back to the 6-bit values they represent.
46  * There are three ranges, two singles, and then there's the rest.
47  *
48  * LUT1[0-63] = base64_dec_table_neon64[0-63]
49  * LUT2[0-63] = base64_dec_table_neon64[64-127]
50  *   #  From       To        LUT  Characters
51  *   1  [0..42]    [255]      #1  invalid input
52  *   2  [43]       [62]       #1  +
53  *   3  [44..46]   [255]      #1  invalid input
54  *   4  [47]       [63]       #1  /
55  *   5  [48..57]   [52..61]   #1  0..9
56  *   6  [58..63]   [255]      #1  invalid input
57  *   7  [64]       [255]      #2  invalid input
58  *   8  [65..90]   [0..25]    #2  A..Z
59  *   9  [91..96]   [255]      #2 invalid input
60  *  10  [97..122]  [26..51]   #2  a..z
61  *  11  [123..126] [255]      #2 invalid input
62  * (12) Everything else => invalid input
63  */
64 static const uint8_t base64_dec_table_neon64[] = {
65 	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
66 	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
67 	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,  62, 255, 255, 255,  63,
68 	52,  53,  54,  55,  56,  57,  58,  59,  60,  61, 255, 255, 255, 255, 255, 255,
69 	0, 255,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
70 	14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25, 255, 255, 255, 255,
71 	255, 255,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
72 	40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51, 255, 255, 255, 255
73 };
74 
75 /*
76  * LUT1[0-63] = base64_urlsafe_dec_table_neon64[0-63]
77  * LUT2[0-63] = base64_urlsafe_dec_table_neon64[64-127]
78  *   #  From       To        LUT  Characters
79  *   1  [0..44]    [255]      #1  invalid input
80  *   2  [45]       [62]       #1  -
81  *   3  [46..47]   [255]      #1  invalid input
82  *   5  [48..57]   [52..61]   #1  0..9
83  *   6  [58..63]   [255]      #1  invalid input
84  *   7  [64]       [255]      #2  invalid input
85  *   8  [65..90]   [0..25]    #2  A..Z
86  *   9  [91..94]   [255]      #2  invalid input
87  *  10  [95]       [63]       #2  _
88  *  11  [96]       [255]      #2  invalid input
89  *  12  [97..122]  [26..51]   #2  a..z
90  *  13  [123..126] [255]      #2 invalid input
91  * (14) Everything else => invalid input
92  */
93 static const uint8_t base64_urlsafe_dec_table_neon64[] = {
94 	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
95 	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
96 	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,  62, 255, 255,
97 	52,  53,  54,  55,  56,  57,  58,  59,  60,  61, 255, 255, 255, 255, 255, 255,
98 	0, 255,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
99 	14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25, 255, 255, 255, 255,
100 	63, 255,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
101 	40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51, 255, 255, 255, 255
102 };
103 
104 #include <arm_neon.h>
105 #define CMPGT(s,n)      vcgtq_u8((s), vdupq_n_u8(n))
106 
107 static inline uint8x16x4_t
108 load_64byte_table(const uint8_t *p)
109 {
110 	uint8x16x4_t ret;
111 	ret.val[0] = vld1q_u8(p +  0);
112 	ret.val[1] = vld1q_u8(p + 16);
113 	ret.val[2] = vld1q_u8(p + 32);
114 	ret.val[3] = vld1q_u8(p + 48);
115 	return ret;
116 }
117 
118 static void
119 base64_encode_neon64(char **dst, const char *enc_table, const void **src, size_t *src_len)
120 {
121 	const uint8x16x4_t tbl_enc = load_64byte_table(enc_table);
122 
123 	while (*src_len >= 48) {
124 		uint8x16x3_t str;
125 		uint8x16x4_t res;
126 
127 		/* Load 48 bytes and deinterleave */
128 		str = vld3q_u8((uint8_t *)*src);
129 
130 		/* Divide bits of three input bytes over four output bytes and clear top two bits */
131 		res.val[0] = vshrq_n_u8(str.val[0], 2);
132 		res.val[1] = vandq_u8(vorrq_u8(vshrq_n_u8(str.val[1], 4), vshlq_n_u8(str.val[0], 4)),
133 				      vdupq_n_u8(0x3F));
134 		res.val[2] = vandq_u8(vorrq_u8(vshrq_n_u8(str.val[2], 6), vshlq_n_u8(str.val[1], 2)),
135 				      vdupq_n_u8(0x3F));
136 		res.val[3] = vandq_u8(str.val[2], vdupq_n_u8(0x3F));
137 
138 		/*
139 		 * The bits have now been shifted to the right locations;
140 		 * translate their values 0..63 to the Base64 alphabet.
141 		 * Use a 64-byte table lookup:
142 		 */
143 		res.val[0] = vqtbl4q_u8(tbl_enc, res.val[0]);
144 		res.val[1] = vqtbl4q_u8(tbl_enc, res.val[1]);
145 		res.val[2] = vqtbl4q_u8(tbl_enc, res.val[2]);
146 		res.val[3] = vqtbl4q_u8(tbl_enc, res.val[3]);
147 
148 		/* Interleave and store result */
149 		vst4q_u8((uint8_t *)*dst, res);
150 
151 		*src += 48;      /* 3 * 16 bytes of input */
152 		*dst += 64;      /* 4 * 16 bytes of output */
153 		*src_len -= 48;
154 	}
155 }
156 
157 static void
158 base64_decode_neon64(void **dst, const uint8_t *dec_table_neon64, const uint8_t **src,
159 		     size_t *src_len)
160 {
161 	/*
162 	 * First LUT tbl_dec1 will use VTBL instruction (out of range indices are set to 0 in destination).
163 	 * Second LUT tbl_dec2 will use VTBX instruction (out of range indices will be unchanged in destination).
164 	 * Input [64..126] will be mapped to index [1..63] in tb1_dec2. Index 0 means that value comes from tb1_dec1.
165 	 */
166 	const uint8x16x4_t tbl_dec1 = load_64byte_table(dec_table_neon64);
167 	const uint8x16x4_t tbl_dec2 = load_64byte_table(dec_table_neon64 + 64);
168 	const uint8x16_t offset = vdupq_n_u8(63U);
169 
170 	while (*src_len >= 64) {
171 
172 		uint8x16x4_t dec1, dec2;
173 		uint8x16x3_t dec;
174 
175 		/* Load 64 bytes and deinterleave */
176 		uint8x16x4_t str = vld4q_u8((uint8_t *)*src);
177 
178 		/* Get indices for 2nd LUT */
179 		dec2.val[0] = vqsubq_u8(str.val[0], offset);
180 		dec2.val[1] = vqsubq_u8(str.val[1], offset);
181 		dec2.val[2] = vqsubq_u8(str.val[2], offset);
182 		dec2.val[3] = vqsubq_u8(str.val[3], offset);
183 
184 		/* Get values from 1st LUT */
185 		dec1.val[0] = vqtbl4q_u8(tbl_dec1, str.val[0]);
186 		dec1.val[1] = vqtbl4q_u8(tbl_dec1, str.val[1]);
187 		dec1.val[2] = vqtbl4q_u8(tbl_dec1, str.val[2]);
188 		dec1.val[3] = vqtbl4q_u8(tbl_dec1, str.val[3]);
189 
190 		/* Get values from 2nd LUT */
191 		dec2.val[0] = vqtbx4q_u8(dec2.val[0], tbl_dec2, dec2.val[0]);
192 		dec2.val[1] = vqtbx4q_u8(dec2.val[1], tbl_dec2, dec2.val[1]);
193 		dec2.val[2] = vqtbx4q_u8(dec2.val[2], tbl_dec2, dec2.val[2]);
194 		dec2.val[3] = vqtbx4q_u8(dec2.val[3], tbl_dec2, dec2.val[3]);
195 
196 		/* Get final values */
197 		str.val[0] = vorrq_u8(dec1.val[0], dec2.val[0]);
198 		str.val[1] = vorrq_u8(dec1.val[1], dec2.val[1]);
199 		str.val[2] = vorrq_u8(dec1.val[2], dec2.val[2]);
200 		str.val[3] = vorrq_u8(dec1.val[3], dec2.val[3]);
201 
202 		/* Check for invalid input, any value larger than 63 */
203 		uint8x16_t classified = CMPGT(str.val[0], 63);
204 		classified = vorrq_u8(classified, CMPGT(str.val[1], 63));
205 		classified = vorrq_u8(classified, CMPGT(str.val[2], 63));
206 		classified = vorrq_u8(classified, CMPGT(str.val[3], 63));
207 
208 		/* check that all bits are zero */
209 		if (vmaxvq_u8(classified) != 0U) {
210 			break;
211 		}
212 
213 		/* Compress four bytes into three */
214 		dec.val[0] = vorrq_u8(vshlq_n_u8(str.val[0], 2), vshrq_n_u8(str.val[1], 4));
215 		dec.val[1] = vorrq_u8(vshlq_n_u8(str.val[1], 4), vshrq_n_u8(str.val[2], 2));
216 		dec.val[2] = vorrq_u8(vshlq_n_u8(str.val[2], 6), str.val[3]);
217 
218 		/* Interleave and store decoded result */
219 		vst3q_u8((uint8_t *)*dst, dec);
220 
221 		*src += 64;
222 		*dst += 48;
223 		*src_len -= 64;
224 	}
225 }
226