1 /* crypto/rsa/rsa_oaep.c */ 2 /* Written by Ulf Moeller. This software is distributed on an "AS IS" 3 basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. */ 4 5 /* EME_OAEP as defined in RFC 2437 (PKCS #1 v2.0) */ 6 7 #if !defined(NO_SHA) && !defined(NO_SHA1) 8 #include <stdio.h> 9 #include "cryptlib.h" 10 #include <openssl/bn.h> 11 #include <openssl/rsa.h> 12 #include <openssl/sha.h> 13 #include <openssl/rand.h> 14 15 int MGF1(unsigned char *mask, long len, unsigned char *seed, long seedlen); 16 17 int RSA_padding_add_PKCS1_OAEP(unsigned char *to, int tlen, 18 unsigned char *from, int flen, unsigned char *param, int plen) 19 { 20 int i, emlen = tlen - 1; 21 unsigned char *db, *seed; 22 unsigned char *dbmask, seedmask[SHA_DIGEST_LENGTH]; 23 24 if (flen > emlen - 2 * SHA_DIGEST_LENGTH - 1) 25 { 26 RSAerr(RSA_F_RSA_PADDING_ADD_PKCS1_OAEP, 27 RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE); 28 return (0); 29 } 30 31 if (emlen < 2 * SHA_DIGEST_LENGTH + 1) 32 { 33 RSAerr(RSA_F_RSA_PADDING_ADD_PKCS1_OAEP, RSA_R_KEY_SIZE_TOO_SMALL); 34 return (0); 35 } 36 37 dbmask = OPENSSL_malloc(emlen - SHA_DIGEST_LENGTH); 38 if (dbmask == NULL) 39 { 40 RSAerr(RSA_F_RSA_PADDING_ADD_PKCS1_OAEP, ERR_R_MALLOC_FAILURE); 41 return (0); 42 } 43 44 to[0] = 0; 45 seed = to + 1; 46 db = to + SHA_DIGEST_LENGTH + 1; 47 48 SHA1(param, plen, db); 49 memset(db + SHA_DIGEST_LENGTH, 0, 50 emlen - flen - 2 * SHA_DIGEST_LENGTH - 1); 51 db[emlen - flen - SHA_DIGEST_LENGTH - 1] = 0x01; 52 memcpy(db + emlen - flen - SHA_DIGEST_LENGTH, from, (unsigned int) flen); 53 if (RAND_bytes(seed, SHA_DIGEST_LENGTH) <= 0) 54 return (0); 55 #ifdef PKCS_TESTVECT 56 memcpy(seed, 57 "\xaa\xfd\x12\xf6\x59\xca\xe6\x34\x89\xb4\x79\xe5\x07\x6d\xde\xc2\xf0\x6c\xb5\x8f", 58 20); 59 #endif 60 61 MGF1(dbmask, emlen - SHA_DIGEST_LENGTH, seed, SHA_DIGEST_LENGTH); 62 for (i = 0; i < emlen - SHA_DIGEST_LENGTH; i++) 63 db[i] ^= dbmask[i]; 64 65 MGF1(seedmask, SHA_DIGEST_LENGTH, db, emlen - SHA_DIGEST_LENGTH); 66 for (i = 0; i < SHA_DIGEST_LENGTH; i++) 67 seed[i] ^= seedmask[i]; 68 69 OPENSSL_free(dbmask); 70 return (1); 71 } 72 73 int RSA_padding_check_PKCS1_OAEP(unsigned char *to, int tlen, 74 unsigned char *from, int flen, int num, unsigned char *param, 75 int plen) 76 { 77 int i, dblen, mlen = -1; 78 unsigned char *maskeddb; 79 int lzero; 80 unsigned char *db = NULL, seed[SHA_DIGEST_LENGTH], phash[SHA_DIGEST_LENGTH]; 81 82 if (--num < 2 * SHA_DIGEST_LENGTH + 1) 83 goto decoding_err; 84 85 lzero = num - flen; 86 if (lzero < 0) 87 goto decoding_err; 88 maskeddb = from - lzero + SHA_DIGEST_LENGTH; 89 90 dblen = num - SHA_DIGEST_LENGTH; 91 db = OPENSSL_malloc(dblen); 92 if (db == NULL) 93 { 94 RSAerr(RSA_F_RSA_PADDING_ADD_PKCS1_OAEP, ERR_R_MALLOC_FAILURE); 95 return (-1); 96 } 97 98 MGF1(seed, SHA_DIGEST_LENGTH, maskeddb, dblen); 99 for (i = lzero; i < SHA_DIGEST_LENGTH; i++) 100 seed[i] ^= from[i - lzero]; 101 102 MGF1(db, dblen, seed, SHA_DIGEST_LENGTH); 103 for (i = 0; i < dblen; i++) 104 db[i] ^= maskeddb[i]; 105 106 SHA1(param, plen, phash); 107 108 if (memcmp(db, phash, SHA_DIGEST_LENGTH) != 0) 109 goto decoding_err; 110 else 111 { 112 for (i = SHA_DIGEST_LENGTH; i < dblen; i++) 113 if (db[i] != 0x00) 114 break; 115 if (db[i] != 0x01 || i++ >= dblen) 116 goto decoding_err; 117 else 118 { 119 mlen = dblen - i; 120 if (tlen < mlen) 121 { 122 RSAerr(RSA_F_RSA_PADDING_CHECK_PKCS1_OAEP, RSA_R_DATA_TOO_LARGE); 123 mlen = -1; 124 } 125 else 126 memcpy(to, db + i, mlen); 127 } 128 } 129 OPENSSL_free(db); 130 return (mlen); 131 132 decoding_err: 133 /* to avoid chosen ciphertext attacks, the error message should not reveal 134 * which kind of decoding error happened */ 135 RSAerr(RSA_F_RSA_PADDING_CHECK_PKCS1_OAEP, RSA_R_OAEP_DECODING_ERROR); 136 if (db != NULL) OPENSSL_free(db); 137 return -1; 138 } 139 140 int MGF1(unsigned char *mask, long len, unsigned char *seed, long seedlen) 141 { 142 long i, outlen = 0; 143 unsigned char cnt[4]; 144 SHA_CTX c; 145 unsigned char md[SHA_DIGEST_LENGTH]; 146 147 for (i = 0; outlen < len; i++) 148 { 149 cnt[0] = (i >> 24) & 255, cnt[1] = (i >> 16) & 255, 150 cnt[2] = (i >> 8) & 255, cnt[3] = i & 255; 151 SHA1_Init(&c); 152 SHA1_Update(&c, seed, seedlen); 153 SHA1_Update(&c, cnt, 4); 154 if (outlen + SHA_DIGEST_LENGTH <= len) 155 { 156 SHA1_Final(mask + outlen, &c); 157 outlen += SHA_DIGEST_LENGTH; 158 } 159 else 160 { 161 SHA1_Final(md, &c); 162 memcpy(mask + outlen, md, len - outlen); 163 outlen = len; 164 } 165 } 166 return (0); 167 } 168 #endif 169