Subversion Repositories planix.SVN

Rev

Rev 2 | Details | Compare with Previous | Last modification | View Log | RSS feed

Rev Author Line No. Line
2 - 1
#include <u.h>
2
#include <libc.h>
3
#include <bio.h>
4
#include <auth.h>
5
#include <mp.h>
6
#include <libsec.h>
7
 
8
// The main groups of functions are:
9
//		client/server - main handshake protocol definition
10
//		message functions - formating handshake messages
11
//		cipher choices - catalog of digest and encrypt algorithms
12
//		security functions - PKCS#1, sslHMAC, session keygen
13
//		general utility functions - malloc, serialization
14
// The handshake protocol builds on the TLS/SSL3 record layer protocol,
15
// which is implemented in kernel device #a.  See also /lib/rfc/rfc2246.
16
 
17
enum {
18
	TLSFinishedLen = 12,
19
	SSL3FinishedLen = MD5dlen+SHA1dlen,
20
	MaxKeyData = 104,	// amount of secret we may need
21
	MaxChunk = 1<<14,
22
	RandomSize = 32,
23
	SidSize = 32,
24
	MasterSecretSize = 48,
25
	AQueue = 0,
26
	AFlush = 1,
27
};
28
 
29
typedef struct TlsSec TlsSec;
30
 
31
typedef struct Bytes{
32
	int len;
33
	uchar data[1];  // [len]
34
} Bytes;
35
 
36
typedef struct Ints{
37
	int len;
38
	int data[1];  // [len]
39
} Ints;
40
 
41
typedef struct Algs{
42
	char *enc;
43
	char *digest;
44
	int nsecret;
45
	int tlsid;
46
	int ok;
47
} Algs;
48
 
49
typedef struct Finished{
50
	uchar verify[SSL3FinishedLen];
51
	int n;
52
} Finished;
53
 
54
typedef struct TlsConnection{
55
	TlsSec *sec;	// security management goo
56
	int hand, ctl;	// record layer file descriptors
57
	int erred;		// set when tlsError called
58
	int (*trace)(char*fmt, ...); // for debugging
59
	int version;	// protocol we are speaking
60
	int verset;		// version has been set
61
	int ver2hi;		// server got a version 2 hello
62
	int isClient;	// is this the client or server?
63
	Bytes *sid;		// SessionID
64
	Bytes *cert;	// only last - no chain
65
 
66
	Lock statelk;
67
	int state;		// must be set using setstate
68
 
69
	// input buffer for handshake messages
70
	uchar buf[MaxChunk+2048];
71
	uchar *rp, *ep;
72
 
73
	uchar crandom[RandomSize];	// client random
74
	uchar srandom[RandomSize];	// server random
75
	int clientVersion;	// version in ClientHello
76
	char *digest;	// name of digest algorithm to use
77
	char *enc;		// name of encryption algorithm to use
78
	int nsecret;	// amount of secret data to init keys
79
 
80
	// for finished messages
81
	MD5state	hsmd5;	// handshake hash
82
	SHAstate	hssha1;	// handshake hash
83
	Finished	finished;
84
} TlsConnection;
85
 
86
typedef struct Msg{
87
	int tag;
88
	union {
89
		struct {
90
			int version;
91
			uchar 	random[RandomSize];
92
			Bytes*	sid;
93
			Ints*	ciphers;
94
			Bytes*	compressors;
95
		} clientHello;
96
		struct {
97
			int version;
98
			uchar 	random[RandomSize];
99
			Bytes*	sid;
100
			int cipher;
101
			int compressor;
102
		} serverHello;
103
		struct {
104
			int ncert;
105
			Bytes **certs;
106
		} certificate;
107
		struct {
108
			Bytes *types;
109
			int nca;
110
			Bytes **cas;
111
		} certificateRequest;
112
		struct {
113
			Bytes *key;
114
		} clientKeyExchange;
115
		Finished finished;
116
	} u;
117
} Msg;
118
 
119
typedef struct TlsSec{
120
	char *server;	// name of remote; nil for server
121
	int ok;	// <0 killed; ==0 in progress; >0 reusable
122
	RSApub *rsapub;
123
	AuthRpc *rpc;	// factotum for rsa private key
124
	uchar sec[MasterSecretSize];	// master secret
125
	uchar crandom[RandomSize];	// client random
126
	uchar srandom[RandomSize];	// server random
127
	int clientVers;		// version in ClientHello
128
	int vers;			// final version
129
	// byte generation and handshake checksum
130
	void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
131
	void (*setFinished)(TlsSec*, MD5state, SHAstate, uchar*, int);
132
	int nfin;
133
} TlsSec;
134
 
135
 
136
enum {
137
	TLSVersion = 0x0301,
138
	SSL3Version = 0x0300,
139
	ProtocolVersion = 0x0301,	// maximum version we speak
140
	MinProtoVersion = 0x0300,	// limits on version we accept
141
	MaxProtoVersion	= 0x03ff,
142
};
143
 
144
// handshake type
145
enum {
146
	HHelloRequest,
147
	HClientHello,
148
	HServerHello,
149
	HSSL2ClientHello = 9,  /* local convention;  see devtls.c */
150
	HCertificate = 11,
151
	HServerKeyExchange,
152
	HCertificateRequest,
153
	HServerHelloDone,
154
	HCertificateVerify,
155
	HClientKeyExchange,
156
	HFinished = 20,
157
	HMax
158
};
159
 
160
// alerts
161
enum {
162
	ECloseNotify = 0,
163
	EUnexpectedMessage = 10,
164
	EBadRecordMac = 20,
165
	EDecryptionFailed = 21,
166
	ERecordOverflow = 22,
167
	EDecompressionFailure = 30,
168
	EHandshakeFailure = 40,
169
	ENoCertificate = 41,
170
	EBadCertificate = 42,
171
	EUnsupportedCertificate = 43,
172
	ECertificateRevoked = 44,
173
	ECertificateExpired = 45,
174
	ECertificateUnknown = 46,
175
	EIllegalParameter = 47,
176
	EUnknownCa = 48,
177
	EAccessDenied = 49,
178
	EDecodeError = 50,
179
	EDecryptError = 51,
180
	EExportRestriction = 60,
181
	EProtocolVersion = 70,
182
	EInsufficientSecurity = 71,
183
	EInternalError = 80,
184
	EUserCanceled = 90,
185
	ENoRenegotiation = 100,
186
	EMax = 256
187
};
188
 
189
// cipher suites
190
enum {
191
	TLS_NULL_WITH_NULL_NULL	 		= 0x0000,
192
	TLS_RSA_WITH_NULL_MD5 			= 0x0001,
193
	TLS_RSA_WITH_NULL_SHA 			= 0x0002,
194
	TLS_RSA_EXPORT_WITH_RC4_40_MD5 		= 0x0003,
195
	TLS_RSA_WITH_RC4_128_MD5 		= 0x0004,
196
	TLS_RSA_WITH_RC4_128_SHA 		= 0x0005,
197
	TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5	= 0X0006,
198
	TLS_RSA_WITH_IDEA_CBC_SHA 		= 0X0007,
199
	TLS_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X0008,
200
	TLS_RSA_WITH_DES_CBC_SHA		= 0X0009,
201
	TLS_RSA_WITH_3DES_EDE_CBC_SHA		= 0X000A,
202
	TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA	= 0X000B,
203
	TLS_DH_DSS_WITH_DES_CBC_SHA		= 0X000C,
204
	TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA	= 0X000D,
205
	TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X000E,
206
	TLS_DH_RSA_WITH_DES_CBC_SHA		= 0X000F,
207
	TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA	= 0X0010,
208
	TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA	= 0X0011,
209
	TLS_DHE_DSS_WITH_DES_CBC_SHA		= 0X0012,
210
	TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA	= 0X0013,	// ZZZ must be implemented for tls1.0 compliance
211
	TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X0014,
212
	TLS_DHE_RSA_WITH_DES_CBC_SHA		= 0X0015,
213
	TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA	= 0X0016,
214
	TLS_DH_anon_EXPORT_WITH_RC4_40_MD5	= 0x0017,
215
	TLS_DH_anon_WITH_RC4_128_MD5 		= 0x0018,
216
	TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA	= 0X0019,
217
	TLS_DH_anon_WITH_DES_CBC_SHA		= 0X001A,
218
	TLS_DH_anon_WITH_3DES_EDE_CBC_SHA	= 0X001B,
219
 
220
	TLS_RSA_WITH_AES_128_CBC_SHA		= 0X002f,	// aes, aka rijndael with 128 bit blocks
221
	TLS_DH_DSS_WITH_AES_128_CBC_SHA		= 0X0030,
222
	TLS_DH_RSA_WITH_AES_128_CBC_SHA		= 0X0031,
223
	TLS_DHE_DSS_WITH_AES_128_CBC_SHA	= 0X0032,
224
	TLS_DHE_RSA_WITH_AES_128_CBC_SHA	= 0X0033,
225
	TLS_DH_anon_WITH_AES_128_CBC_SHA	= 0X0034,
226
	TLS_RSA_WITH_AES_256_CBC_SHA		= 0X0035,
227
	TLS_DH_DSS_WITH_AES_256_CBC_SHA		= 0X0036,
228
	TLS_DH_RSA_WITH_AES_256_CBC_SHA		= 0X0037,
229
	TLS_DHE_DSS_WITH_AES_256_CBC_SHA	= 0X0038,
230
	TLS_DHE_RSA_WITH_AES_256_CBC_SHA	= 0X0039,
231
	TLS_DH_anon_WITH_AES_256_CBC_SHA	= 0X003A,
232
	CipherMax
233
};
234
 
235
// compression methods
236
enum {
237
	CompressionNull = 0,
238
	CompressionMax
239
};
240
 
241
static Algs cipherAlgs[] = {
242
	{"rc4_128", "md5",	2 * (16 + MD5dlen), TLS_RSA_WITH_RC4_128_MD5},
243
	{"rc4_128", "sha1",	2 * (16 + SHA1dlen), TLS_RSA_WITH_RC4_128_SHA},
244
	{"3des_ede_cbc","sha1",2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
245
};
246
 
247
static uchar compressors[] = {
248
	CompressionNull,
249
};
250
 
251
static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...));
252
static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...));
253
 
