Subversion Repositories planix.SVN

Rev

Rev 2 | Blame | Compare with Previous | Last modification | View Log | RSS feed

/*
 * IP packet filter
 */
#include "u.h"
#include "../port/lib.h"
#include "mem.h"
#include "dat.h"
#include "fns.h"
#include "../port/error.h"

#include "ip.h"
#include "ipv6.h"

typedef struct Ipmuxrock  Ipmuxrock;
typedef struct Ipmux      Ipmux;

typedef struct Myip4hdr Myip4hdr;
struct Myip4hdr
{
        uchar   vihl;           /* Version and header length */
        uchar   tos;            /* Type of service */
        uchar   length[2];      /* packet length */
        uchar   id[2];          /* ip->identification */
        uchar   frag[2];        /* Fragment information */
        uchar   ttl;            /* Time to live */
        uchar   proto;          /* Protocol */
        uchar   cksum[2];       /* Header checksum */
        uchar   src[4];         /* IP source */
        uchar   dst[4];         /* IP destination */

        uchar   data[1];        /* start of data */
};
Myip4hdr *ipoff = 0;

enum
{
        Tproto,
        Tdata,
        Tiph,
        Tdst,
        Tsrc,
        Tifc,

        Cother = 0,
        Cbyte,          /* single byte */
        Cmbyte,         /* single byte with mask */
        Cshort,         /* single short */
        Cmshort,        /* single short with mask */
        Clong,          /* single long */
        Cmlong,         /* single long with mask */
        Cifc,
        Cmifc,
};

char *ftname[] =
{
[Tproto]        "proto",
[Tdata]         "data",
[Tiph]          "iph",
[Tdst]          "dst",
[Tsrc]          "src",
[Tifc]          "ifc",
};

/*
 *  a node in the decision tree
 */
struct Ipmux
{
        Ipmux   *yes;
        Ipmux   *no;
        uchar   type;           /* type of field(Txxxx) */
        uchar   ctype;          /* tupe of comparison(Cxxxx) */
        uchar   len;            /* length in bytes of item to compare */
        uchar   n;              /* number of items val points to */
        short   off;            /* offset of comparison */
        short   eoff;           /* end offset of comparison */
        uchar   skiphdr;        /* should offset start after ipheader */
        uchar   *val;
        uchar   *mask;
        uchar   *e;             /* val+n*len*/

        int     ref;            /* so we can garbage collect */
        Conv    *conv;
};

/*
 *  someplace to hold per conversation data
 */
struct Ipmuxrock
{
        Ipmux   *chain;
};

static int      ipmuxsprint(Ipmux*, int, char*, int);
static void     ipmuxkick(void *x);

static char*
skipwhite(char *p)
{
        while(*p == ' ' || *p == '\t')
                p++;
        return p;
}

static char*
follows(char *p, char c)
{
        char *f;

        f = strchr(p, c);
        if(f == nil)
                return nil;
        *f++ = 0;
        f = skipwhite(f);
        if(*f == 0)
                return nil;
        return f;
}

static Ipmux*
parseop(char **pp)
{
        char *p = *pp;
        int type, off, end, len;
        Ipmux *f;

        p = skipwhite(p);
        if(strncmp(p, "dst", 3) == 0){
                type = Tdst;
                off = (ulong)(ipoff->dst);
                len = IPv4addrlen;
                p += 3;
        }
        else if(strncmp(p, "src", 3) == 0){
                type = Tsrc;
                off = (ulong)(ipoff->src);
                len = IPv4addrlen;
                p += 3;
        }
        else if(strncmp(p, "ifc", 3) == 0){
                type = Tifc;
                off = -IPv4addrlen;
                len = IPv4addrlen;
                p += 3;
        }
        else if(strncmp(p, "proto", 5) == 0){
                type = Tproto;
                off = (ulong)&(ipoff->proto);
                len = 1;
                p += 5;
        }
        else if(strncmp(p, "data", 4) == 0 || strncmp(p, "iph", 3) == 0){
                if(strncmp(p, "data", 4) == 0) {
                        type = Tdata;
                        p += 4;
                }
                else {
                        type = Tiph;
                        p += 3;
                }
                p = skipwhite(p);
                if(*p != '[')
                        return nil;
                p++;
                off = strtoul(p, &p, 0);
                if(off < 0 || off > (64-IP4HDR))
                        return nil;
                p = skipwhite(p);
                if(*p != ':')
                        end = off;
                else {
                        p++;
                        p = skipwhite(p);
                        end = strtoul(p, &p, 0);
                        if(end < off)
                                return nil;
                        p = skipwhite(p);
                }
                if(*p != ']')
                        return nil;
                p++;
                len = end - off + 1;
        }
        else
                return nil;

        f = smalloc(sizeof(*f));
        f->type = type;
        f->len = len;
        f->off = off;
        f->val = nil;
        f->mask = nil;
        f->n = 1;
        f->ref = 1;
        if(type == Tdata)
                f->skiphdr = 1;
        else
                f->skiphdr = 0;

        return f;
}

