Subversion Repositories planix.SVN

Rev

Rev 22 | Go to most recent revision | Details | Compare with Previous | Last modification | View Log | RSS feed

Rev Author Line No. Line
2 - 1
#include <u.h>
2
#include <libc.h>
3
#include <auth.h>
4
#include <mp.h>
5
#include <libsec.h>
6
 
7
// The main groups of functions are:
8
//		client/server - main handshake protocol definition
9
//		message functions - formating handshake messages
10
//		cipher choices - catalog of digest and encrypt algorithms
11
//		security functions - PKCS#1, sslHMAC, session keygen
12
//		general utility functions - malloc, serialization
13
// The handshake protocol builds on the TLS/SSL3 record layer protocol,
14
// which is implemented in kernel device #a.  See also /lib/rfc/rfc2246.
15
 
16
enum {
17
	TLSFinishedLen = 12,
18
	SSL3FinishedLen = MD5dlen+SHA1dlen,
26 7u83 19
	MaxKeyData = 160,	// amount of secret we may need
20
	MAXdlen = SHA2_512dlen,
2 - 21
	RandomSize = 32,
22
	MasterSecretSize = 48,
23
	AQueue = 0,
24
	AFlush = 1,
25
};
26
 
27
typedef struct Bytes{
28
	int len;
26 7u83 29
	uchar data[];
2 - 30
} Bytes;
31
 
32
typedef struct Ints{
33
	int len;
26 7u83 34
	int data[];
2 - 35
} Ints;
36
 
37
typedef struct Algs{
38
	char *enc;
39
	char *digest;
40
	int nsecret;
41
	int tlsid;
42
	int ok;
43
} Algs;
44
 
26 7u83 45
typedef struct Namedcurve{
46
	int tlsid;
47
	void (*init)(mpint *p, mpint *a, mpint *b, mpint *x, mpint *y, mpint *n, mpint *h);
48
} Namedcurve;
49
 
2 - 50
typedef struct Finished{
51
	uchar verify[SSL3FinishedLen];
52
	int n;
53
} Finished;
54
 
26 7u83 55
typedef struct HandshakeHash {
56
	MD5state	md5;
57
	SHAstate	sha1;
58
	SHA2_256state	sha2_256;
59
} HandshakeHash;
60
 
61
typedef struct TlsSec TlsSec;
62
struct TlsSec {
63
	RSApub *rsapub;
64
	AuthRpc *rpc;	// factotum for rsa private key
65
	uchar *psk;	// pre-shared key
66
	int psklen;
67
	int clientVers;			// version in ClientHello
68
	uchar sec[MasterSecretSize];	// master secret
69
	uchar crandom[RandomSize];	// client random
70
	uchar srandom[RandomSize];	// server random
71
 
72
	// diffie hellman state
73
	DHstate dh;
74
	struct {
75
		ECdomain dom;
76
		ECpriv Q;
77
	} ec;
78
 
79
	// byte generation and handshake checksum
80
	void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
81
	void (*setFinished)(TlsSec*, HandshakeHash, uchar*, int);
82
	int nfin;
83
};
84
 
2 - 85
typedef struct TlsConnection{
26 7u83 86
	TlsSec sec[1];	// security management goo
2 - 87
	int hand, ctl;	// record layer file descriptors
88
	int erred;		// set when tlsError called
89
	int (*trace)(char*fmt, ...); // for debugging
90
	int version;	// protocol we are speaking
26 7u83 91
	Bytes *cert;	// server certificate; only last - no chain
2 - 92
 
26 7u83 93
	int cipher;
94
	int nsecret;	// amount of secret data to init keys
2 - 95
	char *digest;	// name of digest algorithm to use
26 7u83 96
	char *enc;	// name of encryption algorithm to use
2 - 97
 
98
	// for finished messages
26 7u83 99
	HandshakeHash	handhash;
2 - 100
	Finished	finished;
26 7u83 101
 
102
	uchar *sendp;
103
	uchar buf[1<<16];
2 - 104
} TlsConnection;
105
 
106
typedef struct Msg{
107
	int tag;
108
	union {
109
		struct {
110
			int version;
111
			uchar 	random[RandomSize];
112
			Bytes*	sid;
113
			Ints*	ciphers;
114
			Bytes*	compressors;
26 7u83 115
			Bytes*	extensions;
2 - 116
		} clientHello;
117
		struct {
118
			int version;
26 7u83 119
			uchar	random[RandomSize];
2 - 120
			Bytes*	sid;
26 7u83 121
			int	cipher;
122
			int	compressor;
123
			Bytes*	extensions;
2 - 124
		} serverHello;
125
		struct {
126
			int ncert;
127
			Bytes **certs;
128
		} certificate;
129
		struct {
130
			Bytes *types;
26 7u83 131
			Ints *sigalgs;
2 - 132
			int nca;
133
			Bytes **cas;
134
		} certificateRequest;
135
		struct {
26 7u83 136
			Bytes *pskid;
2 - 137
			Bytes *key;
138
		} clientKeyExchange;
26 7u83 139
		struct {
140
			Bytes *pskid;
141
			Bytes *dh_p;
142
			Bytes *dh_g;
143
			Bytes *dh_Ys;
144
			Bytes *dh_parameters;
145
			Bytes *dh_signature;
146
			int sigalg;
147
			int curve;
148
		} serverKeyExchange;
149
		struct {
150
			int sigalg;
151
			Bytes *signature;
152
		} certificateVerify;		
2 - 153
		Finished finished;
154
	} u;
155
} Msg;
156
 
157
 
158
enum {
26 7u83 159
	SSL3Version	= 0x0300,
160
	TLS10Version	= 0x0301,
161
	TLS11Version	= 0x0302,
162
	TLS12Version	= 0x0303,
163
	ProtocolVersion	= TLS12Version,	// maximum version we speak
164
	MinProtoVersion	= 0x0300,	// limits on version we accept
2 - 165
	MaxProtoVersion	= 0x03ff,
166
};
167
 
168
// handshake type
169
enum {
170
	HHelloRequest,
171
	HClientHello,
172
	HServerHello,
173
	HSSL2ClientHello = 9,  /* local convention;  see devtls.c */
174
	HCertificate = 11,
175
	HServerKeyExchange,
176
	HCertificateRequest,
177
	HServerHelloDone,
178
	HCertificateVerify,
179
	HClientKeyExchange,
180
	HFinished = 20,
181
	HMax
182
};
183
 
184
// alerts
185
enum {
186
	ECloseNotify = 0,
187
	EUnexpectedMessage = 10,
188
	EBadRecordMac = 20,
189
	EDecryptionFailed = 21,
190
	ERecordOverflow = 22,
191
	EDecompressionFailure = 30,
192
	EHandshakeFailure = 40,
193
	ENoCertificate = 41,
194
	EBadCertificate = 42,
195
	EUnsupportedCertificate = 43,
196
	ECertificateRevoked = 44,
197
	ECertificateExpired = 45,
198
	ECertificateUnknown = 46,
199
	EIllegalParameter = 47,
200
	EUnknownCa = 48,
201
	EAccessDenied = 49,
202
	EDecodeError = 50,
203
	EDecryptError = 51,
204
	EExportRestriction = 60,
205
	EProtocolVersion = 70,
206
	EInsufficientSecurity = 71,
207
	EInternalError = 80,
26 7u83 208
	EInappropriateFallback = 86,
2 - 209
	EUserCanceled = 90,
210
	ENoRenegotiation = 100,
26 7u83 211
	EUnknownPSKidentity = 115,
2 - 212
	EMax = 256
213
};
214
 
215
// cipher suites
216
enum {
26 7u83 217
	TLS_NULL_WITH_NULL_NULL			= 0x0000,
218
	TLS_RSA_WITH_NULL_MD5			= 0x0001,
219
	TLS_RSA_WITH_NULL_SHA			= 0x0002,
220
	TLS_RSA_EXPORT_WITH_RC4_40_MD5		= 0x0003,
221
	TLS_RSA_WITH_RC4_128_MD5		= 0x0004,
222
	TLS_RSA_WITH_RC4_128_SHA		= 0x0005,
2 - 223
	TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5	= 0X0006,
26 7u83 224
	TLS_RSA_WITH_IDEA_CBC_SHA		= 0X0007,
2 - 225
	TLS_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X0008,
226
	TLS_RSA_WITH_DES_CBC_SHA		= 0X0009,
227
	TLS_RSA_WITH_3DES_EDE_CBC_SHA		= 0X000A,
228
	TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA	= 0X000B,
229
	TLS_DH_DSS_WITH_DES_CBC_SHA		= 0X000C,
230
	TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA	= 0X000D,
231
	TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X000E,
232
	TLS_DH_RSA_WITH_DES_CBC_SHA		= 0X000F,
233
	TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA	= 0X0010,
234
	TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA	= 0X0011,
235
	TLS_DHE_DSS_WITH_DES_CBC_SHA		= 0X0012,
236
	TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA	= 0X0013,	// ZZZ must be implemented for tls1.0 compliance
237
	TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X0014,
238
	TLS_DHE_RSA_WITH_DES_CBC_SHA		= 0X0015,
239
	TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA	= 0X0016,
240
	TLS_DH_anon_EXPORT_WITH_RC4_40_MD5	= 0x0017,
26 7u83 241
	TLS_DH_anon_WITH_RC4_128_MD5		= 0x0018,
2 - 242
	TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA	= 0X0019,
243
	TLS_DH_anon_WITH_DES_CBC_SHA		= 0X001A,
244
	TLS_DH_anon_WITH_3DES_EDE_CBC_SHA	= 0X001B,
26 7u83 245
	TLS_RSA_WITH_AES_128_CBC_SHA		= 0X002F,	// aes, aka rijndael with 128 bit blocks
2 - 246
	TLS_DH_DSS_WITH_AES_128_CBC_SHA		= 0X0030,
247
	TLS_DH_RSA_WITH_AES_128_CBC_SHA		= 0X0031,
248
	TLS_DHE_DSS_WITH_AES_128_CBC_SHA	= 0X0032,
249
	TLS_DHE_RSA_WITH_AES_128_CBC_SHA	= 0X0033,
250
	TLS_DH_anon_WITH_AES_128_CBC_SHA	= 0X0034,
251
	TLS_RSA_WITH_AES_256_CBC_SHA		= 0X0035,
252
	TLS_DH_DSS_WITH_AES_256_CBC_SHA		= 0X0036,
253
	TLS_DH_RSA_WITH_AES_256_CBC_SHA		= 0X0037,
254
	TLS_DHE_DSS_WITH_AES_256_CBC_SHA	= 0X0038,
255
	TLS_DHE_RSA_WITH_AES_256_CBC_SHA	= 0X0039,
256
	TLS_DH_anon_WITH_AES_256_CBC_SHA	= 0X003A,
26 7u83 257
	TLS_RSA_WITH_AES_128_CBC_SHA256		= 0X003C,
258
	TLS_RSA_WITH_AES_256_CBC_SHA256		= 0X003D,
259
	TLS_DHE_RSA_WITH_AES_128_CBC_SHA256	= 0X0067,
260
 
261
	TLS_RSA_WITH_AES_128_GCM_SHA256		= 0x009C,
262
	TLS_RSA_WITH_AES_256_GCM_SHA384		= 0x009D,
263
	TLS_DHE_RSA_WITH_AES_128_GCM_SHA256	= 0x009E,
264
	TLS_DHE_RSA_WITH_AES_256_GCM_SHA384	= 0x009F,
265
	TLS_DH_RSA_WITH_AES_128_GCM_SHA256	= 0x00A0,
266
	TLS_DH_RSA_WITH_AES_256_GCM_SHA384	= 0x00A1,
267
	TLS_DHE_DSS_WITH_AES_128_GCM_SHA256	= 0x00A2,
268
	TLS_DHE_DSS_WITH_AES_256_GCM_SHA384	= 0x00A3,
269
	TLS_DH_DSS_WITH_AES_128_GCM_SHA256	= 0x00A4,
270
	TLS_DH_DSS_WITH_AES_256_GCM_SHA384	= 0x00A5,
271
	TLS_DH_anon_WITH_AES_128_GCM_SHA256	= 0x00A6,
272
	TLS_DH_anon_WITH_AES_256_GCM_SHA384	= 0x00A7,
273
 
274
	TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = 0xC02B,
275
	TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256	= 0xC02F,
276
 
277
	TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA	= 0xC013,
278
	TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA	= 0xC014,
279
	TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256	= 0xC027,
280
	TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256	= 0xC023,
281
 
282
	TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305	= 0xCCA8,
283
	TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305	= 0xCCA9,
284
	TLS_DHE_RSA_WITH_CHACHA20_POLY1305	= 0xCCAA,
285
 
286
	GOOGLE_ECDHE_RSA_WITH_CHACHA20_POLY1305		= 0xCC13,
287
	GOOGLE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305	= 0xCC14,
288
	GOOGLE_DHE_RSA_WITH_CHACHA20_POLY1305		= 0xCC15,
289
 
290
	TLS_PSK_WITH_CHACHA20_POLY1305		= 0xCCAB,
291
	TLS_PSK_WITH_AES_128_CBC_SHA256		= 0x00AE,
292
	TLS_PSK_WITH_AES_128_CBC_SHA		= 0x008C,
293
 
294
	TLS_FALLBACK_SCSV = 0x5600,
2 - 295
};
296
 
297
// compression methods
298
enum {
299
	CompressionNull = 0,
300
	CompressionMax
301
};
302
 
