Subversion Repositories planix.SVN

Rev

Blame | Last modification | View Log | RSS feed

#include <u.h>
#include <libc.h>
#include <ip.h>
#include "dns.h"

/*
 *  a dictionary of domain names for packing messages
 */
enum
{
        Ndict=  64,
};
typedef struct Dict     Dict;
struct Dict
{
        struct {
                ushort  offset;         /* pointer to packed name in message */
                char    *name;          /* pointer to unpacked name in buf */
        } x[Ndict];
        int     n;              /* size of dictionary */
        uchar   *start;         /* start of packed message */
        char    buf[16*1024];   /* buffer for unpacked names (was 4k) */
        char    *ep;            /* first free char in buf */
};

#define NAME(x)         p = pname(p, ep, x, dp)
#define SYMBOL(x)       p = psym(p, ep, x)
#define STRING(x)       p = pstr(p, ep, x)
#define BYTES(x, n)     p = pbytes(p, ep, x, n)
#define USHORT(x)       p = pushort(p, ep, x)
#define UCHAR(x)        p = puchar(p, ep, x)
#define ULONG(x)        p = pulong(p, ep, x)
#define V4ADDR(x)       p = pv4addr(p, ep, x)
#define V6ADDR(x)       p = pv6addr(p, ep, x)

static uchar*
psym(uchar *p, uchar *ep, char *np)
{
        int n;

        n = strlen(np);
        if(n >= Strlen)                 /* DNS maximum length string */
                n = Strlen - 1;
        if(ep - p < n+1)                /* see if it fits in the buffer */
                return ep+1;
        *p++ = n;
        memmove(p, np, n);
        return p + n;
}

static uchar*
pstr(uchar *p, uchar *ep, char *np)
{
        return psym(p, ep, np);
}

static uchar*
pbytes(uchar *p, uchar *ep, uchar *np, int n)
{
        if(ep - p < n)
                return ep+1;
        memmove(p, np, n);
        return p + n;
}

static uchar*
puchar(uchar *p, uchar *ep, int val)
{
        if(ep - p < 1)
                return ep+1;
        *p++ = val;
        return p;
}

static uchar*
pushort(uchar *p, uchar *ep, int val)
{
        if(ep - p < 2)
                return ep+1;
        *p++ = val>>8;
        *p++ = val;
        return p;
}

static uchar*
pulong(uchar *p, uchar *ep, int val)
{
        if(ep - p < 4)
                return ep+1;
        *p++ = val>>24;
        *p++ = val>>16;
        *p++ = val>>8;
        *p++ = val;
        return p;
}

static uchar*
pv4addr(uchar *p, uchar *ep, char *name)
{
        uchar ip[IPaddrlen];

        if(ep - p < 4)
                return ep+1;
        parseip(ip, name);
        v6tov4(p, ip);
        return p + 4;
}

static uchar*
pv6addr(uchar *p, uchar *ep, char *name)
{
        if(ep - p < IPaddrlen)
                return ep+1;
        parseip(p, name);
        return p + IPaddrlen;
}

static uchar*
pname(uchar *p, uchar *ep, char *np, Dict *dp)
{
        int i;
        char *cp;
        char *last;             /* last component packed */

        if(strlen(np) >= Domlen) /* make sure we don't exceed DNS limits */
                return ep+1;

        last = 0;
        while(*np){
                /* look through every component in the dictionary for a match */
                for(i = 0; i < dp->n; i++)
                        if(strcmp(np, dp->x[i].name) == 0){
                                if(ep - p < 2)
                                        return ep+1;
                                if ((dp->x[i].offset>>8) & 0xc0)
                                        dnslog("convDNS2M: offset too big for "
                                                "DNS packet format");
                                *p++ = dp->x[i].offset>>8 | 0xc0;
                                *p++ = dp->x[i].offset;
                                return p;
                        }

                /* if there's room, enter this name in dictionary */
                if(dp->n < Ndict)
                        if(last){
                                /* the whole name is already in dp->buf */
                                last = strchr(last, '.') + 1;
                                dp->x[dp->n].name = last;
                                dp->x[dp->n].offset = p - dp->start;
                                dp->n++;
                        } else {
                                /* add to dp->buf */
                                i = strlen(np);
                                if(dp->ep + i + 1 < &dp->buf[sizeof dp->buf]){
                                        strcpy(dp->ep, np);
                                        dp->x[dp->n].name = dp->ep;
                                        last = dp->ep;
                                        dp->x[dp->n].offset = p - dp->start;
                                        dp->ep += i + 1;
                                        dp->n++;
                                }
                        }

                /* put next component into message */
                cp = strchr(np, '.');
                if(cp == nil){
                        i = strlen(np);
                        cp = np + i;    /* point to null terminator */
                } else {
                        i = cp - np;
                        cp++;           /* point past '.' */
                }
                if(ep-p < i+1)
                        return ep+1;
                if (i > Labellen)
                        return ep+1;
                *p++ = i;               /* count of chars in label */
                memmove(p, np, i);
                np = cp;
                p += i;
        }

        if(p >= ep)
                return ep+1;
        *p++ = 0;       /* add top level domain */

        return p;
}