static int
htoi(char x)
{
        if(x >= '0' && x <= '9')
                x -= '0';
        else if(x >= 'a' && x <= 'f')
                x -= 'a' - 10;
        else if(x >= 'A' && x <= 'F')
                x -= 'A' - 10;
        else
                x = 0;
        return x;
}

static int
hextoi(char *p)
{
        return (htoi(p[0])<<4) | htoi(p[1]);
}

static void
parseval(uchar *v, char *p, int len)
{
        while(*p && len-- > 0){
                *v++ = hextoi(p);
                p += 2;
        }
}

static Ipmux*
parsemux(char *p)
{
        int n, nomask;
        Ipmux *f;
        char *val;
        char *mask;
        char *vals[20];
        uchar *v;

        /* parse operand */
        f = parseop(&p);
        if(f == nil)
                return nil;

        /* find value */
        val = follows(p, '=');
        if(val == nil)
                goto parseerror;

        /* parse mask */
        mask = follows(p, '&');
        if(mask != nil){
                switch(f->type){
                case Tsrc:
                case Tdst:
                case Tifc:
                        f->mask = smalloc(f->len);
                        v4parseip(f->mask, mask);
                        break;
                case Tdata:
                case Tiph:
                        f->mask = smalloc(f->len);
                        parseval(f->mask, mask, f->len);
                        break;
                default:
                        goto parseerror;
                }
                nomask = 0;
        } else {
                nomask = 1;
                f->mask = smalloc(f->len);
                memset(f->mask, 0xff, f->len);
        }

        /* parse vals */
        f->n = getfields(val, vals, sizeof(vals)/sizeof(char*), 1, "|");
        if(f->n == 0)
                goto parseerror;
        f->val = smalloc(f->n*f->len);
        v = f->val;
        for(n = 0; n < f->n; n++){
                switch(f->type){
                case Tsrc:
                case Tdst:
                case Tifc:
                        v4parseip(v, vals[n]);
                        break;
                case Tproto:
                case Tdata:
                case Tiph:
                        parseval(v, vals[n], f->len);
                        break;
                }
                v += f->len;
        }

        f->eoff = f->off + f->len;
        f->e = f->val + f->n*f->len;
        f->ctype = Cother;
        if(f->n == 1){
                switch(f->len){
                case 1:
                        f->ctype = nomask ? Cbyte : Cmbyte;
                        break;
                case 2:
                        f->ctype = nomask ? Cshort : Cmshort;
                        break;
                case 4:
                        if(f->type == Tifc)
                                f->ctype = nomask ? Cifc : Cmifc;
                        else
                                f->ctype = nomask ? Clong : Cmlong;
                        break;
                }
        }
        return f;

parseerror:
        if(f->mask)
                free(f->mask);
        if(f->val)
                free(f->val);
        free(f);
        return nil;
}

/*
 *  Compare relative ordering of two ipmuxs.  This doesn't compare the
 *  values, just the fields being looked at.
 *
 *  returns:    <0 if a is a more specific match
 *               0 if a and b are matching on the same fields
 *              >0 if b is a more specific match
 */
static int
ipmuxcmp(Ipmux *a, Ipmux *b)
{
        int n;

        /* compare types, lesser ones are more important */
        n = a->type - b->type;
        if(n != 0)
                return n;

        /* compare offsets, call earlier ones more specific */
        n = (a->off+((int)a->skiphdr)*(ulong)ipoff->data) -
                (b->off+((int)b->skiphdr)*(ulong)ipoff->data);
        if(n != 0)
                return n;

        /* compare match lengths, longer ones are more specific */
        n = b->len - a->len;
        if(n != 0)
                return n;

        /*
         *  if we get here we have two entries matching
         *  the same bytes of the record.  Now check
         *  the mask for equality.  Longer masks are
         *  more specific.
         */
        if(a->mask != nil && b->mask == nil)
                return -1;
        if(a->mask == nil && b->mask != nil)
                return 1;
        if(a->mask != nil && b->mask != nil){
                n = memcmp(b->mask, a->mask, a->len);
                if(n != 0)
                        return n;
        }
        return 0;
}

