Subversion Repositories planix.SVN

Rev

Rev 2 | Blame | Compare with Previous | Last modification | View Log | RSS feed

#include <u.h>
#include <libc.h>
#include <bio.h>
#include <regexp.h>
#include "hash.h"

enum
{
        MAXTAB = 256,
        MAXBEST = 32,
};

typedef struct Table Table;
struct Table
{
        char *file;
        Hash *hash;
        int nmsg;
};

typedef struct Word Word;
struct Word
{
        Stringtab *s;   /* from hmsg */
        int count[MAXTAB];      /* counts from each table */
        double p[MAXTAB];       /* probabilities from each table */
        double mp;      /* max probability */
        int mi;         /* w.p[w.mi] = w.mp */
};

Table tab[MAXTAB];
int ntab;

Word best[MAXBEST];
int mbest;
int nbest;

int debug;

void
usage(void)
{
        fprint(2, "usage: bayes [-D] [-m maxword] boxhash ... ~ msghash ...\n");
        exits("usage");
}

void*
emalloc(int n)
{
        void *v;

        v = mallocz(n, 1);
        if(v == nil)
                sysfatal("out of memory");
        return v;
}

void
noteword(Word *w)
{
        int i;

        for(i=nbest-1; i>=0; i--)
                if(w->mp < best[i].mp)
                        break;
        i++;

        if(i >= mbest)
                return;
        if(nbest == mbest)
                nbest--;
        if(i < nbest)
                memmove(&best[i+1], &best[i], (nbest-i)*sizeof(best[0]));
        best[i] = *w;
        nbest++;
}

Hash*
hread(char *s)
{
        Hash *h;
        Biobuf *b;

        if((b = Bopenlock(s, OREAD)) == nil)
                sysfatal("open %s: %r", s);

        h = emalloc(sizeof(Hash));
        Breadhash(b, h, 1);
        Bterm(b);
        return h;
}

void
main(int argc, char **argv)
{
        int i, j, a, mi, oi, tot, keywords;
        double totp, p, xp[MAXTAB];
        Hash *hmsg;
        Word w;
        Stringtab *s, *t;
        Biobuf bout;

        mbest = 15;
        keywords = 0;
        ARGBEGIN{
        case 'D':
                debug = 1;
                break;
        case 'k':
                keywords = 1;
                break;
        case 'm':
                mbest = atoi(EARGF(usage()));
                if(mbest > MAXBEST)
                        sysfatal("cannot keep more than %d words", MAXBEST);
                break;
        default:
                usage();
        }ARGEND

        for(i=0; i<argc; i++)
                if(strcmp(argv[i], "~") == 0)
                        break;

        if(i > MAXTAB)
                sysfatal("cannot handle more than %d tables", MAXTAB);

        if(i+1 >= argc)
                usage();

        for(i=0; i<argc; i++){
                if(strcmp(argv[i], "~") == 0)
                        break;
                tab[ntab].file = argv[i];
                tab[ntab].hash = hread(argv[i]);
                s = findstab(tab[ntab].hash, "*nmsg*", 6, 1);
                if(s == nil || s->count == 0)
                        tab[ntab].nmsg = 1;
                else
                        tab[ntab].nmsg = s->count;
                ntab++;
        }

        Binit(&bout, 1, OWRITE);

        oi = ++i;
        for(a=i; a<argc; a++){
                hmsg = hread(argv[a]);
                nbest = 0;
                for(s=hmsg->all; s; s=s->link){
                        w.s = s;
                        tot = 0;
                        totp = 0.0;
                        for(i=0; i<ntab; i++){
                                t = findstab(tab[i].hash, s->str, s->n, 0);
                                if(t == nil)
                                        w.count[i] = 0;
                                else
                                        w.count[i] = t->count;
                                tot += w.count[i];
                                p = w.count[i]/(double)tab[i].nmsg;
                                if(p >= 1.0)
                                        p = 1.0;
                                w.p[i] = p;
                                totp += p;
                        }

                        if(tot < 5){            /* word does not appear enough; give to box 0 */
                                w.p[0] = 0.5;
                                for(i=1; i<ntab; i++)
                                        w.p[i] = 0.1;
                                w.mp = 0.5;
                                w.mi = 0;
                                noteword(&w);
                                continue;
                        }

                        w.mp = 0.0;
                        for(i=0; i<ntab; i++){
                                p = w.p[i];
                                p /= totp;
                                if(p < 0.01)
                                        p = 0.01;
                                else if(p > 0.99)
                                        p = 0.99;
                                if(p > w.mp){
                                        w.mp = p;
                                        w.mi = i;
                                }
                                w.p[i] = p;
                        }
                        noteword(&w);
                }

                totp = 0.0;
                for(i=0; i<ntab; i++){
                        p = 1.0;
                        for(j=0; j<nbest; j++)
                                p *= best[j].p[i];
                        xp[i] = p;
                        totp += p;
                }
                for(i=0; i<ntab; i++)
                        xp[i] /= totp;
                mi = 0;
                for(i=1; i<ntab; i++)
                        if(xp[i] > xp[mi])
                                mi = i;
                if(oi != argc-1)
                        Bprint(&bout, "%s: ", argv[a]);
                Bprint(&bout, "%s %f", tab[mi].file, xp[mi]);
                if(keywords){
                        for(i=0; i<nbest; i++){
                                Bprint(&bout, " ");
                                Bwrite(&bout, best[i].s->str, best[i].s->n);
                                Bprint(&bout, " %f", best[i].p[mi]);
                        }
                }
                freehash(hmsg);
                Bprint(&bout, "\n");
                if(debug){
                        for(i=0; i<nbest; i++){
                                Bwrite(&bout, best[i].s->str, best[i].s->n);
                                Bprint(&bout, " %f", best[i].p[mi]);
                                if(best[i].p[mi] < best[i].mp)
                                        Bprint(&bout, " (%f %s)", best[i].mp, tab[best[i].mi].file);
                                Bprint(&bout, "\n");
                        }
                }
        }
        Bterm(&bout);
}