Subversion Repositories planix.SVN

Rev

Go to most recent revision | Details | Last modification | View Log | RSS feed

Rev Author Line No. Line
2 - 1
#include <u.h>
2
#include <libc.h>
3
#include <bio.h>
4
#include <auth.h>
5
#include <mp.h>
6
#include <libsec.h>
7
 
8
// The main groups of functions are:
9
//		client/server - main handshake protocol definition
10
//		message functions - formating handshake messages
11
//		cipher choices - catalog of digest and encrypt algorithms
12
//		security functions - PKCS#1, sslHMAC, session keygen
13
//		general utility functions - malloc, serialization
14
// The handshake protocol builds on the TLS/SSL3 record layer protocol,
15
// which is implemented in kernel device #a.  See also /lib/rfc/rfc2246.
16
 
17
enum {
18
	TLSFinishedLen = 12,
19
	SSL3FinishedLen = MD5dlen+SHA1dlen,
20
	MaxKeyData = 136,	// amount of secret we may need
21
	MaxChunk = 1<<14,
22
	RandomSize = 32,
23
	SidSize = 32,
24
	MasterSecretSize = 48,
25
	AQueue = 0,
26
	AFlush = 1,
27
};
28
 
29
typedef struct TlsSec TlsSec;
30
 
31
typedef struct Bytes{
32
	int len;
33
	uchar data[1];  // [len]
34
} Bytes;
35
 
36
typedef struct Ints{
37
	int len;
38
	int data[1];  // [len]
39
} Ints;
40
 
41
typedef struct Algs{
42
	char *enc;
43
	char *digest;
44
	int nsecret;
45
	int tlsid;
46
	int ok;
47
} Algs;
48
 
49
typedef struct Finished{
50
	uchar verify[SSL3FinishedLen];
51
	int n;
52
} Finished;
53
 
54
typedef struct TlsConnection{
55
	TlsSec *sec;	// security management goo
56
	int hand, ctl;	// record layer file descriptors
57
	int erred;		// set when tlsError called
58
	int (*trace)(char*fmt, ...); // for debugging
59
	int version;	// protocol we are speaking
60
	int verset;		// version has been set
61
	int ver2hi;		// server got a version 2 hello
62
	int isClient;	// is this the client or server?
63
	Bytes *sid;		// SessionID
64
	Bytes *cert;	// only last - no chain
65
 
66
	Lock statelk;
67
	int state;		// must be set using setstate
68
 
69
	// input buffer for handshake messages
70
	uchar buf[MaxChunk+2048];
71
	uchar *rp, *ep;
72
 
73
	uchar crandom[RandomSize];	// client random
74
	uchar srandom[RandomSize];	// server random
75
	int clientVersion;	// version in ClientHello
76
	char *digest;	// name of digest algorithm to use
77
	char *enc;		// name of encryption algorithm to use
78
	int nsecret;	// amount of secret data to init keys
79
 
80
	// for finished messages
81
	MD5state	hsmd5;	// handshake hash
82
	SHAstate	hssha1;	// handshake hash
83
	Finished	finished;
84
} TlsConnection;
85
 
86
typedef struct Msg{
87
	int tag;
88
	union {
89
		struct {
90
			int version;
91
			uchar 	random[RandomSize];
92
			Bytes*	sid;
93
			Ints*	ciphers;
94
			Bytes*	compressors;
95
		} clientHello;
96
		struct {
97
			int version;
98
			uchar 	random[RandomSize];
99
			Bytes*	sid;
100
			int cipher;
101
			int compressor;
102
		} serverHello;
103
		struct {
104
			int ncert;
105
			Bytes **certs;
106
		} certificate;
107
		struct {
108
			Bytes *types;
109
			int nca;
110
			Bytes **cas;
111
		} certificateRequest;
112
		struct {
113
			Bytes *key;
114
		} clientKeyExchange;
115
		Finished finished;
116
	} u;
117
} Msg;
118
 
119
typedef struct TlsSec{
120
	char *server;	// name of remote; nil for server
121
	int ok;	// <0 killed; == 0 in progress; >0 reusable
122
	RSApub *rsapub;
123
	AuthRpc *rpc;	// factotum for rsa private key
124
	uchar sec[MasterSecretSize];	// master secret
125
	uchar crandom[RandomSize];	// client random
126
	uchar srandom[RandomSize];	// server random
127
	int clientVers;		// version in ClientHello
128
	int vers;			// final version
129
	// byte generation and handshake checksum
130
	void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
131
	void (*setFinished)(TlsSec*, MD5state, SHAstate, uchar*, int);
132
	int nfin;
133
} TlsSec;
134
 
135
 
136
enum {
137
	TLSVersion = 0x0301,
138
	SSL3Version = 0x0300,
139
	ProtocolVersion = 0x0301,	// maximum version we speak
140
	MinProtoVersion = 0x0300,	// limits on version we accept
141
	MaxProtoVersion	= 0x03ff,
142
};
143
 
144
// handshake type
145
enum {
146
	HHelloRequest,
147
	HClientHello,
148
	HServerHello,
149
	HSSL2ClientHello = 9,  /* local convention;  see devtls.c */
150
	HCertificate = 11,
151
	HServerKeyExchange,
152
	HCertificateRequest,
153
	HServerHelloDone,
154
	HCertificateVerify,
155
	HClientKeyExchange,
156
	HFinished = 20,
157
	HMax
158
};
159
 
160
// alerts
161
enum {
162
	ECloseNotify = 0,
163
	EUnexpectedMessage = 10,
164
	EBadRecordMac = 20,
165
	EDecryptionFailed = 21,
166
	ERecordOverflow = 22,
167
	EDecompressionFailure = 30,
168
	EHandshakeFailure = 40,
169
	ENoCertificate = 41,
170
	EBadCertificate = 42,
171
	EUnsupportedCertificate = 43,
172
	ECertificateRevoked = 44,
173
	ECertificateExpired = 45,
174
	ECertificateUnknown = 46,
175
	EIllegalParameter = 47,
176
	EUnknownCa = 48,
177
	EAccessDenied = 49,
178
	EDecodeError = 50,
179
	EDecryptError = 51,
180
	EExportRestriction = 60,
181
	EProtocolVersion = 70,
182
	EInsufficientSecurity = 71,
183
	EInternalError = 80,
184
	EUserCanceled = 90,
185
	ENoRenegotiation = 100,
186
	EMax = 256
187
};
188
 
189
// cipher suites
190
enum {
191
	TLS_NULL_WITH_NULL_NULL	 		= 0x0000,
192
	TLS_RSA_WITH_NULL_MD5 			= 0x0001,
193
	TLS_RSA_WITH_NULL_SHA 			= 0x0002,
194
	TLS_RSA_EXPORT_WITH_RC4_40_MD5 		= 0x0003,
195
	TLS_RSA_WITH_RC4_128_MD5 		= 0x0004,
196
	TLS_RSA_WITH_RC4_128_SHA 		= 0x0005,
197
	TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5	= 0X0006,
198
	TLS_RSA_WITH_IDEA_CBC_SHA 		= 0X0007,
199
	TLS_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X0008,
200
	TLS_RSA_WITH_DES_CBC_SHA		= 0X0009,
201
	TLS_RSA_WITH_3DES_EDE_CBC_SHA		= 0X000A,
202
	TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA	= 0X000B,
203
	TLS_DH_DSS_WITH_DES_CBC_SHA		= 0X000C,
204
	TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA	= 0X000D,
205
	TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X000E,
206
	TLS_DH_RSA_WITH_DES_CBC_SHA		= 0X000F,
207
	TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA	= 0X0010,
208
	TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA	= 0X0011,
209
	TLS_DHE_DSS_WITH_DES_CBC_SHA		= 0X0012,
210
	TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA	= 0X0013,	// ZZZ must be implemented for tls1.0 compliance
211
	TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X0014,
212
	TLS_DHE_RSA_WITH_DES_CBC_SHA		= 0X0015,
213
	TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA	= 0X0016,
214
	TLS_DH_anon_EXPORT_WITH_RC4_40_MD5	= 0x0017,
215
	TLS_DH_anon_WITH_RC4_128_MD5 		= 0x0018,
216
	TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA	= 0X0019,
217
	TLS_DH_anon_WITH_DES_CBC_SHA		= 0X001A,
218
	TLS_DH_anon_WITH_3DES_EDE_CBC_SHA	= 0X001B,
219
 
220
	TLS_RSA_WITH_AES_128_CBC_SHA		= 0X002f,	// aes, aka rijndael with 128 bit blocks
221
	TLS_DH_DSS_WITH_AES_128_CBC_SHA		= 0X0030,
222
	TLS_DH_RSA_WITH_AES_128_CBC_SHA		= 0X0031,
223
	TLS_DHE_DSS_WITH_AES_128_CBC_SHA	= 0X0032,
224
	TLS_DHE_RSA_WITH_AES_128_CBC_SHA	= 0X0033,
225
	TLS_DH_anon_WITH_AES_128_CBC_SHA	= 0X0034,
226
	TLS_RSA_WITH_AES_256_CBC_SHA		= 0X0035,
227
	TLS_DH_DSS_WITH_AES_256_CBC_SHA		= 0X0036,
228
	TLS_DH_RSA_WITH_AES_256_CBC_SHA		= 0X0037,
229
	TLS_DHE_DSS_WITH_AES_256_CBC_SHA	= 0X0038,
230
	TLS_DHE_RSA_WITH_AES_256_CBC_SHA	= 0X0039,
231
	TLS_DH_anon_WITH_AES_256_CBC_SHA	= 0X003A,
232
	CipherMax
233
};
234
 
235
// compression methods
236
enum {
237
	CompressionNull = 0,
238
	CompressionMax
239
};
240
 
241
static Algs cipherAlgs[] = {
242
	{"rc4_128", "md5", 2*(16+MD5dlen), TLS_RSA_WITH_RC4_128_MD5},
243
	{"rc4_128", "sha1", 2*(16+SHA1dlen), TLS_RSA_WITH_RC4_128_SHA},
244
	{"3des_ede_cbc", "sha1", 2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
245
	{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_RSA_WITH_AES_128_CBC_SHA},
246
	{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA}
247
};
248
 
249
static uchar compressors[] = {
250
	CompressionNull,
251
};
252
 
253
static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chain);
254
static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...));
255
 