254
static void	msgClear(Msg *m);
255
static char* msgPrint(char *buf, int n, Msg *m);
256
static int	msgRecv(TlsConnection *c, Msg *m);
257
static int	msgSend(TlsConnection *c, Msg *m, int act);
258
static void	tlsError(TlsConnection *c, int err, char *msg, ...);
259
#pragma	varargck argpos	tlsError 3
260
static int setVersion(TlsConnection *c, int version);
261
static int finishedMatch(TlsConnection *c, Finished *f);
262
static void tlsConnectionFree(TlsConnection *c);
263
 
264
static int setAlgs(TlsConnection *c, int a);
265
static int okCipher(Ints *cv);
266
static int okCompression(Bytes *cv);
267
static int initCiphers(void);
268
static Ints* makeciphers(void);
269
 
270
static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
271
static int	tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd);
272
static TlsSec*	tlsSecInitc(int cvers, uchar *crandom);
273
static int	tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd);
274
static int	tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient);
275
static void	tlsSecOk(TlsSec *sec);
276
static void	tlsSecKill(TlsSec *sec);
277
static void	tlsSecClose(TlsSec *sec);
278
static void	setMasterSecret(TlsSec *sec, Bytes *pm);
279
static void	serverMasterSecret(TlsSec *sec, uchar *epm, int nepm);
280
static void	setSecrets(TlsSec *sec, uchar *kd, int nkd);
281
static int	clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm);
282
static Bytes *pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
283
static Bytes *pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm);
284
static void	tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
285
static void	sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
286
static void	sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
287
			uchar *seed0, int nseed0, uchar *seed1, int nseed1);
288
static int setVers(TlsSec *sec, int version);
289
 
290
static AuthRpc* factotum_rsa_open(uchar *cert, int certlen);
291
static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher);
292
static void factotum_rsa_close(AuthRpc*rpc);
293
 
294
static void* emalloc(int);
295
static void* erealloc(void*, int);
296
static void put32(uchar *p, u32int);
297
static void put24(uchar *p, int);
298
static void put16(uchar *p, int);
299
static u32int get32(uchar *p);
300
static int get24(uchar *p);
301
static int get16(uchar *p);
302
static Bytes* newbytes(int len);
303
static Bytes* makebytes(uchar* buf, int len);
304
static void freebytes(Bytes* b);
305
static Ints* newints(int len);
306
static Ints* makeints(int* buf, int len);
307
static void freeints(Ints* b);
308
 
309
//================= client/server ========================
310
 
311
//	push TLS onto fd, returning new (application) file descriptor
312
//		or -1 if error.
313
int
314
tlsServer(int fd, TLSconn *conn)
315
{
316
	char buf[8];
317
	char dname[64];
318
	int n, data, ctl, hand;
319
	TlsConnection *tls;
320
 
321
	if(conn == nil)
322
		return -1;
323
	ctl = open("#a/tls/clone", ORDWR);
324
	if(ctl < 0)
325
		return -1;
326
	n = read(ctl, buf, sizeof(buf)-1);
327
	if(n < 0){
328
		close(ctl);
329
		return -1;
330
	}
331
	buf[n] = 0;
332
	sprint(conn->dir, "#a/tls/%s", buf);
333
	sprint(dname, "#a/tls/%s/hand", buf);
334
	hand = open(dname, ORDWR);
335
	if(hand < 0){
336
		close(ctl);
337
		return -1;
338
	}
339
	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
340
	tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace);
341
	sprint(dname, "#a/tls/%s/data", buf);
342
	data = open(dname, ORDWR);
343
	close(fd);
344
	close(hand);
345
	close(ctl);
346
	if(data < 0){
347
		return -1;
348
	}
349
	if(tls == nil){
350
		close(data);
351
		return -1;
352
	}
353
	if(conn->cert)
354
		free(conn->cert);
355
	conn->cert = 0;  // client certificates are not yet implemented
356
	conn->certlen = 0;
357
	conn->sessionIDlen = tls->sid->len;
358
	conn->sessionID = emalloc(conn->sessionIDlen);
359
	memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
360
	tlsConnectionFree(tls);
361
	return data;
362
}
363
 
364
//	push TLS onto fd, returning new (application) file descriptor
365
//		or -1 if error.
366
int
367
tlsClient(int fd, TLSconn *conn)
368
{
369
	char buf[8];
370
	char dname[64];
371
	int n, data, ctl, hand;
372
	TlsConnection *tls;
373
 
374
	if(!conn)
375
		return -1;
376
	ctl = open("#a/tls/clone", ORDWR);
377
	if(ctl < 0)
378
		return -1;
379
	n = read(ctl, buf, sizeof(buf)-1);
380
	if(n < 0){
381
		close(ctl);
382
		return -1;
383
	}
384
	buf[n] = 0;
385
	sprint(conn->dir, "#a/tls/%s", buf);
386
	sprint(dname, "#a/tls/%s/hand", buf);
387
	hand = open(dname, ORDWR);
388
	if(hand < 0){
389
		close(ctl);
390
		return -1;
391
	}
392
	sprint(dname, "#a/tls/%s/data", buf);
393
	data = open(dname, ORDWR);
394
	if(data < 0)
395
		return -1;
396
	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
397
	tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->trace);
398
	close(fd);
399
	close(hand);
400
	close(ctl);
401
	if(tls == nil){
402
		close(data);
403
		return -1;
404
	}
405
	conn->certlen = tls->cert->len;
406
	conn->cert = emalloc(conn->certlen);
407
	memcpy(conn->cert, tls->cert->data, conn->certlen);
408
	conn->sessionIDlen = tls->sid->len;
409
	conn->sessionID = emalloc(conn->sessionIDlen);
410
	memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
411
	tlsConnectionFree(tls);
412
	return data;
413
}
414
 
415
static TlsConnection *
416
tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...))
417
{
418
	TlsConnection *c;
419
	Msg m;
420
	Bytes *csid;
421
	uchar sid[SidSize], kd[MaxKeyData];
422
	char *secrets;
423
	int cipher, compressor, nsid, rv;
424
 
425
	if(trace)
426
		trace("tlsServer2\n");
427
	if(!initCiphers())
428
		return nil;
429
	c = emalloc(sizeof(TlsConnection));
430
	c->ctl = ctl;
431
	c->hand = hand;
432
	c->trace = trace;
433
	c->version = ProtocolVersion;
434
 
435
	memset(&m, 0, sizeof(m));
436
	if(!msgRecv(c, &m)){
437
		if(trace)
438
			trace("initial msgRecv failed\n");
439
		goto Err;
440
	}
441
	if(m.tag != HClientHello) {
442
		tlsError(c, EUnexpectedMessage, "expected a client hello");
443
		goto Err;
444
	}
445
	c->clientVersion = m.u.clientHello.version;
446
	if(trace)
447
		trace("ClientHello version %x\n", c->clientVersion);
448
	if(setVersion(c, m.u.clientHello.version) < 0) {
449
		tlsError(c, EIllegalParameter, "incompatible version");
450
		goto Err;
451
	}
452
 
453
	memmove(c->crandom, m.u.clientHello.random, RandomSize);
454
	cipher = okCipher(m.u.clientHello.ciphers);
455
	if(cipher < 0) {
456
		// reply with EInsufficientSecurity if we know that's the case
457
		if(cipher == -2)
458
			tlsError(c, EInsufficientSecurity, "cipher suites too weak");
459
		else
460
			tlsError(c, EHandshakeFailure, "no matching cipher suite");
461
		goto Err;
462
	}
463
	if(!setAlgs(c, cipher)){
464
		tlsError(c, EHandshakeFailure, "no matching cipher suite");
465
		goto Err;
466
	}
467
	compressor = okCompression(m.u.clientHello.compressors);
468
	if(compressor < 0) {
469
		tlsError(c, EHandshakeFailure, "no matching compressor");
470
		goto Err;
471
	}
472
 
473
	csid = m.u.clientHello.sid;
474
	if(trace)
475
		trace("  cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len);
476
	c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom);
477
	if(c->sec == nil){
478
		tlsError(c, EHandshakeFailure, "can't initialize security: %r");
479
		goto Err;
480
	}
481
	c->sec->rpc = factotum_rsa_open(cert, ncert);
482
	if(c->sec->rpc == nil){
483
		tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
484
		goto Err;
485
	}
486
	c->sec->rsapub = X509toRSApub(cert, ncert, nil, 0);
487
	msgClear(&m);
488
 
489
	m.tag = HServerHello;
490
	m.u.serverHello.version = c->version;
491
	memmove(m.u.serverHello.random, c->srandom, RandomSize);
492
	m.u.serverHello.cipher = cipher;
493
	m.u.serverHello.compressor = compressor;
494
	c->sid = makebytes(sid, nsid);
495
	m.u.serverHello.sid = makebytes(c->sid->data, c->sid->len);
496
	if(!msgSend(c, &m, AQueue))
497
		goto Err;
498
	msgClear(&m);
499
 
500
	m.tag = HCertificate;
501
	m.u.certificate.ncert = 1;
502
	m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes));
503
	m.u.certificate.certs[0] = makebytes(cert, ncert);
504
	if(!msgSend(c, &m, AQueue))
505
		goto Err;
506
	msgClear(&m);
507
 
508
	m.tag = HServerHelloDone;
509
	if(!msgSend(c, &m, AFlush))
510
		goto Err;
511
	msgClear(&m);
512
 
513
	if(!msgRecv(c, &m))
514
		goto Err;
515
	if(m.tag != HClientKeyExchange) {
516
		tlsError(c, EUnexpectedMessage, "expected a client key exchange");
517
		goto Err;
518
	}
519
	if(tlsSecSecrets(c->sec, c->version, m.u.clientKeyExchange.key->data, m.u.clientKeyExchange.key->len, kd, c->nsecret) < 0){
520
		tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
521
		goto Err;
522
	}
523
	if(trace)
524
		trace("tls secrets\n");
525
	secrets = (char*)emalloc(2*c->nsecret);
526
	enc64(secrets, 2*c->nsecret, kd, c->nsecret);
527
	rv = fprint(c->ctl, "secret %s %s 0 %s", c->digest, c->enc, secrets);