/*
 *  Compare the values of two ipmuxs.  We're assuming that ipmuxcmp
 *  returned 0 comparing them.
 */
static int
ipmuxvalcmp(Ipmux *a, Ipmux *b)
{
        int n;

        n = b->len*b->n - a->len*a->n;
        if(n != 0)
                return n;
        return memcmp(a->val, b->val, a->len*a->n);
}

/*
 *  add onto an existing ipmux chain in the canonical comparison
 *  order
 */
static void
ipmuxchain(Ipmux **l, Ipmux *f)
{
        for(; *l; l = &(*l)->yes)
                if(ipmuxcmp(f, *l) < 0)
                        break;
        f->yes = *l;
        *l = f;
}

/*
 *  copy a tree
 */
static Ipmux*
ipmuxcopy(Ipmux *f)
{
        Ipmux *nf;

        if(f == nil)
                return nil;
        nf = smalloc(sizeof *nf);
        *nf = *f;
        nf->no = ipmuxcopy(f->no);
        nf->yes = ipmuxcopy(f->yes);
        nf->val = smalloc(f->n*f->len);
        nf->e = nf->val + f->len*f->n;
        memmove(nf->val, f->val, f->n*f->len);
        return nf;
}

static void
ipmuxfree(Ipmux *f)
{
        if(f->val != nil)
                free(f->val);
        free(f);
}

static void
ipmuxtreefree(Ipmux *f)
{
        if(f == nil)
                return;
        if(f->no != nil)
                ipmuxfree(f->no);
        if(f->yes != nil)
                ipmuxfree(f->yes);
        ipmuxfree(f);
}

/*
 *  merge two trees
 */
static Ipmux*
ipmuxmerge(Ipmux *a, Ipmux *b)
{
        int n;
        Ipmux *f;

        if(a == nil)
                return b;
        if(b == nil)
                return a;
        n = ipmuxcmp(a, b);
        if(n < 0){
                f = ipmuxcopy(b);
                a->yes = ipmuxmerge(a->yes, b);
                a->no = ipmuxmerge(a->no, f);
                return a;
        }
        if(n > 0){
                f = ipmuxcopy(a);
                b->yes = ipmuxmerge(b->yes, a);
                b->no = ipmuxmerge(b->no, f);
                return b;
        }
        if(ipmuxvalcmp(a, b) == 0){
                a->yes = ipmuxmerge(a->yes, b->yes);
                a->no = ipmuxmerge(a->no, b->no);
                a->ref++;
                ipmuxfree(b);
                return a;
        }
        a->no = ipmuxmerge(a->no, b);
        return a;
}

/*
 *  remove a chain from a demux tree.  This is like merging accept that
 *  we remove instead of insert.
 */
static int
ipmuxremove(Ipmux **l, Ipmux *f)
{
        int n, rv;
        Ipmux *ft;

        if(f == nil)
                return 0;               /* we've removed it all */
        if(*l == nil)
                return -1;

        ft = *l;
        n = ipmuxcmp(ft, f);
        if(n < 0){
                /* *l is maching an earlier field, descend both paths */
                rv = ipmuxremove(&ft->yes, f);
                rv += ipmuxremove(&ft->no, f);
                return rv;
        }
        if(n > 0){
                /* f represents an earlier field than *l, this should be impossible */
                return -1;
        }

        /* if we get here f and *l are comparing the same fields */
        if(ipmuxvalcmp(ft, f) != 0){
                /* different values mean mutually exclusive */
                return ipmuxremove(&ft->no, f);
        }

        /* we found a match */
        if(--(ft->ref) == 0){
                /*
                 *  a dead node implies the whole yes side is also dead.
                 *  since our chain is constrained to be on that side,
                 *  we're done.
                 */
                ipmuxtreefree(ft->yes);
                *l = ft->no;
                ipmuxfree(ft);
                return 0;
        }

        /*
         *  free the rest of the chain.  it is constrained to match the
         *  yes side.
         */
        return ipmuxremove(&ft->yes, f->yes);
}