256
static void	msgClear(Msg *m);
257
static char* msgPrint(char *buf, int n, Msg *m);
258
static int	msgRecv(TlsConnection *c, Msg *m);
259
static int	msgSend(TlsConnection *c, Msg *m, int act);
260
static void	tlsError(TlsConnection *c, int err, char *msg, ...);
261
#pragma	varargck argpos	tlsError 3
262
static int setVersion(TlsConnection *c, int version);
263
static int finishedMatch(TlsConnection *c, Finished *f);
264
static void tlsConnectionFree(TlsConnection *c);
265
 
266
static int setAlgs(TlsConnection *c, int a);
267
static int okCipher(Ints *cv);
268
static int okCompression(Bytes *cv);
269
static int initCiphers(void);
270
static Ints* makeciphers(void);
271
 
272
static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
273
static int	tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd);
274
static TlsSec*	tlsSecInitc(int cvers, uchar *crandom);
275
static int	tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd);
276
static int	tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient);
277
static void	tlsSecOk(TlsSec *sec);
278
static void	tlsSecKill(TlsSec *sec);
279
static void	tlsSecClose(TlsSec *sec);
280
static void	setMasterSecret(TlsSec *sec, Bytes *pm);
281
static void	serverMasterSecret(TlsSec *sec, uchar *epm, int nepm);
282
static void	setSecrets(TlsSec *sec, uchar *kd, int nkd);
283
static int	clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm);
284
static Bytes *pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
285
static Bytes *pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm);
286
static void	tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
287
static void	sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
288
static void	sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
289
			uchar *seed0, int nseed0, uchar *seed1, int nseed1);
290
static int setVers(TlsSec *sec, int version);
291
 
292
static AuthRpc* factotum_rsa_open(uchar *cert, int certlen);
293
static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher);
294
static void factotum_rsa_close(AuthRpc*rpc);
295
 
296
static void* emalloc(int);
297
static void* erealloc(void*, int);
298
static void put32(uchar *p, u32int);
299
static void put24(uchar *p, int);
300
static void put16(uchar *p, int);
301
static u32int get32(uchar *p);
302
static int get24(uchar *p);
303
static int get16(uchar *p);
304
static Bytes* newbytes(int len);
305
static Bytes* makebytes(uchar* buf, int len);
306
static void freebytes(Bytes* b);
307
static Ints* newints(int len);
308
static Ints* makeints(int* buf, int len);
309
static void freeints(Ints* b);
310
 
311
//================= client/server ========================
312
 
313
//	push TLS onto fd, returning new (application) file descriptor
314
//		or -1 if error.
315
int
316
tlsServer(int fd, TLSconn *conn)
317
{
318
	char buf[8];
319
	char dname[64];
320
	int n, data, ctl, hand;
321
	TlsConnection *tls;
322
 
323
	if(conn == nil)
324
		return -1;
325
	ctl = open("#a/tls/clone", ORDWR);
326
	if(ctl < 0)
327
		return -1;
328
	n = read(ctl, buf, sizeof(buf)-1);
329
	if(n < 0){
330
		close(ctl);
331
		return -1;
332
	}
333
	buf[n] = 0;
334
	sprint(conn->dir, "#a/tls/%s", buf);
335
	sprint(dname, "#a/tls/%s/hand", buf);
336
	hand = open(dname, ORDWR);
337
	if(hand < 0){
338
		close(ctl);
339
		return -1;
340
	}
341
	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
342
	tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace, conn->chain);
343
	sprint(dname, "#a/tls/%s/data", buf);
344
	data = open(dname, ORDWR);
345
	close(fd);
346
	close(hand);
347
	close(ctl);
348
	if(data < 0){
349
		return -1;
350
	}
351
	if(tls == nil){
352
		close(data);
353
		return -1;
354
	}
355
	if(conn->cert)
356
		free(conn->cert);
357
	conn->cert = 0;  // client certificates are not yet implemented
358
	conn->certlen = 0;
359
	conn->sessionIDlen = tls->sid->len;
360
	conn->sessionID = emalloc(conn->sessionIDlen);
361
	memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
362
	if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0)
363
		tls->sec->prf(conn->sessionKey, conn->sessionKeylen, tls->sec->sec, MasterSecretSize, conn->sessionConst,  tls->sec->crandom, RandomSize, tls->sec->srandom, RandomSize);
364
	tlsConnectionFree(tls);
365
	return data;
366
}
367
 
368
//	push TLS onto fd, returning new (application) file descriptor
369
//		or -1 if error.
370
int
371
tlsClient(int fd, TLSconn *conn)
372
{
373
	char buf[8];
374
	char dname[64];
375
	int n, data, ctl, hand;
376
	TlsConnection *tls;
377
 
378
	if(!conn)
379
		return -1;
380
	ctl = open("#a/tls/clone", ORDWR);
381
	if(ctl < 0)
382
		return -1;
383
	n = read(ctl, buf, sizeof(buf)-1);
384
	if(n < 0){
385
		close(ctl);
386
		return -1;
387
	}
388
	buf[n] = 0;
389
	sprint(conn->dir, "#a/tls/%s", buf);
390
	sprint(dname, "#a/tls/%s/hand", buf);
391
	hand = open(dname, ORDWR);
392
	if(hand < 0){
393
		close(ctl);
394
		return -1;
395
	}
396
	sprint(dname, "#a/tls/%s/data", buf);
397
	data = open(dname, ORDWR);
398
	if(data < 0)
399
		return -1;
400
	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
401
	tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->trace);
402
	close(fd);
403
	close(hand);
404
	close(ctl);
405
	if(tls == nil){
406
		close(data);
407
		return -1;
408
	}
409
	conn->certlen = tls->cert->len;
410
	conn->cert = emalloc(conn->certlen);
411
	memcpy(conn->cert, tls->cert->data, conn->certlen);
412
	conn->sessionIDlen = tls->sid->len;
413
	conn->sessionID = emalloc(conn->sessionIDlen);
414
	memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
415
	if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0)
416
		tls->sec->prf(conn->sessionKey, conn->sessionKeylen, tls->sec->sec, MasterSecretSize, conn->sessionConst,  tls->sec->crandom, RandomSize, tls->sec->srandom, RandomSize);
417
	tlsConnectionFree(tls);
418
	return data;
419
}
420
 
421
static int
422
countchain(PEMChain *p)
423
{
424
	int i = 0;
425
 
426
	while (p) {
427
		i++;
428
		p = p->next;
429
	}
430
	return i;
431
}
432
 
433
static TlsConnection *
434
tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chp)
435
{
436
	TlsConnection *c;
437
	Msg m;
438
	Bytes *csid;
439
	uchar sid[SidSize], kd[MaxKeyData];
440
	char *secrets;
441
	int cipher, compressor, nsid, rv, numcerts, i;
442
 
443
	if(trace)
444
		trace("tlsServer2\n");
445
	if(!initCiphers())
446
		return nil;
447
	c = emalloc(sizeof(TlsConnection));
448
	c->ctl = ctl;
449
	c->hand = hand;
450
	c->trace = trace;
451
	c->version = ProtocolVersion;
452
 
453
	memset(&m, 0, sizeof(m));
454
	if(!msgRecv(c, &m)){
455
		if(trace)
456
			trace("initial msgRecv failed\n");
457
		goto Err;
458
	}
459
	if(m.tag != HClientHello) {
460
		tlsError(c, EUnexpectedMessage, "expected a client hello");
461
		goto Err;
462
	}
463
	c->clientVersion = m.u.clientHello.version;
464
	if(trace)
465
		trace("ClientHello version %x\n", c->clientVersion);
466
	if(setVersion(c, m.u.clientHello.version) < 0) {
467
		tlsError(c, EIllegalParameter, "incompatible version");
468
		goto Err;
469
	}
470
 
471
	memmove(c->crandom, m.u.clientHello.random, RandomSize);
472
	cipher = okCipher(m.u.clientHello.ciphers);
473
	if(cipher < 0) {
474
		// reply with EInsufficientSecurity if we know that's the case
475
		if(cipher == -2)
476
			tlsError(c, EInsufficientSecurity, "cipher suites too weak");
477
		else
478
			tlsError(c, EHandshakeFailure, "no matching cipher suite");
479
		goto Err;
480
	}
481
	if(!setAlgs(c, cipher)){
482
		tlsError(c, EHandshakeFailure, "no matching cipher suite");
483
		goto Err;
484
	}
485
	compressor = okCompression(m.u.clientHello.compressors);
486
	if(compressor < 0) {
487
		tlsError(c, EHandshakeFailure, "no matching compressor");
488
		goto Err;
489
	}
490
 
491
	csid = m.u.clientHello.sid;
492
	if(trace)
493
		trace("  cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len);
494
	c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom);
495
	if(c->sec == nil){
496
		tlsError(c, EHandshakeFailure, "can't initialize security: %r");
497
		goto Err;
498
	}
499
	c->sec->rpc = factotum_rsa_open(cert, ncert);
500
	if(c->sec->rpc == nil){
501
		tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
502
		goto Err;
503
	}
504
	c->sec->rsapub = X509toRSApub(cert, ncert, nil, 0);
505
	msgClear(&m);
506
 
507
	m.tag = HServerHello;
508
	m.u.serverHello.version = c->version;
509
	memmove(m.u.serverHello.random, c->srandom, RandomSize);
510
	m.u.serverHello.cipher = cipher;
511
	m.u.serverHello.compressor = compressor;
512
	c->sid = makebytes(sid, nsid);
513
	m.u.serverHello.sid = makebytes(c->sid->data, c->sid->len);
514
	if(!msgSend(c, &m, AQueue))
515
		goto Err;
516
	msgClear(&m);
517
 
518
	m.tag = HCertificate;
519
	numcerts = countchain(chp);
520
	m.u.certificate.ncert = 1 + numcerts;
521
	m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes));
522
	m.u.certificate.certs[0] = makebytes(cert, ncert);
523
	for (i = 0; i < numcerts && chp; i++, chp = chp->next)
524
		m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
525
	if(!msgSend(c, &m, AQueue))
526
		goto Err;
527
	msgClear(&m);
528
 
529
	m.tag = HServerHelloDone;
530
	if(!msgSend(c, &m, AFlush))
531
		goto Err;
532
	msgClear(&m);
533
 
534
	if(!msgRecv(c, &m))
535
		goto Err;
