Subversion Repositories planix.SVN

Rev

Rev 2 | Details | Compare with Previous | Last modification | View Log | RSS feed

Rev Author Line No. Line
2 - 1
#include <u.h>
2
#include <libc.h>
3
#include <bio.h>
4
#include <regexp.h>
5
#include "hash.h"
6
 
7
enum
8
{
9
	MAXTAB = 256,
10
	MAXBEST = 32,
11
};
12
 
13
typedef struct Table Table;
14
struct Table
15
{
16
	char *file;
17
	Hash *hash;
18
	int nmsg;
19
};
20
 
21
typedef struct Word Word;
22
struct Word
23
{
24
	Stringtab *s;	/* from hmsg */
25
	int count[MAXTAB];	/* counts from each table */
26
	double p[MAXTAB];	/* probabilities from each table */
27
	double mp;	/* max probability */
28
	int mi;		/* w.p[w.mi] = w.mp */
29
};
30
 
31
Table tab[MAXTAB];
32
int ntab;
33
 
34
Word best[MAXBEST];
35
int mbest;
36
int nbest;
37
 
38
int debug;
39
 
40
void
41
usage(void)
42
{
43
	fprint(2, "usage: bayes [-D] [-m maxword] boxhash ... ~ msghash ...\n");
44
	exits("usage");
45
}
46
 
47
void*
48
emalloc(int n)
49
{
50
	void *v;
51
 
52
	v = mallocz(n, 1);
53
	if(v == nil)
54
		sysfatal("out of memory");
55
	return v;
56
}
57
 
58
void
59
noteword(Word *w)
60
{
61
	int i;
62
 
63
	for(i=nbest-1; i>=0; i--)
64
		if(w->mp < best[i].mp)
65
			break;
66
	i++;
67
 
68
	if(i >= mbest)
69
		return;
70
	if(nbest == mbest)
71
		nbest--;
72
	if(i < nbest)
73
		memmove(&best[i+1], &best[i], (nbest-i)*sizeof(best[0]));
74
	best[i] = *w;
75
	nbest++;
76
}
77
 
78
Hash*
79
hread(char *s)
80
{
81
	Hash *h;
82
	Biobuf *b;
83
 
84
	if((b = Bopenlock(s, OREAD)) == nil)
85
		sysfatal("open %s: %r", s);
86
 
87
	h = emalloc(sizeof(Hash));
88
	Breadhash(b, h, 1);
89
	Bterm(b);
90
	return h;
91
}
92
 
93
void
94
main(int argc, char **argv)
95
{
96
	int i, j, a, mi, oi, tot, keywords;
97
	double totp, p, xp[MAXTAB];
98
	Hash *hmsg;
99
	Word w;
100
	Stringtab *s, *t;
101
	Biobuf bout;
102
 
103
	mbest = 15;
104
	keywords = 0;
105
	ARGBEGIN{
106
	case 'D':
107
		debug = 1;
108
		break;
109
	case 'k':
110
		keywords = 1;
111
		break;
112
	case 'm':
113
		mbest = atoi(EARGF(usage()));
114
		if(mbest > MAXBEST)
115
			sysfatal("cannot keep more than %d words", MAXBEST);
116
		break;
117
	default:
118
		usage();
119
	}ARGEND
120
 
121
	for(i=0; i<argc; i++)
122
		if(strcmp(argv[i], "~") == 0)
123
			break;
124
 
125
	if(i > MAXTAB)
126
		sysfatal("cannot handle more than %d tables", MAXTAB);
127
 
128
	if(i+1 >= argc)
129
		usage();
130
 
131
	for(i=0; i<argc; i++){
132
		if(strcmp(argv[i], "~") == 0)
133
			break;
134
		tab[ntab].file = argv[i];
135
		tab[ntab].hash = hread(argv[i]);
136
		s = findstab(tab[ntab].hash, "*nmsg*", 6, 1);
137
		if(s == nil || s->count == 0)
138
			tab[ntab].nmsg = 1;
139
		else
140
			tab[ntab].nmsg = s->count;
141
		ntab++;
142
	}
143
 
144
	Binit(&bout, 1, OWRITE);
145
 
146
	oi = ++i;
147
	for(a=i; a<argc; a++){
148
		hmsg = hread(argv[a]);
149
		nbest = 0;
150
		for(s=hmsg->all; s; s=s->link){
151
			w.s = s;
152
			tot = 0;
153
			totp = 0.0;
154
			for(i=0; i<ntab; i++){
155
				t = findstab(tab[i].hash, s->str, s->n, 0);
156
				if(t == nil)
157
					w.count[i] = 0;
158
				else
159
					w.count[i] = t->count;
160
				tot += w.count[i];
161
				p = w.count[i]/(double)tab[i].nmsg;
162
				if(p >= 1.0)
163
					p = 1.0;
164
				w.p[i] = p;
165
				totp += p;
166
			}
167
 
168
			if(tot < 5){		/* word does not appear enough; give to box 0 */
169
				w.p[0] = 0.5;
170
				for(i=1; i<ntab; i++)
171
					w.p[i] = 0.1;
172
				w.mp = 0.5;
173
				w.mi = 0;
174
				noteword(&w);
175
				continue;
176
			}
177
 
178
			w.mp = 0.0;
179
			for(i=0; i<ntab; i++){
180
				p = w.p[i];
181
				p /= totp;
182
				if(p < 0.01)
183
					p = 0.01;
184
				else if(p > 0.99)
185
					p = 0.99;
186
				if(p > w.mp){
187
					w.mp = p;
188
					w.mi = i;
189
				}
190
				w.p[i] = p;
191
			}
192
			noteword(&w);
193
		}
194
 
195
		totp = 0.0;
196
		for(i=0; i<ntab; i++){
197
			p = 1.0;
198
			for(j=0; j<nbest; j++)
199
				p *= best[j].p[i];
200
			xp[i] = p;
201
			totp += p;
202
		}
203
		for(i=0; i<ntab; i++)
204
			xp[i] /= totp;
205
		mi = 0;
206
		for(i=1; i<ntab; i++)
207
			if(xp[i] > xp[mi])
208
				mi = i;
209
		if(oi != argc-1)
210
			Bprint(&bout, "%s: ", argv[a]);
211
		Bprint(&bout, "%s %f", tab[mi].file, xp[mi]);
212
		if(keywords){
213
			for(i=0; i<nbest; i++){
214
				Bprint(&bout, " ");
215
				Bwrite(&bout, best[i].s->str, best[i].s->n);
216
				Bprint(&bout, " %f", best[i].p[mi]);
217
			}
218
		}
219
		freehash(hmsg);
220
		Bprint(&bout, "\n");
221
		if(debug){
222
			for(i=0; i<nbest; i++){
223
				Bwrite(&bout, best[i].s->str, best[i].s->n);
224
				Bprint(&bout, " %f", best[i].p[mi]);
225
				if(best[i].p[mi] < best[i].mp)
226
					Bprint(&bout, " (%f %s)", best[i].mp, tab[best[i].mi].file);
227
				Bprint(&bout, "\n");
228
			}
229
		}
230
	}
231
	Bterm(&bout);
232
}