Subversion Repositories planix.SVN

Rev

Go to most recent revision | Blame | Compare with Previous | Last modification | View Log | RSS feed

/*
 * dnstcp - serve dns via tcp
 */
#include <u.h>
#include <libc.h>
#include <ip.h>
#include "dns.h"

Cfg cfg;

char    *caller = "";
char    *dbfile;
int     debug;
uchar   ipaddr[IPaddrlen];      /* my ip address */
char    *logfile = "dns";
int     maxage = 60*60;
char    mntpt[Maxpath];
int     needrefresh;
ulong   now;
vlong   nowns;
int     testing;
int     traceactivity;
char    *zonerefreshprogram;

static int      readmsg(int, uchar*, int);
static void     reply(int, DNSmsg*, Request*);
static void     dnzone(DNSmsg*, DNSmsg*, Request*);
static void     getcaller(char*);
static void     refreshmain(char*);

void
usage(void)
{
        fprint(2, "usage: %s [-rR] [-f ndb-file] [-x netmtpt] [conndir]\n", argv0);
        exits("usage");
}

void
main(int argc, char *argv[])
{
        volatile int len, rcode;
        volatile char tname[32];
        char *volatile err, *volatile ext = "";
        volatile uchar buf[64*1024], callip[IPaddrlen];
        volatile DNSmsg reqmsg, repmsg;
        volatile Request req;

        alarm(2*60*1000);
        cfg.cachedb = 1;
        ARGBEGIN{
        case 'd':
                debug++;
                break;
        case 'f':
                dbfile = EARGF(usage());
                break;
        case 'r':
                cfg.resolver = 1;
                break;
        case 'R':
                norecursion = 1;
                break;
        case 'x':
                ext = EARGF(usage());
                break;
        default:
                usage();
                break;
        }ARGEND

        if(debug < 2)
                debug = 0;

        if(argc > 0)
                getcaller(argv[0]);

        cfg.inside = 1;
        dninit();

        snprint(mntpt, sizeof mntpt, "/net%s", ext);
        if(myipaddr(ipaddr, mntpt) < 0)
                sysfatal("can't read my ip address");
        dnslog("dnstcp call from %s to %I", caller, ipaddr);
        memset(callip, 0, sizeof callip);
        parseip(callip, caller);

        db2cache(1);

        memset(&req, 0, sizeof req);
        setjmp(req.mret);
        req.isslave = 0;
        procsetname("main loop");

        /* loop on requests */
        for(;; putactivity(0)){
                now = time(nil);
                memset(&repmsg, 0, sizeof repmsg);
                len = readmsg(0, buf, sizeof buf);
                if(len <= 0)
                        break;

                getactivity(&req, 0);
                req.aborttime = timems() + S2MS(15*Min);
                rcode = 0;
                memset(&reqmsg, 0, sizeof reqmsg);
                err = convM2DNS(buf, len, &reqmsg, &rcode);
                if(err){
                        dnslog("server: input error: %s from %s", err, caller);
                        free(err);
                        break;
                }
                if (rcode == 0)
                        if(reqmsg.qdcount < 1){
                                dnslog("server: no questions from %s", caller);
                                break;
                        } else if(reqmsg.flags & Fresp){
                                dnslog("server: reply not request from %s",
                                        caller);
                                break;
                        } else if((reqmsg.flags & Omask) != Oquery){
                                dnslog("server: op %d from %s",
                                        reqmsg.flags & Omask, caller);
                                break;
                        }
                if(debug)
                        dnslog("[%d] %d: serve (%s) %d %s %s",
                                getpid(), req.id, caller,
                                reqmsg.id, reqmsg.qd->owner->name,
                                rrname(reqmsg.qd->type, tname, sizeof tname));

                /* loop through each question */
                while(reqmsg.qd)
                        if(reqmsg.qd->type == Taxfr)
                                dnzone(&reqmsg, &repmsg, &req);
                        else {
                                dnserver(&reqmsg, &repmsg, &req, callip, rcode);
                                reply(1, &repmsg, &req);
                                rrfreelist(repmsg.qd);
                                rrfreelist(repmsg.an);
                                rrfreelist(repmsg.ns);
                                rrfreelist(repmsg.ar);
                        }
                rrfreelist(reqmsg.qd);          /* qd will be nil */
                rrfreelist(reqmsg.an);
                rrfreelist(reqmsg.ns);
                rrfreelist(reqmsg.ar);

                if(req.isslave){
                        putactivity(0);
                        _exits(0);
                }
        }
        refreshmain(mntpt);
}

static int
readmsg(int fd, uchar *buf, int max)
{
        int n;
        uchar x[2];

        if(readn(fd, x, 2) != 2)
                return -1;
        n = x[0]<<8 | x[1];
        if(n > max)
                return -1;
        if(readn(fd, buf, n) != n)
                return -1;
        return n;
}

static void
reply(int fd, DNSmsg *rep, Request *req)
{
        int len, rv;
        char tname[32];
        uchar buf[64*1024];
        RR *rp;

        if(debug){
                dnslog("%d: reply (%s) %s %s %ux",
                        req->id, caller,
                        rep->qd->owner->name,
                        rrname(rep->qd->type, tname, sizeof tname),
                        rep->flags);
                for(rp = rep->an; rp; rp = rp->next)
                        dnslog("an %R", rp);
                for(rp = rep->ns; rp; rp = rp->next)
                        dnslog("ns %R", rp);
                for(rp = rep->ar; rp; rp = rp->next)
                        dnslog("ar %R", rp);
        }


        len = convDNS2M(rep, buf+2, sizeof(buf) - 2);
        buf[0] = len>>8;
        buf[1] = len;
        rv = write(fd, buf, len+2);
        if(rv != len+2){
                dnslog("[%d] sending reply: %d instead of %d", getpid(), rv,
                        len+2);
                exits(0);
        }
}

