1 #include <u.h>
2 #include <libc.h>
3 #include <oventi.h>
4 #include "session.h"
5
6 static char EAuthState[] = "bad authentication state";
7 static char ENotServer[] = "not a server session";
8 static char EVersion[] = "incorrect version number";
9 static char EProtocolBotch[] = "venti protocol botch";
10
11 VtSession *
vtServerAlloc(VtServerVtbl * vtbl)12 vtServerAlloc(VtServerVtbl *vtbl)
13 {
14 VtSession *z = vtAlloc();
15 z->vtbl = vtMemAlloc(sizeof(VtServerVtbl));
16 setmalloctag(z->vtbl, getcallerpc(&vtbl));
17 *z->vtbl = *vtbl;
18 return z;
19 }
20
21 static int
srvHello(VtSession * z,char * version,char * uid,int,uchar *,int,uchar *,int)22 srvHello(VtSession *z, char *version, char *uid, int , uchar *, int , uchar *, int )
23 {
24 vtLock(z->lk);
25 if(z->auth.state != VtAuthHello) {
26 vtSetError(EAuthState);
27 goto Err;
28 }
29 if(strcmp(version, vtGetVersion(z)) != 0) {
30 vtSetError(EVersion);
31 goto Err;
32 }
33 vtMemFree(z->uid);
34 z->uid = vtStrDup(uid);
35 z->auth.state = VtAuthOK;
36 vtUnlock(z->lk);
37 return 1;
38 Err:
39 z->auth.state = VtAuthFailed;
40 vtUnlock(z->lk);
41 return 0;
42 }
43
44
45 static int
dispatchHello(VtSession * z,Packet ** pkt)46 dispatchHello(VtSession *z, Packet **pkt)
47 {
48 char *version, *uid;
49 uchar *crypto, *codec;
50 uchar buf[10];
51 int ncrypto, ncodec, cryptoStrength;
52 int ret;
53 Packet *p;
54
55 p = *pkt;
56
57 version = nil;
58 uid = nil;
59 crypto = nil;
60 codec = nil;
61
62 ret = 0;
63 if(!vtGetString(p, &version))
64 goto Err;
65 if(!vtGetString(p, &uid))
66 goto Err;
67 if(!packetConsume(p, buf, 2))
68 goto Err;
69 cryptoStrength = buf[0];
70 ncrypto = buf[1];
71 crypto = vtMemAlloc(ncrypto);
72 if(!packetConsume(p, crypto, ncrypto))
73 goto Err;
74
75 if(!packetConsume(p, buf, 1))
76 goto Err;
77 ncodec = buf[0];
78 codec = vtMemAlloc(ncodec);
79 if(!packetConsume(p, codec, ncodec))
80 goto Err;
81
82 if(packetSize(p) != 0) {
83 vtSetError(EProtocolBotch);
84 goto Err;
85 }
86 if(!srvHello(z, version, uid, cryptoStrength, crypto, ncrypto, codec, ncodec)) {
87 packetFree(p);
88 *pkt = nil;
89 } else {
90 if(!vtAddString(p, vtGetSid(z)))
91 goto Err;
92 buf[0] = vtGetCrypto(z);
93 buf[1] = vtGetCodec(z);
94 packetAppend(p, buf, 2);
95 }
96 ret = 1;
97 Err:
98 vtMemFree(version);
99 vtMemFree(uid);
100 vtMemFree(crypto);
101 vtMemFree(codec);
102 return ret;
103 }
104
105 static int
dispatchRead(VtSession * z,Packet ** pkt)106 dispatchRead(VtSession *z, Packet **pkt)
107 {
108 Packet *p;
109 int type, n;
110 uchar score[VtScoreSize], buf[4];
111
112 p = *pkt;
113 if(!packetConsume(p, score, VtScoreSize))
114 return 0;
115 if(!packetConsume(p, buf, 4))
116 return 0;
117 type = buf[0];
118 n = (buf[2]<<8) | buf[3];
119 if(packetSize(p) != 0) {
120 vtSetError(EProtocolBotch);
121 return 0;
122 }
123 packetFree(p);
124 *pkt = (*z->vtbl->read)(z, score, type, n);
125 return 1;
126 }
127
128 static int
dispatchWrite(VtSession * z,Packet ** pkt)129 dispatchWrite(VtSession *z, Packet **pkt)
130 {
131 Packet *p;
132 int type;
133 uchar score[VtScoreSize], buf[4];
134
135 p = *pkt;
136 if(!packetConsume(p, buf, 4))
137 return 0;
138 type = buf[0];
139 if(!(z->vtbl->write)(z, score, type, p)) {
140 *pkt = 0;
141 } else {
142 *pkt = packetAlloc();
143 packetAppend(*pkt, score, VtScoreSize);
144 }
145 return 1;
146 }
147
148 static int
dispatchSync(VtSession * z,Packet ** pkt)149 dispatchSync(VtSession *z, Packet **pkt)
150 {
151 (z->vtbl->sync)(z);
152 if(packetSize(*pkt) != 0) {
153 vtSetError(EProtocolBotch);
154 return 0;
155 }
156 return 1;
157 }
158
159 int
vtExport(VtSession * z)160 vtExport(VtSession *z)
161 {
162 Packet *p;
163 uchar buf[10], *hdr;
164 int op, tid, clean;
165
166 if(z->vtbl == nil) {
167 vtSetError(ENotServer);
168 return 0;
169 }
170
171 /* fork off slave */
172 switch(rfork(RFNOWAIT|RFMEM|RFPROC)){
173 case -1:
174 vtOSError();
175 return 0;
176 case 0:
177 break;
178 default:
179 return 1;
180 }
181
182
183 p = nil;
184 clean = 0;
185 vtAttach();
186 if(!vtConnect(z, nil))
187 goto Exit;
188
189 vtDebug(z, "server connected!\n");
190 if(0) vtSetDebug(z, 1);
191
192 for(;;) {
193 p = vtRecvPacket(z);
194 if(p == nil) {
195 break;
196 }
197 vtDebug(z, "server recv: ");
198 vtDebugMesg(z, p, "\n");
199
200 if(!packetConsume(p, buf, 2)) {
201 vtSetError(EProtocolBotch);
202 break;
203 }
204 op = buf[0];
205 tid = buf[1];
206 switch(op) {
207 default:
208 vtSetError(EProtocolBotch);
209 goto Exit;
210 case VtQPing:
211 break;
212 case VtQGoodbye:
213 clean = 1;
214 goto Exit;
215 case VtQHello:
216 if(!dispatchHello(z, &p))
217 goto Exit;
218 break;
219 case VtQRead:
220 if(!dispatchRead(z, &p))
221 goto Exit;
222 break;
223 case VtQWrite:
224 if(!dispatchWrite(z, &p))
225 goto Exit;
226 break;
227 case VtQSync:
228 if(!dispatchSync(z, &p))
229 goto Exit;
230 break;
231 }
232 if(p != nil) {
233 hdr = packetHeader(p, 2);
234 hdr[0] = op+1;
235 hdr[1] = tid;
236 } else {
237 p = packetAlloc();
238 hdr = packetHeader(p, 2);
239 hdr[0] = VtRError;
240 hdr[1] = tid;
241 if(!vtAddString(p, vtGetError()))
242 goto Exit;
243 }
244
245 vtDebug(z, "server send: ");
246 vtDebugMesg(z, p, "\n");
247
248 if(!vtSendPacket(z, p)) {
249 p = nil;
250 goto Exit;
251 }
252 }
253 Exit:
254 if(p != nil)
255 packetFree(p);
256 if(z->vtbl->closing)
257 z->vtbl->closing(z, clean);
258 vtClose(z);
259 vtFree(z);
260 vtDetach();
261
262 exits(0);
263 return 0; /* never gets here */
264 }
265
266