Subversion Repositories planix.SVN

Rev

Rev 26 | Blame | Compare with Previous | Last modification | View Log | RSS feed

#include "os.h"
#include <mp.h>
#include <libsec.h>
#include <ctype.h>

extern void jacobian_affine(mpint *p,
        mpint *X, mpint *Y, mpint *Z);
extern void jacobian_dbl(mpint *p, mpint *a,
        mpint *X1, mpint *Y1, mpint *Z1,
        mpint *X3, mpint *Y3, mpint *Z3);
extern void jacobian_add(mpint *p, mpint *a,
        mpint *X1, mpint *Y1, mpint *Z1,
        mpint *X2, mpint *Y2, mpint *Z2,
        mpint *X3, mpint *Y3, mpint *Z3);

void
ecassign(ECdomain *dom, ECpoint *a, ECpoint *b)
{
        if((b->inf = a->inf) != 0)
                return;
        mpassign(a->x, b->x);
        mpassign(a->y, b->y);
        if(b->z != nil){
                mpassign(a->z != nil ? a->z : mpone, b->z);
                return;
        }
        if(a->z != nil){
                b->z = mpcopy(a->z);
                jacobian_affine(dom->p, b->x, b->y, b->z);
                mpfree(b->z);
                b->z = nil;
        }
}

void
ecadd(ECdomain *dom, ECpoint *a, ECpoint *b, ECpoint *s)
{
        if(a->inf && b->inf){
                s->inf = 1;
                return;
        }
        if(a->inf){
                ecassign(dom, b, s);
                return;
        }
        if(b->inf){
                ecassign(dom, a, s);
                return;
        }

        if(s->z == nil){
                s->z = mpcopy(mpone);
                ecadd(dom, a, b, s);
                if(!s->inf)
                        jacobian_affine(dom->p, s->x, s->y, s->z);
                mpfree(s->z);
                s->z = nil;
                return;
        }

        if(a == b)
                jacobian_dbl(dom->p, dom->a,
                        a->x, a->y, a->z != nil ? a->z : mpone,
                        s->x, s->y, s->z);
        else
                jacobian_add(dom->p, dom->a,
                        a->x, a->y, a->z != nil ? a->z : mpone,
                        b->x, b->y, b->z != nil ? b->z : mpone,
                        s->x, s->y, s->z);
        s->inf = mpcmp(s->z, mpzero) == 0;
}

void
ecmul(ECdomain *dom, ECpoint *a, mpint *k, ECpoint *s)
{
        ECpoint ns, na;
        mpint *l;

        if(a->inf || mpcmp(k, mpzero) == 0){
                s->inf = 1;
                return;
        }
        ns.inf = 1;
        ns.x = mpnew(0);
        ns.y = mpnew(0);
        ns.z = mpnew(0);
        na.x = mpnew(0);
        na.y = mpnew(0);
        na.z = mpnew(0);
        ecassign(dom, a, &na);
        l = mpcopy(k);
        l->sign = 1;
        while(mpcmp(l, mpzero) != 0){
                if(l->p[0] & 1)
                        ecadd(dom, &na, &ns, &ns);
                ecadd(dom, &na, &na, &na);
                mpright(l, 1, l);
        }
        if(k->sign < 0 && !ns.inf){
                ns.y->sign = -1;
                mpmod(ns.y, dom->p, ns.y);
        }
        ecassign(dom, &ns, s);
        mpfree(ns.x);
        mpfree(ns.y);
        mpfree(ns.z);
        mpfree(na.x);
        mpfree(na.y);
        mpfree(na.z);
        mpfree(l);
}

int
ecverify(ECdomain *dom, ECpoint *a)
{
        mpint *p, *q;
        int r;

        if(a->inf)
                return 1;

        assert(a->z == nil);    /* need affine coordinates */
        p = mpnew(0);
        q = mpnew(0);
        mpmodmul(a->y, a->y, dom->p, p);
        mpmodmul(a->x, a->x, dom->p, q);
        mpmodadd(q, dom->a, dom->p, q);
        mpmodmul(q, a->x, dom->p, q);
        mpmodadd(q, dom->b, dom->p, q);
        r = mpcmp(p, q);
        mpfree(p);
        mpfree(q);
        return r == 0;
}