/*
 *  connection request is a semi separated list of filters
 *  e.g. proto=17;data[0:4]=11aa22bb;ifc=135.104.9.2&255.255.255.0
 *
 *  there's no protection against overlapping specs.
 */
static char*
ipmuxconnect(Conv *c, char **argv, int argc)
{
        int i, n;
        char *field[10];
        Ipmux *mux, *chain;
        Ipmuxrock *r;
        Fs *f;

        f = c->p->f;

        if(argc != 2)
                return Ebadarg;

        n = getfields(argv[1], field, nelem(field), 1, ";");
        if(n <= 0)
                return Ebadarg;

        chain = nil;
        mux = nil;
        for(i = 0; i < n; i++){
                mux = parsemux(field[i]);
                if(mux == nil){
                        ipmuxtreefree(chain);
                        return Ebadarg;
                }
                ipmuxchain(&chain, mux);
        }
        if(chain == nil)
                return Ebadarg;
        mux->conv = c;

        /* save a copy of the chain so we can later remove it */
        mux = ipmuxcopy(chain);
        r = (Ipmuxrock*)(c->ptcl);
        r->chain = chain;

        /* add the chain to the protocol demultiplexor tree */
        wlock(f);
        f->ipmux->priv = ipmuxmerge(f->ipmux->priv, mux);
        wunlock(f);

        Fsconnected(c, nil);
        return nil;
}

static int
ipmuxstate(Conv *c, char *state, int n)
{
        Ipmuxrock *r;

        r = (Ipmuxrock*)(c->ptcl);
        return ipmuxsprint(r->chain, 0, state, n);
}

static void
ipmuxcreate(Conv *c)
{
        Ipmuxrock *r;

        c->rq = qopen(64*1024, Qmsg, 0, c);
        c->wq = qopen(64*1024, Qkick, ipmuxkick, c);
        r = (Ipmuxrock*)(c->ptcl);
        r->chain = nil;
}

static char*
ipmuxannounce(Conv*, char**, int)
{
        return "ipmux does not support announce";
}

static void
ipmuxclose(Conv *c)
{
        Ipmuxrock *r;
        Fs *f = c->p->f;

        r = (Ipmuxrock*)(c->ptcl);

        qclose(c->rq);
        qclose(c->wq);
        qclose(c->eq);
        ipmove(c->laddr, IPnoaddr);
        ipmove(c->raddr, IPnoaddr);
        c->lport = 0;
        c->rport = 0;

        wlock(f);
        ipmuxremove(&(c->p->priv), r->chain);
        wunlock(f);
        ipmuxtreefree(r->chain);
        r->chain = nil;
}

/*
 *  takes a fully formed ip packet and just passes it down
 *  the stack
 */
static void
ipmuxkick(void *x)
{
        Conv *c = x;
        Block *bp;

        bp = qget(c->wq);
        if(bp != nil) {
                Myip4hdr *ih4 = (Myip4hdr*)(bp->rp);

                if((ih4->vihl & 0xF0) != IP_VER6)
                        ipoput4(c->p->f, bp, 0, ih4->ttl, ih4->tos, nil);
                else
                        ipoput6(c->p->f, bp, 0, ((Ip6hdr*)ih4)->ttl, 0, nil);
        }
}