303
static Algs cipherAlgs[] = {
26 7u83 304
	// ECDHE-ECDSA
305
	{"ccpoly96_aead", "clear", 2*(32+12), TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305},
306
	{"ccpoly64_aead", "clear", 2*32, GOOGLE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305},
307
	{"aes_128_gcm_aead", "clear", 2*(16+4), TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
308
	{"aes_128_cbc", "sha256", 2*(16+16+SHA2_256dlen), TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256},
309
 
310
	// ECDHE-RSA
311
	{"ccpoly96_aead", "clear", 2*(32+12), TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305},
312
	{"ccpoly64_aead", "clear", 2*32, GOOGLE_ECDHE_RSA_WITH_CHACHA20_POLY1305},
313
	{"aes_128_gcm_aead", "clear", 2*(16+4), TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
314
	{"aes_128_cbc", "sha256", 2*(16+16+SHA2_256dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256},
315
	{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
316
	{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA},
317
 
318
	// DHE-RSA
319
	{"ccpoly96_aead", "clear", 2*(32+12), TLS_DHE_RSA_WITH_CHACHA20_POLY1305},
320
	{"ccpoly64_aead", "clear", 2*32, GOOGLE_DHE_RSA_WITH_CHACHA20_POLY1305},
321
	{"aes_128_gcm_aead", "clear", 2*(16+4), TLS_DHE_RSA_WITH_AES_128_GCM_SHA256},
322
	{"aes_128_cbc", "sha256", 2*(16+16+SHA2_256dlen), TLS_DHE_RSA_WITH_AES_128_CBC_SHA256},
323
	{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_DHE_RSA_WITH_AES_128_CBC_SHA},
324
	{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_DHE_RSA_WITH_AES_256_CBC_SHA},
325
	{"3des_ede_cbc","sha1",	2*(4*8+SHA1dlen), TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA},
326
 
327
	// RSA
328
	{"aes_128_gcm_aead", "clear", 2*(16+4), TLS_RSA_WITH_AES_128_GCM_SHA256},
329
	{"aes_128_cbc", "sha256", 2*(16+16+SHA2_256dlen), TLS_RSA_WITH_AES_128_CBC_SHA256},
330
	{"aes_256_cbc", "sha256", 2*(32+16+SHA2_256dlen), TLS_RSA_WITH_AES_256_CBC_SHA256},
2 - 331
	{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_RSA_WITH_AES_128_CBC_SHA},
26 7u83 332
	{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA},
333
	{"3des_ede_cbc","sha1",	2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
334
 
335
	// PSK
336
	{"ccpoly96_aead", "clear", 2*(32+12), TLS_PSK_WITH_CHACHA20_POLY1305},
337
	{"aes_128_cbc", "sha256", 2*(16+16+SHA2_256dlen), TLS_PSK_WITH_AES_128_CBC_SHA256},
338
	{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_PSK_WITH_AES_128_CBC_SHA},
2 - 339
};
340
 
341
static uchar compressors[] = {
342
	CompressionNull,
343
};
344
 
26 7u83 345
static Namedcurve namedcurves[] = {
346
	0x0017, secp256r1,
347
	0x0018, secp384r1,
348
};
2 - 349
 
26 7u83 350
static uchar pointformats[] = {
351
	CompressionNull /* support of uncompressed point format is mandatory */
352
};
353
 
354
static struct {
355
	DigestState* (*fun)(uchar*, ulong, uchar*, DigestState*);
356
	int len;
357
} hashfun[] = {
358
/*	[0x00]  is reserved for MD5+SHA1 for < TLS1.2 */
359
	[0x01]	{md5,		MD5dlen},
360
	[0x02]	{sha1,		SHA1dlen},
361
	[0x03]	{sha2_224,	SHA2_224dlen},
362
	[0x04]	{sha2_256,	SHA2_256dlen},
363
	[0x05]	{sha2_384,	SHA2_384dlen},
364
	[0x06]	{sha2_512,	SHA2_512dlen},
365
};
366
 
367
// signature algorithms (only RSA and ECDSA at the moment)
368
static int sigalgs[] = {
369
	0x0603,		/* SHA512 ECDSA */
370
	0x0503,		/* SHA384 ECDSA */
371
	0x0403,		/* SHA256 ECDSA */
372
	0x0203,		/* SHA1 ECDSA */
373
 
374
	0x0601,		/* SHA512 RSA */
375
	0x0501,		/* SHA384 RSA */
376
	0x0401,		/* SHA256 RSA */
377
	0x0201,		/* SHA1 RSA */
378
};
379
 
380
static TlsConnection *tlsServer2(int ctl, int hand,
381
	uchar *cert, int certlen,
382
	char *pskid, uchar *psk, int psklen,
383
	int (*trace)(char*fmt, ...), PEMChain *chain);
384
static TlsConnection *tlsClient2(int ctl, int hand,
385
	uchar *cert, int certlen,
386
	char *pskid, uchar *psk, int psklen,
387
	uchar *ext, int extlen, int (*trace)(char*fmt, ...));
2 - 388
static void	msgClear(Msg *m);
389
static char* msgPrint(char *buf, int n, Msg *m);
390
static int	msgRecv(TlsConnection *c, Msg *m);
391
static int	msgSend(TlsConnection *c, Msg *m, int act);
392
static void	tlsError(TlsConnection *c, int err, char *msg, ...);
393
#pragma	varargck argpos	tlsError 3
394
static int setVersion(TlsConnection *c, int version);
26 7u83 395
static int setSecrets(TlsConnection *c, int isclient);
2 - 396
static int finishedMatch(TlsConnection *c, Finished *f);
397
static void tlsConnectionFree(TlsConnection *c);
398
 
26 7u83 399
static int isDHE(int tlsid);
400
static int isECDHE(int tlsid);
401
static int isPSK(int tlsid);
402
static int isECDSA(int tlsid);
403
 
2 - 404
static int setAlgs(TlsConnection *c, int a);
26 7u83 405
static int okCipher(Ints *cv, int ispsk);
2 - 406
static int okCompression(Bytes *cv);
407
static int initCiphers(void);
26 7u83 408
static Ints* makeciphers(int ispsk);
2 - 409
 
26 7u83 410
static AuthRpc* factotum_rsa_open(RSApub *rsapub);
411
static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher);
412
static void factotum_rsa_close(AuthRpc *rpc);
413
 
414
static void	tlsSecInits(TlsSec *sec, int cvers, uchar *crandom);
415
static int	tlsSecRSAs(TlsSec *sec, Bytes *epm);
416
static Bytes*	tlsSecECDHEs1(TlsSec *sec, Namedcurve *nc);
417
static int	tlsSecECDHEs2(TlsSec *sec, Bytes *Yc);
418
static void	tlsSecInitc(TlsSec *sec, int cvers);
419
static Bytes*	tlsSecRSAc(TlsSec *sec, uchar *cert, int ncert);
420
static Bytes*	tlsSecDHEc(TlsSec *sec, Bytes *p, Bytes *g, Bytes *Ys);
421
static Bytes*	tlsSecECDHEc(TlsSec *sec, int curve, Bytes *Ys);
422
static void	tlsSecVers(TlsSec *sec, int v);
423
static int	tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient);
2 - 424
static void	setMasterSecret(TlsSec *sec, Bytes *pm);
26 7u83 425
static int	digestDHparams(TlsSec *sec, Bytes *par, uchar digest[MAXdlen], int sigalg);
426
static char*	verifyDHparams(TlsSec *sec, Bytes *par, Bytes *cert, Bytes *sig, int sigalg);
2 - 427
 
26 7u83 428
static Bytes*	pkcs1_encrypt(Bytes* data, RSApub* key);
429
static Bytes*	pkcs1_decrypt(TlsSec *sec, Bytes *data);
430
static Bytes*	pkcs1_sign(TlsSec *sec, uchar *digest, int digestlen, int sigalg);
2 - 431
 
432
static void* emalloc(int);
433
static void* erealloc(void*, int);
434
static void put32(uchar *p, u32int);
435
static void put24(uchar *p, int);
436
static void put16(uchar *p, int);
437
static int get24(uchar *p);
438
static int get16(uchar *p);
439
static Bytes* newbytes(int len);
440
static Bytes* makebytes(uchar* buf, int len);
26 7u83 441
static Bytes* mptobytes(mpint* big, int len);
442
static mpint* bytestomp(Bytes* bytes);
2 - 443
static void freebytes(Bytes* b);
444
static Ints* newints(int len);
445
static void freeints(Ints* b);
26 7u83 446
static int lookupid(Ints* b, int id);
2 - 447
 
448
//================= client/server ========================
449
 
450
//	push TLS onto fd, returning new (application) file descriptor
451
//		or -1 if error.
452
int
453
tlsServer(int fd, TLSconn *conn)
454
{
455
	char buf[8];
456
	char dname[64];
457
	int n, data, ctl, hand;
458
	TlsConnection *tls;
459
 
460
	if(conn == nil)
461
		return -1;
462
	ctl = open("#a/tls/clone", ORDWR);
463
	if(ctl < 0)
464
		return -1;
465
	n = read(ctl, buf, sizeof(buf)-1);
466
	if(n < 0){
467
		close(ctl);
468
		return -1;
469
	}
470
	buf[n] = 0;
26 7u83 471
	snprint(conn->dir, sizeof(conn->dir), "#a/tls/%s", buf);
472
	snprint(dname, sizeof(dname), "#a/tls/%s/hand", buf);
2 - 473
	hand = open(dname, ORDWR);
474
	if(hand < 0){
475
		close(ctl);
476
		return -1;
477
	}
26 7u83 478
	data = -1;
2 - 479
	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
26 7u83 480
	tls = tlsServer2(ctl, hand,
481
		conn->cert, conn->certlen,
482
		conn->pskID, conn->psk, conn->psklen,
483
		conn->trace, conn->chain);
484
	if(tls != nil){
485
		snprint(dname, sizeof(dname), "#a/tls/%s/data", buf);
486
		data = open(dname, ORDWR);
487
	}
2 - 488
	close(hand);
489
	close(ctl);
490
	if(data < 0){
26 7u83 491
		tlsConnectionFree(tls);
2 - 492
		return -1;
493
	}
26 7u83 494
	free(conn->cert);
495
	conn->cert = nil;  // client certificates are not yet implemented
2 - 496
	conn->certlen = 0;
26 7u83 497
	conn->sessionIDlen = 0;
498
	conn->sessionID = nil;
499
	if(conn->sessionKey != nil
500
	&& conn->sessionType != nil
501
	&& strcmp(conn->sessionType, "ttls") == 0)
502
		tls->sec->prf(
503
			conn->sessionKey, conn->sessionKeylen,
504
			tls->sec->sec, MasterSecretSize,
505
			conn->sessionConst, 
506
			tls->sec->crandom, RandomSize,
507
			tls->sec->srandom, RandomSize);
2 - 508
	tlsConnectionFree(tls);
26 7u83 509
	close(fd);
2 - 510
	return data;
511
}
512
 
26 7u83 513
static uchar*
514
tlsClientExtensions(TLSconn *conn, int *plen)
515
{
516
	uchar *b, *p;
517
	int i, n, m;
518
 
519
	p = b = nil;
520
 
521
	// RFC6066 - Server Name Identification
522
	if(conn->serverName != nil){
523
		n = strlen(conn->serverName);
524
 
525
		m = p - b;
526
		b = erealloc(b, m + 2+2+2+1+2+n);
527
		p = b + m;
528
 
529
		put16(p, 0), p += 2;		/* Type: server_name */
530
		put16(p, 2+1+2+n), p += 2;	/* Length */
531
		put16(p, 1+2+n), p += 2;	/* Server Name list length */
532
		*p++ = 0;			/* Server Name Type: host_name */
533
		put16(p, n), p += 2;		/* Server Name length */
534
		memmove(p, conn->serverName, n);
535
		p += n;
536
	}
537
 
538
	// ECDHE
539
	if(ProtocolVersion >= TLS10Version){
540
		m = p - b;
541
		b = erealloc(b, m + 2+2+2+nelem(namedcurves)*2 + 2+2+1+nelem(pointformats));
542
		p = b + m;
543
 
544
		n = nelem(namedcurves);
545
		put16(p, 0x000a), p += 2;	/* Type: elliptic_curves */
546
		put16(p, (n+1)*2), p += 2;	/* Length */
547
		put16(p, n*2), p += 2;		/* Elliptic Curves Length */
548
		for(i=0; i < n; i++){		/* Elliptic curves */
549
			put16(p, namedcurves[i].tlsid);
550
			p += 2;
551
		}
552
 
553
		n = nelem(pointformats);
554
		put16(p, 0x000b), p += 2;	/* Type: ec_point_formats */
555
		put16(p, n+1), p += 2;		/* Length */
556
		*p++ = n;			/* EC point formats Length */
557
		for(i=0; i < n; i++)		/* Elliptic curves point formats */
558
			*p++ = pointformats[i];
559
	}
560
 
561
	// signature algorithms
562
	if(ProtocolVersion >= TLS12Version){
563
		n = nelem(sigalgs);
564
 
565
		m = p - b;
566
		b = erealloc(b, m + 2+2+2+n*2);
567
		p = b + m;
568
 
569
		put16(p, 0x000d), p += 2;
570
		put16(p, n*2 + 2), p += 2;
571
		put16(p, n*2), p += 2;
572
		for(i=0; i < n; i++){
573
			put16(p, sigalgs[i]);
574
			p += 2;
575
		}
576
	}
577
 
578
	*plen = p - b;
579
	return b;
580
}
581
 
2 - 582
//	push TLS onto fd, returning new (application) file descriptor
583
//		or -1 if error.
584
int
585
tlsClient(int fd, TLSconn *conn)
586
{
587
	char buf[8];
588
	char dname[64];
589
	int n, data, ctl, hand;
590
	TlsConnection *tls;
26 7u83 591
	uchar *ext;
2 - 592
 
26 7u83 593
	if(conn == nil)
2 - 594
		return -1;
595
	ctl = open("#a/tls/clone", ORDWR);
596
	if(ctl < 0)
597
		return -1;
598
	n = read(ctl, buf, sizeof(buf)-1);
599
	if(n < 0){
600
		close(ctl);
601
		return -1;
602
	}
603
	buf[n] = 0;
26 7u83 604
	snprint(conn->dir, sizeof(conn->dir), "#a/tls/%s", buf);
605
	snprint(dname, sizeof(dname), "#a/tls/%s/hand", buf);
2 - 606
	hand = open(dname, ORDWR);
607
	if(hand < 0){
608
		close(ctl);
609
		return -1;
610
	}
26 7u83 611
	snprint(dname, sizeof(dname), "#a/tls/%s/data", buf);
2 - 612
	data = open(dname, ORDWR);
26 7u83 613
	if(data < 0){
614
		close(hand);
615
		close(ctl);
2 - 616
		return -1;
26 7u83 617
	}
2 - 618
	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
26 7u83 619
	ext = tlsClientExtensions(conn, &n);
620
	tls = tlsClient2(ctl, hand,
621
		conn->cert, conn->certlen, 
622
		conn->pskID, conn->psk, conn->psklen,
623
		ext, n, conn->trace);
624
	free(ext);
2 - 625
	close(hand);
626
	close(ctl);
627
	if(tls == nil){
628
		close(data);
629
		return -1;
630
	}
26 7u83 631
	free(conn->cert);
632
	if(tls->cert != nil){
633
		conn->certlen = tls->cert->len;
634
		conn->cert = emalloc(conn->certlen);
635
		memcpy(conn->cert, tls->cert->data, conn->certlen);
636
	} else {
637
		conn->certlen = 0;
638
		conn->cert = nil;
639
	}
640
	conn->sessionIDlen = 0;
641
	conn->sessionID = nil;
642
	if(conn->sessionKey != nil
643
	&& conn->sessionType != nil
644
	&& strcmp(conn->sessionType, "ttls") == 0)
645
		tls->sec->prf(
646
			conn->sessionKey, conn->sessionKeylen,
647
			tls->sec->sec, MasterSecretSize,
648
			conn->sessionConst, 
649
			tls->sec->crandom, RandomSize,
650
			tls->sec->srandom, RandomSize);
2 - 651
	tlsConnectionFree(tls);
26 7u83 652
	close(fd);
2 - 653
	return data;
654
}
655
 
656
static int
657
countchain(PEMChain *p)
658
{
659
	int i = 0;
660
 
661
	while (p) {
662
		i++;
663
		p = p->next;
664
	}
665
	return i;
666
}
667
 
668
static TlsConnection *
26 7u83 669
tlsServer2(int ctl, int hand,
670
	uchar *cert, int certlen,
671
	char *pskid, uchar *psk, int psklen,
672
	int (*trace)(char*fmt, ...), PEMChain *chp)
2 - 673
{
26 7u83 674
	int cipher, compressor, numcerts, i;
2 - 675
	TlsConnection *c;
676
	Msg m;
677
 
678
	if(trace)
679
		trace("tlsServer2\n");
680
	if(!initCiphers())
681
		return nil;
26 7u83 682
 
2 - 683
	c = emalloc(sizeof(TlsConnection));
684
	c->ctl = ctl;
685
	c->hand = hand;
686
	c->trace = trace;
687
	c->version = ProtocolVersion;
26 7u83 688
	c->sendp = c->buf;
2 - 689
 
690
	memset(&m, 0, sizeof(m));
691
	if(!msgRecv(c, &m)){
692
		if(trace)
693
			trace("initial msgRecv failed\n");
694
		goto Err;
695
	}
696
	if(m.tag != HClientHello) {
697
		tlsError(c, EUnexpectedMessage, "expected a client hello");
698
		goto Err;
699
	}
700
	if(trace)
26 7u83 701
		trace("ClientHello version %x\n", m.u.clientHello.version);
2 - 702
	if(setVersion(c, m.u.clientHello.version) < 0) {
703
		tlsError(c, EIllegalParameter, "incompatible version");
704
		goto Err;
705
	}
26 7u83 706
	if(c->version < ProtocolVersion
707
	&& lookupid(m.u.clientHello.ciphers, TLS_FALLBACK_SCSV) >= 0){
708
		tlsError(c, EInappropriateFallback, "inappropriate fallback");
2 - 709
		goto Err;
710
	}
26 7u83 711
	cipher = okCipher(m.u.clientHello.ciphers, psklen > 0);
712
	if(cipher < 0 || !setAlgs(c, cipher)) {
2 - 713
		tlsError(c, EHandshakeFailure, "no matching cipher suite");
714
		goto Err;
715
	}
716
	compressor = okCompression(m.u.clientHello.compressors);
717
	if(compressor < 0) {
718
		tlsError(c, EHandshakeFailure, "no matching compressor");
719
		goto Err;
720
	}
26 7u83 721
	if(trace)
722
		trace("  cipher %x, compressor %x\n", cipher, compressor);
2 - 723
 
26 7u83 724
	tlsSecInits(c->sec, m.u.clientHello.version, m.u.clientHello.random);
725
	tlsSecVers(c->sec, c->version);
726
	if(psklen > 0){
727
		c->sec->psk = psk;
728
		c->sec->psklen = psklen;
2 - 729
	}
26 7u83 730
	if(certlen > 0){
731
		/* server certificate */
732
		c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0);
733
		if(c->sec->rsapub == nil){
734
			tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate");
735
			goto Err;
736
		}
737
		c->sec->rpc = factotum_rsa_open(c->sec->rsapub);
738
		if(c->sec->rpc == nil){
739
			tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
740
			goto Err;
741
		}
2 - 742
	}
743
	msgClear(&m);
744
 
745
	m.tag = HServerHello;
746
	m.u.serverHello.version = c->version;
26 7u83 747
	memmove(m.u.serverHello.random, c->sec->srandom, RandomSize);
2 - 748
	m.u.serverHello.cipher = cipher;
749
	m.u.serverHello.compressor = compressor;
26 7u83 750
	m.u.serverHello.sid = makebytes(nil, 0);
2 - 751
	if(!msgSend(c, &m, AQueue))
752
		goto Err;
753
 
26 7u83 754
	if(certlen > 0){
755
		m.tag = HCertificate;
756
		numcerts = countchain(chp);
757
		m.u.certificate.ncert = 1 + numcerts;
758
		m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*));
759
		m.u.certificate.certs[0] = makebytes(cert, certlen);
760
		for (i = 0; i < numcerts && chp; i++, chp = chp->next)
761
			m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
762
		if(!msgSend(c, &m, AQueue))
763
			goto Err;
764
	}
2 - 765
 
26 7u83 766
	if(isECDHE(cipher)){
767
		Namedcurve *nc = &namedcurves[0];	/* secp256r1 */
768
 
769
		m.tag = HServerKeyExchange;
770
		m.u.serverKeyExchange.curve = nc->tlsid;
771
		m.u.serverKeyExchange.dh_parameters = tlsSecECDHEs1(c->sec, nc);
772
		if(m.u.serverKeyExchange.dh_parameters == nil){
773
			tlsError(c, EInternalError, "can't set DH parameters");
774
			goto Err;
775
		}
776
 
777
		/* sign the DH parameters */
778
		if(certlen > 0){
779
			uchar digest[MAXdlen];
780
			int digestlen;
781
 
782
			if(c->version >= TLS12Version)
783
				m.u.serverKeyExchange.sigalg = 0x0401;	/* RSA SHA256 */
784
			digestlen = digestDHparams(c->sec, m.u.serverKeyExchange.dh_parameters,
785
				digest, m.u.serverKeyExchange.sigalg);
786
			if((m.u.serverKeyExchange.dh_signature = pkcs1_sign(c->sec, digest, digestlen,
787
				m.u.serverKeyExchange.sigalg)) == nil){
788
				tlsError(c, EHandshakeFailure, "pkcs1_sign: %r");
789
				goto Err;
790
			}
791
		}
792
		if(!msgSend(c, &m, AQueue))
793
			goto Err;
794
	}
795
 
2 - 796
	m.tag = HServerHelloDone;
797
	if(!msgSend(c, &m, AFlush))
798
		goto Err;
799
 
800
	if(!msgRecv(c, &m))
801
		goto Err;
802
	if(m.tag != HClientKeyExchange) {
803
		tlsError(c, EUnexpectedMessage, "expected a client key exchange");
804
		goto Err;
805
	}
26 7u83 806
	if(pskid != nil){
807
		if(m.u.clientKeyExchange.pskid == nil
808
		|| m.u.clientKeyExchange.pskid->len != strlen(pskid)
809
		|| memcmp(pskid, m.u.clientKeyExchange.pskid->data, m.u.clientKeyExchange.pskid->len) != 0){
810
			tlsError(c, EUnknownPSKidentity, "unknown or missing pskid");
811
			goto Err;
812
		}
813
	}
814
	if(isECDHE(cipher)){
815
		if(tlsSecECDHEs2(c->sec, m.u.clientKeyExchange.key) < 0){
816
			tlsError(c, EHandshakeFailure, "couldn't set keys: %r");
817
			goto Err;
818
		}
819
	} else if(certlen > 0){
820
		if(tlsSecRSAs(c->sec, m.u.clientKeyExchange.key) < 0){
821
			tlsError(c, EHandshakeFailure, "couldn't set keys: %r");
822
			goto Err;
823
		}
824
	} else if(psklen > 0){
825
		setMasterSecret(c->sec, newbytes(psklen));
826
	} else {
827
		tlsError(c, EInternalError, "no psk or certificate");
2 - 828
		goto Err;
829
	}
26 7u83 830
 
2 - 831
	if(trace)
832
		trace("tls secrets\n");
26 7u83 833
	if(setSecrets(c, 0) < 0){
834
		tlsError(c, EHandshakeFailure, "can't set secrets: %r");
2 - 835
		goto Err;
836
	}
837
 
838
	/* no CertificateVerify; skip to Finished */
26 7u83 839
	if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 1) < 0){
2 - 840
		tlsError(c, EInternalError, "can't set finished: %r");
841
		goto Err;
842
	}