int
ecpubverify(ECdomain *dom, ECpub *a)
{
        ECpoint p;
        int r;

        if(a->inf)
                return 0;
        if(!ecverify(dom, a))
                return 0;
        p.x = mpnew(0);
        p.y = mpnew(0);
        p.z = mpnew(0);
        ecmul(dom, a, dom->n, &p);
        r = p.inf;
        mpfree(p.x);
        mpfree(p.y);
        mpfree(p.z);
        return r;
}

static void
fixnibble(uchar *a)
{
        if(*a >= 'a')
                *a -= 'a'-10;
        else if(*a >= 'A')
                *a -= 'A'-10;
        else
                *a -= '0';
}

static int
octet(char **s)
{
        uchar c, d;
        
        c = *(*s)++;
        if(!isxdigit(c))
                return -1;
        d = *(*s)++;
        if(!isxdigit(d))
                return -1;
        fixnibble(&c);
        fixnibble(&d);
        return (c << 4) | d;
}

static mpint*
halfpt(ECdomain *dom, char *s, char **rptr, mpint *out)
{
        char *buf, *r;
        int n;
        mpint *ret;
        
        n = ((mpsignif(dom->p)+7)/8)*2;
        if(strlen(s) < n)
                return 0;
        buf = malloc(n+1);
        buf[n] = 0;
        memcpy(buf, s, n);
        ret = strtomp(buf, &r, 16, out);
        *rptr = s + (r - buf);
        free(buf);
        return ret;
}

static int
mpleg(mpint *a, mpint *b)
{
        int r, k;
        mpint *m, *n, *t;
        
        r = 1;
        m = mpcopy(a);
        n = mpcopy(b);
        for(;;){
                if(mpcmp(m, n) > 0)
                        mpmod(m, n, m);
                if(mpcmp(m, mpzero) == 0){
                        r = 0;
                        break;
                }
                if(mpcmp(m, mpone) == 0)
                        break;
                k = mplowbits0(m);
                if(k > 0){
                        if(k & 1)
                                switch(n->p[0] & 15){
                                case 3: case 5: case 11: case 13:
                                        r = -r;
                                }
                        mpright(m, k, m);
                }
                if((n->p[0] & 3) == 3 && (m->p[0] & 3) == 3)
                        r = -r;
                t = m;
                m = n;
                n = t;
        }
        mpfree(m);
        mpfree(n);
        return r;
}

static int
mpsqrt(mpint *n, mpint *p, mpint *r)
{
        mpint *a, *t, *s, *xp, *xq, *yp, *yq, *zp, *zq, *N;

        if(mpleg(n, p) == -1)
                return 0;
        a = mpnew(0);
        t = mpnew(0);
        s = mpnew(0);
        N = mpnew(0);
        xp = mpnew(0);
        xq = mpnew(0);
        yp = mpnew(0);
        yq = mpnew(0);
        zp = mpnew(0);
        zq = mpnew(0);
        for(;;){
                for(;;){
                        mpnrand(p, genrandom, a);
                        if(mpcmp(a, mpzero) > 0)
                                break;
                }
                mpmul(a, a, t);
                mpsub(t, n, t);
                mpmod(t, p, t);
                if(mpleg(t, p) == -1)
                        break;
        }
        mpadd(p, mpone, N);
        mpright(N, 1, N);
        mpmul(a, a, t);
        mpsub(t, n, t);
        mpassign(a, xp);
        uitomp(1, xq);
        uitomp(1, yp);
        uitomp(0, yq);
        while(mpcmp(N, mpzero) != 0){
                if(N->p[0] & 1){
                        mpmul(xp, yp, zp);
                        mpmul(xq, yq, zq);
                        mpmul(zq, t, zq);
                        mpadd(zp, zq, zp);
                        mpmod(zp, p, zp);
                        mpmul(xp, yq, zq);
                        mpmul(xq, yp, s);
                        mpadd(zq, s, zq);
                        mpmod(zq, p, yq);
                        mpassign(zp, yp);
                }
                mpmul(xp, xp, zp);
                mpmul(xq, xq, zq);
                mpmul(zq, t, zq);
                mpadd(zp, zq, zp);
                mpmod(zp, p, zp);
                mpmul(xp, xq, zq);
                mpadd(zq, zq, zq);
                mpmod(zq, p, xq);
                mpassign(zp, xp);
                mpright(N, 1, N);
        }
        if(mpcmp(yq, mpzero) != 0)
                abort();
        mpassign(yp, r);
        mpfree(a);
        mpfree(t);
        mpfree(s);
        mpfree(N);
        mpfree(xp);
        mpfree(xq);
        mpfree(yp);
        mpfree(yq);
        mpfree(zp);
        mpfree(zq);
        return 1;
}