/*
 *  Hash table for domain names.  The hash is based only on the
 *  first element of the domain name.
 */
extern DN       *ht[HTLEN];

static int
numelem(char *name)
{
        int i;

        i = 1;
        for(; *name; name++)
                if(*name == '.')
                        i++;
        return i;
}

int
inzone(DN *dp, char *name, int namelen, int depth)
{
        int n;

        if(dp->name == nil)
                return 0;
        if(numelem(dp->name) != depth)
                return 0;
        n = strlen(dp->name);
        if(n < namelen)
                return 0;
        if(strcmp(name, dp->name + n - namelen) != 0)
                return 0;
        if(n > namelen && dp->name[n - namelen - 1] != '.')
                return 0;
        return 1;
}

static void
dnzone(DNSmsg *reqp, DNSmsg *repp, Request *req)
{
        DN *dp, *ndp;
        RR r, *rp;
        int h, depth, found, nlen;

        memset(repp, 0, sizeof(*repp));
        repp->id = reqp->id;
        repp->qd = reqp->qd;
        reqp->qd = reqp->qd->next;
        repp->qd->next = 0;
        repp->flags = Fauth | Fresp | Oquery;
        if(!norecursion)
                repp->flags |= Fcanrec;
        dp = repp->qd->owner;

        /* send the soa */
        repp->an = rrlookup(dp, Tsoa, NOneg);
        reply(1, repp, req);
        if(repp->an == 0)
                goto out;
        rrfreelist(repp->an);
        repp->an = nil;

        nlen = strlen(dp->name);

        /* construct a breadth-first search of the name space (hard with a hash) */
        repp->an = &r;
        for(depth = numelem(dp->name); ; depth++){
                found = 0;
                for(h = 0; h < HTLEN; h++)
                        for(ndp = ht[h]; ndp; ndp = ndp->next)
                                if(inzone(ndp, dp->name, nlen, depth)){
                                        for(rp = ndp->rr; rp; rp = rp->next){
                                                /*
                                                 * there shouldn't be negatives,
                                                 * but just in case.
                                                 * don't send any soa's,
                                                 * ns's are enough.
                                                 */
                                                if (rp->negative ||
                                                    rp->type == Tsoa)
                                                        continue;
                                                r = *rp;
                                                r.next = 0;
                                                reply(1, repp, req);
                                        }
                                        found = 1;
                                }
                if(!found)
                        break;
        }

        /* resend the soa */
        repp->an = rrlookup(dp, Tsoa, NOneg);
        reply(1, repp, req);
        rrfreelist(repp->an);
        repp->an = nil;
out:
        rrfree(repp->qd);
        repp->qd = nil;
}

static void
getcaller(char *dir)
{
        int fd, n;
        static char remote[128];

        snprint(remote, sizeof(remote), "%s/remote", dir);
        fd = open(remote, OREAD);
        if(fd < 0)
                return;
        n = read(fd, remote, sizeof remote - 1);
        close(fd);
        if(n <= 0)
                return;
        if(remote[n-1] == '\n')
                n--;
        remote[n] = 0;
        caller = remote;
}

static void
refreshmain(char *net)
{
        int fd;
        char file[128];

        snprint(file, sizeof(file), "%s/dns", net);
        if(debug)
                dnslog("refreshing %s", file);
        fd = open(file, ORDWR);
        if(fd < 0)
                dnslog("can't refresh %s", file);
        else {
                fprint(fd, "refresh");
                close(fd);
        }
}

/*
 *  the following varies between dnsdebug and dns
 */
void
logreply(int id, uchar *addr, DNSmsg *mp)
{
        RR *rp;

        dnslog("%d: rcvd %I flags:%s%s%s%s%s", id, addr,
                mp->flags & Fauth? " auth": "",
                mp->flags & Ftrunc? " trunc": "",
                mp->flags & Frecurse? " rd": "",
                mp->flags & Fcanrec? " ra": "",
                (mp->flags & (Fauth|Rmask)) == (Fauth|Rname)? " nx": "");
        for(rp = mp->qd; rp != nil; rp = rp->next)
                dnslog("%d: rcvd %I qd %s", id, addr, rp->owner->name);
        for(rp = mp->an; rp != nil; rp = rp->next)
                dnslog("%d: rcvd %I an %R", id, addr, rp);
        for(rp = mp->ns; rp != nil; rp = rp->next)
                dnslog("%d: rcvd %I ns %R", id, addr, rp);
        for(rp = mp->ar; rp != nil; rp = rp->next)
                dnslog("%d: rcvd %I ar %R", id, addr, rp);
}

void
logsend(int id, int subid, uchar *addr, char *sname, char *rname, int type)
{
        char buf[12];

        dnslog("%d.%d: sending to %I/%s %s %s",
                id, subid, addr, sname, rname, rrname(type, buf, sizeof buf));
}

RR*
getdnsservers(int class)
{
        return dnsservers(class);
}