Subversion Repositories planix.SVN

Rev

Rev 33 | Blame | Compare with Previous | Last modification | View Log | RSS feed

#include "os.h"
#include <mp.h>
#include "dat.h"

//
//  from knuth's 1969 seminumberical algorithms, pp 233-235 and pp 258-260
//
//  mpvecmul is an assembly language routine that performs the inner
//  loop.
//
//  the karatsuba trade off is set empiricly by measuring the algs on
//  a 400 MHz Pentium II.
//

// karatsuba like (see knuth pg 258)
// prereq: p is already zeroed
static void
mpkaratsuba(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
{
        mpdigit *t, *u0, *u1, *v0, *v1, *u0v0, *u1v1, *res, *diffprod;
        int u0len, u1len, v0len, v1len, reslen;
        int sign, n;

        // divide each piece in half
        n = alen/2;
        if(alen&1)
                n++;
        u0len = n;
        u1len = alen-n;
        if(blen > n){
                v0len = n;
                v1len = blen-n;
        } else {
                v0len = blen;
                v1len = 0;
        }
        u0 = a;
        u1 = a + u0len;
        v0 = b;
        v1 = b + v0len;

        // room for the partial products
        t = mallocz(Dbytes*5*(2*n+1), 1);
        if(t == nil)
                sysfatal("mpkaratsuba: %r");
        u0v0 = t;
        u1v1 = t + (2*n+1);
        diffprod = t + 2*(2*n+1);
        res = t + 3*(2*n+1);
        reslen = 4*n+1;

        // t[0] = (u1-u0)
        sign = 1;
        if(mpveccmp(u1, u1len, u0, u0len) < 0){
                sign = -1;
                mpvecsub(u0, u0len, u1, u1len, u0v0);
        } else
                mpvecsub(u1, u1len, u0, u1len, u0v0);

        // t[1] = (v0-v1)
        if(mpveccmp(v0, v0len, v1, v1len) < 0){
                sign *= -1;
                mpvecsub(v1, v1len, v0, v1len, u1v1);
        } else
                mpvecsub(v0, v0len, v1, v1len, u1v1);

        // t[4:5] = (u1-u0)*(v0-v1)
        mpvecmul(u0v0, u0len, u1v1, v0len, diffprod);

        // t[0:1] = u1*v1
        memset(t, 0, 2*(2*n+1)*Dbytes);
        if(v1len > 0)
                mpvecmul(u1, u1len, v1, v1len, u1v1);

        // t[2:3] = u0v0
        mpvecmul(u0, u0len, v0, v0len, u0v0);

        // res = u0*v0<<n + u0*v0
        mpvecadd(res, reslen, u0v0, u0len+v0len, res);
        mpvecadd(res+n, reslen-n, u0v0, u0len+v0len, res+n);

        // res += u1*v1<<n + u1*v1<<2*n
        if(v1len > 0){
                mpvecadd(res+n, reslen-n, u1v1, u1len+v1len, res+n);
                mpvecadd(res+2*n, reslen-2*n, u1v1, u1len+v1len, res+2*n);
        }

        // res += (u1-u0)*(v0-v1)<<n
        if(sign < 0)
                mpvecsub(res+n, reslen-n, diffprod, u0len+v0len, res+n);
        else
                mpvecadd(res+n, reslen-n, diffprod, u0len+v0len, res+n);
        memmove(p, res, (alen+blen)*Dbytes);

        free(t);
}

#define KARATSUBAMIN 32

void
mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
{
        int i;
        mpdigit d;
        mpdigit *t;

        // both mpvecdigmuladd and karatsuba are fastest when a is the longer vector
        if(alen < blen){
                i = alen;
                alen = blen;
                blen = i;
                t = a;
                a = b;
                b = t;
        }

        if(alen >= KARATSUBAMIN && blen > 1){
                // O(n^1.585)
                mpkaratsuba(a, alen, b, blen, p);
        } else {
                // O(n^2)
                for(i = 0; i < blen; i++){
                        d = b[i];
                        if(d != 0)
                                mpvecdigmuladd(a, alen, d, &p[i]);
                }
        }
}

void
mpvectsmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
{
        int i;
        mpdigit *t;

        if(alen < blen){
                i = alen;
                alen = blen;
                blen = i;
                t = a;
                a = b;
                b = t;
        }
        if(blen == 0)
                return;
        for(i = 0; i < blen; i++)
                mpvecdigmuladd(a, alen, b[i], &p[i]);
}

void
mpmul(mpint *b1, mpint *b2, mpint *prod)
{
        mpint *oprod;

        oprod = prod;
        if(prod == b1 || prod == b2){
                prod = mpnew(0);
                prod->flags = oprod->flags;
        }
        prod->flags |= (b1->flags | b2->flags) & MPtimesafe;

        prod->top = 0;
        mpbits(prod, (b1->top+b2->top+1)*Dbits);
        if(prod->flags & MPtimesafe)
                mpvectsmul(b1->p, b1->top, b2->p, b2->top, prod->p);
        else
                mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p);
        prod->top = b1->top+b2->top+1;
        prod->sign = b1->sign*b2->sign;
        mpnorm(prod);

        if(oprod != prod){
                mpassign(prod, oprod);
                mpfree(prod);
        }
}