536
	if(m.tag != HClientKeyExchange) {
537
		tlsError(c, EUnexpectedMessage, "expected a client key exchange");
538
		goto Err;
539
	}
540
	if(tlsSecSecrets(c->sec, c->version, m.u.clientKeyExchange.key->data, m.u.clientKeyExchange.key->len, kd, c->nsecret) < 0){
541
		tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
542
		goto Err;
543
	}
544
	if(trace)
545
		trace("tls secrets\n");
546
	secrets = (char*)emalloc(2*c->nsecret);
547
	enc64(secrets, 2*c->nsecret, kd, c->nsecret);
548
	rv = fprint(c->ctl, "secret %s %s 0 %s", c->digest, c->enc, secrets);
549
	memset(secrets, 0, 2*c->nsecret);
550
	free(secrets);
551
	memset(kd, 0, c->nsecret);
552
	if(rv < 0){
553
		tlsError(c, EHandshakeFailure, "can't set keys: %r");
554
		goto Err;
555
	}
556
	msgClear(&m);
557
 
558
	/* no CertificateVerify; skip to Finished */
559
	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
560
		tlsError(c, EInternalError, "can't set finished: %r");
561
		goto Err;
562
	}
563
	if(!msgRecv(c, &m))
564
		goto Err;
565
	if(m.tag != HFinished) {
566
		tlsError(c, EUnexpectedMessage, "expected a finished");
567
		goto Err;
568
	}
569
	if(!finishedMatch(c, &m.u.finished)) {
570
		tlsError(c, EHandshakeFailure, "finished verification failed");
571
		goto Err;
572
	}
573
	msgClear(&m);
574
 
575
	/* change cipher spec */
576
	if(fprint(c->ctl, "changecipher") < 0){
577
		tlsError(c, EInternalError, "can't enable cipher: %r");
578
		goto Err;
579
	}
580
 
581
	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
582
		tlsError(c, EInternalError, "can't set finished: %r");
583
		goto Err;
584
	}
585
	m.tag = HFinished;
586
	m.u.finished = c->finished;
587
	if(!msgSend(c, &m, AFlush))
588
		goto Err;
589
	msgClear(&m);
590
	if(trace)
591
		trace("tls finished\n");
592
 
593
	if(fprint(c->ctl, "opened") < 0)
594
		goto Err;
595
	tlsSecOk(c->sec);
596
	return c;
597
 
598
Err:
599
	msgClear(&m);
600
	tlsConnectionFree(c);
601
	return 0;
602
}
603
 
604
static TlsConnection *
605
tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...))
606
{
607
	TlsConnection *c;
608
	Msg m;
609
	uchar kd[MaxKeyData], *epm;
610
	char *secrets;
611
	int creq, nepm, rv;
612
 
613
	if(!initCiphers())
614
		return nil;
615
	epm = nil;
616
	c = emalloc(sizeof(TlsConnection));
617
	c->version = ProtocolVersion;
618
	c->ctl = ctl;
619
	c->hand = hand;
620
	c->trace = trace;
621
	c->isClient = 1;
622
	c->clientVersion = c->version;
623
 
624
	c->sec = tlsSecInitc(c->clientVersion, c->crandom);
625
	if(c->sec == nil)
626
		goto Err;
627
 
628
	/* client hello */
629
	memset(&m, 0, sizeof(m));
630
	m.tag = HClientHello;
631
	m.u.clientHello.version = c->clientVersion;
632
	memmove(m.u.clientHello.random, c->crandom, RandomSize);
633
	m.u.clientHello.sid = makebytes(csid, ncsid);
634
	m.u.clientHello.ciphers = makeciphers();
635
	m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
636
	if(!msgSend(c, &m, AFlush))
637
		goto Err;
638
	msgClear(&m);
639
 
640
	/* server hello */
641
	if(!msgRecv(c, &m))
642
		goto Err;
643
	if(m.tag != HServerHello) {
644
		tlsError(c, EUnexpectedMessage, "expected a server hello");
645
		goto Err;
646
	}
647
	if(setVersion(c, m.u.serverHello.version) < 0) {
648
		tlsError(c, EIllegalParameter, "incompatible version %r");
649
		goto Err;
650
	}
651
	memmove(c->srandom, m.u.serverHello.random, RandomSize);
652
	c->sid = makebytes(m.u.serverHello.sid->data, m.u.serverHello.sid->len);
653
	if(c->sid->len != 0 && c->sid->len != SidSize) {
654
		tlsError(c, EIllegalParameter, "invalid server session identifier");
655
		goto Err;
656
	}
657
	if(!setAlgs(c, m.u.serverHello.cipher)) {
658
		tlsError(c, EIllegalParameter, "invalid cipher suite");
659
		goto Err;
660
	}
661
	if(m.u.serverHello.compressor != CompressionNull) {
662
		tlsError(c, EIllegalParameter, "invalid compression");
663
		goto Err;
664
	}
665
	msgClear(&m);
666
 
667
	/* certificate */
668
	if(!msgRecv(c, &m) || m.tag != HCertificate) {
669
		tlsError(c, EUnexpectedMessage, "expected a certificate");
670
		goto Err;
671
	}
672
	if(m.u.certificate.ncert < 1) {
673
		tlsError(c, EIllegalParameter, "runt certificate");
674
		goto Err;
675
	}
676
	c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
677
	msgClear(&m);
678
 
679
	/* server key exchange (optional) */
680
	if(!msgRecv(c, &m))
681
		goto Err;
682
	if(m.tag == HServerKeyExchange) {
683
		tlsError(c, EUnexpectedMessage, "got an server key exchange");
684
		goto Err;
685
		// If implementing this later, watch out for rollback attack
686
		// described in Wagner Schneier 1996, section 4.4.
687
	}
688
 
689
	/* certificate request (optional) */
690
	creq = 0;
691
	if(m.tag == HCertificateRequest) {
692
		creq = 1;
693
		msgClear(&m);
694
		if(!msgRecv(c, &m))
695
			goto Err;
696
	}
697
 
698
	if(m.tag != HServerHelloDone) {
699
		tlsError(c, EUnexpectedMessage, "expected a server hello done");
700
		goto Err;
701
	}
702
	msgClear(&m);
703
 
704
	if(tlsSecSecretc(c->sec, c->sid->data, c->sid->len, c->srandom,
705
			c->cert->data, c->cert->len, c->version, &epm, &nepm,
706
			kd, c->nsecret) < 0){
707
		tlsError(c, EBadCertificate, "invalid x509/rsa certificate");
708
		goto Err;
709
	}
710
	secrets = (char*)emalloc(2*c->nsecret);
711
	enc64(secrets, 2*c->nsecret, kd, c->nsecret);
712
	rv = fprint(c->ctl, "secret %s %s 1 %s", c->digest, c->enc, secrets);
713
	memset(secrets, 0, 2*c->nsecret);
714
	free(secrets);
715
	memset(kd, 0, c->nsecret);
716
	if(rv < 0){
717
		tlsError(c, EHandshakeFailure, "can't set keys: %r");
718
		goto Err;
719
	}
720
 
721
	if(creq) {
722
		/* send a zero length certificate */
723
		m.tag = HCertificate;
724
		if(!msgSend(c, &m, AFlush))
725
			goto Err;
726
		msgClear(&m);
727
	}
728
 
729
	/* client key exchange */
730
	m.tag = HClientKeyExchange;
731
	m.u.clientKeyExchange.key = makebytes(epm, nepm);
732
	free(epm);
733
	epm = nil;
734
	if(m.u.clientKeyExchange.key == nil) {
735
		tlsError(c, EHandshakeFailure, "can't set secret: %r");
736
		goto Err;
737
	}
738
	if(!msgSend(c, &m, AFlush))
739
		goto Err;
740
	msgClear(&m);
741
 
742
	/* change cipher spec */
743
	if(fprint(c->ctl, "changecipher") < 0){
744
		tlsError(c, EInternalError, "can't enable cipher: %r");
745
		goto Err;
746
	}
747
 
748
	// Cipherchange must occur immediately before Finished to avoid
749
	// potential hole;  see section 4.3 of Wagner Schneier 1996.
750
	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
751
		tlsError(c, EInternalError, "can't set finished 1: %r");
752
		goto Err;
753
	}
754
	m.tag = HFinished;
755
	m.u.finished = c->finished;
756
 
757
	if(!msgSend(c, &m, AFlush)) {
758
		fprint(2, "tlsClient nepm=%d\n", nepm);
759
		tlsError(c, EInternalError, "can't flush after client Finished: %r");
760
		goto Err;
761
	}
762
	msgClear(&m);
763
 
764
	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
765
		fprint(2, "tlsClient nepm=%d\n", nepm);
766
		tlsError(c, EInternalError, "can't set finished 0: %r");
767
		goto Err;
768
	}
769
	if(!msgRecv(c, &m)) {
770
		fprint(2, "tlsClient nepm=%d\n", nepm);
771
		tlsError(c, EInternalError, "can't read server Finished: %r");
772
		goto Err;
773
	}
774
	if(m.tag != HFinished) {
775
		fprint(2, "tlsClient nepm=%d\n", nepm);
776
		tlsError(c, EUnexpectedMessage, "expected a Finished msg from server");
777
		goto Err;
778
	}
779
 
780
	if(!finishedMatch(c, &m.u.finished)) {
781
		tlsError(c, EHandshakeFailure, "finished verification failed");
782
		goto Err;
783
	}
784
	msgClear(&m);
785
 
786
	if(fprint(c->ctl, "opened") < 0){
787
		if(trace)
788
			trace("unable to do final open: %r\n");
789
		goto Err;
790
	}
791
	tlsSecOk(c->sec);
792
	return c;
793
 
794
Err:
795
	free(epm);
796
	msgClear(&m);
797
	tlsConnectionFree(c);
798
	return 0;
799
}
800
 
801
 
802
//================= message functions ========================
803
 
804
static uchar sendbuf[9000], *sendp;
805
 