528
	memset(secrets, 0, 2*c->nsecret);
529
	free(secrets);
530
	memset(kd, 0, c->nsecret);
531
	if(rv < 0){
532
		tlsError(c, EHandshakeFailure, "can't set keys: %r");
533
		goto Err;
534
	}
535
	msgClear(&m);
536
 
537
	/* no CertificateVerify; skip to Finished */
538
	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
539
		tlsError(c, EInternalError, "can't set finished: %r");
540
		goto Err;
541
	}
542
	if(!msgRecv(c, &m))
543
		goto Err;
544
	if(m.tag != HFinished) {
545
		tlsError(c, EUnexpectedMessage, "expected a finished");
546
		goto Err;
547
	}
548
	if(!finishedMatch(c, &m.u.finished)) {
549
		tlsError(c, EHandshakeFailure, "finished verification failed");
550
		goto Err;
551
	}
552
	msgClear(&m);
553
 
554
	/* change cipher spec */
555
	if(fprint(c->ctl, "changecipher") < 0){
556
		tlsError(c, EInternalError, "can't enable cipher: %r");
557
		goto Err;
558
	}
559
 
560
	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
561
		tlsError(c, EInternalError, "can't set finished: %r");
562
		goto Err;
563
	}
564
	m.tag = HFinished;
565
	m.u.finished = c->finished;
566
	if(!msgSend(c, &m, AFlush))
567
		goto Err;
568
	msgClear(&m);
569
	if(trace)
570
		trace("tls finished\n");
571
 
572
	if(fprint(c->ctl, "opened") < 0)
573
		goto Err;
574
	tlsSecOk(c->sec);
575
	return c;
576
 
577
Err:
578
	msgClear(&m);
579
	tlsConnectionFree(c);
580
	return 0;
581
}
582
 
583
static TlsConnection *
584
tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...))
585
{
586
	TlsConnection *c;
587
	Msg m;
588
	uchar kd[MaxKeyData], *epm;
589
	char *secrets;
590
	int creq, nepm, rv;
591
 
592
	if(!initCiphers())
593
		return nil;
594
	epm = nil;
595
	c = emalloc(sizeof(TlsConnection));
596
	c->version = ProtocolVersion;
597
	c->ctl = ctl;
598
	c->hand = hand;
599
	c->trace = trace;
600
	c->isClient = 1;
601
	c->clientVersion = c->version;
602
 
603
	c->sec = tlsSecInitc(c->clientVersion, c->crandom);
604
	if(c->sec == nil)
605
		goto Err;
606
 
607
	/* client hello */
608
	memset(&m, 0, sizeof(m));
609
	m.tag = HClientHello;
610
	m.u.clientHello.version = c->clientVersion;
611
	memmove(m.u.clientHello.random, c->crandom, RandomSize);
612
	m.u.clientHello.sid = makebytes(csid, ncsid);
613
	m.u.clientHello.ciphers = makeciphers();
614
	m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
615
	if(!msgSend(c, &m, AFlush))
616
		goto Err;
617
	msgClear(&m);
618
 
619
	/* server hello */
620
	if(!msgRecv(c, &m))
621
		goto Err;
622
	if(m.tag != HServerHello) {
623
		tlsError(c, EUnexpectedMessage, "expected a server hello");
624
		goto Err;
625
	}
626
	if(setVersion(c, m.u.serverHello.version) < 0) {
627
		tlsError(c, EIllegalParameter, "incompatible version %r");
628
		goto Err;
629
	}
630
	memmove(c->srandom, m.u.serverHello.random, RandomSize);
631
	c->sid = makebytes(m.u.serverHello.sid->data, m.u.serverHello.sid->len);
632
	if(c->sid->len != 0 && c->sid->len != SidSize) {
633
		tlsError(c, EIllegalParameter, "invalid server session identifier");
634
		goto Err;
635
	}
636
	if(!setAlgs(c, m.u.serverHello.cipher)) {
637
		tlsError(c, EIllegalParameter, "invalid cipher suite");
638
		goto Err;
639
	}
640
	if(m.u.serverHello.compressor != CompressionNull) {
641
		tlsError(c, EIllegalParameter, "invalid compression");
642
		goto Err;
643
	}
644
	msgClear(&m);
645
 
646
	/* certificate */
647
	if(!msgRecv(c, &m) || m.tag != HCertificate) {
648
		tlsError(c, EUnexpectedMessage, "expected a certificate");
649
		goto Err;
650
	}
651
	if(m.u.certificate.ncert < 1) {
652
		tlsError(c, EIllegalParameter, "runt certificate");
653
		goto Err;
654
	}
655
	c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
656
	msgClear(&m);
657
 
658
	/* server key exchange (optional) */
659
	if(!msgRecv(c, &m))
660
		goto Err;
661
	if(m.tag == HServerKeyExchange) {
662
		tlsError(c, EUnexpectedMessage, "got an server key exchange");
663
		goto Err;
664
		// If implementing this later, watch out for rollback attack
665
		// described in Wagner Schneier 1996, section 4.4.
666
	}
667
 
668
	/* certificate request (optional) */
669
	creq = 0;
670
	if(m.tag == HCertificateRequest) {
671
		creq = 1;
672
		msgClear(&m);
673
		if(!msgRecv(c, &m))
674
			goto Err;
675
	}
676
 
677
	if(m.tag != HServerHelloDone) {
678
		tlsError(c, EUnexpectedMessage, "expected a server hello done");
679
		goto Err;
680
	}
681
	msgClear(&m);
682
 
683
	if(tlsSecSecretc(c->sec, c->sid->data, c->sid->len, c->srandom,
684
			c->cert->data, c->cert->len, c->version, &epm, &nepm,
685
			kd, c->nsecret) < 0){
686
		tlsError(c, EBadCertificate, "invalid x509/rsa certificate");
687
		goto Err;
688
	}
689
	secrets = (char*)emalloc(2*c->nsecret);
690
	enc64(secrets, 2*c->nsecret, kd, c->nsecret);
691
	rv = fprint(c->ctl, "secret %s %s 1 %s", c->digest, c->enc, secrets);
692
	memset(secrets, 0, 2*c->nsecret);
693
	free(secrets);
694
	memset(kd, 0, c->nsecret);
695
	if(rv < 0){
696
		tlsError(c, EHandshakeFailure, "can't set keys: %r");
697
		goto Err;
698
	}
699
 
700
	if(creq) {
701
		/* send a zero length certificate */
702
		m.tag = HCertificate;
703
		if(!msgSend(c, &m, AFlush))
704
			goto Err;
705
		msgClear(&m);
706
	}
707
 
708
	/* client key exchange */
709
	m.tag = HClientKeyExchange;
710
	m.u.clientKeyExchange.key = makebytes(epm, nepm);
711
	free(epm);
712
	epm = nil;
713
	if(m.u.clientKeyExchange.key == nil) {
714
		tlsError(c, EHandshakeFailure, "can't set secret: %r");
715
		goto Err;
716
	}
717
	if(!msgSend(c, &m, AFlush))
718
		goto Err;
719
	msgClear(&m);
720
 
721
	/* change cipher spec */
722
	if(fprint(c->ctl, "changecipher") < 0){
723
		tlsError(c, EInternalError, "can't enable cipher: %r");
724
		goto Err;
725
	}
726
 
727
	// Cipherchange must occur immediately before Finished to avoid
728
	// potential hole;  see section 4.3 of Wagner Schneier 1996.
729
	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
730
		tlsError(c, EInternalError, "can't set finished 1: %r");
731
		goto Err;
732
	}
733
	m.tag = HFinished;
734
	m.u.finished = c->finished;
735
 
736
	if(!msgSend(c, &m, AFlush)) {
737
		fprint(2, "tlsClient nepm=%d\n", nepm);
738
		tlsError(c, EInternalError, "can't flush after client Finished: %r");
739
		goto Err;
740
	}
741
	msgClear(&m);
742
 
743
	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
744
		fprint(2, "tlsClient nepm=%d\n", nepm);
745
		tlsError(c, EInternalError, "can't set finished 0: %r");
746
		goto Err;
747
	}
748
	if(!msgRecv(c, &m)) {
749
		fprint(2, "tlsClient nepm=%d\n", nepm);
750
		tlsError(c, EInternalError, "can't read server Finished: %r");
751
		goto Err;
752
	}
753
	if(m.tag != HFinished) {
754
		fprint(2, "tlsClient nepm=%d\n", nepm);
755
		tlsError(c, EUnexpectedMessage, "expected a Finished msg from server");
756
		goto Err;
757
	}
758
 
759
	if(!finishedMatch(c, &m.u.finished)) {
760
		tlsError(c, EHandshakeFailure, "finished verification failed");
761
		goto Err;
762
	}
763
	msgClear(&m);
764
 
765
	if(fprint(c->ctl, "opened") < 0){
766
		if(trace)
767
			trace("unable to do final open: %r\n");
768
		goto Err;
769
	}
770
	tlsSecOk(c->sec);
771
	return c;
772
 
773
Err:
774
	free(epm);
775
	msgClear(&m);
776
	tlsConnectionFree(c);
777
	return 0;
778
}
779
 
780
 
781
//================= message functions ========================
782
 
783
static uchar sendbuf[9000], *sendp;
784
 