843
	if(!msgRecv(c, &m))
844
		goto Err;
845
	if(m.tag != HFinished) {
846
		tlsError(c, EUnexpectedMessage, "expected a finished");
847
		goto Err;
848
	}
849
	if(!finishedMatch(c, &m.u.finished)) {
850
		tlsError(c, EHandshakeFailure, "finished verification failed");
851
		goto Err;
852
	}
853
	msgClear(&m);
854
 
855
	/* change cipher spec */
856
	if(fprint(c->ctl, "changecipher") < 0){
857
		tlsError(c, EInternalError, "can't enable cipher: %r");
858
		goto Err;
859
	}
860
 
26 7u83 861
	if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 0) < 0){
2 - 862
		tlsError(c, EInternalError, "can't set finished: %r");
863
		goto Err;
864
	}
865
	m.tag = HFinished;
866
	m.u.finished = c->finished;
867
	if(!msgSend(c, &m, AFlush))
868
		goto Err;
869
	if(trace)
870
		trace("tls finished\n");
871
 
872
	if(fprint(c->ctl, "opened") < 0)
873
		goto Err;
874
	return c;
875
 
876
Err:
877
	msgClear(&m);
878
	tlsConnectionFree(c);
26 7u83 879
	return nil;
2 - 880
}
881
 
26 7u83 882
static Bytes*
883
tlsSecDHEc(TlsSec *sec, Bytes *p, Bytes *g, Bytes *Ys)
884
{
885
	DHstate *dh = &sec->dh;
886
	mpint *G, *P, *Y, *K;
887
	Bytes *Yc;
888
	int n;
889
 
890
	if(p == nil || g == nil || Ys == nil)
891
		return nil;
892
 
893
	Yc = nil;
894
	P = bytestomp(p);
895
	G = bytestomp(g);
896
	Y = bytestomp(Ys);
897
	K = nil;
898
 
899
	if(dh_new(dh, P, nil, G) == nil)
900
		goto Out;
901
	n = (mpsignif(P)+7)/8;
902
	Yc = mptobytes(dh->y, n);
903
	K = dh_finish(dh, Y);	/* zeros dh */
904
	if(K == nil){
905
		freebytes(Yc);
906
		Yc = nil;
907
		goto Out;
908
	}
909
	setMasterSecret(sec, mptobytes(K, n));
910
 
911
Out:
912
	mpfree(K);
913
	mpfree(Y);
914
	mpfree(G);
915
	mpfree(P);
916
 
917
	return Yc;
918
}
919
 
920
static Bytes*
921
tlsSecECDHEc(TlsSec *sec, int curve, Bytes *Ys)
922
{
923
	ECdomain *dom = &sec->ec.dom;
924
	ECpriv *Q = &sec->ec.Q;
925
	Namedcurve *nc;
926
	ECpub *pub;
927
	ECpoint K;
928
	Bytes *Yc;
929
	int n;
930
 
931
	if(Ys == nil)
932
		return nil;
933
	for(nc = namedcurves; nc != &namedcurves[nelem(namedcurves)]; nc++)
934
		if(nc->tlsid == curve)
935
			goto Found;
936
	return nil;
937
 
938
Found:
939
	ecdominit(dom, nc->init);
940
	pub = ecdecodepub(dom, Ys->data, Ys->len);
941
	if(pub == nil)
942
		return nil;
943
 
944
	memset(Q, 0, sizeof(*Q));
945
	Q->x = mpnew(0);
946
	Q->y = mpnew(0);
947
	Q->d = mpnew(0);
948
 
949
	memset(&K, 0, sizeof(K));
950
	K.x = mpnew(0);
951
	K.y = mpnew(0);
952
 
953
	ecgen(dom, Q);
954
	ecmul(dom, pub, Q->d, &K);
955
 
956
	n = (mpsignif(dom->p)+7)/8;
957
	setMasterSecret(sec, mptobytes(K.x, n));
958
	Yc = newbytes(1 + 2*n);
959
	Yc->len = ecencodepub(dom, Q, Yc->data, Yc->len);
960
 
961
	mpfree(K.x);
962
	mpfree(K.y);
963
 
964
	ecpubfree(pub);
965
 
966
	return Yc;
967
}
968
 
2 - 969
static TlsConnection *
26 7u83 970
tlsClient2(int ctl, int hand,
971
	uchar *cert, int certlen,
972
	char *pskid, uchar *psk, int psklen,
973
	uchar *ext, int extlen,
974
	int (*trace)(char*fmt, ...))
2 - 975
{
26 7u83 976
	int creq, dhx, cipher;
2 - 977
	TlsConnection *c;
26 7u83 978
	Bytes *epm;
2 - 979
	Msg m;
980
 
981
	if(!initCiphers())
982
		return nil;
26 7u83 983
 
2 - 984
	epm = nil;
26 7u83 985
	memset(&m, 0, sizeof(m));
2 - 986
	c = emalloc(sizeof(TlsConnection));
26 7u83 987
 
2 - 988
	c->ctl = ctl;
989
	c->hand = hand;
990
	c->trace = trace;
26 7u83 991
	c->cert = nil;
992
	c->sendp = c->buf;
2 - 993
 
26 7u83 994
	c->version = ProtocolVersion;
995
	tlsSecInitc(c->sec, c->version);
996
	if(psklen > 0){
997
		c->sec->psk = psk;
998
		c->sec->psklen = psklen;
999
	}
1000
	if(certlen > 0){
1001
		/* client certificate */
1002
		c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0);
1003
		if(c->sec->rsapub == nil){
1004
			tlsError(c, EInternalError, "invalid X509/rsa certificate");
1005
			goto Err;
1006
		}
1007
		c->sec->rpc = factotum_rsa_open(c->sec->rsapub);
1008
		if(c->sec->rpc == nil){
1009
			tlsError(c, EInternalError, "factotum_rsa_open: %r");
1010
			goto Err;
1011
		}
1012
	}
2 - 1013
 
1014
	/* client hello */
1015
	m.tag = HClientHello;
26 7u83 1016
	m.u.clientHello.version = c->version;
1017
	memmove(m.u.clientHello.random, c->sec->crandom, RandomSize);
1018
	m.u.clientHello.sid = makebytes(nil, 0);
1019
	m.u.clientHello.ciphers = makeciphers(psklen > 0);
2 - 1020
	m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
26 7u83 1021
	m.u.clientHello.extensions = makebytes(ext, extlen);
2 - 1022
	if(!msgSend(c, &m, AFlush))
1023
		goto Err;
1024
 
1025
	/* server hello */
1026
	if(!msgRecv(c, &m))
1027
		goto Err;
1028
	if(m.tag != HServerHello) {
1029
		tlsError(c, EUnexpectedMessage, "expected a server hello");
1030
		goto Err;
1031
	}
1032
	if(setVersion(c, m.u.serverHello.version) < 0) {
26 7u83 1033
		tlsError(c, EIllegalParameter, "incompatible version: %r");
2 - 1034
		goto Err;
1035
	}
26 7u83 1036
	tlsSecVers(c->sec, c->version);
1037
	memmove(c->sec->srandom, m.u.serverHello.random, RandomSize);
1038
 
1039
	cipher = m.u.serverHello.cipher;
1040
	if((psklen > 0) != isPSK(cipher) || !setAlgs(c, cipher)) {
2 - 1041
		tlsError(c, EIllegalParameter, "invalid cipher suite");
1042
		goto Err;
1043
	}
1044
	if(m.u.serverHello.compressor != CompressionNull) {
1045
		tlsError(c, EIllegalParameter, "invalid compression");
1046
		goto Err;
1047
	}
1048
 
26 7u83 1049
	dhx = isDHE(cipher) || isECDHE(cipher);
1050
	if(!msgRecv(c, &m))
1051
		goto Err;
1052
	if(m.tag == HCertificate){
1053
		if(m.u.certificate.ncert < 1) {
1054
			tlsError(c, EIllegalParameter, "runt certificate");
1055
			goto Err;
1056
		}
1057
		c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
1058
		if(!msgRecv(c, &m))
1059
			goto Err;
1060
	} else if(psklen == 0) {
2 - 1061
		tlsError(c, EUnexpectedMessage, "expected a certificate");
1062
		goto Err;
1063
	}
