xref: /spdk/lib/util/base64_neon.c (revision 075d422f3480d3db11013734f833304606867da4)
1  /*   SPDX-License-Identifier: BSD-2-Clause
2   *   Copyright (C) 2020 Intel Corporation.
3   *   Copyright (c) 2005-2007, Nick Galbreath
4   *   Copyright (c) 2013-2017, Alfred Klomp
5   *   Copyright (c) 2015-2017, Wojciech Mula
6   *   Copyright (c) 2016-2017, Matthieu Darbois
7   *   All rights reserved.
8   */
9  
10  #ifndef __aarch64__
11  #error Unsupported hardware
12  #endif
13  
14  #include "spdk/stdinc.h"
15  /*
16   * Encoding
17   * Use a 64-byte lookup to do the encoding.
18   * Reuse existing base64_dec_table and base64_dec_table.
19  
20   * Decoding
21   * The input consists of five valid character sets in the Base64 alphabet,
22   * which we need to map back to the 6-bit values they represent.
23   * There are three ranges, two singles, and then there's the rest.
24   *
25   * LUT1[0-63] = base64_dec_table_neon64[0-63]
26   * LUT2[0-63] = base64_dec_table_neon64[64-127]
27   *   #  From       To        LUT  Characters
28   *   1  [0..42]    [255]      #1  invalid input
29   *   2  [43]       [62]       #1  +
30   *   3  [44..46]   [255]      #1  invalid input
31   *   4  [47]       [63]       #1  /
32   *   5  [48..57]   [52..61]   #1  0..9
33   *   6  [58..63]   [255]      #1  invalid input
34   *   7  [64]       [255]      #2  invalid input
35   *   8  [65..90]   [0..25]    #2  A..Z
36   *   9  [91..96]   [255]      #2 invalid input
37   *  10  [97..122]  [26..51]   #2  a..z
38   *  11  [123..126] [255]      #2 invalid input
39   * (12) Everything else => invalid input
40   */
41  static const uint8_t base64_dec_table_neon64[] = {
42  	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
43  	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
44  	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,  62, 255, 255, 255,  63,
45  	52,  53,  54,  55,  56,  57,  58,  59,  60,  61, 255, 255, 255, 255, 255, 255,
46  	0, 255,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
47  	14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25, 255, 255, 255, 255,
48  	255, 255,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
49  	40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51, 255, 255, 255, 255
50  };
51  
52  /*
53   * LUT1[0-63] = base64_urlsafe_dec_table_neon64[0-63]
54   * LUT2[0-63] = base64_urlsafe_dec_table_neon64[64-127]
55   *   #  From       To        LUT  Characters
56   *   1  [0..44]    [255]      #1  invalid input
57   *   2  [45]       [62]       #1  -
58   *   3  [46..47]   [255]      #1  invalid input
59   *   5  [48..57]   [52..61]   #1  0..9
60   *   6  [58..63]   [255]      #1  invalid input
61   *   7  [64]       [255]      #2  invalid input
62   *   8  [65..90]   [0..25]    #2  A..Z
63   *   9  [91..94]   [255]      #2  invalid input
64   *  10  [95]       [63]       #2  _
65   *  11  [96]       [255]      #2  invalid input
66   *  12  [97..122]  [26..51]   #2  a..z
67   *  13  [123..126] [255]      #2 invalid input
68   * (14) Everything else => invalid input
69   */
70  static const uint8_t base64_urlsafe_dec_table_neon64[] = {
71  	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
72  	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
73  	255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,  62, 255, 255,
74  	52,  53,  54,  55,  56,  57,  58,  59,  60,  61, 255, 255, 255, 255, 255, 255,
75  	0, 255,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
76  	14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25, 255, 255, 255, 255,
77  	63, 255,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
78  	40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51, 255, 255, 255, 255
79  };
80  
81  #include <arm_neon.h>
82  #define CMPGT(s,n)      vcgtq_u8((s), vdupq_n_u8(n))
83  
84  static inline uint8x16x4_t
load_64byte_table(const uint8_t * p)85  load_64byte_table(const uint8_t *p)
86  {
87  	uint8x16x4_t ret;
88  	ret.val[0] = vld1q_u8(p +  0);
89  	ret.val[1] = vld1q_u8(p + 16);
90  	ret.val[2] = vld1q_u8(p + 32);
91  	ret.val[3] = vld1q_u8(p + 48);
92  	return ret;
93  }
94  
95  static void
base64_encode_neon64(char ** dst,const char * enc_table,const void ** src,size_t * src_len)96  base64_encode_neon64(char **dst, const char *enc_table, const void **src, size_t *src_len)
97  {
98  	const uint8x16x4_t tbl_enc = load_64byte_table(enc_table);
99  
100  	while (*src_len >= 48) {
101  		uint8x16x3_t str;
102  		uint8x16x4_t res;
103  
104  		/* Load 48 bytes and deinterleave */
105  		str = vld3q_u8((uint8_t *)*src);
106  
107  		/* Divide bits of three input bytes over four output bytes and clear top two bits */
108  		res.val[0] = vshrq_n_u8(str.val[0], 2);
109  		res.val[1] = vandq_u8(vorrq_u8(vshrq_n_u8(str.val[1], 4), vshlq_n_u8(str.val[0], 4)),
110  				      vdupq_n_u8(0x3F));
111  		res.val[2] = vandq_u8(vorrq_u8(vshrq_n_u8(str.val[2], 6), vshlq_n_u8(str.val[1], 2)),
112  				      vdupq_n_u8(0x3F));
113  		res.val[3] = vandq_u8(str.val[2], vdupq_n_u8(0x3F));
114  
115  		/*
116  		 * The bits have now been shifted to the right locations;
117  		 * translate their values 0..63 to the Base64 alphabet.
118  		 * Use a 64-byte table lookup:
119  		 */
120  		res.val[0] = vqtbl4q_u8(tbl_enc, res.val[0]);
121  		res.val[1] = vqtbl4q_u8(tbl_enc, res.val[1]);
122  		res.val[2] = vqtbl4q_u8(tbl_enc, res.val[2]);
123  		res.val[3] = vqtbl4q_u8(tbl_enc, res.val[3]);
124  
125  		/* Interleave and store result */
126  		vst4q_u8((uint8_t *)*dst, res);
127  
128  		*src = (uint8_t *)*src + 48;	/* 3 * 16 bytes of input */
129  		*dst += 64;			/* 4 * 16 bytes of output */
130  		*src_len -= 48;
131  	}
132  }
133  
134  static void
base64_decode_neon64(void ** dst,const uint8_t * dec_table_neon64,const uint8_t ** src,size_t * src_len)135  base64_decode_neon64(void **dst, const uint8_t *dec_table_neon64, const uint8_t **src,
136  		     size_t *src_len)
137  {
138  	/*
139  	 * First LUT tbl_dec1 will use VTBL instruction (out of range indices are set to 0 in destination).
140  	 * Second LUT tbl_dec2 will use VTBX instruction (out of range indices will be unchanged in destination).
141  	 * Input [64..126] will be mapped to index [1..63] in tb1_dec2. Index 0 means that value comes from tb1_dec1.
142  	 */
143  	const uint8x16x4_t tbl_dec1 = load_64byte_table(dec_table_neon64);
144  	const uint8x16x4_t tbl_dec2 = load_64byte_table(dec_table_neon64 + 64);
145  	const uint8x16_t offset = vdupq_n_u8(63U);
146  
147  	while (*src_len >= 64) {
148  
149  		uint8x16x4_t dec1, dec2;
150  		uint8x16x3_t dec;
151  
152  		/* Load 64 bytes and deinterleave */
153  		uint8x16x4_t str = vld4q_u8((uint8_t *)*src);
154  
155  		/* Get indices for 2nd LUT */
156  		dec2.val[0] = vqsubq_u8(str.val[0], offset);
157  		dec2.val[1] = vqsubq_u8(str.val[1], offset);
158  		dec2.val[2] = vqsubq_u8(str.val[2], offset);
159  		dec2.val[3] = vqsubq_u8(str.val[3], offset);
160  
161  		/* Get values from 1st LUT */
162  		dec1.val[0] = vqtbl4q_u8(tbl_dec1, str.val[0]);
163  		dec1.val[1] = vqtbl4q_u8(tbl_dec1, str.val[1]);
164  		dec1.val[2] = vqtbl4q_u8(tbl_dec1, str.val[2]);
165  		dec1.val[3] = vqtbl4q_u8(tbl_dec1, str.val[3]);
166  
167  		/* Get values from 2nd LUT */
168  		dec2.val[0] = vqtbx4q_u8(dec2.val[0], tbl_dec2, dec2.val[0]);
169  		dec2.val[1] = vqtbx4q_u8(dec2.val[1], tbl_dec2, dec2.val[1]);
170  		dec2.val[2] = vqtbx4q_u8(dec2.val[2], tbl_dec2, dec2.val[2]);
171  		dec2.val[3] = vqtbx4q_u8(dec2.val[3], tbl_dec2, dec2.val[3]);
172  
173  		/* Get final values */
174  		str.val[0] = vorrq_u8(dec1.val[0], dec2.val[0]);
175  		str.val[1] = vorrq_u8(dec1.val[1], dec2.val[1]);
176  		str.val[2] = vorrq_u8(dec1.val[2], dec2.val[2]);
177  		str.val[3] = vorrq_u8(dec1.val[3], dec2.val[3]);
178  
179  		/* Check for invalid input, any value larger than 63 */
180  		uint8x16_t classified = CMPGT(str.val[0], 63);
181  		classified = vorrq_u8(classified, CMPGT(str.val[1], 63));
182  		classified = vorrq_u8(classified, CMPGT(str.val[2], 63));
183  		classified = vorrq_u8(classified, CMPGT(str.val[3], 63));
184  
185  		/* check that all bits are zero */
186  		if (vmaxvq_u8(classified) != 0U) {
187  			break;
188  		}
189  
190  		/* Compress four bytes into three */
191  		dec.val[0] = vorrq_u8(vshlq_n_u8(str.val[0], 2), vshrq_n_u8(str.val[1], 4));
192  		dec.val[1] = vorrq_u8(vshlq_n_u8(str.val[1], 4), vshrq_n_u8(str.val[2], 2));
193  		dec.val[2] = vorrq_u8(vshlq_n_u8(str.val[2], 6), str.val[3]);
194  
195  		/* Interleave and store decoded result */
196  		vst3q_u8((uint8_t *)*dst, dec);
197  
198  		*src += 64;
199  		*dst = (uint8_t *)*dst + 48;
200  		*src_len -= 64;
201  	}
202  }
203