806
static int
807
msgSend(TlsConnection *c, Msg *m, int act)
808
{
809
	uchar *p; // sendp = start of new message;  p = write pointer
810
	int nn, n, i;
811
 
812
	if(sendp == nil)
813
		sendp = sendbuf;
814
	p = sendp;
815
	if(c->trace)
816
		c->trace("send %s", msgPrint((char*)p, (sizeof sendbuf) - (p-sendbuf), m));
817
 
818
	p[0] = m->tag;	// header - fill in size later
819
	p += 4;
820
 
821
	switch(m->tag) {
822
	default:
823
		tlsError(c, EInternalError, "can't encode a %d", m->tag);
824
		goto Err;
825
	case HClientHello:
826
		// version
827
		put16(p, m->u.clientHello.version);
828
		p += 2;
829
 
830
		// random
831
		memmove(p, m->u.clientHello.random, RandomSize);
832
		p += RandomSize;
833
 
834
		// sid
835
		n = m->u.clientHello.sid->len;
836
		assert(n < 256);
837
		p[0] = n;
838
		memmove(p+1, m->u.clientHello.sid->data, n);
839
		p += n+1;
840
 
841
		n = m->u.clientHello.ciphers->len;
842
		assert(n > 0 && n < 200);
843
		put16(p, n*2);
844
		p += 2;
845
		for(i=0; i<n; i++) {
846
			put16(p, m->u.clientHello.ciphers->data[i]);
847
			p += 2;
848
		}
849
 
850
		n = m->u.clientHello.compressors->len;
851
		assert(n > 0);
852
		p[0] = n;
853
		memmove(p+1, m->u.clientHello.compressors->data, n);
854
		p += n+1;
855
		break;
856
	case HServerHello:
857
		put16(p, m->u.serverHello.version);
858
		p += 2;
859
 
860
		// random
861
		memmove(p, m->u.serverHello.random, RandomSize);
862
		p += RandomSize;
863
 
864
		// sid
865
		n = m->u.serverHello.sid->len;
866
		assert(n < 256);
867
		p[0] = n;
868
		memmove(p+1, m->u.serverHello.sid->data, n);
869
		p += n+1;
870
 
871
		put16(p, m->u.serverHello.cipher);
872
		p += 2;
873
		p[0] = m->u.serverHello.compressor;
874
		p += 1;
875
		break;
876
	case HServerHelloDone:
877
		break;
878
	case HCertificate:
879
		nn = 0;
880
		for(i = 0; i < m->u.certificate.ncert; i++)
881
			nn += 3 + m->u.certificate.certs[i]->len;
882
		if(p + 3 + nn - sendbuf > sizeof(sendbuf)) {
883
			tlsError(c, EInternalError, "output buffer too small for certificate");
884
			goto Err;
885
		}
886
		put24(p, nn);
887
		p += 3;
888
		for(i = 0; i < m->u.certificate.ncert; i++){
889
			put24(p, m->u.certificate.certs[i]->len);
890
			p += 3;
891
			memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len);
892
			p += m->u.certificate.certs[i]->len;
893
		}
894
		break;
895
	case HClientKeyExchange:
896
		n = m->u.clientKeyExchange.key->len;
897
		if(c->version != SSL3Version){
898
			put16(p, n);
899
			p += 2;
900
		}
901
		memmove(p, m->u.clientKeyExchange.key->data, n);
902
		p += n;
903
		break;
904
	case HFinished:
905
		memmove(p, m->u.finished.verify, m->u.finished.n);
906
		p += m->u.finished.n;
907
		break;
908
	}
909
 
910
	// go back and fill in size
911
	n = p-sendp;
912
	assert(p <= sendbuf+sizeof(sendbuf));
913
	put24(sendp+1, n-4);
914
 
915
	// remember hash of Handshake messages
916
	if(m->tag != HHelloRequest) {
917
		md5(sendp, n, 0, &c->hsmd5);
918
		sha1(sendp, n, 0, &c->hssha1);
919
	}
920
 
921
	sendp = p;
922
	if(act == AFlush){
923
		sendp = sendbuf;
924
		if(write(c->hand, sendbuf, p-sendbuf) < 0){
925
			fprint(2, "write error: %r\n");
926
			goto Err;
927
		}
928
	}
929
	msgClear(m);
930
	return 1;
931
Err:
932
	msgClear(m);
933
	return 0;
934
}
935
 
936
static uchar*
937
tlsReadN(TlsConnection *c, int n)
938
{
939
	uchar *p;
940
	int nn, nr;
941
 
942
	nn = c->ep - c->rp;
943
	if(nn < n){
944
		if(c->rp != c->buf){
945
			memmove(c->buf, c->rp, nn);
946
			c->rp = c->buf;
947
			c->ep = &c->buf[nn];
948
		}
949
		for(; nn < n; nn += nr) {
950
			nr = read(c->hand, &c->rp[nn], n - nn);
951
			if(nr <= 0)
952
				return nil;
953
			c->ep += nr;
954
		}
955
	}
956
	p = c->rp;
957
	c->rp += n;
958
	return p;
959
}
960
 
961
static int
962
msgRecv(TlsConnection *c, Msg *m)
963
{
964
	uchar *p;
965
	int type, n, nn, i, nsid, nrandom, nciph;
966
 
967
	for(;;) {
968
		p = tlsReadN(c, 4);
969
		if(p == nil)
970
			return 0;
971
		type = p[0];
972
		n = get24(p+1);
973
 
974
		if(type != HHelloRequest)
975
			break;
976
		if(n != 0) {
977
			tlsError(c, EDecodeError, "invalid hello request during handshake");
978
			return 0;
979
		}
980
	}
981
 
982
	if(n > sizeof(c->buf)) {
983
		tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->buf));
984
		return 0;
985
	}
986
 
987
	if(type == HSSL2ClientHello){
988
		/* Cope with an SSL3 ClientHello expressed in SSL2 record format.
989
			This is sent by some clients that we must interoperate
990
			with, such as Java's JSSE and Microsoft's Internet Explorer. */
991
		p = tlsReadN(c, n);
992
		if(p == nil)
993
			return 0;
994
		md5(p, n, 0, &c->hsmd5);
995
		sha1(p, n, 0, &c->hssha1);
996
		m->tag = HClientHello;
997
		if(n < 22)
998
			goto Short;
999
		m->u.clientHello.version = get16(p+1);
1000
		p += 3;
1001
		n -= 3;
1002
		nn = get16(p); /* cipher_spec_len */
1003
		nsid = get16(p + 2);
1004
		nrandom = get16(p + 4);
1005
		p += 6;
1006
		n -= 6;
1007
		if(nsid != 0 	/* no sid's, since shouldn't restart using ssl2 header */
1008
				|| nrandom < 16 || nn % 3)
1009
			goto Err;
1010
		if(c->trace && (n - nrandom != nn))
1011
			c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn);
1012
		/* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
1013
		nciph = 0;
1014
		for(i = 0; i < nn; i += 3)
1015
			if(p[i] == 0)
1016
				nciph++;
1017
		m->u.clientHello.ciphers = newints(nciph);
1018
		nciph = 0;
1019
		for(i = 0; i < nn; i += 3)
1020
			if(p[i] == 0)
1021
				m->u.clientHello.ciphers->data[nciph++] = get16(&p[i + 1]);
1022
		p += nn;
1023
		m->u.clientHello.sid = makebytes(nil, 0);
1024
		if(nrandom > RandomSize)
1025
			nrandom = RandomSize;
1026
		memset(m->u.clientHello.random, 0, RandomSize - nrandom);
1027
		memmove(&m->u.clientHello.random[RandomSize - nrandom], p, nrandom);
1028
		m->u.clientHello.compressors = newbytes(1);
1029
		m->u.clientHello.compressors->data[0] = CompressionNull;
1030
		goto Ok;
1031
	}
1032
 
1033
	md5(p, 4, 0, &c->hsmd5);
1034
	sha1(p, 4, 0, &c->hssha1);
1035
 
1036
	p = tlsReadN(c, n);
1037
	if(p == nil)
1038
		return 0;
1039
 
1040
	md5(p, n, 0, &c->hsmd5);
1041
	sha1(p, n, 0, &c->hssha1);
1042
 
1043
	m->tag = type;
1044
 
1045
	switch(type) {
1046
	default:
1047
		tlsError(c, EUnexpectedMessage, "can't decode a %d", type);
1048
		goto Err;
1049
	case HClientHello:
1050
		if(n < 2)
1051
			goto Short;
1052
		m->u.clientHello.version = get16(p);
1053
		p += 2;
1054
		n -= 2;
1055
 
1056
		if(n < RandomSize)
1057
			goto Short;
1058
		memmove(m->u.clientHello.random, p, RandomSize);
1059
		p += RandomSize;
1060
		n -= RandomSize;
1061
		if(n < 1 || n < p[0]+1)
1062
			goto Short;
1063
		m->u.clientHello.sid = makebytes(p+1, p[0]);
1064
		p += m->u.clientHello.sid->len+1;
1065
		n -= m->u.clientHello.sid->len+1;
1066
 
1067
		if(n < 2)
1068
			goto Short;
1069
		nn = get16(p);
1070
		p += 2;
1071
		n -= 2;
1072
 
1073
		if((nn & 1) || n < nn || nn < 2)
1074
			goto Short;
1075
		m->u.clientHello.ciphers = newints(nn >> 1);
1076
		for(i = 0; i < nn; i += 2)
1077
			m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]);
1078
		p += nn;
1079
		n -= nn;
1080
 
1081
		if(n < 1 || n < p[0]+1 || p[0] == 0)
1082
			goto Short;
1083
		nn = p[0];
1084
		m->u.clientHello.compressors = newbytes(nn);
1085
		memmove(m->u.clientHello.compressors->data, p+1, nn);
1086
		n -= nn + 1;
1087
		break;
1088
	case HServerHello:
1089
		if(n < 2)
1090
			goto Short;
1091
		m->u.serverHello.version = get16(p);
1092
		p += 2;
1093
		n -= 2;
1094
 
1095
		if(n < RandomSize)
1096
			goto Short;
1097
		memmove(m->u.serverHello.random, p, RandomSize);
1098
		p += RandomSize;
1099
		n -= RandomSize;
1100
 
1101
		if(n < 1 || n < p[0]+1)
1102
			goto Short;
1103
		m->u.serverHello.sid = makebytes(p+1, p[0]);
1104
		p += m->u.serverHello.sid->len+1;
1105
		n -= m->u.serverHello.sid->len+1;
1106
 
1107
		if(n < 3)
1108
			goto Short;
1109
		m->u.serverHello.cipher = get16(p);
1110
		m->u.serverHello.compressor = p[2];
1111
		n -= 3;
1112
		break;
1113
	case HCertificate:
1114
		if(n < 3)
1115
			goto Short;
1116
		nn = get24(p);
1117
		p += 3;
1118
		n -= 3;
1119
		if(n != nn)
1120
			goto Short;
1121
		/* certs */