785
static int
786
msgSend(TlsConnection *c, Msg *m, int act)
787
{
788
	uchar *p; // sendp = start of new message;  p = write pointer
789
	int nn, n, i;
790
 
791
	if(sendp == nil)
792
		sendp = sendbuf;
793
	p = sendp;
794
	if(c->trace)
795
		c->trace("send %s", msgPrint((char*)p, (sizeof sendbuf) - (p-sendbuf), m));
796
 
797
	p[0] = m->tag;	// header - fill in size later
798
	p += 4;
799
 
800
	switch(m->tag) {
801
	default:
802
		tlsError(c, EInternalError, "can't encode a %d", m->tag);
803
		goto Err;
804
	case HClientHello:
805
		// version
806
		put16(p, m->u.clientHello.version);
807
		p += 2;
808
 
809
		// random
810
		memmove(p, m->u.clientHello.random, RandomSize);
811
		p += RandomSize;
812
 
813
		// sid
814
		n = m->u.clientHello.sid->len;
815
		assert(n < 256);
816
		p[0] = n;
817
		memmove(p+1, m->u.clientHello.sid->data, n);
818
		p += n+1;
819
 
820
		n = m->u.clientHello.ciphers->len;
821
		assert(n > 0 && n < 200);
822
		put16(p, n*2);
823
		p += 2;
824
		for(i=0; i<n; i++) {
825
			put16(p, m->u.clientHello.ciphers->data[i]);
826
			p += 2;
827
		}
828
 
829
		n = m->u.clientHello.compressors->len;
830
		assert(n > 0);
831
		p[0] = n;
832
		memmove(p+1, m->u.clientHello.compressors->data, n);
833
		p += n+1;
834
		break;
835
	case HServerHello:
836
		put16(p, m->u.serverHello.version);
837
		p += 2;
838
 
839
		// random
840
		memmove(p, m->u.serverHello.random, RandomSize);
841
		p += RandomSize;
842
 
843
		// sid
844
		n = m->u.serverHello.sid->len;
845
		assert(n < 256);
846
		p[0] = n;
847
		memmove(p+1, m->u.serverHello.sid->data, n);
848
		p += n+1;
849
 
850
		put16(p, m->u.serverHello.cipher);
851
		p += 2;
852
		p[0] = m->u.serverHello.compressor;
853
		p += 1;
854
		break;
855
	case HServerHelloDone:
856
		break;
857
	case HCertificate:
858
		nn = 0;
859
		for(i = 0; i < m->u.certificate.ncert; i++)
860
			nn += 3 + m->u.certificate.certs[i]->len;
861
		if(p + 3 + nn - sendbuf > sizeof(sendbuf)) {
862
			tlsError(c, EInternalError, "output buffer too small for certificate");
863
			goto Err;
864
		}
865
		put24(p, nn);
866
		p += 3;
867
		for(i = 0; i < m->u.certificate.ncert; i++){
868
			put24(p, m->u.certificate.certs[i]->len);
869
			p += 3;
870
			memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len);
871
			p += m->u.certificate.certs[i]->len;
872
		}
873
		break;
874
	case HClientKeyExchange:
875
		n = m->u.clientKeyExchange.key->len;
876
		if(c->version != SSL3Version){
877
			put16(p, n);
878
			p += 2;
879
		}
880
		memmove(p, m->u.clientKeyExchange.key->data, n);
881
		p += n;
882
		break;
883
	case HFinished:
884
		memmove(p, m->u.finished.verify, m->u.finished.n);
885
		p += m->u.finished.n;
886
		break;
887
	}
888
 
889
	// go back and fill in size
890
	n = p-sendp;
891
	assert(p <= sendbuf+sizeof(sendbuf));
892
	put24(sendp+1, n-4);
893
 
894
	// remember hash of Handshake messages
895
	if(m->tag != HHelloRequest) {
896
		md5(sendp, n, 0, &c->hsmd5);
897
		sha1(sendp, n, 0, &c->hssha1);
898
	}
899
 
900
	sendp = p;
901
	if(act == AFlush){
902
		sendp = sendbuf;
903
		if(write(c->hand, sendbuf, p-sendbuf) < 0){
904
			fprint(2, "write error: %r\n");
905
			goto Err;
906
		}
907
	}
908
	msgClear(m);
909
	return 1;
910
Err:
911
	msgClear(m);
912
	return 0;
913
}
914
 
915
static uchar*
916
tlsReadN(TlsConnection *c, int n)
917
{
918
	uchar *p;
919
	int nn, nr;
920
 
921
	nn = c->ep - c->rp;
922
	if(nn < n){
923
		if(c->rp != c->buf){
924
			memmove(c->buf, c->rp, nn);
925
			c->rp = c->buf;
926
			c->ep = &c->buf[nn];
927
		}
928
		for(; nn < n; nn += nr) {
929
			nr = read(c->hand, &c->rp[nn], n - nn);
930
			if(nr <= 0)
931
				return nil;
932
			c->ep += nr;
933
		}
934
	}
935
	p = c->rp;
936
	c->rp += n;
937
	return p;
938
}
939
 
940
static int
941
msgRecv(TlsConnection *c, Msg *m)
942
{
943
	uchar *p;
944
	int type, n, nn, i, nsid, nrandom, nciph;
945
 
946
	for(;;) {
947
		p = tlsReadN(c, 4);
948
		if(p == nil)
949
			return 0;
950
		type = p[0];
951
		n = get24(p+1);
952
 
953
		if(type != HHelloRequest)
954
			break;
955
		if(n != 0) {
956
			tlsError(c, EDecodeError, "invalid hello request during handshake");
957
			return 0;
958
		}
959
	}
960
 
961
	if(n > sizeof(c->buf)) {
962
		tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->buf));
963
		return 0;
964
	}
965
 
966
	if(type == HSSL2ClientHello){
967
		/* Cope with an SSL3 ClientHello expressed in SSL2 record format.
968
			This is sent by some clients that we must interoperate
969
			with, such as Java's JSSE and Microsoft's Internet Explorer. */
970
		p = tlsReadN(c, n);
971
		if(p == nil)
972
			return 0;
973
		md5(p, n, 0, &c->hsmd5);
974
		sha1(p, n, 0, &c->hssha1);
975
		m->tag = HClientHello;
976
		if(n < 22)
977
			goto Short;
978
		m->u.clientHello.version = get16(p+1);
979
		p += 3;
980
		n -= 3;
981
		nn = get16(p); /* cipher_spec_len */
982
		nsid = get16(p + 2);
983
		nrandom = get16(p + 4);
984
		p += 6;
985
		n -= 6;
986
		if(nsid != 0 	/* no sid's, since shouldn't restart using ssl2 header */
987
				|| nrandom < 16 || nn % 3)
988
			goto Err;
989
		if(c->trace && (n - nrandom != nn))
990
			c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn);
991
		/* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
992
		nciph = 0;
993
		for(i = 0; i < nn; i += 3)
994
			if(p[i] == 0)
995
				nciph++;
996
		m->u.clientHello.ciphers = newints(nciph);
997
		nciph = 0;
998
		for(i = 0; i < nn; i += 3)
999
			if(p[i] == 0)
1000
				m->u.clientHello.ciphers->data[nciph++] = get16(&p[i + 1]);
1001
		p += nn;
1002
		m->u.clientHello.sid = makebytes(nil, 0);
1003
		if(nrandom > RandomSize)
1004
			nrandom = RandomSize;
1005
		memset(m->u.clientHello.random, 0, RandomSize - nrandom);
1006
		memmove(&m->u.clientHello.random[RandomSize - nrandom], p, nrandom);
1007
		m->u.clientHello.compressors = newbytes(1);
1008
		m->u.clientHello.compressors->data[0] = CompressionNull;
1009
		goto Ok;
1010
	}
1011
 
1012
	md5(p, 4, 0, &c->hsmd5);
1013
	sha1(p, 4, 0, &c->hssha1);
1014
 
1015
	p = tlsReadN(c, n);
1016
	if(p == nil)
1017
		return 0;
1018
 
1019
	md5(p, n, 0, &c->hsmd5);
1020
	sha1(p, n, 0, &c->hssha1);
1021
 
1022
	m->tag = type;
1023
 
1024
	switch(type) {
1025
	default:
1026
		tlsError(c, EUnexpectedMessage, "can't decode a %d", type);
1027
		goto Err;
1028
	case HClientHello:
1029
		if(n < 2)
1030
			goto Short;
1031
		m->u.clientHello.version = get16(p);
1032
		p += 2;
1033
		n -= 2;
1034
 
1035
		if(n < RandomSize)
1036
			goto Short;
1037
		memmove(m->u.clientHello.random, p, RandomSize);
1038
		p += RandomSize;
1039
		n -= RandomSize;
1040
		if(n < 1 || n < p[0]+1)
1041
			goto Short;
1042
		m->u.clientHello.sid = makebytes(p+1, p[0]);
1043
		p += m->u.clientHello.sid->len+1;
1044
		n -= m->u.clientHello.sid->len+1;
1045
 
1046
		if(n < 2)
1047
			goto Short;
1048
		nn = get16(p);
1049
		p += 2;
1050
		n -= 2;
1051
 
1052
		if((nn & 1) || n < nn || nn < 2)
1053
			goto Short;
1054
		m->u.clientHello.ciphers = newints(nn >> 1);
1055
		for(i = 0; i < nn; i += 2)
1056
			m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]);
1057
		p += nn;
1058
		n -= nn;
1059
 
1060
		if(n < 1 || n < p[0]+1 || p[0] == 0)
1061
			goto Short;
1062
		nn = p[0];
1063
		m->u.clientHello.compressors = newbytes(nn);
1064
		memmove(m->u.clientHello.compressors->data, p+1, nn);
1065
		n -= nn + 1;
1066
		break;
1067
	case HServerHello:
1068
		if(n < 2)
1069
			goto Short;
1070
		m->u.serverHello.version = get16(p);
1071
		p += 2;
1072
		n -= 2;
1073
 
1074
		if(n < RandomSize)
1075
			goto Short;
1076
		memmove(m->u.serverHello.random, p, RandomSize);
1077
		p += RandomSize;
1078
		n -= RandomSize;
1079
 
1080
		if(n < 1 || n < p[0]+1)
1081
			goto Short;
1082
		m->u.serverHello.sid = makebytes(p+1, p[0]);
1083
		p += m->u.serverHello.sid->len+1;
1084
		n -= m->u.serverHello.sid->len+1;
1085
 
1086
		if(n < 3)
1087
			goto Short;
1088
		m->u.serverHello.cipher = get16(p);
1089
		m->u.serverHello.compressor = p[2];
1090
		n -= 3;
1091
		break;
1092
	case HCertificate:
1093
		if(n < 3)
1094
			goto Short;
1095
		nn = get24(p);
1096
		p += 3;
1097
		n -= 3;
1098
		if(n != nn)
1099
			goto Short;
1100
		/* certs */
1101
		i = 0;
1102
		while(n > 0) {
1103
			if(n < 3)
1104
				goto Short;
1105
			nn = get24(p);
1106
			p += 3;
1107
			n -= 3;
1108
			if(nn > n)
1109
				goto Short;
1110
			m->u.certificate.ncert = i+1;
1111
			m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes));
