1 /* $OpenBSD: mlkem768.c,v 1.7 2025/01/03 08:19:24 tb Exp $ */ 2 /* 3 * Copyright (c) 2024, Google Inc. 4 * Copyright (c) 2024, Bob Beck <beck@obtuse.com> 5 * 6 * Permission to use, copy, modify, and/or distribute this software for any 7 * purpose with or without fee is hereby granted, provided that the above 8 * copyright notice and this permission notice appear in all copies. 9 * 10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY 13 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION 15 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN 16 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 17 */ 18 19 #include <assert.h> 20 #include <stdlib.h> 21 #include <string.h> 22 23 #include "bytestring.h" 24 #include "mlkem.h" 25 26 #include "sha3_internal.h" 27 #include "mlkem_internal.h" 28 #include "constant_time.h" 29 #include "crypto_internal.h" 30 31 /* Remove later */ 32 #undef LCRYPTO_ALIAS 33 #define LCRYPTO_ALIAS(A) 34 35 /* 36 * See 37 * https://csrc.nist.gov/pubs/fips/203/final 38 */ 39 40 static void 41 prf(uint8_t *out, size_t out_len, const uint8_t in[33]) 42 { 43 sha3_ctx ctx; 44 shake256_init(&ctx); 45 shake_update(&ctx, in, 33); 46 shake_xof(&ctx); 47 shake_out(&ctx, out, out_len); 48 } 49 50 /* Section 4.1 */ 51 static void 52 hash_h(uint8_t out[32], const uint8_t *in, size_t len) 53 { 54 sha3_ctx ctx; 55 sha3_init(&ctx, 32); 56 sha3_update(&ctx, in, len); 57 sha3_final(out, &ctx); 58 } 59 60 static void 61 hash_g(uint8_t out[64], const uint8_t *in, size_t len) 62 { 63 sha3_ctx ctx; 64 sha3_init(&ctx, 64); 65 sha3_update(&ctx, in, len); 66 sha3_final(out, &ctx); 67 } 68 69 /* this is called 'J' in the spec */ 70 static void 71 kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES], const uint8_t failure_secret[32], 72 const uint8_t *in, size_t len) 73 { 74 sha3_ctx ctx; 75 shake256_init(&ctx); 76 shake_update(&ctx, failure_secret, 32); 77 shake_update(&ctx, in, len); 78 shake_xof(&ctx); 79 shake_out(&ctx, out, MLKEM_SHARED_SECRET_BYTES); 80 } 81 82 #define DEGREE 256 83 #define RANK768 3 84 85 static const size_t kBarrettMultiplier = 5039; 86 static const unsigned kBarrettShift = 24; 87 static const uint16_t kPrime = 3329; 88 static const int kLog2Prime = 12; 89 static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2; 90 static const int kDU768 = 10; 91 static const int kDV768 = 4; 92 /* 93 * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th 94 * root of unity. 95 */ 96 static const uint16_t kInverseDegree = 3303; 97 static const size_t kEncodedVectorSize = 98 (/*kLog2Prime=*/12 * DEGREE / 8) * RANK768; 99 static const size_t kCompressedVectorSize = /*kDU768=*/ 10 * RANK768 * DEGREE / 100 8; 101 102 typedef struct scalar { 103 /* On every function entry and exit, 0 <= c < kPrime. */ 104 uint16_t c[DEGREE]; 105 } scalar; 106 107 typedef struct vector { 108 scalar v[RANK768]; 109 } vector; 110 111 typedef struct matrix { 112 scalar v[RANK768][RANK768]; 113 } matrix; 114 115 /* 116 * This bit of Python will be referenced in some of the following comments: 117 * 118 * p = 3329 119 * 120 * def bitreverse(i): 121 * ret = 0 122 * for n in range(7): 123 * bit = i & 1 124 * ret <<= 1 125 * ret |= bit 126 * i >>= 1 127 * return ret 128 */ 129 130 /* kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] */ 131 static const uint16_t kNTTRoots[128] = { 132 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 133 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, 134 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, 135 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915, 136 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, 137 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, 138 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789, 139 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, 140 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, 141 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, 142 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154, 143 }; 144 145 /* kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */ 146 static const uint16_t kInverseNTTRoots[128] = { 147 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543, 148 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903, 149 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855, 150 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010, 151 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132, 152 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607, 153 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230, 154 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745, 155 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482, 156 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920, 157 2229, 1041, 2606, 1692, 680, 2746, 568, 3312, 158 }; 159 160 /* kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] */ 161 static const uint16_t kModRoots[128] = { 162 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, 163 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096, 164 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678, 165 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642, 166 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992, 167 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, 168 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010, 169 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735, 170 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179, 171 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300, 172 2110, 1219, 2935, 394, 885, 2444, 2154, 1175, 173 }; 174 175 /* reduce_once reduces 0 <= x < 2*kPrime, mod kPrime. */ 176 static uint16_t 177 reduce_once(uint16_t x) 178 { 179 assert(x < 2 * kPrime); 180 const uint16_t subtracted = x - kPrime; 181 uint16_t mask = 0u - (subtracted >> 15); 182 183 /* 184 * Although this is a constant-time select, we omit a value barrier here. 185 * Value barriers impede auto-vectorization (likely because it forces the 186 * value to transit through a general-purpose register). On AArch64, this 187 * is a difference of 2x. 188 * 189 * We usually add value barriers to selects because Clang turns 190 * consecutive selects with the same condition into a branch instead of 191 * CMOV/CSEL. This condition does not occur in ML-KEM, so omitting it 192 * seems to be safe so far but see 193 * |scalar_centered_binomial_distribution_eta_2_with_prf|. 194 */ 195 return (mask & x) | (~mask & subtracted); 196 } 197 198 /* 199 * constant time reduce x mod kPrime using Barrett reduction. x must be less 200 * than kPrime + 2×kPrime². 201 */ 202 static uint16_t 203 reduce(uint32_t x) 204 { 205 uint64_t product = (uint64_t)x * kBarrettMultiplier; 206 uint32_t quotient = (uint32_t)(product >> kBarrettShift); 207 uint32_t remainder = x - quotient * kPrime; 208 209 assert(x < kPrime + 2u * kPrime * kPrime); 210 return reduce_once(remainder); 211 } 212 213 static void 214 scalar_zero(scalar *out) 215 { 216 memset(out, 0, sizeof(*out)); 217 } 218 219 static void 220 vector_zero(vector *out) 221 { 222 memset(out, 0, sizeof(*out)); 223 } 224 225 /* 226 * In place number theoretic transform of a given scalar. 227 * Note that MLKEM's kPrime 3329 does not have a 512th root of unity, so this 228 * transform leaves off the last iteration of the usual FFT code, with the 128 229 * relevant roots of unity being stored in |kNTTRoots|. This means the output 230 * should be seen as 128 elements in GF(3329^2), with the coefficients of the 231 * elements being consecutive entries in |s->c|. 232 */ 233 static void 234 scalar_ntt(scalar *s) 235 { 236 int offset = DEGREE; 237 int step; 238 /* 239 * `int` is used here because using `size_t` throughout caused a ~5% slowdown 240 * with Clang 14 on Aarch64. 241 */ 242 for (step = 1; step < DEGREE / 2; step <<= 1) { 243 int i, j, k = 0; 244 245 offset >>= 1; 246 for (i = 0; i < step; i++) { 247 const uint32_t step_root = kNTTRoots[i + step]; 248 249 for (j = k; j < k + offset; j++) { 250 uint16_t odd, even; 251 252 odd = reduce(step_root * s->c[j + offset]); 253 even = s->c[j]; 254 s->c[j] = reduce_once(odd + even); 255 s->c[j + offset] = reduce_once(even - odd + 256 kPrime); 257 } 258 k += 2 * offset; 259 } 260 } 261 } 262 263 static void 264 vector_ntt(vector *a) 265 { 266 int i; 267 268 for (i = 0; i < RANK768; i++) { 269 scalar_ntt(&a->v[i]); 270 } 271 } 272 273 /* 274 * In place inverse number theoretic transform of a given scalar, with pairs of 275 * entries of s->v being interpreted as elements of GF(3329^2). Just as with the 276 * number theoretic transform, this leaves off the first step of the normal iFFT 277 * to account for the fact that 3329 does not have a 512th root of unity, using 278 * the precomputed 128 roots of unity stored in |kInverseNTTRoots|. 279 */ 280 static void 281 scalar_inverse_ntt(scalar *s) 282 { 283 int i, j, k, offset, step = DEGREE / 2; 284 285 /* 286 * `int` is used here because using `size_t` throughout caused a ~5% slowdown 287 * with Clang 14 on Aarch64. 288 */ 289 for (offset = 2; offset < DEGREE; offset <<= 1) { 290 step >>= 1; 291 k = 0; 292 for (i = 0; i < step; i++) { 293 uint32_t step_root = kInverseNTTRoots[i + step]; 294 for (j = k; j < k + offset; j++) { 295 uint16_t odd, even; 296 odd = s->c[j + offset]; 297 even = s->c[j]; 298 s->c[j] = reduce_once(odd + even); 299 s->c[j + offset] = reduce(step_root * 300 (even - odd + kPrime)); 301 } 302 k += 2 * offset; 303 } 304 } 305 for (i = 0; i < DEGREE; i++) { 306 s->c[i] = reduce(s->c[i] * kInverseDegree); 307 } 308 } 309 310 static void 311 vector_inverse_ntt(vector *a) 312 { 313 int i; 314 315 for (i = 0; i < RANK768; i++) { 316 scalar_inverse_ntt(&a->v[i]); 317 } 318 } 319 320 static void 321 scalar_add(scalar *lhs, const scalar *rhs) 322 { 323 int i; 324 325 for (i = 0; i < DEGREE; i++) { 326 lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]); 327 } 328 } 329 330 static void 331 scalar_sub(scalar *lhs, const scalar *rhs) 332 { 333 int i; 334 335 for (i = 0; i < DEGREE; i++) { 336 lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime); 337 } 338 } 339 340 /* 341 * Multiplying two scalars in the number theoretically transformed state. 342 * Since 3329 does not have a 512th root of unity, this means we have to 343 * interpret the 2*ith and (2*i+1)th entries of the scalar as elements of 344 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)). 345 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed 346 * |kModRoots| table. Our Barrett transform only allows us to multiply two 347 * reduced numbers together, so we need some intermediate reduction steps, 348 * even if an uint64_t could hold 3 multiplied numbers. 349 */ 350 static void 351 scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) 352 { 353 int i; 354 355 for (i = 0; i < DEGREE / 2; i++) { 356 uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i]; 357 uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] * 358 rhs->c[2 * i + 1]; 359 uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1]; 360 uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i]; 361 362 out->c[2 * i] = 363 reduce(real_real + 364 (uint32_t)reduce(img_img) * kModRoots[i]); 365 out->c[2 * i + 1] = reduce(img_real + real_img); 366 } 367 } 368 369 static void 370 vector_add(vector *lhs, const vector *rhs) 371 { 372 int i; 373 374 for (i = 0; i < RANK768; i++) { 375 scalar_add(&lhs->v[i], &rhs->v[i]); 376 } 377 } 378 379 static void 380 matrix_mult(vector *out, const matrix *m, const vector *a) 381 { 382 int i, j; 383 384 vector_zero(out); 385 for (i = 0; i < RANK768; i++) { 386 for (j = 0; j < RANK768; j++) { 387 scalar product; 388 389 scalar_mult(&product, &m->v[i][j], &a->v[j]); 390 scalar_add(&out->v[i], &product); 391 } 392 } 393 } 394 395 static void 396 matrix_mult_transpose(vector *out, const matrix *m, 397 const vector *a) 398 { 399 int i, j; 400 401 vector_zero(out); 402 for (i = 0; i < RANK768; i++) { 403 for (j = 0; j < RANK768; j++) { 404 scalar product; 405 406 scalar_mult(&product, &m->v[j][i], &a->v[j]); 407 scalar_add(&out->v[i], &product); 408 } 409 } 410 } 411 412 static void 413 scalar_inner_product(scalar *out, const vector *lhs, 414 const vector *rhs) 415 { 416 int i; 417 scalar_zero(out); 418 for (i = 0; i < RANK768; i++) { 419 scalar product; 420 421 scalar_mult(&product, &lhs->v[i], &rhs->v[i]); 422 scalar_add(out, &product); 423 } 424 } 425 426 /* 427 * Algorithm 6 of spec. Rejection samples a Keccak stream to get uniformly 428 * distributed elements. This is used for matrix expansion and only operates on 429 * public inputs. 430 */ 431 static void 432 scalar_from_keccak_vartime(scalar *out, sha3_ctx *keccak_ctx) 433 { 434 int i, done = 0; 435 436 while (done < DEGREE) { 437 uint8_t block[168]; 438 439 shake_out(keccak_ctx, block, sizeof(block)); 440 for (i = 0; i < sizeof(block) && done < DEGREE; i += 3) { 441 uint16_t d1 = block[i] + 256 * (block[i + 1] % 16); 442 uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2]; 443 444 if (d1 < kPrime) { 445 out->c[done++] = d1; 446 } 447 if (d2 < kPrime && done < DEGREE) { 448 out->c[done++] = d2; 449 } 450 } 451 } 452 } 453 454 /* 455 * Algorithm 7 of the spec, with eta fixed to two and the PRF call 456 * included. Creates binominally distributed elements by sampling 2*|eta| bits, 457 * and setting the coefficient to the count of the first bits minus the count of 458 * the second bits, resulting in a centered binomial distribution. Since eta is 459 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4, 460 * and 0 with probability 3/8. 461 */ 462 static void 463 scalar_centered_binomial_distribution_eta_2_with_prf(scalar *out, 464 const uint8_t input[33]) 465 { 466 uint8_t entropy[128]; 467 int i; 468 469 CTASSERT(sizeof(entropy) == 2 * /*kEta=*/ 2 * DEGREE / 8); 470 prf(entropy, sizeof(entropy), input); 471 472 for (i = 0; i < DEGREE; i += 2) { 473 uint8_t byte = entropy[i / 2]; 474 uint16_t mask; 475 uint16_t value = (byte & 1) + ((byte >> 1) & 1); 476 477 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1); 478 479 /* 480 * Add |kPrime| if |value| underflowed. See |reduce_once| for a 481 * discussion on why the value barrier is omitted. While this 482 * could have been written reduce_once(value + kPrime), this is 483 * one extra addition and small range of |value| tempts some 484 * versions of Clang to emit a branch. 485 */ 486 mask = 0u - (value >> 15); 487 out->c[i] = ((value + kPrime) & mask) | (value & ~mask); 488 489 byte >>= 4; 490 value = (byte & 1) + ((byte >> 1) & 1); 491 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1); 492 /* See above. */ 493 mask = 0u - (value >> 15); 494 out->c[i + 1] = ((value + kPrime) & mask) | (value & ~mask); 495 } 496 } 497 498 /* 499 * Generates a secret vector by using 500 * |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed 501 * appending and incrementing |counter| for entry of the vector. 502 */ 503 static void 504 vector_generate_secret_eta_2(vector *out, uint8_t *counter, 505 const uint8_t seed[32]) 506 { 507 uint8_t input[33]; 508 int i; 509 510 memcpy(input, seed, 32); 511 for (i = 0; i < RANK768; i++) { 512 input[32] = (*counter)++; 513 scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i], 514 input); 515 } 516 } 517 518 /* Expands the matrix of a seed for key generation and for encaps-CPA. */ 519 static void 520 matrix_expand(matrix *out, const uint8_t rho[32]) 521 { 522 uint8_t input[34]; 523 int i, j; 524 525 memcpy(input, rho, 32); 526 for (i = 0; i < RANK768; i++) { 527 for (j = 0; j < RANK768; j++) { 528 sha3_ctx keccak_ctx; 529 530 input[32] = i; 531 input[33] = j; 532 shake128_init(&keccak_ctx); 533 shake_update(&keccak_ctx, input, sizeof(input)); 534 shake_xof(&keccak_ctx); 535 scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx); 536 } 537 } 538 } 539 540 static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f, 541 0x1f, 0x3f, 0x7f, 0xff}; 542 543 static void 544 scalar_encode(uint8_t *out, const scalar *s, int bits) 545 { 546 uint8_t out_byte = 0; 547 int i, out_byte_bits = 0; 548 549 assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1); 550 for (i = 0; i < DEGREE; i++) { 551 uint16_t element = s->c[i]; 552 int element_bits_done = 0; 553 554 while (element_bits_done < bits) { 555 int chunk_bits = bits - element_bits_done; 556 int out_bits_remaining = 8 - out_byte_bits; 557 558 if (chunk_bits >= out_bits_remaining) { 559 chunk_bits = out_bits_remaining; 560 out_byte |= (element & 561 kMasks[chunk_bits - 1]) << out_byte_bits; 562 *out = out_byte; 563 out++; 564 out_byte_bits = 0; 565 out_byte = 0; 566 } else { 567 out_byte |= (element & 568 kMasks[chunk_bits - 1]) << out_byte_bits; 569 out_byte_bits += chunk_bits; 570 } 571 572 element_bits_done += chunk_bits; 573 element >>= chunk_bits; 574 } 575 } 576 577 if (out_byte_bits > 0) { 578 *out = out_byte; 579 } 580 } 581 582 /* scalar_encode_1 is |scalar_encode| specialised for |bits| == 1. */ 583 static void 584 scalar_encode_1(uint8_t out[32], const scalar *s) 585 { 586 int i, j; 587 588 for (i = 0; i < DEGREE; i += 8) { 589 uint8_t out_byte = 0; 590 591 for (j = 0; j < 8; j++) { 592 out_byte |= (s->c[i + j] & 1) << j; 593 } 594 *out = out_byte; 595 out++; 596 } 597 } 598 599 /* 600 * Encodes an entire vector into 32*|RANK768|*|bits| bytes. Note that since 256 601 * (DEGREE) is divisible by 8, the individual vector entries will always fill a 602 * whole number of bytes, so we do not need to worry about bit packing here. 603 */ 604 static void 605 vector_encode(uint8_t *out, const vector *a, int bits) 606 { 607 int i; 608 609 for (i = 0; i < RANK768; i++) { 610 scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits); 611 } 612 } 613 614 /* 615 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in 616 * |out|. It returns one on success and zero if any parsed value is >= 617 * |kPrime|. 618 */ 619 static int 620 scalar_decode(scalar *out, const uint8_t *in, int bits) 621 { 622 uint8_t in_byte = 0; 623 int i, in_byte_bits_left = 0; 624 625 assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1); 626 627 for (i = 0; i < DEGREE; i++) { 628 uint16_t element = 0; 629 int element_bits_done = 0; 630 631 while (element_bits_done < bits) { 632 int chunk_bits = bits - element_bits_done; 633 634 if (in_byte_bits_left == 0) { 635 in_byte = *in; 636 in++; 637 in_byte_bits_left = 8; 638 } 639 640 if (chunk_bits > in_byte_bits_left) { 641 chunk_bits = in_byte_bits_left; 642 } 643 644 element |= (in_byte & kMasks[chunk_bits - 1]) << 645 element_bits_done; 646 in_byte_bits_left -= chunk_bits; 647 in_byte >>= chunk_bits; 648 649 element_bits_done += chunk_bits; 650 } 651 652 if (element >= kPrime) { 653 return 0; 654 } 655 out->c[i] = element; 656 } 657 658 return 1; 659 } 660 661 /* scalar_decode_1 is |scalar_decode| specialised for |bits| == 1. */ 662 static void 663 scalar_decode_1(scalar *out, const uint8_t in[32]) 664 { 665 int i, j; 666 667 for (i = 0; i < DEGREE; i += 8) { 668 uint8_t in_byte = *in; 669 670 in++; 671 for (j = 0; j < 8; j++) { 672 out->c[i + j] = in_byte & 1; 673 in_byte >>= 1; 674 } 675 } 676 } 677 678 /* 679 * Decodes 32*|RANK768|*|bits| bytes from |in| into |out|. It returns one on 680 * success or zero if any parsed value is >= |kPrime|. 681 */ 682 static int 683 vector_decode(vector *out, const uint8_t *in, int bits) 684 { 685 int i; 686 687 for (i = 0; i < RANK768; i++) { 688 if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8, 689 bits)) { 690 return 0; 691 } 692 } 693 return 1; 694 } 695 696 /* 697 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping 698 * numbers close to each other together. The formula used is 699 * round(2^|bits|/kPrime*x) mod 2^|bits|. 700 * Uses Barrett reduction to achieve constant time. Since we need both the 701 * remainder (for rounding) and the quotient (as the result), we cannot use 702 * |reduce| here, but need to do the Barrett reduction directly. 703 */ 704 static uint16_t 705 compress(uint16_t x, int bits) 706 { 707 uint32_t shifted = (uint32_t)x << bits; 708 uint64_t product = (uint64_t)shifted * kBarrettMultiplier; 709 uint32_t quotient = (uint32_t)(product >> kBarrettShift); 710 uint32_t remainder = shifted - quotient * kPrime; 711 712 /* 713 * Adjust the quotient to round correctly: 714 * 0 <= remainder <= kHalfPrime round to 0 715 * kHalfPrime < remainder <= kPrime + kHalfPrime round to 1 716 * kPrime + kHalfPrime < remainder < 2 * kPrime round to 2 717 */ 718 assert(remainder < 2u * kPrime); 719 quotient += 1 & constant_time_lt(kHalfPrime, remainder); 720 quotient += 1 & constant_time_lt(kPrime + kHalfPrime, remainder); 721 return quotient & ((1 << bits) - 1); 722 } 723 724 /* 725 * Decompresses |x| by using an equi-distant representative. The formula is 726 * round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to 727 * implement this logic using only bit operations. 728 */ 729 static uint16_t 730 decompress(uint16_t x, int bits) 731 { 732 uint32_t product = (uint32_t)x * kPrime; 733 uint32_t power = 1 << bits; 734 /* This is |product| % power, since |power| is a power of 2. */ 735 uint32_t remainder = product & (power - 1); 736 /* This is |product| / power, since |power| is a power of 2. */ 737 uint32_t lower = product >> bits; 738 739 /* 740 * The rounding logic works since the first half of numbers mod |power| have a 741 * 0 as first bit, and the second half has a 1 as first bit, since |power| is 742 * a power of 2. As a 12 bit number, |remainder| is always positive, so we 743 * will shift in 0s for a right shift. 744 */ 745 return lower + (remainder >> (bits - 1)); 746 } 747 748 static void 749 scalar_compress(scalar *s, int bits) 750 { 751 int i; 752 753 for (i = 0; i < DEGREE; i++) { 754 s->c[i] = compress(s->c[i], bits); 755 } 756 } 757 758 static void 759 scalar_decompress(scalar *s, int bits) 760 { 761 int i; 762 763 for (i = 0; i < DEGREE; i++) { 764 s->c[i] = decompress(s->c[i], bits); 765 } 766 } 767 768 static void 769 vector_compress(vector *a, int bits) 770 { 771 int i; 772 773 for (i = 0; i < RANK768; i++) { 774 scalar_compress(&a->v[i], bits); 775 } 776 } 777 778 static void 779 vector_decompress(vector *a, int bits) 780 { 781 int i; 782 783 for (i = 0; i < RANK768; i++) { 784 scalar_decompress(&a->v[i], bits); 785 } 786 } 787 788 struct public_key { 789 vector t; 790 uint8_t rho[32]; 791 uint8_t public_key_hash[32]; 792 matrix m; 793 }; 794 795 static struct public_key * 796 public_key_768_from_external(const struct MLKEM768_public_key *external) 797 { 798 return (struct public_key *)external; 799 } 800 801 struct private_key { 802 struct public_key pub; 803 vector s; 804 uint8_t fo_failure_secret[32]; 805 }; 806 807 static struct private_key * 808 private_key_768_from_external(const struct MLKEM768_private_key *external) 809 { 810 return (struct private_key *)external; 811 } 812 813 /* 814 * Calls |MLKEM768_generate_key_external_entropy| with random bytes from 815 * |RAND_bytes|. 816 */ 817 void 818 MLKEM768_generate_key(uint8_t out_encoded_public_key[MLKEM768_PUBLIC_KEY_BYTES], 819 uint8_t optional_out_seed[MLKEM_SEED_BYTES], 820 struct MLKEM768_private_key *out_private_key) 821 { 822 uint8_t entropy_buf[MLKEM_SEED_BYTES]; 823 uint8_t *entropy = optional_out_seed != NULL ? optional_out_seed : 824 entropy_buf; 825 826 arc4random_buf(entropy, MLKEM_SEED_BYTES); 827 MLKEM768_generate_key_external_entropy(out_encoded_public_key, 828 out_private_key, entropy); 829 } 830 LCRYPTO_ALIAS(MLKEM768_generate_key); 831 832 int 833 MLKEM768_private_key_from_seed(struct MLKEM768_private_key *out_private_key, 834 const uint8_t *seed, size_t seed_len) 835 { 836 uint8_t public_key_bytes[MLKEM768_PUBLIC_KEY_BYTES]; 837 838 if (seed_len != MLKEM_SEED_BYTES) { 839 return 0; 840 } 841 MLKEM768_generate_key_external_entropy(public_key_bytes, 842 out_private_key, seed); 843 844 return 1; 845 } 846 LCRYPTO_ALIAS(MLKEM768_private_key_from_seed); 847 848 static int 849 mlkem_marshal_public_key(CBB *out, const struct public_key *pub) 850 { 851 uint8_t *vector_output; 852 853 if (!CBB_add_space(out, &vector_output, kEncodedVectorSize)) { 854 return 0; 855 } 856 vector_encode(vector_output, &pub->t, kLog2Prime); 857 if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) { 858 return 0; 859 } 860 return 1; 861 } 862 863 void 864 MLKEM768_generate_key_external_entropy( 865 uint8_t out_encoded_public_key[MLKEM768_PUBLIC_KEY_BYTES], 866 struct MLKEM768_private_key *out_private_key, 867 const uint8_t entropy[MLKEM_SEED_BYTES]) 868 { 869 struct private_key *priv = private_key_768_from_external( 870 out_private_key); 871 uint8_t augmented_seed[33]; 872 uint8_t *rho, *sigma; 873 uint8_t counter = 0; 874 uint8_t hashed[64]; 875 vector error; 876 CBB cbb; 877 878 memcpy(augmented_seed, entropy, 32); 879 augmented_seed[32] = RANK768; 880 hash_g(hashed, augmented_seed, 33); 881 rho = hashed; 882 sigma = hashed + 32; 883 memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho)); 884 matrix_expand(&priv->pub.m, rho); 885 vector_generate_secret_eta_2(&priv->s, &counter, sigma); 886 vector_ntt(&priv->s); 887 vector_generate_secret_eta_2(&error, &counter, sigma); 888 vector_ntt(&error); 889 matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s); 890 vector_add(&priv->pub.t, &error); 891 892 /* XXX - error checking */ 893 CBB_init_fixed(&cbb, out_encoded_public_key, MLKEM768_PUBLIC_KEY_BYTES); 894 if (!mlkem_marshal_public_key(&cbb, &priv->pub)) { 895 abort(); 896 } 897 CBB_cleanup(&cbb); 898 899 hash_h(priv->pub.public_key_hash, out_encoded_public_key, 900 MLKEM768_PUBLIC_KEY_BYTES); 901 memcpy(priv->fo_failure_secret, entropy + 32, 32); 902 } 903 904 void 905 MLKEM768_public_from_private(struct MLKEM768_public_key *out_public_key, 906 const struct MLKEM768_private_key *private_key) 907 { 908 struct public_key *const pub = public_key_768_from_external( 909 out_public_key); 910 const struct private_key *const priv = private_key_768_from_external( 911 private_key); 912 913 *pub = priv->pub; 914 } 915 LCRYPTO_ALIAS(MLKEM768_public_from_private); 916 917 /* 918 * Encrypts a message with given randomness to the ciphertext in |out|. Without 919 * applying the Fujisaki-Okamoto transform this would not result in a CCA secure 920 * scheme, since lattice schemes are vulnerable to decryption failure oracles. 921 */ 922 static void 923 encrypt_cpa(uint8_t out[MLKEM768_CIPHERTEXT_BYTES], 924 const struct public_key *pub, const uint8_t message[32], 925 const uint8_t randomness[32]) 926 { 927 scalar expanded_message, scalar_error; 928 vector secret, error, u; 929 uint8_t counter = 0; 930 uint8_t input[33]; 931 scalar v; 932 933 vector_generate_secret_eta_2(&secret, &counter, randomness); 934 vector_ntt(&secret); 935 vector_generate_secret_eta_2(&error, &counter, randomness); 936 memcpy(input, randomness, 32); 937 input[32] = counter; 938 scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error, 939 input); 940 matrix_mult(&u, &pub->m, &secret); 941 vector_inverse_ntt(&u); 942 vector_add(&u, &error); 943 scalar_inner_product(&v, &pub->t, &secret); 944 scalar_inverse_ntt(&v); 945 scalar_add(&v, &scalar_error); 946 scalar_decode_1(&expanded_message, message); 947 scalar_decompress(&expanded_message, 1); 948 scalar_add(&v, &expanded_message); 949 vector_compress(&u, kDU768); 950 vector_encode(out, &u, kDU768); 951 scalar_compress(&v, kDV768); 952 scalar_encode(out + kCompressedVectorSize, &v, kDV768); 953 } 954 955 /* Calls MLKEM768_encap_external_entropy| with random bytes */ 956 void 957 MLKEM768_encap(uint8_t out_ciphertext[MLKEM768_CIPHERTEXT_BYTES], 958 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES], 959 const struct MLKEM768_public_key *public_key) 960 { 961 uint8_t entropy[MLKEM_ENCAP_ENTROPY]; 962 963 arc4random_buf(entropy, MLKEM_ENCAP_ENTROPY); 964 MLKEM768_encap_external_entropy(out_ciphertext, out_shared_secret, 965 public_key, entropy); 966 } 967 LCRYPTO_ALIAS(MLKEM768_encap); 968 969 /* See section 6.2 of the spec. */ 970 void 971 MLKEM768_encap_external_entropy( 972 uint8_t out_ciphertext[MLKEM768_CIPHERTEXT_BYTES], 973 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES], 974 const struct MLKEM768_public_key *public_key, 975 const uint8_t entropy[MLKEM_ENCAP_ENTROPY]) 976 { 977 const struct public_key *pub = public_key_768_from_external(public_key); 978 uint8_t key_and_randomness[64]; 979 uint8_t input[64]; 980 981 memcpy(input, entropy, MLKEM_ENCAP_ENTROPY); 982 memcpy(input + MLKEM_ENCAP_ENTROPY, pub->public_key_hash, 983 sizeof(input) - MLKEM_ENCAP_ENTROPY); 984 hash_g(key_and_randomness, input, sizeof(input)); 985 encrypt_cpa(out_ciphertext, pub, entropy, key_and_randomness + 32); 986 memcpy(out_shared_secret, key_and_randomness, 32); 987 } 988 989 static void 990 decrypt_cpa(uint8_t out[32], const struct private_key *priv, 991 const uint8_t ciphertext[MLKEM768_CIPHERTEXT_BYTES]) 992 { 993 scalar mask, v; 994 vector u; 995 996 vector_decode(&u, ciphertext, kDU768); 997 vector_decompress(&u, kDU768); 998 vector_ntt(&u); 999 scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV768); 1000 scalar_decompress(&v, kDV768); 1001 scalar_inner_product(&mask, &priv->s, &u); 1002 scalar_inverse_ntt(&mask); 1003 scalar_sub(&v, &mask); 1004 scalar_compress(&v, 1); 1005 scalar_encode_1(out, &v); 1006 } 1007 1008 /* See section 6.3 */ 1009 int 1010 MLKEM768_decap(uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES], 1011 const uint8_t *ciphertext, size_t ciphertext_len, 1012 const struct MLKEM768_private_key *private_key) 1013 { 1014 const struct private_key *priv = private_key_768_from_external( 1015 private_key); 1016 uint8_t expected_ciphertext[MLKEM768_CIPHERTEXT_BYTES]; 1017 uint8_t key_and_randomness[64]; 1018 uint8_t failure_key[32]; 1019 uint8_t decrypted[64]; 1020 uint8_t mask; 1021 int i; 1022 1023 if (ciphertext_len != MLKEM768_CIPHERTEXT_BYTES) { 1024 arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_BYTES); 1025 return 0; 1026 } 1027 1028 decrypt_cpa(decrypted, priv, ciphertext); 1029 memcpy(decrypted + 32, priv->pub.public_key_hash, 1030 sizeof(decrypted) - 32); 1031 hash_g(key_and_randomness, decrypted, sizeof(decrypted)); 1032 encrypt_cpa(expected_ciphertext, &priv->pub, decrypted, 1033 key_and_randomness + 32); 1034 kdf(failure_key, priv->fo_failure_secret, ciphertext, ciphertext_len); 1035 mask = constant_time_eq_int_8(memcmp(ciphertext, expected_ciphertext, 1036 sizeof(expected_ciphertext)), 0); 1037 for (i = 0; i < MLKEM_SHARED_SECRET_BYTES; i++) { 1038 out_shared_secret[i] = constant_time_select_8(mask, 1039 key_and_randomness[i], failure_key[i]); 1040 } 1041 1042 return 1; 1043 } 1044 LCRYPTO_ALIAS(MLKEM768_decap); 1045 1046 int 1047 MLKEM768_marshal_public_key(CBB *out, 1048 const struct MLKEM768_public_key *public_key) 1049 { 1050 return mlkem_marshal_public_key(out, 1051 public_key_768_from_external(public_key)); 1052 } 1053 LCRYPTO_ALIAS(MLKEM768_marshal_public_key); 1054 1055 /* 1056 * mlkem_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate 1057 * the value of |pub->public_key_hash|. 1058 */ 1059 static int 1060 mlkem_parse_public_key_no_hash(struct public_key *pub, CBS *in) 1061 { 1062 CBS t_bytes; 1063 1064 if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize) || 1065 !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime)) { 1066 return 0; 1067 } 1068 memcpy(pub->rho, CBS_data(in), sizeof(pub->rho)); 1069 if (!CBS_skip(in, sizeof(pub->rho))) 1070 return 0; 1071 matrix_expand(&pub->m, pub->rho); 1072 return 1; 1073 } 1074 1075 int 1076 MLKEM768_parse_public_key(struct MLKEM768_public_key *public_key, CBS *in) 1077 { 1078 struct public_key *pub = public_key_768_from_external(public_key); 1079 CBS orig_in = *in; 1080 1081 if (!mlkem_parse_public_key_no_hash(pub, in) || 1082 CBS_len(in) != 0) { 1083 return 0; 1084 } 1085 hash_h(pub->public_key_hash, CBS_data(&orig_in), CBS_len(&orig_in)); 1086 return 1; 1087 } 1088 LCRYPTO_ALIAS(MLKEM768_parse_public_key); 1089 1090 int 1091 MLKEM768_marshal_private_key(CBB *out, 1092 const struct MLKEM768_private_key *private_key) 1093 { 1094 const struct private_key *const priv = private_key_768_from_external( 1095 private_key); 1096 uint8_t *s_output; 1097 1098 if (!CBB_add_space(out, &s_output, kEncodedVectorSize)) { 1099 return 0; 1100 } 1101 vector_encode(s_output, &priv->s, kLog2Prime); 1102 if (!mlkem_marshal_public_key(out, &priv->pub) || 1103 !CBB_add_bytes(out, priv->pub.public_key_hash, 1104 sizeof(priv->pub.public_key_hash)) || 1105 !CBB_add_bytes(out, priv->fo_failure_secret, 1106 sizeof(priv->fo_failure_secret))) { 1107 return 0; 1108 } 1109 return 1; 1110 } 1111 1112 int 1113 MLKEM768_parse_private_key(struct MLKEM768_private_key *out_private_key, 1114 CBS *in) 1115 { 1116 struct private_key *const priv = private_key_768_from_external( 1117 out_private_key); 1118 CBS s_bytes; 1119 1120 if (!CBS_get_bytes(in, &s_bytes, kEncodedVectorSize) || 1121 !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) || 1122 !mlkem_parse_public_key_no_hash(&priv->pub, in)) { 1123 return 0; 1124 } 1125 memcpy(priv->pub.public_key_hash, CBS_data(in), 1126 sizeof(priv->pub.public_key_hash)); 1127 if (!CBS_skip(in, sizeof(priv->pub.public_key_hash))) 1128 return 0; 1129 memcpy(priv->fo_failure_secret, CBS_data(in), 1130 sizeof(priv->fo_failure_secret)); 1131 if (!CBS_skip(in, sizeof(priv->fo_failure_secret))) 1132 return 0; 1133 if (CBS_len(in) != 0) 1134 return 0; 1135 1136 return 1; 1137 } 1138 LCRYPTO_ALIAS(MLKEM768_parse_private_key); 1139