Linux Audio

Check our new training course

Loading...
Note: File does not exist in v4.10.11.
   1// SPDX-License-Identifier: GPL-2.0
   2
   3#define _GNU_SOURCE
   4
   5#include <arpa/inet.h>
   6#include <errno.h>
   7#include <error.h>
   8#include <fcntl.h>
   9#include <poll.h>
  10#include <stdio.h>
  11#include <stdlib.h>
  12#include <unistd.h>
  13
  14#include <linux/tls.h>
  15#include <linux/tcp.h>
  16#include <linux/socket.h>
  17
  18#include <sys/epoll.h>
  19#include <sys/types.h>
  20#include <sys/sendfile.h>
  21#include <sys/socket.h>
  22#include <sys/stat.h>
  23
  24#include "../kselftest_harness.h"
  25
  26#define TLS_PAYLOAD_MAX_LEN 16384
  27#define SOL_TLS 282
  28
  29static int fips_enabled;
  30
  31struct tls_crypto_info_keys {
  32	union {
  33		struct tls_crypto_info crypto_info;
  34		struct tls12_crypto_info_aes_gcm_128 aes128;
  35		struct tls12_crypto_info_chacha20_poly1305 chacha20;
  36		struct tls12_crypto_info_sm4_gcm sm4gcm;
  37		struct tls12_crypto_info_sm4_ccm sm4ccm;
  38		struct tls12_crypto_info_aes_ccm_128 aesccm128;
  39		struct tls12_crypto_info_aes_gcm_256 aesgcm256;
  40		struct tls12_crypto_info_aria_gcm_128 ariagcm128;
  41		struct tls12_crypto_info_aria_gcm_256 ariagcm256;
  42	};
  43	size_t len;
  44};
  45
  46static void tls_crypto_info_init(uint16_t tls_version, uint16_t cipher_type,
  47				 struct tls_crypto_info_keys *tls12)
  48{
  49	memset(tls12, 0, sizeof(*tls12));
  50
  51	switch (cipher_type) {
  52	case TLS_CIPHER_CHACHA20_POLY1305:
  53		tls12->len = sizeof(struct tls12_crypto_info_chacha20_poly1305);
  54		tls12->chacha20.info.version = tls_version;
  55		tls12->chacha20.info.cipher_type = cipher_type;
  56		break;
  57	case TLS_CIPHER_AES_GCM_128:
  58		tls12->len = sizeof(struct tls12_crypto_info_aes_gcm_128);
  59		tls12->aes128.info.version = tls_version;
  60		tls12->aes128.info.cipher_type = cipher_type;
  61		break;
  62	case TLS_CIPHER_SM4_GCM:
  63		tls12->len = sizeof(struct tls12_crypto_info_sm4_gcm);
  64		tls12->sm4gcm.info.version = tls_version;
  65		tls12->sm4gcm.info.cipher_type = cipher_type;
  66		break;
  67	case TLS_CIPHER_SM4_CCM:
  68		tls12->len = sizeof(struct tls12_crypto_info_sm4_ccm);
  69		tls12->sm4ccm.info.version = tls_version;
  70		tls12->sm4ccm.info.cipher_type = cipher_type;
  71		break;
  72	case TLS_CIPHER_AES_CCM_128:
  73		tls12->len = sizeof(struct tls12_crypto_info_aes_ccm_128);
  74		tls12->aesccm128.info.version = tls_version;
  75		tls12->aesccm128.info.cipher_type = cipher_type;
  76		break;
  77	case TLS_CIPHER_AES_GCM_256:
  78		tls12->len = sizeof(struct tls12_crypto_info_aes_gcm_256);
  79		tls12->aesgcm256.info.version = tls_version;
  80		tls12->aesgcm256.info.cipher_type = cipher_type;
  81		break;
  82	case TLS_CIPHER_ARIA_GCM_128:
  83		tls12->len = sizeof(struct tls12_crypto_info_aria_gcm_128);
  84		tls12->ariagcm128.info.version = tls_version;
  85		tls12->ariagcm128.info.cipher_type = cipher_type;
  86		break;
  87	case TLS_CIPHER_ARIA_GCM_256:
  88		tls12->len = sizeof(struct tls12_crypto_info_aria_gcm_256);
  89		tls12->ariagcm256.info.version = tls_version;
  90		tls12->ariagcm256.info.cipher_type = cipher_type;
  91		break;
  92	default:
  93		break;
  94	}
  95}
  96
  97static void memrnd(void *s, size_t n)
  98{
  99	int *dword = s;
 100	char *byte;
 101
 102	for (; n >= 4; n -= 4)
 103		*dword++ = rand();
 104	byte = (void *)dword;
 105	while (n--)
 106		*byte++ = rand();
 107}
 108
 109static void ulp_sock_pair(struct __test_metadata *_metadata,
 110			  int *fd, int *cfd, bool *notls)
 111{
 112	struct sockaddr_in addr;
 113	socklen_t len;
 114	int sfd, ret;
 115
 116	*notls = false;
 117	len = sizeof(addr);
 118
 119	addr.sin_family = AF_INET;
 120	addr.sin_addr.s_addr = htonl(INADDR_ANY);
 121	addr.sin_port = 0;
 122
 123	*fd = socket(AF_INET, SOCK_STREAM, 0);
 124	sfd = socket(AF_INET, SOCK_STREAM, 0);
 125
 126	ret = bind(sfd, &addr, sizeof(addr));
 127	ASSERT_EQ(ret, 0);
 128	ret = listen(sfd, 10);
 129	ASSERT_EQ(ret, 0);
 130
 131	ret = getsockname(sfd, &addr, &len);
 132	ASSERT_EQ(ret, 0);
 133
 134	ret = connect(*fd, &addr, sizeof(addr));
 135	ASSERT_EQ(ret, 0);
 136
 137	*cfd = accept(sfd, &addr, &len);
 138	ASSERT_GE(*cfd, 0);
 139
 140	close(sfd);
 141
 142	ret = setsockopt(*fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
 143	if (ret != 0) {
 144		ASSERT_EQ(errno, ENOENT);
 145		*notls = true;
 146		printf("Failure setting TCP_ULP, testing without tls\n");
 147		return;
 148	}
 149
 150	ret = setsockopt(*cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
 151	ASSERT_EQ(ret, 0);
 152}
 153
 154/* Produce a basic cmsg */
 155static int tls_send_cmsg(int fd, unsigned char record_type,
 156			 void *data, size_t len, int flags)
 157{
 158	char cbuf[CMSG_SPACE(sizeof(char))];
 159	int cmsg_len = sizeof(char);
 160	struct cmsghdr *cmsg;
 161	struct msghdr msg;
 162	struct iovec vec;
 163
 164	vec.iov_base = data;
 165	vec.iov_len = len;
 166	memset(&msg, 0, sizeof(struct msghdr));
 167	msg.msg_iov = &vec;
 168	msg.msg_iovlen = 1;
 169	msg.msg_control = cbuf;
 170	msg.msg_controllen = sizeof(cbuf);
 171	cmsg = CMSG_FIRSTHDR(&msg);
 172	cmsg->cmsg_level = SOL_TLS;
 173	/* test sending non-record types. */
 174	cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
 175	cmsg->cmsg_len = CMSG_LEN(cmsg_len);
 176	*CMSG_DATA(cmsg) = record_type;
 177	msg.msg_controllen = cmsg->cmsg_len;
 178
 179	return sendmsg(fd, &msg, flags);
 180}
 181
 182static int tls_recv_cmsg(struct __test_metadata *_metadata,
 183			 int fd, unsigned char record_type,
 184			 void *data, size_t len, int flags)
 185{
 186	char cbuf[CMSG_SPACE(sizeof(char))];
 187	struct cmsghdr *cmsg;
 188	unsigned char ctype;
 189	struct msghdr msg;
 190	struct iovec vec;
 191	int n;
 192
 193	vec.iov_base = data;
 194	vec.iov_len = len;
 195	memset(&msg, 0, sizeof(struct msghdr));
 196	msg.msg_iov = &vec;
 197	msg.msg_iovlen = 1;
 198	msg.msg_control = cbuf;
 199	msg.msg_controllen = sizeof(cbuf);
 200
 201	n = recvmsg(fd, &msg, flags);
 202
 203	cmsg = CMSG_FIRSTHDR(&msg);
 204	EXPECT_NE(cmsg, NULL);
 205	EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
 206	EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
 207	ctype = *((unsigned char *)CMSG_DATA(cmsg));
 208	EXPECT_EQ(ctype, record_type);
 209
 210	return n;
 211}
 212
 213FIXTURE(tls_basic)
 214{
 215	int fd, cfd;
 216	bool notls;
 217};
 218
 219FIXTURE_SETUP(tls_basic)
 220{
 221	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
 222}
 223
 224FIXTURE_TEARDOWN(tls_basic)
 225{
 226	close(self->fd);
 227	close(self->cfd);
 228}
 229
 230/* Send some data through with ULP but no keys */
 231TEST_F(tls_basic, base_base)
 232{
 233	char const *test_str = "test_read";
 234	int send_len = 10;
 235	char buf[10];
 236
 237	ASSERT_EQ(strlen(test_str) + 1, send_len);
 238
 239	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
 240	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
 241	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
 242};
 243
 244TEST_F(tls_basic, bad_cipher)
 245{
 246	struct tls_crypto_info_keys tls12;
 247
 248	tls12.crypto_info.version = 200;
 249	tls12.crypto_info.cipher_type = TLS_CIPHER_AES_GCM_128;
 250	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
 251
 252	tls12.crypto_info.version = TLS_1_2_VERSION;
 253	tls12.crypto_info.cipher_type = 50;
 254	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
 255
 256	tls12.crypto_info.version = TLS_1_2_VERSION;
 257	tls12.crypto_info.cipher_type = 59;
 258	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
 259
 260	tls12.crypto_info.version = TLS_1_2_VERSION;
 261	tls12.crypto_info.cipher_type = 10;
 262	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
 263
 264	tls12.crypto_info.version = TLS_1_2_VERSION;
 265	tls12.crypto_info.cipher_type = 70;
 266	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
 267}
 268
 269FIXTURE(tls)
 270{
 271	int fd, cfd;
 272	bool notls;
 273};
 274
 275FIXTURE_VARIANT(tls)
 276{
 277	uint16_t tls_version;
 278	uint16_t cipher_type;
 279	bool nopad, fips_non_compliant;
 280};
 281
 282FIXTURE_VARIANT_ADD(tls, 12_aes_gcm)
 283{
 284	.tls_version = TLS_1_2_VERSION,
 285	.cipher_type = TLS_CIPHER_AES_GCM_128,
 286};
 287
 288FIXTURE_VARIANT_ADD(tls, 13_aes_gcm)
 289{
 290	.tls_version = TLS_1_3_VERSION,
 291	.cipher_type = TLS_CIPHER_AES_GCM_128,
 292};
 293
 294FIXTURE_VARIANT_ADD(tls, 12_chacha)
 295{
 296	.tls_version = TLS_1_2_VERSION,
 297	.cipher_type = TLS_CIPHER_CHACHA20_POLY1305,
 298	.fips_non_compliant = true,
 299};
 300
 301FIXTURE_VARIANT_ADD(tls, 13_chacha)
 302{
 303	.tls_version = TLS_1_3_VERSION,
 304	.cipher_type = TLS_CIPHER_CHACHA20_POLY1305,
 305	.fips_non_compliant = true,
 306};
 307
 308FIXTURE_VARIANT_ADD(tls, 13_sm4_gcm)
 309{
 310	.tls_version = TLS_1_3_VERSION,
 311	.cipher_type = TLS_CIPHER_SM4_GCM,
 312	.fips_non_compliant = true,
 313};
 314
 315FIXTURE_VARIANT_ADD(tls, 13_sm4_ccm)
 316{
 317	.tls_version = TLS_1_3_VERSION,
 318	.cipher_type = TLS_CIPHER_SM4_CCM,
 319	.fips_non_compliant = true,
 320};
 321
 322FIXTURE_VARIANT_ADD(tls, 12_aes_ccm)
 323{
 324	.tls_version = TLS_1_2_VERSION,
 325	.cipher_type = TLS_CIPHER_AES_CCM_128,
 326};
 327
 328FIXTURE_VARIANT_ADD(tls, 13_aes_ccm)
 329{
 330	.tls_version = TLS_1_3_VERSION,
 331	.cipher_type = TLS_CIPHER_AES_CCM_128,
 332};
 333
 334FIXTURE_VARIANT_ADD(tls, 12_aes_gcm_256)
 335{
 336	.tls_version = TLS_1_2_VERSION,
 337	.cipher_type = TLS_CIPHER_AES_GCM_256,
 338};
 339
 340FIXTURE_VARIANT_ADD(tls, 13_aes_gcm_256)
 341{
 342	.tls_version = TLS_1_3_VERSION,
 343	.cipher_type = TLS_CIPHER_AES_GCM_256,
 344};
 345
 346FIXTURE_VARIANT_ADD(tls, 13_nopad)
 347{
 348	.tls_version = TLS_1_3_VERSION,
 349	.cipher_type = TLS_CIPHER_AES_GCM_128,
 350	.nopad = true,
 351};
 352
 353FIXTURE_VARIANT_ADD(tls, 12_aria_gcm)
 354{
 355	.tls_version = TLS_1_2_VERSION,
 356	.cipher_type = TLS_CIPHER_ARIA_GCM_128,
 357};
 358
 359FIXTURE_VARIANT_ADD(tls, 12_aria_gcm_256)
 360{
 361	.tls_version = TLS_1_2_VERSION,
 362	.cipher_type = TLS_CIPHER_ARIA_GCM_256,
 363};
 364
 365FIXTURE_SETUP(tls)
 366{
 367	struct tls_crypto_info_keys tls12;
 368	int one = 1;
 369	int ret;
 370
 371	if (fips_enabled && variant->fips_non_compliant)
 372		SKIP(return, "Unsupported cipher in FIPS mode");
 373
 374	tls_crypto_info_init(variant->tls_version, variant->cipher_type,
 375			     &tls12);
 376
 377	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
 378
 379	if (self->notls)
 380		return;
 381
 382	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
 383	ASSERT_EQ(ret, 0);
 384
 385	ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len);
 386	ASSERT_EQ(ret, 0);
 387
 388	if (variant->nopad) {
 389		ret = setsockopt(self->cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
 390				 (void *)&one, sizeof(one));
 391		ASSERT_EQ(ret, 0);
 392	}
 393}
 394
 395FIXTURE_TEARDOWN(tls)
 396{
 397	close(self->fd);
 398	close(self->cfd);
 399}
 400
 401TEST_F(tls, sendfile)
 402{
 403	int filefd = open("/proc/self/exe", O_RDONLY);
 404	struct stat st;
 405
 406	EXPECT_GE(filefd, 0);
 407	fstat(filefd, &st);
 408	EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
 409}
 410
 411TEST_F(tls, send_then_sendfile)
 412{
 413	int filefd = open("/proc/self/exe", O_RDONLY);
 414	char const *test_str = "test_send";
 415	int to_send = strlen(test_str) + 1;
 416	char recv_buf[10];
 417	struct stat st;
 418	char *buf;
 419
 420	EXPECT_GE(filefd, 0);
 421	fstat(filefd, &st);
 422	buf = (char *)malloc(st.st_size);
 423
 424	EXPECT_EQ(send(self->fd, test_str, to_send, 0), to_send);
 425	EXPECT_EQ(recv(self->cfd, recv_buf, to_send, MSG_WAITALL), to_send);
 426	EXPECT_EQ(memcmp(test_str, recv_buf, to_send), 0);
 427
 428	EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
 429	EXPECT_EQ(recv(self->cfd, buf, st.st_size, MSG_WAITALL), st.st_size);
 430}
 431
 432static void chunked_sendfile(struct __test_metadata *_metadata,
 433			     struct _test_data_tls *self,
 434			     uint16_t chunk_size,
 435			     uint16_t extra_payload_size)
 436{
 437	char buf[TLS_PAYLOAD_MAX_LEN];
 438	uint16_t test_payload_size;
 439	int size = 0;
 440	int ret;
 441	char filename[] = "/tmp/mytemp.XXXXXX";
 442	int fd = mkstemp(filename);
 443	off_t offset = 0;
 444
 445	unlink(filename);
 446	ASSERT_GE(fd, 0);
 447	EXPECT_GE(chunk_size, 1);
 448	test_payload_size = chunk_size + extra_payload_size;
 449	ASSERT_GE(TLS_PAYLOAD_MAX_LEN, test_payload_size);
 450	memset(buf, 1, test_payload_size);
 451	size = write(fd, buf, test_payload_size);
 452	EXPECT_EQ(size, test_payload_size);
 453	fsync(fd);
 454
 455	while (size > 0) {
 456		ret = sendfile(self->fd, fd, &offset, chunk_size);
 457		EXPECT_GE(ret, 0);
 458		size -= ret;
 459	}
 460
 461	EXPECT_EQ(recv(self->cfd, buf, test_payload_size, MSG_WAITALL),
 462		  test_payload_size);
 463
 464	close(fd);
 465}
 466
 467TEST_F(tls, multi_chunk_sendfile)
 468{
 469	chunked_sendfile(_metadata, self, 4096, 4096);
 470	chunked_sendfile(_metadata, self, 4096, 0);
 471	chunked_sendfile(_metadata, self, 4096, 1);
 472	chunked_sendfile(_metadata, self, 4096, 2048);
 473	chunked_sendfile(_metadata, self, 8192, 2048);
 474	chunked_sendfile(_metadata, self, 4096, 8192);
 475	chunked_sendfile(_metadata, self, 8192, 4096);
 476	chunked_sendfile(_metadata, self, 12288, 1024);
 477	chunked_sendfile(_metadata, self, 12288, 2000);
 478	chunked_sendfile(_metadata, self, 15360, 100);
 479	chunked_sendfile(_metadata, self, 15360, 300);
 480	chunked_sendfile(_metadata, self, 1, 4096);
 481	chunked_sendfile(_metadata, self, 2048, 4096);
 482	chunked_sendfile(_metadata, self, 2048, 8192);
 483	chunked_sendfile(_metadata, self, 4096, 8192);
 484	chunked_sendfile(_metadata, self, 1024, 12288);
 485	chunked_sendfile(_metadata, self, 2000, 12288);
 486	chunked_sendfile(_metadata, self, 100, 15360);
 487	chunked_sendfile(_metadata, self, 300, 15360);
 488}
 489
 490TEST_F(tls, recv_max)
 491{
 492	unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
 493	char recv_mem[TLS_PAYLOAD_MAX_LEN];
 494	char buf[TLS_PAYLOAD_MAX_LEN];
 495
 496	memrnd(buf, sizeof(buf));
 497
 498	EXPECT_GE(send(self->fd, buf, send_len, 0), 0);
 499	EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
 500	EXPECT_EQ(memcmp(buf, recv_mem, send_len), 0);
 501}
 502
 503TEST_F(tls, recv_small)
 504{
 505	char const *test_str = "test_read";
 506	int send_len = 10;
 507	char buf[10];
 508
 509	send_len = strlen(test_str) + 1;
 510	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
 511	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
 512	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
 513}
 514
 515TEST_F(tls, msg_more)
 516{
 517	char const *test_str = "test_read";
 518	int send_len = 10;
 519	char buf[10 * 2];
 520
 521	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
 522	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
 523	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
 524	EXPECT_EQ(recv(self->cfd, buf, send_len * 2, MSG_WAITALL),
 525		  send_len * 2);
 526	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
 527}
 528
 529TEST_F(tls, msg_more_unsent)
 530{
 531	char const *test_str = "test_read";
 532	int send_len = 10;
 533	char buf[10];
 534
 535	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
 536	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
 537}
 538
 539TEST_F(tls, msg_eor)
 540{
 541	char const *test_str = "test_read";
 542	int send_len = 10;
 543	char buf[10];
 544
 545	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_EOR), send_len);
 546	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
 547	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
 548}
 549
 550TEST_F(tls, sendmsg_single)
 551{
 552	struct msghdr msg;
 553
 554	char const *test_str = "test_sendmsg";
 555	size_t send_len = 13;
 556	struct iovec vec;
 557	char buf[13];
 558
 559	vec.iov_base = (char *)test_str;
 560	vec.iov_len = send_len;
 561	memset(&msg, 0, sizeof(struct msghdr));
 562	msg.msg_iov = &vec;
 563	msg.msg_iovlen = 1;
 564	EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
 565	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
 566	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
 567}
 568
 569#define MAX_FRAGS	64
 570#define SEND_LEN	13
 571TEST_F(tls, sendmsg_fragmented)
 572{
 573	char const *test_str = "test_sendmsg";
 574	char buf[SEND_LEN * MAX_FRAGS];
 575	struct iovec vec[MAX_FRAGS];
 576	struct msghdr msg;
 577	int i, frags;
 578
 579	for (frags = 1; frags <= MAX_FRAGS; frags++) {
 580		for (i = 0; i < frags; i++) {
 581			vec[i].iov_base = (char *)test_str;
 582			vec[i].iov_len = SEND_LEN;
 583		}
 584
 585		memset(&msg, 0, sizeof(struct msghdr));
 586		msg.msg_iov = vec;
 587		msg.msg_iovlen = frags;
 588
 589		EXPECT_EQ(sendmsg(self->fd, &msg, 0), SEND_LEN * frags);
 590		EXPECT_EQ(recv(self->cfd, buf, SEND_LEN * frags, MSG_WAITALL),
 591			  SEND_LEN * frags);
 592
 593		for (i = 0; i < frags; i++)
 594			EXPECT_EQ(memcmp(buf + SEND_LEN * i,
 595					 test_str, SEND_LEN), 0);
 596	}
 597}
 598#undef MAX_FRAGS
 599#undef SEND_LEN
 600
 601TEST_F(tls, sendmsg_large)
 602{
 603	void *mem = malloc(16384);
 604	size_t send_len = 16384;
 605	size_t sends = 128;
 606	struct msghdr msg;
 607	size_t recvs = 0;
 608	size_t sent = 0;
 609
 610	memset(&msg, 0, sizeof(struct msghdr));
 611	while (sent++ < sends) {
 612		struct iovec vec = { (void *)mem, send_len };
 613
 614		msg.msg_iov = &vec;
 615		msg.msg_iovlen = 1;
 616		EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
 617	}
 618
 619	while (recvs++ < sends) {
 620		EXPECT_NE(recv(self->cfd, mem, send_len, 0), -1);
 621	}
 622
 623	free(mem);
 624}
 625
 626TEST_F(tls, sendmsg_multiple)
 627{
 628	char const *test_str = "test_sendmsg_multiple";
 629	struct iovec vec[5];
 630	char *test_strs[5];
 631	struct msghdr msg;
 632	int total_len = 0;
 633	int len_cmp = 0;
 634	int iov_len = 5;
 635	char *buf;
 636	int i;
 637
 638	memset(&msg, 0, sizeof(struct msghdr));
 639	for (i = 0; i < iov_len; i++) {
 640		test_strs[i] = (char *)malloc(strlen(test_str) + 1);
 641		snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
 642		vec[i].iov_base = (void *)test_strs[i];
 643		vec[i].iov_len = strlen(test_strs[i]) + 1;
 644		total_len += vec[i].iov_len;
 645	}
 646	msg.msg_iov = vec;
 647	msg.msg_iovlen = iov_len;
 648
 649	EXPECT_EQ(sendmsg(self->fd, &msg, 0), total_len);
 650	buf = malloc(total_len);
 651	EXPECT_NE(recv(self->cfd, buf, total_len, 0), -1);
 652	for (i = 0; i < iov_len; i++) {
 653		EXPECT_EQ(memcmp(test_strs[i], buf + len_cmp,
 654				 strlen(test_strs[i])),
 655			  0);
 656		len_cmp += strlen(buf + len_cmp) + 1;
 657	}
 658	for (i = 0; i < iov_len; i++)
 659		free(test_strs[i]);
 660	free(buf);
 661}
 662
 663TEST_F(tls, sendmsg_multiple_stress)
 664{
 665	char const *test_str = "abcdefghijklmno";
 666	struct iovec vec[1024];
 667	char *test_strs[1024];
 668	int iov_len = 1024;
 669	int total_len = 0;
 670	char buf[1 << 14];
 671	struct msghdr msg;
 672	int len_cmp = 0;
 673	int i;
 674
 675	memset(&msg, 0, sizeof(struct msghdr));
 676	for (i = 0; i < iov_len; i++) {
 677		test_strs[i] = (char *)malloc(strlen(test_str) + 1);
 678		snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
 679		vec[i].iov_base = (void *)test_strs[i];
 680		vec[i].iov_len = strlen(test_strs[i]) + 1;
 681		total_len += vec[i].iov_len;
 682	}
 683	msg.msg_iov = vec;
 684	msg.msg_iovlen = iov_len;
 685
 686	EXPECT_EQ(sendmsg(self->fd, &msg, 0), total_len);
 687	EXPECT_NE(recv(self->cfd, buf, total_len, 0), -1);
 688
 689	for (i = 0; i < iov_len; i++)
 690		len_cmp += strlen(buf + len_cmp) + 1;
 691
 692	for (i = 0; i < iov_len; i++)
 693		free(test_strs[i]);
 694}
 695
 696TEST_F(tls, splice_from_pipe)
 697{
 698	int send_len = TLS_PAYLOAD_MAX_LEN;
 699	char mem_send[TLS_PAYLOAD_MAX_LEN];
 700	char mem_recv[TLS_PAYLOAD_MAX_LEN];
 701	int p[2];
 702
 703	ASSERT_GE(pipe(p), 0);
 704	EXPECT_GE(write(p[1], mem_send, send_len), 0);
 705	EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), 0);
 706	EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
 707	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
 708}
 709
 710TEST_F(tls, splice_more)
 711{
 712	unsigned int f = SPLICE_F_NONBLOCK | SPLICE_F_MORE | SPLICE_F_GIFT;
 713	int send_len = TLS_PAYLOAD_MAX_LEN;
 714	char mem_send[TLS_PAYLOAD_MAX_LEN];
 715	int i, send_pipe = 1;
 716	int p[2];
 717
 718	ASSERT_GE(pipe(p), 0);
 719	EXPECT_GE(write(p[1], mem_send, send_len), 0);
 720	for (i = 0; i < 32; i++)
 721		EXPECT_EQ(splice(p[0], NULL, self->fd, NULL, send_pipe, f), 1);
 722}
 723
 724TEST_F(tls, splice_from_pipe2)
 725{
 726	int send_len = 16000;
 727	char mem_send[16000];
 728	char mem_recv[16000];
 729	int p2[2];
 730	int p[2];
 731
 732	memrnd(mem_send, sizeof(mem_send));
 733
 734	ASSERT_GE(pipe(p), 0);
 735	ASSERT_GE(pipe(p2), 0);
 736	EXPECT_EQ(write(p[1], mem_send, 8000), 8000);
 737	EXPECT_EQ(splice(p[0], NULL, self->fd, NULL, 8000, 0), 8000);
 738	EXPECT_EQ(write(p2[1], mem_send + 8000, 8000), 8000);
 739	EXPECT_EQ(splice(p2[0], NULL, self->fd, NULL, 8000, 0), 8000);
 740	EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
 741	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
 742}
 743
 744TEST_F(tls, send_and_splice)
 745{
 746	int send_len = TLS_PAYLOAD_MAX_LEN;
 747	char mem_send[TLS_PAYLOAD_MAX_LEN];
 748	char mem_recv[TLS_PAYLOAD_MAX_LEN];
 749	char const *test_str = "test_read";
 750	int send_len2 = 10;
 751	char buf[10];
 752	int p[2];
 753
 754	ASSERT_GE(pipe(p), 0);
 755	EXPECT_EQ(send(self->fd, test_str, send_len2, 0), send_len2);
 756	EXPECT_EQ(recv(self->cfd, buf, send_len2, MSG_WAITALL), send_len2);
 757	EXPECT_EQ(memcmp(test_str, buf, send_len2), 0);
 758
 759	EXPECT_GE(write(p[1], mem_send, send_len), send_len);
 760	EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), send_len);
 761
 762	EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
 763	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
 764}
 765
 766TEST_F(tls, splice_to_pipe)
 767{
 768	int send_len = TLS_PAYLOAD_MAX_LEN;
 769	char mem_send[TLS_PAYLOAD_MAX_LEN];
 770	char mem_recv[TLS_PAYLOAD_MAX_LEN];
 771	int p[2];
 772
 773	memrnd(mem_send, sizeof(mem_send));
 774
 775	ASSERT_GE(pipe(p), 0);
 776	EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
 777	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), send_len);
 778	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
 779	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
 780}
 781
 782TEST_F(tls, splice_cmsg_to_pipe)
 783{
 784	char *test_str = "test_read";
 785	char record_type = 100;
 786	int send_len = 10;
 787	char buf[10];
 788	int p[2];
 789
 790	if (self->notls)
 791		SKIP(return, "no TLS support");
 792
 793	ASSERT_GE(pipe(p), 0);
 794	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
 795	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
 796	EXPECT_EQ(errno, EINVAL);
 797	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
 798	EXPECT_EQ(errno, EIO);
 799	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
 800				buf, sizeof(buf), MSG_WAITALL),
 801		  send_len);
 802	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
 803}
 804
 805TEST_F(tls, splice_dec_cmsg_to_pipe)
 806{
 807	char *test_str = "test_read";
 808	char record_type = 100;
 809	int send_len = 10;
 810	char buf[10];
 811	int p[2];
 812
 813	if (self->notls)
 814		SKIP(return, "no TLS support");
 815
 816	ASSERT_GE(pipe(p), 0);
 817	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
 818	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
 819	EXPECT_EQ(errno, EIO);
 820	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
 821	EXPECT_EQ(errno, EINVAL);
 822	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
 823				buf, sizeof(buf), MSG_WAITALL),
 824		  send_len);
 825	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
 826}
 827
 828TEST_F(tls, recv_and_splice)
 829{
 830	int send_len = TLS_PAYLOAD_MAX_LEN;
 831	char mem_send[TLS_PAYLOAD_MAX_LEN];
 832	char mem_recv[TLS_PAYLOAD_MAX_LEN];
 833	int half = send_len / 2;
 834	int p[2];
 835
 836	ASSERT_GE(pipe(p), 0);
 837	EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
 838	/* Recv hald of the record, splice the other half */
 839	EXPECT_EQ(recv(self->cfd, mem_recv, half, MSG_WAITALL), half);
 840	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, half, SPLICE_F_NONBLOCK),
 841		  half);
 842	EXPECT_EQ(read(p[0], &mem_recv[half], half), half);
 843	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
 844}
 845
 846TEST_F(tls, peek_and_splice)
 847{
 848	int send_len = TLS_PAYLOAD_MAX_LEN;
 849	char mem_send[TLS_PAYLOAD_MAX_LEN];
 850	char mem_recv[TLS_PAYLOAD_MAX_LEN];
 851	int chunk = TLS_PAYLOAD_MAX_LEN / 4;
 852	int n, i, p[2];
 853
 854	memrnd(mem_send, sizeof(mem_send));
 855
 856	ASSERT_GE(pipe(p), 0);
 857	for (i = 0; i < 4; i++)
 858		EXPECT_EQ(send(self->fd, &mem_send[chunk * i], chunk, 0),
 859			  chunk);
 860
 861	EXPECT_EQ(recv(self->cfd, mem_recv, chunk * 5 / 2,
 862		       MSG_WAITALL | MSG_PEEK),
 863		  chunk * 5 / 2);
 864	EXPECT_EQ(memcmp(mem_send, mem_recv, chunk * 5 / 2), 0);
 865
 866	n = 0;
 867	while (n < send_len) {
 868		i = splice(self->cfd, NULL, p[1], NULL, send_len - n, 0);
 869		EXPECT_GT(i, 0);
 870		n += i;
 871	}
 872	EXPECT_EQ(n, send_len);
 873	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
 874	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
 875}
 876
 877TEST_F(tls, recvmsg_single)
 878{
 879	char const *test_str = "test_recvmsg_single";
 880	int send_len = strlen(test_str) + 1;
 881	char buf[20];
 882	struct msghdr hdr;
 883	struct iovec vec;
 884
 885	memset(&hdr, 0, sizeof(hdr));
 886	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
 887	vec.iov_base = (char *)buf;
 888	vec.iov_len = send_len;
 889	hdr.msg_iovlen = 1;
 890	hdr.msg_iov = &vec;
 891	EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
 892	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
 893}
 894
 895TEST_F(tls, recvmsg_single_max)
 896{
 897	int send_len = TLS_PAYLOAD_MAX_LEN;
 898	char send_mem[TLS_PAYLOAD_MAX_LEN];
 899	char recv_mem[TLS_PAYLOAD_MAX_LEN];
 900	struct iovec vec;
 901	struct msghdr hdr;
 902
 903	memrnd(send_mem, sizeof(send_mem));
 904
 905	EXPECT_EQ(send(self->fd, send_mem, send_len, 0), send_len);
 906	vec.iov_base = (char *)recv_mem;
 907	vec.iov_len = TLS_PAYLOAD_MAX_LEN;
 908
 909	hdr.msg_iovlen = 1;
 910	hdr.msg_iov = &vec;
 911	EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
 912	EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
 913}
 914
 915TEST_F(tls, recvmsg_multiple)
 916{
 917	unsigned int msg_iovlen = 1024;
 918	struct iovec vec[1024];
 919	char *iov_base[1024];
 920	unsigned int iov_len = 16;
 921	int send_len = 1 << 14;
 922	char buf[1 << 14];
 923	struct msghdr hdr;
 924	int i;
 925
 926	memrnd(buf, sizeof(buf));
 927
 928	EXPECT_EQ(send(self->fd, buf, send_len, 0), send_len);
 929	for (i = 0; i < msg_iovlen; i++) {
 930		iov_base[i] = (char *)malloc(iov_len);
 931		vec[i].iov_base = iov_base[i];
 932		vec[i].iov_len = iov_len;
 933	}
 934
 935	hdr.msg_iovlen = msg_iovlen;
 936	hdr.msg_iov = vec;
 937	EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
 938
 939	for (i = 0; i < msg_iovlen; i++)
 940		free(iov_base[i]);
 941}
 942
 943TEST_F(tls, single_send_multiple_recv)
 944{
 945	unsigned int total_len = TLS_PAYLOAD_MAX_LEN * 2;
 946	unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
 947	char send_mem[TLS_PAYLOAD_MAX_LEN * 2];
 948	char recv_mem[TLS_PAYLOAD_MAX_LEN * 2];
 949
 950	memrnd(send_mem, sizeof(send_mem));
 951
 952	EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
 953	memset(recv_mem, 0, total_len);
 954
 955	EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
 956	EXPECT_NE(recv(self->cfd, recv_mem + send_len, send_len, 0), -1);
 957	EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
 958}
 959
 960TEST_F(tls, multiple_send_single_recv)
 961{
 962	unsigned int total_len = 2 * 10;
 963	unsigned int send_len = 10;
 964	char recv_mem[2 * 10];
 965	char send_mem[10];
 966
 967	memrnd(send_mem, sizeof(send_mem));
 968
 969	EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
 970	EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
 971	memset(recv_mem, 0, total_len);
 972	EXPECT_EQ(recv(self->cfd, recv_mem, total_len, MSG_WAITALL), total_len);
 973
 974	EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
 975	EXPECT_EQ(memcmp(send_mem, recv_mem + send_len, send_len), 0);
 976}
 977
 978TEST_F(tls, single_send_multiple_recv_non_align)
 979{
 980	const unsigned int total_len = 15;
 981	const unsigned int recv_len = 10;
 982	char recv_mem[recv_len * 2];
 983	char send_mem[total_len];
 984
 985	memrnd(send_mem, sizeof(send_mem));
 986
 987	EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
 988	memset(recv_mem, 0, total_len);
 989
 990	EXPECT_EQ(recv(self->cfd, recv_mem, recv_len, 0), recv_len);
 991	EXPECT_EQ(recv(self->cfd, recv_mem + recv_len, recv_len, 0), 5);
 992	EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
 993}
 994
 995TEST_F(tls, recv_partial)
 996{
 997	char const *test_str = "test_read_partial";
 998	char const *test_str_first = "test_read";
 999	char const *test_str_second = "_partial";
1000	int send_len = strlen(test_str) + 1;
1001	char recv_mem[18];
1002
1003	memset(recv_mem, 0, sizeof(recv_mem));
1004	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1005	EXPECT_EQ(recv(self->cfd, recv_mem, strlen(test_str_first),
1006		       MSG_WAITALL), strlen(test_str_first));
1007	EXPECT_EQ(memcmp(test_str_first, recv_mem, strlen(test_str_first)), 0);
1008	memset(recv_mem, 0, sizeof(recv_mem));
1009	EXPECT_EQ(recv(self->cfd, recv_mem, strlen(test_str_second),
1010		       MSG_WAITALL), strlen(test_str_second));
1011	EXPECT_EQ(memcmp(test_str_second, recv_mem, strlen(test_str_second)),
1012		  0);
1013}
1014
1015TEST_F(tls, recv_nonblock)
1016{
1017	char buf[4096];
1018	bool err;
1019
1020	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
1021	err = (errno == EAGAIN || errno == EWOULDBLOCK);
1022	EXPECT_EQ(err, true);
1023}
1024
1025TEST_F(tls, recv_peek)
1026{
1027	char const *test_str = "test_read_peek";
1028	int send_len = strlen(test_str) + 1;
1029	char buf[15];
1030
1031	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1032	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_PEEK), send_len);
1033	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1034	memset(buf, 0, sizeof(buf));
1035	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1036	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1037}
1038
1039TEST_F(tls, recv_peek_multiple)
1040{
1041	char const *test_str = "test_read_peek";
1042	int send_len = strlen(test_str) + 1;
1043	unsigned int num_peeks = 100;
1044	char buf[15];
1045	int i;
1046
1047	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1048	for (i = 0; i < num_peeks; i++) {
1049		EXPECT_NE(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
1050		EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1051		memset(buf, 0, sizeof(buf));
1052	}
1053	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1054	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1055}
1056
1057TEST_F(tls, recv_peek_multiple_records)
1058{
1059	char const *test_str = "test_read_peek_mult_recs";
1060	char const *test_str_first = "test_read_peek";
1061	char const *test_str_second = "_mult_recs";
1062	int len;
1063	char buf[64];
1064
1065	len = strlen(test_str_first);
1066	EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
1067
1068	len = strlen(test_str_second) + 1;
1069	EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1070
1071	len = strlen(test_str_first);
1072	memset(buf, 0, len);
1073	EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
1074
1075	/* MSG_PEEK can only peek into the current record. */
1076	len = strlen(test_str_first);
1077	EXPECT_EQ(memcmp(test_str_first, buf, len), 0);
1078
1079	len = strlen(test_str) + 1;
1080	memset(buf, 0, len);
1081	EXPECT_EQ(recv(self->cfd, buf, len, MSG_WAITALL), len);
1082
1083	/* Non-MSG_PEEK will advance strparser (and therefore record)
1084	 * however.
1085	 */
1086	len = strlen(test_str) + 1;
1087	EXPECT_EQ(memcmp(test_str, buf, len), 0);
1088
1089	/* MSG_MORE will hold current record open, so later MSG_PEEK
1090	 * will see everything.
1091	 */
1092	len = strlen(test_str_first);
1093	EXPECT_EQ(send(self->fd, test_str_first, len, MSG_MORE), len);
1094
1095	len = strlen(test_str_second) + 1;
1096	EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1097
1098	len = strlen(test_str) + 1;
1099	memset(buf, 0, len);
1100	EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
1101
1102	len = strlen(test_str) + 1;
1103	EXPECT_EQ(memcmp(test_str, buf, len), 0);
1104}
1105
1106TEST_F(tls, recv_peek_large_buf_mult_recs)
1107{
1108	char const *test_str = "test_read_peek_mult_recs";
1109	char const *test_str_first = "test_read_peek";
1110	char const *test_str_second = "_mult_recs";
1111	int len;
1112	char buf[64];
1113
1114	len = strlen(test_str_first);
1115	EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
1116
1117	len = strlen(test_str_second) + 1;
1118	EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1119
1120	len = strlen(test_str) + 1;
1121	memset(buf, 0, len);
1122	EXPECT_NE((len = recv(self->cfd, buf, len,
1123			      MSG_PEEK | MSG_WAITALL)), -1);
1124	len = strlen(test_str) + 1;
1125	EXPECT_EQ(memcmp(test_str, buf, len), 0);
1126}
1127
1128TEST_F(tls, recv_lowat)
1129{
1130	char send_mem[10] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
1131	char recv_mem[20];
1132	int lowat = 8;
1133
1134	EXPECT_EQ(send(self->fd, send_mem, 10, 0), 10);
1135	EXPECT_EQ(send(self->fd, send_mem, 5, 0), 5);
1136
1137	memset(recv_mem, 0, 20);
1138	EXPECT_EQ(setsockopt(self->cfd, SOL_SOCKET, SO_RCVLOWAT,
1139			     &lowat, sizeof(lowat)), 0);
1140	EXPECT_EQ(recv(self->cfd, recv_mem, 1, MSG_WAITALL), 1);
1141	EXPECT_EQ(recv(self->cfd, recv_mem + 1, 6, MSG_WAITALL), 6);
1142	EXPECT_EQ(recv(self->cfd, recv_mem + 7, 10, 0), 8);
1143
1144	EXPECT_EQ(memcmp(send_mem, recv_mem, 10), 0);
1145	EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0);
1146}
1147
1148TEST_F(tls, bidir)
1149{
1150	char const *test_str = "test_read";
1151	int send_len = 10;
1152	char buf[10];
1153	int ret;
1154
1155	if (!self->notls) {
1156		struct tls_crypto_info_keys tls12;
1157
1158		tls_crypto_info_init(variant->tls_version, variant->cipher_type,
1159				     &tls12);
1160
1161		ret = setsockopt(self->fd, SOL_TLS, TLS_RX, &tls12,
1162				 tls12.len);
1163		ASSERT_EQ(ret, 0);
1164
1165		ret = setsockopt(self->cfd, SOL_TLS, TLS_TX, &tls12,
1166				 tls12.len);
1167		ASSERT_EQ(ret, 0);
1168	}
1169
1170	ASSERT_EQ(strlen(test_str) + 1, send_len);
1171
1172	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1173	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1174	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1175
1176	memset(buf, 0, sizeof(buf));
1177
1178	EXPECT_EQ(send(self->cfd, test_str, send_len, 0), send_len);
1179	EXPECT_NE(recv(self->fd, buf, send_len, 0), -1);
1180	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1181};
1182
1183TEST_F(tls, pollin)
1184{
1185	char const *test_str = "test_poll";
1186	struct pollfd fd = { 0, 0, 0 };
1187	char buf[10];
1188	int send_len = 10;
1189
1190	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1191	fd.fd = self->cfd;
1192	fd.events = POLLIN;
1193
1194	EXPECT_EQ(poll(&fd, 1, 20), 1);
1195	EXPECT_EQ(fd.revents & POLLIN, 1);
1196	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
1197	/* Test timing out */
1198	EXPECT_EQ(poll(&fd, 1, 20), 0);
1199}
1200
1201TEST_F(tls, poll_wait)
1202{
1203	char const *test_str = "test_poll_wait";
1204	int send_len = strlen(test_str) + 1;
1205	struct pollfd fd = { 0, 0, 0 };
1206	char recv_mem[15];
1207
1208	fd.fd = self->cfd;
1209	fd.events = POLLIN;
1210	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1211	/* Set timeout to inf. secs */
1212	EXPECT_EQ(poll(&fd, 1, -1), 1);
1213	EXPECT_EQ(fd.revents & POLLIN, 1);
1214	EXPECT_EQ(recv(self->cfd, recv_mem, send_len, MSG_WAITALL), send_len);
1215}
1216
1217TEST_F(tls, poll_wait_split)
1218{
1219	struct pollfd fd = { 0, 0, 0 };
1220	char send_mem[20] = {};
1221	char recv_mem[15];
1222
1223	fd.fd = self->cfd;
1224	fd.events = POLLIN;
1225	/* Send 20 bytes */
1226	EXPECT_EQ(send(self->fd, send_mem, sizeof(send_mem), 0),
1227		  sizeof(send_mem));
1228	/* Poll with inf. timeout */
1229	EXPECT_EQ(poll(&fd, 1, -1), 1);
1230	EXPECT_EQ(fd.revents & POLLIN, 1);
1231	EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), MSG_WAITALL),
1232		  sizeof(recv_mem));
1233
1234	/* Now the remaining 5 bytes of record data are in TLS ULP */
1235	fd.fd = self->cfd;
1236	fd.events = POLLIN;
1237	EXPECT_EQ(poll(&fd, 1, -1), 1);
1238	EXPECT_EQ(fd.revents & POLLIN, 1);
1239	EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0),
1240		  sizeof(send_mem) - sizeof(recv_mem));
1241}
1242
1243TEST_F(tls, blocking)
1244{
1245	size_t data = 100000;
1246	int res = fork();
1247
1248	EXPECT_NE(res, -1);
1249
1250	if (res) {
1251		/* parent */
1252		size_t left = data;
1253		char buf[16384];
1254		int status;
1255		int pid2;
1256
1257		while (left) {
1258			int res = send(self->fd, buf,
1259				       left > 16384 ? 16384 : left, 0);
1260
1261			EXPECT_GE(res, 0);
1262			left -= res;
1263		}
1264
1265		pid2 = wait(&status);
1266		EXPECT_EQ(status, 0);
1267		EXPECT_EQ(res, pid2);
1268	} else {
1269		/* child */
1270		size_t left = data;
1271		char buf[16384];
1272
1273		while (left) {
1274			int res = recv(self->cfd, buf,
1275				       left > 16384 ? 16384 : left, 0);
1276
1277			EXPECT_GE(res, 0);
1278			left -= res;
1279		}
1280	}
1281}
1282
1283TEST_F(tls, nonblocking)
1284{
1285	size_t data = 100000;
1286	int sendbuf = 100;
1287	int flags;
1288	int res;
1289
1290	flags = fcntl(self->fd, F_GETFL, 0);
1291	fcntl(self->fd, F_SETFL, flags | O_NONBLOCK);
1292	fcntl(self->cfd, F_SETFL, flags | O_NONBLOCK);
1293
1294	/* Ensure nonblocking behavior by imposing a small send
1295	 * buffer.
1296	 */
1297	EXPECT_EQ(setsockopt(self->fd, SOL_SOCKET, SO_SNDBUF,
1298			     &sendbuf, sizeof(sendbuf)), 0);
1299
1300	res = fork();
1301	EXPECT_NE(res, -1);
1302
1303	if (res) {
1304		/* parent */
1305		bool eagain = false;
1306		size_t left = data;
1307		char buf[16384];
1308		int status;
1309		int pid2;
1310
1311		while (left) {
1312			int res = send(self->fd, buf,
1313				       left > 16384 ? 16384 : left, 0);
1314
1315			if (res == -1 && errno == EAGAIN) {
1316				eagain = true;
1317				usleep(10000);
1318				continue;
1319			}
1320			EXPECT_GE(res, 0);
1321			left -= res;
1322		}
1323
1324		EXPECT_TRUE(eagain);
1325		pid2 = wait(&status);
1326
1327		EXPECT_EQ(status, 0);
1328		EXPECT_EQ(res, pid2);
1329	} else {
1330		/* child */
1331		bool eagain = false;
1332		size_t left = data;
1333		char buf[16384];
1334
1335		while (left) {
1336			int res = recv(self->cfd, buf,
1337				       left > 16384 ? 16384 : left, 0);
1338
1339			if (res == -1 && errno == EAGAIN) {
1340				eagain = true;
1341				usleep(10000);
1342				continue;
1343			}
1344			EXPECT_GE(res, 0);
1345			left -= res;
1346		}
1347		EXPECT_TRUE(eagain);
1348	}
1349}
1350
1351static void
1352test_mutliproc(struct __test_metadata *_metadata, struct _test_data_tls *self,
1353	       bool sendpg, unsigned int n_readers, unsigned int n_writers)
1354{
1355	const unsigned int n_children = n_readers + n_writers;
1356	const size_t data = 6 * 1000 * 1000;
1357	const size_t file_sz = data / 100;
1358	size_t read_bias, write_bias;
1359	int i, fd, child_id;
1360	char buf[file_sz];
1361	pid_t pid;
1362
1363	/* Only allow multiples for simplicity */
1364	ASSERT_EQ(!(n_readers % n_writers) || !(n_writers % n_readers), true);
1365	read_bias = n_writers / n_readers ?: 1;
1366	write_bias = n_readers / n_writers ?: 1;
1367
1368	/* prep a file to send */
1369	fd = open("/tmp/", O_TMPFILE | O_RDWR, 0600);
1370	ASSERT_GE(fd, 0);
1371
1372	memset(buf, 0xac, file_sz);
1373	ASSERT_EQ(write(fd, buf, file_sz), file_sz);
1374
1375	/* spawn children */
1376	for (child_id = 0; child_id < n_children; child_id++) {
1377		pid = fork();
1378		ASSERT_NE(pid, -1);
1379		if (!pid)
1380			break;
1381	}
1382
1383	/* parent waits for all children */
1384	if (pid) {
1385		for (i = 0; i < n_children; i++) {
1386			int status;
1387
1388			wait(&status);
1389			EXPECT_EQ(status, 0);
1390		}
1391
1392		return;
1393	}
1394
1395	/* Split threads for reading and writing */
1396	if (child_id < n_readers) {
1397		size_t left = data * read_bias;
1398		char rb[8001];
1399
1400		while (left) {
1401			int res;
1402
1403			res = recv(self->cfd, rb,
1404				   left > sizeof(rb) ? sizeof(rb) : left, 0);
1405
1406			EXPECT_GE(res, 0);
1407			left -= res;
1408		}
1409	} else {
1410		size_t left = data * write_bias;
1411
1412		while (left) {
1413			int res;
1414
1415			ASSERT_EQ(lseek(fd, 0, SEEK_SET), 0);
1416			if (sendpg)
1417				res = sendfile(self->fd, fd, NULL,
1418					       left > file_sz ? file_sz : left);
1419			else
1420				res = send(self->fd, buf,
1421					   left > file_sz ? file_sz : left, 0);
1422
1423			EXPECT_GE(res, 0);
1424			left -= res;
1425		}
1426	}
1427}
1428
1429TEST_F(tls, mutliproc_even)
1430{
1431	test_mutliproc(_metadata, self, false, 6, 6);
1432}
1433
1434TEST_F(tls, mutliproc_readers)
1435{
1436	test_mutliproc(_metadata, self, false, 4, 12);
1437}
1438
1439TEST_F(tls, mutliproc_writers)
1440{
1441	test_mutliproc(_metadata, self, false, 10, 2);
1442}
1443
1444TEST_F(tls, mutliproc_sendpage_even)
1445{
1446	test_mutliproc(_metadata, self, true, 6, 6);
1447}
1448
1449TEST_F(tls, mutliproc_sendpage_readers)
1450{
1451	test_mutliproc(_metadata, self, true, 4, 12);
1452}
1453
1454TEST_F(tls, mutliproc_sendpage_writers)
1455{
1456	test_mutliproc(_metadata, self, true, 10, 2);
1457}
1458
1459TEST_F(tls, control_msg)
1460{
1461	char *test_str = "test_read";
1462	char record_type = 100;
1463	int send_len = 10;
1464	char buf[10];
1465
1466	if (self->notls)
1467		SKIP(return, "no TLS support");
1468
1469	EXPECT_EQ(tls_send_cmsg(self->fd, record_type, test_str, send_len, 0),
1470		  send_len);
1471	/* Should fail because we didn't provide a control message */
1472	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
1473
1474	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
1475				buf, sizeof(buf), MSG_WAITALL | MSG_PEEK),
1476		  send_len);
1477	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1478
1479	/* Recv the message again without MSG_PEEK */
1480	memset(buf, 0, sizeof(buf));
1481
1482	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
1483				buf, sizeof(buf), MSG_WAITALL),
1484		  send_len);
1485	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1486}
1487
1488TEST_F(tls, control_msg_nomerge)
1489{
1490	char *rec1 = "1111";
1491	char *rec2 = "2222";
1492	int send_len = 5;
1493	char buf[15];
1494
1495	if (self->notls)
1496		SKIP(return, "no TLS support");
1497
1498	EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec1, send_len, 0), send_len);
1499	EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec2, send_len, 0), send_len);
1500
1501	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), MSG_PEEK), send_len);
1502	EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1503
1504	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), MSG_PEEK), send_len);
1505	EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1506
1507	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), 0), send_len);
1508	EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1509
1510	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), 0), send_len);
1511	EXPECT_EQ(memcmp(buf, rec2, send_len), 0);
1512}
1513
1514TEST_F(tls, data_control_data)
1515{
1516	char *rec1 = "1111";
1517	char *rec2 = "2222";
1518	char *rec3 = "3333";
1519	int send_len = 5;
1520	char buf[15];
1521
1522	if (self->notls)
1523		SKIP(return, "no TLS support");
1524
1525	EXPECT_EQ(send(self->fd, rec1, send_len, 0), send_len);
1526	EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec2, send_len, 0), send_len);
1527	EXPECT_EQ(send(self->fd, rec3, send_len, 0), send_len);
1528
1529	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
1530	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
1531}
1532
1533TEST_F(tls, shutdown)
1534{
1535	char const *test_str = "test_read";
1536	int send_len = 10;
1537	char buf[10];
1538
1539	ASSERT_EQ(strlen(test_str) + 1, send_len);
1540
1541	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1542	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1543	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1544
1545	shutdown(self->fd, SHUT_RDWR);
1546	shutdown(self->cfd, SHUT_RDWR);
1547}
1548
1549TEST_F(tls, shutdown_unsent)
1550{
1551	char const *test_str = "test_read";
1552	int send_len = 10;
1553
1554	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
1555
1556	shutdown(self->fd, SHUT_RDWR);
1557	shutdown(self->cfd, SHUT_RDWR);
1558}
1559
1560TEST_F(tls, shutdown_reuse)
1561{
1562	struct sockaddr_in addr;
1563	int ret;
1564
1565	shutdown(self->fd, SHUT_RDWR);
1566	shutdown(self->cfd, SHUT_RDWR);
1567	close(self->cfd);
1568
1569	addr.sin_family = AF_INET;
1570	addr.sin_addr.s_addr = htonl(INADDR_ANY);
1571	addr.sin_port = 0;
1572
1573	ret = bind(self->fd, &addr, sizeof(addr));
1574	EXPECT_EQ(ret, 0);
1575	ret = listen(self->fd, 10);
1576	EXPECT_EQ(ret, -1);
1577	EXPECT_EQ(errno, EINVAL);
1578
1579	ret = connect(self->fd, &addr, sizeof(addr));
1580	EXPECT_EQ(ret, -1);
1581	EXPECT_EQ(errno, EISCONN);
1582}
1583
1584TEST_F(tls, getsockopt)
1585{
1586	struct tls_crypto_info_keys expect, get;
1587	socklen_t len;
1588
1589	/* get only the version/cipher */
1590	len = sizeof(struct tls_crypto_info);
1591	memrnd(&get, sizeof(get));
1592	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), 0);
1593	EXPECT_EQ(len, sizeof(struct tls_crypto_info));
1594	EXPECT_EQ(get.crypto_info.version, variant->tls_version);
1595	EXPECT_EQ(get.crypto_info.cipher_type, variant->cipher_type);
1596
1597	/* get the full crypto_info */
1598	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &expect);
1599	len = expect.len;
1600	memrnd(&get, sizeof(get));
1601	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), 0);
1602	EXPECT_EQ(len, expect.len);
1603	EXPECT_EQ(get.crypto_info.version, variant->tls_version);
1604	EXPECT_EQ(get.crypto_info.cipher_type, variant->cipher_type);
1605	EXPECT_EQ(memcmp(&get, &expect, expect.len), 0);
1606
1607	/* short get should fail */
1608	len = sizeof(struct tls_crypto_info) - 1;
1609	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), -1);
1610	EXPECT_EQ(errno, EINVAL);
1611
1612	/* partial get of the cipher data should fail */
1613	len = expect.len - 1;
1614	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), -1);
1615	EXPECT_EQ(errno, EINVAL);
1616}
1617
1618FIXTURE(tls_err)
1619{
1620	int fd, cfd;
1621	int fd2, cfd2;
1622	bool notls;
1623};
1624
1625FIXTURE_VARIANT(tls_err)
1626{
1627	uint16_t tls_version;
1628};
1629
1630FIXTURE_VARIANT_ADD(tls_err, 12_aes_gcm)
1631{
1632	.tls_version = TLS_1_2_VERSION,
1633};
1634
1635FIXTURE_VARIANT_ADD(tls_err, 13_aes_gcm)
1636{
1637	.tls_version = TLS_1_3_VERSION,
1638};
1639
1640FIXTURE_SETUP(tls_err)
1641{
1642	struct tls_crypto_info_keys tls12;
1643	int ret;
1644
1645	tls_crypto_info_init(variant->tls_version, TLS_CIPHER_AES_GCM_128,
1646			     &tls12);
1647
1648	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
1649	ulp_sock_pair(_metadata, &self->fd2, &self->cfd2, &self->notls);
1650	if (self->notls)
1651		return;
1652
1653	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
1654	ASSERT_EQ(ret, 0);
1655
1656	ret = setsockopt(self->cfd2, SOL_TLS, TLS_RX, &tls12, tls12.len);
1657	ASSERT_EQ(ret, 0);
1658}
1659
1660FIXTURE_TEARDOWN(tls_err)
1661{
1662	close(self->fd);
1663	close(self->cfd);
1664	close(self->fd2);
1665	close(self->cfd2);
1666}
1667
1668TEST_F(tls_err, bad_rec)
1669{
1670	char buf[64];
1671
1672	if (self->notls)
1673		SKIP(return, "no TLS support");
1674
1675	memset(buf, 0x55, sizeof(buf));
1676	EXPECT_EQ(send(self->fd2, buf, sizeof(buf), 0), sizeof(buf));
1677	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1678	EXPECT_EQ(errno, EMSGSIZE);
1679	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), MSG_DONTWAIT), -1);
1680	EXPECT_EQ(errno, EAGAIN);
1681}
1682
1683TEST_F(tls_err, bad_auth)
1684{
1685	char buf[128];
1686	int n;
1687
1688	if (self->notls)
1689		SKIP(return, "no TLS support");
1690
1691	memrnd(buf, sizeof(buf) / 2);
1692	EXPECT_EQ(send(self->fd, buf, sizeof(buf) / 2, 0), sizeof(buf) / 2);
1693	n = recv(self->cfd, buf, sizeof(buf), 0);
1694	EXPECT_GT(n, sizeof(buf) / 2);
1695
1696	buf[n - 1]++;
1697
1698	EXPECT_EQ(send(self->fd2, buf, n, 0), n);
1699	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1700	EXPECT_EQ(errno, EBADMSG);
1701	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1702	EXPECT_EQ(errno, EBADMSG);
1703}
1704
1705TEST_F(tls_err, bad_in_large_read)
1706{
1707	char txt[3][64];
1708	char cip[3][128];
1709	char buf[3 * 128];
1710	int i, n;
1711
1712	if (self->notls)
1713		SKIP(return, "no TLS support");
1714
1715	/* Put 3 records in the sockets */
1716	for (i = 0; i < 3; i++) {
1717		memrnd(txt[i], sizeof(txt[i]));
1718		EXPECT_EQ(send(self->fd, txt[i], sizeof(txt[i]), 0),
1719			  sizeof(txt[i]));
1720		n = recv(self->cfd, cip[i], sizeof(cip[i]), 0);
1721		EXPECT_GT(n, sizeof(txt[i]));
1722		/* Break the third message */
1723		if (i == 2)
1724			cip[2][n - 1]++;
1725		EXPECT_EQ(send(self->fd2, cip[i], n, 0), n);
1726	}
1727
1728	/* We should be able to receive the first two messages */
1729	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt[0]) * 2);
1730	EXPECT_EQ(memcmp(buf, txt[0], sizeof(txt[0])), 0);
1731	EXPECT_EQ(memcmp(buf + sizeof(txt[0]), txt[1], sizeof(txt[1])), 0);
1732	/* Third mesasge is bad */
1733	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1734	EXPECT_EQ(errno, EBADMSG);
1735	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1736	EXPECT_EQ(errno, EBADMSG);
1737}
1738
1739TEST_F(tls_err, bad_cmsg)
1740{
1741	char *test_str = "test_read";
1742	int send_len = 10;
1743	char cip[128];
1744	char buf[128];
1745	char txt[64];
1746	int n;
1747
1748	if (self->notls)
1749		SKIP(return, "no TLS support");
1750
1751	/* Queue up one data record */
1752	memrnd(txt, sizeof(txt));
1753	EXPECT_EQ(send(self->fd, txt, sizeof(txt), 0), sizeof(txt));
1754	n = recv(self->cfd, cip, sizeof(cip), 0);
1755	EXPECT_GT(n, sizeof(txt));
1756	EXPECT_EQ(send(self->fd2, cip, n, 0), n);
1757
1758	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
1759	n = recv(self->cfd, cip, sizeof(cip), 0);
1760	cip[n - 1]++; /* Break it */
1761	EXPECT_GT(n, send_len);
1762	EXPECT_EQ(send(self->fd2, cip, n, 0), n);
1763
1764	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt));
1765	EXPECT_EQ(memcmp(buf, txt, sizeof(txt)), 0);
1766	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1767	EXPECT_EQ(errno, EBADMSG);
1768	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1769	EXPECT_EQ(errno, EBADMSG);
1770}
1771
1772TEST_F(tls_err, timeo)
1773{
1774	struct timeval tv = { .tv_usec = 10000, };
1775	char buf[128];
1776	int ret;
1777
1778	if (self->notls)
1779		SKIP(return, "no TLS support");
1780
1781	ret = setsockopt(self->cfd2, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
1782	ASSERT_EQ(ret, 0);
1783
1784	ret = fork();
1785	ASSERT_GE(ret, 0);
1786
1787	if (ret) {
1788		usleep(1000); /* Give child a head start */
1789
1790		EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1791		EXPECT_EQ(errno, EAGAIN);
1792
1793		EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1794		EXPECT_EQ(errno, EAGAIN);
1795
1796		wait(&ret);
1797	} else {
1798		EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1799		EXPECT_EQ(errno, EAGAIN);
1800		exit(0);
1801	}
1802}
1803
1804TEST_F(tls_err, poll_partial_rec)
1805{
1806	struct pollfd pfd = { };
1807	ssize_t rec_len;
1808	char rec[256];
1809	char buf[128];
1810
1811	if (self->notls)
1812		SKIP(return, "no TLS support");
1813
1814	pfd.fd = self->cfd2;
1815	pfd.events = POLLIN;
1816	EXPECT_EQ(poll(&pfd, 1, 1), 0);
1817
1818	memrnd(buf, sizeof(buf));
1819	EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
1820	rec_len = recv(self->cfd, rec, sizeof(rec), 0);
1821	EXPECT_GT(rec_len, sizeof(buf));
1822
1823	/* Write 100B, not the full record ... */
1824	EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
1825	/* ... no full record should mean no POLLIN */
1826	pfd.fd = self->cfd2;
1827	pfd.events = POLLIN;
1828	EXPECT_EQ(poll(&pfd, 1, 1), 0);
1829	/* Now write the rest, and it should all pop out of the other end. */
1830	EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0), rec_len - 100);
1831	pfd.fd = self->cfd2;
1832	pfd.events = POLLIN;
1833	EXPECT_EQ(poll(&pfd, 1, 1), 1);
1834	EXPECT_EQ(recv(self->cfd2, rec, sizeof(rec), 0), sizeof(buf));
1835	EXPECT_EQ(memcmp(buf, rec, sizeof(buf)), 0);
1836}
1837
1838TEST_F(tls_err, epoll_partial_rec)
1839{
1840	struct epoll_event ev, events[10];
1841	ssize_t rec_len;
1842	char rec[256];
1843	char buf[128];
1844	int epollfd;
1845
1846	if (self->notls)
1847		SKIP(return, "no TLS support");
1848
1849	epollfd = epoll_create1(0);
1850	ASSERT_GE(epollfd, 0);
1851
1852	memset(&ev, 0, sizeof(ev));
1853	ev.events = EPOLLIN;
1854	ev.data.fd = self->cfd2;
1855	ASSERT_GE(epoll_ctl(epollfd, EPOLL_CTL_ADD, self->cfd2, &ev), 0);
1856
1857	EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 0);
1858
1859	memrnd(buf, sizeof(buf));
1860	EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
1861	rec_len = recv(self->cfd, rec, sizeof(rec), 0);
1862	EXPECT_GT(rec_len, sizeof(buf));
1863
1864	/* Write 100B, not the full record ... */
1865	EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
1866	/* ... no full record should mean no POLLIN */
1867	EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 0);
1868	/* Now write the rest, and it should all pop out of the other end. */
1869	EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0), rec_len - 100);
1870	EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 1);
1871	EXPECT_EQ(recv(self->cfd2, rec, sizeof(rec), 0), sizeof(buf));
1872	EXPECT_EQ(memcmp(buf, rec, sizeof(buf)), 0);
1873
1874	close(epollfd);
1875}
1876
1877TEST_F(tls_err, poll_partial_rec_async)
1878{
1879	struct pollfd pfd = { };
1880	ssize_t rec_len;
1881	char rec[256];
1882	char buf[128];
1883	char token;
1884	int p[2];
1885	int ret;
1886
1887	if (self->notls)
1888		SKIP(return, "no TLS support");
1889
1890	ASSERT_GE(pipe(p), 0);
1891
1892	memrnd(buf, sizeof(buf));
1893	EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
1894	rec_len = recv(self->cfd, rec, sizeof(rec), 0);
1895	EXPECT_GT(rec_len, sizeof(buf));
1896
1897	ret = fork();
1898	ASSERT_GE(ret, 0);
1899
1900	if (ret) {
1901		int status, pid2;
1902
1903		close(p[1]);
1904		usleep(1000); /* Give child a head start */
1905
1906		EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
1907
1908		EXPECT_EQ(read(p[0], &token, 1), 1); /* Barrier #1 */
1909
1910		EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0),
1911			  rec_len - 100);
1912
1913		pid2 = wait(&status);
1914		EXPECT_EQ(pid2, ret);
1915		EXPECT_EQ(status, 0);
1916	} else {
1917		close(p[0]);
1918
1919		/* Child should sleep in poll(), never get a wake */
1920		pfd.fd = self->cfd2;
1921		pfd.events = POLLIN;
1922		EXPECT_EQ(poll(&pfd, 1, 20), 0);
1923
1924		EXPECT_EQ(write(p[1], &token, 1), 1); /* Barrier #1 */
1925
1926		pfd.fd = self->cfd2;
1927		pfd.events = POLLIN;
1928		EXPECT_EQ(poll(&pfd, 1, 20), 1);
1929
1930		exit(!_metadata->passed);
1931	}
1932}
1933
1934TEST(non_established) {
1935	struct tls12_crypto_info_aes_gcm_256 tls12;
1936	struct sockaddr_in addr;
1937	int sfd, ret, fd;
1938	socklen_t len;
1939
1940	len = sizeof(addr);
1941
1942	memset(&tls12, 0, sizeof(tls12));
1943	tls12.info.version = TLS_1_2_VERSION;
1944	tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
1945
1946	addr.sin_family = AF_INET;
1947	addr.sin_addr.s_addr = htonl(INADDR_ANY);
1948	addr.sin_port = 0;
1949
1950	fd = socket(AF_INET, SOCK_STREAM, 0);
1951	sfd = socket(AF_INET, SOCK_STREAM, 0);
1952
1953	ret = bind(sfd, &addr, sizeof(addr));
1954	ASSERT_EQ(ret, 0);
1955	ret = listen(sfd, 10);
1956	ASSERT_EQ(ret, 0);
1957
1958	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1959	EXPECT_EQ(ret, -1);
1960	/* TLS ULP not supported */
1961	if (errno == ENOENT)
1962		return;
1963	EXPECT_EQ(errno, ENOTCONN);
1964
1965	ret = setsockopt(sfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1966	EXPECT_EQ(ret, -1);
1967	EXPECT_EQ(errno, ENOTCONN);
1968
1969	ret = getsockname(sfd, &addr, &len);
1970	ASSERT_EQ(ret, 0);
1971
1972	ret = connect(fd, &addr, sizeof(addr));
1973	ASSERT_EQ(ret, 0);
1974
1975	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1976	ASSERT_EQ(ret, 0);
1977
1978	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1979	EXPECT_EQ(ret, -1);
1980	EXPECT_EQ(errno, EEXIST);
1981
1982	close(fd);
1983	close(sfd);
1984}
1985
1986TEST(keysizes) {
1987	struct tls12_crypto_info_aes_gcm_256 tls12;
1988	int ret, fd, cfd;
1989	bool notls;
1990
1991	memset(&tls12, 0, sizeof(tls12));
1992	tls12.info.version = TLS_1_2_VERSION;
1993	tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
1994
1995	ulp_sock_pair(_metadata, &fd, &cfd, &notls);
1996
1997	if (!notls) {
1998		ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12,
1999				 sizeof(tls12));
2000		EXPECT_EQ(ret, 0);
2001
2002		ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12,
2003				 sizeof(tls12));
2004		EXPECT_EQ(ret, 0);
2005	}
2006
2007	close(fd);
2008	close(cfd);
2009}
2010
2011TEST(no_pad) {
2012	struct tls12_crypto_info_aes_gcm_256 tls12;
2013	int ret, fd, cfd, val;
2014	socklen_t len;
2015	bool notls;
2016
2017	memset(&tls12, 0, sizeof(tls12));
2018	tls12.info.version = TLS_1_3_VERSION;
2019	tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
2020
2021	ulp_sock_pair(_metadata, &fd, &cfd, &notls);
2022
2023	if (notls)
2024		exit(KSFT_SKIP);
2025
2026	ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12, sizeof(tls12));
2027	EXPECT_EQ(ret, 0);
2028
2029	ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12, sizeof(tls12));
2030	EXPECT_EQ(ret, 0);
2031
2032	val = 1;
2033	ret = setsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2034			 (void *)&val, sizeof(val));
2035	EXPECT_EQ(ret, 0);
2036
2037	len = sizeof(val);
2038	val = 2;
2039	ret = getsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2040			 (void *)&val, &len);
2041	EXPECT_EQ(ret, 0);
2042	EXPECT_EQ(val, 1);
2043	EXPECT_EQ(len, 4);
2044
2045	val = 0;
2046	ret = setsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2047			 (void *)&val, sizeof(val));
2048	EXPECT_EQ(ret, 0);
2049
2050	len = sizeof(val);
2051	val = 2;
2052	ret = getsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2053			 (void *)&val, &len);
2054	EXPECT_EQ(ret, 0);
2055	EXPECT_EQ(val, 0);
2056	EXPECT_EQ(len, 4);
2057
2058	close(fd);
2059	close(cfd);
2060}
2061
2062TEST(tls_v6ops) {
2063	struct tls_crypto_info_keys tls12;
2064	struct sockaddr_in6 addr, addr2;
2065	int sfd, ret, fd;
2066	socklen_t len, len2;
2067
2068	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_128, &tls12);
2069
2070	addr.sin6_family = AF_INET6;
2071	addr.sin6_addr = in6addr_any;
2072	addr.sin6_port = 0;
2073
2074	fd = socket(AF_INET6, SOCK_STREAM, 0);
2075	sfd = socket(AF_INET6, SOCK_STREAM, 0);
2076
2077	ret = bind(sfd, &addr, sizeof(addr));
2078	ASSERT_EQ(ret, 0);
2079	ret = listen(sfd, 10);
2080	ASSERT_EQ(ret, 0);
2081
2082	len = sizeof(addr);
2083	ret = getsockname(sfd, &addr, &len);
2084	ASSERT_EQ(ret, 0);
2085
2086	ret = connect(fd, &addr, sizeof(addr));
2087	ASSERT_EQ(ret, 0);
2088
2089	len = sizeof(addr);
2090	ret = getsockname(fd, &addr, &len);
2091	ASSERT_EQ(ret, 0);
2092
2093	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2094	if (ret) {
2095		ASSERT_EQ(errno, ENOENT);
2096		SKIP(return, "no TLS support");
2097	}
2098	ASSERT_EQ(ret, 0);
2099
2100	ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
2101	ASSERT_EQ(ret, 0);
2102
2103	ret = setsockopt(fd, SOL_TLS, TLS_RX, &tls12, tls12.len);
2104	ASSERT_EQ(ret, 0);
2105
2106	len2 = sizeof(addr2);
2107	ret = getsockname(fd, &addr2, &len2);
2108	ASSERT_EQ(ret, 0);
2109
2110	EXPECT_EQ(len2, len);
2111	EXPECT_EQ(memcmp(&addr, &addr2, len), 0);
2112
2113	close(fd);
2114	close(sfd);
2115}
2116
2117TEST(prequeue) {
2118	struct tls_crypto_info_keys tls12;
2119	char buf[20000], buf2[20000];
2120	struct sockaddr_in addr;
2121	int sfd, cfd, ret, fd;
2122	socklen_t len;
2123
2124	len = sizeof(addr);
2125	memrnd(buf, sizeof(buf));
2126
2127	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_256, &tls12);
2128
2129	addr.sin_family = AF_INET;
2130	addr.sin_addr.s_addr = htonl(INADDR_ANY);
2131	addr.sin_port = 0;
2132
2133	fd = socket(AF_INET, SOCK_STREAM, 0);
2134	sfd = socket(AF_INET, SOCK_STREAM, 0);
2135
2136	ASSERT_EQ(bind(sfd, &addr, sizeof(addr)), 0);
2137	ASSERT_EQ(listen(sfd, 10), 0);
2138	ASSERT_EQ(getsockname(sfd, &addr, &len), 0);
2139	ASSERT_EQ(connect(fd, &addr, sizeof(addr)), 0);
2140	ASSERT_GE(cfd = accept(sfd, &addr, &len), 0);
2141	close(sfd);
2142
2143	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2144	if (ret) {
2145		ASSERT_EQ(errno, ENOENT);
2146		SKIP(return, "no TLS support");
2147	}
2148
2149	ASSERT_EQ(setsockopt(fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2150	EXPECT_EQ(send(fd, buf, sizeof(buf), MSG_DONTWAIT), sizeof(buf));
2151
2152	ASSERT_EQ(setsockopt(cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls")), 0);
2153	ASSERT_EQ(setsockopt(cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2154	EXPECT_EQ(recv(cfd, buf2, sizeof(buf2), MSG_WAITALL), sizeof(buf2));
2155
2156	EXPECT_EQ(memcmp(buf, buf2, sizeof(buf)), 0);
2157
2158	close(fd);
2159	close(cfd);
2160}
2161
2162static void __attribute__((constructor)) fips_check(void) {
2163	int res;
2164	FILE *f;
2165
2166	f = fopen("/proc/sys/crypto/fips_enabled", "r");
2167	if (f) {
2168		res = fscanf(f, "%d", &fips_enabled);
2169		if (res != 1)
2170			ksft_print_msg("ERROR: Couldn't read /proc/sys/crypto/fips_enabled\n");
2171		fclose(f);
2172	}
2173}
2174
2175TEST_HARNESS_MAIN