1064
	if(m.tag == HServerKeyExchange) {
26 7u83 1065
		if(dhx){
1066
			char *err = verifyDHparams(c->sec,
1067
				m.u.serverKeyExchange.dh_parameters,
1068
				c->cert,
1069
				m.u.serverKeyExchange.dh_signature,
1070
				c->version<TLS12Version ? 0x01 : m.u.serverKeyExchange.sigalg);
1071
			if(err != nil){
1072
				tlsError(c, EBadCertificate, "can't verify DH parameters: %s", err);
1073
				goto Err;
1074
			}
1075
			if(isECDHE(cipher))
1076
				epm = tlsSecECDHEc(c->sec,
1077
					m.u.serverKeyExchange.curve,
1078
					m.u.serverKeyExchange.dh_Ys);
1079
			else
1080
				epm = tlsSecDHEc(c->sec,
1081
					m.u.serverKeyExchange.dh_p, 
1082
					m.u.serverKeyExchange.dh_g,
1083
					m.u.serverKeyExchange.dh_Ys);
1084
			if(epm == nil){
1085
				tlsError(c, EHandshakeFailure, "bad DH parameters");
1086
				goto Err;
1087
			}
1088
		} else if(psklen == 0){
1089
			tlsError(c, EUnexpectedMessage, "got an server key exchange");
1090
			goto Err;
1091
		}
1092
		if(!msgRecv(c, &m))
1093
			goto Err;
1094
	} else if(dhx){
1095
		tlsError(c, EUnexpectedMessage, "expected server key exchange");
2 - 1096
		goto Err;
1097
	}
1098
 
1099
	/* certificate request (optional) */
1100
	creq = 0;
1101
	if(m.tag == HCertificateRequest) {
1102
		creq = 1;
1103
		if(!msgRecv(c, &m))
1104
			goto Err;
1105
	}
1106
 
1107
	if(m.tag != HServerHelloDone) {
1108
		tlsError(c, EUnexpectedMessage, "expected a server hello done");
1109
		goto Err;
1110
	}
1111
	msgClear(&m);
1112
 
26 7u83 1113
	if(!dhx){
1114
		if(c->cert != nil){
1115
			epm = tlsSecRSAc(c->sec, c->cert->data, c->cert->len);
1116
			if(epm == nil){
1117
				tlsError(c, EBadCertificate, "bad certificate: %r");
1118
				goto Err;
1119
			}
1120
		} else if(psklen > 0){
1121
			setMasterSecret(c->sec, newbytes(psklen));
1122
		} else {
1123
			tlsError(c, EInternalError, "no psk or certificate");
1124
			goto Err;
1125
		}
2 - 1126
	}
26 7u83 1127
 
1128
	if(trace)
1129
		trace("tls secrets\n");
1130
	if(setSecrets(c, 1) < 0){
1131
		tlsError(c, EHandshakeFailure, "can't set secrets: %r");
2 - 1132
		goto Err;
1133
	}
1134
 
1135
	if(creq) {
1136
		m.tag = HCertificate;
26 7u83 1137
		if(certlen > 0){
1138
			m.u.certificate.ncert = 1;
1139
			m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*));
1140
			m.u.certificate.certs[0] = makebytes(cert, certlen);
1141
		}		
2 - 1142
		if(!msgSend(c, &m, AFlush))
1143
			goto Err;
1144
	}
1145
 
1146
	/* client key exchange */
1147
	m.tag = HClientKeyExchange;
26 7u83 1148
	if(psklen > 0){
1149
		if(pskid == nil)
1150
			pskid = "";
1151
		m.u.clientKeyExchange.pskid = makebytes((uchar*)pskid, strlen(pskid));
1152
	}
1153
	m.u.clientKeyExchange.key = epm;
2 - 1154
	epm = nil;
26 7u83 1155
 
2 - 1156
	if(!msgSend(c, &m, AFlush))
1157
		goto Err;
1158
 
26 7u83 1159
	/* certificate verify */
1160
	if(creq && certlen > 0) {
1161
		HandshakeHash hsave;
1162
		uchar digest[MAXdlen];
1163
		int digestlen;
1164
 
1165
		/* save the state for the Finish message */
1166
		hsave = c->handhash;
1167
		if(c->version < TLS12Version){
1168
			md5(nil, 0, digest, &c->handhash.md5);
1169
			sha1(nil, 0, digest+MD5dlen, &c->handhash.sha1);
1170
			digestlen = MD5dlen+SHA1dlen;
1171
		} else {
1172
			m.u.certificateVerify.sigalg = 0x0401;	/* RSA SHA256 */
1173
			sha2_256(nil, 0, digest, &c->handhash.sha2_256);
1174
			digestlen = SHA2_256dlen;
1175
		}
1176
		c->handhash = hsave;
1177
 
1178
		if((m.u.certificateVerify.signature = pkcs1_sign(c->sec, digest, digestlen,
1179
			m.u.certificateVerify.sigalg)) == nil){
1180
			tlsError(c, EHandshakeFailure, "pkcs1_sign: %r");
1181
			goto Err;
1182
		}
1183
 
1184
		m.tag = HCertificateVerify;
1185
		if(!msgSend(c, &m, AFlush))
1186
			goto Err;
1187
	} 
1188
 
2 - 1189
	/* change cipher spec */
1190
	if(fprint(c->ctl, "changecipher") < 0){
1191
		tlsError(c, EInternalError, "can't enable cipher: %r");
1192
		goto Err;
1193
	}
1194
 
1195
	// Cipherchange must occur immediately before Finished to avoid
1196
	// potential hole;  see section 4.3 of Wagner Schneier 1996.
26 7u83 1197
	if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 1) < 0){
2 - 1198
		tlsError(c, EInternalError, "can't set finished 1: %r");
1199
		goto Err;
1200
	}
1201
	m.tag = HFinished;
1202
	m.u.finished = c->finished;
1203
	if(!msgSend(c, &m, AFlush)) {
1204
		tlsError(c, EInternalError, "can't flush after client Finished: %r");
1205
		goto Err;
1206
	}
1207
 
26 7u83 1208
	if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 0) < 0){
2 - 1209
		tlsError(c, EInternalError, "can't set finished 0: %r");
1210
		goto Err;
1211
	}
1212
	if(!msgRecv(c, &m)) {
1213
		tlsError(c, EInternalError, "can't read server Finished: %r");
1214
		goto Err;
1215
	}
1216
	if(m.tag != HFinished) {
1217
		tlsError(c, EUnexpectedMessage, "expected a Finished msg from server");
1218
		goto Err;
1219
	}
1220
 
1221
	if(!finishedMatch(c, &m.u.finished)) {
1222
		tlsError(c, EHandshakeFailure, "finished verification failed");
1223
		goto Err;
1224
	}
1225
	msgClear(&m);
1226
 
1227
	if(fprint(c->ctl, "opened") < 0){
1228
		if(trace)
1229
			trace("unable to do final open: %r\n");
1230
		goto Err;
1231
	}
1232
	return c;
1233
 
1234
Err:
1235
	free(epm);
1236
	msgClear(&m);
1237
	tlsConnectionFree(c);
26 7u83 1238
	return nil;
2 - 1239
}
1240
 
1241
 
1242
//================= message functions ========================
1243
 
26 7u83 1244
static void
1245
msgHash(TlsConnection *c, uchar *p, int n)
1246
{
1247
	md5(p, n, 0, &c->handhash.md5);
1248
	sha1(p, n, 0, &c->handhash.sha1);
1249
	if(c->version >= TLS12Version)
1250
		sha2_256(p, n, 0, &c->handhash.sha2_256);
1251
}
2 - 1252
 
1253
static int
1254
msgSend(TlsConnection *c, Msg *m, int act)
1255
{
26 7u83 1256
	uchar *p, *e; // sendp = start of new message;  p = write pointer; e = end pointer
1257
	int n, i;
2 - 1258
 
26 7u83 1259
	p = c->sendp;
1260
	e = &c->buf[sizeof(c->buf)];
2 - 1261
	if(c->trace)
26 7u83 1262
		c->trace("send %s", msgPrint((char*)p, e - p, m));
2 - 1263
 
1264
	p[0] = m->tag;	// header - fill in size later
1265
	p += 4;
1266
 
1267
	switch(m->tag) {
1268
	default:
1269
		tlsError(c, EInternalError, "can't encode a %d", m->tag);
1270
		goto Err;
1271
	case HClientHello:
26 7u83 1272
		if(p+2+RandomSize > e)
1273
			goto Overflow;
1274
		put16(p, m->u.clientHello.version), p += 2;
2 - 1275
		memmove(p, m->u.clientHello.random, RandomSize);
1276
		p += RandomSize;
1277
 
26 7u83 1278
		if(p+1+(n = m->u.clientHello.sid->len) > e)
1279
			goto Overflow;
1280
		*p++ = n;
1281
		memmove(p, m->u.clientHello.sid->data, n);
1282
		p += n;
2 - 1283
 
26 7u83 1284
		if(p+2+(n = m->u.clientHello.ciphers->len) > e)
1285
			goto Overflow;
1286
		put16(p, n*2), p += 2;
1287
		for(i=0; i<n; i++)
1288
			put16(p, m->u.clientHello.ciphers->data[i]), p += 2;
2 - 1289
 
26 7u83 1290
		if(p+1+(n = m->u.clientHello.compressors->len) > e)
1291
			goto Overflow;
1292
		*p++ = n;
1293
		memmove(p, m->u.clientHello.compressors->data, n);
1294
		p += n;
1295
 
1296
		if(m->u.clientHello.extensions == nil
1297
		|| (n = m->u.clientHello.extensions->len) == 0)
1298
			break;
1299
		if(p+2+n > e)
1300
			goto Overflow;
1301
		put16(p, n), p += 2;
1302
		memmove(p, m->u.clientHello.extensions->data, n);
1303
		p += n;
2 - 1304
		break;
1305
	case HServerHello:
26 7u83 1306
		if(p+2+RandomSize > e)
1307
			goto Overflow;
1308
		put16(p, m->u.serverHello.version), p += 2;
2 - 1309
		memmove(p, m->u.serverHello.random, RandomSize);
1310
		p += RandomSize;
1311
 
26 7u83 1312
		if(p+1+(n = m->u.serverHello.sid->len) > e)
1313
			goto Overflow;
1314
		*p++ = n;
1315
		memmove(p, m->u.serverHello.sid->data, n);
1316
		p += n;
2 - 1317
 
26 7u83 1318
		if(p+2+1 > e)
1319
			goto Overflow;
1320
		put16(p, m->u.serverHello.cipher), p += 2;
1321
		*p++ = m->u.serverHello.compressor;
1322
 
1323
		if(m->u.serverHello.extensions == nil
1324
		|| (n = m->u.serverHello.extensions->len) == 0)
1325
			break;
1326
		if(p+2+n > e)
1327
			goto Overflow;
1328
		put16(p, n), p += 2;
1329
		memmove(p, m->u.serverHello.extensions->data, n);
1330
		p += n;
2 - 1331
		break;
1332
	case HServerHelloDone:
1333
		break;
1334
	case HCertificate:
26 7u83 1335
		n = 0;
2 - 1336
		for(i = 0; i < m->u.certificate.ncert; i++)
26 7u83 1337
			n += 3 + m->u.certificate.certs[i]->len;
1338
		if(p+3+n > e)
1339
			goto Overflow;
1340
		put24(p, n), p += 3;
2 - 1341
		for(i = 0; i < m->u.certificate.ncert; i++){
26 7u83 1342
			n = m->u.certificate.certs[i]->len;
1343
			put24(p, n), p += 3;
1344
			memmove(p, m->u.certificate.certs[i]->data, n);
1345
			p += n;
2 - 1346
		}
1347
		break;
26 7u83 1348
	case HCertificateVerify:
1349
		if(p+2+2+(n = m->u.certificateVerify.signature->len) > e)
1350
			goto Overflow;
1351
		if(m->u.certificateVerify.sigalg != 0)
1352
			put16(p, m->u.certificateVerify.sigalg), p += 2;
1353
		put16(p, n), p += 2;
1354
		memmove(p, m->u.certificateVerify.signature->data, n);
1355
		p += n;
1356
		break;
1357
	case HServerKeyExchange:
1358
		if(m->u.serverKeyExchange.pskid != nil){
1359
			if(p+2+(n = m->u.serverKeyExchange.pskid->len) > e)
1360
				goto Overflow;
1361
			put16(p, n), p += 2;
1362
			memmove(p, m->u.serverKeyExchange.pskid->data, n);
1363
			p += n;
1364
		}
1365
		if(m->u.serverKeyExchange.dh_parameters == nil)
1366
			break;
1367
		if(p+(n = m->u.serverKeyExchange.dh_parameters->len) > e)
1368
			goto Overflow;
1369
		memmove(p, m->u.serverKeyExchange.dh_parameters->data, n);
1370
		p += n;
1371
		if(m->u.serverKeyExchange.dh_signature == nil)
1372
			break;
1373
		if(p+2+2+(n = m->u.serverKeyExchange.dh_signature->len) > e)
1374
			goto Overflow;
1375
		if(c->version >= TLS12Version)
1376
			put16(p, m->u.serverKeyExchange.sigalg), p += 2;
1377
		put16(p, n), p += 2;
1378
		memmove(p, m->u.serverKeyExchange.dh_signature->data, n);
1379
		p += n;
1380
		break;
2 - 1381
	case HClientKeyExchange:
26 7u83 1382
		if(m->u.clientKeyExchange.pskid != nil){
1383
			if(p+2+(n = m->u.clientKeyExchange.pskid->len) > e)
1384
				goto Overflow;
1385
			put16(p, n), p += 2;
1386
			memmove(p, m->u.clientKeyExchange.pskid->data, n);
1387
			p += n;
2 - 1388
		}
26 7u83 1389
		if(m->u.clientKeyExchange.key == nil)
1390
			break;
1391
		if(p+2+(n = m->u.clientKeyExchange.key->len) > e)
1392
			goto Overflow;
1393
		if(isECDHE(c->cipher))
1394
			*p++ = n;
1395
		else if(isDHE(c->cipher) || c->version != SSL3Version)
1396
			put16(p, n), p += 2;
2 - 1397
		memmove(p, m->u.clientKeyExchange.key->data, n);
1398
		p += n;
1399
		break;
1400
	case HFinished:
26 7u83 1401
		if(p+m->u.finished.n > e)
1402
			goto Overflow;
2 - 1403
		memmove(p, m->u.finished.verify, m->u.finished.n);
1404
		p += m->u.finished.n;
1405
		break;
1406
	}
1407
 
1408
	// go back and fill in size
26 7u83 1409
	n = p - c->sendp;
1410
	put24(c->sendp+1, n-4);
2 - 1411
 
1412
	// remember hash of Handshake messages
26 7u83 1413
	if(m->tag != HHelloRequest)
1414
		msgHash(c, c->sendp, n);
2 - 1415
 
26 7u83 1416
	c->sendp = p;
2 - 1417
	if(act == AFlush){
26 7u83 1418
		c->sendp = c->buf;
1419
		if(write(c->hand, c->buf, p - c->buf) < 0){
2 - 1420
			fprint(2, "write error: %r\n");
1421
			goto Err;
1422
		}
1423
	}
1424
	msgClear(m);
1425
	return 1;
26 7u83 1426
Overflow:
1427
	tlsError(c, EInternalError, "not enougth send buffer for message (%d)", m->tag);
2 - 1428
Err:
1429
	msgClear(m);
1430
	return 0;
1431
}
1432
 
1433
static uchar*
1434
tlsReadN(TlsConnection *c, int n)
1435
{
26 7u83 1436
	uchar *p, *w, *e;
2 - 1437
 
26 7u83 1438
	e = &c->buf[sizeof(c->buf)];
1439
	p = e - n;
1440
	if(n > sizeof(c->buf) || p < c->sendp){
1441
		tlsError(c, EDecodeError, "handshake message too long %d", n);
1442
		return nil;
2 - 1443
	}
26 7u83 1444
	for(w = p; w < e; w += n)
1445
		if((n = read(c->hand, w, e - w)) <= 0)
1446
			return nil;
2 - 1447
	return p;
1448
}
1449
 