1112
			m->u.certificate.certs[i] = makebytes(p, nn);
1113
			p += nn;
1114
			n -= nn;
1115
			i++;
1116
		}
1117
		break;
1118
	case HCertificateRequest:
1119
		if(n < 2)
1120
			goto Short;
1121
		nn = get16(p);
1122
		p += 2;
1123
		n -= 2;
1124
		if(nn < 1 || nn > n)
1125
			goto Short;
1126
		m->u.certificateRequest.types = makebytes(p, nn);
1127
		nn = get24(p);
1128
		p += 3;
1129
		n -= 3;
1130
		if(nn == 0 || n != nn)
1131
			goto Short;
1132
		/* cas */
1133
		i = 0;
1134
		while(n > 0) {
1135
			if(n < 2)
1136
				goto Short;
1137
			nn = get16(p);
1138
			p += 2;
1139
			n -= 2;
1140
			if(nn < 1 || nn > n)
1141
				goto Short;
1142
			m->u.certificateRequest.nca = i+1;
1143
			m->u.certificateRequest.cas = erealloc(m->u.certificateRequest.cas, (i+1)*sizeof(Bytes));
1144
			m->u.certificateRequest.cas[i] = makebytes(p, nn);
1145
			p += nn;
1146
			n -= nn;
1147
			i++;
1148
		}
1149
		break;
1150
	case HServerHelloDone:
1151
		break;
1152
	case HClientKeyExchange:
1153
		/*
1154
		 * this message depends upon the encryption selected
1155
		 * assume rsa.
1156
		 */
1157
		if(c->version == SSL3Version)
1158
			nn = n;
1159
		else{
1160
			if(n < 2)
1161
				goto Short;
1162
			nn = get16(p);
1163
			p += 2;
1164
			n -= 2;
1165
		}
1166
		if(n < nn)
1167
			goto Short;
1168
		m->u.clientKeyExchange.key = makebytes(p, nn);
1169
		n -= nn;
1170
		break;
1171
	case HFinished:
1172
		m->u.finished.n = c->finished.n;
1173
		if(n < m->u.finished.n)
1174
			goto Short;
1175
		memmove(m->u.finished.verify, p, m->u.finished.n);
1176
		n -= m->u.finished.n;
1177
		break;
1178
	}
1179
 
1180
	if(type != HClientHello && n != 0)
1181
		goto Short;
1182
Ok:
1183
	if(c->trace){
1184
		char buf[8000];
1185
		c->trace("recv %s", msgPrint(buf, sizeof buf, m));
1186
	}
1187
	return 1;
1188
Short:
1189
	tlsError(c, EDecodeError, "handshake message has invalid length");
1190
Err:
1191
	msgClear(m);
1192
	return 0;
1193
}
1194
 
1195
static void
1196
msgClear(Msg *m)
1197
{
1198
	int i;
1199
 
1200
	switch(m->tag) {
1201
	default:
1202
		sysfatal("msgClear: unknown message type: %d", m->tag);
1203
	case HHelloRequest:
1204
		break;
1205
	case HClientHello:
1206
		freebytes(m->u.clientHello.sid);
1207
		freeints(m->u.clientHello.ciphers);
1208
		freebytes(m->u.clientHello.compressors);
1209
		break;
1210
	case HServerHello:
1211
		freebytes(m->u.clientHello.sid);
1212
		break;
1213
	case HCertificate:
1214
		for(i=0; i<m->u.certificate.ncert; i++)
1215
			freebytes(m->u.certificate.certs[i]);
1216
		free(m->u.certificate.certs);
1217
		break;
1218
	case HCertificateRequest:
1219
		freebytes(m->u.certificateRequest.types);
1220
		for(i=0; i<m->u.certificateRequest.nca; i++)
1221
			freebytes(m->u.certificateRequest.cas[i]);
1222
		free(m->u.certificateRequest.cas);
1223
		break;
1224
	case HServerHelloDone:
1225
		break;
1226
	case HClientKeyExchange:
1227
		freebytes(m->u.clientKeyExchange.key);
1228
		break;
1229
	case HFinished:
1230
		break;
1231
	}
1232
	memset(m, 0, sizeof(Msg));
1233
}
1234
 
1235
static char *
1236
bytesPrint(char *bs, char *be, char *s0, Bytes *b, char *s1)
1237
{
1238
	int i;
1239
 
1240
	if(s0)
1241
		bs = seprint(bs, be, "%s", s0);
1242
	bs = seprint(bs, be, "[");
1243
	if(b == nil)
1244
		bs = seprint(bs, be, "nil");
1245
	else
1246
		for(i=0; i<b->len; i++)
1247
			bs = seprint(bs, be, "%.2x ", b->data[i]);
1248
	bs = seprint(bs, be, "]");
1249
	if(s1)
1250
		bs = seprint(bs, be, "%s", s1);
1251
	return bs;
1252
}
1253
 
1254
static char *
1255
intsPrint(char *bs, char *be, char *s0, Ints *b, char *s1)
1256
{
1257
	int i;
1258
 
1259
	if(s0)
1260
		bs = seprint(bs, be, "%s", s0);
1261
	bs = seprint(bs, be, "[");
1262
	if(b == nil)
1263
		bs = seprint(bs, be, "nil");
1264
	else
1265
		for(i=0; i<b->len; i++)
1266
			bs = seprint(bs, be, "%x ", b->data[i]);
1267
	bs = seprint(bs, be, "]");
1268
	if(s1)
1269
		bs = seprint(bs, be, "%s", s1);
1270
	return bs;
1271
}
1272
 
1273
static char*
1274
msgPrint(char *buf, int n, Msg *m)
1275
{
1276
	int i;
1277
	char *bs = buf, *be = buf+n;
1278
 
1279
	switch(m->tag) {
1280
	default:
1281
		bs = seprint(bs, be, "unknown %d\n", m->tag);
1282
		break;
1283
	case HClientHello:
1284
		bs = seprint(bs, be, "ClientHello\n");
1285
		bs = seprint(bs, be, "\tversion: %.4x\n", m->u.clientHello.version);
1286
		bs = seprint(bs, be, "\trandom: ");
1287
		for(i=0; i<RandomSize; i++)
1288
			bs = seprint(bs, be, "%.2x", m->u.clientHello.random[i]);
1289
		bs = seprint(bs, be, "\n");
1290
		bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
1291
		bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
1292
		bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
1293
		break;
1294
	case HServerHello:
1295
		bs = seprint(bs, be, "ServerHello\n");
1296
		bs = seprint(bs, be, "\tversion: %.4x\n", m->u.serverHello.version);
1297
		bs = seprint(bs, be, "\trandom: ");
1298
		for(i=0; i<RandomSize; i++)
1299
			bs = seprint(bs, be, "%.2x", m->u.serverHello.random[i]);
1300
		bs = seprint(bs, be, "\n");
1301
		bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
1302
		bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
1303
		bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
1304
		break;
1305
	case HCertificate:
1306
		bs = seprint(bs, be, "Certificate\n");
1307
		for(i=0; i<m->u.certificate.ncert; i++)
1308
			bs = bytesPrint(bs, be, "\t", m->u.certificate.certs[i], "\n");
1309
		break;
1310
	case HCertificateRequest:
1311
		bs = seprint(bs, be, "CertificateRequest\n");
1312
		bs = bytesPrint(bs, be, "\ttypes: ", m->u.certificateRequest.types, "\n");
1313
		bs = seprint(bs, be, "\tcertificateauthorities\n");
1314
		for(i=0; i<m->u.certificateRequest.nca; i++)
1315
			bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n");
1316
		break;
1317
	case HServerHelloDone:
1318
		bs = seprint(bs, be, "ServerHelloDone\n");
1319
		break;
1320
	case HClientKeyExchange:
1321
		bs = seprint(bs, be, "HClientKeyExchange\n");
1322
		bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
1323
		break;
1324
	case HFinished:
1325
		bs = seprint(bs, be, "HFinished\n");
1326
		for(i=0; i<m->u.finished.n; i++)
1327
			bs = seprint(bs, be, "%.2x", m->u.finished.verify[i]);
1328
		bs = seprint(bs, be, "\n");
1329
		break;
1330
	}
1331
	USED(bs);
1332
	return buf;
1333
}
1334
 
1335
static void
1336
tlsError(TlsConnection *c, int err, char *fmt, ...)
1337
{
1338
	char msg[512];
1339
	va_list arg;
1340
 
1341
	va_start(arg, fmt);
1342
	vseprint(msg, msg+sizeof(msg), fmt, arg);
1343
	va_end(arg);
1344
	if(c->trace)
1345
		c->trace("tlsError: %s\n", msg);
1346
	else if(c->erred)
1347
		fprint(2, "double error: %r, %s", msg);
1348
	else
1349
		werrstr("tls: local %s", msg);
1350
	c->erred = 1;
1351
	fprint(c->ctl, "alert %d", err);
1352
}
1353
 
1354
// commit to specific version number
1355
static int
1356
setVersion(TlsConnection *c, int version)
1357
{
1358
	if(c->verset || version > MaxProtoVersion || version < MinProtoVersion)
1359
		return -1;
1360
	if(version > c->version)
1361
		version = c->version;
1362
	if(version == SSL3Version) {
1363
		c->version = version;
1364
		c->finished.n = SSL3FinishedLen;
1365
	}else if(version == TLSVersion){
1366
		c->version = version;
1367
		c->finished.n = TLSFinishedLen;
1368
	}else
1369
		return -1;
1370
	c->verset = 1;
1371
	return fprint(c->ctl, "version 0x%x", version);
1372
}
1373
 
1374
// confirm that received Finished message matches the expected value
1375
static int
1376
finishedMatch(TlsConnection *c, Finished *f)
1377
{
1378
	return memcmp(f->verify, c->finished.verify, f->n) == 0;
1379
}
1380
 
1381
// free memory associated with TlsConnection struct
1382
//		(but don't close the TLS channel itself)
1383
static void
1384
tlsConnectionFree(TlsConnection *c)
1385
{
1386
	tlsSecClose(c->sec);
1387
	freebytes(c->sid);
1388
	freebytes(c->cert);
1389
	memset(c, 0, sizeof(c));
1390
	free(c);
1391
}
1392
 
