Subversion Repositories planix.SVN

Rev

Rev 2 | Blame | Compare with Previous | Last modification | View Log | RSS feed

/*
 * Sun RPC client.
 */
#include <u.h>
#include <libc.h>
#include <thread.h>
#include <sunrpc.h>

typedef struct Out Out;
struct Out
{
        char err[ERRMAX];       /* error string */
        Channel *creply;        /* send to finish rpc */
        uchar *p;                       /* pending request packet */
        int n;                          /* size of request */
        ulong tag;                      /* flush tag of pending request */
        ulong xid;                      /* xid of pending request */
        ulong st;                       /* first send time */
        ulong t;                        /* resend time */
        int nresend;            /* number of resends */
        SunRpc rpc;             /* response rpc */
};

static void
udpThread(void *v)
{
        uchar *p, *buf;
        Ioproc *io;
        int n;
        SunClient *cli;
        enum { BufSize = 65536 };

        cli = v;
        buf = emalloc(BufSize);
        io = ioproc();
        p = nil;
        for(;;){
                n = ioread(io, cli->fd, buf, BufSize);
                if(n <= 0)
                        break;
                p = emalloc(4+n);
                memmove(p+4, buf, n);
                p[0] = n>>24;
                p[1] = n>>16;
                p[2] = n>>8;
                p[3] = n;
                if(sendp(cli->readchan, p) == 0)
                        break;
                p = nil;
        }
        free(p);
        closeioproc(io);
        while(send(cli->dying, nil) == -1)
                ;
}

static void
netThread(void *v)
{
        uchar *p, buf[4];
        Ioproc *io;
        uint n, tot;
        int done;
        SunClient *cli;

        cli = v;
        io = ioproc();
        tot = 0;
        p = nil;
        for(;;){
                n = ioreadn(io, cli->fd, buf, 4);
                if(n != 4)
                        break;
                n = (buf[0]<<24)|(buf[1]<<16)|(buf[2]<<8)|buf[3];
                if(cli->chatty)
                        fprint(2, "%.8ux...", n);
                done = n&0x80000000;
                n &= ~0x80000000;
                if(tot == 0){
                        p = emalloc(4+n);
                        tot = 4;
                }else
                        p = erealloc(p, tot+n);
                if(ioreadn(io, cli->fd, p+tot, n) != n)
                        break;
                tot += n;
                if(done){
                        p[0] = tot>>24;
                        p[1] = tot>>16;
                        p[2] = tot>>8;
                        p[3] = tot;
                        if(sendp(cli->readchan, p) == 0)
                                break;
                        p = nil;
                        tot = 0;
                }
        }
        free(p);
        closeioproc(io);
        while(send(cli->dying, 0) == -1)
                ;
}

static void
timerThread(void *v)
{
        Ioproc *io;
        SunClient *cli;

        cli = v;
        io = ioproc();
        for(;;){
                if(iosleep(io, 200) < 0)
                        break;
                if(sendul(cli->timerchan, 0) == 0)
                        break;
        }
        closeioproc(io);
        while(send(cli->dying, 0) == -1)
                ;
}

static ulong
msec(void)
{
        return nsec()/1000000;
}

static ulong
twait(ulong rtt, int nresend)
{
        ulong t;

        t = rtt;
        if(nresend <= 1)
                {}
        else if(nresend <= 3)
                t *= 2;
        else if(nresend <= 18)
                t <<= nresend-2;
        else
                t = 60*1000;
        if(t > 60*1000)
                t = 60*1000;

        return t;
}