1450
static int
1451
msgRecv(TlsConnection *c, Msg *m)
1452
{
26 7u83 1453
	uchar *p, *s;
1454
	int type, n, nn, i;
2 - 1455
 
26 7u83 1456
	msgClear(m);
2 - 1457
	for(;;) {
1458
		p = tlsReadN(c, 4);
1459
		if(p == nil)
1460
			return 0;
1461
		type = p[0];
1462
		n = get24(p+1);
1463
 
1464
		if(type != HHelloRequest)
1465
			break;
1466
		if(n != 0) {
1467
			tlsError(c, EDecodeError, "invalid hello request during handshake");
1468
			return 0;
1469
		}
1470
	}
1471
 
1472
	if(type == HSSL2ClientHello){
1473
		/* Cope with an SSL3 ClientHello expressed in SSL2 record format.
1474
			This is sent by some clients that we must interoperate
1475
			with, such as Java's JSSE and Microsoft's Internet Explorer. */
26 7u83 1476
		int nsid, nrandom, nciph;
1477
 
2 - 1478
		p = tlsReadN(c, n);
1479
		if(p == nil)
1480
			return 0;
26 7u83 1481
		msgHash(c, p, n);
2 - 1482
		m->tag = HClientHello;
1483
		if(n < 22)
1484
			goto Short;
1485
		m->u.clientHello.version = get16(p+1);
1486
		p += 3;
1487
		n -= 3;
1488
		nn = get16(p); /* cipher_spec_len */
1489
		nsid = get16(p + 2);
1490
		nrandom = get16(p + 4);
1491
		p += 6;
1492
		n -= 6;
1493
		if(nsid != 0 	/* no sid's, since shouldn't restart using ssl2 header */
26 7u83 1494
		|| nrandom < 16 || nn % 3 || n - nrandom < nn)
2 - 1495
			goto Err;
1496
		/* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
1497
		nciph = 0;
1498
		for(i = 0; i < nn; i += 3)
1499
			if(p[i] == 0)
1500
				nciph++;
1501
		m->u.clientHello.ciphers = newints(nciph);
1502
		nciph = 0;
1503
		for(i = 0; i < nn; i += 3)
1504
			if(p[i] == 0)
1505
				m->u.clientHello.ciphers->data[nciph++] = get16(&p[i + 1]);
1506
		p += nn;
1507
		m->u.clientHello.sid = makebytes(nil, 0);
1508
		if(nrandom > RandomSize)
1509
			nrandom = RandomSize;
1510
		memset(m->u.clientHello.random, 0, RandomSize - nrandom);
1511
		memmove(&m->u.clientHello.random[RandomSize - nrandom], p, nrandom);
1512
		m->u.clientHello.compressors = newbytes(1);
1513
		m->u.clientHello.compressors->data[0] = CompressionNull;
1514
		goto Ok;
1515
	}
26 7u83 1516
	msgHash(c, p, 4);
2 - 1517
 
1518
	p = tlsReadN(c, n);
1519
	if(p == nil)
1520
		return 0;
1521
 
26 7u83 1522
	msgHash(c, p, n);
2 - 1523
 
1524
	m->tag = type;
1525
 
1526
	switch(type) {
1527
	default:
1528
		tlsError(c, EUnexpectedMessage, "can't decode a %d", type);
1529
		goto Err;
1530
	case HClientHello:
1531
		if(n < 2)
1532
			goto Short;
1533
		m->u.clientHello.version = get16(p);
26 7u83 1534
		p += 2, n -= 2;
2 - 1535
 
1536
		if(n < RandomSize)
1537
			goto Short;
1538
		memmove(m->u.clientHello.random, p, RandomSize);
26 7u83 1539
		p += RandomSize, n -= RandomSize;
2 - 1540
		if(n < 1 || n < p[0]+1)
1541
			goto Short;
1542
		m->u.clientHello.sid = makebytes(p+1, p[0]);
1543
		p += m->u.clientHello.sid->len+1;
1544
		n -= m->u.clientHello.sid->len+1;
1545
 
1546
		if(n < 2)
1547
			goto Short;
1548
		nn = get16(p);
26 7u83 1549
		p += 2, n -= 2;
2 - 1550
 
1551
		if((nn & 1) || n < nn || nn < 2)
1552
			goto Short;
1553
		m->u.clientHello.ciphers = newints(nn >> 1);
1554
		for(i = 0; i < nn; i += 2)
1555
			m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]);
26 7u83 1556
		p += nn, n -= nn;
2 - 1557
 
1558
		if(n < 1 || n < p[0]+1 || p[0] == 0)
1559
			goto Short;
1560
		nn = p[0];
26 7u83 1561
		m->u.clientHello.compressors = makebytes(p+1, nn);
1562
		p += nn + 1, n -= nn + 1;
1563
 
1564
		if(n < 2)
1565
			break;
1566
		nn = get16(p);
1567
		if(nn > n-2)
1568
			goto Short;
1569
		m->u.clientHello.extensions = makebytes(p+2, nn);
1570
		n -= nn + 2;
2 - 1571
		break;
1572
	case HServerHello:
1573
		if(n < 2)
1574
			goto Short;
1575
		m->u.serverHello.version = get16(p);
26 7u83 1576
		p += 2, n -= 2;
2 - 1577
 
1578
		if(n < RandomSize)
1579
			goto Short;
1580
		memmove(m->u.serverHello.random, p, RandomSize);
26 7u83 1581
		p += RandomSize, n -= RandomSize;
2 - 1582
 
1583
		if(n < 1 || n < p[0]+1)
1584
			goto Short;
1585
		m->u.serverHello.sid = makebytes(p+1, p[0]);
1586
		p += m->u.serverHello.sid->len+1;
1587
		n -= m->u.serverHello.sid->len+1;
1588
 
1589
		if(n < 3)
1590
			goto Short;
1591
		m->u.serverHello.cipher = get16(p);
1592
		m->u.serverHello.compressor = p[2];
26 7u83 1593
		p += 3, n -= 3;
1594
 
1595
		if(n < 2)
1596
			break;
1597
		nn = get16(p);
1598
		if(nn > n-2)
1599
			goto Short;
1600
		m->u.serverHello.extensions = makebytes(p+2, nn);
1601
		n -= nn + 2;
2 - 1602
		break;
1603
	case HCertificate:
1604
		if(n < 3)
1605
			goto Short;
1606
		nn = get24(p);
26 7u83 1607
		p += 3, n -= 3;
1608
		if(nn == 0 && n > 0)
2 - 1609
			goto Short;
1610
		/* certs */
1611
		i = 0;
1612
		while(n > 0) {
1613
			if(n < 3)
1614
				goto Short;
1615
			nn = get24(p);
26 7u83 1616
			p += 3, n -= 3;
2 - 1617
			if(nn > n)
1618
				goto Short;
1619
			m->u.certificate.ncert = i+1;
26 7u83 1620
			m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes*));
2 - 1621
			m->u.certificate.certs[i] = makebytes(p, nn);
26 7u83 1622
			p += nn, n -= nn;
2 - 1623
			i++;
1624
		}
1625
		break;
1626
	case HCertificateRequest:
1627
		if(n < 1)
1628
			goto Short;
1629
		nn = p[0];
26 7u83 1630
		p++, n--;
1631
		if(nn > n)
2 - 1632
			goto Short;
1633
		m->u.certificateRequest.types = makebytes(p, nn);
26 7u83 1634
		p += nn, n -= nn;
1635
		if(c->version >= TLS12Version){
1636
			if(n < 2)
1637
				goto Short;
1638
			nn = get16(p);
1639
			p += 2, n -= 2;
1640
			if(nn & 1)
1641
				goto Short;
1642
			m->u.certificateRequest.sigalgs = newints(nn>>1);
1643
			for(i = 0; i < nn; i += 2)
1644
				m->u.certificateRequest.sigalgs->data[i >> 1] = get16(&p[i]);
1645
			p += nn, n -= nn;
1646
 
1647
		}
2 - 1648
		if(n < 2)
1649
			goto Short;
1650
		nn = get16(p);
26 7u83 1651
		p += 2, n -= 2;
2 - 1652
		/* nn == 0 can happen; yahoo's servers do it */
1653
		if(nn != n)
1654
			goto Short;
1655
		/* cas */
1656
		i = 0;
1657
		while(n > 0) {
1658
			if(n < 2)
1659
				goto Short;
1660
			nn = get16(p);
26 7u83 1661
			p += 2, n -= 2;
2 - 1662
			if(nn < 1 || nn > n)
1663
				goto Short;
1664
			m->u.certificateRequest.nca = i+1;
1665
			m->u.certificateRequest.cas = erealloc(
26 7u83 1666
				m->u.certificateRequest.cas, (i+1)*sizeof(Bytes*));
2 - 1667
			m->u.certificateRequest.cas[i] = makebytes(p, nn);
26 7u83 1668
			p += nn, n -= nn;
2 - 1669
			i++;
1670
		}
1671
		break;
1672
	case HServerHelloDone:
1673
		break;
26 7u83 1674
	case HServerKeyExchange:
1675
		if(isPSK(c->cipher)){
1676
			if(n < 2)
1677
				goto Short;
1678
			nn = get16(p);
1679
			p += 2, n -= 2;
1680
			if(nn > n)
1681
				goto Short;
1682
			m->u.serverKeyExchange.pskid = makebytes(p, nn);
1683
			p += nn, n -= nn;
1684
			if(n == 0)
1685
				break;
1686
		}
1687
		if(n < 2)
1688
			goto Short;
1689
		s = p;
1690
		if(isECDHE(c->cipher)){
1691
			nn = *p;
1692
			p++, n--;
1693
			if(nn != 3 || nn > n) /* not a named curve */
1694
				goto Short;
1695
			nn = get16(p);
1696
			p += 2, n -= 2;
1697
			m->u.serverKeyExchange.curve = nn;
1698
 
1699
			nn = *p++, n--;
1700
			if(nn < 1 || nn > n)
1701
				goto Short;
1702
			m->u.serverKeyExchange.dh_Ys = makebytes(p, nn);
1703
			p += nn, n -= nn;
1704
		}else if(isDHE(c->cipher)){
1705
			nn = get16(p);
1706
			p += 2, n -= 2;
1707
			if(nn < 1 || nn > n)
1708
				goto Short;
1709
			m->u.serverKeyExchange.dh_p = makebytes(p, nn);
1710
			p += nn, n -= nn;
1711
 
1712
			if(n < 2)
1713
				goto Short;
1714
			nn = get16(p);
1715
			p += 2, n -= 2;
1716
			if(nn < 1 || nn > n)
1717
				goto Short;
1718
			m->u.serverKeyExchange.dh_g = makebytes(p, nn);
1719
			p += nn, n -= nn;
1720
 
1721
			if(n < 2)
1722
				goto Short;
1723
			nn = get16(p);
1724
			p += 2, n -= 2;
1725
			if(nn < 1 || nn > n)
1726
				goto Short;
1727
			m->u.serverKeyExchange.dh_Ys = makebytes(p, nn);
1728
			p += nn, n -= nn;
1729
		} else {
1730
			/* should not happen */
1731
			goto Short;
1732
		}
1733
		m->u.serverKeyExchange.dh_parameters = makebytes(s, p - s);
1734
		if(n >= 2){
1735
			m->u.serverKeyExchange.sigalg = 0;
1736
			if(c->version >= TLS12Version){
1737
				m->u.serverKeyExchange.sigalg = get16(p);
1738
				p += 2, n -= 2;
1739
				if(n < 2)
1740
					goto Short;
1741
			}
1742
			nn = get16(p);
1743
			p += 2, n -= 2;
1744
			if(nn > 0 && nn <= n){
1745
				m->u.serverKeyExchange.dh_signature = makebytes(p, nn);
1746
				n -= nn;
1747
			}
1748
		}
1749
		break;		
2 - 1750
	case HClientKeyExchange:
26 7u83 1751
		if(isPSK(c->cipher)){
2 - 1752
			if(n < 2)
1753
				goto Short;
1754
			nn = get16(p);
26 7u83 1755
			p += 2, n -= 2;
1756
			if(nn > n)
1757
				goto Short;
1758
			m->u.clientKeyExchange.pskid = makebytes(p, nn);
1759
			p += nn, n -= nn;
1760
			if(n == 0)
1761
				break;
2 - 1762
		}
26 7u83 1763
		if(n < 2)
1764
			goto Short;
1765
		if(isECDHE(c->cipher))
1766
			nn = *p++, n--;
1767
		else if(isDHE(c->cipher) || c->version != SSL3Version)
1768
			nn = get16(p), p += 2, n -= 2;
1769
		else
1770
			nn = n;
2 - 1771
		if(n < nn)
1772
			goto Short;
1773
		m->u.clientKeyExchange.key = makebytes(p, nn);
1774
		n -= nn;
1775
		break;
1776
	case HFinished:
1777
		m->u.finished.n = c->finished.n;
1778
		if(n < m->u.finished.n)
1779
			goto Short;
1780
		memmove(m->u.finished.verify, p, m->u.finished.n);
1781
		n -= m->u.finished.n;
1782
		break;
1783
	}
1784
 
26 7u83 1785
	if(n != 0 && type != HClientHello && type != HServerHello)
2 - 1786
		goto Short;
1787
Ok:
26 7u83 1788
	if(c->trace)
1789
		c->trace("recv %s", msgPrint((char*)c->sendp, &c->buf[sizeof(c->buf)] - c->sendp, m));
2 - 1790
	return 1;
1791
Short:
26 7u83 1792
	tlsError(c, EDecodeError, "handshake message (%d) has invalid length", type);
2 - 1793
Err:
1794
	msgClear(m);
1795
	return 0;
1796
}
1797
 
1798
static void
1799
msgClear(Msg *m)
1800
{
1801
	int i;
1802
 
1803
	switch(m->tag) {
1804
	case HHelloRequest:
1805
		break;
1806
	case HClientHello:
1807
		freebytes(m->u.clientHello.sid);
1808
		freeints(m->u.clientHello.ciphers);
1809
		freebytes(m->u.clientHello.compressors);
26 7u83 1810
		freebytes(m->u.clientHello.extensions);
2 - 1811
		break;
1812
	case HServerHello:
26 7u83 1813
		freebytes(m->u.serverHello.sid);
1814
		freebytes(m->u.serverHello.extensions);
2 - 1815
		break;
1816
	case HCertificate:
1817
		for(i=0; i<m->u.certificate.ncert; i++)
1818
			freebytes(m->u.certificate.certs[i]);
1819
		free(m->u.certificate.certs);
1820
		break;
1821
	case HCertificateRequest:
1822
		freebytes(m->u.certificateRequest.types);
26 7u83 1823
		freeints(m->u.certificateRequest.sigalgs);
2 - 1824
		for(i=0; i<m->u.certificateRequest.nca; i++)
1825
			freebytes(m->u.certificateRequest.cas[i]);
1826
		free(m->u.certificateRequest.cas);
1827
		break;
26 7u83 1828
	case HCertificateVerify:
1829
		freebytes(m->u.certificateVerify.signature);
1830
		break;
2 - 1831
	case HServerHelloDone:
1832
		break;
26 7u83 1833
	case HServerKeyExchange:
1834
		freebytes(m->u.serverKeyExchange.pskid);
1835
		freebytes(m->u.serverKeyExchange.dh_p);
1836
		freebytes(m->u.serverKeyExchange.dh_g);
1837
		freebytes(m->u.serverKeyExchange.dh_Ys);
1838
		freebytes(m->u.serverKeyExchange.dh_parameters);
1839
		freebytes(m->u.serverKeyExchange.dh_signature);
1840
		break;
2 - 1841
	case HClientKeyExchange:
26 7u83 1842
		freebytes(m->u.clientKeyExchange.pskid);
2 - 1843
		freebytes(m->u.clientKeyExchange.key);
1844
		break;
1845
	case HFinished:
1846
		break;
1847
	}
1848
	memset(m, 0, sizeof(Msg));
1849
}
1850
 