1393
 
1394
//================= cipher choices ========================
1395
 
1396
static int weakCipher[CipherMax] =
1397
{
1398
	1,	/* TLS_NULL_WITH_NULL_NULL */
1399
	1,	/* TLS_RSA_WITH_NULL_MD5 */
1400
	1,	/* TLS_RSA_WITH_NULL_SHA */
1401
	1,	/* TLS_RSA_EXPORT_WITH_RC4_40_MD5 */
1402
	0,	/* TLS_RSA_WITH_RC4_128_MD5 */
1403
	0,	/* TLS_RSA_WITH_RC4_128_SHA */
1404
	1,	/* TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 */
1405
	0,	/* TLS_RSA_WITH_IDEA_CBC_SHA */
1406
	1,	/* TLS_RSA_EXPORT_WITH_DES40_CBC_SHA */
1407
	0,	/* TLS_RSA_WITH_DES_CBC_SHA */
1408
	0,	/* TLS_RSA_WITH_3DES_EDE_CBC_SHA */
1409
	1,	/* TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA */
1410
	0,	/* TLS_DH_DSS_WITH_DES_CBC_SHA */
1411
	0,	/* TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA */
1412
	1,	/* TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA */
1413
	0,	/* TLS_DH_RSA_WITH_DES_CBC_SHA */
1414
	0,	/* TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA */
1415
	1,	/* TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA */
1416
	0,	/* TLS_DHE_DSS_WITH_DES_CBC_SHA */
1417
	0,	/* TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA */
1418
	1,	/* TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA */
1419
	0,	/* TLS_DHE_RSA_WITH_DES_CBC_SHA */
1420
	0,	/* TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA */
1421
	1,	/* TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 */
1422
	1,	/* TLS_DH_anon_WITH_RC4_128_MD5 */
1423
	1,	/* TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA */
1424
	1,	/* TLS_DH_anon_WITH_DES_CBC_SHA */
1425
	1,	/* TLS_DH_anon_WITH_3DES_EDE_CBC_SHA */
1426
};
1427
 
1428
static int
1429
setAlgs(TlsConnection *c, int a)
1430
{
1431
	int i;
1432
 
1433
	for(i = 0; i < nelem(cipherAlgs); i++){
1434
		if(cipherAlgs[i].tlsid == a){
1435
			c->enc = cipherAlgs[i].enc;
1436
			c->digest = cipherAlgs[i].digest;
1437
			c->nsecret = cipherAlgs[i].nsecret;
1438
			if(c->nsecret > MaxKeyData)
1439
				return 0;
1440
			return 1;
1441
		}
1442
	}
1443
	return 0;
1444
}
1445
 
1446
static int
1447
okCipher(Ints *cv)
1448
{
1449
	int weak, i, j, c;
1450
 
1451
	weak = 1;
1452
	for(i = 0; i < cv->len; i++) {
1453
		c = cv->data[i];
1454
		if(c >= CipherMax)
1455
			weak = 0;
1456
		else
1457
			weak &= weakCipher[c];
1458
		for(j = 0; j < nelem(cipherAlgs); j++)
1459
			if(cipherAlgs[j].ok && cipherAlgs[j].tlsid == c)
1460
				return c;
1461
	}
1462
	if(weak)
1463
		return -2;
1464
	return -1;
1465
}
1466
 
1467
static int
1468
okCompression(Bytes *cv)
1469
{
1470
	int i, j, c;
1471
 
1472
	for(i = 0; i < cv->len; i++) {
1473
		c = cv->data[i];
1474
		for(j = 0; j < nelem(compressors); j++) {
1475
			if(compressors[j] == c)
1476
				return c;
1477
		}
1478
	}
1479
	return -1;
1480
}
1481
 
1482
static Lock	ciphLock;
1483
static int	nciphers;
1484
 
1485
static int
1486
initCiphers(void)
1487
{
1488
	enum {MaxAlgF = 1024, MaxAlgs = 10};
1489
	char s[MaxAlgF], *flds[MaxAlgs];
1490
	int i, j, n, ok;
1491
 
1492
	lock(&ciphLock);
1493
	if(nciphers){
1494
		unlock(&ciphLock);
1495
		return nciphers;
1496
	}
1497
	j = open("#a/tls/encalgs", OREAD);
1498
	if(j < 0){
1499
		werrstr("can't open #a/tls/encalgs: %r");
1500
		return 0;
1501
	}
1502
	n = read(j, s, MaxAlgF-1);
1503
	close(j);
1504
	if(n <= 0){
1505
		werrstr("nothing in #a/tls/encalgs: %r");
1506
		return 0;
1507
	}
1508
	s[n] = 0;
1509
	n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1510
	for(i = 0; i < nelem(cipherAlgs); i++){
1511
		ok = 0;
1512
		for(j = 0; j < n; j++){
1513
			if(strcmp(cipherAlgs[i].enc, flds[j]) == 0){
1514
				ok = 1;
1515
				break;
1516
			}
1517
		}
1518
		cipherAlgs[i].ok = ok;
1519
	}
1520
 
1521
	j = open("#a/tls/hashalgs", OREAD);
1522
	if(j < 0){
1523
		werrstr("can't open #a/tls/hashalgs: %r");
1524
		return 0;
1525
	}
1526
	n = read(j, s, MaxAlgF-1);
1527
	close(j);
1528
	if(n <= 0){
1529
		werrstr("nothing in #a/tls/hashalgs: %r");
1530
		return 0;
1531
	}
1532
	s[n] = 0;
1533
	n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1534
	for(i = 0; i < nelem(cipherAlgs); i++){
1535
		ok = 0;
1536
		for(j = 0; j < n; j++){
1537
			if(strcmp(cipherAlgs[i].digest, flds[j]) == 0){
1538
				ok = 1;
1539
				break;
1540
			}
1541
		}
1542
		cipherAlgs[i].ok &= ok;
1543
		if(cipherAlgs[i].ok)
1544
			nciphers++;
1545
	}
1546
	unlock(&ciphLock);
1547
	return nciphers;
1548
}
1549
 
1550
static Ints*
1551
makeciphers(void)
1552
{
1553
	Ints *is;
1554
	int i, j;
1555
 
1556
	is = newints(nciphers);
1557
	j = 0;
1558
	for(i = 0; i < nelem(cipherAlgs); i++){
1559
		if(cipherAlgs[i].ok)
1560
			is->data[j++] = cipherAlgs[i].tlsid;
1561
	}
1562
	return is;
1563
}
1564
 
1565
 
1566
 
1567
//================= security functions ========================
1568
 
1569
// given X.509 certificate, set up connection to factotum
1570
//	for using corresponding private key
1571
static AuthRpc*
1572
factotum_rsa_open(uchar *cert, int certlen)
1573
{
1574
	int afd;
1575
	char *s;
1576
	mpint *pub = nil;
1577
	RSApub *rsapub;
1578
	AuthRpc *rpc;
1579
 
1580
	// start talking to factotum
1581
	if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
1582
		return nil;
1583
	if((rpc = auth_allocrpc(afd)) == nil){
1584
		close(afd);
1585
		return nil;
1586
	}
1587
	s = "proto=rsa service=tls role=client";
1588
	if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
1589
		factotum_rsa_close(rpc);
1590
		return nil;
1591
	}
1592
 
1593
	// roll factotum keyring around to match certificate
1594
	rsapub = X509toRSApub(cert, certlen, nil, 0);
1595
	while(1){
1596
		if(auth_rpc(rpc, "read", nil, 0) != ARok){
1597
			factotum_rsa_close(rpc);
1598
			rpc = nil;
1599
			goto done;
1600
		}
1601
		pub = strtomp(rpc->arg, nil, 16, nil);
1602
		assert(pub != nil);
1603
		if(mpcmp(pub,rsapub->n) == 0)
1604
			break;
1605
	}
1606
done:
1607
	mpfree(pub);
1608
	rsapubfree(rsapub);
1609
	return rpc;
1610
}
1611
 
1612
static mpint*
1613
factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher)
1614
{
1615
	char *p;
1616
	int rv;
1617
 
1618
	if((p = mptoa(cipher, 16, nil, 0)) == nil)
1619
		return nil;
1620
	rv = auth_rpc(rpc, "write", p, strlen(p));
1621
	free(p);
1622
	if(rv != ARok || auth_rpc(rpc, "read", nil, 0) != ARok)
1623
		return nil;
1624
	mpfree(cipher);
1625
	return strtomp(rpc->arg, nil, 16, nil);
1626
}
1627
 
1628
static void
1629
factotum_rsa_close(AuthRpc*rpc)
1630
{
1631
	if(!rpc)
1632
		return;
1633
	close(rpc->afd);
1634
	auth_freerpc(rpc);
1635
}
1636
 
