Subversion Repositories planix.SVN

Rev

Go to most recent revision | Blame | Compare with Previous | Last modification | View Log | RSS feed

#include <u.h>
#include <libc.h>
#include <oventi.h>
#include "session.h"

static char EProtocolBotch[] = "venti protocol botch";
static char ELumpSize[] = "illegal lump size";
static char ENotConnected[] = "not connected to venti server";

static Packet *vtRPC(VtSession *z, int op, Packet *p);

VtSession *
vtClientAlloc(void)
{
        VtSession *z = vtAlloc();       
        return z;
}

VtSession *
vtDial(char *host, int canfail)
{
        VtSession *z;
        int fd;
        char *na;
        char e[ERRMAX];

        if(host == nil) 
                host = getenv("venti");
        if(host == nil)
                host = "$venti";

        if (host == nil) {
                if (!canfail)
                        werrstr("no venti host set");
                na = "";
                fd = -1;
        } else {
                na = netmkaddr(host, 0, "venti");
                fd = dial(na, 0, 0, 0);
        }
        if(fd < 0){
                rerrstr(e, sizeof e);
                if(!canfail){
                        vtSetError("venti dialstring %s: %s", na, e);
                        return nil;
                }
        }
        z = vtClientAlloc();
        if(fd < 0)
                strcpy(z->fderror, e);
        vtSetFd(z, fd);
        return z;
}

int
vtRedial(VtSession *z, char *host)
{
        int fd;
        char *na;

        if(host == nil) 
                host = getenv("venti");
        if(host == nil)
                host = "$venti";

        na = netmkaddr(host, 0, "venti");
        fd = dial(na, 0, 0, 0);
        if(fd < 0){
                vtOSError();
                return 0;
        }
        vtReset(z);
        vtSetFd(z, fd);
        return 1;
}

VtSession *
vtStdioServer(char *server)
{
        int pfd[2];
        VtSession *z;

        if(server == nil)
                return nil;

        if(access(server, AEXEC) < 0) {
                vtOSError();
                return nil;
        }

        if(pipe(pfd) < 0) {
                vtOSError();
                return nil;
        }

        switch(fork()) {
        case -1:
                close(pfd[0]);
                close(pfd[1]);
                vtOSError();
                return nil;
        case 0:
                close(pfd[0]);
                dup(pfd[1], 0);
                dup(pfd[1], 1);
                execl(server, "ventiserver", "-i", nil);
                exits("exec failed");
        }
        close(pfd[1]);

        z = vtClientAlloc();
        vtSetFd(z, pfd[0]);
        return z;
}

int
vtPing(VtSession *z)
{
        Packet *p = packetAlloc();

        p = vtRPC(z, VtQPing, p);
        if(p == nil)
                return 0;
        packetFree(p);
        return 1;
}

int
vtHello(VtSession *z)
{
        Packet *p;
        uchar buf[10];
        char *sid;
        int crypto, codec;

        sid = nil;

        p = packetAlloc();
        if(!vtAddString(p, vtGetVersion(z)))
                goto Err;
        if(!vtAddString(p, vtGetUid(z)))
                goto Err;
        buf[0] = vtGetCryptoStrength(z);
        buf[1] = 0;
        buf[2] = 0;
        packetAppend(p, buf, 3);
        p = vtRPC(z, VtQHello, p);
        if(p == nil)
                return 0;
        if(!vtGetString(p, &sid))
                goto Err;
        if(!packetConsume(p, buf, 2))
                goto Err;
        if(packetSize(p) != 0) {
                vtSetError(EProtocolBotch);
                goto Err;
        }
        crypto = buf[0];
        codec = buf[1];

        USED(crypto);
        USED(codec);

        packetFree(p);

        vtLock(z->lk);
        z->sid = sid;
        z->auth.state = VtAuthOK;
        vtSha1Free(z->inHash);
        z->inHash = nil;
        vtSha1Free(z->outHash);
        z->outHash = nil;
        vtUnlock(z->lk);

        return 1;
Err:
        packetFree(p);
        vtMemFree(sid);
        return 0;
}

