xref: /plan9/sys/src/cmd/upas/bayes/bayes.c (revision 4246b6162acdbb658503b8bdc98024362bfbf0fe)
1 #include <u.h>
2 #include <libc.h>
3 #include <bio.h>
4 #include <regexp.h>
5 #include "hash.h"
6 
7 enum
8 {
9 	MAXTAB = 256,
10 	MAXBEST = 32,
11 };
12 
13 typedef struct Table Table;
14 struct Table
15 {
16 	char *file;
17 	Hash *hash;
18 	int nmsg;
19 };
20 
21 typedef struct Word Word;
22 struct Word
23 {
24 	Stringtab *s;	/* from hmsg */
25 	int count[MAXTAB];	/* counts from each table */
26 	double p[MAXTAB];	/* probabilities from each table */
27 	double mp;	/* max probability */
28 	int mi;		/* w.p[w.mi] = w.mp */
29 };
30 
31 Table tab[MAXTAB];
32 int ntab;
33 
34 Word best[MAXBEST];
35 int mbest;
36 int nbest;
37 
38 int debug;
39 
40 void
usage(void)41 usage(void)
42 {
43 	fprint(2, "usage: bayes [-D] [-m maxword] boxhash ... ~ msghash ...\n");
44 	exits("usage");
45 }
46 
47 void*
emalloc(int n)48 emalloc(int n)
49 {
50 	void *v;
51 
52 	v = mallocz(n, 1);
53 	if(v == nil)
54 		sysfatal("out of memory");
55 	return v;
56 }
57 
58 void
noteword(Word * w)59 noteword(Word *w)
60 {
61 	int i;
62 
63 	for(i=nbest-1; i>=0; i--)
64 		if(w->mp < best[i].mp)
65 			break;
66 	i++;
67 
68 	if(i >= mbest)
69 		return;
70 	if(nbest == mbest)
71 		nbest--;
72 	if(i < nbest)
73 		memmove(&best[i+1], &best[i], (nbest-i)*sizeof(best[0]));
74 	best[i] = *w;
75 	nbest++;
76 }
77 
78 Hash*
hread(char * s)79 hread(char *s)
80 {
81 	Hash *h;
82 	Biobuf *b;
83 
84 	if((b = Bopenlock(s, OREAD)) == nil)
85 		sysfatal("open %s: %r", s);
86 
87 	h = emalloc(sizeof(Hash));
88 	Breadhash(b, h, 1);
89 	Bterm(b);
90 	return h;
91 }
92 
93 void
main(int argc,char ** argv)94 main(int argc, char **argv)
95 {
96 	int i, j, a, mi, oi, tot, keywords;
97 	double totp, p, xp[MAXTAB];
98 	Hash *hmsg;
99 	Word w;
100 	Stringtab *s, *t;
101 	Biobuf bout;
102 
103 	mbest = 15;
104 	keywords = 0;
105 	ARGBEGIN{
106 	case 'D':
107 		debug = 1;
108 		break;
109 	case 'k':
110 		keywords = 1;
111 		break;
112 	case 'm':
113 		mbest = atoi(EARGF(usage()));
114 		if(mbest > MAXBEST)
115 			sysfatal("cannot keep more than %d words", MAXBEST);
116 		break;
117 	default:
118 		usage();
119 	}ARGEND
120 
121 	for(i=0; i<argc; i++)
122 		if(strcmp(argv[i], "~") == 0)
123 			break;
124 
125 	if(i > MAXTAB)
126 		sysfatal("cannot handle more than %d tables", MAXTAB);
127 
128 	if(i+1 >= argc)
129 		usage();
130 
131 	for(i=0; i<argc; i++){
132 		if(strcmp(argv[i], "~") == 0)
133 			break;
134 		tab[ntab].file = argv[i];
135 		tab[ntab].hash = hread(argv[i]);
136 		s = findstab(tab[ntab].hash, "*nmsg*", 6, 1);
137 		if(s == nil || s->count == 0)
138 			tab[ntab].nmsg = 1;
139 		else
140 			tab[ntab].nmsg = s->count;
141 		ntab++;
142 	}
143 
144 	Binit(&bout, 1, OWRITE);
145 
146 	oi = ++i;
147 	for(a=i; a<argc; a++){
148 		hmsg = hread(argv[a]);
149 		nbest = 0;
150 		for(s=hmsg->all; s; s=s->link){
151 			w.s = s;
152 			tot = 0;
153 			totp = 0.0;
154 			for(i=0; i<ntab; i++){
155 				t = findstab(tab[i].hash, s->str, s->n, 0);
156 				if(t == nil)
157 					w.count[i] = 0;
158 				else
159 					w.count[i] = t->count;
160 				tot += w.count[i];
161 				p = w.count[i]/(double)tab[i].nmsg;
162 				if(p >= 1.0)
163 					p = 1.0;
164 				w.p[i] = p;
165 				totp += p;
166 			}
167 
168 			if(tot < 5){		/* word does not appear enough; give to box 0 */
169 				w.p[0] = 0.5;
170 				for(i=1; i<ntab; i++)
171 					w.p[i] = 0.1;
172 				w.mp = 0.5;
173 				w.mi = 0;
174 				noteword(&w);
175 				continue;
176 			}
177 
178 			w.mp = 0.0;
179 			for(i=0; i<ntab; i++){
180 				p = w.p[i];
181 				p /= totp;
182 				if(p < 0.01)
183 					p = 0.01;
184 				else if(p > 0.99)
185 					p = 0.99;
186 				if(p > w.mp){
187 					w.mp = p;
188 					w.mi = i;
189 				}
190 				w.p[i] = p;
191 			}
192 			noteword(&w);
193 		}
194 
195 		totp = 0.0;
196 		for(i=0; i<ntab; i++){
197 			p = 1.0;
198 			for(j=0; j<nbest; j++)
199 				p *= best[j].p[i];
200 			xp[i] = p;
201 			totp += p;
202 		}
203 		for(i=0; i<ntab; i++)
204 			xp[i] /= totp;
205 		mi = 0;
206 		for(i=1; i<ntab; i++)
207 			if(xp[i] > xp[mi])
208 				mi = i;
209 		if(oi != argc-1)
210 			Bprint(&bout, "%s: ", argv[a]);
211 		Bprint(&bout, "%s %f", tab[mi].file, xp[mi]);
212 		if(keywords){
213 			for(i=0; i<nbest; i++){
214 				Bprint(&bout, " ");
215 				Bwrite(&bout, best[i].s->str, best[i].s->n);
216 				Bprint(&bout, " %f", best[i].p[mi]);
217 			}
218 		}
219 		freehash(hmsg);
220 		Bprint(&bout, "\n");
221 		if(debug){
222 			for(i=0; i<nbest; i++){
223 				Bwrite(&bout, best[i].s->str, best[i].s->n);
224 				Bprint(&bout, " %f", best[i].p[mi]);
225 				if(best[i].p[mi] < best[i].mp)
226 					Bprint(&bout, " (%f %s)", best[i].mp, tab[best[i].mi].file);
227 				Bprint(&bout, "\n");
228 			}
229 		}
230 	}
231 	Bterm(&bout);
232 }
233