1122
		i = 0;
1123
		while(n > 0) {
1124
			if(n < 3)
1125
				goto Short;
1126
			nn = get24(p);
1127
			p += 3;
1128
			n -= 3;
1129
			if(nn > n)
1130
				goto Short;
1131
			m->u.certificate.ncert = i+1;
1132
			m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes));
1133
			m->u.certificate.certs[i] = makebytes(p, nn);
1134
			p += nn;
1135
			n -= nn;
1136
			i++;
1137
		}
1138
		break;
1139
	case HCertificateRequest:
1140
		if(n < 1)
1141
			goto Short;
1142
		nn = p[0];
1143
		p += 1;
1144
		n -= 1;
1145
		if(nn < 1 || nn > n)
1146
			goto Short;
1147
		m->u.certificateRequest.types = makebytes(p, nn);
1148
		p += nn;
1149
		n -= nn;
1150
		if(n < 2)
1151
			goto Short;
1152
		nn = get16(p);
1153
		p += 2;
1154
		n -= 2;
1155
		/* nn == 0 can happen; yahoo's servers do it */
1156
		if(nn != n)
1157
			goto Short;
1158
		/* cas */
1159
		i = 0;
1160
		while(n > 0) {
1161
			if(n < 2)
1162
				goto Short;
1163
			nn = get16(p);
1164
			p += 2;
1165
			n -= 2;
1166
			if(nn < 1 || nn > n)
1167
				goto Short;
1168
			m->u.certificateRequest.nca = i+1;
1169
			m->u.certificateRequest.cas = erealloc(
1170
				m->u.certificateRequest.cas, (i+1)*sizeof(Bytes));
1171
			m->u.certificateRequest.cas[i] = makebytes(p, nn);
1172
			p += nn;
1173
			n -= nn;
1174
			i++;
1175
		}
1176
		break;
1177
	case HServerHelloDone:
1178
		break;
1179
	case HClientKeyExchange:
1180
		/*
1181
		 * this message depends upon the encryption selected
1182
		 * assume rsa.
1183
		 */
1184
		if(c->version == SSL3Version)
1185
			nn = n;
1186
		else{
1187
			if(n < 2)
1188
				goto Short;
1189
			nn = get16(p);
1190
			p += 2;
1191
			n -= 2;
1192
		}
1193
		if(n < nn)
1194
			goto Short;
1195
		m->u.clientKeyExchange.key = makebytes(p, nn);
1196
		n -= nn;
1197
		break;
1198
	case HFinished:
1199
		m->u.finished.n = c->finished.n;
1200
		if(n < m->u.finished.n)
1201
			goto Short;
1202
		memmove(m->u.finished.verify, p, m->u.finished.n);
1203
		n -= m->u.finished.n;
1204
		break;
1205
	}
1206
 
1207
	if(type != HClientHello && n != 0)
1208
		goto Short;
1209
Ok:
1210
	if(c->trace){
1211
		char *buf;
1212
		buf = emalloc(8000);
1213
		c->trace("recv %s", msgPrint(buf, 8000, m));
1214
		free(buf);
1215
	}
1216
	return 1;
1217
Short:
1218
	tlsError(c, EDecodeError, "handshake message has invalid length");
1219
Err:
1220
	msgClear(m);
1221
	return 0;
1222
}
1223
 
1224
static void
1225
msgClear(Msg *m)
1226
{
1227
	int i;
1228
 
1229
	switch(m->tag) {
1230
	default:
1231
		sysfatal("msgClear: unknown message type: %d", m->tag);
1232
	case HHelloRequest:
1233
		break;
1234
	case HClientHello:
1235
		freebytes(m->u.clientHello.sid);
1236
		freeints(m->u.clientHello.ciphers);
1237
		freebytes(m->u.clientHello.compressors);
1238
		break;
1239
	case HServerHello:
1240
		freebytes(m->u.clientHello.sid);
1241
		break;
1242
	case HCertificate:
1243
		for(i=0; i<m->u.certificate.ncert; i++)
1244
			freebytes(m->u.certificate.certs[i]);
1245
		free(m->u.certificate.certs);
1246
		break;
1247
	case HCertificateRequest:
1248
		freebytes(m->u.certificateRequest.types);
1249
		for(i=0; i<m->u.certificateRequest.nca; i++)
1250
			freebytes(m->u.certificateRequest.cas[i]);
1251
		free(m->u.certificateRequest.cas);
1252
		break;
1253
	case HServerHelloDone:
1254
		break;
1255
	case HClientKeyExchange:
1256
		freebytes(m->u.clientKeyExchange.key);
1257
		break;
1258
	case HFinished:
1259
		break;
1260
	}
1261
	memset(m, 0, sizeof(Msg));
1262
}
1263
 
1264
static char *
1265
bytesPrint(char *bs, char *be, char *s0, Bytes *b, char *s1)
1266
{
1267
	int i;
1268
 
1269
	if(s0)
1270
		bs = seprint(bs, be, "%s", s0);
1271
	bs = seprint(bs, be, "[");
1272
	if(b == nil)
1273
		bs = seprint(bs, be, "nil");
1274
	else
1275
		for(i=0; i<b->len; i++)
1276
			bs = seprint(bs, be, "%.2x ", b->data[i]);
1277
	bs = seprint(bs, be, "]");
1278
	if(s1)
1279
		bs = seprint(bs, be, "%s", s1);
1280
	return bs;
1281
}
1282
 
1283
static char *
1284
intsPrint(char *bs, char *be, char *s0, Ints *b, char *s1)
1285
{
1286
	int i;
1287
 
1288
	if(s0)
1289
		bs = seprint(bs, be, "%s", s0);
1290
	bs = seprint(bs, be, "[");
1291
	if(b == nil)
1292
		bs = seprint(bs, be, "nil");
1293
	else
1294
		for(i=0; i<b->len; i++)
1295
			bs = seprint(bs, be, "%x ", b->data[i]);
1296
	bs = seprint(bs, be, "]");
1297
	if(s1)
1298
		bs = seprint(bs, be, "%s", s1);
1299
	return bs;
1300
}
1301
 
1302
static char*
1303
msgPrint(char *buf, int n, Msg *m)
1304
{
1305
	int i;
1306
	char *bs = buf, *be = buf+n;
1307
 
1308
	switch(m->tag) {
1309
	default:
1310
		bs = seprint(bs, be, "unknown %d\n", m->tag);
1311
		break;
1312
	case HClientHello:
1313
		bs = seprint(bs, be, "ClientHello\n");
1314
		bs = seprint(bs, be, "\tversion: %.4x\n", m->u.clientHello.version);
1315
		bs = seprint(bs, be, "\trandom: ");
1316
		for(i=0; i<RandomSize; i++)
1317
			bs = seprint(bs, be, "%.2x", m->u.clientHello.random[i]);
1318
		bs = seprint(bs, be, "\n");
1319
		bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
1320
		bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
1321
		bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
1322
		break;
1323
	case HServerHello:
1324
		bs = seprint(bs, be, "ServerHello\n");
1325
		bs = seprint(bs, be, "\tversion: %.4x\n", m->u.serverHello.version);
1326
		bs = seprint(bs, be, "\trandom: ");
1327
		for(i=0; i<RandomSize; i++)
1328
			bs = seprint(bs, be, "%.2x", m->u.serverHello.random[i]);
1329
		bs = seprint(bs, be, "\n");
1330
		bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
1331
		bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
1332
		bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
1333
		break;
1334
	case HCertificate:
1335
		bs = seprint(bs, be, "Certificate\n");
1336
		for(i=0; i<m->u.certificate.ncert; i++)
1337
			bs = bytesPrint(bs, be, "\t", m->u.certificate.certs[i], "\n");
1338
		break;
1339
	case HCertificateRequest:
1340
		bs = seprint(bs, be, "CertificateRequest\n");
1341
		bs = bytesPrint(bs, be, "\ttypes: ", m->u.certificateRequest.types, "\n");
1342
		bs = seprint(bs, be, "\tcertificateauthorities\n");
1343
		for(i=0; i<m->u.certificateRequest.nca; i++)
1344
			bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n");
1345
		break;
1346
	case HServerHelloDone:
1347
		bs = seprint(bs, be, "ServerHelloDone\n");
1348
		break;
1349
	case HClientKeyExchange:
1350
		bs = seprint(bs, be, "HClientKeyExchange\n");
1351
		bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
1352
		break;
1353
	case HFinished:
1354
		bs = seprint(bs, be, "HFinished\n");
1355
		for(i=0; i<m->u.finished.n; i++)
1356
			bs = seprint(bs, be, "%.2x", m->u.finished.verify[i]);
1357
		bs = seprint(bs, be, "\n");
1358
		break;
1359
	}
1360
	USED(bs);
1361
	return buf;
1362
}
1363
 
1364
static void
1365
tlsError(TlsConnection *c, int err, char *fmt, ...)
1366
{
1367
	char msg[512];
1368
	va_list arg;
1369
 
1370
	va_start(arg, fmt);
1371
	vseprint(msg, msg+sizeof(msg), fmt, arg);
1372
	va_end(arg);
1373
	if(c->trace)
1374
		c->trace("tlsError: %s\n", msg);
1375
	else if(c->erred)
1376
		fprint(2, "double error: %r, %s", msg);
1377
	else
1378
		werrstr("tls: local %s", msg);
1379
	c->erred = 1;
1380
	fprint(c->ctl, "alert %d", err);
1381
}
1382
 
1383
// commit to specific version number
1384
static int
1385
setVersion(TlsConnection *c, int version)
1386
{
1387
	if(c->verset || version > MaxProtoVersion || version < MinProtoVersion)
1388
		return -1;
1389
	if(version > c->version)
1390
		version = c->version;
1391
	if(version == SSL3Version) {
1392
		c->version = version;
1393
		c->finished.n = SSL3FinishedLen;
1394
	}else if(version == TLSVersion){
1395
		c->version = version;
1396
		c->finished.n = TLSFinishedLen;
1397
	}else
1398
		return -1;
1399
	c->verset = 1;
1400
	return fprint(c->ctl, "version 0x%x", version);
1401
}
1402
 
1403
// confirm that received Finished message matches the expected value
1404
static int
1405
finishedMatch(TlsConnection *c, Finished *f)
1406
{
1407
	return memcmp(f->verify, c->finished.verify, f->n) == 0;
1408
}
1409
 