ECpoint*
strtoec(ECdomain *dom, char *s, char **rptr, ECpoint *ret)
{
        int allocd, o;
        mpint *r;

        allocd = 0;
        if(ret == nil){
                allocd = 1;
                ret = mallocz(sizeof(*ret), 1);
                if(ret == nil)
                        return nil;
                ret->x = mpnew(0);
                ret->y = mpnew(0);
        }
        ret->inf = 0;
        o = 0;
        switch(octet(&s)){
        case 0:
                ret->inf = 1;
                break;
        case 3:
                o = 1;
        case 2:
                if(halfpt(dom, s, &s, ret->x) == nil)
                        goto err;
                r = mpnew(0);
                mpmul(ret->x, ret->x, r);
                mpadd(r, dom->a, r);
                mpmul(r, ret->x, r);
                mpadd(r, dom->b, r);
                if(!mpsqrt(r, dom->p, r)){
                        mpfree(r);
                        goto err;
                }
                if((r->p[0] & 1) != o)
                        mpsub(dom->p, r, r);
                mpassign(r, ret->y);
                mpfree(r);
                if(!ecverify(dom, ret))
                        goto err;
                break;
        case 4:
                if(halfpt(dom, s, &s, ret->x) == nil)
                        goto err;
                if(halfpt(dom, s, &s, ret->y) == nil)
                        goto err;
                if(!ecverify(dom, ret))
                        goto err;
                break;
        }
        if(ret->z != nil && !ret->inf)
                mpassign(mpone, ret->z);
        return ret;

err:
        if(rptr)
                *rptr = s;
        if(allocd){
                mpfree(ret->x);
                mpfree(ret->y);
                free(ret);
        }
        return nil;
}

ECpriv*
ecgen(ECdomain *dom, ECpriv *p)
{
        if(p == nil){
                p = mallocz(sizeof(*p), 1);
                if(p == nil)
                        return nil;
                p->x = mpnew(0);
                p->y = mpnew(0);
                p->d = mpnew(0);
        }
        for(;;){
                mpnrand(dom->n, genrandom, p->d);
                if(mpcmp(p->d, mpzero) > 0)
                        break;
        }
        ecmul(dom, &dom->G, p->d, p);
        return p;
}

void
ecdsasign(ECdomain *dom, ECpriv *priv, uchar *dig, int len, mpint *r, mpint *s)
{
        ECpriv tmp;
        mpint *E, *t;

        tmp.x = mpnew(0);
        tmp.y = mpnew(0);
        tmp.z = nil;
        tmp.d = mpnew(0);
        E = betomp(dig, len, nil);
        t = mpnew(0);
        if(mpsignif(dom->n) < 8*len)
                mpright(E, 8*len - mpsignif(dom->n), E);
        for(;;){
                ecgen(dom, &tmp);
                mpmod(tmp.x, dom->n, r);
                if(mpcmp(r, mpzero) == 0)
                        continue;
                mpmul(r, priv->d, s);
                mpadd(E, s, s);
                mpinvert(tmp.d, dom->n, t);
                mpmodmul(s, t, dom->n, s);
                if(mpcmp(s, mpzero) != 0)
                        break;
        }
        mpfree(t);
        mpfree(E);
        mpfree(tmp.x);
        mpfree(tmp.y);
        mpfree(tmp.d);
}

