1 /* $OpenBSD: mlkem1024.c,v 1.6 2025/01/03 08:19:24 tb Exp $ */ 2 /* 3 * Copyright (c) 2024, Google Inc. 4 * Copyright (c) 2024, Bob Beck <beck@obtuse.com> 5 * 6 * Permission to use, copy, modify, and/or distribute this software for any 7 * purpose with or without fee is hereby granted, provided that the above 8 * copyright notice and this permission notice appear in all copies. 9 * 10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY 13 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION 15 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN 16 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 17 */ 18 19 #include <assert.h> 20 #include <stdlib.h> 21 #include <string.h> 22 23 #include "bytestring.h" 24 #include "mlkem.h" 25 26 #include "sha3_internal.h" 27 #include "mlkem_internal.h" 28 #include "constant_time.h" 29 #include "crypto_internal.h" 30 31 /* Remove later */ 32 #undef LCRYPTO_ALIAS 33 #define LCRYPTO_ALIAS(A) 34 35 /* 36 * See 37 * https://csrc.nist.gov/pubs/fips/203/final 38 */ 39 40 static void 41 prf(uint8_t *out, size_t out_len, const uint8_t in[33]) 42 { 43 sha3_ctx ctx; 44 shake256_init(&ctx); 45 shake_update(&ctx, in, 33); 46 shake_xof(&ctx); 47 shake_out(&ctx, out, out_len); 48 } 49 50 /* Section 4.1 */ 51 static void 52 hash_h(uint8_t out[32], const uint8_t *in, size_t len) 53 { 54 sha3_ctx ctx; 55 sha3_init(&ctx, 32); 56 sha3_update(&ctx, in, len); 57 sha3_final(out, &ctx); 58 } 59 60 static void 61 hash_g(uint8_t out[64], const uint8_t *in, size_t len) 62 { 63 sha3_ctx ctx; 64 sha3_init(&ctx, 64); 65 sha3_update(&ctx, in, len); 66 sha3_final(out, &ctx); 67 } 68 69 /* this is called 'J' in the spec */ 70 static void 71 kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES], const uint8_t failure_secret[32], 72 const uint8_t *in, size_t len) 73 { 74 sha3_ctx ctx; 75 shake256_init(&ctx); 76 shake_update(&ctx, failure_secret, 32); 77 shake_update(&ctx, in, len); 78 shake_xof(&ctx); 79 shake_out(&ctx, out, MLKEM_SHARED_SECRET_BYTES); 80 } 81 82 #define DEGREE 256 83 #define RANK1024 4 84 85 static const size_t kBarrettMultiplier = 5039; 86 static const unsigned kBarrettShift = 24; 87 static const uint16_t kPrime = 3329; 88 static const int kLog2Prime = 12; 89 static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2; 90 static const int kDU1024 = 11; 91 static const int kDV1024 = 5; 92 93 /* 94 * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th 95 * root of unity. 96 */ 97 static const uint16_t kInverseDegree = 3303; 98 static const size_t kEncodedVectorSize = 99 (/*kLog2Prime=*/12 * DEGREE / 8) * RANK1024; 100 static const size_t kCompressedVectorSize = /*kDU1024=*/ 11 * RANK1024 * DEGREE / 101 8; 102 103 typedef struct scalar { 104 /* On every function entry and exit, 0 <= c < kPrime. */ 105 uint16_t c[DEGREE]; 106 } scalar; 107 108 typedef struct vector { 109 scalar v[RANK1024]; 110 } vector; 111 112 typedef struct matrix { 113 scalar v[RANK1024][RANK1024]; 114 } matrix; 115 116 /* 117 * This bit of Python will be referenced in some of the following comments: 118 * 119 * p = 3329 120 * 121 * def bitreverse(i): 122 * ret = 0 123 * for n in range(7): 124 * bit = i & 1 125 * ret <<= 1 126 * ret |= bit 127 * i >>= 1 128 * return ret 129 */ 130 131 /* kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] */ 132 static const uint16_t kNTTRoots[128] = { 133 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 134 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, 135 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, 136 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915, 137 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, 138 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, 139 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789, 140 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, 141 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, 142 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, 143 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154, 144 }; 145 146 /* kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */ 147 static const uint16_t kInverseNTTRoots[128] = { 148 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543, 149 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903, 150 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855, 151 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010, 152 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132, 153 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607, 154 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230, 155 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745, 156 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482, 157 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920, 158 2229, 1041, 2606, 1692, 680, 2746, 568, 3312, 159 }; 160 161 /* kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] */ 162 static const uint16_t kModRoots[128] = { 163 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, 164 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096, 165 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678, 166 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642, 167 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992, 168 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, 169 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010, 170 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735, 171 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179, 172 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300, 173 2110, 1219, 2935, 394, 885, 2444, 2154, 1175, 174 }; 175 176 /* reduce_once reduces 0 <= x < 2*kPrime, mod kPrime. */ 177 static uint16_t 178 reduce_once(uint16_t x) 179 { 180 assert(x < 2 * kPrime); 181 const uint16_t subtracted = x - kPrime; 182 uint16_t mask = 0u - (subtracted >> 15); 183 184 /* 185 * Although this is a constant-time select, we omit a value barrier here. 186 * Value barriers impede auto-vectorization (likely because it forces the 187 * value to transit through a general-purpose register). On AArch64, this 188 * is a difference of 2x. 189 * 190 * We usually add value barriers to selects because Clang turns 191 * consecutive selects with the same condition into a branch instead of 192 * CMOV/CSEL. This condition does not occur in ML-KEM, so omitting it 193 * seems to be safe so far but see 194 * |scalar_centered_binomial_distribution_eta_2_with_prf|. 195 */ 196 return (mask & x) | (~mask & subtracted); 197 } 198 199 /* 200 * constant time reduce x mod kPrime using Barrett reduction. x must be less 201 * than kPrime + 2×kPrime². 202 */ 203 static uint16_t 204 reduce(uint32_t x) 205 { 206 uint64_t product = (uint64_t)x * kBarrettMultiplier; 207 uint32_t quotient = (uint32_t)(product >> kBarrettShift); 208 uint32_t remainder = x - quotient * kPrime; 209 210 assert(x < kPrime + 2u * kPrime * kPrime); 211 return reduce_once(remainder); 212 } 213 214 static void 215 scalar_zero(scalar *out) 216 { 217 memset(out, 0, sizeof(*out)); 218 } 219 220 static void 221 vector_zero(vector *out) 222 { 223 memset(out, 0, sizeof(*out)); 224 } 225 226 /* 227 * In place number theoretic transform of a given scalar. 228 * Note that MLKEM's kPrime 3329 does not have a 512th root of unity, so this 229 * transform leaves off the last iteration of the usual FFT code, with the 128 230 * relevant roots of unity being stored in |kNTTRoots|. This means the output 231 * should be seen as 128 elements in GF(3329^2), with the coefficients of the 232 * elements being consecutive entries in |s->c|. 233 */ 234 static void 235 scalar_ntt(scalar *s) 236 { 237 int offset = DEGREE; 238 int step; 239 /* 240 * `int` is used here because using `size_t` throughout caused a ~5% slowdown 241 * with Clang 14 on Aarch64. 242 */ 243 for (step = 1; step < DEGREE / 2; step <<= 1) { 244 int i, j, k = 0; 245 246 offset >>= 1; 247 for (i = 0; i < step; i++) { 248 const uint32_t step_root = kNTTRoots[i + step]; 249 250 for (j = k; j < k + offset; j++) { 251 uint16_t odd, even; 252 253 odd = reduce(step_root * s->c[j + offset]); 254 even = s->c[j]; 255 s->c[j] = reduce_once(odd + even); 256 s->c[j + offset] = reduce_once(even - odd + 257 kPrime); 258 } 259 k += 2 * offset; 260 } 261 } 262 } 263 264 static void 265 vector_ntt(vector *a) 266 { 267 int i; 268 269 for (i = 0; i < RANK1024; i++) { 270 scalar_ntt(&a->v[i]); 271 } 272 } 273 274 /* 275 * In place inverse number theoretic transform of a given scalar, with pairs of 276 * entries of s->v being interpreted as elements of GF(3329^2). Just as with the 277 * number theoretic transform, this leaves off the first step of the normal iFFT 278 * to account for the fact that 3329 does not have a 512th root of unity, using 279 * the precomputed 128 roots of unity stored in |kInverseNTTRoots|. 280 */ 281 static void 282 scalar_inverse_ntt(scalar *s) 283 { 284 int i, j, k, offset, step = DEGREE / 2; 285 286 /* 287 * `int` is used here because using `size_t` throughout caused a ~5% slowdown 288 * with Clang 14 on Aarch64. 289 */ 290 for (offset = 2; offset < DEGREE; offset <<= 1) { 291 step >>= 1; 292 k = 0; 293 for (i = 0; i < step; i++) { 294 uint32_t step_root = kInverseNTTRoots[i + step]; 295 for (j = k; j < k + offset; j++) { 296 uint16_t odd, even; 297 odd = s->c[j + offset]; 298 even = s->c[j]; 299 s->c[j] = reduce_once(odd + even); 300 s->c[j + offset] = reduce(step_root * 301 (even - odd + kPrime)); 302 } 303 k += 2 * offset; 304 } 305 } 306 for (i = 0; i < DEGREE; i++) { 307 s->c[i] = reduce(s->c[i] * kInverseDegree); 308 } 309 } 310 311 static void 312 vector_inverse_ntt(vector *a) 313 { 314 int i; 315 316 for (i = 0; i < RANK1024; i++) { 317 scalar_inverse_ntt(&a->v[i]); 318 } 319 } 320 321 static void 322 scalar_add(scalar *lhs, const scalar *rhs) 323 { 324 int i; 325 326 for (i = 0; i < DEGREE; i++) { 327 lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]); 328 } 329 } 330 331 static void 332 scalar_sub(scalar *lhs, const scalar *rhs) 333 { 334 int i; 335 336 for (i = 0; i < DEGREE; i++) { 337 lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime); 338 } 339 } 340 341 /* 342 * Multiplying two scalars in the number theoretically transformed state. 343 * Since 3329 does not have a 512th root of unity, this means we have to 344 * interpret the 2*ith and (2*i+1)th entries of the scalar as elements of 345 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)). 346 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed 347 * |kModRoots| table. Our Barrett transform only allows us to multiply two 348 * reduced numbers together, so we need some intermediate reduction steps, 349 * even if an uint64_t could hold 3 multiplied numbers. 350 */ 351 static void 352 scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) 353 { 354 int i; 355 356 for (i = 0; i < DEGREE / 2; i++) { 357 uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i]; 358 uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] * 359 rhs->c[2 * i + 1]; 360 uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1]; 361 uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i]; 362 363 out->c[2 * i] = 364 reduce(real_real + 365 (uint32_t)reduce(img_img) * kModRoots[i]); 366 out->c[2 * i + 1] = reduce(img_real + real_img); 367 } 368 } 369 370 static void 371 vector_add(vector *lhs, const vector *rhs) 372 { 373 int i; 374 375 for (i = 0; i < RANK1024; i++) { 376 scalar_add(&lhs->v[i], &rhs->v[i]); 377 } 378 } 379 380 static void 381 matrix_mult(vector *out, const matrix *m, const vector *a) 382 { 383 int i, j; 384 385 vector_zero(out); 386 for (i = 0; i < RANK1024; i++) { 387 for (j = 0; j < RANK1024; j++) { 388 scalar product; 389 390 scalar_mult(&product, &m->v[i][j], &a->v[j]); 391 scalar_add(&out->v[i], &product); 392 } 393 } 394 } 395 396 static void 397 matrix_mult_transpose(vector *out, const matrix *m, 398 const vector *a) 399 { 400 int i, j; 401 402 vector_zero(out); 403 for (i = 0; i < RANK1024; i++) { 404 for (j = 0; j < RANK1024; j++) { 405 scalar product; 406 407 scalar_mult(&product, &m->v[j][i], &a->v[j]); 408 scalar_add(&out->v[i], &product); 409 } 410 } 411 } 412 413 static void 414 scalar_inner_product(scalar *out, const vector *lhs, 415 const vector *rhs) 416 { 417 int i; 418 scalar_zero(out); 419 for (i = 0; i < RANK1024; i++) { 420 scalar product; 421 422 scalar_mult(&product, &lhs->v[i], &rhs->v[i]); 423 scalar_add(out, &product); 424 } 425 } 426 427 /* 428 * Algorithm 6 of spec. Rejection samples a Keccak stream to get uniformly 429 * distributed elements. This is used for matrix expansion and only operates on 430 * public inputs. 431 */ 432 static void 433 scalar_from_keccak_vartime(scalar *out, sha3_ctx *keccak_ctx) 434 { 435 int i, done = 0; 436 437 while (done < DEGREE) { 438 uint8_t block[168]; 439 440 shake_out(keccak_ctx, block, sizeof(block)); 441 for (i = 0; i < sizeof(block) && done < DEGREE; i += 3) { 442 uint16_t d1 = block[i] + 256 * (block[i + 1] % 16); 443 uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2]; 444 445 if (d1 < kPrime) { 446 out->c[done++] = d1; 447 } 448 if (d2 < kPrime && done < DEGREE) { 449 out->c[done++] = d2; 450 } 451 } 452 } 453 } 454 455 /* 456 * Algorithm 7 of the spec, with eta fixed to two and the PRF call 457 * included. Creates binominally distributed elements by sampling 2*|eta| bits, 458 * and setting the coefficient to the count of the first bits minus the count of 459 * the second bits, resulting in a centered binomial distribution. Since eta is 460 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4, 461 * and 0 with probability 3/8. 462 */ 463 static void 464 scalar_centered_binomial_distribution_eta_2_with_prf(scalar *out, 465 const uint8_t input[33]) 466 { 467 uint8_t entropy[128]; 468 int i; 469 470 CTASSERT(sizeof(entropy) == 2 * /*kEta=*/ 2 * DEGREE / 8); 471 prf(entropy, sizeof(entropy), input); 472 473 for (i = 0; i < DEGREE; i += 2) { 474 uint8_t byte = entropy[i / 2]; 475 uint16_t mask; 476 uint16_t value = (byte & 1) + ((byte >> 1) & 1); 477 478 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1); 479 480 /* 481 * Add |kPrime| if |value| underflowed. See |reduce_once| for a 482 * discussion on why the value barrier is omitted. While this 483 * could have been written reduce_once(value + kPrime), this is 484 * one extra addition and small range of |value| tempts some 485 * versions of Clang to emit a branch. 486 */ 487 mask = 0u - (value >> 15); 488 out->c[i] = ((value + kPrime) & mask) | (value & ~mask); 489 490 byte >>= 4; 491 value = (byte & 1) + ((byte >> 1) & 1); 492 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1); 493 /* See above. */ 494 mask = 0u - (value >> 15); 495 out->c[i + 1] = ((value + kPrime) & mask) | (value & ~mask); 496 } 497 } 498 499 /* 500 * Generates a secret vector by using 501 * |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed 502 * appending and incrementing |counter| for entry of the vector. 503 */ 504 static void 505 vector_generate_secret_eta_2(vector *out, uint8_t *counter, 506 const uint8_t seed[32]) 507 { 508 uint8_t input[33]; 509 int i; 510 511 memcpy(input, seed, 32); 512 for (i = 0; i < RANK1024; i++) { 513 input[32] = (*counter)++; 514 scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i], 515 input); 516 } 517 } 518 519 /* Expands the matrix of a seed for key generation and for encaps-CPA. */ 520 static void 521 matrix_expand(matrix *out, const uint8_t rho[32]) 522 { 523 uint8_t input[34]; 524 int i, j; 525 526 memcpy(input, rho, 32); 527 for (i = 0; i < RANK1024; i++) { 528 for (j = 0; j < RANK1024; j++) { 529 sha3_ctx keccak_ctx; 530 531 input[32] = i; 532 input[33] = j; 533 shake128_init(&keccak_ctx); 534 shake_update(&keccak_ctx, input, sizeof(input)); 535 shake_xof(&keccak_ctx); 536 scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx); 537 } 538 } 539 } 540 541 static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f, 542 0x1f, 0x3f, 0x7f, 0xff}; 543 544 static void 545 scalar_encode(uint8_t *out, const scalar *s, int bits) 546 { 547 uint8_t out_byte = 0; 548 int i, out_byte_bits = 0; 549 550 assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1); 551 for (i = 0; i < DEGREE; i++) { 552 uint16_t element = s->c[i]; 553 int element_bits_done = 0; 554 555 while (element_bits_done < bits) { 556 int chunk_bits = bits - element_bits_done; 557 int out_bits_remaining = 8 - out_byte_bits; 558 559 if (chunk_bits >= out_bits_remaining) { 560 chunk_bits = out_bits_remaining; 561 out_byte |= (element & 562 kMasks[chunk_bits - 1]) << out_byte_bits; 563 *out = out_byte; 564 out++; 565 out_byte_bits = 0; 566 out_byte = 0; 567 } else { 568 out_byte |= (element & 569 kMasks[chunk_bits - 1]) << out_byte_bits; 570 out_byte_bits += chunk_bits; 571 } 572 573 element_bits_done += chunk_bits; 574 element >>= chunk_bits; 575 } 576 } 577 578 if (out_byte_bits > 0) { 579 *out = out_byte; 580 } 581 } 582 583 /* scalar_encode_1 is |scalar_encode| specialised for |bits| == 1. */ 584 static void 585 scalar_encode_1(uint8_t out[32], const scalar *s) 586 { 587 int i, j; 588 589 for (i = 0; i < DEGREE; i += 8) { 590 uint8_t out_byte = 0; 591 592 for (j = 0; j < 8; j++) { 593 out_byte |= (s->c[i + j] & 1) << j; 594 } 595 *out = out_byte; 596 out++; 597 } 598 } 599 600 /* 601 * Encodes an entire vector into 32*|RANK1024|*|bits| bytes. Note that since 256 602 * (DEGREE) is divisible by 8, the individual vector entries will always fill a 603 * whole number of bytes, so we do not need to worry about bit packing here. 604 */ 605 static void 606 vector_encode(uint8_t *out, const vector *a, int bits) 607 { 608 int i; 609 610 for (i = 0; i < RANK1024; i++) { 611 scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits); 612 } 613 } 614 615 /* 616 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in 617 * |out|. It returns one on success and zero if any parsed value is >= 618 * |kPrime|. 619 */ 620 static int 621 scalar_decode(scalar *out, const uint8_t *in, int bits) 622 { 623 uint8_t in_byte = 0; 624 int i, in_byte_bits_left = 0; 625 626 assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1); 627 628 for (i = 0; i < DEGREE; i++) { 629 uint16_t element = 0; 630 int element_bits_done = 0; 631 632 while (element_bits_done < bits) { 633 int chunk_bits = bits - element_bits_done; 634 635 if (in_byte_bits_left == 0) { 636 in_byte = *in; 637 in++; 638 in_byte_bits_left = 8; 639 } 640 641 if (chunk_bits > in_byte_bits_left) { 642 chunk_bits = in_byte_bits_left; 643 } 644 645 element |= (in_byte & kMasks[chunk_bits - 1]) << 646 element_bits_done; 647 in_byte_bits_left -= chunk_bits; 648 in_byte >>= chunk_bits; 649 650 element_bits_done += chunk_bits; 651 } 652 653 if (element >= kPrime) { 654 return 0; 655 } 656 out->c[i] = element; 657 } 658 659 return 1; 660 } 661 662 /* scalar_decode_1 is |scalar_decode| specialised for |bits| == 1. */ 663 static void 664 scalar_decode_1(scalar *out, const uint8_t in[32]) 665 { 666 int i, j; 667 668 for (i = 0; i < DEGREE; i += 8) { 669 uint8_t in_byte = *in; 670 671 in++; 672 for (j = 0; j < 8; j++) { 673 out->c[i + j] = in_byte & 1; 674 in_byte >>= 1; 675 } 676 } 677 } 678 679 /* 680 * Decodes 32*|RANK1024|*|bits| bytes from |in| into |out|. It returns one on 681 * success or zero if any parsed value is >= |kPrime|. 682 */ 683 static int 684 vector_decode(vector *out, const uint8_t *in, int bits) 685 { 686 int i; 687 688 for (i = 0; i < RANK1024; i++) { 689 if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8, 690 bits)) { 691 return 0; 692 } 693 } 694 return 1; 695 } 696 697 /* 698 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping 699 * numbers close to each other together. The formula used is 700 * round(2^|bits|/kPrime*x) mod 2^|bits|. 701 * Uses Barrett reduction to achieve constant time. Since we need both the 702 * remainder (for rounding) and the quotient (as the result), we cannot use 703 * |reduce| here, but need to do the Barrett reduction directly. 704 */ 705 static uint16_t 706 compress(uint16_t x, int bits) 707 { 708 uint32_t shifted = (uint32_t)x << bits; 709 uint64_t product = (uint64_t)shifted * kBarrettMultiplier; 710 uint32_t quotient = (uint32_t)(product >> kBarrettShift); 711 uint32_t remainder = shifted - quotient * kPrime; 712 713 /* 714 * Adjust the quotient to round correctly: 715 * 0 <= remainder <= kHalfPrime round to 0 716 * kHalfPrime < remainder <= kPrime + kHalfPrime round to 1 717 * kPrime + kHalfPrime < remainder < 2 * kPrime round to 2 718 */ 719 assert(remainder < 2u * kPrime); 720 quotient += 1 & constant_time_lt(kHalfPrime, remainder); 721 quotient += 1 & constant_time_lt(kPrime + kHalfPrime, remainder); 722 return quotient & ((1 << bits) - 1); 723 } 724 725 /* 726 * Decompresses |x| by using an equi-distant representative. The formula is 727 * round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to 728 * implement this logic using only bit operations. 729 */ 730 static uint16_t 731 decompress(uint16_t x, int bits) 732 { 733 uint32_t product = (uint32_t)x * kPrime; 734 uint32_t power = 1 << bits; 735 /* This is |product| % power, since |power| is a power of 2. */ 736 uint32_t remainder = product & (power - 1); 737 /* This is |product| / power, since |power| is a power of 2. */ 738 uint32_t lower = product >> bits; 739 740 /* 741 * The rounding logic works since the first half of numbers mod |power| have a 742 * 0 as first bit, and the second half has a 1 as first bit, since |power| is 743 * a power of 2. As a 12 bit number, |remainder| is always positive, so we 744 * will shift in 0s for a right shift. 745 */ 746 return lower + (remainder >> (bits - 1)); 747 } 748 749 static void 750 scalar_compress(scalar *s, int bits) 751 { 752 int i; 753 754 for (i = 0; i < DEGREE; i++) { 755 s->c[i] = compress(s->c[i], bits); 756 } 757 } 758 759 static void 760 scalar_decompress(scalar *s, int bits) 761 { 762 int i; 763 764 for (i = 0; i < DEGREE; i++) { 765 s->c[i] = decompress(s->c[i], bits); 766 } 767 } 768 769 static void 770 vector_compress(vector *a, int bits) 771 { 772 int i; 773 774 for (i = 0; i < RANK1024; i++) { 775 scalar_compress(&a->v[i], bits); 776 } 777 } 778 779 static void 780 vector_decompress(vector *a, int bits) 781 { 782 int i; 783 784 for (i = 0; i < RANK1024; i++) { 785 scalar_decompress(&a->v[i], bits); 786 } 787 } 788 789 struct public_key { 790 vector t; 791 uint8_t rho[32]; 792 uint8_t public_key_hash[32]; 793 matrix m; 794 }; 795 796 static struct public_key * 797 public_key_1024_from_external(const struct MLKEM1024_public_key *external) 798 { 799 return (struct public_key *)external; 800 } 801 802 struct private_key { 803 struct public_key pub; 804 vector s; 805 uint8_t fo_failure_secret[32]; 806 }; 807 808 static struct private_key * 809 private_key_1024_from_external(const struct MLKEM1024_private_key *external) 810 { 811 return (struct private_key *)external; 812 } 813 814 /* 815 * Calls |MLKEM1024_generate_key_external_entropy| with random bytes from 816 * |RAND_bytes|. 817 */ 818 void 819 MLKEM1024_generate_key(uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES], 820 uint8_t optional_out_seed[MLKEM_SEED_BYTES], 821 struct MLKEM1024_private_key *out_private_key) 822 { 823 uint8_t entropy_buf[MLKEM_SEED_BYTES]; 824 uint8_t *entropy = optional_out_seed != NULL ? optional_out_seed : 825 entropy_buf; 826 827 arc4random_buf(entropy, MLKEM_SEED_BYTES); 828 MLKEM1024_generate_key_external_entropy(out_encoded_public_key, 829 out_private_key, entropy); 830 } 831 LCRYPTO_ALIAS(MLKEM1024_generate_key); 832 833 int 834 MLKEM1024_private_key_from_seed(struct MLKEM1024_private_key *out_private_key, 835 const uint8_t *seed, size_t seed_len) 836 { 837 uint8_t public_key_bytes[MLKEM1024_PUBLIC_KEY_BYTES]; 838 839 if (seed_len != MLKEM_SEED_BYTES) { 840 return 0; 841 } 842 MLKEM1024_generate_key_external_entropy(public_key_bytes, 843 out_private_key, seed); 844 845 return 1; 846 } 847 LCRYPTO_ALIAS(MLKEM1024_private_key_from_seed); 848 849 static int 850 mlkem_marshal_public_key(CBB *out, const struct public_key *pub) 851 { 852 uint8_t *vector_output; 853 854 if (!CBB_add_space(out, &vector_output, kEncodedVectorSize)) { 855 return 0; 856 } 857 vector_encode(vector_output, &pub->t, kLog2Prime); 858 if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) { 859 return 0; 860 } 861 return 1; 862 } 863 864 void 865 MLKEM1024_generate_key_external_entropy( 866 uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES], 867 struct MLKEM1024_private_key *out_private_key, 868 const uint8_t entropy[MLKEM_SEED_BYTES]) 869 { 870 struct private_key *priv = private_key_1024_from_external( 871 out_private_key); 872 uint8_t augmented_seed[33]; 873 uint8_t *rho, *sigma; 874 uint8_t counter = 0; 875 uint8_t hashed[64]; 876 vector error; 877 CBB cbb; 878 879 memcpy(augmented_seed, entropy, 32); 880 augmented_seed[32] = RANK1024; 881 hash_g(hashed, augmented_seed, 33); 882 rho = hashed; 883 sigma = hashed + 32; 884 memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho)); 885 matrix_expand(&priv->pub.m, rho); 886 vector_generate_secret_eta_2(&priv->s, &counter, sigma); 887 vector_ntt(&priv->s); 888 vector_generate_secret_eta_2(&error, &counter, sigma); 889 vector_ntt(&error); 890 matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s); 891 vector_add(&priv->pub.t, &error); 892 893 /* XXX - error checking. */ 894 CBB_init_fixed(&cbb, out_encoded_public_key, MLKEM1024_PUBLIC_KEY_BYTES); 895 if (!mlkem_marshal_public_key(&cbb, &priv->pub)) { 896 abort(); 897 } 898 CBB_cleanup(&cbb); 899 900 hash_h(priv->pub.public_key_hash, out_encoded_public_key, 901 MLKEM1024_PUBLIC_KEY_BYTES); 902 memcpy(priv->fo_failure_secret, entropy + 32, 32); 903 } 904 905 void 906 MLKEM1024_public_from_private(struct MLKEM1024_public_key *out_public_key, 907 const struct MLKEM1024_private_key *private_key) 908 { 909 struct public_key *const pub = public_key_1024_from_external( 910 out_public_key); 911 const struct private_key *const priv = private_key_1024_from_external( 912 private_key); 913 914 *pub = priv->pub; 915 } 916 LCRYPTO_ALIAS(MLKEM1024_public_from_private); 917 918 /* 919 * Encrypts a message with given randomness to the ciphertext in |out|. Without 920 * applying the Fujisaki-Okamoto transform this would not result in a CCA secure 921 * scheme, since lattice schemes are vulnerable to decryption failure oracles. 922 */ 923 static void 924 encrypt_cpa(uint8_t out[MLKEM1024_CIPHERTEXT_BYTES], 925 const struct public_key *pub, const uint8_t message[32], 926 const uint8_t randomness[32]) 927 { 928 scalar expanded_message, scalar_error; 929 vector secret, error, u; 930 uint8_t counter = 0; 931 uint8_t input[33]; 932 scalar v; 933 934 vector_generate_secret_eta_2(&secret, &counter, randomness); 935 vector_ntt(&secret); 936 vector_generate_secret_eta_2(&error, &counter, randomness); 937 memcpy(input, randomness, 32); 938 input[32] = counter; 939 scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error, 940 input); 941 matrix_mult(&u, &pub->m, &secret); 942 vector_inverse_ntt(&u); 943 vector_add(&u, &error); 944 scalar_inner_product(&v, &pub->t, &secret); 945 scalar_inverse_ntt(&v); 946 scalar_add(&v, &scalar_error); 947 scalar_decode_1(&expanded_message, message); 948 scalar_decompress(&expanded_message, 1); 949 scalar_add(&v, &expanded_message); 950 vector_compress(&u, kDU1024); 951 vector_encode(out, &u, kDU1024); 952 scalar_compress(&v, kDV1024); 953 scalar_encode(out + kCompressedVectorSize, &v, kDV1024); 954 } 955 956 /* Calls MLKEM1024_encap_external_entropy| with random bytes */ 957 void 958 MLKEM1024_encap(uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES], 959 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES], 960 const struct MLKEM1024_public_key *public_key) 961 { 962 uint8_t entropy[MLKEM_ENCAP_ENTROPY]; 963 964 arc4random_buf(entropy, MLKEM_ENCAP_ENTROPY); 965 MLKEM1024_encap_external_entropy(out_ciphertext, out_shared_secret, 966 public_key, entropy); 967 } 968 LCRYPTO_ALIAS(MLKEM1024_encap); 969 970 /* See section 6.2 of the spec. */ 971 void 972 MLKEM1024_encap_external_entropy( 973 uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES], 974 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES], 975 const struct MLKEM1024_public_key *public_key, 976 const uint8_t entropy[MLKEM_ENCAP_ENTROPY]) 977 { 978 const struct public_key *pub = public_key_1024_from_external(public_key); 979 uint8_t key_and_randomness[64]; 980 uint8_t input[64]; 981 982 memcpy(input, entropy, MLKEM_ENCAP_ENTROPY); 983 memcpy(input + MLKEM_ENCAP_ENTROPY, pub->public_key_hash, 984 sizeof(input) - MLKEM_ENCAP_ENTROPY); 985 hash_g(key_and_randomness, input, sizeof(input)); 986 encrypt_cpa(out_ciphertext, pub, entropy, key_and_randomness + 32); 987 memcpy(out_shared_secret, key_and_randomness, 32); 988 } 989 990 static void 991 decrypt_cpa(uint8_t out[32], const struct private_key *priv, 992 const uint8_t ciphertext[MLKEM1024_CIPHERTEXT_BYTES]) 993 { 994 scalar mask, v; 995 vector u; 996 997 vector_decode(&u, ciphertext, kDU1024); 998 vector_decompress(&u, kDU1024); 999 vector_ntt(&u); 1000 scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV1024); 1001 scalar_decompress(&v, kDV1024); 1002 scalar_inner_product(&mask, &priv->s, &u); 1003 scalar_inverse_ntt(&mask); 1004 scalar_sub(&v, &mask); 1005 scalar_compress(&v, 1); 1006 scalar_encode_1(out, &v); 1007 } 1008 1009 /* See section 6.3 */ 1010 int 1011 MLKEM1024_decap(uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES], 1012 const uint8_t *ciphertext, size_t ciphertext_len, 1013 const struct MLKEM1024_private_key *private_key) 1014 { 1015 const struct private_key *priv = private_key_1024_from_external( 1016 private_key); 1017 uint8_t expected_ciphertext[MLKEM1024_CIPHERTEXT_BYTES]; 1018 uint8_t key_and_randomness[64]; 1019 uint8_t failure_key[32]; 1020 uint8_t decrypted[64]; 1021 uint8_t mask; 1022 int i; 1023 1024 if (ciphertext_len != MLKEM1024_CIPHERTEXT_BYTES) { 1025 arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_BYTES); 1026 return 0; 1027 } 1028 1029 decrypt_cpa(decrypted, priv, ciphertext); 1030 memcpy(decrypted + 32, priv->pub.public_key_hash, 1031 sizeof(decrypted) - 32); 1032 hash_g(key_and_randomness, decrypted, sizeof(decrypted)); 1033 encrypt_cpa(expected_ciphertext, &priv->pub, decrypted, 1034 key_and_randomness + 32); 1035 kdf(failure_key, priv->fo_failure_secret, ciphertext, ciphertext_len); 1036 mask = constant_time_eq_int_8(memcmp(ciphertext, expected_ciphertext, 1037 sizeof(expected_ciphertext)), 0); 1038 for (i = 0; i < MLKEM_SHARED_SECRET_BYTES; i++) { 1039 out_shared_secret[i] = constant_time_select_8(mask, 1040 key_and_randomness[i], failure_key[i]); 1041 } 1042 1043 return 1; 1044 } 1045 LCRYPTO_ALIAS(MLKEM1024_decap); 1046 1047 int 1048 MLKEM1024_marshal_public_key(CBB *out, 1049 const struct MLKEM1024_public_key *public_key) 1050 { 1051 return mlkem_marshal_public_key(out, 1052 public_key_1024_from_external(public_key)); 1053 } 1054 LCRYPTO_ALIAS(MLKEM1024_marshal_public_key); 1055 1056 /* 1057 * mlkem_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate 1058 * the value of |pub->public_key_hash|. 1059 */ 1060 static int 1061 mlkem_parse_public_key_no_hash(struct public_key *pub, CBS *in) 1062 { 1063 CBS t_bytes; 1064 1065 if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize) || 1066 !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime)) { 1067 return 0; 1068 } 1069 memcpy(pub->rho, CBS_data(in), sizeof(pub->rho)); 1070 if (!CBS_skip(in, sizeof(pub->rho))) 1071 return 0; 1072 matrix_expand(&pub->m, pub->rho); 1073 return 1; 1074 } 1075 1076 int 1077 MLKEM1024_parse_public_key(struct MLKEM1024_public_key *public_key, CBS *in) 1078 { 1079 struct public_key *pub = public_key_1024_from_external(public_key); 1080 CBS orig_in = *in; 1081 1082 if (!mlkem_parse_public_key_no_hash(pub, in) || 1083 CBS_len(in) != 0) { 1084 return 0; 1085 } 1086 hash_h(pub->public_key_hash, CBS_data(&orig_in), CBS_len(&orig_in)); 1087 return 1; 1088 } 1089 LCRYPTO_ALIAS(MLKEM1024_parse_public_key); 1090 1091 int 1092 MLKEM1024_marshal_private_key(CBB *out, 1093 const struct MLKEM1024_private_key *private_key) 1094 { 1095 const struct private_key *const priv = private_key_1024_from_external( 1096 private_key); 1097 uint8_t *s_output; 1098 1099 if (!CBB_add_space(out, &s_output, kEncodedVectorSize)) { 1100 return 0; 1101 } 1102 vector_encode(s_output, &priv->s, kLog2Prime); 1103 if (!mlkem_marshal_public_key(out, &priv->pub) || 1104 !CBB_add_bytes(out, priv->pub.public_key_hash, 1105 sizeof(priv->pub.public_key_hash)) || 1106 !CBB_add_bytes(out, priv->fo_failure_secret, 1107 sizeof(priv->fo_failure_secret))) { 1108 return 0; 1109 } 1110 return 1; 1111 } 1112 1113 int 1114 MLKEM1024_parse_private_key(struct MLKEM1024_private_key *out_private_key, 1115 CBS *in) 1116 { 1117 struct private_key *const priv = private_key_1024_from_external( 1118 out_private_key); 1119 CBS s_bytes; 1120 1121 if (!CBS_get_bytes(in, &s_bytes, kEncodedVectorSize) || 1122 !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) || 1123 !mlkem_parse_public_key_no_hash(&priv->pub, in)) { 1124 return 0; 1125 } 1126 memcpy(priv->pub.public_key_hash, CBS_data(in), 1127 sizeof(priv->pub.public_key_hash)); 1128 if (!CBS_skip(in, sizeof(priv->pub.public_key_hash))) 1129 return 0; 1130 memcpy(priv->fo_failure_secret, CBS_data(in), 1131 sizeof(priv->fo_failure_secret)); 1132 if (!CBS_skip(in, sizeof(priv->fo_failure_secret))) 1133 return 0; 1134 if (CBS_len(in) != 0) 1135 return 0; 1136 1137 return 1; 1138 } 1139 LCRYPTO_ALIAS(MLKEM1024_parse_private_key); 1140