1410
// free memory associated with TlsConnection struct
1411
//		(but don't close the TLS channel itself)
1412
static void
1413
tlsConnectionFree(TlsConnection *c)
1414
{
1415
	tlsSecClose(c->sec);
1416
	freebytes(c->sid);
1417
	freebytes(c->cert);
1418
	memset(c, 0, sizeof(c));
1419
	free(c);
1420
}
1421
 
1422
 
1423
//================= cipher choices ========================
1424
 
1425
static int weakCipher[CipherMax] =
1426
{
1427
	1,	/* TLS_NULL_WITH_NULL_NULL */
1428
	1,	/* TLS_RSA_WITH_NULL_MD5 */
1429
	1,	/* TLS_RSA_WITH_NULL_SHA */
1430
	1,	/* TLS_RSA_EXPORT_WITH_RC4_40_MD5 */
1431
	0,	/* TLS_RSA_WITH_RC4_128_MD5 */
1432
	0,	/* TLS_RSA_WITH_RC4_128_SHA */
1433
	1,	/* TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 */
1434
	0,	/* TLS_RSA_WITH_IDEA_CBC_SHA */
1435
	1,	/* TLS_RSA_EXPORT_WITH_DES40_CBC_SHA */
1436
	0,	/* TLS_RSA_WITH_DES_CBC_SHA */
1437
	0,	/* TLS_RSA_WITH_3DES_EDE_CBC_SHA */
1438
	1,	/* TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA */
1439
	0,	/* TLS_DH_DSS_WITH_DES_CBC_SHA */
1440
	0,	/* TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA */
1441
	1,	/* TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA */
1442
	0,	/* TLS_DH_RSA_WITH_DES_CBC_SHA */
1443
	0,	/* TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA */
1444
	1,	/* TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA */
1445
	0,	/* TLS_DHE_DSS_WITH_DES_CBC_SHA */
1446
	0,	/* TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA */
1447
	1,	/* TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA */
1448
	0,	/* TLS_DHE_RSA_WITH_DES_CBC_SHA */
1449
	0,	/* TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA */
1450
	1,	/* TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 */
1451
	1,	/* TLS_DH_anon_WITH_RC4_128_MD5 */
1452
	1,	/* TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA */
1453
	1,	/* TLS_DH_anon_WITH_DES_CBC_SHA */
1454
	1,	/* TLS_DH_anon_WITH_3DES_EDE_CBC_SHA */
1455
};
1456
 
1457
static int
1458
setAlgs(TlsConnection *c, int a)
1459
{
1460
	int i;
1461
 
1462
	for(i = 0; i < nelem(cipherAlgs); i++){
1463
		if(cipherAlgs[i].tlsid == a){
1464
			c->enc = cipherAlgs[i].enc;
1465
			c->digest = cipherAlgs[i].digest;
1466
			c->nsecret = cipherAlgs[i].nsecret;
1467
			if(c->nsecret > MaxKeyData)
1468
				return 0;
1469
			return 1;
1470
		}
1471
	}
1472
	return 0;
1473
}
1474
 
1475
static int
1476
okCipher(Ints *cv)
1477
{
1478
	int weak, i, j, c;
1479
 
1480
	weak = 1;
1481
	for(i = 0; i < cv->len; i++) {
1482
		c = cv->data[i];
1483
		if(c >= CipherMax)
1484
			weak = 0;
1485
		else
1486
			weak &= weakCipher[c];
1487
		for(j = 0; j < nelem(cipherAlgs); j++)
1488
			if(cipherAlgs[j].ok && cipherAlgs[j].tlsid == c)
1489
				return c;
1490
	}
1491
	if(weak)
1492
		return -2;
1493
	return -1;
1494
}
1495
 
1496
static int
1497
okCompression(Bytes *cv)
1498
{
1499
	int i, j, c;
1500
 
1501
	for(i = 0; i < cv->len; i++) {
1502
		c = cv->data[i];
1503
		for(j = 0; j < nelem(compressors); j++) {
1504
			if(compressors[j] == c)
1505
				return c;
1506
		}
1507
	}
1508
	return -1;
1509
}
1510
 
1511
static Lock	ciphLock;
1512
static int	nciphers;
1513
 
1514
static int
1515
initCiphers(void)
1516
{
1517
	enum {MaxAlgF = 1024, MaxAlgs = 10};
1518
	char s[MaxAlgF], *flds[MaxAlgs];
1519
	int i, j, n, ok;
1520
 
1521
	lock(&ciphLock);
1522
	if(nciphers){
1523
		unlock(&ciphLock);
1524
		return nciphers;
1525
	}
1526
	j = open("#a/tls/encalgs", OREAD);
1527
	if(j < 0){
1528
		werrstr("can't open #a/tls/encalgs: %r");
1529
		return 0;
1530
	}
1531
	n = read(j, s, MaxAlgF-1);
1532
	close(j);
1533
	if(n <= 0){
1534
		werrstr("nothing in #a/tls/encalgs: %r");
1535
		return 0;
1536
	}
1537
	s[n] = 0;
1538
	n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1539
	for(i = 0; i < nelem(cipherAlgs); i++){
1540
		ok = 0;
1541
		for(j = 0; j < n; j++){
1542
			if(strcmp(cipherAlgs[i].enc, flds[j]) == 0){
1543
				ok = 1;
1544
				break;
1545
			}
1546
		}
1547
		cipherAlgs[i].ok = ok;
1548
	}
1549
 
1550
	j = open("#a/tls/hashalgs", OREAD);
1551
	if(j < 0){
1552
		werrstr("can't open #a/tls/hashalgs: %r");
1553
		return 0;
1554
	}
1555
	n = read(j, s, MaxAlgF-1);
1556
	close(j);
1557
	if(n <= 0){
1558
		werrstr("nothing in #a/tls/hashalgs: %r");
1559
		return 0;
1560
	}
1561
	s[n] = 0;
1562
	n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1563
	for(i = 0; i < nelem(cipherAlgs); i++){
1564
		ok = 0;
1565
		for(j = 0; j < n; j++){
1566
			if(strcmp(cipherAlgs[i].digest, flds[j]) == 0){
1567
				ok = 1;
1568
				break;
1569
			}
1570
		}
1571
		cipherAlgs[i].ok &= ok;
1572
		if(cipherAlgs[i].ok)
1573
			nciphers++;
1574
	}
1575
	unlock(&ciphLock);
1576
	return nciphers;
1577
}
1578
 
1579
static Ints*
1580
makeciphers(void)
1581
{
1582
	Ints *is;
1583
	int i, j;
1584
 
1585
	is = newints(nciphers);
1586
	j = 0;
1587
	for(i = 0; i < nelem(cipherAlgs); i++){
1588
		if(cipherAlgs[i].ok)
1589
			is->data[j++] = cipherAlgs[i].tlsid;
1590
	}
1591
	return is;
1592
}
1593
 
1594
 
1595
 
1596
//================= security functions ========================
1597
 
1598
// given X.509 certificate, set up connection to factotum
1599
//	for using corresponding private key
1600
static AuthRpc*
1601
factotum_rsa_open(uchar *cert, int certlen)
1602
{
1603
	int afd;
1604
	char *s;
1605
	mpint *pub = nil;
1606
	RSApub *rsapub;
1607
	AuthRpc *rpc;
1608
 
1609
	// start talking to factotum
1610
	if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
1611
		return nil;
1612
	if((rpc = auth_allocrpc(afd)) == nil){
1613
		close(afd);
1614
		return nil;
1615
	}
1616
	s = "proto=rsa service=tls role=client";
1617
	if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
1618
		factotum_rsa_close(rpc);
1619
		return nil;
1620
	}
1621
 
1622
	// roll factotum keyring around to match certificate
1623
	rsapub = X509toRSApub(cert, certlen, nil, 0);
1624
	while(1){
1625
		if(auth_rpc(rpc, "read", nil, 0) != ARok){
1626
			factotum_rsa_close(rpc);
1627
			rpc = nil;
1628
			goto done;
1629
		}
1630
		pub = strtomp(rpc->arg, nil, 16, nil);
1631
		assert(pub != nil);
1632
		if(mpcmp(pub,rsapub->n) == 0)
1633
			break;
1634
	}
1635
done:
1636
	mpfree(pub);
1637
	rsapubfree(rsapub);
1638
	return rpc;
1639
}
1640
 
1641
static mpint*
1642
factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher)
1643
{
1644
	char *p;
1645
	int rv;
1646
 
1647
	if((p = mptoa(cipher, 16, nil, 0)) == nil)
1648
		return nil;
1649
	rv = auth_rpc(rpc, "write", p, strlen(p));
1650
	free(p);
1651
	if(rv != ARok || auth_rpc(rpc, "read", nil, 0) != ARok)
1652
		return nil;
1653
	mpfree(cipher);
1654
	return strtomp(rpc->arg, nil, 16, nil);
1655
}
1656
 
1657
static void
1658
factotum_rsa_close(AuthRpc*rpc)
1659
{
1660
	if(!rpc)
1661
		return;
1662
	close(rpc->afd);
1663
	auth_freerpc(rpc);
1664
}
1665
 