int
ecdsaverify(ECdomain *dom, ECpub *pub, uchar *dig, int len, mpint *r, mpint *s)
{
        mpint *E, *t, *u1, *u2;
        ECpoint R, S;
        int ret;

        if(mpcmp(r, mpone) < 0 || mpcmp(s, mpone) < 0 || mpcmp(r, dom->n) >= 0 || mpcmp(r, dom->n) >= 0)
                return 0;
        E = betomp(dig, len, nil);
        if(mpsignif(dom->n) < 8*len)
                mpright(E, 8*len - mpsignif(dom->n), E);
        t = mpnew(0);
        u1 = mpnew(0);
        u2 = mpnew(0);
        R.x = mpnew(0);
        R.y = mpnew(0);
        R.z = mpnew(0);
        S.x = mpnew(0);
        S.y = mpnew(0);
        S.z = mpnew(0);
        mpinvert(s, dom->n, t);
        mpmodmul(E, t, dom->n, u1);
        mpmodmul(r, t, dom->n, u2);
        ecmul(dom, &dom->G, u1, &R);
        ecmul(dom, pub, u2, &S);
        ecadd(dom, &R, &S, &R);
        ret = 0;
        if(!R.inf){
                jacobian_affine(dom->p, R.x, R.y, R.z);
                mpmod(R.x, dom->n, t);
                ret = mpcmp(r, t) == 0;
        }
        mpfree(E);
        mpfree(t);
        mpfree(u1);
        mpfree(u2);
        mpfree(R.x);
        mpfree(R.y);
        mpfree(R.z);
        mpfree(S.x);
        mpfree(S.y);
        mpfree(S.z);
        return ret;
}

static char *code = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz";

void
base58enc(uchar *src, char *dst, int len)
{
        mpint *n, *r, *b;
        char *sdst, t;
        
        sdst = dst;
        n = betomp(src, len, nil);
        b = uitomp(58, nil);
        r = mpnew(0);
        while(mpcmp(n, mpzero) != 0){
                mpdiv(n, b, n, r);
                *dst++ = code[mptoui(r)];
        }
        for(; *src == 0; src++)
                *dst++ = code[0];
        dst--;
        while(dst > sdst){
                t = *sdst;
                *sdst++ = *dst;
                *dst-- = t;
        }
}

int
base58dec(char *src, uchar *dst, int len)
{
        mpint *n, *b, *r;
        char *t;
        
        n = mpnew(0);
        r = mpnew(0);
        b = uitomp(58, nil);
        for(; *src; src++){
                t = strchr(code, *src);
                if(t == nil){
                        mpfree(n);
                        mpfree(r);
                        mpfree(b);
                        werrstr("invalid base58 char");
                        return -1;
                }
                uitomp(t - code, r);
                mpmul(n, b, n);
                mpadd(n, r, n);
        }
        mptober(n, dst, len);
        mpfree(n);
        mpfree(r);
        mpfree(b);
        return 0;
}

void
ecdominit(ECdomain *dom, void (*init)(mpint *p, mpint *a, mpint *b, mpint *x, mpint *y, mpint *n, mpint *h))
{
        memset(dom, 0, sizeof(*dom));
        dom->p = mpnew(0);
        dom->a = mpnew(0);
        dom->b = mpnew(0);
        dom->G.x = mpnew(0);
        dom->G.y = mpnew(0);
        dom->n = mpnew(0);
        dom->h = mpnew(0);
        if(init){
                (*init)(dom->p, dom->a, dom->b, dom->G.x, dom->G.y, dom->n, dom->h);
                dom->p = mpfield(dom->p);
        }
}

void
ecdomfree(ECdomain *dom)
{
        mpfree(dom->p);
        mpfree(dom->a);
        mpfree(dom->b);
        mpfree(dom->G.x);
        mpfree(dom->G.y);
        mpfree(dom->n);
        mpfree(dom->h);
        memset(dom, 0, sizeof(*dom));
}

int
ecencodepub(ECdomain *dom, ECpub *pub, uchar *data, int len)
{
        int n;

        n = (mpsignif(dom->p)+7)/8;
        if(len < 1 + 2*n)
                return 0;
        len = 1 + 2*n;
        data[0] = 0x04;
        mptober(pub->x, data+1, n);
        mptober(pub->y, data+1+n, n);
        return len;
}

ECpub*
ecdecodepub(ECdomain *dom, uchar *data, int len)
{
        ECpub *pub;
        int n;

        n = (mpsignif(dom->p)+7)/8;
        if(len != 1 + 2*n || data[0] != 0x04)
                return nil;
        pub = mallocz(sizeof(*pub), 1);
        if(pub == nil)
                return nil;
        pub->x = betomp(data+1, n, nil);
        pub->y = betomp(data+1+n, n, nil);
        if(!ecpubverify(dom, pub)){
                ecpubfree(pub);
                pub = nil;
        }
        return pub;
}

void
ecpubfree(ECpub *p)
{
        if(p == nil)
                return;
        mpfree(p->x);
        mpfree(p->y);
        free(p);
}