Subversion Repositories planix.SVN

Rev

Rev 33 | Details | Compare with Previous | Last modification | View Log | RSS feed

Rev Author Line No. Line
2 - 1
#include "os.h"
2
#include <mp.h>
3
#include "dat.h"
4
 
5
//
6
//  from knuth's 1969 seminumberical algorithms, pp 233-235 and pp 258-260
7
//
8
//  mpvecmul is an assembly language routine that performs the inner
9
//  loop.
10
//
11
//  the karatsuba trade off is set empiricly by measuring the algs on
12
//  a 400 MHz Pentium II.
13
//
14
 
15
// karatsuba like (see knuth pg 258)
16
// prereq: p is already zeroed
17
static void
18
mpkaratsuba(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
19
{
20
	mpdigit *t, *u0, *u1, *v0, *v1, *u0v0, *u1v1, *res, *diffprod;
21
	int u0len, u1len, v0len, v1len, reslen;
22
	int sign, n;
23
 
24
	// divide each piece in half
25
	n = alen/2;
26
	if(alen&1)
27
		n++;
28
	u0len = n;
29
	u1len = alen-n;
30
	if(blen > n){
31
		v0len = n;
32
		v1len = blen-n;
33
	} else {
34
		v0len = blen;
35
		v1len = 0;
36
	}
37
	u0 = a;
38
	u1 = a + u0len;
39
	v0 = b;
40
	v1 = b + v0len;
41
 
42
	// room for the partial products
43
	t = mallocz(Dbytes*5*(2*n+1), 1);
44
	if(t == nil)
45
		sysfatal("mpkaratsuba: %r");
46
	u0v0 = t;
47
	u1v1 = t + (2*n+1);
48
	diffprod = t + 2*(2*n+1);
49
	res = t + 3*(2*n+1);
50
	reslen = 4*n+1;
51
 
52
	// t[0] = (u1-u0)
53
	sign = 1;
54
	if(mpveccmp(u1, u1len, u0, u0len) < 0){
55
		sign = -1;
56
		mpvecsub(u0, u0len, u1, u1len, u0v0);
57
	} else
58
		mpvecsub(u1, u1len, u0, u1len, u0v0);
59
 
60
	// t[1] = (v0-v1)
61
	if(mpveccmp(v0, v0len, v1, v1len) < 0){
62
		sign *= -1;
63
		mpvecsub(v1, v1len, v0, v1len, u1v1);
64
	} else
65
		mpvecsub(v0, v0len, v1, v1len, u1v1);
66
 
67
	// t[4:5] = (u1-u0)*(v0-v1)
68
	mpvecmul(u0v0, u0len, u1v1, v0len, diffprod);
69
 
70
	// t[0:1] = u1*v1
71
	memset(t, 0, 2*(2*n+1)*Dbytes);
72
	if(v1len > 0)
73
		mpvecmul(u1, u1len, v1, v1len, u1v1);
74
 
75
	// t[2:3] = u0v0
76
	mpvecmul(u0, u0len, v0, v0len, u0v0);
77
 
78
	// res = u0*v0<<n + u0*v0
79
	mpvecadd(res, reslen, u0v0, u0len+v0len, res);
80
	mpvecadd(res+n, reslen-n, u0v0, u0len+v0len, res+n);
81
 
82
	// res += u1*v1<<n + u1*v1<<2*n
83
	if(v1len > 0){
84
		mpvecadd(res+n, reslen-n, u1v1, u1len+v1len, res+n);
85
		mpvecadd(res+2*n, reslen-2*n, u1v1, u1len+v1len, res+2*n);
86
	}
87
 
88
	// res += (u1-u0)*(v0-v1)<<n
89
	if(sign < 0)
90
		mpvecsub(res+n, reslen-n, diffprod, u0len+v0len, res+n);
91
	else
92
		mpvecadd(res+n, reslen-n, diffprod, u0len+v0len, res+n);
93
	memmove(p, res, (alen+blen)*Dbytes);
94
 
95
	free(t);
96
}
97
 
98
#define KARATSUBAMIN 32
99
 
100
void
101
mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
102
{
103
	int i;
104
	mpdigit d;
105
	mpdigit *t;
106
 
107
	// both mpvecdigmuladd and karatsuba are fastest when a is the longer vector
108
	if(alen < blen){
109
		i = alen;
110
		alen = blen;
111
		blen = i;
112
		t = a;
113
		a = b;
114
		b = t;
115
	}
116
 
117
	if(alen >= KARATSUBAMIN && blen > 1){
118
		// O(n^1.585)
119
		mpkaratsuba(a, alen, b, blen, p);
120
	} else {
121
		// O(n^2)
122
		for(i = 0; i < blen; i++){
123
			d = b[i];
124
			if(d != 0)
125
				mpvecdigmuladd(a, alen, d, &p[i]);
126
		}
127
	}
128
}
129
 
130
void
33 7u83 131
mpvectsmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
132
{
133
	int i;
134
	mpdigit *t;
135
 
136
	if(alen < blen){
137
		i = alen;
138
		alen = blen;
139
		blen = i;
140
		t = a;
141
		a = b;
142
		b = t;
143
	}
144
	if(blen == 0)
145
		return;
146
	for(i = 0; i < blen; i++)
147
		mpvecdigmuladd(a, alen, b[i], &p[i]);
148
}
149
 
150
void
2 - 151
mpmul(mpint *b1, mpint *b2, mpint *prod)
152
{
153
	mpint *oprod;
154
 
33 7u83 155
	oprod = prod;
2 - 156
	if(prod == b1 || prod == b2){
157
		prod = mpnew(0);
33 7u83 158
		prod->flags = oprod->flags;
2 - 159
	}
33 7u83 160
	prod->flags |= (b1->flags | b2->flags) & MPtimesafe;
2 - 161
 
162
	prod->top = 0;
163
	mpbits(prod, (b1->top+b2->top+1)*Dbits);
33 7u83 164
	if(prod->flags & MPtimesafe)
165
		mpvectsmul(b1->p, b1->top, b2->p, b2->top, prod->p);
166
	else
167
		mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p);
2 - 168
	prod->top = b1->top+b2->top+1;
169
	prod->sign = b1->sign*b2->sign;
170
	mpnorm(prod);
171
 
33 7u83 172
	if(oprod != prod){
2 - 173
		mpassign(prod, oprod);
174
		mpfree(prod);
175
	}
176
}