1666
static void
1667
tlsPmd5(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1668
{
1669
	uchar ai[MD5dlen], tmp[MD5dlen];
1670
	int i, n;
1671
	MD5state *s;
1672
 
1673
	// generate a1
1674
	s = hmac_md5(label, nlabel, key, nkey, nil, nil);
1675
	s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1676
	hmac_md5(seed1, nseed1, key, nkey, ai, s);
1677
 
1678
	while(nbuf > 0) {
1679
		s = hmac_md5(ai, MD5dlen, key, nkey, nil, nil);
1680
		s = hmac_md5(label, nlabel, key, nkey, nil, s);
1681
		s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1682
		hmac_md5(seed1, nseed1, key, nkey, tmp, s);
1683
		n = MD5dlen;
1684
		if(n > nbuf)
1685
			n = nbuf;
1686
		for(i = 0; i < n; i++)
1687
			buf[i] ^= tmp[i];
1688
		buf += n;
1689
		nbuf -= n;
1690
		hmac_md5(ai, MD5dlen, key, nkey, tmp, nil);
1691
		memmove(ai, tmp, MD5dlen);
1692
	}
1693
}
1694
 
1695
static void
1696
tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1697
{
1698
	uchar ai[SHA1dlen], tmp[SHA1dlen];
1699
	int i, n;
1700
	SHAstate *s;
1701
 
1702
	// generate a1
1703
	s = hmac_sha1(label, nlabel, key, nkey, nil, nil);
1704
	s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1705
	hmac_sha1(seed1, nseed1, key, nkey, ai, s);
1706
 
1707
	while(nbuf > 0) {
1708
		s = hmac_sha1(ai, SHA1dlen, key, nkey, nil, nil);
1709
		s = hmac_sha1(label, nlabel, key, nkey, nil, s);
1710
		s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1711
		hmac_sha1(seed1, nseed1, key, nkey, tmp, s);
1712
		n = SHA1dlen;
1713
		if(n > nbuf)
1714
			n = nbuf;
1715
		for(i = 0; i < n; i++)
1716
			buf[i] ^= tmp[i];
1717
		buf += n;
1718
		nbuf -= n;
1719
		hmac_sha1(ai, SHA1dlen, key, nkey, tmp, nil);
1720
		memmove(ai, tmp, SHA1dlen);
1721
	}
1722
}
1723
 
1724
// fill buf with md5(args)^sha1(args)
1725
static void
1726
tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1727
{
1728
	int i;
1729
	int nlabel = strlen(label);
1730
	int n = (nkey + 1) >> 1;
1731
 
1732
	for(i = 0; i < nbuf; i++)
1733
		buf[i] = 0;
1734
	tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1735
	tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1736
}
1737
 
1738
/*
1739
 * for setting server session id's
1740
 */
1741
static Lock	sidLock;
1742
static long	maxSid = 1;
1743
 
1744
/* the keys are verified to have the same public components
1745
 * and to function correctly with pkcs 1 encryption and decryption. */
1746
static TlsSec*
1747
tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom)
1748
{
1749
	TlsSec *sec = emalloc(sizeof(*sec));
1750
 
1751
	USED(csid); USED(ncsid);  // ignore csid for now
1752
 
1753
	memmove(sec->crandom, crandom, RandomSize);
1754
	sec->clientVers = cvers;
1755
 
1756
	put32(sec->srandom, time(0));
1757
	genrandom(sec->srandom+4, RandomSize-4);
1758
	memmove(srandom, sec->srandom, RandomSize);
1759
 
1760
	/*
1761
	 * make up a unique sid: use our pid, and and incrementing id
1762
	 * can signal no sid by setting nssid to 0.
1763
	 */
1764
	memset(ssid, 0, SidSize);
1765
	put32(ssid, getpid());
1766
	lock(&sidLock);
1767
	put32(ssid+4, maxSid++);
1768
	unlock(&sidLock);
1769
	*nssid = SidSize;
1770
	return sec;
1771
}
1772
 
1773
static int
1774
tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd)
1775
{
1776
	if(epm != nil){
1777
		if(setVers(sec, vers) < 0)
1778
			goto Err;
1779
		serverMasterSecret(sec, epm, nepm);
1780
	}else if(sec->vers != vers){
1781
		werrstr("mismatched session versions");
1782
		goto Err;
1783
	}
1784
	setSecrets(sec, kd, nkd);
1785
	return 0;
1786
Err:
1787
	sec->ok = -1;
1788
	return -1;
1789
}
1790
 
1791
static TlsSec*
1792
tlsSecInitc(int cvers, uchar *crandom)
1793
{
1794
	TlsSec *sec = emalloc(sizeof(*sec));
1795
	sec->clientVers = cvers;
1796
	put32(sec->crandom, time(0));
1797
	genrandom(sec->crandom+4, RandomSize-4);
1798
	memmove(crandom, sec->crandom, RandomSize);
1799
	return sec;
1800
}
1801
 
1802
static int
1803
tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd)
1804
{
1805
	RSApub *pub;
1806
 
1807
	pub = nil;
1808
 
1809
	USED(sid);
1810
	USED(nsid);
1811
 
1812
	memmove(sec->srandom, srandom, RandomSize);
1813
 
1814
	if(setVers(sec, vers) < 0)
1815
		goto Err;
1816
 
1817
	pub = X509toRSApub(cert, ncert, nil, 0);
1818
	if(pub == nil){
1819
		werrstr("invalid x509/rsa certificate");
1820
		goto Err;
1821
	}
1822
	if(clientMasterSecret(sec, pub, epm, nepm) < 0)
1823
		goto Err;
1824
	rsapubfree(pub);
1825
	setSecrets(sec, kd, nkd);
1826
	return 0;
1827
 
1828
Err:
1829
	if(pub != nil)
1830
		rsapubfree(pub);
1831
	sec->ok = -1;
1832
	return -1;
1833
}
1834
 
1835
static int
1836
tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient)
1837
{
1838
	if(sec->nfin != nfin){
1839
		sec->ok = -1;
1840
		werrstr("invalid finished exchange");
1841
		return -1;
1842
	}
1843
	md5.malloced = 0;
1844
	sha1.malloced = 0;
1845
	(*sec->setFinished)(sec, md5, sha1, fin, isclient);
1846
	return 1;
1847
}
1848
 
1849
static void
1850
tlsSecOk(TlsSec *sec)
1851
{
1852
	if(sec->ok == 0)
1853
		sec->ok = 1;
1854
}
1855
 
1856
static void
1857
tlsSecKill(TlsSec *sec)
1858
{
1859
	if(!sec)
1860
		return;
1861
	factotum_rsa_close(sec->rpc);
1862
	sec->ok = -1;
1863
}
1864
 
1865
static void
1866
tlsSecClose(TlsSec *sec)
1867
{
1868
	if(!sec)
1869
		return;
1870
	factotum_rsa_close(sec->rpc);
1871
	free(sec->server);
1872
	free(sec);
1873
}
1874
 
1875
static int
1876
setVers(TlsSec *sec, int v)
1877
{
1878
	if(v == SSL3Version){
1879
		sec->setFinished = sslSetFinished;
1880
		sec->nfin = SSL3FinishedLen;
1881
		sec->prf = sslPRF;
1882
	}else if(v == TLSVersion){
1883
		sec->setFinished = tlsSetFinished;
1884
		sec->nfin = TLSFinishedLen;
1885
		sec->prf = tlsPRF;
1886
	}else{
1887
		werrstr("invalid version");
1888
		return -1;
1889
	}
1890
	sec->vers = v;
1891
	return 0;
1892
}
1893
 
1894
/*
1895
 * generate secret keys from the master secret.
1896
 *
1897
 * different crypto selections will require different amounts
1898
 * of key expansion and use of key expansion data,
1899
 * but it's all generated using the same function.
1900
 */
1901
static void
1902
setSecrets(TlsSec *sec, uchar *kd, int nkd)
1903
{
1904
	(*sec->prf)(kd, nkd, sec->sec, MasterSecretSize, "key expansion",
1905
			sec->srandom, RandomSize, sec->crandom, RandomSize);
1906
}
1907
 
1908
/*
1909
 * set the master secret from the pre-master secret.
1910
 */
1911
static void
1912
setMasterSecret(TlsSec *sec, Bytes *pm)
1913
{
1914
	(*sec->prf)(sec->sec, MasterSecretSize, pm->data, MasterSecretSize, "master secret",
1915
			sec->crandom, RandomSize, sec->srandom, RandomSize);
1916
}
1917
 
1918
static void
1919
serverMasterSecret(TlsSec *sec, uchar *epm, int nepm)
1920
{
1921
	Bytes *pm;
1922
 
1923
	pm = pkcs1_decrypt(sec, epm, nepm);
1924
 
1925
	// if the client messed up, just continue as if everything is ok,
1926
	// to prevent attacks to check for correctly formatted messages.
1927
	// Hence the fprint(2,) can't be replaced by tlsError(), which sends an Alert msg to the client.
1928
	if(sec->ok < 0 || pm == nil || get16(pm->data) != sec->clientVers){
1929
		fprint(2, "serverMasterSecret failed ok=%d pm=%p pmvers=%x cvers=%x nepm=%d\n",
1930
			sec->ok, pm, pm ? get16(pm->data) : -1, sec->clientVers, nepm);
1931
		sec->ok = -1;
1932
		if(pm != nil)
1933
			freebytes(pm);
1934
		pm = newbytes(MasterSecretSize);
1935
		genrandom(pm->data, MasterSecretSize);
1936
	}
1937
	setMasterSecret(sec, pm);
1938
	memset(pm->data, 0, pm->len);	
1939
	freebytes(pm);
1940
}
1941
 
1942
static int
1943
clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm)
1944
{
1945
	Bytes *pm, *key;
1946
 
1947
	pm = newbytes(MasterSecretSize);
1948
	put16(pm->data, sec->clientVers);
1949
	genrandom(pm->data+2, MasterSecretSize - 2);
1950
 
1951
	setMasterSecret(sec, pm);
1952
 
1953
	key = pkcs1_encrypt(pm, pub, 2);
1954
	memset(pm->data, 0, pm->len);
1955
	freebytes(pm);
1956
	if(key == nil){
1957
		werrstr("tls pkcs1_encrypt failed");
1958
		return -1;
1959
	}
1960
 
1961
	*nepm = key->len;
1962
	*epm = malloc(*nepm);
1963
	if(*epm == nil){
1964
		freebytes(key);
1965
		werrstr("out of memory");
1966
		return -1;
1967
	}
1968
	memmove(*epm, key->data, *nepm);
1969
 
1970
	freebytes(key);
1971
 
1972
	return 1;
1973
}
1974
 
1975
static void
1976
sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
1977
{
1978
	DigestState *s;
1979
	uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
1980
	char *label;
1981
 
1982
	if(isClient)
1983
		label = "CLNT";
1984
	else
1985
		label = "SRVR";
1986
 
1987
	md5((uchar*)label, 4, nil, &hsmd5);
1988
	md5(sec->sec, MasterSecretSize, nil, &hsmd5);
1989
	memset(pad, 0x36, 48);
1990
	md5(pad, 48, nil, &hsmd5);
1991
	md5(nil, 0, h0, &hsmd5);
1992
	memset(pad, 0x5C, 48);
1993
	s = md5(sec->sec, MasterSecretSize, nil, nil);
1994
	s = md5(pad, 48, nil, s);
1995
	md5(h0, MD5dlen, finished, s);
1996
 
1997
	sha1((uchar*)label, 4, nil, &hssha1);
1998
	sha1(sec->sec, MasterSecretSize, nil, &hssha1);
1999
	memset(pad, 0x36, 40);
2000
	sha1(pad, 40, nil, &hssha1);
2001
	sha1(nil, 0, h1, &hssha1);
2002
	memset(pad, 0x5C, 40);
2003
	s = sha1(sec->sec, MasterSecretSize, nil, nil);
2004
	s = sha1(pad, 40, nil, s);
2005
	sha1(h1, SHA1dlen, finished + MD5dlen, s);
2006
}
2007
 