1851
static char *
1852
bytesPrint(char *bs, char *be, char *s0, Bytes *b, char *s1)
1853
{
1854
	int i;
1855
 
1856
	if(s0)
1857
		bs = seprint(bs, be, "%s", s0);
1858
	if(b == nil)
1859
		bs = seprint(bs, be, "nil");
26 7u83 1860
	else {
1861
		bs = seprint(bs, be, "<%d> [ ", b->len);
2 - 1862
		for(i=0; i<b->len; i++)
1863
			bs = seprint(bs, be, "%.2x ", b->data[i]);
26 7u83 1864
		bs = seprint(bs, be, "]");
1865
	}
2 - 1866
	if(s1)
1867
		bs = seprint(bs, be, "%s", s1);
1868
	return bs;
1869
}
1870
 
1871
static char *
1872
intsPrint(char *bs, char *be, char *s0, Ints *b, char *s1)
1873
{
1874
	int i;
1875
 
1876
	if(s0)
1877
		bs = seprint(bs, be, "%s", s0);
1878
	if(b == nil)
1879
		bs = seprint(bs, be, "nil");
26 7u83 1880
	else {
1881
		bs = seprint(bs, be, "[ ");
2 - 1882
		for(i=0; i<b->len; i++)
1883
			bs = seprint(bs, be, "%x ", b->data[i]);
26 7u83 1884
		bs = seprint(bs, be, "]");
1885
	}
2 - 1886
	if(s1)
1887
		bs = seprint(bs, be, "%s", s1);
1888
	return bs;
1889
}
1890
 
1891
static char*
1892
msgPrint(char *buf, int n, Msg *m)
1893
{
1894
	int i;
1895
	char *bs = buf, *be = buf+n;
1896
 
1897
	switch(m->tag) {
1898
	default:
1899
		bs = seprint(bs, be, "unknown %d\n", m->tag);
1900
		break;
1901
	case HClientHello:
1902
		bs = seprint(bs, be, "ClientHello\n");
1903
		bs = seprint(bs, be, "\tversion: %.4x\n", m->u.clientHello.version);
1904
		bs = seprint(bs, be, "\trandom: ");
1905
		for(i=0; i<RandomSize; i++)
1906
			bs = seprint(bs, be, "%.2x", m->u.clientHello.random[i]);
1907
		bs = seprint(bs, be, "\n");
1908
		bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
1909
		bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
1910
		bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
26 7u83 1911
		if(m->u.clientHello.extensions != nil)
1912
			bs = bytesPrint(bs, be, "\textensions: ", m->u.clientHello.extensions, "\n");
2 - 1913
		break;
1914
	case HServerHello:
1915
		bs = seprint(bs, be, "ServerHello\n");
1916
		bs = seprint(bs, be, "\tversion: %.4x\n", m->u.serverHello.version);
1917
		bs = seprint(bs, be, "\trandom: ");
1918
		for(i=0; i<RandomSize; i++)
1919
			bs = seprint(bs, be, "%.2x", m->u.serverHello.random[i]);
1920
		bs = seprint(bs, be, "\n");
1921
		bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
1922
		bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
1923
		bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
26 7u83 1924
		if(m->u.serverHello.extensions != nil)
1925
			bs = bytesPrint(bs, be, "\textensions: ", m->u.serverHello.extensions, "\n");
2 - 1926
		break;
1927
	case HCertificate:
1928
		bs = seprint(bs, be, "Certificate\n");
1929
		for(i=0; i<m->u.certificate.ncert; i++)
1930
			bs = bytesPrint(bs, be, "\t", m->u.certificate.certs[i], "\n");
1931
		break;
1932
	case HCertificateRequest:
1933
		bs = seprint(bs, be, "CertificateRequest\n");
1934
		bs = bytesPrint(bs, be, "\ttypes: ", m->u.certificateRequest.types, "\n");
26 7u83 1935
		if(m->u.certificateRequest.sigalgs != nil)
1936
			bs = intsPrint(bs, be, "\tsigalgs: ", m->u.certificateRequest.sigalgs, "\n");
2 - 1937
		bs = seprint(bs, be, "\tcertificateauthorities\n");
1938
		for(i=0; i<m->u.certificateRequest.nca; i++)
1939
			bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n");
1940
		break;
26 7u83 1941
	case HCertificateVerify:
1942
		bs = seprint(bs, be, "HCertificateVerify\n");
1943
		if(m->u.certificateVerify.sigalg != 0)
1944
			bs = seprint(bs, be, "\tsigalg: %.4x\n", m->u.certificateVerify.sigalg);
1945
		bs = bytesPrint(bs, be, "\tsignature: ", m->u.certificateVerify.signature,"\n");
1946
		break;	
2 - 1947
	case HServerHelloDone:
1948
		bs = seprint(bs, be, "ServerHelloDone\n");
1949
		break;
26 7u83 1950
	case HServerKeyExchange:
1951
		bs = seprint(bs, be, "HServerKeyExchange\n");
1952
		if(m->u.serverKeyExchange.pskid != nil)
1953
			bs = bytesPrint(bs, be, "\tpskid: ", m->u.serverKeyExchange.pskid, "\n");
1954
		if(m->u.serverKeyExchange.dh_parameters == nil)
1955
			break;
1956
		if(m->u.serverKeyExchange.curve != 0){
1957
			bs = seprint(bs, be, "\tcurve: %.4x\n", m->u.serverKeyExchange.curve);
1958
		} else {
1959
			bs = bytesPrint(bs, be, "\tdh_p: ", m->u.serverKeyExchange.dh_p, "\n");
1960
			bs = bytesPrint(bs, be, "\tdh_g: ", m->u.serverKeyExchange.dh_g, "\n");
1961
		}
1962
		bs = bytesPrint(bs, be, "\tdh_Ys: ", m->u.serverKeyExchange.dh_Ys, "\n");
1963
		if(m->u.serverKeyExchange.sigalg != 0)
1964
			bs = seprint(bs, be, "\tsigalg: %.4x\n", m->u.serverKeyExchange.sigalg);
1965
		bs = bytesPrint(bs, be, "\tdh_parameters: ", m->u.serverKeyExchange.dh_parameters, "\n");
1966
		bs = bytesPrint(bs, be, "\tdh_signature: ", m->u.serverKeyExchange.dh_signature, "\n");
1967
		break;
2 - 1968
	case HClientKeyExchange:
1969
		bs = seprint(bs, be, "HClientKeyExchange\n");
26 7u83 1970
		if(m->u.clientKeyExchange.pskid != nil)
1971
			bs = bytesPrint(bs, be, "\tpskid: ", m->u.clientKeyExchange.pskid, "\n");
1972
		if(m->u.clientKeyExchange.key != nil)
1973
			bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
2 - 1974
		break;
1975
	case HFinished:
1976
		bs = seprint(bs, be, "HFinished\n");
1977
		for(i=0; i<m->u.finished.n; i++)
1978
			bs = seprint(bs, be, "%.2x", m->u.finished.verify[i]);
1979
		bs = seprint(bs, be, "\n");
1980
		break;
1981
	}
1982
	USED(bs);
1983
	return buf;
1984
}
1985
 
1986
static void
1987
tlsError(TlsConnection *c, int err, char *fmt, ...)
1988
{
1989
	char msg[512];
1990
	va_list arg;
1991
 
1992
	va_start(arg, fmt);
1993
	vseprint(msg, msg+sizeof(msg), fmt, arg);
1994
	va_end(arg);
1995
	if(c->trace)
1996
		c->trace("tlsError: %s\n", msg);
26 7u83 1997
	if(c->erred)
2 - 1998
		fprint(2, "double error: %r, %s", msg);
1999
	else
26 7u83 2000
		errstr(msg, sizeof(msg));
2 - 2001
	c->erred = 1;
2002
	fprint(c->ctl, "alert %d", err);
2003
}
2004
 
2005
// commit to specific version number
2006
static int
2007
setVersion(TlsConnection *c, int version)
2008
{
26 7u83 2009
	if(version > MaxProtoVersion || version < MinProtoVersion)
2 - 2010
		return -1;
2011
	if(version > c->version)
2012
		version = c->version;
2013
	if(version == SSL3Version) {
2014
		c->version = version;
2015
		c->finished.n = SSL3FinishedLen;
26 7u83 2016
	}else {
2 - 2017
		c->version = version;
2018
		c->finished.n = TLSFinishedLen;
26 7u83 2019
	}
2 - 2020
	return fprint(c->ctl, "version 0x%x", version);
2021
}
2022
 
2023
// confirm that received Finished message matches the expected value
2024
static int
2025
finishedMatch(TlsConnection *c, Finished *f)
2026
{
26 7u83 2027
	return tsmemcmp(f->verify, c->finished.verify, f->n) == 0;
2 - 2028
}
2029
 
2030
// free memory associated with TlsConnection struct
2031
//		(but don't close the TLS channel itself)
2032
static void
2033
tlsConnectionFree(TlsConnection *c)
2034
{
26 7u83 2035
	if(c == nil)
2036
		return;
2037
 
2038
	dh_finish(&c->sec->dh, nil);
2039
 
2040
	mpfree(c->sec->ec.Q.x);
2041
	mpfree(c->sec->ec.Q.y);
2042
	mpfree(c->sec->ec.Q.d);
2043
	ecdomfree(&c->sec->ec.dom);
2044
 
2045
	factotum_rsa_close(c->sec->rpc);
2046
	rsapubfree(c->sec->rsapub);
2 - 2047
	freebytes(c->cert);
26 7u83 2048
 
2049
	memset(c, 0, sizeof(*c));
2 - 2050
	free(c);
2051
}
2052
 
2053
 
2054
//================= cipher choices ========================
2055
 
26 7u83 2056
static int
2057
isDHE(int tlsid)
2 - 2058
{
26 7u83 2059
	switch(tlsid){
2060
	case TLS_DHE_RSA_WITH_AES_128_GCM_SHA256:
2061
	case TLS_DHE_RSA_WITH_AES_128_CBC_SHA256:
2062
 	case TLS_DHE_RSA_WITH_AES_128_CBC_SHA:
2063
 	case TLS_DHE_RSA_WITH_AES_256_CBC_SHA:
2064
 	case TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA:
2065
	case TLS_DHE_RSA_WITH_CHACHA20_POLY1305:
2066
	case GOOGLE_DHE_RSA_WITH_CHACHA20_POLY1305:
2067
		return 1;
2068
	}
2069
	return 0;
2070
}
2 - 2071
 
2072
static int
26 7u83 2073
isECDHE(int tlsid)
2074
{
2075
	switch(tlsid){
2076
	case TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305:
2077
	case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305:
2078
 
2079
	case GOOGLE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305:
2080
	case GOOGLE_ECDHE_RSA_WITH_CHACHA20_POLY1305:
2081
 
2082
	case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
2083
	case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
2084
 
2085
	case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256:
2086
	case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256:
2087
	case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA:
2088
	case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
2089
		return 1;
2090
	}
2091
	return 0;
2092
}
2093
 
2094
static int
2095
isPSK(int tlsid)
2096
{
2097
	switch(tlsid){
2098
	case TLS_PSK_WITH_CHACHA20_POLY1305:
2099
	case TLS_PSK_WITH_AES_128_CBC_SHA256:
2100
	case TLS_PSK_WITH_AES_128_CBC_SHA:
2101
		return 1;
2102
	}
2103
	return 0;
2104
}
2105
 
2106
static int
2107
isECDSA(int tlsid)
2108
{
2109
	switch(tlsid){
2110
	case TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305:
2111
	case GOOGLE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305:
2112
	case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
2113
	case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256:
2114
		return 1;
2115
	}
2116
	return 0;
2117
}
2118
 
2119
static int
2 - 2120
setAlgs(TlsConnection *c, int a)
2121
{
2122
	int i;
2123
 
2124
	for(i = 0; i < nelem(cipherAlgs); i++){
2125
		if(cipherAlgs[i].tlsid == a){
26 7u83 2126
			c->cipher = a;
2 - 2127
			c->enc = cipherAlgs[i].enc;
2128
			c->digest = cipherAlgs[i].digest;
2129
			c->nsecret = cipherAlgs[i].nsecret;
2130
			if(c->nsecret > MaxKeyData)
2131
				return 0;
2132
			return 1;
2133
		}
2134
	}
2135
	return 0;
2136
}
2137
 
2138
static int
26 7u83 2139
okCipher(Ints *cv, int ispsk)
2 - 2140
{
26 7u83 2141
	int i, c;
2 - 2142
 
26 7u83 2143
	for(i = 0; i < nelem(cipherAlgs); i++) {
2144
		c = cipherAlgs[i].tlsid;
2145
		if(!cipherAlgs[i].ok || isECDSA(c) || isDHE(c) || isPSK(c) != ispsk)
2146
			continue;
2147
		if(lookupid(cv, c) >= 0)
2148
			return c;
2 - 2149
	}
2150
	return -1;
2151
}
2152
 
2153
static int
2154
okCompression(Bytes *cv)
2155
{
26 7u83 2156
	int i, c;
2 - 2157
 
26 7u83 2158
	for(i = 0; i < nelem(compressors); i++) {
2159
		c = compressors[i];
2160
		if(memchr(cv->data, c, cv->len) != nil)
2161
			return c;
2 - 2162
	}
2163
	return -1;
2164
}
2165
 
2166
static Lock	ciphLock;
2167
static int	nciphers;
2168
 
2169
static int
2170
initCiphers(void)
2171
{
2172
	enum {MaxAlgF = 1024, MaxAlgs = 10};
2173
	char s[MaxAlgF], *flds[MaxAlgs];
2174
	int i, j, n, ok;
2175
 
2176
	lock(&ciphLock);
2177
	if(nciphers){
2178
		unlock(&ciphLock);
2179
		return nciphers;
2180
	}
2181
	j = open("#a/tls/encalgs", OREAD);
2182
	if(j < 0){
2183
		werrstr("can't open #a/tls/encalgs: %r");
26 7u83 2184
		goto out;
2 - 2185
	}
2186
	n = read(j, s, MaxAlgF-1);
2187
	close(j);
2188
	if(n <= 0){
2189
		werrstr("nothing in #a/tls/encalgs: %r");
26 7u83 2190
		goto out;
2 - 2191
	}
2192
	s[n] = 0;
2193
	n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
2194
	for(i = 0; i < nelem(cipherAlgs); i++){
2195
		ok = 0;
2196
		for(j = 0; j < n; j++){
2197
			if(strcmp(cipherAlgs[i].enc, flds[j]) == 0){
2198
				ok = 1;
2199
				break;
2200
			}
2201
		}
2202
		cipherAlgs[i].ok = ok;
2203
	}
2204
 
2205
	j = open("#a/tls/hashalgs", OREAD);
2206
	if(j < 0){
2207
		werrstr("can't open #a/tls/hashalgs: %r");
26 7u83 2208
		goto out;
2 - 2209
	}
2210
	n = read(j, s, MaxAlgF-1);
2211
	close(j);
2212
	if(n <= 0){
2213
		werrstr("nothing in #a/tls/hashalgs: %r");
26 7u83 2214
		goto out;
2 - 2215
	}
2216
	s[n] = 0;
2217
	n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
2218
	for(i = 0; i < nelem(cipherAlgs); i++){
2219
		ok = 0;
2220
		for(j = 0; j < n; j++){
2221
			if(strcmp(cipherAlgs[i].digest, flds[j]) == 0){
2222
				ok = 1;
2223
				break;
2224
			}
2225
		}
2226
		cipherAlgs[i].ok &= ok;
2227
		if(cipherAlgs[i].ok)
2228
			nciphers++;
2229
	}
26 7u83 2230
out:
2 - 2231
	unlock(&ciphLock);
2232
	return nciphers;
2233
}
2234
 