static void
rpcMuxThread(void *v)
{
        uchar *buf, *p, *ep;
        int i, n, nout, mout;
        ulong t, xidgen, tag;
        Alt a[5];
        Out *o, **out;
        SunRpc rpc;
        SunClient *cli;

        cli = v;
        mout = 16;
        nout = 0;
        out = emalloc(mout*sizeof(out[0]));
        xidgen = truerand();

        a[0].op = CHANRCV;
        a[0].c = cli->rpcchan;
        a[0].v = &o;
        a[1].op = CHANNOP;
        a[1].c = cli->timerchan;
        a[1].v = nil;
        a[2].op = CHANRCV;
        a[2].c = cli->flushchan;
        a[2].v = &tag;
        a[3].op = CHANRCV;
        a[3].c = cli->readchan;
        a[3].v = &buf;
        a[4].op = CHANEND;

        for(;;){
                switch(alt(a)){
                case 0:  /* o = <-rpcchan */
                        if(o == nil)
                                goto Done;
                        cli->nsend++;
                        /* set xid */
                        o->xid = ++xidgen;
                        if(cli->needcount)
                                p = o->p+4;
                        else
                                p = o->p;
                        p[0] = xidgen>>24;
                        p[1] = xidgen>>16;
                        p[2] = xidgen>>8;
                        p[3] = xidgen;
                        if(write(cli->fd, o->p, o->n) != o->n){
                                free(o->p);
                                o->p = nil;
                                snprint(o->err, sizeof o->err, "write: %r");
                                sendp(o->creply, 0);
                                break;
                        }
                        if(nout >= mout){
                                mout *= 2;
                                out = erealloc(out, mout*sizeof(out[0]));
                        }
                        o->st = msec();
                        o->nresend = 0;
                        o->t = o->st + twait(cli->rtt.avg, 0);
if(cli->chatty) fprint(2, "send %lux %lud %lud\n", o->xid, o->st, o->t);
                        out[nout++] = o;
                        a[1].op = CHANRCV;
                        break;

                case 1: /* <-timerchan */
                        t = msec();
                        for(i=0; i<nout; i++){
                                o = out[i];
                                if((int)(t - o->t) > 0){
if(cli->chatty) fprint(2, "resend %lux %lud %lud\n", o->xid, t, o->t);
                                        if(cli->maxwait && t - o->st >= cli->maxwait){
                                                free(o->p);
                                                o->p = nil;
                                                strcpy(o->err, "timeout");
                                                sendp(o->creply, 0);
                                                out[i--] = out[--nout];
                                                continue;
                                        }
                                        cli->nresend++;
                                        o->nresend++;
                                        o->t = t + twait(cli->rtt.avg, o->nresend);
                                        if(write(cli->fd, o->p, o->n) != o->n){
                                                free(o->p);
                                                o->p = nil;
                                                snprint(o->err, sizeof o->err, "rewrite: %r");
                                                sendp(o->creply, 0);
                                                out[i--] = out[--nout];
                                                continue;
                                        }
                                }
                        }
                        /* stop ticking if no work; rpcchan will turn it back on */
                        if(nout == 0)
                                a[1].op = CHANNOP;
                        break;
                        
                case 2: /* tag = <-flushchan */
                        for(i=0; i<nout; i++){
                                o = out[i];
                                if(o->tag == tag){
                                        out[i--] = out[--nout];
                                        strcpy(o->err, "flushed");
                                        free(o->p);
                                        o->p = nil;
                                        sendp(o->creply, 0);
                                }
                        }
                        break;

                case 3: /* buf = <-readchan */
                        p = buf;
                        n = (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3];
                        p += 4;
                        ep = p+n;
                        if(sunRpcUnpack(p, ep, &p, &rpc) < 0){
                                fprint(2, "in: %.*H unpack failed\n", n, buf+4);
                                free(buf);
                                break;
                        }
                        if(cli->chatty)
                                fprint(2, "in: %B\n", &rpc);
                        if(rpc.iscall){
                                fprint(2, "did not get reply\n");
                                free(buf);
                                break;
                        }
                        o = nil;
                        for(i=0; i<nout; i++){
                                o = out[i];
                                if(o->xid == rpc.xid)
                                        break;
                        }
                        if(i==nout){
                                if(cli->chatty) fprint(2, "did not find waiting request\n");
                                free(buf);
                                break;
                        }
                        out[i] = out[--nout];
                        free(o->p);
                        o->p = nil;
                        if(rpc.status == SunSuccess){
                                o->p = buf;
                                o->rpc = rpc;
                        }else{
                                o->p = nil;
                                free(buf);
                                sunErrstr(rpc.status);
                                rerrstr(o->err, sizeof o->err);
                        }
                        sendp(o->creply, 0);
                        break;
                }
        }
Done:
        free(out);
        sendp(cli->dying, 0);
}

SunClient*
sunDial(char *address)
{
        int fd;
        SunClient *cli;

        if((fd = dial(address, 0, 0, 0)) < 0)
                return nil;

        cli = emalloc(sizeof(SunClient));
        cli->fd = fd;
        cli->maxwait = 15000;
        cli->rtt.avg = 1000;
        cli->dying = chancreate(sizeof(void*), 0);
        cli->rpcchan = chancreate(sizeof(Out*), 0);
        cli->timerchan = chancreate(sizeof(ulong), 0);
        cli->flushchan = chancreate(sizeof(ulong), 0);
        cli->readchan = chancreate(sizeof(uchar*), 0);
        if(strstr(address, "udp!")){
                cli->needcount = 0;
                cli->nettid = threadcreate(udpThread, cli, SunStackSize);
                cli->timertid = threadcreate(timerThread, cli, SunStackSize);
        }else{
                cli->needcount = 1;
                cli->nettid = threadcreate(netThread, cli, SunStackSize);
                /* assume reliable: don't need timer */
                /* BUG: netThread should know how to redial */
        }
        threadcreate(rpcMuxThread, cli, SunStackSize);

        return cli;
}