2008
// fill "finished" arg with md5(args)^sha1(args)
2009
static void
2010
tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
2011
{
2012
	uchar h0[MD5dlen], h1[SHA1dlen];
2013
	char *label;
2014
 
2015
	// get current hash value, but allow further messages to be hashed in
2016
	md5(nil, 0, h0, &hsmd5);
2017
	sha1(nil, 0, h1, &hssha1);
2018
 
2019
	if(isClient)
2020
		label = "client finished";
2021
	else
2022
		label = "server finished";
2023
	tlsPRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
2024
}
2025
 
2026
static void
2027
sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2028
{
2029
	DigestState *s;
2030
	uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
2031
	int i, n, len;
2032
 
2033
	USED(label);
2034
	len = 1;
2035
	while(nbuf > 0){
2036
		if(len > 26)
2037
			return;
2038
		for(i = 0; i < len; i++)
2039
			tmp[i] = 'A' - 1 + len;
2040
		s = sha1(tmp, len, nil, nil);
2041
		s = sha1(key, nkey, nil, s);
2042
		s = sha1(seed0, nseed0, nil, s);
2043
		sha1(seed1, nseed1, sha1dig, s);
2044
		s = md5(key, nkey, nil, nil);
2045
		md5(sha1dig, SHA1dlen, md5dig, s);
2046
		n = MD5dlen;
2047
		if(n > nbuf)
2048
			n = nbuf;
2049
		memmove(buf, md5dig, n);
2050
		buf += n;
2051
		nbuf -= n;
2052
		len++;
2053
	}
2054
}
2055
 
2056
static mpint*
2057
bytestomp(Bytes* bytes)
2058
{
2059
	mpint* ans;
2060
 
2061
	ans = betomp(bytes->data, bytes->len, nil);
2062
	return ans;
2063
}
2064
 
2065
/*
2066
 * Convert mpint* to Bytes, putting high order byte first.
2067
 */
2068
static Bytes*
2069
mptobytes(mpint* big)
2070
{
2071
	int n, m;
2072
	uchar *a;
2073
	Bytes* ans;
2074
 
2075
	a = nil;
2076
	n = (mpsignif(big)+7)/8;
2077
	m = mptobe(big, nil, n, &a);
2078
	ans = makebytes(a, m);
2079
	if(a != nil)
2080
		free(a);
2081
	return ans;
2082
}
2083
 
2084
// Do RSA computation on block according to key, and pad
2085
// result on left with zeros to make it modlen long.
2086
static Bytes*
2087
rsacomp(Bytes* block, RSApub* key, int modlen)
2088
{
2089
	mpint *x, *y;
2090
	Bytes *a, *ybytes;
2091
	int ylen;
2092
 
2093
	x = bytestomp(block);
2094
	y = rsaencrypt(key, x, nil);
2095
	mpfree(x);
2096
	ybytes = mptobytes(y);
2097
	ylen = ybytes->len;
2098
 
2099
	if(ylen < modlen) {
2100
		a = newbytes(modlen);
2101
		memset(a->data, 0, modlen-ylen);
2102
		memmove(a->data+modlen-ylen, ybytes->data, ylen);
2103
		freebytes(ybytes);
2104
		ybytes = a;
2105
	}
2106
	else if(ylen > modlen) {
2107
		// assume it has leading zeros (mod should make it so)
2108
		a = newbytes(modlen);
2109
		memmove(a->data, ybytes->data, modlen);
2110
		freebytes(ybytes);
2111
		ybytes = a;
2112
	}
2113
	mpfree(y);
2114
	return ybytes;
2115
}
2116
 
2117
// encrypt data according to PKCS#1, /lib/rfc/rfc2437 9.1.2.1
2118
static Bytes*
2119
pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype)
2120
{
2121
	Bytes *pad, *eb, *ans;
2122
	int i, dlen, padlen, modlen;
2123
 
2124
	modlen = (mpsignif(key->n)+7)/8;
2125
	dlen = data->len;
2126
	if(modlen < 12 || dlen > modlen - 11)
2127
		return nil;
2128
	padlen = modlen - 3 - dlen;
2129
	pad = newbytes(padlen);
2130
	genrandom(pad->data, padlen);
2131
	for(i = 0; i < padlen; i++) {
2132
		if(blocktype == 0)
2133
			pad->data[i] = 0;
2134
		else if(blocktype == 1)
2135
			pad->data[i] = 255;
2136
		else if(pad->data[i] == 0)
2137
			pad->data[i] = 1;
2138
	}
2139
	eb = newbytes(modlen);
2140
	eb->data[0] = 0;
2141
	eb->data[1] = blocktype;
2142
	memmove(eb->data+2, pad->data, padlen);
2143
	eb->data[padlen+2] = 0;
2144
	memmove(eb->data+padlen+3, data->data, dlen);
2145
	ans = rsacomp(eb, key, modlen);
2146
	freebytes(eb);
2147
	freebytes(pad);
2148
	return ans;
2149
}
2150
 
2151
// decrypt data according to PKCS#1, with given key.
2152
// expect a block type of 2.
2153
static Bytes*
2154
pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm)
2155
{
2156
	Bytes *eb, *ans = nil;
2157
	int i, modlen;
2158
	mpint *x, *y;
2159
 
2160
	modlen = (mpsignif(sec->rsapub->n)+7)/8;
2161
	if(nepm != modlen)
2162
		return nil;
2163
	x = betomp(epm, nepm, nil);
2164
	y = factotum_rsa_decrypt(sec->rpc, x);
2165
	if(y == nil)
2166
		return nil;
2167
	eb = mptobytes(y);
2168
	if(eb->len < modlen){ // pad on left with zeros
2169
		ans = newbytes(modlen);
2170
		memset(ans->data, 0, modlen-eb->len);
2171
		memmove(ans->data+modlen-eb->len, eb->data, eb->len);
2172
		freebytes(eb);
2173
		eb = ans;
2174
	}
2175
	if(eb->data[0] == 0 && eb->data[1] == 2) {
2176
		for(i = 2; i < modlen; i++)
2177
			if(eb->data[i] == 0)
2178
				break;
2179
		if(i < modlen - 1)
2180
			ans = makebytes(eb->data+i+1, modlen-(i+1));
2181
	}
2182
	freebytes(eb);
2183
	return ans;
2184
}
2185
 
2186
 
2187
//================= general utility functions ========================
2188
 
2189
static void *
2190
emalloc(int n)
2191
{
2192
	void *p;
2193
	if(n==0)
2194
		n=1;
2195
	p = malloc(n);
2196
	if(p == nil){
2197
		exits("out of memory");
2198
	}
2199
	memset(p, 0, n);
2200
	return p;
2201
}
2202
 
2203
static void *
2204
erealloc(void *ReallocP, int ReallocN)
2205
{
2206
	if(ReallocN == 0)
2207
		ReallocN = 1;
2208
	if(!ReallocP)
2209
		ReallocP = emalloc(ReallocN);
2210
	else if(!(ReallocP = realloc(ReallocP, ReallocN))){
2211
		exits("out of memory");
2212
	}
2213
	return(ReallocP);
2214
}
2215
 
2216
static void
2217
put32(uchar *p, u32int x)
2218
{
2219
	p[0] = x>>24;
2220
	p[1] = x>>16;
2221
	p[2] = x>>8;
2222
	p[3] = x;
2223
}
2224
 
2225
static void
2226
put24(uchar *p, int x)
2227
{
2228
	p[0] = x>>16;
2229
	p[1] = x>>8;
2230
	p[2] = x;
2231
}
2232
 
2233
static void
2234
put16(uchar *p, int x)
2235
{
2236
	p[0] = x>>8;
2237
	p[1] = x;
2238
}
2239
 
2240
static u32int
2241
get32(uchar *p)
2242
{
2243
	return (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3];
2244
}
2245
 
2246
static int
2247
get24(uchar *p)
2248
{
2249
	return (p[0]<<16)|(p[1]<<8)|p[2];
2250
}
2251
 
2252
static int
2253
get16(uchar *p)
2254
{
2255
	return (p[0]<<8)|p[1];
2256
}
2257
 
2258
#define OFFSET(x, s) offsetof(s, x)
2259
 
2260
/*
2261
 * malloc and return a new Bytes structure capable of
2262
 * holding len bytes. (len >= 0)
2263
 * Used to use crypt_malloc, which aborts if malloc fails.
2264
 */
2265
static Bytes*
2266
newbytes(int len)
2267
{
2268
	Bytes* ans;
2269
 
2270
	ans = (Bytes*)malloc(OFFSET(data[0], Bytes) + len);
2271
	ans->len = len;
2272
	return ans;
2273
}
2274
 
2275
/*
2276
 * newbytes(len), with data initialized from buf
2277
 */
2278
static Bytes*
2279
makebytes(uchar* buf, int len)
2280
{
2281
	Bytes* ans;
2282
 
2283
	ans = newbytes(len);
2284
	memmove(ans->data, buf, len);
2285
	return ans;
2286
}
2287
 
2288
static void
2289
freebytes(Bytes* b)
2290
{
2291
	if(b != nil)
2292
		free(b);
2293
}
2294
 
2295
/* len is number of ints */
2296
static Ints*
2297
newints(int len)
2298
{
2299
	Ints* ans;
2300
 
2301
	ans = (Ints*)malloc(OFFSET(data[0], Ints) + len*sizeof(int));
2302
	ans->len = len;
2303
	return ans;
2304
}
2305
 
2306
static Ints*
2307
makeints(int* buf, int len)
2308
{
2309
	Ints* ans;
2310
 
2311
	ans = newints(len);
2312
	if(len > 0)
2313
		memmove(ans->data, buf, len*sizeof(int));
2314
	return ans;
2315
}
2316
 
2317
static void
2318
freeints(Ints* b)
2319
{
2320
	if(b != nil)
2321
		free(b);
2322
}