2235
static Ints*
26 7u83 2236
makeciphers(int ispsk)
2 - 2237
{
2238
	Ints *is;
2239
	int i, j;
2240
 
2241
	is = newints(nciphers);
2242
	j = 0;
26 7u83 2243
	for(i = 0; i < nelem(cipherAlgs); i++)
2244
		if(cipherAlgs[i].ok && isPSK(cipherAlgs[i].tlsid) == ispsk)
2 - 2245
			is->data[j++] = cipherAlgs[i].tlsid;
26 7u83 2246
	is->len = j;
2 - 2247
	return is;
2248
}
2249
 
2250
 
2251
//================= security functions ========================
2252
 
26 7u83 2253
// given a public key, set up connection to factotum
2254
// for using corresponding private key
2 - 2255
static AuthRpc*
26 7u83 2256
factotum_rsa_open(RSApub *rsapub)
2 - 2257
{
2258
	int afd;
2259
	char *s;
26 7u83 2260
	mpint *n;
2 - 2261
	AuthRpc *rpc;
2262
 
2263
	// start talking to factotum
2264
	if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
2265
		return nil;
2266
	if((rpc = auth_allocrpc(afd)) == nil){
2267
		close(afd);
2268
		return nil;
2269
	}
2270
	s = "proto=rsa service=tls role=client";
26 7u83 2271
	if(auth_rpc(rpc, "start", s, strlen(s)) == ARok){
2272
		// roll factotum keyring around to match public key
2273
		n = mpnew(0);
2274
		while(auth_rpc(rpc, "read", nil, 0) == ARok){
2275
			if(strtomp(rpc->arg, nil, 16, n) != nil
2276
			&& mpcmp(n, rsapub->n) == 0){
2277
				mpfree(n);
2278
				return rpc;
2279
			}
2 - 2280
		}
26 7u83 2281
		mpfree(n);
2 - 2282
	}
26 7u83 2283
	factotum_rsa_close(rpc);
2284
	return nil;
2 - 2285
}
2286
 
2287
static mpint*
2288
factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher)
2289
{
2290
	char *p;
2291
	int rv;
2292
 
26 7u83 2293
	if(cipher == nil)
2 - 2294
		return nil;
26 7u83 2295
	p = mptoa(cipher, 16, nil, 0);
2296
	mpfree(cipher);
2297
	if(p == nil)
2298
		return nil;
2 - 2299
	rv = auth_rpc(rpc, "write", p, strlen(p));
2300
	free(p);
2301
	if(rv != ARok || auth_rpc(rpc, "read", nil, 0) != ARok)
2302
		return nil;
2303
	return strtomp(rpc->arg, nil, 16, nil);
2304
}
2305
 
2306
static void
26 7u83 2307
factotum_rsa_close(AuthRpc *rpc)
2 - 2308
{
26 7u83 2309
	if(rpc == nil)
2 - 2310
		return;
2311
	close(rpc->afd);
2312
	auth_freerpc(rpc);
2313
}
2314
 
2315
static void
2316
tlsPmd5(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2317
{
2318
	uchar ai[MD5dlen], tmp[MD5dlen];
2319
	int i, n;
2320
	MD5state *s;
2321
 
2322
	// generate a1
2323
	s = hmac_md5(label, nlabel, key, nkey, nil, nil);
2324
	s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
2325
	hmac_md5(seed1, nseed1, key, nkey, ai, s);
2326
 
2327
	while(nbuf > 0) {
2328
		s = hmac_md5(ai, MD5dlen, key, nkey, nil, nil);
2329
		s = hmac_md5(label, nlabel, key, nkey, nil, s);
2330
		s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
2331
		hmac_md5(seed1, nseed1, key, nkey, tmp, s);
2332
		n = MD5dlen;
2333
		if(n > nbuf)
2334
			n = nbuf;
2335
		for(i = 0; i < n; i++)
2336
			buf[i] ^= tmp[i];
2337
		buf += n;
2338
		nbuf -= n;
2339
		hmac_md5(ai, MD5dlen, key, nkey, tmp, nil);
2340
		memmove(ai, tmp, MD5dlen);
2341
	}
2342
}
2343
 
2344
static void
2345
tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2346
{
2347
	uchar ai[SHA1dlen], tmp[SHA1dlen];
2348
	int i, n;
2349
	SHAstate *s;
2350
 
2351
	// generate a1
2352
	s = hmac_sha1(label, nlabel, key, nkey, nil, nil);
2353
	s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
2354
	hmac_sha1(seed1, nseed1, key, nkey, ai, s);
2355
 
2356
	while(nbuf > 0) {
2357
		s = hmac_sha1(ai, SHA1dlen, key, nkey, nil, nil);
2358
		s = hmac_sha1(label, nlabel, key, nkey, nil, s);
2359
		s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
2360
		hmac_sha1(seed1, nseed1, key, nkey, tmp, s);
2361
		n = SHA1dlen;
2362
		if(n > nbuf)
2363
			n = nbuf;
2364
		for(i = 0; i < n; i++)
2365
			buf[i] ^= tmp[i];
2366
		buf += n;
2367
		nbuf -= n;
2368
		hmac_sha1(ai, SHA1dlen, key, nkey, tmp, nil);
2369
		memmove(ai, tmp, SHA1dlen);
2370
	}
2371
}
2372
 
26 7u83 2373
static void
2374
p_sha256(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed, int nseed)
2375
{
2376
	uchar ai[SHA2_256dlen], tmp[SHA2_256dlen];
2377
	SHAstate *s;
2378
	int n;
2379
 
2380
	// generate a1
2381
	s = hmac_sha2_256(label, nlabel, key, nkey, nil, nil);
2382
	hmac_sha2_256(seed, nseed, key, nkey, ai, s);
2383
 
2384
	while(nbuf > 0) {
2385
		s = hmac_sha2_256(ai, SHA2_256dlen, key, nkey, nil, nil);
2386
		s = hmac_sha2_256(label, nlabel, key, nkey, nil, s);
2387
		hmac_sha2_256(seed, nseed, key, nkey, tmp, s);
2388
		n = SHA2_256dlen;
2389
		if(n > nbuf)
2390
			n = nbuf;
2391
		memmove(buf, tmp, n);
2392
		buf += n;
2393
		nbuf -= n;
2394
		hmac_sha2_256(ai, SHA2_256dlen, key, nkey, tmp, nil);
2395
		memmove(ai, tmp, SHA2_256dlen);
2396
	}
2397
}
2398
 
2 - 2399
// fill buf with md5(args)^sha1(args)
2400
static void
26 7u83 2401
tls10PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2 - 2402
{
2403
	int nlabel = strlen(label);
2404
	int n = (nkey + 1) >> 1;
2405
 
26 7u83 2406
	memset(buf, 0, nbuf);
2 - 2407
	tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
2408
	tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
2409
}
2410
 
26 7u83 2411
static void
2412
tls12PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2413
{
2414
	uchar seed[2*RandomSize];
2 - 2415
 
26 7u83 2416
	assert(nseed0+nseed1 <= sizeof(seed));
2417
	memmove(seed, seed0, nseed0);
2418
	memmove(seed+nseed0, seed1, nseed1);
2419
	p_sha256(buf, nbuf, key, nkey, (uchar*)label, strlen(label), seed, nseed0+nseed1);
2420
}
2421
 
2422
static void
2423
sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2 - 2424
{
26 7u83 2425
	uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
2426
	DigestState *s;
2427
	int i, n, len;
2 - 2428
 
26 7u83 2429
	USED(label);
2430
	len = 1;
2431
	while(nbuf > 0){
2432
		if(len > 26)
2433
			return;
2434
		for(i = 0; i < len; i++)
2435
			tmp[i] = 'A' - 1 + len;
2436
		s = sha1(tmp, len, nil, nil);
2437
		s = sha1(key, nkey, nil, s);
2438
		s = sha1(seed0, nseed0, nil, s);
2439
		sha1(seed1, nseed1, sha1dig, s);
2440
		s = md5(key, nkey, nil, nil);
2441
		md5(sha1dig, SHA1dlen, md5dig, s);
2442
		n = MD5dlen;
2443
		if(n > nbuf)
2444
			n = nbuf;
2445
		memmove(buf, md5dig, n);
2446
		buf += n;
2447
		nbuf -= n;
2448
		len++;
2449
	}
2450
}
2 - 2451
 
26 7u83 2452
static void
2453
sslSetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isclient)
2454
{
2455
	DigestState *s;
2456
	uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
2457
	char *label;
2458
 
2459
	if(isclient)
2460
		label = "CLNT";
2461
	else
2462
		label = "SRVR";
2463
 
2464
	md5((uchar*)label, 4, nil, &hsh.md5);
2465
	md5(sec->sec, MasterSecretSize, nil, &hsh.md5);
2466
	memset(pad, 0x36, 48);
2467
	md5(pad, 48, nil, &hsh.md5);
2468
	md5(nil, 0, h0, &hsh.md5);
2469
	memset(pad, 0x5C, 48);
2470
	s = md5(sec->sec, MasterSecretSize, nil, nil);
2471
	s = md5(pad, 48, nil, s);
2472
	md5(h0, MD5dlen, finished, s);
2473
 
2474
	sha1((uchar*)label, 4, nil, &hsh.sha1);
2475
	sha1(sec->sec, MasterSecretSize, nil, &hsh.sha1);
2476
	memset(pad, 0x36, 40);
2477
	sha1(pad, 40, nil, &hsh.sha1);
2478
	sha1(nil, 0, h1, &hsh.sha1);
2479
	memset(pad, 0x5C, 40);
2480
	s = sha1(sec->sec, MasterSecretSize, nil, nil);
2481
	s = sha1(pad, 40, nil, s);
2482
	sha1(h1, SHA1dlen, finished + MD5dlen, s);
2483
}
2484
 
2485
// fill "finished" arg with md5(args)^sha1(args)
2486
static void
2487
tls10SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isclient)
2488
{
2489
	uchar h0[MD5dlen], h1[SHA1dlen];
2490
	char *label;
2491
 
2492
	// get current hash value, but allow further messages to be hashed in
2493
	md5(nil, 0, h0, &hsh.md5);
2494
	sha1(nil, 0, h1, &hsh.sha1);
2495
 
2496
	if(isclient)
2497
		label = "client finished";
2498
	else
2499
		label = "server finished";
2500
	tls10PRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
2501
}
2502
 
2503
static void
2504
tls12SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isclient)
2505
{
2506
	uchar seed[SHA2_256dlen];
2507
	char *label;
2508
 
2509
	// get current hash value, but allow further messages to be hashed in
2510
	sha2_256(nil, 0, seed, &hsh.sha2_256);
2511
 
2512
	if(isclient)
2513
		label = "client finished";
2514
	else
2515
		label = "server finished";
2516
	p_sha256(finished, TLSFinishedLen, sec->sec, MasterSecretSize, (uchar*)label, strlen(label), seed, SHA2_256dlen);
2517
}
2518
 
2519
static void
2520
tlsSecInits(TlsSec *sec, int cvers, uchar *crandom)
2521
{
2522
	memset(sec, 0, sizeof(*sec));
2523
	sec->clientVers = cvers;
2 - 2524
	memmove(sec->crandom, crandom, RandomSize);
2525
 
26 7u83 2526
	put32(sec->srandom, time(nil));
2 - 2527
	genrandom(sec->srandom+4, RandomSize-4);
26 7u83 2528
}
2 - 2529
 
26 7u83 2530
static int
2531
tlsSecRSAs(TlsSec *sec, Bytes *epm)
2532
{
2533
	Bytes *pm;
2534
 
2535
	if(epm == nil){
2536
		werrstr("no encrypted premaster secret");
2537
		return -1;
2538
	}
2539
	// if the client messed up, just continue as if everything is ok,
2540
	// to prevent attacks to check for correctly formatted messages.
2541
	pm = pkcs1_decrypt(sec, epm);
2542
	if(pm == nil || pm->len != MasterSecretSize || get16(pm->data) != sec->clientVers){
2543
		freebytes(pm);
2544
		pm = newbytes(MasterSecretSize);
2545
		genrandom(pm->data, pm->len);
2546
	}
2547
	setMasterSecret(sec, pm);
2548
	return 0;
2 - 2549
}
2550
 
26 7u83 2551
static Bytes*
2552
tlsSecECDHEs1(TlsSec *sec, Namedcurve *nc)
2553
{
2554
	ECdomain *dom = &sec->ec.dom;
2555
	ECpriv *Q = &sec->ec.Q;
2556
	Bytes *par;
2557
	int n;
2558
 
2559
	ecdominit(dom, nc->init);
2560
	memset(Q, 0, sizeof(*Q));
2561
	Q->x = mpnew(0);
2562
	Q->y = mpnew(0);
2563
	Q->d = mpnew(0);
2564
	ecgen(dom, Q);
2565
	n = 1 + 2*((mpsignif(dom->p)+7)/8);
2566
	par = newbytes(1+2+1+n);
2567
	par->data[0] = 3;
2568
	put16(par->data+1, nc->tlsid);
2569
	n = ecencodepub(dom, Q, par->data+4, par->len-4);
2570
	par->data[3] = n;
2571
	par->len = 1+2+1+n;
2572
 
2573
	return par;
2574
}
2575
 
2 - 2576
static int
26 7u83 2577
tlsSecECDHEs2(TlsSec *sec, Bytes *Yc)
2 - 2578
{
26 7u83 2579
	ECdomain *dom = &sec->ec.dom;
2580
	ECpriv *Q = &sec->ec.Q;
2581
	ECpoint K;
2582
	ECpub *Y;
2583
 
2584
	if(Yc == nil){
2585
		werrstr("no public key");
2586
		return -1;
2 - 2587
	}
26 7u83 2588
 
2589
	if((Y = ecdecodepub(dom, Yc->data, Yc->len)) == nil){
2590
		werrstr("bad public key");
2591
		return -1;
2592
	}
2593
 
2594
	memset(&K, 0, sizeof(K));
2595
	K.x = mpnew(0);
2596
	K.y = mpnew(0);
2597
 
2598
	ecmul(dom, Y, Q->d, &K);
2599
 
2600
	setMasterSecret(sec, mptobytes(K.x, (mpsignif(dom->p)+7)/8));
2601
 
2602
	mpfree(K.x);
2603
	mpfree(K.y);
2604
 
2605
	ecpubfree(Y);
2606
 
2 - 2607
	return 0;
2608
}
2609
 
