Subversion Repositories planix.SVN

Rev

Blame | Last modification | View Log | RSS feed

#include <u.h>
#include <libc.h>
#include <oventi.h>
#include "session.h"

struct {
        int version;
        char *s;
} vtVersions[] = {
        VtVersion02, "02",
        0, 0,
};

static char EBigString[] = "string too long";
static char EBigPacket[] = "packet too long";
static char ENullString[] = "missing string";
static char EBadVersion[] = "bad format in version string";

static Packet *vtRPC(VtSession *z, int op, Packet *p);


VtSession *
vtAlloc(void)
{
        VtSession *z;

        z = vtMemAllocZ(sizeof(VtSession));
        z->lk = vtLockAlloc();
//      z->inHash = vtSha1Alloc();
        z->inLock = vtLockAlloc();
        z->part = packetAlloc();
//      z->outHash = vtSha1Alloc();
        z->outLock = vtLockAlloc();
        z->fd = -1;
        z->uid = vtStrDup("anonymous");
        z->sid = vtStrDup("anonymous");
        return z;
}

void
vtReset(VtSession *z)
{
        vtLock(z->lk);
        z->cstate = VtStateAlloc;
        if(z->fd >= 0){
                vtFdClose(z->fd);
                z->fd = -1;
        }
        vtUnlock(z->lk);
}

int
vtConnected(VtSession *z)
{
        return z->cstate == VtStateConnected;
}

void
vtDisconnect(VtSession *z, int error)
{
        Packet *p;
        uchar *b;

vtDebug(z, "vtDisconnect\n");
        vtLock(z->lk);
        if(z->cstate == VtStateConnected && !error && z->vtbl == nil) {
                /* clean shutdown */
                p = packetAlloc();
                b = packetHeader(p, 2);
                b[0] = VtQGoodbye;
                b[1] = 0;
                vtSendPacket(z, p);
        }
        if(z->fd >= 0)
                vtFdClose(z->fd);
        z->fd = -1;
        z->cstate = VtStateClosed;
        vtUnlock(z->lk);
}

void
vtClose(VtSession *z)
{
        vtDisconnect(z, 0);
}

void
vtFree(VtSession *z)
{
        if(z == nil)
                return;
        vtLockFree(z->lk);
        vtSha1Free(z->inHash);
        vtLockFree(z->inLock);
        packetFree(z->part);
        vtSha1Free(z->outHash);
        vtLockFree(z->outLock);
        vtMemFree(z->uid);
        vtMemFree(z->sid);
        vtMemFree(z->vtbl);

        memset(z, 0, sizeof(VtSession));
        z->fd = -1;

        vtMemFree(z);
}

char *
vtGetUid(VtSession *s)
{
        return s->uid;
}

char *
vtGetSid(VtSession *z)
{
        return z->sid;
}

int
vtSetDebug(VtSession *z, int debug)
{
        int old;
        vtLock(z->lk);
        old = z->debug;
        z->debug = debug;
        vtUnlock(z->lk);
        return old;
}

int
vtSetFd(VtSession *z, int fd)
{
        vtLock(z->lk);
        if(z->cstate != VtStateAlloc) {
                vtSetError("bad state");
                vtUnlock(z->lk);
                return 0;
        }
        if(z->fd >= 0)
                vtFdClose(z->fd);
        z->fd = fd;
        vtUnlock(z->lk);
        return 1;
}

int
vtGetFd(VtSession *z)
{
        return z->fd;
}

int
vtSetCryptoStrength(VtSession *z, int c)
{
        if(z->cstate != VtStateAlloc) {
                vtSetError("bad state");
                return 0;
        }
        if(c != VtCryptoStrengthNone) {
                vtSetError("not supported yet");
                return 0;
        }
        return 1;
}

int
vtGetCryptoStrength(VtSession *s)
{
        return s->cryptoStrength;
}

int
vtSetCompression(VtSession *z, int fd)
{
        vtLock(z->lk);
        if(z->cstate != VtStateAlloc) {
                vtSetError("bad state");
                vtUnlock(z->lk);
                return 0;
        }
        z->fd = fd;
        vtUnlock(z->lk);
        return 1;
}

int
vtGetCompression(VtSession *s)
{
        return s->compression;
}

int
vtGetCrypto(VtSession *s)
{
        return s->crypto;
}

int
vtGetCodec(VtSession *s)
{
        return s->codec;
}

char *
vtGetVersion(VtSession *z)
{
        int v, i;
        
        v = z->version;
        if(v == 0)
                return "unknown";
        for(i=0; vtVersions[i].version; i++)
                if(vtVersions[i].version == v)
                        return vtVersions[i].s;
        assert(0);
        return 0;
}

/* hold z->inLock */
static int
vtVersionRead(VtSession *z, char *prefix, int *ret)
{
        char c;
        char buf[VtMaxStringSize];
        char *q, *p, *pp;
        int i;

        q = prefix;
        p = buf;
        for(;;) {
                if(p >= buf + sizeof(buf)) {
                        vtSetError(EBadVersion);
                        return 0;
                }
                if(!vtFdReadFully(z->fd, (uchar*)&c, 1))
                        return 0;
                if(z->inHash)
                        vtSha1Update(z->inHash, (uchar*)&c, 1);
                if(c == '\n') {
                        *p = 0;
                        break;
                }
                if(c < ' ' || *q && c != *q) {
                        vtSetError(EBadVersion);
                        return 0;
                }
                *p++ = c;
                if(*q)
                        q++;
        }
                
        vtDebug(z, "version string in: %s\n", buf);

        p = buf + strlen(prefix);
        for(;;) {
                for(pp=p; *pp && *pp != ':'  && *pp != '-'; pp++)
                        ;
                for(i=0; vtVersions[i].version; i++) {
                        if(strlen(vtVersions[i].s) != pp-p)
                                continue;
                        if(memcmp(vtVersions[i].s, p, pp-p) == 0) {
                                *ret = vtVersions[i].version;
                                return 1;
                        }
                }
                p = pp;
                if(*p != ':')
                        return 0;
                p++;
        }       
}