static void
ipmuxiput(Proto *p, Ipifc *ifc, Block *bp)
{
        int len, hl;
        Fs *f = p->f;
        uchar *m, *h, *v, *e, *ve, *hp;
        Conv *c;
        Ipmux *mux;
        Myip4hdr *ip;
        Ip6hdr *ip6;

        ip = (Myip4hdr*)bp->rp;
        hl = (ip->vihl&0x0F)<<2;

        if(p->priv == nil)
                goto nomatch;

        h = bp->rp;
        len = BLEN(bp);

        /* run the v4 filter */
        rlock(f);
        c = nil;
        mux = f->ipmux->priv;
        while(mux != nil){
                if(mux->eoff > len){
                        mux = mux->no;
                        continue;
                }
                hp = h + mux->off + ((int)mux->skiphdr)*hl;
                switch(mux->ctype){
                case Cbyte:
                        if(*mux->val == *hp)
                                goto yes;
                        break;
                case Cmbyte:
                        if((*hp & *mux->mask) == *mux->val)
                                goto yes;
                        break;
                case Cshort:
                        if(*((ushort*)mux->val) == *(ushort*)hp)
                                goto yes;
                        break;
                case Cmshort:
                        if((*(ushort*)hp & (*((ushort*)mux->mask))) == *((ushort*)mux->val))
                                goto yes;
                        break;
                case Clong:
                        if(*((ulong*)mux->val) == *(ulong*)hp)
                                goto yes;
                        break;
                case Cmlong:
                        if((*(ulong*)hp & (*((ulong*)mux->mask))) == *((ulong*)mux->val))
                                goto yes;
                        break;
                case Cifc:
                        if(*((ulong*)mux->val) == *(ulong*)(ifc->lifc->local + IPv4off))
                                goto yes;
                        break;
                case Cmifc:
                        if((*(ulong*)(ifc->lifc->local + IPv4off) & (*((ulong*)mux->mask))) == *((ulong*)mux->val))
                                goto yes;
                        break;
                default:
                        v = mux->val;
                        for(e = mux->e; v < e; v = ve){
                                m = mux->mask;
                                hp = h + mux->off;
                                for(ve = v + mux->len; v < ve; v++){
                                        if((*hp++ & *m++) != *v)
                                                break;
                                }
                                if(v == ve)
                                        goto yes;
                        }
                }
                mux = mux->no;
                continue;
yes:
                if(mux->conv != nil)
                        c = mux->conv;
                mux = mux->yes;
        }
        runlock(f);

        if(c != nil){
                /* tack on interface address */
                bp = padblock(bp, IPaddrlen);
                ipmove(bp->rp, ifc->lifc->local);
                bp = concatblock(bp);
                if(bp != nil)
                        if(qpass(c->rq, bp) < 0)
                                print("ipmuxiput: qpass failed\n");
                return;
        }

nomatch:
        /* doesn't match any filter, hand it to the specific protocol handler */
        ip = (Myip4hdr*)bp->rp;
        if((ip->vihl & 0xF0) == IP_VER4) {
                p = f->t2p[ip->proto];
        } else {
                ip6 = (Ip6hdr*)bp->rp;
                p = f->t2p[ip6->proto];
        }
        if(p && p->rcv)
                (*p->rcv)(p, ifc, bp);
        else
                freeblist(bp);
        return;
}

static int
ipmuxsprint(Ipmux *mux, int level, char *buf, int len)
{
        int i, j, n;
        uchar *v;

        n = 0;
        for(i = 0; i < level; i++)
                n += snprint(buf+n, len-n, " ");
        if(mux == nil){
                n += snprint(buf+n, len-n, "\n");
                return n;
        }
        n += snprint(buf+n, len-n, "h[%d:%d]&",
               mux->off+((int)mux->skiphdr)*((int)ipoff->data),
               mux->off+(((int)mux->skiphdr)*((int)ipoff->data))+mux->len-1);
        for(i = 0; i < mux->len; i++)
                n += snprint(buf+n, len - n, "%2.2ux", mux->mask[i]);
        n += snprint(buf+n, len-n, "=");
        v = mux->val;
        for(j = 0; j < mux->n; j++){
                for(i = 0; i < mux->len; i++)
                        n += snprint(buf+n, len - n, "%2.2ux", *v++);
                n += snprint(buf+n, len-n, "|");
        }
        n += snprint(buf+n, len-n, "\n");
        level++;
        n += ipmuxsprint(mux->no, level, buf+n, len-n);
        n += ipmuxsprint(mux->yes, level, buf+n, len-n);
        return n;
}

static int
ipmuxstats(Proto *p, char *buf, int len)
{
        int n;
        Fs *f = p->f;

        rlock(f);
        n = ipmuxsprint(p->priv, 0, buf, len);
        runlock(f);

        return n;
}

void
ipmuxinit(Fs *f)
{
        Proto *ipmux;

        ipmux = smalloc(sizeof(Proto));
        ipmux->priv = nil;
        ipmux->name = "ipmux";
        ipmux->connect = ipmuxconnect;
        ipmux->announce = ipmuxannounce;
        ipmux->state = ipmuxstate;
        ipmux->create = ipmuxcreate;
        ipmux->close = ipmuxclose;
        ipmux->rcv = ipmuxiput;
        ipmux->ctl = nil;
        ipmux->advise = nil;
        ipmux->stats = ipmuxstats;
        ipmux->ipproto = -1;
        ipmux->nc = 64;
        ipmux->ptclsize = sizeof(Ipmuxrock);

        f->ipmux = ipmux;                       /* hack for Fsrcvpcol */

        Fsproto(f, ipmux);
}