static uchar*
convRR2M(RR *rp, uchar *p, uchar *ep, Dict *dp)
{
        uchar *lp, *data;
        int len, ttl;
        Txt *t;

        NAME(rp->owner->name);
        USHORT(rp->type);
        USHORT(rp->owner->class);

        /* egregious overuse of ttl (it's absolute time in the cache) */
        if(rp->db)
                ttl = rp->ttl;
        else
                ttl = rp->ttl - now;
        if(ttl < 0)
                ttl = 0;
        ULONG(ttl);

        lp = p;                 /* leave room for the rdata length */
        p += 2;
        data = p;

        if(data >= ep)
                return p+1;

        switch(rp->type){
        case Thinfo:
                SYMBOL(rp->cpu->name);
                SYMBOL(rp->os->name);
                break;
        case Tcname:
        case Tmb:
        case Tmd:
        case Tmf:
        case Tns:
                NAME(rp->host->name);
                break;
        case Tmg:
        case Tmr:
                NAME(rp->mb->name);
                break;
        case Tminfo:
                NAME(rp->rmb->name);
                NAME(rp->mb->name);
                break;
        case Tmx:
                USHORT(rp->pref);
                NAME(rp->host->name);
                break;
        case Ta:
                V4ADDR(rp->ip->name);
                break;
        case Taaaa:
                V6ADDR(rp->ip->name);
                break;
        case Tptr:
                NAME(rp->ptr->name);
                break;
        case Tsoa:
                NAME(rp->host->name);
                NAME(rp->rmb->name);
                ULONG(rp->soa->serial);
                ULONG(rp->soa->refresh);
                ULONG(rp->soa->retry);
                ULONG(rp->soa->expire);
                ULONG(rp->soa->minttl);
                break;
        case Tsrv:
                USHORT(rp->srv->pri);
                USHORT(rp->srv->weight);
                USHORT(rp->port);
                STRING(rp->host->name); /* rfc2782 sez no name compression */
                break;
        case Ttxt:
                for(t = rp->txt; t != nil; t = t->next)
                        STRING(t->p);
                break;
        case Tnull:
                BYTES(rp->null->data, rp->null->dlen);
                break;
        case Trp:
                NAME(rp->rmb->name);
                NAME(rp->rp->name);
                break;
        case Tkey:
                USHORT(rp->key->flags);
                UCHAR(rp->key->proto);
                UCHAR(rp->key->alg);
                BYTES(rp->key->data, rp->key->dlen);
                break;
        case Tsig:
                USHORT(rp->sig->type);
                UCHAR(rp->sig->alg);
                UCHAR(rp->sig->labels);
                ULONG(rp->sig->ttl);
                ULONG(rp->sig->exp);
                ULONG(rp->sig->incep);
                USHORT(rp->sig->tag);
                NAME(rp->sig->signer->name);
                BYTES(rp->sig->data, rp->sig->dlen);
                break;
        case Tcert:
                USHORT(rp->cert->type);
                USHORT(rp->cert->tag);
                UCHAR(rp->cert->alg);
                BYTES(rp->cert->data, rp->cert->dlen);
                break;
        }

        /* stuff in the rdata section length */
        len = p - data;
        *lp++ = len >> 8;
        *lp = len;

        return p;
}

static uchar*
convQ2M(RR *rp, uchar *p, uchar *ep, Dict *dp)
{
        NAME(rp->owner->name);
        USHORT(rp->type);
        USHORT(rp->owner->class);
        return p;
}

static uchar*
rrloop(RR *rp, int *countp, uchar *p, uchar *ep, Dict *dp, int quest)
{
        uchar *np;

        *countp = 0;
        for(; rp && p < ep; rp = rp->next){
                if(quest)
                        np = convQ2M(rp, p, ep, dp);
                else
                        np = convRR2M(rp, p, ep, dp);
                if(np > ep)
                        break;
                p = np;
                (*countp)++;
        }
        return p;
}

/*
 *  convert into a message
 */
int
convDNS2M(DNSmsg *m, uchar *buf, int len)
{
        ulong trunc = 0;
        uchar *p, *ep, *np;
        Dict d;

        d.n = 0;
        d.start = buf;
        d.ep = d.buf;
        memset(buf, 0, len);
        m->qdcount = m->ancount = m->nscount = m->arcount = 0;

        /* first pack in the RR's so we can get real counts */
        p = buf + 12;
        ep = buf + len;
        p = rrloop(m->qd, &m->qdcount, p, ep, &d, 1);
        p = rrloop(m->an, &m->ancount, p, ep, &d, 0);
        p = rrloop(m->ns, &m->nscount, p, ep, &d, 0);
        p = rrloop(m->ar, &m->arcount, p, ep, &d, 0);
        if(p > ep) {
                trunc = Ftrunc;
                dnslog("udp packet full; truncating my reply");
                p = ep;
        }

        /* now pack the rest */
        np = p;
        p = buf;
        ep = buf + len;
        USHORT(m->id);
        USHORT(m->flags | trunc);
        USHORT(m->qdcount);
        USHORT(m->ancount);
        USHORT(m->nscount);
        USHORT(m->arcount);
        USED(p);
        return np - buf;
}