Subversion Repositories planix.SVN

Rev

Rev 2 | Blame | Compare with Previous | Last modification | View Log | RSS feed

#include <u.h>
#include <libc.h>
#include <oventi.h>
#include "session.h"

static char EAuthState[] = "bad authentication state";
static char ENotServer[] = "not a server session";
static char EVersion[] = "incorrect version number";
static char EProtocolBotch[] = "venti protocol botch";

VtSession *
vtServerAlloc(VtServerVtbl *vtbl)
{
        VtSession *z = vtAlloc();
        z->vtbl = vtMemAlloc(sizeof(VtServerVtbl));
        setmalloctag(z->vtbl, getcallerpc(&vtbl));
        *z->vtbl = *vtbl;
        return z;
}

static int
srvHello(VtSession *z, char *version, char *uid, int , uchar *, int , uchar *, int )
{
        vtLock(z->lk);
        if(z->auth.state != VtAuthHello) {
                vtSetError(EAuthState);
                goto Err;
        }
        if(strcmp(version, vtGetVersion(z)) != 0) {
                vtSetError(EVersion);
                goto Err;
        }
        vtMemFree(z->uid);
        z->uid = vtStrDup(uid);
        z->auth.state = VtAuthOK;
        vtUnlock(z->lk);
        return 1;
Err:
        z->auth.state = VtAuthFailed;
        vtUnlock(z->lk);
        return 0;
}


static int
dispatchHello(VtSession *z, Packet **pkt)
{
        char *version, *uid;
        uchar *crypto, *codec;
        uchar buf[10];
        int ncrypto, ncodec, cryptoStrength;
        int ret;
        Packet *p;

        p = *pkt;

        version = nil;  
        uid = nil;
        crypto = nil;
        codec = nil;

        ret = 0;
        if(!vtGetString(p, &version))
                goto Err;
        if(!vtGetString(p, &uid))
                goto Err;
        if(!packetConsume(p, buf, 2))
                goto Err;
        cryptoStrength = buf[0];
        ncrypto = buf[1];
        crypto = vtMemAlloc(ncrypto);
        if(!packetConsume(p, crypto, ncrypto))
                goto Err;

        if(!packetConsume(p, buf, 1))
                goto Err;
        ncodec = buf[0];
        codec = vtMemAlloc(ncodec);
        if(!packetConsume(p, codec, ncodec))
                goto Err;

        if(packetSize(p) != 0) {
                vtSetError(EProtocolBotch);
                goto Err;
        }
        if(!srvHello(z, version, uid, cryptoStrength, crypto, ncrypto, codec, ncodec)) {
                packetFree(p);
                *pkt = nil;
        } else {
                if(!vtAddString(p, vtGetSid(z)))
                        goto Err;
                buf[0] = vtGetCrypto(z);
                buf[1] = vtGetCodec(z);
                packetAppend(p, buf, 2);
        }
        ret = 1;
Err:
        vtMemFree(version);
        vtMemFree(uid);
        vtMemFree(crypto);
        vtMemFree(codec);
        return ret;
}

static int
dispatchRead(VtSession *z, Packet **pkt)
{
        Packet *p;
        int type, n;
        uchar score[VtScoreSize], buf[4];

        p = *pkt;
        if(!packetConsume(p, score, VtScoreSize))
                return 0;
        if(!packetConsume(p, buf, 4))
                return 0;
        type = buf[0];
        n = (buf[2]<<8) | buf[3];
        if(packetSize(p) != 0) {
                vtSetError(EProtocolBotch);
                return 0;
        }
        packetFree(p);
        *pkt = (*z->vtbl->read)(z, score, type, n);
        return 1;
}

static int
dispatchWrite(VtSession *z, Packet **pkt)
{
        Packet *p;
        int type;
        uchar score[VtScoreSize], buf[4];

        p = *pkt;
        if(!packetConsume(p, buf, 4))
                return 0;
        type = buf[0];
        if(!(z->vtbl->write)(z, score, type, p)) {
                *pkt = 0;
        } else {
                *pkt = packetAlloc();
                packetAppend(*pkt, score, VtScoreSize);
        }
        return 1;
}

static int
dispatchSync(VtSession *z, Packet **pkt)
{
        (z->vtbl->sync)(z);
        if(packetSize(*pkt) != 0) {
                vtSetError(EProtocolBotch);
                return 0;
        }
        return 1;
}

int
vtExport(VtSession *z)
{
        Packet *p;
        uchar buf[10], *hdr;
        int op, tid, clean;

        if(z->vtbl == nil) {
                vtSetError(ENotServer);
                return 0;
        }

        /* fork off slave */
        switch(rfork(RFNOWAIT|RFMEM|RFPROC)){
        case -1:
                vtOSError();
                return 0;
        case 0:
                break;
        default:
                return 1;
        }

        
        p = nil;
        clean = 0;
        vtAttach();
        if(!vtConnect(z, nil))
                goto Exit;

        vtDebug(z, "server connected!\n");
if(0)   vtSetDebug(z, 1);

        for(;;) {
                p = vtRecvPacket(z);
                if(p == nil) {
                        break;
                }
                vtDebug(z, "server recv: ");
                vtDebugMesg(z, p, "\n");

                if(!packetConsume(p, buf, 2)) {
                        vtSetError(EProtocolBotch);
                        break;
                }
                op = buf[0];
                tid = buf[1];
                switch(op) {
                default:
                        vtSetError(EProtocolBotch);
                        goto Exit;
                case VtQPing:
                        break;
                case VtQGoodbye:
                        clean = 1;
                        goto Exit;
                case VtQHello:
                        if(!dispatchHello(z, &p))
                                goto Exit;
                        break;
                case VtQRead:
                        if(!dispatchRead(z, &p))
                                goto Exit;
                        break;
                case VtQWrite:
                        if(!dispatchWrite(z, &p))
                                goto Exit;
                        break;
                case VtQSync:
                        if(!dispatchSync(z, &p))
                                goto Exit;
                        break;
                }
                if(p != nil) {
                        hdr = packetHeader(p, 2);
                        hdr[0] = op+1;
                        hdr[1] = tid;
                } else {
                        p = packetAlloc();
                        hdr = packetHeader(p, 2);
                        hdr[0] = VtRError;
                        hdr[1] = tid;
                        if(!vtAddString(p, vtGetError()))
                                goto Exit;
                }

                vtDebug(z, "server send: ");
                vtDebugMesg(z, p, "\n");

                if(!vtSendPacket(z, p)) {
                        p = nil;
                        goto Exit;
                }
        }
Exit:
        if(p != nil)
                packetFree(p);
        if(z->vtbl->closing)
                z->vtbl->closing(z, clean);
        vtClose(z);
        vtFree(z);
        vtDetach();

        exits(0);
        return 0;       /* never gets here */
}