int
vtSync(VtSession *z)
{
        Packet *p = packetAlloc();

        p = vtRPC(z, VtQSync, p);
        if(p == nil)
                return 0;
        if(packetSize(p) != 0){
                vtSetError(EProtocolBotch);
                goto Err;
        }
        packetFree(p);
        return 1;

Err:
        packetFree(p);
        return 0;
}

int
vtWrite(VtSession *z, uchar score[VtScoreSize], int type, uchar *buf, int n)
{
        Packet *p = packetAlloc();

        packetAppend(p, buf, n);
        return vtWritePacket(z, score, type, p);
}

int
vtWritePacket(VtSession *z, uchar score[VtScoreSize], int type, Packet *p)
{
        int n = packetSize(p);
        uchar *hdr;

        if(n > VtMaxLumpSize || n < 0) {
                vtSetError(ELumpSize);
                goto Err;
        }
        
        if(n == 0) {
                memmove(score, vtZeroScore, VtScoreSize);
                return 1;
        }

        hdr = packetHeader(p, 4);
        hdr[0] = type;
        hdr[1] = 0;     /* pad */
        hdr[2] = 0;     /* pad */
        hdr[3] = 0;     /* pad */
        p = vtRPC(z, VtQWrite, p);
        if(p == nil)
                return 0;
        if(!packetConsume(p, score, VtScoreSize))
                goto Err;
        if(packetSize(p) != 0) {
                vtSetError(EProtocolBotch);
                goto Err;
        }
        packetFree(p);
        return 1;
Err:
        packetFree(p);
        return 0;
}

int
vtRead(VtSession *z, uchar score[VtScoreSize], int type, uchar *buf, int n)
{
        Packet *p;

        p = vtReadPacket(z, score, type, n);
        if(p == nil)
                return -1;
        n = packetSize(p);
        packetCopy(p, buf, 0, n);
        packetFree(p);
        return n;
}

Packet *
vtReadPacket(VtSession *z, uchar score[VtScoreSize], int type, int n)
{
        Packet *p;
        uchar buf[10];

        if(n < 0 || n > VtMaxLumpSize) {
                vtSetError(ELumpSize);
                return nil;
        }

        p = packetAlloc();
        if(memcmp(score, vtZeroScore, VtScoreSize) == 0)
                return p;

        packetAppend(p, score, VtScoreSize);
        buf[0] = type;
        buf[1] = 0;     /* pad */
        buf[2] = n >> 8;
        buf[3] = n;
        packetAppend(p, buf, 4);
        return vtRPC(z, VtQRead, p);
}


static Packet *
vtRPC(VtSession *z, int op, Packet *p)
{
        uchar *hdr, buf[2];
        char *err;

        if(z == nil){
                vtSetError(ENotConnected);
                return nil;
        }

        /*
         * single threaded for the momment
         */
        vtLock(z->lk);
        if(z->cstate != VtStateConnected){
                vtSetError(ENotConnected);
                goto Err;
        }
        hdr = packetHeader(p, 2);
        hdr[0] = op;    /* op */
        hdr[1] = 0;     /* tid */
        vtDebug(z, "client send: ");
        vtDebugMesg(z, p, "\n");
        if(!vtSendPacket(z, p)) {
                p = nil;
                goto Err;
        }
        p = vtRecvPacket(z);
        if(p == nil)
                goto Err;
        vtDebug(z, "client recv: ");
        vtDebugMesg(z, p, "\n");
        if(!packetConsume(p, buf, 2))
                goto Err;
        if(buf[0] == VtRError) {
                if(!vtGetString(p, &err)) {
                        vtSetError(EProtocolBotch);
                        goto Err;
                }
                vtSetError(err);
                vtMemFree(err);
                packetFree(p);
                vtUnlock(z->lk);
                return nil;
        }
        if(buf[0] != op+1 || buf[1] != 0) {
                vtSetError(EProtocolBotch);
                goto Err;
        }
        vtUnlock(z->lk);
        return p;
Err:
        vtDebug(z, "vtRPC failed: %s\n", vtGetError());
        if(p != nil)
                packetFree(p);
        vtUnlock(z->lk);
        vtDisconnect(z, 1);
        return nil;
}