xref: /inferno-os/libkeyring/rsaalg.c (revision d3641b487cf5cdc46e9b537d30eb37736e5c7b1a)
1 #include <lib9.h>
2 #include <kernel.h>
3 #include <isa.h>
4 #include "interp.h"
5 #include "../libinterp/keyringif.h"
6 #include "mp.h"
7 #include "libsec.h"
8 #include "keys.h"
9 
10 static char*	pkattr[] = { "n", "ek", nil };
11 static char*	skattr[] = { "n", "ek", "!dk", "!p", "!q", "!kp", "!kq", "!c2", nil };
12 static char*	sigattr[] = { "val", nil };
13 
14 static void*
15 rsa_str2sk(char *str, char **strp)
16 {
17 	RSApriv *rsa;
18 	char *p;
19 
20 	rsa = rsaprivalloc();
21 	rsa->pub.n = base64tobig(str, &p);
22 	rsa->pub.ek = base64tobig(p, &p);
23 	rsa->dk = base64tobig(p, &p);
24 	rsa->p = base64tobig(p, &p);
25 	rsa->q = base64tobig(p, &p);
26 	rsa->kp = base64tobig(p, &p);
27 	rsa->kq = base64tobig(p, &p);
28 	rsa->c2 = base64tobig(p, &p);
29 	if(strp)
30 		*strp = p;
31 	if(rsa->pub.n == nil || rsa->pub.ek == nil ||
32 	   rsa->dk == nil || rsa->p == nil || rsa->q == nil ||
33 	   rsa->kp == nil || rsa->kq == nil || rsa->c2 == nil){
34 		rsaprivfree(rsa);
35 		return nil;
36 	}
37 
38 	return rsa;
39 }
40 
41 static void*
42 rsa_str2pk(char *str, char **strp)
43 {
44 	RSApub *rsa;
45 	char *p;
46 
47 	rsa = rsapuballoc();
48 	rsa->n = base64tobig(str, &p);
49 	rsa->ek = base64tobig(p, &p);
50 	if(strp)
51 		*strp = p;
52 	if(rsa->n == nil || rsa->ek == nil){
53 		rsapubfree(rsa);
54 		return nil;
55 	}
56 
57 	return rsa;
58 }
59 
60 static void*
61 rsa_str2sig(char *str, char **strp)
62 {
63 	mpint *rsa;
64 	char *p;
65 
66 	rsa = base64tobig(str, &p);
67 	if(rsa == nil)
68 		return nil;
69 	if(strp)
70 		*strp = p;
71 	return rsa;
72 }
73 
74 static int
75 rsa_sk2str(void *vrsa, char *buf, int len)
76 {
77 	RSApriv *rsa;
78 	char *cp, *ep;
79 
80 	rsa = vrsa;
81 	ep = buf + len - 1;
82 	cp = buf;
83 
84 	cp += snprint(cp, ep - cp, "%U\n", rsa->pub.n);
85 	cp += snprint(cp, ep - cp, "%U\n", rsa->pub.ek);
86 	cp += snprint(cp, ep - cp, "%U\n", rsa->dk);
87 	cp += snprint(cp, ep - cp, "%U\n", rsa->p);
88 	cp += snprint(cp, ep - cp, "%U\n", rsa->q);
89 	cp += snprint(cp, ep - cp, "%U\n", rsa->kp);
90 	cp += snprint(cp, ep - cp, "%U\n", rsa->kq);
91 	cp += snprint(cp, ep - cp, "%U\n", rsa->c2);
92 	*cp = 0;
93 
94 	return cp - buf;
95 }
96 
97 static int
98 rsa_pk2str(void *vrsa, char *buf, int len)
99 {
100 	RSApub *rsa;
101 	char *cp, *ep;
102 
103 	rsa = vrsa;
104 	ep = buf + len - 1;
105 	cp = buf;
106 	cp += snprint(cp, ep - cp, "%U\n", rsa->n);
107 	cp += snprint(cp, ep - cp, "%U\n", rsa->ek);
108 	*cp = 0;
109 
110 	return cp - buf;
111 }
112 
113 static int
114 rsa_sig2str(void *vrsa, char *buf, int len)
115 {
116 	mpint *rsa;
117 	char *cp, *ep;
118 
119 	rsa = vrsa;
120 	ep = buf + len - 1;
121 	cp = buf;
122 
123 	cp += snprint(cp, ep - cp, "%U\n", rsa);
124 	*cp = 0;
125 
126 	return cp - buf;
127 }
128 
129 static void*
130 rsa_sk2pk(void *vs)
131 {
132 	return rsaprivtopub((RSApriv*)vs);
133 }
134 
135 /* generate an rsa secret key */
136 static void*
137 rsa_gen(int len)
138 {
139 	RSApriv *key;
140 
141 	for(;;){
142 		key = rsagen(len, 6, 0);
143 		if(mpsignif(key->pub.n) == len)
144 			return key;
145 		rsaprivfree(key);
146 	}
147 }
148 
149 /* generate an rsa secret key with same params as a public key */
150 static void*
151 rsa_genfrompk(void *vpub)
152 {
153 	RSApub *pub;
154 
155 	pub = vpub;
156 	return rsagen(mpsignif(pub->n), mpsignif(pub->ek), 0);
157 }
158 
159 static void*
160 rsa_sign(mpint* m, void *key)
161 {
162 	return rsadecrypt((RSApriv*)key, m, nil);
163 }
164 
165 static int
166 rsa_verify(mpint* m, void *sig, void *key)
167 {
168 	mpint *t;
169 	int r;
170 
171 	t = rsaencrypt((RSApub*)key, (mpint*)sig, nil);
172 	r = mpcmp(t, m) == 0;
173 	mpfree(t);
174 	return r;
175 }
176 
177 static void
178 rsa_freepriv(void *a)
179 {
180 	rsaprivfree((RSApriv*)a);
181 }
182 
183 static void
184 rsa_freepub(void *a)
185 {
186 	rsapubfree((RSApub*)a);
187 }
188 
189 static void
190 rsa_freesig(void *a)
191 {
192 	mpfree(a);
193 }
194 
195 SigAlgVec*
196 rsainit(void)
197 {
198 	SigAlgVec *vec;
199 
200 	vec = malloc(sizeof(SigAlgVec));
201 	if(vec == nil)
202 		return nil;
203 
204 	vec->name = "rsa";
205 
206 	vec->pkattr = pkattr;
207 	vec->skattr = skattr;
208 	vec->sigattr = sigattr;
209 
210 	vec->str2sk = rsa_str2sk;
211 	vec->str2pk = rsa_str2pk;
212 	vec->str2sig = rsa_str2sig;
213 
214 	vec->sk2str = rsa_sk2str;
215 	vec->pk2str = rsa_pk2str;
216 	vec->sig2str = rsa_sig2str;
217 
218 	vec->sk2pk = rsa_sk2pk;
219 
220 	vec->gensk = rsa_gen;
221 	vec->genskfrompk = rsa_genfrompk;
222 	vec->sign = rsa_sign;
223 	vec->verify = rsa_verify;
224 
225 	vec->skfree = rsa_freepriv;
226 	vec->pkfree = rsa_freepub;
227 	vec->sigfree = rsa_freesig;
228 
229 	return vec;
230 }
231