Linux Audio

Check our new training course

Loading...
Note: File does not exist in v3.1.
   1// SPDX-License-Identifier: GPL-2.0
   2
   3#define _GNU_SOURCE
   4
   5#include <arpa/inet.h>
   6#include <errno.h>
   7#include <error.h>
   8#include <fcntl.h>
   9#include <poll.h>
  10#include <stdio.h>
  11#include <stdlib.h>
  12#include <unistd.h>
  13
  14#include <linux/tls.h>
  15#include <linux/tcp.h>
  16#include <linux/socket.h>
  17
  18#include <sys/epoll.h>
  19#include <sys/types.h>
  20#include <sys/sendfile.h>
  21#include <sys/socket.h>
  22#include <sys/stat.h>
  23
  24#include "../kselftest_harness.h"
  25
  26#define TLS_PAYLOAD_MAX_LEN 16384
  27#define SOL_TLS 282
  28
  29static int fips_enabled;
  30
  31struct tls_crypto_info_keys {
  32	union {
  33		struct tls_crypto_info crypto_info;
  34		struct tls12_crypto_info_aes_gcm_128 aes128;
  35		struct tls12_crypto_info_chacha20_poly1305 chacha20;
  36		struct tls12_crypto_info_sm4_gcm sm4gcm;
  37		struct tls12_crypto_info_sm4_ccm sm4ccm;
  38		struct tls12_crypto_info_aes_ccm_128 aesccm128;
  39		struct tls12_crypto_info_aes_gcm_256 aesgcm256;
  40		struct tls12_crypto_info_aria_gcm_128 ariagcm128;
  41		struct tls12_crypto_info_aria_gcm_256 ariagcm256;
  42	};
  43	size_t len;
  44};
  45
  46static void tls_crypto_info_init(uint16_t tls_version, uint16_t cipher_type,
  47				 struct tls_crypto_info_keys *tls12)
  48{
  49	memset(tls12, 0, sizeof(*tls12));
  50
  51	switch (cipher_type) {
  52	case TLS_CIPHER_CHACHA20_POLY1305:
  53		tls12->len = sizeof(struct tls12_crypto_info_chacha20_poly1305);
  54		tls12->chacha20.info.version = tls_version;
  55		tls12->chacha20.info.cipher_type = cipher_type;
  56		break;
  57	case TLS_CIPHER_AES_GCM_128:
  58		tls12->len = sizeof(struct tls12_crypto_info_aes_gcm_128);
  59		tls12->aes128.info.version = tls_version;
  60		tls12->aes128.info.cipher_type = cipher_type;
  61		break;
  62	case TLS_CIPHER_SM4_GCM:
  63		tls12->len = sizeof(struct tls12_crypto_info_sm4_gcm);
  64		tls12->sm4gcm.info.version = tls_version;
  65		tls12->sm4gcm.info.cipher_type = cipher_type;
  66		break;
  67	case TLS_CIPHER_SM4_CCM:
  68		tls12->len = sizeof(struct tls12_crypto_info_sm4_ccm);
  69		tls12->sm4ccm.info.version = tls_version;
  70		tls12->sm4ccm.info.cipher_type = cipher_type;
  71		break;
  72	case TLS_CIPHER_AES_CCM_128:
  73		tls12->len = sizeof(struct tls12_crypto_info_aes_ccm_128);
  74		tls12->aesccm128.info.version = tls_version;
  75		tls12->aesccm128.info.cipher_type = cipher_type;
  76		break;
  77	case TLS_CIPHER_AES_GCM_256:
  78		tls12->len = sizeof(struct tls12_crypto_info_aes_gcm_256);
  79		tls12->aesgcm256.info.version = tls_version;
  80		tls12->aesgcm256.info.cipher_type = cipher_type;
  81		break;
  82	case TLS_CIPHER_ARIA_GCM_128:
  83		tls12->len = sizeof(struct tls12_crypto_info_aria_gcm_128);
  84		tls12->ariagcm128.info.version = tls_version;
  85		tls12->ariagcm128.info.cipher_type = cipher_type;
  86		break;
  87	case TLS_CIPHER_ARIA_GCM_256:
  88		tls12->len = sizeof(struct tls12_crypto_info_aria_gcm_256);
  89		tls12->ariagcm256.info.version = tls_version;
  90		tls12->ariagcm256.info.cipher_type = cipher_type;
  91		break;
  92	default:
  93		break;
  94	}
  95}
  96
  97static void memrnd(void *s, size_t n)
  98{
  99	int *dword = s;
 100	char *byte;
 101
 102	for (; n >= 4; n -= 4)
 103		*dword++ = rand();
 104	byte = (void *)dword;
 105	while (n--)
 106		*byte++ = rand();
 107}
 108
 109static void ulp_sock_pair(struct __test_metadata *_metadata,
 110			  int *fd, int *cfd, bool *notls)
 111{
 112	struct sockaddr_in addr;
 113	socklen_t len;
 114	int sfd, ret;
 115
 116	*notls = false;
 117	len = sizeof(addr);
 118
 119	addr.sin_family = AF_INET;
 120	addr.sin_addr.s_addr = htonl(INADDR_ANY);
 121	addr.sin_port = 0;
 122
 123	*fd = socket(AF_INET, SOCK_STREAM, 0);
 124	sfd = socket(AF_INET, SOCK_STREAM, 0);
 125
 126	ret = bind(sfd, &addr, sizeof(addr));
 127	ASSERT_EQ(ret, 0);
 128	ret = listen(sfd, 10);
 129	ASSERT_EQ(ret, 0);
 130
 131	ret = getsockname(sfd, &addr, &len);
 132	ASSERT_EQ(ret, 0);
 133
 134	ret = connect(*fd, &addr, sizeof(addr));
 135	ASSERT_EQ(ret, 0);
 136
 137	*cfd = accept(sfd, &addr, &len);
 138	ASSERT_GE(*cfd, 0);
 139
 140	close(sfd);
 141
 142	ret = setsockopt(*fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
 143	if (ret != 0) {
 144		ASSERT_EQ(errno, ENOENT);
 145		*notls = true;
 146		printf("Failure setting TCP_ULP, testing without tls\n");
 147		return;
 148	}
 149
 150	ret = setsockopt(*cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
 151	ASSERT_EQ(ret, 0);
 152}
 153
 154/* Produce a basic cmsg */
 155static int tls_send_cmsg(int fd, unsigned char record_type,
 156			 void *data, size_t len, int flags)
 157{
 158	char cbuf[CMSG_SPACE(sizeof(char))];
 159	int cmsg_len = sizeof(char);
 160	struct cmsghdr *cmsg;
 161	struct msghdr msg;
 162	struct iovec vec;
 163
 164	vec.iov_base = data;
 165	vec.iov_len = len;
 166	memset(&msg, 0, sizeof(struct msghdr));
 167	msg.msg_iov = &vec;
 168	msg.msg_iovlen = 1;
 169	msg.msg_control = cbuf;
 170	msg.msg_controllen = sizeof(cbuf);
 171	cmsg = CMSG_FIRSTHDR(&msg);
 172	cmsg->cmsg_level = SOL_TLS;
 173	/* test sending non-record types. */
 174	cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
 175	cmsg->cmsg_len = CMSG_LEN(cmsg_len);
 176	*CMSG_DATA(cmsg) = record_type;
 177	msg.msg_controllen = cmsg->cmsg_len;
 178
 179	return sendmsg(fd, &msg, flags);
 180}
 181
 182static int tls_recv_cmsg(struct __test_metadata *_metadata,
 183			 int fd, unsigned char record_type,
 184			 void *data, size_t len, int flags)
 185{
 186	char cbuf[CMSG_SPACE(sizeof(char))];
 187	struct cmsghdr *cmsg;
 188	unsigned char ctype;
 189	struct msghdr msg;
 190	struct iovec vec;
 191	int n;
 192
 193	vec.iov_base = data;
 194	vec.iov_len = len;
 195	memset(&msg, 0, sizeof(struct msghdr));
 196	msg.msg_iov = &vec;
 197	msg.msg_iovlen = 1;
 198	msg.msg_control = cbuf;
 199	msg.msg_controllen = sizeof(cbuf);
 200
 201	n = recvmsg(fd, &msg, flags);
 202
 203	cmsg = CMSG_FIRSTHDR(&msg);
 204	EXPECT_NE(cmsg, NULL);
 205	EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
 206	EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
 207	ctype = *((unsigned char *)CMSG_DATA(cmsg));
 208	EXPECT_EQ(ctype, record_type);
 209
 210	return n;
 211}
 212
 213FIXTURE(tls_basic)
 214{
 215	int fd, cfd;
 216	bool notls;
 217};
 218
 219FIXTURE_SETUP(tls_basic)
 220{
 221	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
 222}
 223
 224FIXTURE_TEARDOWN(tls_basic)
 225{
 226	close(self->fd);
 227	close(self->cfd);
 228}
 229
 230/* Send some data through with ULP but no keys */
 231TEST_F(tls_basic, base_base)
 232{
 233	char const *test_str = "test_read";
 234	int send_len = 10;
 235	char buf[10];
 236
 237	ASSERT_EQ(strlen(test_str) + 1, send_len);
 238
 239	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
 240	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
 241	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
 242};
 243
 244TEST_F(tls_basic, bad_cipher)
 245{
 246	struct tls_crypto_info_keys tls12;
 247
 248	tls12.crypto_info.version = 200;
 249	tls12.crypto_info.cipher_type = TLS_CIPHER_AES_GCM_128;
 250	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
 251
 252	tls12.crypto_info.version = TLS_1_2_VERSION;
 253	tls12.crypto_info.cipher_type = 50;
 254	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
 255
 256	tls12.crypto_info.version = TLS_1_2_VERSION;
 257	tls12.crypto_info.cipher_type = 59;
 258	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
 259
 260	tls12.crypto_info.version = TLS_1_2_VERSION;
 261	tls12.crypto_info.cipher_type = 10;
 262	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
 263
 264	tls12.crypto_info.version = TLS_1_2_VERSION;
 265	tls12.crypto_info.cipher_type = 70;
 266	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
 267}
 268
 269TEST_F(tls_basic, recseq_wrap)
 270{
 271	struct tls_crypto_info_keys tls12;
 272	char const *test_str = "test_read";
 273	int send_len = 10;
 274
 275	if (self->notls)
 276		SKIP(return, "no TLS support");
 277
 278	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_128, &tls12);
 279	memset(&tls12.aes128.rec_seq, 0xff, sizeof(tls12.aes128.rec_seq));
 280
 281	ASSERT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
 282	ASSERT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
 283
 284	EXPECT_EQ(send(self->fd, test_str, send_len, 0), -1);
 285	EXPECT_EQ(errno, EBADMSG);
 286}
 287
 288FIXTURE(tls)
 289{
 290	int fd, cfd;
 291	bool notls;
 292};
 293
 294FIXTURE_VARIANT(tls)
 295{
 296	uint16_t tls_version;
 297	uint16_t cipher_type;
 298	bool nopad, fips_non_compliant;
 299};
 300
 301FIXTURE_VARIANT_ADD(tls, 12_aes_gcm)
 302{
 303	.tls_version = TLS_1_2_VERSION,
 304	.cipher_type = TLS_CIPHER_AES_GCM_128,
 305};
 306
 307FIXTURE_VARIANT_ADD(tls, 13_aes_gcm)
 308{
 309	.tls_version = TLS_1_3_VERSION,
 310	.cipher_type = TLS_CIPHER_AES_GCM_128,
 311};
 312
 313FIXTURE_VARIANT_ADD(tls, 12_chacha)
 314{
 315	.tls_version = TLS_1_2_VERSION,
 316	.cipher_type = TLS_CIPHER_CHACHA20_POLY1305,
 317	.fips_non_compliant = true,
 318};
 319
 320FIXTURE_VARIANT_ADD(tls, 13_chacha)
 321{
 322	.tls_version = TLS_1_3_VERSION,
 323	.cipher_type = TLS_CIPHER_CHACHA20_POLY1305,
 324	.fips_non_compliant = true,
 325};
 326
 327FIXTURE_VARIANT_ADD(tls, 13_sm4_gcm)
 328{
 329	.tls_version = TLS_1_3_VERSION,
 330	.cipher_type = TLS_CIPHER_SM4_GCM,
 331	.fips_non_compliant = true,
 332};
 333
 334FIXTURE_VARIANT_ADD(tls, 13_sm4_ccm)
 335{
 336	.tls_version = TLS_1_3_VERSION,
 337	.cipher_type = TLS_CIPHER_SM4_CCM,
 338	.fips_non_compliant = true,
 339};
 340
 341FIXTURE_VARIANT_ADD(tls, 12_aes_ccm)
 342{
 343	.tls_version = TLS_1_2_VERSION,
 344	.cipher_type = TLS_CIPHER_AES_CCM_128,
 345};
 346
 347FIXTURE_VARIANT_ADD(tls, 13_aes_ccm)
 348{
 349	.tls_version = TLS_1_3_VERSION,
 350	.cipher_type = TLS_CIPHER_AES_CCM_128,
 351};
 352
 353FIXTURE_VARIANT_ADD(tls, 12_aes_gcm_256)
 354{
 355	.tls_version = TLS_1_2_VERSION,
 356	.cipher_type = TLS_CIPHER_AES_GCM_256,
 357};
 358
 359FIXTURE_VARIANT_ADD(tls, 13_aes_gcm_256)
 360{
 361	.tls_version = TLS_1_3_VERSION,
 362	.cipher_type = TLS_CIPHER_AES_GCM_256,
 363};
 364
 365FIXTURE_VARIANT_ADD(tls, 13_nopad)
 366{
 367	.tls_version = TLS_1_3_VERSION,
 368	.cipher_type = TLS_CIPHER_AES_GCM_128,
 369	.nopad = true,
 370};
 371
 372FIXTURE_VARIANT_ADD(tls, 12_aria_gcm)
 373{
 374	.tls_version = TLS_1_2_VERSION,
 375	.cipher_type = TLS_CIPHER_ARIA_GCM_128,
 376};
 377
 378FIXTURE_VARIANT_ADD(tls, 12_aria_gcm_256)
 379{
 380	.tls_version = TLS_1_2_VERSION,
 381	.cipher_type = TLS_CIPHER_ARIA_GCM_256,
 382};
 383
 384FIXTURE_SETUP(tls)
 385{
 386	struct tls_crypto_info_keys tls12;
 387	int one = 1;
 388	int ret;
 389
 390	if (fips_enabled && variant->fips_non_compliant)
 391		SKIP(return, "Unsupported cipher in FIPS mode");
 392
 393	tls_crypto_info_init(variant->tls_version, variant->cipher_type,
 394			     &tls12);
 395
 396	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
 397
 398	if (self->notls)
 399		return;
 400
 401	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
 402	ASSERT_EQ(ret, 0);
 403
 404	ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len);
 405	ASSERT_EQ(ret, 0);
 406
 407	if (variant->nopad) {
 408		ret = setsockopt(self->cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
 409				 (void *)&one, sizeof(one));
 410		ASSERT_EQ(ret, 0);
 411	}
 412}
 413
 414FIXTURE_TEARDOWN(tls)
 415{
 416	close(self->fd);
 417	close(self->cfd);
 418}
 419
 420TEST_F(tls, sendfile)
 421{
 422	int filefd = open("/proc/self/exe", O_RDONLY);
 423	struct stat st;
 424
 425	EXPECT_GE(filefd, 0);
 426	fstat(filefd, &st);
 427	EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
 428}
 429
 430TEST_F(tls, send_then_sendfile)
 431{
 432	int filefd = open("/proc/self/exe", O_RDONLY);
 433	char const *test_str = "test_send";
 434	int to_send = strlen(test_str) + 1;
 435	char recv_buf[10];
 436	struct stat st;
 437	char *buf;
 438
 439	EXPECT_GE(filefd, 0);
 440	fstat(filefd, &st);
 441	buf = (char *)malloc(st.st_size);
 442
 443	EXPECT_EQ(send(self->fd, test_str, to_send, 0), to_send);
 444	EXPECT_EQ(recv(self->cfd, recv_buf, to_send, MSG_WAITALL), to_send);
 445	EXPECT_EQ(memcmp(test_str, recv_buf, to_send), 0);
 446
 447	EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
 448	EXPECT_EQ(recv(self->cfd, buf, st.st_size, MSG_WAITALL), st.st_size);
 449}
 450
 451static void chunked_sendfile(struct __test_metadata *_metadata,
 452			     struct _test_data_tls *self,
 453			     uint16_t chunk_size,
 454			     uint16_t extra_payload_size)
 455{
 456	char buf[TLS_PAYLOAD_MAX_LEN];
 457	uint16_t test_payload_size;
 458	int size = 0;
 459	int ret;
 460	char filename[] = "/tmp/mytemp.XXXXXX";
 461	int fd = mkstemp(filename);
 462	off_t offset = 0;
 463
 464	unlink(filename);
 465	ASSERT_GE(fd, 0);
 466	EXPECT_GE(chunk_size, 1);
 467	test_payload_size = chunk_size + extra_payload_size;
 468	ASSERT_GE(TLS_PAYLOAD_MAX_LEN, test_payload_size);
 469	memset(buf, 1, test_payload_size);
 470	size = write(fd, buf, test_payload_size);
 471	EXPECT_EQ(size, test_payload_size);
 472	fsync(fd);
 473
 474	while (size > 0) {
 475		ret = sendfile(self->fd, fd, &offset, chunk_size);
 476		EXPECT_GE(ret, 0);
 477		size -= ret;
 478	}
 479
 480	EXPECT_EQ(recv(self->cfd, buf, test_payload_size, MSG_WAITALL),
 481		  test_payload_size);
 482
 483	close(fd);
 484}
 485
 486TEST_F(tls, multi_chunk_sendfile)
 487{
 488	chunked_sendfile(_metadata, self, 4096, 4096);
 489	chunked_sendfile(_metadata, self, 4096, 0);
 490	chunked_sendfile(_metadata, self, 4096, 1);
 491	chunked_sendfile(_metadata, self, 4096, 2048);
 492	chunked_sendfile(_metadata, self, 8192, 2048);
 493	chunked_sendfile(_metadata, self, 4096, 8192);
 494	chunked_sendfile(_metadata, self, 8192, 4096);
 495	chunked_sendfile(_metadata, self, 12288, 1024);
 496	chunked_sendfile(_metadata, self, 12288, 2000);
 497	chunked_sendfile(_metadata, self, 15360, 100);
 498	chunked_sendfile(_metadata, self, 15360, 300);
 499	chunked_sendfile(_metadata, self, 1, 4096);
 500	chunked_sendfile(_metadata, self, 2048, 4096);
 501	chunked_sendfile(_metadata, self, 2048, 8192);
 502	chunked_sendfile(_metadata, self, 4096, 8192);
 503	chunked_sendfile(_metadata, self, 1024, 12288);
 504	chunked_sendfile(_metadata, self, 2000, 12288);
 505	chunked_sendfile(_metadata, self, 100, 15360);
 506	chunked_sendfile(_metadata, self, 300, 15360);
 507}
 508
 509TEST_F(tls, recv_max)
 510{
 511	unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
 512	char recv_mem[TLS_PAYLOAD_MAX_LEN];
 513	char buf[TLS_PAYLOAD_MAX_LEN];
 514
 515	memrnd(buf, sizeof(buf));
 516
 517	EXPECT_GE(send(self->fd, buf, send_len, 0), 0);
 518	EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
 519	EXPECT_EQ(memcmp(buf, recv_mem, send_len), 0);
 520}
 521
 522TEST_F(tls, recv_small)
 523{
 524	char const *test_str = "test_read";
 525	int send_len = 10;
 526	char buf[10];
 527
 528	send_len = strlen(test_str) + 1;
 529	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
 530	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
 531	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
 532}
 533
 534TEST_F(tls, msg_more)
 535{
 536	char const *test_str = "test_read";
 537	int send_len = 10;
 538	char buf[10 * 2];
 539
 540	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
 541	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
 542	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
 543	EXPECT_EQ(recv(self->cfd, buf, send_len * 2, MSG_WAITALL),
 544		  send_len * 2);
 545	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
 546}
 547
 548TEST_F(tls, msg_more_unsent)
 549{
 550	char const *test_str = "test_read";
 551	int send_len = 10;
 552	char buf[10];
 553
 554	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
 555	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
 556}
 557
 558TEST_F(tls, msg_eor)
 559{
 560	char const *test_str = "test_read";
 561	int send_len = 10;
 562	char buf[10];
 563
 564	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_EOR), send_len);
 565	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
 566	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
 567}
 568
 569TEST_F(tls, sendmsg_single)
 570{
 571	struct msghdr msg;
 572
 573	char const *test_str = "test_sendmsg";
 574	size_t send_len = 13;
 575	struct iovec vec;
 576	char buf[13];
 577
 578	vec.iov_base = (char *)test_str;
 579	vec.iov_len = send_len;
 580	memset(&msg, 0, sizeof(struct msghdr));
 581	msg.msg_iov = &vec;
 582	msg.msg_iovlen = 1;
 583	EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
 584	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
 585	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
 586}
 587
 588#define MAX_FRAGS	64
 589#define SEND_LEN	13
 590TEST_F(tls, sendmsg_fragmented)
 591{
 592	char const *test_str = "test_sendmsg";
 593	char buf[SEND_LEN * MAX_FRAGS];
 594	struct iovec vec[MAX_FRAGS];
 595	struct msghdr msg;
 596	int i, frags;
 597
 598	for (frags = 1; frags <= MAX_FRAGS; frags++) {
 599		for (i = 0; i < frags; i++) {
 600			vec[i].iov_base = (char *)test_str;
 601			vec[i].iov_len = SEND_LEN;
 602		}
 603
 604		memset(&msg, 0, sizeof(struct msghdr));
 605		msg.msg_iov = vec;
 606		msg.msg_iovlen = frags;
 607
 608		EXPECT_EQ(sendmsg(self->fd, &msg, 0), SEND_LEN * frags);
 609		EXPECT_EQ(recv(self->cfd, buf, SEND_LEN * frags, MSG_WAITALL),
 610			  SEND_LEN * frags);
 611
 612		for (i = 0; i < frags; i++)
 613			EXPECT_EQ(memcmp(buf + SEND_LEN * i,
 614					 test_str, SEND_LEN), 0);
 615	}
 616}
 617#undef MAX_FRAGS
 618#undef SEND_LEN
 619
 620TEST_F(tls, sendmsg_large)
 621{
 622	void *mem = malloc(16384);
 623	size_t send_len = 16384;
 624	size_t sends = 128;
 625	struct msghdr msg;
 626	size_t recvs = 0;
 627	size_t sent = 0;
 628
 629	memset(&msg, 0, sizeof(struct msghdr));
 630	while (sent++ < sends) {
 631		struct iovec vec = { (void *)mem, send_len };
 632
 633		msg.msg_iov = &vec;
 634		msg.msg_iovlen = 1;
 635		EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
 636	}
 637
 638	while (recvs++ < sends) {
 639		EXPECT_NE(recv(self->cfd, mem, send_len, 0), -1);
 640	}
 641
 642	free(mem);
 643}
 644
 645TEST_F(tls, sendmsg_multiple)
 646{
 647	char const *test_str = "test_sendmsg_multiple";
 648	struct iovec vec[5];
 649	char *test_strs[5];
 650	struct msghdr msg;
 651	int total_len = 0;
 652	int len_cmp = 0;
 653	int iov_len = 5;
 654	char *buf;
 655	int i;
 656
 657	memset(&msg, 0, sizeof(struct msghdr));
 658	for (i = 0; i < iov_len; i++) {
 659		test_strs[i] = (char *)malloc(strlen(test_str) + 1);
 660		snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
 661		vec[i].iov_base = (void *)test_strs[i];
 662		vec[i].iov_len = strlen(test_strs[i]) + 1;
 663		total_len += vec[i].iov_len;
 664	}
 665	msg.msg_iov = vec;
 666	msg.msg_iovlen = iov_len;
 667
 668	EXPECT_EQ(sendmsg(self->fd, &msg, 0), total_len);
 669	buf = malloc(total_len);
 670	EXPECT_NE(recv(self->cfd, buf, total_len, 0), -1);
 671	for (i = 0; i < iov_len; i++) {
 672		EXPECT_EQ(memcmp(test_strs[i], buf + len_cmp,
 673				 strlen(test_strs[i])),
 674			  0);
 675		len_cmp += strlen(buf + len_cmp) + 1;
 676	}
 677	for (i = 0; i < iov_len; i++)
 678		free(test_strs[i]);
 679	free(buf);
 680}
 681
 682TEST_F(tls, sendmsg_multiple_stress)
 683{
 684	char const *test_str = "abcdefghijklmno";
 685	struct iovec vec[1024];
 686	char *test_strs[1024];
 687	int iov_len = 1024;
 688	int total_len = 0;
 689	char buf[1 << 14];
 690	struct msghdr msg;
 691	int len_cmp = 0;
 692	int i;
 693
 694	memset(&msg, 0, sizeof(struct msghdr));
 695	for (i = 0; i < iov_len; i++) {
 696		test_strs[i] = (char *)malloc(strlen(test_str) + 1);
 697		snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
 698		vec[i].iov_base = (void *)test_strs[i];
 699		vec[i].iov_len = strlen(test_strs[i]) + 1;
 700		total_len += vec[i].iov_len;
 701	}
 702	msg.msg_iov = vec;
 703	msg.msg_iovlen = iov_len;
 704
 705	EXPECT_EQ(sendmsg(self->fd, &msg, 0), total_len);
 706	EXPECT_NE(recv(self->cfd, buf, total_len, 0), -1);
 707
 708	for (i = 0; i < iov_len; i++)
 709		len_cmp += strlen(buf + len_cmp) + 1;
 710
 711	for (i = 0; i < iov_len; i++)
 712		free(test_strs[i]);
 713}
 714
 715TEST_F(tls, splice_from_pipe)
 716{
 717	int send_len = TLS_PAYLOAD_MAX_LEN;
 718	char mem_send[TLS_PAYLOAD_MAX_LEN];
 719	char mem_recv[TLS_PAYLOAD_MAX_LEN];
 720	int p[2];
 721
 722	ASSERT_GE(pipe(p), 0);
 723	EXPECT_GE(write(p[1], mem_send, send_len), 0);
 724	EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), 0);
 725	EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
 726	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
 727}
 728
 729TEST_F(tls, splice_more)
 730{
 731	unsigned int f = SPLICE_F_NONBLOCK | SPLICE_F_MORE | SPLICE_F_GIFT;
 732	int send_len = TLS_PAYLOAD_MAX_LEN;
 733	char mem_send[TLS_PAYLOAD_MAX_LEN];
 734	int i, send_pipe = 1;
 735	int p[2];
 736
 737	ASSERT_GE(pipe(p), 0);
 738	EXPECT_GE(write(p[1], mem_send, send_len), 0);
 739	for (i = 0; i < 32; i++)
 740		EXPECT_EQ(splice(p[0], NULL, self->fd, NULL, send_pipe, f), 1);
 741}
 742
 743TEST_F(tls, splice_from_pipe2)
 744{
 745	int send_len = 16000;
 746	char mem_send[16000];
 747	char mem_recv[16000];
 748	int p2[2];
 749	int p[2];
 750
 751	memrnd(mem_send, sizeof(mem_send));
 752
 753	ASSERT_GE(pipe(p), 0);
 754	ASSERT_GE(pipe(p2), 0);
 755	EXPECT_EQ(write(p[1], mem_send, 8000), 8000);
 756	EXPECT_EQ(splice(p[0], NULL, self->fd, NULL, 8000, 0), 8000);
 757	EXPECT_EQ(write(p2[1], mem_send + 8000, 8000), 8000);
 758	EXPECT_EQ(splice(p2[0], NULL, self->fd, NULL, 8000, 0), 8000);
 759	EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
 760	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
 761}
 762
 763TEST_F(tls, send_and_splice)
 764{
 765	int send_len = TLS_PAYLOAD_MAX_LEN;
 766	char mem_send[TLS_PAYLOAD_MAX_LEN];
 767	char mem_recv[TLS_PAYLOAD_MAX_LEN];
 768	char const *test_str = "test_read";
 769	int send_len2 = 10;
 770	char buf[10];
 771	int p[2];
 772
 773	ASSERT_GE(pipe(p), 0);
 774	EXPECT_EQ(send(self->fd, test_str, send_len2, 0), send_len2);
 775	EXPECT_EQ(recv(self->cfd, buf, send_len2, MSG_WAITALL), send_len2);
 776	EXPECT_EQ(memcmp(test_str, buf, send_len2), 0);
 777
 778	EXPECT_GE(write(p[1], mem_send, send_len), send_len);
 779	EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), send_len);
 780
 781	EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
 782	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
 783}
 784
 785TEST_F(tls, splice_to_pipe)
 786{
 787	int send_len = TLS_PAYLOAD_MAX_LEN;
 788	char mem_send[TLS_PAYLOAD_MAX_LEN];
 789	char mem_recv[TLS_PAYLOAD_MAX_LEN];
 790	int p[2];
 791
 792	memrnd(mem_send, sizeof(mem_send));
 793
 794	ASSERT_GE(pipe(p), 0);
 795	EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
 796	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), send_len);
 797	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
 798	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
 799}
 800
 801TEST_F(tls, splice_cmsg_to_pipe)
 802{
 803	char *test_str = "test_read";
 804	char record_type = 100;
 805	int send_len = 10;
 806	char buf[10];
 807	int p[2];
 808
 809	if (self->notls)
 810		SKIP(return, "no TLS support");
 811
 812	ASSERT_GE(pipe(p), 0);
 813	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
 814	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
 815	EXPECT_EQ(errno, EINVAL);
 816	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
 817	EXPECT_EQ(errno, EIO);
 818	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
 819				buf, sizeof(buf), MSG_WAITALL),
 820		  send_len);
 821	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
 822}
 823
 824TEST_F(tls, splice_dec_cmsg_to_pipe)
 825{
 826	char *test_str = "test_read";
 827	char record_type = 100;
 828	int send_len = 10;
 829	char buf[10];
 830	int p[2];
 831
 832	if (self->notls)
 833		SKIP(return, "no TLS support");
 834
 835	ASSERT_GE(pipe(p), 0);
 836	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
 837	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
 838	EXPECT_EQ(errno, EIO);
 839	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
 840	EXPECT_EQ(errno, EINVAL);
 841	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
 842				buf, sizeof(buf), MSG_WAITALL),
 843		  send_len);
 844	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
 845}
 846
 847TEST_F(tls, recv_and_splice)
 848{
 849	int send_len = TLS_PAYLOAD_MAX_LEN;
 850	char mem_send[TLS_PAYLOAD_MAX_LEN];
 851	char mem_recv[TLS_PAYLOAD_MAX_LEN];
 852	int half = send_len / 2;
 853	int p[2];
 854
 855	ASSERT_GE(pipe(p), 0);
 856	EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
 857	/* Recv hald of the record, splice the other half */
 858	EXPECT_EQ(recv(self->cfd, mem_recv, half, MSG_WAITALL), half);
 859	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, half, SPLICE_F_NONBLOCK),
 860		  half);
 861	EXPECT_EQ(read(p[0], &mem_recv[half], half), half);
 862	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
 863}
 864
 865TEST_F(tls, peek_and_splice)
 866{
 867	int send_len = TLS_PAYLOAD_MAX_LEN;
 868	char mem_send[TLS_PAYLOAD_MAX_LEN];
 869	char mem_recv[TLS_PAYLOAD_MAX_LEN];
 870	int chunk = TLS_PAYLOAD_MAX_LEN / 4;
 871	int n, i, p[2];
 872
 873	memrnd(mem_send, sizeof(mem_send));
 874
 875	ASSERT_GE(pipe(p), 0);
 876	for (i = 0; i < 4; i++)
 877		EXPECT_EQ(send(self->fd, &mem_send[chunk * i], chunk, 0),
 878			  chunk);
 879
 880	EXPECT_EQ(recv(self->cfd, mem_recv, chunk * 5 / 2,
 881		       MSG_WAITALL | MSG_PEEK),
 882		  chunk * 5 / 2);
 883	EXPECT_EQ(memcmp(mem_send, mem_recv, chunk * 5 / 2), 0);
 884
 885	n = 0;
 886	while (n < send_len) {
 887		i = splice(self->cfd, NULL, p[1], NULL, send_len - n, 0);
 888		EXPECT_GT(i, 0);
 889		n += i;
 890	}
 891	EXPECT_EQ(n, send_len);
 892	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
 893	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
 894}
 895
 896TEST_F(tls, recvmsg_single)
 897{
 898	char const *test_str = "test_recvmsg_single";
 899	int send_len = strlen(test_str) + 1;
 900	char buf[20];
 901	struct msghdr hdr;
 902	struct iovec vec;
 903
 904	memset(&hdr, 0, sizeof(hdr));
 905	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
 906	vec.iov_base = (char *)buf;
 907	vec.iov_len = send_len;
 908	hdr.msg_iovlen = 1;
 909	hdr.msg_iov = &vec;
 910	EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
 911	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
 912}
 913
 914TEST_F(tls, recvmsg_single_max)
 915{
 916	int send_len = TLS_PAYLOAD_MAX_LEN;
 917	char send_mem[TLS_PAYLOAD_MAX_LEN];
 918	char recv_mem[TLS_PAYLOAD_MAX_LEN];
 919	struct iovec vec;
 920	struct msghdr hdr;
 921
 922	memrnd(send_mem, sizeof(send_mem));
 923
 924	EXPECT_EQ(send(self->fd, send_mem, send_len, 0), send_len);
 925	vec.iov_base = (char *)recv_mem;
 926	vec.iov_len = TLS_PAYLOAD_MAX_LEN;
 927
 928	hdr.msg_iovlen = 1;
 929	hdr.msg_iov = &vec;
 930	EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
 931	EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
 932}
 933
 934TEST_F(tls, recvmsg_multiple)
 935{
 936	unsigned int msg_iovlen = 1024;
 937	struct iovec vec[1024];
 938	char *iov_base[1024];
 939	unsigned int iov_len = 16;
 940	int send_len = 1 << 14;
 941	char buf[1 << 14];
 942	struct msghdr hdr;
 943	int i;
 944
 945	memrnd(buf, sizeof(buf));
 946
 947	EXPECT_EQ(send(self->fd, buf, send_len, 0), send_len);
 948	for (i = 0; i < msg_iovlen; i++) {
 949		iov_base[i] = (char *)malloc(iov_len);
 950		vec[i].iov_base = iov_base[i];
 951		vec[i].iov_len = iov_len;
 952	}
 953
 954	hdr.msg_iovlen = msg_iovlen;
 955	hdr.msg_iov = vec;
 956	EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
 957
 958	for (i = 0; i < msg_iovlen; i++)
 959		free(iov_base[i]);
 960}
 961
 962TEST_F(tls, single_send_multiple_recv)
 963{
 964	unsigned int total_len = TLS_PAYLOAD_MAX_LEN * 2;
 965	unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
 966	char send_mem[TLS_PAYLOAD_MAX_LEN * 2];
 967	char recv_mem[TLS_PAYLOAD_MAX_LEN * 2];
 968
 969	memrnd(send_mem, sizeof(send_mem));
 970
 971	EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
 972	memset(recv_mem, 0, total_len);
 973
 974	EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
 975	EXPECT_NE(recv(self->cfd, recv_mem + send_len, send_len, 0), -1);
 976	EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
 977}
 978
 979TEST_F(tls, multiple_send_single_recv)
 980{
 981	unsigned int total_len = 2 * 10;
 982	unsigned int send_len = 10;
 983	char recv_mem[2 * 10];
 984	char send_mem[10];
 985
 986	memrnd(send_mem, sizeof(send_mem));
 987
 988	EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
 989	EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
 990	memset(recv_mem, 0, total_len);
 991	EXPECT_EQ(recv(self->cfd, recv_mem, total_len, MSG_WAITALL), total_len);
 992
 993	EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
 994	EXPECT_EQ(memcmp(send_mem, recv_mem + send_len, send_len), 0);
 995}
 996
 997TEST_F(tls, single_send_multiple_recv_non_align)
 998{
 999	const unsigned int total_len = 15;
1000	const unsigned int recv_len = 10;
1001	char recv_mem[recv_len * 2];
1002	char send_mem[total_len];
1003
1004	memrnd(send_mem, sizeof(send_mem));
1005
1006	EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
1007	memset(recv_mem, 0, total_len);
1008
1009	EXPECT_EQ(recv(self->cfd, recv_mem, recv_len, 0), recv_len);
1010	EXPECT_EQ(recv(self->cfd, recv_mem + recv_len, recv_len, 0), 5);
1011	EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
1012}
1013
1014TEST_F(tls, recv_partial)
1015{
1016	char const *test_str = "test_read_partial";
1017	char const *test_str_first = "test_read";
1018	char const *test_str_second = "_partial";
1019	int send_len = strlen(test_str) + 1;
1020	char recv_mem[18];
1021
1022	memset(recv_mem, 0, sizeof(recv_mem));
1023	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1024	EXPECT_EQ(recv(self->cfd, recv_mem, strlen(test_str_first),
1025		       MSG_WAITALL), strlen(test_str_first));
1026	EXPECT_EQ(memcmp(test_str_first, recv_mem, strlen(test_str_first)), 0);
1027	memset(recv_mem, 0, sizeof(recv_mem));
1028	EXPECT_EQ(recv(self->cfd, recv_mem, strlen(test_str_second),
1029		       MSG_WAITALL), strlen(test_str_second));
1030	EXPECT_EQ(memcmp(test_str_second, recv_mem, strlen(test_str_second)),
1031		  0);
1032}
1033
1034TEST_F(tls, recv_nonblock)
1035{
1036	char buf[4096];
1037	bool err;
1038
1039	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
1040	err = (errno == EAGAIN || errno == EWOULDBLOCK);
1041	EXPECT_EQ(err, true);
1042}
1043
1044TEST_F(tls, recv_peek)
1045{
1046	char const *test_str = "test_read_peek";
1047	int send_len = strlen(test_str) + 1;
1048	char buf[15];
1049
1050	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1051	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_PEEK), send_len);
1052	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1053	memset(buf, 0, sizeof(buf));
1054	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1055	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1056}
1057
1058TEST_F(tls, recv_peek_multiple)
1059{
1060	char const *test_str = "test_read_peek";
1061	int send_len = strlen(test_str) + 1;
1062	unsigned int num_peeks = 100;
1063	char buf[15];
1064	int i;
1065
1066	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1067	for (i = 0; i < num_peeks; i++) {
1068		EXPECT_NE(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
1069		EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1070		memset(buf, 0, sizeof(buf));
1071	}
1072	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1073	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1074}
1075
1076TEST_F(tls, recv_peek_multiple_records)
1077{
1078	char const *test_str = "test_read_peek_mult_recs";
1079	char const *test_str_first = "test_read_peek";
1080	char const *test_str_second = "_mult_recs";
1081	int len;
1082	char buf[64];
1083
1084	len = strlen(test_str_first);
1085	EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
1086
1087	len = strlen(test_str_second) + 1;
1088	EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1089
1090	len = strlen(test_str_first);
1091	memset(buf, 0, len);
1092	EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
1093
1094	/* MSG_PEEK can only peek into the current record. */
1095	len = strlen(test_str_first);
1096	EXPECT_EQ(memcmp(test_str_first, buf, len), 0);
1097
1098	len = strlen(test_str) + 1;
1099	memset(buf, 0, len);
1100	EXPECT_EQ(recv(self->cfd, buf, len, MSG_WAITALL), len);
1101
1102	/* Non-MSG_PEEK will advance strparser (and therefore record)
1103	 * however.
1104	 */
1105	len = strlen(test_str) + 1;
1106	EXPECT_EQ(memcmp(test_str, buf, len), 0);
1107
1108	/* MSG_MORE will hold current record open, so later MSG_PEEK
1109	 * will see everything.
1110	 */
1111	len = strlen(test_str_first);
1112	EXPECT_EQ(send(self->fd, test_str_first, len, MSG_MORE), len);
1113
1114	len = strlen(test_str_second) + 1;
1115	EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1116
1117	len = strlen(test_str) + 1;
1118	memset(buf, 0, len);
1119	EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
1120
1121	len = strlen(test_str) + 1;
1122	EXPECT_EQ(memcmp(test_str, buf, len), 0);
1123}
1124
1125TEST_F(tls, recv_peek_large_buf_mult_recs)
1126{
1127	char const *test_str = "test_read_peek_mult_recs";
1128	char const *test_str_first = "test_read_peek";
1129	char const *test_str_second = "_mult_recs";
1130	int len;
1131	char buf[64];
1132
1133	len = strlen(test_str_first);
1134	EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
1135
1136	len = strlen(test_str_second) + 1;
1137	EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1138
1139	len = strlen(test_str) + 1;
1140	memset(buf, 0, len);
1141	EXPECT_NE((len = recv(self->cfd, buf, len,
1142			      MSG_PEEK | MSG_WAITALL)), -1);
1143	len = strlen(test_str) + 1;
1144	EXPECT_EQ(memcmp(test_str, buf, len), 0);
1145}
1146
1147TEST_F(tls, recv_lowat)
1148{
1149	char send_mem[10] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
1150	char recv_mem[20];
1151	int lowat = 8;
1152
1153	EXPECT_EQ(send(self->fd, send_mem, 10, 0), 10);
1154	EXPECT_EQ(send(self->fd, send_mem, 5, 0), 5);
1155
1156	memset(recv_mem, 0, 20);
1157	EXPECT_EQ(setsockopt(self->cfd, SOL_SOCKET, SO_RCVLOWAT,
1158			     &lowat, sizeof(lowat)), 0);
1159	EXPECT_EQ(recv(self->cfd, recv_mem, 1, MSG_WAITALL), 1);
1160	EXPECT_EQ(recv(self->cfd, recv_mem + 1, 6, MSG_WAITALL), 6);
1161	EXPECT_EQ(recv(self->cfd, recv_mem + 7, 10, 0), 8);
1162
1163	EXPECT_EQ(memcmp(send_mem, recv_mem, 10), 0);
1164	EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0);
1165}
1166
1167TEST_F(tls, bidir)
1168{
1169	char const *test_str = "test_read";
1170	int send_len = 10;
1171	char buf[10];
1172	int ret;
1173
1174	if (!self->notls) {
1175		struct tls_crypto_info_keys tls12;
1176
1177		tls_crypto_info_init(variant->tls_version, variant->cipher_type,
1178				     &tls12);
1179
1180		ret = setsockopt(self->fd, SOL_TLS, TLS_RX, &tls12,
1181				 tls12.len);
1182		ASSERT_EQ(ret, 0);
1183
1184		ret = setsockopt(self->cfd, SOL_TLS, TLS_TX, &tls12,
1185				 tls12.len);
1186		ASSERT_EQ(ret, 0);
1187	}
1188
1189	ASSERT_EQ(strlen(test_str) + 1, send_len);
1190
1191	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1192	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1193	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1194
1195	memset(buf, 0, sizeof(buf));
1196
1197	EXPECT_EQ(send(self->cfd, test_str, send_len, 0), send_len);
1198	EXPECT_NE(recv(self->fd, buf, send_len, 0), -1);
1199	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1200};
1201
1202TEST_F(tls, pollin)
1203{
1204	char const *test_str = "test_poll";
1205	struct pollfd fd = { 0, 0, 0 };
1206	char buf[10];
1207	int send_len = 10;
1208
1209	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1210	fd.fd = self->cfd;
1211	fd.events = POLLIN;
1212
1213	EXPECT_EQ(poll(&fd, 1, 20), 1);
1214	EXPECT_EQ(fd.revents & POLLIN, 1);
1215	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
1216	/* Test timing out */
1217	EXPECT_EQ(poll(&fd, 1, 20), 0);
1218}
1219
1220TEST_F(tls, poll_wait)
1221{
1222	char const *test_str = "test_poll_wait";
1223	int send_len = strlen(test_str) + 1;
1224	struct pollfd fd = { 0, 0, 0 };
1225	char recv_mem[15];
1226
1227	fd.fd = self->cfd;
1228	fd.events = POLLIN;
1229	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1230	/* Set timeout to inf. secs */
1231	EXPECT_EQ(poll(&fd, 1, -1), 1);
1232	EXPECT_EQ(fd.revents & POLLIN, 1);
1233	EXPECT_EQ(recv(self->cfd, recv_mem, send_len, MSG_WAITALL), send_len);
1234}
1235
1236TEST_F(tls, poll_wait_split)
1237{
1238	struct pollfd fd = { 0, 0, 0 };
1239	char send_mem[20] = {};
1240	char recv_mem[15];
1241
1242	fd.fd = self->cfd;
1243	fd.events = POLLIN;
1244	/* Send 20 bytes */
1245	EXPECT_EQ(send(self->fd, send_mem, sizeof(send_mem), 0),
1246		  sizeof(send_mem));
1247	/* Poll with inf. timeout */
1248	EXPECT_EQ(poll(&fd, 1, -1), 1);
1249	EXPECT_EQ(fd.revents & POLLIN, 1);
1250	EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), MSG_WAITALL),
1251		  sizeof(recv_mem));
1252
1253	/* Now the remaining 5 bytes of record data are in TLS ULP */
1254	fd.fd = self->cfd;
1255	fd.events = POLLIN;
1256	EXPECT_EQ(poll(&fd, 1, -1), 1);
1257	EXPECT_EQ(fd.revents & POLLIN, 1);
1258	EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0),
1259		  sizeof(send_mem) - sizeof(recv_mem));
1260}
1261
1262TEST_F(tls, blocking)
1263{
1264	size_t data = 100000;
1265	int res = fork();
1266
1267	EXPECT_NE(res, -1);
1268
1269	if (res) {
1270		/* parent */
1271		size_t left = data;
1272		char buf[16384];
1273		int status;
1274		int pid2;
1275
1276		while (left) {
1277			int res = send(self->fd, buf,
1278				       left > 16384 ? 16384 : left, 0);
1279
1280			EXPECT_GE(res, 0);
1281			left -= res;
1282		}
1283
1284		pid2 = wait(&status);
1285		EXPECT_EQ(status, 0);
1286		EXPECT_EQ(res, pid2);
1287	} else {
1288		/* child */
1289		size_t left = data;
1290		char buf[16384];
1291
1292		while (left) {
1293			int res = recv(self->cfd, buf,
1294				       left > 16384 ? 16384 : left, 0);
1295
1296			EXPECT_GE(res, 0);
1297			left -= res;
1298		}
1299	}
1300}
1301
1302TEST_F(tls, nonblocking)
1303{
1304	size_t data = 100000;
1305	int sendbuf = 100;
1306	int flags;
1307	int res;
1308
1309	flags = fcntl(self->fd, F_GETFL, 0);
1310	fcntl(self->fd, F_SETFL, flags | O_NONBLOCK);
1311	fcntl(self->cfd, F_SETFL, flags | O_NONBLOCK);
1312
1313	/* Ensure nonblocking behavior by imposing a small send
1314	 * buffer.
1315	 */
1316	EXPECT_EQ(setsockopt(self->fd, SOL_SOCKET, SO_SNDBUF,
1317			     &sendbuf, sizeof(sendbuf)), 0);
1318
1319	res = fork();
1320	EXPECT_NE(res, -1);
1321
1322	if (res) {
1323		/* parent */
1324		bool eagain = false;
1325		size_t left = data;
1326		char buf[16384];
1327		int status;
1328		int pid2;
1329
1330		while (left) {
1331			int res = send(self->fd, buf,
1332				       left > 16384 ? 16384 : left, 0);
1333
1334			if (res == -1 && errno == EAGAIN) {
1335				eagain = true;
1336				usleep(10000);
1337				continue;
1338			}
1339			EXPECT_GE(res, 0);
1340			left -= res;
1341		}
1342
1343		EXPECT_TRUE(eagain);
1344		pid2 = wait(&status);
1345
1346		EXPECT_EQ(status, 0);
1347		EXPECT_EQ(res, pid2);
1348	} else {
1349		/* child */
1350		bool eagain = false;
1351		size_t left = data;
1352		char buf[16384];
1353
1354		while (left) {
1355			int res = recv(self->cfd, buf,
1356				       left > 16384 ? 16384 : left, 0);
1357
1358			if (res == -1 && errno == EAGAIN) {
1359				eagain = true;
1360				usleep(10000);
1361				continue;
1362			}
1363			EXPECT_GE(res, 0);
1364			left -= res;
1365		}
1366		EXPECT_TRUE(eagain);
1367	}
1368}
1369
1370static void
1371test_mutliproc(struct __test_metadata *_metadata, struct _test_data_tls *self,
1372	       bool sendpg, unsigned int n_readers, unsigned int n_writers)
1373{
1374	const unsigned int n_children = n_readers + n_writers;
1375	const size_t data = 6 * 1000 * 1000;
1376	const size_t file_sz = data / 100;
1377	size_t read_bias, write_bias;
1378	int i, fd, child_id;
1379	char buf[file_sz];
1380	pid_t pid;
1381
1382	/* Only allow multiples for simplicity */
1383	ASSERT_EQ(!(n_readers % n_writers) || !(n_writers % n_readers), true);
1384	read_bias = n_writers / n_readers ?: 1;
1385	write_bias = n_readers / n_writers ?: 1;
1386
1387	/* prep a file to send */
1388	fd = open("/tmp/", O_TMPFILE | O_RDWR, 0600);
1389	ASSERT_GE(fd, 0);
1390
1391	memset(buf, 0xac, file_sz);
1392	ASSERT_EQ(write(fd, buf, file_sz), file_sz);
1393
1394	/* spawn children */
1395	for (child_id = 0; child_id < n_children; child_id++) {
1396		pid = fork();
1397		ASSERT_NE(pid, -1);
1398		if (!pid)
1399			break;
1400	}
1401
1402	/* parent waits for all children */
1403	if (pid) {
1404		for (i = 0; i < n_children; i++) {
1405			int status;
1406
1407			wait(&status);
1408			EXPECT_EQ(status, 0);
1409		}
1410
1411		return;
1412	}
1413
1414	/* Split threads for reading and writing */
1415	if (child_id < n_readers) {
1416		size_t left = data * read_bias;
1417		char rb[8001];
1418
1419		while (left) {
1420			int res;
1421
1422			res = recv(self->cfd, rb,
1423				   left > sizeof(rb) ? sizeof(rb) : left, 0);
1424
1425			EXPECT_GE(res, 0);
1426			left -= res;
1427		}
1428	} else {
1429		size_t left = data * write_bias;
1430
1431		while (left) {
1432			int res;
1433
1434			ASSERT_EQ(lseek(fd, 0, SEEK_SET), 0);
1435			if (sendpg)
1436				res = sendfile(self->fd, fd, NULL,
1437					       left > file_sz ? file_sz : left);
1438			else
1439				res = send(self->fd, buf,
1440					   left > file_sz ? file_sz : left, 0);
1441
1442			EXPECT_GE(res, 0);
1443			left -= res;
1444		}
1445	}
1446}
1447
1448TEST_F(tls, mutliproc_even)
1449{
1450	test_mutliproc(_metadata, self, false, 6, 6);
1451}
1452
1453TEST_F(tls, mutliproc_readers)
1454{
1455	test_mutliproc(_metadata, self, false, 4, 12);
1456}
1457
1458TEST_F(tls, mutliproc_writers)
1459{
1460	test_mutliproc(_metadata, self, false, 10, 2);
1461}
1462
1463TEST_F(tls, mutliproc_sendpage_even)
1464{
1465	test_mutliproc(_metadata, self, true, 6, 6);
1466}
1467
1468TEST_F(tls, mutliproc_sendpage_readers)
1469{
1470	test_mutliproc(_metadata, self, true, 4, 12);
1471}
1472
1473TEST_F(tls, mutliproc_sendpage_writers)
1474{
1475	test_mutliproc(_metadata, self, true, 10, 2);
1476}
1477
1478TEST_F(tls, control_msg)
1479{
1480	char *test_str = "test_read";
1481	char record_type = 100;
1482	int send_len = 10;
1483	char buf[10];
1484
1485	if (self->notls)
1486		SKIP(return, "no TLS support");
1487
1488	EXPECT_EQ(tls_send_cmsg(self->fd, record_type, test_str, send_len, 0),
1489		  send_len);
1490	/* Should fail because we didn't provide a control message */
1491	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
1492
1493	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
1494				buf, sizeof(buf), MSG_WAITALL | MSG_PEEK),
1495		  send_len);
1496	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1497
1498	/* Recv the message again without MSG_PEEK */
1499	memset(buf, 0, sizeof(buf));
1500
1501	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
1502				buf, sizeof(buf), MSG_WAITALL),
1503		  send_len);
1504	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1505}
1506
1507TEST_F(tls, control_msg_nomerge)
1508{
1509	char *rec1 = "1111";
1510	char *rec2 = "2222";
1511	int send_len = 5;
1512	char buf[15];
1513
1514	if (self->notls)
1515		SKIP(return, "no TLS support");
1516
1517	EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec1, send_len, 0), send_len);
1518	EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec2, send_len, 0), send_len);
1519
1520	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), MSG_PEEK), send_len);
1521	EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1522
1523	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), MSG_PEEK), send_len);
1524	EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1525
1526	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), 0), send_len);
1527	EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1528
1529	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), 0), send_len);
1530	EXPECT_EQ(memcmp(buf, rec2, send_len), 0);
1531}
1532
1533TEST_F(tls, data_control_data)
1534{
1535	char *rec1 = "1111";
1536	char *rec2 = "2222";
1537	char *rec3 = "3333";
1538	int send_len = 5;
1539	char buf[15];
1540
1541	if (self->notls)
1542		SKIP(return, "no TLS support");
1543
1544	EXPECT_EQ(send(self->fd, rec1, send_len, 0), send_len);
1545	EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec2, send_len, 0), send_len);
1546	EXPECT_EQ(send(self->fd, rec3, send_len, 0), send_len);
1547
1548	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
1549	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
1550}
1551
1552TEST_F(tls, shutdown)
1553{
1554	char const *test_str = "test_read";
1555	int send_len = 10;
1556	char buf[10];
1557
1558	ASSERT_EQ(strlen(test_str) + 1, send_len);
1559
1560	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1561	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1562	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1563
1564	shutdown(self->fd, SHUT_RDWR);
1565	shutdown(self->cfd, SHUT_RDWR);
1566}
1567
1568TEST_F(tls, shutdown_unsent)
1569{
1570	char const *test_str = "test_read";
1571	int send_len = 10;
1572
1573	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
1574
1575	shutdown(self->fd, SHUT_RDWR);
1576	shutdown(self->cfd, SHUT_RDWR);
1577}
1578
1579TEST_F(tls, shutdown_reuse)
1580{
1581	struct sockaddr_in addr;
1582	int ret;
1583
1584	shutdown(self->fd, SHUT_RDWR);
1585	shutdown(self->cfd, SHUT_RDWR);
1586	close(self->cfd);
1587
1588	addr.sin_family = AF_INET;
1589	addr.sin_addr.s_addr = htonl(INADDR_ANY);
1590	addr.sin_port = 0;
1591
1592	ret = bind(self->fd, &addr, sizeof(addr));
1593	EXPECT_EQ(ret, 0);
1594	ret = listen(self->fd, 10);
1595	EXPECT_EQ(ret, -1);
1596	EXPECT_EQ(errno, EINVAL);
1597
1598	ret = connect(self->fd, &addr, sizeof(addr));
1599	EXPECT_EQ(ret, -1);
1600	EXPECT_EQ(errno, EISCONN);
1601}
1602
1603TEST_F(tls, getsockopt)
1604{
1605	struct tls_crypto_info_keys expect, get;
1606	socklen_t len;
1607
1608	/* get only the version/cipher */
1609	len = sizeof(struct tls_crypto_info);
1610	memrnd(&get, sizeof(get));
1611	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), 0);
1612	EXPECT_EQ(len, sizeof(struct tls_crypto_info));
1613	EXPECT_EQ(get.crypto_info.version, variant->tls_version);
1614	EXPECT_EQ(get.crypto_info.cipher_type, variant->cipher_type);
1615
1616	/* get the full crypto_info */
1617	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &expect);
1618	len = expect.len;
1619	memrnd(&get, sizeof(get));
1620	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), 0);
1621	EXPECT_EQ(len, expect.len);
1622	EXPECT_EQ(get.crypto_info.version, variant->tls_version);
1623	EXPECT_EQ(get.crypto_info.cipher_type, variant->cipher_type);
1624	EXPECT_EQ(memcmp(&get, &expect, expect.len), 0);
1625
1626	/* short get should fail */
1627	len = sizeof(struct tls_crypto_info) - 1;
1628	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), -1);
1629	EXPECT_EQ(errno, EINVAL);
1630
1631	/* partial get of the cipher data should fail */
1632	len = expect.len - 1;
1633	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), -1);
1634	EXPECT_EQ(errno, EINVAL);
1635}
1636
1637TEST_F(tls, recv_efault)
1638{
1639	char *rec1 = "1111111111";
1640	char *rec2 = "2222222222";
1641	struct msghdr hdr = {};
1642	struct iovec iov[2];
1643	char recv_mem[12];
1644	int ret;
1645
1646	if (self->notls)
1647		SKIP(return, "no TLS support");
1648
1649	EXPECT_EQ(send(self->fd, rec1, 10, 0), 10);
1650	EXPECT_EQ(send(self->fd, rec2, 10, 0), 10);
1651
1652	iov[0].iov_base = recv_mem;
1653	iov[0].iov_len = sizeof(recv_mem);
1654	iov[1].iov_base = NULL; /* broken iov to make process_rx_list fail */
1655	iov[1].iov_len = 1;
1656
1657	hdr.msg_iovlen = 2;
1658	hdr.msg_iov = iov;
1659
1660	EXPECT_EQ(recv(self->cfd, recv_mem, 1, 0), 1);
1661	EXPECT_EQ(recv_mem[0], rec1[0]);
1662
1663	ret = recvmsg(self->cfd, &hdr, 0);
1664	EXPECT_LE(ret, sizeof(recv_mem));
1665	EXPECT_GE(ret, 9);
1666	EXPECT_EQ(memcmp(rec1, recv_mem, 9), 0);
1667	if (ret > 9)
1668		EXPECT_EQ(memcmp(rec2, recv_mem + 9, ret - 9), 0);
1669}
1670
1671FIXTURE(tls_err)
1672{
1673	int fd, cfd;
1674	int fd2, cfd2;
1675	bool notls;
1676};
1677
1678FIXTURE_VARIANT(tls_err)
1679{
1680	uint16_t tls_version;
1681};
1682
1683FIXTURE_VARIANT_ADD(tls_err, 12_aes_gcm)
1684{
1685	.tls_version = TLS_1_2_VERSION,
1686};
1687
1688FIXTURE_VARIANT_ADD(tls_err, 13_aes_gcm)
1689{
1690	.tls_version = TLS_1_3_VERSION,
1691};
1692
1693FIXTURE_SETUP(tls_err)
1694{
1695	struct tls_crypto_info_keys tls12;
1696	int ret;
1697
1698	tls_crypto_info_init(variant->tls_version, TLS_CIPHER_AES_GCM_128,
1699			     &tls12);
1700
1701	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
1702	ulp_sock_pair(_metadata, &self->fd2, &self->cfd2, &self->notls);
1703	if (self->notls)
1704		return;
1705
1706	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
1707	ASSERT_EQ(ret, 0);
1708
1709	ret = setsockopt(self->cfd2, SOL_TLS, TLS_RX, &tls12, tls12.len);
1710	ASSERT_EQ(ret, 0);
1711}
1712
1713FIXTURE_TEARDOWN(tls_err)
1714{
1715	close(self->fd);
1716	close(self->cfd);
1717	close(self->fd2);
1718	close(self->cfd2);
1719}
1720
1721TEST_F(tls_err, bad_rec)
1722{
1723	char buf[64];
1724
1725	if (self->notls)
1726		SKIP(return, "no TLS support");
1727
1728	memset(buf, 0x55, sizeof(buf));
1729	EXPECT_EQ(send(self->fd2, buf, sizeof(buf), 0), sizeof(buf));
1730	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1731	EXPECT_EQ(errno, EMSGSIZE);
1732	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), MSG_DONTWAIT), -1);
1733	EXPECT_EQ(errno, EAGAIN);
1734}
1735
1736TEST_F(tls_err, bad_auth)
1737{
1738	char buf[128];
1739	int n;
1740
1741	if (self->notls)
1742		SKIP(return, "no TLS support");
1743
1744	memrnd(buf, sizeof(buf) / 2);
1745	EXPECT_EQ(send(self->fd, buf, sizeof(buf) / 2, 0), sizeof(buf) / 2);
1746	n = recv(self->cfd, buf, sizeof(buf), 0);
1747	EXPECT_GT(n, sizeof(buf) / 2);
1748
1749	buf[n - 1]++;
1750
1751	EXPECT_EQ(send(self->fd2, buf, n, 0), n);
1752	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1753	EXPECT_EQ(errno, EBADMSG);
1754	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1755	EXPECT_EQ(errno, EBADMSG);
1756}
1757
1758TEST_F(tls_err, bad_in_large_read)
1759{
1760	char txt[3][64];
1761	char cip[3][128];
1762	char buf[3 * 128];
1763	int i, n;
1764
1765	if (self->notls)
1766		SKIP(return, "no TLS support");
1767
1768	/* Put 3 records in the sockets */
1769	for (i = 0; i < 3; i++) {
1770		memrnd(txt[i], sizeof(txt[i]));
1771		EXPECT_EQ(send(self->fd, txt[i], sizeof(txt[i]), 0),
1772			  sizeof(txt[i]));
1773		n = recv(self->cfd, cip[i], sizeof(cip[i]), 0);
1774		EXPECT_GT(n, sizeof(txt[i]));
1775		/* Break the third message */
1776		if (i == 2)
1777			cip[2][n - 1]++;
1778		EXPECT_EQ(send(self->fd2, cip[i], n, 0), n);
1779	}
1780
1781	/* We should be able to receive the first two messages */
1782	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt[0]) * 2);
1783	EXPECT_EQ(memcmp(buf, txt[0], sizeof(txt[0])), 0);
1784	EXPECT_EQ(memcmp(buf + sizeof(txt[0]), txt[1], sizeof(txt[1])), 0);
1785	/* Third mesasge is bad */
1786	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1787	EXPECT_EQ(errno, EBADMSG);
1788	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1789	EXPECT_EQ(errno, EBADMSG);
1790}
1791
1792TEST_F(tls_err, bad_cmsg)
1793{
1794	char *test_str = "test_read";
1795	int send_len = 10;
1796	char cip[128];
1797	char buf[128];
1798	char txt[64];
1799	int n;
1800
1801	if (self->notls)
1802		SKIP(return, "no TLS support");
1803
1804	/* Queue up one data record */
1805	memrnd(txt, sizeof(txt));
1806	EXPECT_EQ(send(self->fd, txt, sizeof(txt), 0), sizeof(txt));
1807	n = recv(self->cfd, cip, sizeof(cip), 0);
1808	EXPECT_GT(n, sizeof(txt));
1809	EXPECT_EQ(send(self->fd2, cip, n, 0), n);
1810
1811	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
1812	n = recv(self->cfd, cip, sizeof(cip), 0);
1813	cip[n - 1]++; /* Break it */
1814	EXPECT_GT(n, send_len);
1815	EXPECT_EQ(send(self->fd2, cip, n, 0), n);
1816
1817	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt));
1818	EXPECT_EQ(memcmp(buf, txt, sizeof(txt)), 0);
1819	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1820	EXPECT_EQ(errno, EBADMSG);
1821	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1822	EXPECT_EQ(errno, EBADMSG);
1823}
1824
1825TEST_F(tls_err, timeo)
1826{
1827	struct timeval tv = { .tv_usec = 10000, };
1828	char buf[128];
1829	int ret;
1830
1831	if (self->notls)
1832		SKIP(return, "no TLS support");
1833
1834	ret = setsockopt(self->cfd2, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
1835	ASSERT_EQ(ret, 0);
1836
1837	ret = fork();
1838	ASSERT_GE(ret, 0);
1839
1840	if (ret) {
1841		usleep(1000); /* Give child a head start */
1842
1843		EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1844		EXPECT_EQ(errno, EAGAIN);
1845
1846		EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1847		EXPECT_EQ(errno, EAGAIN);
1848
1849		wait(&ret);
1850	} else {
1851		EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1852		EXPECT_EQ(errno, EAGAIN);
1853		exit(0);
1854	}
1855}
1856
1857TEST_F(tls_err, poll_partial_rec)
1858{
1859	struct pollfd pfd = { };
1860	ssize_t rec_len;
1861	char rec[256];
1862	char buf[128];
1863
1864	if (self->notls)
1865		SKIP(return, "no TLS support");
1866
1867	pfd.fd = self->cfd2;
1868	pfd.events = POLLIN;
1869	EXPECT_EQ(poll(&pfd, 1, 1), 0);
1870
1871	memrnd(buf, sizeof(buf));
1872	EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
1873	rec_len = recv(self->cfd, rec, sizeof(rec), 0);
1874	EXPECT_GT(rec_len, sizeof(buf));
1875
1876	/* Write 100B, not the full record ... */
1877	EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
1878	/* ... no full record should mean no POLLIN */
1879	pfd.fd = self->cfd2;
1880	pfd.events = POLLIN;
1881	EXPECT_EQ(poll(&pfd, 1, 1), 0);
1882	/* Now write the rest, and it should all pop out of the other end. */
1883	EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0), rec_len - 100);
1884	pfd.fd = self->cfd2;
1885	pfd.events = POLLIN;
1886	EXPECT_EQ(poll(&pfd, 1, 1), 1);
1887	EXPECT_EQ(recv(self->cfd2, rec, sizeof(rec), 0), sizeof(buf));
1888	EXPECT_EQ(memcmp(buf, rec, sizeof(buf)), 0);
1889}
1890
1891TEST_F(tls_err, epoll_partial_rec)
1892{
1893	struct epoll_event ev, events[10];
1894	ssize_t rec_len;
1895	char rec[256];
1896	char buf[128];
1897	int epollfd;
1898
1899	if (self->notls)
1900		SKIP(return, "no TLS support");
1901
1902	epollfd = epoll_create1(0);
1903	ASSERT_GE(epollfd, 0);
1904
1905	memset(&ev, 0, sizeof(ev));
1906	ev.events = EPOLLIN;
1907	ev.data.fd = self->cfd2;
1908	ASSERT_GE(epoll_ctl(epollfd, EPOLL_CTL_ADD, self->cfd2, &ev), 0);
1909
1910	EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 0);
1911
1912	memrnd(buf, sizeof(buf));
1913	EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
1914	rec_len = recv(self->cfd, rec, sizeof(rec), 0);
1915	EXPECT_GT(rec_len, sizeof(buf));
1916
1917	/* Write 100B, not the full record ... */
1918	EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
1919	/* ... no full record should mean no POLLIN */
1920	EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 0);
1921	/* Now write the rest, and it should all pop out of the other end. */
1922	EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0), rec_len - 100);
1923	EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 1);
1924	EXPECT_EQ(recv(self->cfd2, rec, sizeof(rec), 0), sizeof(buf));
1925	EXPECT_EQ(memcmp(buf, rec, sizeof(buf)), 0);
1926
1927	close(epollfd);
1928}
1929
1930TEST_F(tls_err, poll_partial_rec_async)
1931{
1932	struct pollfd pfd = { };
1933	ssize_t rec_len;
1934	char rec[256];
1935	char buf[128];
1936	char token;
1937	int p[2];
1938	int ret;
1939
1940	if (self->notls)
1941		SKIP(return, "no TLS support");
1942
1943	ASSERT_GE(pipe(p), 0);
1944
1945	memrnd(buf, sizeof(buf));
1946	EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
1947	rec_len = recv(self->cfd, rec, sizeof(rec), 0);
1948	EXPECT_GT(rec_len, sizeof(buf));
1949
1950	ret = fork();
1951	ASSERT_GE(ret, 0);
1952
1953	if (ret) {
1954		int status, pid2;
1955
1956		close(p[1]);
1957		usleep(1000); /* Give child a head start */
1958
1959		EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
1960
1961		EXPECT_EQ(read(p[0], &token, 1), 1); /* Barrier #1 */
1962
1963		EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0),
1964			  rec_len - 100);
1965
1966		pid2 = wait(&status);
1967		EXPECT_EQ(pid2, ret);
1968		EXPECT_EQ(status, 0);
1969	} else {
1970		close(p[0]);
1971
1972		/* Child should sleep in poll(), never get a wake */
1973		pfd.fd = self->cfd2;
1974		pfd.events = POLLIN;
1975		EXPECT_EQ(poll(&pfd, 1, 20), 0);
1976
1977		EXPECT_EQ(write(p[1], &token, 1), 1); /* Barrier #1 */
1978
1979		pfd.fd = self->cfd2;
1980		pfd.events = POLLIN;
1981		EXPECT_EQ(poll(&pfd, 1, 20), 1);
1982
1983		exit(!__test_passed(_metadata));
1984	}
1985}
1986
1987TEST(non_established) {
1988	struct tls12_crypto_info_aes_gcm_256 tls12;
1989	struct sockaddr_in addr;
1990	int sfd, ret, fd;
1991	socklen_t len;
1992
1993	len = sizeof(addr);
1994
1995	memset(&tls12, 0, sizeof(tls12));
1996	tls12.info.version = TLS_1_2_VERSION;
1997	tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
1998
1999	addr.sin_family = AF_INET;
2000	addr.sin_addr.s_addr = htonl(INADDR_ANY);
2001	addr.sin_port = 0;
2002
2003	fd = socket(AF_INET, SOCK_STREAM, 0);
2004	sfd = socket(AF_INET, SOCK_STREAM, 0);
2005
2006	ret = bind(sfd, &addr, sizeof(addr));
2007	ASSERT_EQ(ret, 0);
2008	ret = listen(sfd, 10);
2009	ASSERT_EQ(ret, 0);
2010
2011	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2012	EXPECT_EQ(ret, -1);
2013	/* TLS ULP not supported */
2014	if (errno == ENOENT)
2015		return;
2016	EXPECT_EQ(errno, ENOTCONN);
2017
2018	ret = setsockopt(sfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2019	EXPECT_EQ(ret, -1);
2020	EXPECT_EQ(errno, ENOTCONN);
2021
2022	ret = getsockname(sfd, &addr, &len);
2023	ASSERT_EQ(ret, 0);
2024
2025	ret = connect(fd, &addr, sizeof(addr));
2026	ASSERT_EQ(ret, 0);
2027
2028	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2029	ASSERT_EQ(ret, 0);
2030
2031	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2032	EXPECT_EQ(ret, -1);
2033	EXPECT_EQ(errno, EEXIST);
2034
2035	close(fd);
2036	close(sfd);
2037}
2038
2039TEST(keysizes) {
2040	struct tls12_crypto_info_aes_gcm_256 tls12;
2041	int ret, fd, cfd;
2042	bool notls;
2043
2044	memset(&tls12, 0, sizeof(tls12));
2045	tls12.info.version = TLS_1_2_VERSION;
2046	tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
2047
2048	ulp_sock_pair(_metadata, &fd, &cfd, &notls);
2049
2050	if (!notls) {
2051		ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12,
2052				 sizeof(tls12));
2053		EXPECT_EQ(ret, 0);
2054
2055		ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12,
2056				 sizeof(tls12));
2057		EXPECT_EQ(ret, 0);
2058	}
2059
2060	close(fd);
2061	close(cfd);
2062}
2063
2064TEST(no_pad) {
2065	struct tls12_crypto_info_aes_gcm_256 tls12;
2066	int ret, fd, cfd, val;
2067	socklen_t len;
2068	bool notls;
2069
2070	memset(&tls12, 0, sizeof(tls12));
2071	tls12.info.version = TLS_1_3_VERSION;
2072	tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
2073
2074	ulp_sock_pair(_metadata, &fd, &cfd, &notls);
2075
2076	if (notls)
2077		exit(KSFT_SKIP);
2078
2079	ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12, sizeof(tls12));
2080	EXPECT_EQ(ret, 0);
2081
2082	ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12, sizeof(tls12));
2083	EXPECT_EQ(ret, 0);
2084
2085	val = 1;
2086	ret = setsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2087			 (void *)&val, sizeof(val));
2088	EXPECT_EQ(ret, 0);
2089
2090	len = sizeof(val);
2091	val = 2;
2092	ret = getsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2093			 (void *)&val, &len);
2094	EXPECT_EQ(ret, 0);
2095	EXPECT_EQ(val, 1);
2096	EXPECT_EQ(len, 4);
2097
2098	val = 0;
2099	ret = setsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2100			 (void *)&val, sizeof(val));
2101	EXPECT_EQ(ret, 0);
2102
2103	len = sizeof(val);
2104	val = 2;
2105	ret = getsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2106			 (void *)&val, &len);
2107	EXPECT_EQ(ret, 0);
2108	EXPECT_EQ(val, 0);
2109	EXPECT_EQ(len, 4);
2110
2111	close(fd);
2112	close(cfd);
2113}
2114
2115TEST(tls_v6ops) {
2116	struct tls_crypto_info_keys tls12;
2117	struct sockaddr_in6 addr, addr2;
2118	int sfd, ret, fd;
2119	socklen_t len, len2;
2120
2121	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_128, &tls12);
2122
2123	addr.sin6_family = AF_INET6;
2124	addr.sin6_addr = in6addr_any;
2125	addr.sin6_port = 0;
2126
2127	fd = socket(AF_INET6, SOCK_STREAM, 0);
2128	sfd = socket(AF_INET6, SOCK_STREAM, 0);
2129
2130	ret = bind(sfd, &addr, sizeof(addr));
2131	ASSERT_EQ(ret, 0);
2132	ret = listen(sfd, 10);
2133	ASSERT_EQ(ret, 0);
2134
2135	len = sizeof(addr);
2136	ret = getsockname(sfd, &addr, &len);
2137	ASSERT_EQ(ret, 0);
2138
2139	ret = connect(fd, &addr, sizeof(addr));
2140	ASSERT_EQ(ret, 0);
2141
2142	len = sizeof(addr);
2143	ret = getsockname(fd, &addr, &len);
2144	ASSERT_EQ(ret, 0);
2145
2146	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2147	if (ret) {
2148		ASSERT_EQ(errno, ENOENT);
2149		SKIP(return, "no TLS support");
2150	}
2151	ASSERT_EQ(ret, 0);
2152
2153	ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
2154	ASSERT_EQ(ret, 0);
2155
2156	ret = setsockopt(fd, SOL_TLS, TLS_RX, &tls12, tls12.len);
2157	ASSERT_EQ(ret, 0);
2158
2159	len2 = sizeof(addr2);
2160	ret = getsockname(fd, &addr2, &len2);
2161	ASSERT_EQ(ret, 0);
2162
2163	EXPECT_EQ(len2, len);
2164	EXPECT_EQ(memcmp(&addr, &addr2, len), 0);
2165
2166	close(fd);
2167	close(sfd);
2168}
2169
2170TEST(prequeue) {
2171	struct tls_crypto_info_keys tls12;
2172	char buf[20000], buf2[20000];
2173	struct sockaddr_in addr;
2174	int sfd, cfd, ret, fd;
2175	socklen_t len;
2176
2177	len = sizeof(addr);
2178	memrnd(buf, sizeof(buf));
2179
2180	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_256, &tls12);
2181
2182	addr.sin_family = AF_INET;
2183	addr.sin_addr.s_addr = htonl(INADDR_ANY);
2184	addr.sin_port = 0;
2185
2186	fd = socket(AF_INET, SOCK_STREAM, 0);
2187	sfd = socket(AF_INET, SOCK_STREAM, 0);
2188
2189	ASSERT_EQ(bind(sfd, &addr, sizeof(addr)), 0);
2190	ASSERT_EQ(listen(sfd, 10), 0);
2191	ASSERT_EQ(getsockname(sfd, &addr, &len), 0);
2192	ASSERT_EQ(connect(fd, &addr, sizeof(addr)), 0);
2193	ASSERT_GE(cfd = accept(sfd, &addr, &len), 0);
2194	close(sfd);
2195
2196	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2197	if (ret) {
2198		ASSERT_EQ(errno, ENOENT);
2199		SKIP(return, "no TLS support");
2200	}
2201
2202	ASSERT_EQ(setsockopt(fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2203	EXPECT_EQ(send(fd, buf, sizeof(buf), MSG_DONTWAIT), sizeof(buf));
2204
2205	ASSERT_EQ(setsockopt(cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls")), 0);
2206	ASSERT_EQ(setsockopt(cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2207	EXPECT_EQ(recv(cfd, buf2, sizeof(buf2), MSG_WAITALL), sizeof(buf2));
2208
2209	EXPECT_EQ(memcmp(buf, buf2, sizeof(buf)), 0);
2210
2211	close(fd);
2212	close(cfd);
2213}
2214
2215static void __attribute__((constructor)) fips_check(void) {
2216	int res;
2217	FILE *f;
2218
2219	f = fopen("/proc/sys/crypto/fips_enabled", "r");
2220	if (f) {
2221		res = fscanf(f, "%d", &fips_enabled);
2222		if (res != 1)
2223			ksft_print_msg("ERROR: Couldn't read /proc/sys/crypto/fips_enabled\n");
2224		fclose(f);
2225	}
2226}
2227
2228TEST_HARNESS_MAIN