Packet*
vtRecvPacket(VtSession *z)
{
        uchar buf[10], *b;
        int n;
        Packet *p;
        int size, len;

        if(z->cstate != VtStateConnected) {
                vtSetError("session not connected");
                return 0;
        }

        vtLock(z->inLock);
        p = z->part;
        /* get enough for head size */
        size = packetSize(p);
        while(size < 2) {
                b = packetTrailer(p, MaxFragSize);
                assert(b != nil);
                n = vtFdRead(z->fd, b, MaxFragSize);
                if(n <= 0)
                        goto Err;
                size += n;
                packetTrim(p, 0, size);
        }

        if(!packetConsume(p, buf, 2))
                goto Err;
        len = (buf[0] << 8) | buf[1];
        size -= 2;

        while(size < len) {
                n = len - size;
                if(n > MaxFragSize)
                        n = MaxFragSize;
                b = packetTrailer(p, n);
                if(!vtFdReadFully(z->fd, b, n))
                        goto Err;
                size += n;
        }
        p = packetSplit(p, len);
        vtUnlock(z->inLock);
        return p;
Err:    
        vtUnlock(z->inLock);
        return nil;     
}

int
vtSendPacket(VtSession *z, Packet *p)
{
        IOchunk ioc;
        int n;
        uchar buf[2];
        
        /* add framing */
        n = packetSize(p);
        if(n >= (1<<16)) {
                vtSetError(EBigPacket);
                packetFree(p);
                return 0;
        }
        buf[0] = n>>8;
        buf[1] = n;
        packetPrefix(p, buf, 2);

        for(;;) {
                n = packetFragments(p, &ioc, 1, 0);
                if(n == 0)
                        break;
                if(!vtFdWrite(z->fd, ioc.addr, ioc.len)) {
                        packetFree(p);
                        return 0;
                }
                packetConsume(p, nil, n);
        }
        packetFree(p);
        return 1;
}


int
vtGetString(Packet *p, char **ret)
{
        uchar buf[2];
        int n;
        char *s;

        if(!packetConsume(p, buf, 2))
                return 0;
        n = (buf[0]<<8) + buf[1];
        if(n > VtMaxStringSize) {
                vtSetError(EBigString);
                return 0;
        }
        s = vtMemAlloc(n+1);
        setmalloctag(s, getcallerpc(&p));
        if(!packetConsume(p, (uchar*)s, n)) {
                vtMemFree(s);
                return 0;
        }
        s[n] = 0;
        *ret = s;
        return 1;
}

int
vtAddString(Packet *p, char *s)
{
        uchar buf[2];
        int n;

        if(s == nil) {
                vtSetError(ENullString);
                return 0;
        }
        n = strlen(s);
        if(n > VtMaxStringSize) {
                vtSetError(EBigString);
                return 0;
        }
        buf[0] = n>>8;
        buf[1] = n;
        packetAppend(p, buf, 2);
        packetAppend(p, (uchar*)s, n);
        return 1;
}

int
vtConnect(VtSession *z, char *password)
{
        char buf[VtMaxStringSize], *p, *ep, *prefix;
        int i;

        USED(password);
        vtLock(z->lk);
        if(z->cstate != VtStateAlloc) {
                vtSetError("bad session state");
                vtUnlock(z->lk);
                return 0;
        }
        if(z->fd < 0){
                vtSetError("%s", z->fderror);
                vtUnlock(z->lk);
                return 0;
        }

        /* be a little anal */
        vtLock(z->inLock);
        vtLock(z->outLock);

        prefix = "venti-";
        p = buf;
        ep = buf + sizeof(buf);
        p = seprint(p, ep, "%s", prefix);
        p += strlen(p);
        for(i=0; vtVersions[i].version; i++) {
                if(i != 0)
                        *p++ = ':';
                p = seprint(p, ep, "%s", vtVersions[i].s);
        }
        p = seprint(p, ep, "-libventi\n");
        assert(p-buf < sizeof(buf));
        if(z->outHash)
                vtSha1Update(z->outHash, (uchar*)buf, p-buf);
        if(!vtFdWrite(z->fd, (uchar*)buf, p-buf))
                goto Err;
        
        vtDebug(z, "version string out: %s", buf);

        if(!vtVersionRead(z, prefix, &z->version))
                goto Err;
                
        vtDebug(z, "version = %d: %s\n", z->version, vtGetVersion(z));

        vtUnlock(z->inLock);
        vtUnlock(z->outLock);
        z->cstate = VtStateConnected;
        vtUnlock(z->lk);

        if(z->vtbl)
                return 1;

        if(!vtHello(z))
                goto Err;
        return 1;       
Err:
        if(z->fd >= 0)
                vtFdClose(z->fd);
        z->fd = -1;
        vtUnlock(z->inLock);
        vtUnlock(z->outLock);
        z->cstate = VtStateClosed;
        vtUnlock(z->lk);
        return 0;       
}