void
sunClientClose(SunClient *cli)
{
        int n;

        /*
         * Threadints get you out of any stuck system calls
         * or thread rendezvouses, but do nothing if the thread
         * is in the ready state.  Keep interrupting until it takes.
         */
        n = 0;
        if(!cli->timertid)
                n++;
        while(n < 2){
                threadint(cli->nettid);
                if(cli->timertid)
                        threadint(cli->timertid);
                yield();
                while(nbrecv(cli->dying, nil) == 1)
                        n++;
        }

        sendp(cli->rpcchan, 0);
        recvp(cli->dying);

        /* everyone's gone: clean up */
        close(cli->fd);
        chanfree(cli->flushchan);
        chanfree(cli->readchan);
        chanfree(cli->timerchan);
        free(cli);
}
        
void
sunClientFlushRpc(SunClient *cli, ulong tag)
{
        sendul(cli->flushchan, tag);
}

void
sunClientProg(SunClient *cli, SunProg *p)
{
        if(cli->nprog%16 == 0)
                cli->prog = erealloc(cli->prog, (cli->nprog+16)*sizeof(cli->prog[0]));
        cli->prog[cli->nprog++] = p;
}

int
sunClientRpc(SunClient *cli, ulong tag, SunCall *tx, SunCall *rx, uchar **tofree)
{
        uchar *bp, *p, *ep;
        int i, n1, n2, n, nn;
        Out o;
        SunProg *prog;
        SunStatus ok;

        for(i=0; i<cli->nprog; i++)
                if(cli->prog[i]->prog == tx->rpc.prog && cli->prog[i]->vers == tx->rpc.vers)
                        break;
        if(i==cli->nprog){
                werrstr("unknown sun rpc program %d version %d", tx->rpc.prog, tx->rpc.vers);
                return -1;
        }
        prog = cli->prog[i];

        if(cli->chatty){
                fprint(2, "out: %B\n", &tx->rpc);
                fprint(2, "\t%C\n", tx);
        }

        n1 = sunRpcSize(&tx->rpc);
        n2 = sunCallSize(prog, tx);

        n = n1+n2;
        if(cli->needcount)
                n += 4;

        bp = emalloc(n);
        ep = bp+n;
        p = bp;
        if(cli->needcount){
                nn = n-4;
                p[0] = (nn>>24)|0x80;
                p[1] = nn>>16;
                p[2] = nn>>8;
                p[3] = nn;
                p += 4;
        }
        if((ok = sunRpcPack(p, ep, &p, &tx->rpc)) != SunSuccess
        || (ok = sunCallPack(prog, p, ep, &p, tx)) != SunSuccess){
                sunErrstr(ok);
                free(bp);
                return -1;
        }
        if(p != ep){
                werrstr("rpc: packet size mismatch");
                free(bp);
                return -1;
        }

        memset(&o, 0, sizeof o);
        o.creply = chancreate(sizeof(void*), 0);
        o.tag = tag;
        o.p = bp;
        o.n = n;

        sendp(cli->rpcchan, &o);
        recvp(o.creply);
        chanfree(o.creply);

        if(o.p == nil){
                werrstr("%s", o.err);
                return -1;
        }

        p = o.rpc.data;
        ep = p+o.rpc.ndata;
        rx->rpc = o.rpc;
        rx->rpc.proc = tx->rpc.proc;
        rx->rpc.prog = tx->rpc.prog;
        rx->rpc.vers = tx->rpc.vers;
        rx->type = (rx->rpc.proc<<1)|1;
        if((ok = sunCallUnpack(prog, p, ep, &p, rx)) != SunSuccess){
                sunErrstr(ok);
                werrstr("unpack: %r");
                free(o.p);
                return -1;
        }

        if(cli->chatty){
                fprint(2, "in: %B\n", &rx->rpc);
                fprint(2, "in:\t%C\n", rx);
        }

        if(tofree)
                *tofree = o.p;
        else
                free(o.p);

        return 0;
}