xref: /inferno-os/appl/lib/db.b (revision 9b29ac7ea714507a9c0690620c02c8ca5ab25f90)
1implement DB;
2
3include "sys.m";
4	sys: Sys;
5
6include "dial.m";
7
8include "keyring.m";
9
10include "security.m";
11
12include "db.m";
13
14RES_HEADER_SIZE: con 22;
15
16open(addr, username, password, dbname: string): (ref DB_Handle, list of string)
17{
18	(fd, err) := connect(addr, "none");
19	if(nil == fd)
20		return (nil, err :: nil);
21	return dbopen(fd, username, password, dbname);
22}
23
24connect(addr: string, alg: string): (ref Sys->FD, string)
25{
26	if (sys == nil)
27		sys = load Sys Sys->PATH;
28
29	dial := load Dial Dial->PATH;
30	if(dial == nil)
31		return (nil, sys->sprint("load %s: %r", Dial->PATH));
32
33	addr = dial->netmkaddr(addr, "net", "6669");	# infdb
34
35	conn := dial->dial(addr, nil);
36	if(conn == nil)
37		return (nil, sys->sprint("can't dial %s: %r", addr));
38
39	(n, addrparts) := sys->tokenize(addr, "!");
40	if(n >= 2)
41		addr = hd addrparts + "!" + hd tl addrparts;	# ignore service for key search
42
43	kr := load Keyring Keyring->PATH;
44
45	user := user();
46	kd := "/usr/" + user + "/keyring/";
47	cert := kd + addr;
48	if(sys->stat(cert).t0 < 0)
49		cert = kd + "default";
50
51	ai := kr->readauthinfo(cert);
52
53	#
54	# let auth->client handle nil ai
55	# if(ai == nil){
56	#	return (nil, sys->sprint("DB init: certificate for %s not found, use getauthinfo first", addr));
57	# }
58	#
59
60	au := load Auth Auth->PATH;
61	if(au == nil)
62		return (nil, sys->sprint("DB init: can't load module Auth %r"));
63
64	err := au->init();
65	if(err != nil)
66		return (nil, sys->sprint("DB init: can't initialize module Auth: %s", err));
67
68	fd: ref Sys->FD;
69
70	(fd, err) = au->client(alg, ai, conn.dfd);
71	if(fd == nil)
72		return (nil, sys->sprint("DB init: authentication failed: %s", err));
73
74	return (fd, nil);
75}
76
77dbopen(fd: ref Sys->FD, username, password, dbname: string): (ref DB_Handle, list of string)
78{
79	dbh := ref DB_Handle;
80	dbh.datafd = fd;
81	dbh.lock = makelock();
82	dbh.sqlstream = -1;
83	logon := array of byte (username +"/"+ password +"/"+ dbname);
84	(mtype, strm, rc, data) := sendReq(dbh, 'I', logon);
85	if(mtype == 'h')
86		return (nil, (sys->sprint("DB: couldn't initialize %s for %s", dbname, username) :: string data :: nil));
87	dbh.sqlconn = int string data;
88
89	(mtype, strm, rc, data) = sendReq(dbh, 'O', array of byte string dbh.sqlconn);
90	if(mtype == 'h')
91		return (nil, (sys->sprint("DB: couldn't open SQL connection") :: string data :: nil));
92	dbh.sqlstream = int string data;
93	return (dbh, nil);
94}
95
96DB_Handle.SQLOpen(oldh: self ref DB_Handle): (int, ref DB_Handle)
97{
98	dbh := ref *oldh;
99	(mtype, nil, nil, data) := sendReq(dbh, 'O', array of byte string dbh.sqlconn);
100	if(mtype == 'h')
101		return (-1, nil);
102	dbh.sqlstream = int string data;
103	return (0, dbh);
104}
105
106DB_Handle.SQLClose(dbh: self ref DB_Handle): int
107{
108	(mtype, nil, nil, nil) := sendReq(dbh, 'K', array[0] of byte);
109	if(mtype == 'h')
110		return -1;
111	dbh.sqlstream = -1;
112	return 0;
113}
114
115DB_Handle.SQL(dbh: self ref DB_Handle, command: string): (int, list of string)
116{
117	(mtype, nil, nil, data) := sendReq(dbh, 'W', array of byte command);
118	if(mtype == 'h')
119		return (-1, "Probable SQL format error" :: string data :: nil);
120	return (0, nil);
121}
122
123DB_Handle.columns(dbh: self ref DB_Handle): int
124{
125	(mtype, nil, nil, data) := sendReq(dbh, 'C', array[0] of byte);
126	if(mtype == 'h')
127		return 0;
128	return int string data;
129}
130
131DB_Handle.nextRow(dbh: self ref DB_Handle): int
132{
133	(mtype, nil, nil, data) := sendReq(dbh, 'N', array[0] of byte);
134	if(mtype == 'h')
135		return 0;
136	return int string data;
137}
138
139DB_Handle.read(dbh: self ref DB_Handle, columnI: int): (int, array of byte)
140{
141	(mtype, nil, nil, data) := sendReq(dbh, 'R', array of byte string columnI);
142	if(mtype == 'h')
143		return (-1, data);
144	return (len data, data);
145}
146
147DB_Handle.write(dbh: self ref DB_Handle, paramI: int, val: array of byte)
148									: int
149{
150	outbuf := array[len val + 4] of byte;
151	param := array of byte sys->sprint("%3d ", paramI);
152
153	for(i := 0; i < 4; i++)
154		outbuf[i] = param[i];
155	outbuf[4:] = val;
156	(mtype, nil, nil, nil) := sendReq(dbh, 'P', outbuf);
157	if(mtype == 'h')
158		return -1;
159	return len val;
160}
161
162DB_Handle.columnTitle(handle: self ref DB_Handle, columnI: int): string
163{
164	(mtype, nil, nil, data) := sendReq(handle, 'T', array of byte string columnI);
165	if(mtype == 'h')
166		return nil;
167	return string data;
168}
169
170DB_Handle.errmsg(dbh: self ref DB_Handle): string
171{
172	(nil, nil, nil, data) := sendReq(dbh, 'H', array[0] of byte);
173	return string data;
174}
175
176sendReq(dbh: ref DB_Handle, mtype: int, data: array of byte) : (int, int, int, array of byte)
177{
178	lock(dbh);
179	header := sys->sprint("%c1%11d %3d ", mtype, len data, dbh.sqlstream);
180	if(sys->write(dbh.datafd, array of byte header, 18) != 18) {
181		unlock(dbh);
182		return ('h', dbh.sqlstream, 0, array of byte "header write failure");
183	}
184	if(sys->write(dbh.datafd, data, len data) != len data) {
185		unlock(dbh);
186		return ('h', dbh.sqlstream, 0, array of byte "data write failure");
187	}
188	if(sys->write(dbh.datafd, array of byte "\n", 1) != 1) {
189		unlock(dbh);
190		return ('h', dbh.sqlstream, 0, array of byte "header write failure");
191	}
192	hbuf := array[RES_HEADER_SIZE+3] of byte;
193	if((n := sys->readn(dbh.datafd, hbuf, RES_HEADER_SIZE)) != RES_HEADER_SIZE) {
194		unlock(dbh);
195		if(n < 0)
196			why := sys->aprint("read error: %r");
197		else if(n == 0)
198			why = sys->aprint("lost connection");
199		else
200			why = sys->aprint("read error: short read");
201		return ('h', dbh.sqlstream, 0, why);
202	}
203	rheader := string hbuf[0:22];
204	rtype := rheader[0];
205	#	Probably should check version in header[1]
206	datalen := int rheader[2:13];
207	rstrm := int rheader[14:17];
208	retcode := int rheader[18:21];
209
210	databuf := array[datalen] of byte;
211	# read in loop until get amount of data we want.  If there is a mismatch
212	# here, we may hang with a lock on!
213
214	nbytes: int;
215
216	for(length := 0; length < datalen; length += nbytes) {
217		nbytes = sys->read(dbh.datafd, databuf[length:], datalen-length);
218		if(nbytes <= 0) {
219		    break;
220		}
221	}
222	nbytes = sys->read(dbh.datafd, hbuf, 1);	#  The final \n
223	unlock(dbh);
224	return (rtype, rstrm, retcode, databuf);
225}
226
227makelock(): chan of int
228{
229	return chan[1] of int;
230}
231
232lock(h: ref DB_Handle)
233{
234	h.lock <-= h.sqlstream;
235}
236
237unlock(h: ref DB_Handle)
238{
239	<-h.lock;
240}
241
242user(): string
243{
244	sys = load Sys Sys->PATH;
245	fd := sys->open("/dev/user", sys->OREAD);
246	if(fd == nil)
247		return "";
248	buf := array[Sys->NAMEMAX] of byte;
249	n := sys->read(fd, buf, len buf);
250	if(n < 0)
251		return "";
252	return string buf[0:n];
253}
254