26 7u83 2610
static void
2611
tlsSecInitc(TlsSec *sec, int cvers)
2 - 2612
{
26 7u83 2613
	memset(sec, 0, sizeof(*sec));
2 - 2614
	sec->clientVers = cvers;
26 7u83 2615
	put32(sec->crandom, time(nil));
2 - 2616
	genrandom(sec->crandom+4, RandomSize-4);
2617
}
2618
 
26 7u83 2619
static Bytes*
2620
tlsSecRSAc(TlsSec *sec, uchar *cert, int ncert)
2 - 2621
{
2622
	RSApub *pub;
26 7u83 2623
	Bytes *pm, *epm;
2 - 2624
 
2625
	pub = X509toRSApub(cert, ncert, nil, 0);
2626
	if(pub == nil){
2627
		werrstr("invalid x509/rsa certificate");
26 7u83 2628
		return nil;
2 - 2629
	}
26 7u83 2630
	pm = newbytes(MasterSecretSize);
2631
	put16(pm->data, sec->clientVers);
2632
	genrandom(pm->data+2, MasterSecretSize - 2);
2633
	epm = pkcs1_encrypt(pm, pub);
2634
	setMasterSecret(sec, pm);
2 - 2635
	rsapubfree(pub);
26 7u83 2636
	return epm;
2 - 2637
}
2638
 
2639
static int
26 7u83 2640
tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient)
2 - 2641
{
2642
	if(sec->nfin != nfin){
2643
		werrstr("invalid finished exchange");
2644
		return -1;
2645
	}
26 7u83 2646
	hsh.md5.malloced = 0;
2647
	hsh.sha1.malloced = 0;
2648
	hsh.sha2_256.malloced = 0;
2649
	(*sec->setFinished)(sec, hsh, fin, isclient);
2650
	return 0;
2 - 2651
}
2652
 
2653
static void
26 7u83 2654
tlsSecVers(TlsSec *sec, int v)
2 - 2655
{
2656
	if(v == SSL3Version){
2657
		sec->setFinished = sslSetFinished;
2658
		sec->nfin = SSL3FinishedLen;
2659
		sec->prf = sslPRF;
26 7u83 2660
	}else if(v < TLS12Version) {
2661
		sec->setFinished = tls10SetFinished;
2 - 2662
		sec->nfin = TLSFinishedLen;
26 7u83 2663
		sec->prf = tls10PRF;
2664
	}else {
2665
		sec->setFinished = tls12SetFinished;
2666
		sec->nfin = TLSFinishedLen;
2667
		sec->prf = tls12PRF;
2 - 2668
	}
2669
}
2670
 
26 7u83 2671
static int
2672
setSecrets(TlsConnection *c, int isclient)
2 - 2673
{
26 7u83 2674
	uchar kd[MaxKeyData];
2675
	char *secrets;
2676
	int rv;
2677
 
2678
	assert(c->nsecret <= sizeof(kd));
2679
	secrets = emalloc(2*c->nsecret);
2680
 
2681
	/*
2682
	 * generate secret keys from the master secret.
2683
	 *
2684
	 * different cipher selections will require different amounts
2685
	 * of key expansion and use of key expansion data,
2686
	 * but it's all generated using the same function.
2687
	 */
2688
	(*c->sec->prf)(kd, c->nsecret, c->sec->sec, MasterSecretSize, "key expansion",
2689
			c->sec->srandom, RandomSize, c->sec->crandom, RandomSize);
2690
 
2691
	enc64(secrets, 2*c->nsecret, kd, c->nsecret);
2692
	memset(kd, 0, c->nsecret);
2693
 
2694
	rv = fprint(c->ctl, "secret %s %s %d %s", c->digest, c->enc, isclient, secrets);
2695
	memset(secrets, 0, 2*c->nsecret);
2696
	free(secrets);
2697
 
2698
	return rv;
2 - 2699
}
2700
 
2701
/*
26 7u83 2702
 * set the master secret from the pre-master secret,
2703
 * destroys premaster.
2 - 2704
 */
2705
static void
2706
setMasterSecret(TlsSec *sec, Bytes *pm)
2707
{
26 7u83 2708
	if(sec->psklen > 0){
2709
		Bytes *opm = pm;
2710
		uchar *p;
2 - 2711
 
26 7u83 2712
		/* concatenate psk to pre-master secret */
2713
		pm = newbytes(4 + opm->len + sec->psklen);
2714
		p = pm->data;
2715
		put16(p, opm->len), p += 2;
2716
		memmove(p, opm->data, opm->len), p += opm->len;
2717
		put16(p, sec->psklen), p += 2;
2718
		memmove(p, sec->psk, sec->psklen);
2 - 2719
 
26 7u83 2720
		memset(opm->data, 0, opm->len);
2721
		freebytes(opm);
2722
	}
2 - 2723
 
26 7u83 2724
	(*sec->prf)(sec->sec, MasterSecretSize, pm->data, pm->len, "master secret",
2725
			sec->crandom, RandomSize, sec->srandom, RandomSize);
2726
 
2 - 2727
	memset(pm->data, 0, pm->len);	
2728
	freebytes(pm);
2729
}
2730
 
2731
static int
26 7u83 2732
digestDHparams(TlsSec *sec, Bytes *par, uchar digest[MAXdlen], int sigalg)
2 - 2733
{
26 7u83 2734
	int hashalg = (sigalg>>8) & 0xFF;
2735
	int digestlen;
2736
	Bytes *blob;
2 - 2737
 
26 7u83 2738
	blob = newbytes(2*RandomSize + par->len);
2739
	memmove(blob->data+0*RandomSize, sec->crandom, RandomSize);
2740
	memmove(blob->data+1*RandomSize, sec->srandom, RandomSize);
2741
	memmove(blob->data+2*RandomSize, par->data, par->len);
2742
	if(hashalg == 0){
2743
		digestlen = MD5dlen+SHA1dlen;
2744
		md5(blob->data, blob->len, digest, nil);
2745
		sha1(blob->data, blob->len, digest+MD5dlen, nil);
2746
	} else {
2747
		digestlen = -1;
2748
		if(hashalg < nelem(hashfun) && hashfun[hashalg].fun != nil){
2749
			digestlen = hashfun[hashalg].len;
2750
			(*hashfun[hashalg].fun)(blob->data, blob->len, digest, nil);
2751
		}
2 - 2752
	}
26 7u83 2753
	freebytes(blob);
2754
	return digestlen;
2 - 2755
}
2756
 
26 7u83 2757
static char*
2758
verifyDHparams(TlsSec *sec, Bytes *par, Bytes *cert, Bytes *sig, int sigalg)
2 - 2759
{
26 7u83 2760
	uchar digest[MAXdlen];
2761
	int digestlen;
2762
	ECdomain dom;
2763
	ECpub *ecpk;
2764
	RSApub *rsapk;
2765
	char *err;
2 - 2766
 
26 7u83 2767
	if(par == nil || par->len <= 0)
2768
		return "no DH parameters";
2 - 2769
 
26 7u83 2770
	if(sig == nil || sig->len <= 0){
2771
		if(sec->psklen > 0)
2772
			return nil;
2773
		return "no signature";
2774
	}
2 - 2775
 
26 7u83 2776
	if(cert == nil)
2777
		return "no certificate";
2 - 2778
 
26 7u83 2779
	digestlen = digestDHparams(sec, par, digest, sigalg);
2780
	if(digestlen <= 0)
2781
		return "unknown signature digest algorithm";
2782
 
2783
	switch(sigalg & 0xFF){
2784
	case 0x01:
2785
		rsapk = X509toRSApub(cert->data, cert->len, nil, 0);
2786
		if(rsapk == nil)
2787
			return "bad certificate";
2788
		err = X509rsaverifydigest(sig->data, sig->len, digest, digestlen, rsapk);
2789
		rsapubfree(rsapk);
2790
		break;
2791
	case 0x03:
2792
		ecpk = X509toECpub(cert->data, cert->len, nil, 0, &dom);
2793
		if(ecpk == nil)
2794
			return "bad certificate";
2795
		err = X509ecdsaverifydigest(sig->data, sig->len, digest, digestlen, &dom, ecpk);
2796
		ecdomfree(&dom);
2797
		ecpubfree(ecpk);
2798
		break;
2799
	default:
2800
		err = "signaure algorithm not RSA or ECDSA";
2 - 2801
	}
2802
 
26 7u83 2803
	return err;
2 - 2804
}
2805
 
26 7u83 2806
// encrypt data according to PKCS#1, /lib/rfc/rfc2437 9.1.2.1
2 - 2807
static Bytes*
26 7u83 2808
pkcs1_encrypt(Bytes* data, RSApub* key)
2 - 2809
{
2810
	mpint *x, *y;
2811
 
26 7u83 2812
	x = pkcs1padbuf(data->data, data->len, key->n, 2);
2813
	if(x == nil)
2814
		return nil;
2 - 2815
	y = rsaencrypt(key, x, nil);
2816
	mpfree(x);
26 7u83 2817
	data = newbytes((mpsignif(key->n)+7)/8);
2818
	mptober(y, data->data, data->len);
2 - 2819
	mpfree(y);
26 7u83 2820
	return data;
2 - 2821
}
2822
 
26 7u83 2823
// decrypt data according to PKCS#1, with given key.
2 - 2824
static Bytes*
26 7u83 2825
pkcs1_decrypt(TlsSec *sec, Bytes *data)
2 - 2826
{
26 7u83 2827
	mpint *y;
2 - 2828
 
26 7u83 2829
	if(data->len != (mpsignif(sec->rsapub->n)+7)/8)
2 - 2830
		return nil;
26 7u83 2831
	y = factotum_rsa_decrypt(sec->rpc, bytestomp(data));
2832
	if(y == nil)
2833
		return nil;
2834
	data = mptobytes(y, (mpsignif(y)+7)/8);
2835
	mpfree(y);
2836
	if((data->len = pkcs1unpadbuf(data->data, data->len, sec->rsapub->n, 2)) < 0){
2837
		freebytes(data);
2838
		return nil;
2 - 2839
	}
26 7u83 2840
	return data;
2 - 2841
}
2842
 
2843
static Bytes*
26 7u83 2844
pkcs1_sign(TlsSec *sec, uchar *digest, int digestlen, int sigalg)
2 - 2845
{
26 7u83 2846
	int hashalg = (sigalg>>8)&0xFF;
2847
	mpint *signedMP;
2848
	Bytes *signature;
2849
	uchar buf[128];
2 - 2850
 
26 7u83 2851
	if(hashalg > 0 && hashalg < nelem(hashfun) && hashfun[hashalg].len == digestlen)
2852
		digestlen = asn1encodedigest(hashfun[hashalg].fun, digest, buf, sizeof(buf));
2853
	else if(digestlen == MD5dlen+SHA1dlen)
2854
		memmove(buf, digest, digestlen);
2855
	else
2856
		digestlen = -1;
2857
	if(digestlen <= 0){
2858
		werrstr("bad digest algorithm");
2 - 2859
		return nil;
26 7u83 2860
	}
2861
 
2862
	signedMP = factotum_rsa_decrypt(sec->rpc, pkcs1padbuf(buf, digestlen, sec->rsapub->n, 1));
2863
	if(signedMP == nil)
2 - 2864
		return nil;
26 7u83 2865
	signature = mptobytes(signedMP, (mpsignif(sec->rsapub->n)+7)/8);
2866
	mpfree(signedMP);
2867
	return signature;
2 - 2868
}
2869
 
2870
 
2871
//================= general utility functions ========================
2872
 
2873
static void *
2874
emalloc(int n)
2875
{
2876
	void *p;
2877
	if(n==0)
2878
		n=1;
2879
	p = malloc(n);
26 7u83 2880
	if(p == nil)
2881
		sysfatal("out of memory");
2 - 2882
	memset(p, 0, n);
26 7u83 2883
	setmalloctag(p, getcallerpc(&n));
2 - 2884
	return p;
2885
}
2886
 
2887
static void *
2888
erealloc(void *ReallocP, int ReallocN)
2889
{
2890
	if(ReallocN == 0)
2891
		ReallocN = 1;
26 7u83 2892
	if(ReallocP == nil)
2 - 2893
		ReallocP = emalloc(ReallocN);
26 7u83 2894
	else if((ReallocP = realloc(ReallocP, ReallocN)) == nil)
2895
		sysfatal("out of memory");
2896
	setrealloctag(ReallocP, getcallerpc(&ReallocP));
2 - 2897
	return(ReallocP);
2898
}
2899
 
2900
static void
2901
put32(uchar *p, u32int x)
2902
{
2903
	p[0] = x>>24;
2904
	p[1] = x>>16;
2905
	p[2] = x>>8;
2906
	p[3] = x;
2907
}
2908
 
2909
static void
2910
put24(uchar *p, int x)
2911
{
2912
	p[0] = x>>16;
2913
	p[1] = x>>8;
2914
	p[2] = x;
2915
}
2916
 
2917
static void
2918
put16(uchar *p, int x)
2919
{
2920
	p[0] = x>>8;
2921
	p[1] = x;
2922
}
2923
 
2924
static int
2925
get24(uchar *p)
2926
{
2927
	return (p[0]<<16)|(p[1]<<8)|p[2];
2928
}
2929
 
2930
static int
2931
get16(uchar *p)
2932
{
2933
	return (p[0]<<8)|p[1];
2934
}
2935
 
2936
static Bytes*
2937
newbytes(int len)
2938
{
2939
	Bytes* ans;
2940
 
26 7u83 2941
	if(len < 0)
2942
		abort();
2943
	ans = emalloc(sizeof(Bytes) + len);
2 - 2944
	ans->len = len;
2945
	return ans;
2946
}
2947
 
2948
/*
2949
 * newbytes(len), with data initialized from buf
2950
 */
2951
static Bytes*
2952
makebytes(uchar* buf, int len)
2953
{
2954
	Bytes* ans;
2955
 
2956
	ans = newbytes(len);
2957
	memmove(ans->data, buf, len);
2958
	return ans;
2959
}
2960
 
2961
static void
2962
freebytes(Bytes* b)
2963
{
26 7u83 2964
	free(b);
2 - 2965
}
2966
 
26 7u83 2967
static mpint*
2968
bytestomp(Bytes* bytes)
2 - 2969
{
26 7u83 2970
	return betomp(bytes->data, bytes->len, nil);
2971
}
2 - 2972
 
26 7u83 2973
/*
2974
 * Convert mpint* to Bytes, putting high order byte first.
2975
 */
2976
static Bytes*
2977
mptobytes(mpint *big, int len)
2978
{
2979
	Bytes* ans;
2980
 
2981
	if(len == 0) len++;
2982
	ans = newbytes(len);
2983
	mptober(big, ans->data, ans->len);
2 - 2984
	return ans;
2985
}
2986
 
26 7u83 2987
/* len is number of ints */
2 - 2988
static Ints*
26 7u83 2989
newints(int len)
2 - 2990
{
2991
	Ints* ans;
2992
 
26 7u83 2993
	if(len < 0 || len > ((uint)-1>>1)/sizeof(int))
2994
		abort();
2995
	ans = emalloc(sizeof(Ints) + len*sizeof(int));
2996
	ans->len = len;
2 - 2997
	return ans;
2998
}
2999
 
3000
static void
3001
freeints(Ints* b)
3002
{
26 7u83 3003
	free(b);
2 - 3004
}
26 7u83 3005
 
3006
static int
3007
lookupid(Ints* b, int id)
3008
{
3009
	int i;
3010
 
3011
	for(i=0; i<b->len; i++)
3012
		if(b->data[i] == id)
3013
			return i;
3014
	return -1;
3015
}