1637
static void
1638
tlsPmd5(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1639
{
1640
	uchar ai[MD5dlen], tmp[MD5dlen];
1641
	int i, n;
1642
	MD5state *s;
1643
 
1644
	// generate a1
1645
	s = hmac_md5(label, nlabel, key, nkey, nil, nil);
1646
	s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1647
	hmac_md5(seed1, nseed1, key, nkey, ai, s);
1648
 
1649
	while(nbuf > 0) {
1650
		s = hmac_md5(ai, MD5dlen, key, nkey, nil, nil);
1651
		s = hmac_md5(label, nlabel, key, nkey, nil, s);
1652
		s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1653
		hmac_md5(seed1, nseed1, key, nkey, tmp, s);
1654
		n = MD5dlen;
1655
		if(n > nbuf)
1656
			n = nbuf;
1657
		for(i = 0; i < n; i++)
1658
			buf[i] ^= tmp[i];
1659
		buf += n;
1660
		nbuf -= n;
1661
		hmac_md5(ai, MD5dlen, key, nkey, tmp, nil);
1662
		memmove(ai, tmp, MD5dlen);
1663
	}
1664
}
1665
 
1666
static void
1667
tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1668
{
1669
	uchar ai[SHA1dlen], tmp[SHA1dlen];
1670
	int i, n;
1671
	SHAstate *s;
1672
 
1673
	// generate a1
1674
	s = hmac_sha1(label, nlabel, key, nkey, nil, nil);
1675
	s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1676
	hmac_sha1(seed1, nseed1, key, nkey, ai, s);
1677
 
1678
	while(nbuf > 0) {
1679
		s = hmac_sha1(ai, SHA1dlen, key, nkey, nil, nil);
1680
		s = hmac_sha1(label, nlabel, key, nkey, nil, s);
1681
		s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1682
		hmac_sha1(seed1, nseed1, key, nkey, tmp, s);
1683
		n = SHA1dlen;
1684
		if(n > nbuf)
1685
			n = nbuf;
1686
		for(i = 0; i < n; i++)
1687
			buf[i] ^= tmp[i];
1688
		buf += n;
1689
		nbuf -= n;
1690
		hmac_sha1(ai, SHA1dlen, key, nkey, tmp, nil);
1691
		memmove(ai, tmp, SHA1dlen);
1692
	}
1693
}
1694
 
1695
// fill buf with md5(args)^sha1(args)
1696
static void
1697
tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1698
{
1699
	int i;
1700
	int nlabel = strlen(label);
1701
	int n = (nkey + 1) >> 1;
1702
 
1703
	for(i = 0; i < nbuf; i++)
1704
		buf[i] = 0;
1705
	tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1706
	tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1707
}
1708
 
1709
/*
1710
 * for setting server session id's
1711
 */
1712
static Lock	sidLock;
1713
static long	maxSid = 1;
1714
 
1715
/* the keys are verified to have the same public components
1716
 * and to function correctly with pkcs 1 encryption and decryption. */
1717
static TlsSec*
1718
tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom)
1719
{
1720
	TlsSec *sec = emalloc(sizeof(*sec));
1721
 
1722
	USED(csid); USED(ncsid);  // ignore csid for now
1723
 
1724
	memmove(sec->crandom, crandom, RandomSize);
1725
	sec->clientVers = cvers;
1726
 
1727
	put32(sec->srandom, time(0));
1728
	genrandom(sec->srandom+4, RandomSize-4);
1729
	memmove(srandom, sec->srandom, RandomSize);
1730
 
1731
	/*
1732
	 * make up a unique sid: use our pid, and and incrementing id
1733
	 * can signal no sid by setting nssid to 0.
1734
	 */
1735
	memset(ssid, 0, SidSize);
1736
	put32(ssid, getpid());
1737
	lock(&sidLock);
1738
	put32(ssid+4, maxSid++);
1739
	unlock(&sidLock);
1740
	*nssid = SidSize;
1741
	return sec;
1742
}
1743
 
1744
static int
1745
tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd)
1746
{
1747
	if(epm != nil){
1748
		if(setVers(sec, vers) < 0)
1749
			goto Err;
1750
		serverMasterSecret(sec, epm, nepm);
1751
	}else if(sec->vers != vers){
1752
		werrstr("mismatched session versions");
1753
		goto Err;
1754
	}
1755
	setSecrets(sec, kd, nkd);
1756
	return 0;
1757
Err:
1758
	sec->ok = -1;
1759
	return -1;
1760
}
1761
 
1762
static TlsSec*
1763
tlsSecInitc(int cvers, uchar *crandom)
1764
{
1765
	TlsSec *sec = emalloc(sizeof(*sec));
1766
	sec->clientVers = cvers;
1767
	put32(sec->crandom, time(0));
1768
	genrandom(sec->crandom+4, RandomSize-4);
1769
	memmove(crandom, sec->crandom, RandomSize);
1770
	return sec;
1771
}
1772
 
1773
static int
1774
tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd)
1775
{
1776
	RSApub *pub;
1777
 
1778
	pub = nil;
1779
 
1780
	USED(sid);
1781
	USED(nsid);
1782
 
1783
	memmove(sec->srandom, srandom, RandomSize);
1784
 
1785
	if(setVers(sec, vers) < 0)
1786
		goto Err;
1787
 
1788
	pub = X509toRSApub(cert, ncert, nil, 0);
1789
	if(pub == nil){
1790
		werrstr("invalid x509/rsa certificate");
1791
		goto Err;
1792
	}
1793
	if(clientMasterSecret(sec, pub, epm, nepm) < 0)
1794
		goto Err;
1795
	rsapubfree(pub);
1796
	setSecrets(sec, kd, nkd);
1797
	return 0;
1798
 
1799
Err:
1800
	if(pub != nil)
1801
		rsapubfree(pub);
1802
	sec->ok = -1;
1803
	return -1;
1804
}
1805
 
1806
static int
1807
tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient)
1808
{
1809
	if(sec->nfin != nfin){
1810
		sec->ok = -1;
1811
		werrstr("invalid finished exchange");
1812
		return -1;
1813
	}
1814
	md5.malloced = 0;
1815
	sha1.malloced = 0;
1816
	(*sec->setFinished)(sec, md5, sha1, fin, isclient);
1817
	return 1;
1818
}
1819
 
1820
static void
1821
tlsSecOk(TlsSec *sec)
1822
{
1823
	if(sec->ok == 0)
1824
		sec->ok = 1;
1825
}
1826
 
1827
static void
1828
tlsSecKill(TlsSec *sec)
1829
{
1830
	if(!sec)
1831
		return;
1832
	factotum_rsa_close(sec->rpc);
1833
	sec->ok = -1;
1834
}
1835
 
1836
static void
1837
tlsSecClose(TlsSec *sec)
1838
{
1839
	if(!sec)
1840
		return;
1841
	factotum_rsa_close(sec->rpc);
1842
	free(sec->server);
1843
	free(sec);
1844
}
1845
 
1846
static int
1847
setVers(TlsSec *sec, int v)
1848
{
1849
	if(v == SSL3Version){
1850
		sec->setFinished = sslSetFinished;
1851
		sec->nfin = SSL3FinishedLen;
1852
		sec->prf = sslPRF;
1853
	}else if(v == TLSVersion){
1854
		sec->setFinished = tlsSetFinished;
1855
		sec->nfin = TLSFinishedLen;
1856
		sec->prf = tlsPRF;
1857
	}else{
1858
		werrstr("invalid version");
1859
		return -1;
1860
	}
1861
	sec->vers = v;
1862
	return 0;
1863
}
1864
 
1865
/*
1866
 * generate secret keys from the master secret.
1867
 *
1868
 * different crypto selections will require different amounts
1869
 * of key expansion and use of key expansion data,
1870
 * but it's all generated using the same function.
1871
 */
1872
static void
1873
setSecrets(TlsSec *sec, uchar *kd, int nkd)
1874
{
1875
	(*sec->prf)(kd, nkd, sec->sec, MasterSecretSize, "key expansion",
1876
			sec->srandom, RandomSize, sec->crandom, RandomSize);
1877
}
1878
 
1879
/*
1880
 * set the master secret from the pre-master secret.
1881
 */
1882
static void
1883
setMasterSecret(TlsSec *sec, Bytes *pm)
1884
{
1885
	(*sec->prf)(sec->sec, MasterSecretSize, pm->data, MasterSecretSize, "master secret",
1886
			sec->crandom, RandomSize, sec->srandom, RandomSize);
1887
}
1888
 
1889
static void
1890
serverMasterSecret(TlsSec *sec, uchar *epm, int nepm)
1891
{
1892
	Bytes *pm;
1893
 
1894
	pm = pkcs1_decrypt(sec, epm, nepm);
1895
 
1896
	// if the client messed up, just continue as if everything is ok,
1897
	// to prevent attacks to check for correctly formatted messages.
1898
	// Hence the fprint(2,) can't be replaced by tlsError(), which sends an Alert msg to the client.
1899
	if(sec->ok < 0 || pm == nil || get16(pm->data) != sec->clientVers){
1900
		fprint(2, "serverMasterSecret failed ok=%d pm=%p pmvers=%x cvers=%x nepm=%d\n",
1901
			sec->ok, pm, pm ? get16(pm->data) : -1, sec->clientVers, nepm);
1902
		sec->ok = -1;
1903
		if(pm != nil)
1904
			freebytes(pm);
1905
		pm = newbytes(MasterSecretSize);
1906
		genrandom(pm->data, MasterSecretSize);
1907
	}
1908
	setMasterSecret(sec, pm);
1909
	memset(pm->data, 0, pm->len);	
1910
	freebytes(pm);
1911
}
1912
 
1913
static int
1914
clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm)
1915
{
1916
	Bytes *pm, *key;
1917
 
1918
	pm = newbytes(MasterSecretSize);
1919
	put16(pm->data, sec->clientVers);
1920
	genrandom(pm->data+2, MasterSecretSize - 2);
1921
 
1922
	setMasterSecret(sec, pm);
1923
 
1924
	key = pkcs1_encrypt(pm, pub, 2);
1925
	memset(pm->data, 0, pm->len);
1926
	freebytes(pm);
1927
	if(key == nil){
1928
		werrstr("tls pkcs1_encrypt failed");
1929
		return -1;
1930
	}
1931
 
1932
	*nepm = key->len;
1933
	*epm = malloc(*nepm);
1934
	if(*epm == nil){
1935
		freebytes(key);
1936
		werrstr("out of memory");
1937
		return -1;
1938
	}
1939
	memmove(*epm, key->data, *nepm);
1940
 
1941
	freebytes(key);
1942
 
1943
	return 1;
1944
}
1945
 
1946
static void
1947
sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
1948
{
1949
	DigestState *s;
1950
	uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
1951
	char *label;
1952
 
1953
	if(isClient)
1954
		label = "CLNT";
1955
	else
1956
		label = "SRVR";
1957
 
1958
	md5((uchar*)label, 4, nil, &hsmd5);
1959
	md5(sec->sec, MasterSecretSize, nil, &hsmd5);
1960
	memset(pad, 0x36, 48);
1961
	md5(pad, 48, nil, &hsmd5);
1962
	md5(nil, 0, h0, &hsmd5);
1963
	memset(pad, 0x5C, 48);
1964
	s = md5(sec->sec, MasterSecretSize, nil, nil);
1965
	s = md5(pad, 48, nil, s);
1966
	md5(h0, MD5dlen, finished, s);
1967
 
1968
	sha1((uchar*)label, 4, nil, &hssha1);
1969
	sha1(sec->sec, MasterSecretSize, nil, &hssha1);
1970
	memset(pad, 0x36, 40);
1971
	sha1(pad, 40, nil, &hssha1);
1972
	sha1(nil, 0, h1, &hssha1);
1973
	memset(pad, 0x5C, 40);
1974
	s = sha1(sec->sec, MasterSecretSize, nil, nil);
1975
	s = sha1(pad, 40, nil, s);
1976
	sha1(h1, SHA1dlen, finished + MD5dlen, s);
1977
}
1978
 
