1 /* $OpenBSD: mlkem_tests_util.c,v 1.3 2024/12/20 00:07:12 tb Exp $ */ 2 /* 3 * Copyright (c) 2024 Google Inc. 4 * Copyright (c) 2024 Bob Beck <beck@obtuse.com> 5 * Copyright (c) 2024 Theo Buehler <tb@openbsd.org> 6 * 7 * Permission to use, copy, modify, and/or distribute this software for any 8 * purpose with or without fee is hereby granted, provided that the above 9 * copyright notice and this permission notice appear in all copies. 10 * 11 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 12 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 13 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY 14 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 15 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION 16 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN 17 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 18 */ 19 20 #include <err.h> 21 #include <stdint.h> 22 #include <stdio.h> 23 #include <string.h> 24 25 #include "bytestring.h" 26 #include "mlkem.h" 27 28 #include "mlkem_internal.h" 29 30 #include "mlkem_tests_util.h" 31 32 int failure; 33 int test_number; 34 35 static void 36 hexdump(const uint8_t *buf, size_t len, const uint8_t *compare) 37 { 38 const char *mark = ""; 39 size_t i; 40 41 for (i = 1; i <= len; i++) { 42 if (compare != NULL) 43 mark = (buf[i - 1] != compare[i - 1]) ? "*" : " "; 44 fprintf(stderr, " %s0x%02hhx,%s", mark, buf[i - 1], 45 i % 8 && i != len ? "" : "\n"); 46 } 47 fprintf(stderr, "\n"); 48 } 49 50 int 51 compare_data(const uint8_t *want, const uint8_t *got, size_t len, size_t line, 52 const char *msg) 53 { 54 if (memcmp(want, got, len) == 0) 55 return 0; 56 57 warnx("FAIL: #%zu - %s differs", line, msg); 58 fprintf(stderr, "want:\n"); 59 hexdump(want, len, got); 60 fprintf(stderr, "got:\n"); 61 hexdump(got, len, want); 62 fprintf(stderr, "\n"); 63 64 return 1; 65 } 66 67 int 68 compare_length(size_t want, size_t got, size_t line, const char *msg) 69 { 70 if (want == got) 71 return 1; 72 73 warnx("#%zu: %s: want %zu, got %zu", line, msg, want, got); 74 return 0; 75 } 76 77 static int 78 hex_get_nibble_cbs(CBS *cbs, uint8_t *out_nibble) 79 { 80 uint8_t c; 81 82 if (!CBS_get_u8(cbs, &c)) 83 return 0; 84 85 if (c >= '0' && c <= '9') { 86 *out_nibble = c - '0'; 87 return 1; 88 } 89 if (c >= 'a' && c <= 'f') { 90 *out_nibble = c - 'a' + 10; 91 return 1; 92 } 93 if (c >= 'A' && c <= 'F') { 94 *out_nibble = c - 'A' + 10; 95 return 1; 96 } 97 98 return 0; 99 } 100 101 void 102 hex_decode_cbs(CBS *cbs, CBB *cbb, size_t line, const char *msg) 103 { 104 if (!CBB_init(cbb, 0)) 105 errx(1, "#%zu %s: %s CBB_init", line, msg, __func__); 106 107 while (CBS_len(cbs) > 0) { 108 uint8_t hi, lo; 109 110 if (!hex_get_nibble_cbs(cbs, &hi)) 111 errx(1, "#%zu %s: %s nibble", line, msg, __func__); 112 if (!hex_get_nibble_cbs(cbs, &lo)) 113 errx(1, "#%zu %s: %s nibble", line, msg, __func__); 114 115 if (!CBB_add_u8(cbb, hi << 4 | lo)) 116 errx(1, "#%zu %s: %s CBB_add_u8", line, msg, __func__); 117 } 118 } 119 120 int 121 get_string_cbs(CBS *cbs_in, const char *str, size_t line, const char *msg) 122 { 123 CBS cbs; 124 size_t len = strlen(str); 125 126 if (!CBS_get_bytes(cbs_in, &cbs, len)) 127 errx(1, "#%zu %s: %s CBB_get_bytes", line, msg, __func__); 128 129 return CBS_mem_equal(&cbs, str, len); 130 } 131 132 int 133 mlkem768_encode_private_key(const struct MLKEM768_private_key *priv, 134 uint8_t **out_buf, size_t *out_len) 135 { 136 CBB cbb; 137 int ret = 0; 138 139 if (!CBB_init(&cbb, MLKEM768_PUBLIC_KEY_BYTES)) 140 goto err; 141 if (!MLKEM768_marshal_private_key(&cbb, priv)) 142 goto err; 143 if (!CBB_finish(&cbb, out_buf, out_len)) 144 goto err; 145 146 ret = 1; 147 148 err: 149 CBB_cleanup(&cbb); 150 151 return ret; 152 } 153 154 int 155 mlkem768_encode_public_key(const struct MLKEM768_public_key *pub, 156 uint8_t **out_buf, size_t *out_len) 157 { 158 CBB cbb; 159 int ret = 0; 160 161 if (!CBB_init(&cbb, MLKEM768_PUBLIC_KEY_BYTES)) 162 goto err; 163 if (!MLKEM768_marshal_public_key(&cbb, pub)) 164 goto err; 165 if (!CBB_finish(&cbb, out_buf, out_len)) 166 goto err; 167 168 ret = 1; 169 170 err: 171 CBB_cleanup(&cbb); 172 173 return ret; 174 } 175 176 int 177 mlkem1024_encode_private_key(const struct MLKEM1024_private_key *priv, 178 uint8_t **out_buf, size_t *out_len) 179 { 180 CBB cbb; 181 int ret = 0; 182 183 if (!CBB_init(&cbb, MLKEM1024_PUBLIC_KEY_BYTES)) 184 goto err; 185 if (!MLKEM1024_marshal_private_key(&cbb, priv)) 186 goto err; 187 if (!CBB_finish(&cbb, out_buf, out_len)) 188 goto err; 189 190 ret = 1; 191 192 err: 193 CBB_cleanup(&cbb); 194 195 return ret; 196 } 197 198 int 199 mlkem1024_encode_public_key(const struct MLKEM1024_public_key *pub, 200 uint8_t **out_buf, size_t *out_len) 201 { 202 CBB cbb; 203 int ret = 0; 204 205 if (!CBB_init(&cbb, MLKEM1024_PUBLIC_KEY_BYTES)) 206 goto err; 207 if (!MLKEM1024_marshal_public_key(&cbb, pub)) 208 goto err; 209 if (!CBB_finish(&cbb, out_buf, out_len)) 210 goto err; 211 212 ret = 1; 213 214 err: 215 CBB_cleanup(&cbb); 216 217 return ret; 218 } 219