1 /* $OpenBSD: buffertest.c,v 1.6 2022/07/22 19:34:55 jsing Exp $ */ 2 /* 3 * Copyright (c) 2019, 2022 Joel Sing <jsing@openbsd.org> 4 * 5 * Permission to use, copy, modify, and distribute this software for any 6 * purpose with or without fee is hereby granted, provided that the above 7 * copyright notice and this permission notice appear in all copies. 8 * 9 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 */ 17 18 #include <err.h> 19 #include <stdio.h> 20 #include <stdlib.h> 21 #include <string.h> 22 23 #include "tls_internal.h" 24 25 uint8_t testdata[] = { 26 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 27 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 28 }; 29 30 struct read_state { 31 uint8_t *buf; 32 size_t len; 33 size_t offset; 34 }; 35 36 static ssize_t 37 read_cb(void *buf, size_t buflen, void *cb_arg) 38 { 39 struct read_state *rs = cb_arg; 40 ssize_t n; 41 42 if (rs->offset > rs->len) 43 return TLS_IO_EOF; 44 45 if ((size_t)(n = buflen) > (rs->len - rs->offset)) 46 n = rs->len - rs->offset; 47 48 if (n == 0) 49 return TLS_IO_WANT_POLLIN; 50 51 memcpy(buf, &rs->buf[rs->offset], n); 52 rs->offset += n; 53 54 return n; 55 } 56 57 struct extend_test { 58 size_t extend_len; 59 size_t read_len; 60 ssize_t want_ret; 61 }; 62 63 const struct extend_test extend_tests[] = { 64 { 65 .extend_len = 4, 66 .read_len = 0, 67 .want_ret = TLS_IO_WANT_POLLIN, 68 }, 69 { 70 .extend_len = 4, 71 .read_len = 8, 72 .want_ret = 4, 73 }, 74 { 75 .extend_len = 12, 76 .read_len = 8, 77 .want_ret = TLS_IO_WANT_POLLIN, 78 }, 79 { 80 .extend_len = 12, 81 .read_len = 10, 82 .want_ret = TLS_IO_WANT_POLLIN, 83 }, 84 { 85 .extend_len = 12, 86 .read_len = 12, 87 .want_ret = 12, 88 }, 89 { 90 .extend_len = 16, 91 .read_len = 16, 92 .want_ret = 16, 93 }, 94 { 95 .extend_len = 20, 96 .read_len = 1, 97 .want_ret = TLS_IO_EOF, 98 }, 99 }; 100 101 #define N_EXTEND_TESTS (sizeof(extend_tests) / sizeof(extend_tests[0])) 102 103 static int 104 tls_buffer_extend_test(void) 105 { 106 const struct extend_test *et; 107 struct tls_buffer *buf; 108 struct read_state rs; 109 uint8_t *data = NULL; 110 size_t i, data_len; 111 ssize_t ret; 112 CBS cbs; 113 int failed = 1; 114 115 rs.buf = testdata; 116 rs.offset = 0; 117 118 if ((buf = tls_buffer_new(0)) == NULL) 119 errx(1, "tls_buffer_new"); 120 121 for (i = 0; i < N_EXTEND_TESTS; i++) { 122 et = &extend_tests[i]; 123 rs.len = et->read_len; 124 125 ret = tls_buffer_extend(buf, et->extend_len, read_cb, &rs); 126 if (ret != extend_tests[i].want_ret) { 127 fprintf(stderr, "FAIL: Test %zd - extend returned %zd, " 128 "want %zd\n", i, ret, et->want_ret); 129 goto failed; 130 } 131 132 if (!tls_buffer_data(buf, &cbs)) { 133 fprintf(stderr, "FAIL: Test %zd - failed to get data\n", 134 i); 135 goto failed; 136 } 137 138 if (!CBS_mem_equal(&cbs, testdata, CBS_len(&cbs))) { 139 fprintf(stderr, "FAIL: Test %zd - extend buffer " 140 "mismatch", i); 141 goto failed; 142 } 143 } 144 145 if (!tls_buffer_finish(buf, &data, &data_len)) { 146 fprintf(stderr, "FAIL: failed to finish\n"); 147 goto failed; 148 } 149 150 tls_buffer_free(buf); 151 buf = NULL; 152 153 if (data_len != sizeof(testdata)) { 154 fprintf(stderr, "FAIL: got data length %zu, want %zu\n", 155 data_len, sizeof(testdata)); 156 goto failed; 157 } 158 if (memcmp(data, testdata, data_len) != 0) { 159 fprintf(stderr, "FAIL: data mismatch\n"); 160 goto failed; 161 } 162 163 failed = 0; 164 165 failed: 166 tls_buffer_free(buf); 167 free(data); 168 169 return failed; 170 } 171 172 struct read_write_test { 173 uint8_t pattern; 174 size_t read; 175 size_t write; 176 size_t append; 177 ssize_t want; 178 }; 179 180 const struct read_write_test read_write_tests[] = { 181 { 182 .read = 2048, 183 .want = TLS_IO_WANT_POLLIN, 184 }, 185 { 186 .pattern = 0xdb, 187 .write = 2048, 188 .want = 2048, 189 }, 190 { 191 .pattern = 0xbd, 192 .append = 2048, 193 .want = 1, 194 }, 195 { 196 .pattern = 0xdb, 197 .read = 2048, 198 .want = 2048, 199 }, 200 { 201 .pattern = 0xfe, 202 .append = 1024, 203 .want = 1, 204 }, 205 { 206 .pattern = 0xbd, 207 .read = 1000, 208 .want = 1000, 209 }, 210 { 211 .pattern = 0xbd, 212 .read = 1048, 213 .want = 1048, 214 }, 215 { 216 .pattern = 0xdb, 217 .write = 2048, 218 .want = 2048, 219 }, 220 { 221 .pattern = 0xbd, 222 .append = 1024, 223 .want = 1, 224 }, 225 { 226 .pattern = 0xee, 227 .append = 4096, 228 .want = 1, 229 }, 230 { 231 .pattern = 0xfe, 232 .append = 1, 233 .want = 0, 234 }, 235 { 236 .pattern = 0xfe, 237 .write = 1, 238 .want = TLS_IO_FAILURE, 239 }, 240 { 241 .pattern = 0xfe, 242 .read = 1024, 243 .want = 1024, 244 }, 245 { 246 .pattern = 0xdb, 247 .read = 2048, 248 .want = 2048, 249 }, 250 { 251 .pattern = 0xbd, 252 .read = 1024, 253 .want = 1024, 254 }, 255 { 256 .pattern = 0xee, 257 .read = 1024, 258 .want = 1024, 259 }, 260 { 261 .pattern = 0xee, 262 .read = 4096, 263 .want = 3072, 264 }, 265 { 266 .read = 2048, 267 .want = TLS_IO_WANT_POLLIN, 268 }, 269 }; 270 271 #define N_READ_WRITE_TESTS (sizeof(read_write_tests) / sizeof(read_write_tests[0])) 272 273 static int 274 tls_buffer_read_write_test(void) 275 { 276 const struct read_write_test *rwt; 277 struct tls_buffer *buf = NULL; 278 uint8_t *rbuf = NULL, *wbuf = NULL; 279 ssize_t n; 280 size_t i; 281 int ret; 282 int failed = 1; 283 284 if ((buf = tls_buffer_new(0)) == NULL) 285 errx(1, "tls_buffer_new"); 286 287 tls_buffer_set_capacity_limit(buf, 8192); 288 289 for (i = 0; i < N_READ_WRITE_TESTS; i++) { 290 rwt = &read_write_tests[i]; 291 292 if (rwt->append > 0) { 293 free(wbuf); 294 if ((wbuf = malloc(rwt->append)) == NULL) 295 errx(1, "malloc"); 296 memset(wbuf, rwt->pattern, rwt->append); 297 if ((ret = tls_buffer_append(buf, wbuf, rwt->append)) != 298 rwt->want) { 299 fprintf(stderr, "FAIL: test %zu - " 300 "tls_buffer_append() = %d, want %zu\n", 301 i, ret, rwt->want); 302 goto failed; 303 } 304 } 305 306 if (rwt->write > 0) { 307 free(wbuf); 308 if ((wbuf = malloc(rwt->write)) == NULL) 309 errx(1, "malloc"); 310 memset(wbuf, rwt->pattern, rwt->write); 311 if ((n = tls_buffer_write(buf, wbuf, rwt->write)) != 312 rwt->want) { 313 fprintf(stderr, "FAIL: test %zu - " 314 "tls_buffer_write() = %zi, want %zu\n", 315 i, n, rwt->want); 316 goto failed; 317 } 318 } 319 320 if (rwt->read > 0) { 321 free(rbuf); 322 if ((rbuf = calloc(1, rwt->read)) == NULL) 323 errx(1, "malloc"); 324 if ((n = tls_buffer_read(buf, rbuf, rwt->read)) != 325 rwt->want) { 326 fprintf(stderr, "FAIL: test %zu - " 327 "tls_buffer_read() = %zi, want %zu\n", 328 i, n, rwt->want); 329 goto failed; 330 } 331 if (rwt->want > 0) { 332 free(wbuf); 333 if ((wbuf = malloc(rwt->want)) == NULL) 334 errx(1, "malloc"); 335 memset(wbuf, rwt->pattern, rwt->want); 336 if (memcmp(rbuf, wbuf, rwt->want) != 0) { 337 fprintf(stderr, "FAIL: test %zu - " 338 "read byte mismatch\n", i); 339 goto failed; 340 } 341 } 342 } 343 } 344 345 failed = 0; 346 347 failed: 348 tls_buffer_free(buf); 349 free(rbuf); 350 free(wbuf); 351 352 return failed; 353 } 354 355 int 356 main(int argc, char **argv) 357 { 358 int failed = 0; 359 360 failed |= tls_buffer_extend_test(); 361 failed |= tls_buffer_read_write_test(); 362 363 return failed; 364 } 365