1979
// fill "finished" arg with md5(args)^sha1(args)
1980
static void
1981
tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
1982
{
1983
	uchar h0[MD5dlen], h1[SHA1dlen];
1984
	char *label;
1985
 
1986
	// get current hash value, but allow further messages to be hashed in
1987
	md5(nil, 0, h0, &hsmd5);
1988
	sha1(nil, 0, h1, &hssha1);
1989
 
1990
	if(isClient)
1991
		label = "client finished";
1992
	else
1993
		label = "server finished";
1994
	tlsPRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
1995
}
1996
 
1997
static void
1998
sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1999
{
2000
	DigestState *s;
2001
	uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
2002
	int i, n, len;
2003
 
2004
	USED(label);
2005
	len = 1;
2006
	while(nbuf > 0){
2007
		if(len > 26)
2008
			return;
2009
		for(i = 0; i < len; i++)
2010
			tmp[i] = 'A' - 1 + len;
2011
		s = sha1(tmp, len, nil, nil);
2012
		s = sha1(key, nkey, nil, s);
2013
		s = sha1(seed0, nseed0, nil, s);
2014
		sha1(seed1, nseed1, sha1dig, s);
2015
		s = md5(key, nkey, nil, nil);
2016
		md5(sha1dig, SHA1dlen, md5dig, s);
2017
		n = MD5dlen;
2018
		if(n > nbuf)
2019
			n = nbuf;
2020
		memmove(buf, md5dig, n);
2021
		buf += n;
2022
		nbuf -= n;
2023
		len++;
2024
	}
2025
}
2026
 
2027
static mpint*
2028
bytestomp(Bytes* bytes)
2029
{
2030
	mpint* ans;
2031
 
2032
	ans = betomp(bytes->data, bytes->len, nil);
2033
	return ans;
2034
}
2035
 
2036
/*
2037
 * Convert mpint* to Bytes, putting high order byte first.
2038
 */
2039
static Bytes*
2040
mptobytes(mpint* big)
2041
{
2042
	int n, m;
2043
	uchar *a;
2044
	Bytes* ans;
2045
 
2046
	n = (mpsignif(big)+7)/8;
2047
	m = mptobe(big, nil, n, &a);
2048
	ans = makebytes(a, m);
2049
	return ans;
2050
}
2051
 
2052
// Do RSA computation on block according to key, and pad
2053
// result on left with zeros to make it modlen long.
2054
static Bytes*
2055
rsacomp(Bytes* block, RSApub* key, int modlen)
2056
{
2057
	mpint *x, *y;
2058
	Bytes *a, *ybytes;
2059
	int ylen;
2060
 
2061
	x = bytestomp(block);
2062
	y = rsaencrypt(key, x, nil);
2063
	mpfree(x);
2064
	ybytes = mptobytes(y);
2065
	ylen = ybytes->len;
2066
 
2067
	if(ylen < modlen) {
2068
		a = newbytes(modlen);
2069
		memset(a->data, 0, modlen-ylen);
2070
		memmove(a->data+modlen-ylen, ybytes->data, ylen);
2071
		freebytes(ybytes);
2072
		ybytes = a;
2073
	}
2074
	else if(ylen > modlen) {
2075
		// assume it has leading zeros (mod should make it so)
2076
		a = newbytes(modlen);
2077
		memmove(a->data, ybytes->data, modlen);
2078
		freebytes(ybytes);
2079
		ybytes = a;
2080
	}
2081
	mpfree(y);
2082
	return ybytes;
2083
}
2084
 
2085
// encrypt data according to PKCS#1, /lib/rfc/rfc2437 9.1.2.1
2086
static Bytes*
2087
pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype)
2088
{
2089
	Bytes *pad, *eb, *ans;
2090
	int i, dlen, padlen, modlen;
2091
 
2092
	modlen = (mpsignif(key->n)+7)/8;
2093
	dlen = data->len;
2094
	if(modlen < 12 || dlen > modlen - 11)
2095
		return nil;
2096
	padlen = modlen - 3 - dlen;
2097
	pad = newbytes(padlen);
2098
	genrandom(pad->data, padlen);
2099
	for(i = 0; i < padlen; i++) {
2100
		if(blocktype == 0)
2101
			pad->data[i] = 0;
2102
		else if(blocktype == 1)
2103
			pad->data[i] = 255;
2104
		else if(pad->data[i] == 0)
2105
			pad->data[i] = 1;
2106
	}
2107
	eb = newbytes(modlen);
2108
	eb->data[0] = 0;
2109
	eb->data[1] = blocktype;
2110
	memmove(eb->data+2, pad->data, padlen);
2111
	eb->data[padlen+2] = 0;
2112
	memmove(eb->data+padlen+3, data->data, dlen);
2113
	ans = rsacomp(eb, key, modlen);
2114
	freebytes(eb);
2115
	freebytes(pad);
2116
	return ans;
2117
}
2118
 
2119
// decrypt data according to PKCS#1, with given key.
2120
// expect a block type of 2.
2121
static Bytes*
2122
pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm)
2123
{
2124
	Bytes *eb, *ans = nil;
2125
	int i, modlen;
2126
	mpint *x, *y;
2127
 
2128
	modlen = (mpsignif(sec->rsapub->n)+7)/8;
2129
	if(nepm != modlen)
2130
		return nil;
2131
	x = betomp(epm, nepm, nil);
2132
	y = factotum_rsa_decrypt(sec->rpc, x);
2133
	if(y == nil)
2134
		return nil;
2135
	eb = mptobytes(y);
2136
	if(eb->len < modlen){ // pad on left with zeros
2137
		ans = newbytes(modlen);
2138
		memset(ans->data, 0, modlen-eb->len);
2139
		memmove(ans->data+modlen-eb->len, eb->data, eb->len);
2140
		freebytes(eb);
2141
		eb = ans;
2142
	}
2143
	if(eb->data[0] == 0 && eb->data[1] == 2) {
2144
		for(i = 2; i < modlen; i++)
2145
			if(eb->data[i] == 0)
2146
				break;
2147
		if(i < modlen - 1)
2148
			ans = makebytes(eb->data+i+1, modlen-(i+1));
2149
	}
2150
	freebytes(eb);
2151
	return ans;
2152
}
2153
 
2154
 
2155
//================= general utility functions ========================
2156
 
2157
static void *
2158
emalloc(int n)
2159
{
2160
	void *p;
2161
	if(n==0)
2162
		n=1;
2163
	p = malloc(n);
2164
	if(p == nil){
2165
		exits("out of memory");
2166
	}
2167
	memset(p, 0, n);
2168
	return p;
2169
}
2170
 
2171
static void *
2172
erealloc(void *ReallocP, int ReallocN)
2173
{
2174
	if(ReallocN == 0)
2175
		ReallocN = 1;
2176
	if(!ReallocP)
2177
		ReallocP = emalloc(ReallocN);
2178
	else if(!(ReallocP = realloc(ReallocP, ReallocN))){
2179
		exits("out of memory");
2180
	}
2181
	return(ReallocP);
2182
}
2183
 
2184
static void
2185
put32(uchar *p, u32int x)
2186
{
2187
	p[0] = x>>24;
2188
	p[1] = x>>16;
2189
	p[2] = x>>8;
2190
	p[3] = x;
2191
}
2192
 
2193
static void
2194
put24(uchar *p, int x)
2195
{
2196
	p[0] = x>>16;
2197
	p[1] = x>>8;
2198
	p[2] = x;
2199
}
2200
 
2201
static void
2202
put16(uchar *p, int x)
2203
{
2204
	p[0] = x>>8;
2205
	p[1] = x;
2206
}
2207
 
2208
static u32int
2209
get32(uchar *p)
2210
{
2211
	return (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3];
2212
}
2213
 
2214
static int
2215
get24(uchar *p)
2216
{
2217
	return (p[0]<<16)|(p[1]<<8)|p[2];
2218
}
2219
 
2220
static int
2221
get16(uchar *p)
2222
{
2223
	return (p[0]<<8)|p[1];
2224
}
2225
 
2226
/* ANSI offsetof() */
2227
#define OFFSET(x, s) ((int)(&(((s*)0)->x)))
2228
 
2229
/*
2230
 * malloc and return a new Bytes structure capable of
2231
 * holding len bytes. (len >= 0)
2232
 * Used to use crypt_malloc, which aborts if malloc fails.
2233
 */
2234
static Bytes*
2235
newbytes(int len)
2236
{
2237
	Bytes* ans;
2238
 
2239
	ans = (Bytes*)malloc(OFFSET(data[0], Bytes) + len);
2240
	ans->len = len;
2241
	return ans;
2242
}
2243
 
2244
/*
2245
 * newbytes(len), with data initialized from buf
2246
 */
2247
static Bytes*
2248
makebytes(uchar* buf, int len)
2249
{
2250
	Bytes* ans;
2251
 
2252
	ans = newbytes(len);
2253
	memmove(ans->data, buf, len);
2254
	return ans;
2255
}
2256
 
2257
static void
2258
freebytes(Bytes* b)
2259
{
2260
	if(b != nil)
2261
		free(b);
2262
}
2263
 
2264
/* len is number of ints */
2265
static Ints*
2266
newints(int len)
2267
{
2268
	Ints* ans;
2269
 
2270
	ans = (Ints*)malloc(OFFSET(data[0], Ints) + len*sizeof(int));
2271
	ans->len = len;
2272
	return ans;
2273
}
2274
 
2275
static Ints*
2276
makeints(int* buf, int len)
2277
{
2278
	Ints* ans;
2279
 
2280
	ans = newints(len);
2281
	if(len > 0)
2282
		memmove(ans->data, buf, len*sizeof(int));
2283
	return ans;
2284
}
2285
 
2286
static void
2287
freeints(Ints* b)
2288
{
2289
	if(b != nil)
2290
		free(b);
2291
}