Linux Audio

Check our new training course

Loading...
Note: File does not exist in v3.5.6.
   1/*
   2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
   3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
   4 * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
   5 * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
   6 * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
   7 *
   8 * This software is available to you under a choice of one of two
   9 * licenses.  You may choose to be licensed under the terms of the GNU
  10 * General Public License (GPL) Version 2, available from the file
  11 * COPYING in the main directory of this source tree, or the
  12 * OpenIB.org BSD license below:
  13 *
  14 *     Redistribution and use in source and binary forms, with or
  15 *     without modification, are permitted provided that the following
  16 *     conditions are met:
  17 *
  18 *      - Redistributions of source code must retain the above
  19 *        copyright notice, this list of conditions and the following
  20 *        disclaimer.
  21 *
  22 *      - Redistributions in binary form must reproduce the above
  23 *        copyright notice, this list of conditions and the following
  24 *        disclaimer in the documentation and/or other materials
  25 *        provided with the distribution.
  26 *
  27 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  28 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  29 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  30 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  31 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  32 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  33 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  34 * SOFTWARE.
  35 */
  36
  37#include <linux/sched/signal.h>
  38#include <linux/module.h>
  39#include <crypto/aead.h>
  40
  41#include <net/strparser.h>
  42#include <net/tls.h>
  43
  44#define MAX_IV_SIZE	TLS_CIPHER_AES_GCM_128_IV_SIZE
  45
  46static int tls_do_decryption(struct sock *sk,
  47			     struct scatterlist *sgin,
  48			     struct scatterlist *sgout,
  49			     char *iv_recv,
  50			     size_t data_len,
  51			     struct sk_buff *skb,
  52			     gfp_t flags)
  53{
  54	struct tls_context *tls_ctx = tls_get_ctx(sk);
  55	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
  56	struct strp_msg *rxm = strp_msg(skb);
  57	struct aead_request *aead_req;
  58
  59	int ret;
  60	unsigned int req_size = sizeof(struct aead_request) +
  61		crypto_aead_reqsize(ctx->aead_recv);
  62
  63	aead_req = kzalloc(req_size, flags);
  64	if (!aead_req)
  65		return -ENOMEM;
  66
  67	aead_request_set_tfm(aead_req, ctx->aead_recv);
  68	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
  69	aead_request_set_crypt(aead_req, sgin, sgout,
  70			       data_len + tls_ctx->rx.tag_size,
  71			       (u8 *)iv_recv);
  72	aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
  73				  crypto_req_done, &ctx->async_wait);
  74
  75	ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);
  76
  77	if (ret < 0)
  78		goto out;
  79
  80	rxm->offset += tls_ctx->rx.prepend_size;
  81	rxm->full_len -= tls_ctx->rx.overhead_size;
  82	tls_advance_record_sn(sk, &tls_ctx->rx);
  83
  84	ctx->decrypted = true;
  85
  86	ctx->saved_data_ready(sk);
  87
  88out:
  89	kfree(aead_req);
  90	return ret;
  91}
  92
  93static void trim_sg(struct sock *sk, struct scatterlist *sg,
  94		    int *sg_num_elem, unsigned int *sg_size, int target_size)
  95{
  96	int i = *sg_num_elem - 1;
  97	int trim = *sg_size - target_size;
  98
  99	if (trim <= 0) {
 100		WARN_ON(trim < 0);
 101		return;
 102	}
 103
 104	*sg_size = target_size;
 105	while (trim >= sg[i].length) {
 106		trim -= sg[i].length;
 107		sk_mem_uncharge(sk, sg[i].length);
 108		put_page(sg_page(&sg[i]));
 109		i--;
 110
 111		if (i < 0)
 112			goto out;
 113	}
 114
 115	sg[i].length -= trim;
 116	sk_mem_uncharge(sk, trim);
 117
 118out:
 119	*sg_num_elem = i + 1;
 120}
 121
 122static void trim_both_sgl(struct sock *sk, int target_size)
 123{
 124	struct tls_context *tls_ctx = tls_get_ctx(sk);
 125	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 126
 127	trim_sg(sk, ctx->sg_plaintext_data,
 128		&ctx->sg_plaintext_num_elem,
 129		&ctx->sg_plaintext_size,
 130		target_size);
 131
 132	if (target_size > 0)
 133		target_size += tls_ctx->tx.overhead_size;
 134
 135	trim_sg(sk, ctx->sg_encrypted_data,
 136		&ctx->sg_encrypted_num_elem,
 137		&ctx->sg_encrypted_size,
 138		target_size);
 139}
 140
 141static int alloc_encrypted_sg(struct sock *sk, int len)
 142{
 143	struct tls_context *tls_ctx = tls_get_ctx(sk);
 144	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 145	int rc = 0;
 146
 147	rc = sk_alloc_sg(sk, len,
 148			 ctx->sg_encrypted_data, 0,
 149			 &ctx->sg_encrypted_num_elem,
 150			 &ctx->sg_encrypted_size, 0);
 151
 152	return rc;
 153}
 154
 155static int alloc_plaintext_sg(struct sock *sk, int len)
 156{
 157	struct tls_context *tls_ctx = tls_get_ctx(sk);
 158	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 159	int rc = 0;
 160
 161	rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0,
 162			 &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size,
 163			 tls_ctx->pending_open_record_frags);
 164
 165	return rc;
 166}
 167
 168static void free_sg(struct sock *sk, struct scatterlist *sg,
 169		    int *sg_num_elem, unsigned int *sg_size)
 170{
 171	int i, n = *sg_num_elem;
 172
 173	for (i = 0; i < n; ++i) {
 174		sk_mem_uncharge(sk, sg[i].length);
 175		put_page(sg_page(&sg[i]));
 176	}
 177	*sg_num_elem = 0;
 178	*sg_size = 0;
 179}
 180
 181static void tls_free_both_sg(struct sock *sk)
 182{
 183	struct tls_context *tls_ctx = tls_get_ctx(sk);
 184	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 185
 186	free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem,
 187		&ctx->sg_encrypted_size);
 188
 189	free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
 190		&ctx->sg_plaintext_size);
 191}
 192
 193static int tls_do_encryption(struct tls_context *tls_ctx,
 194			     struct tls_sw_context *ctx, size_t data_len,
 195			     gfp_t flags)
 196{
 197	unsigned int req_size = sizeof(struct aead_request) +
 198		crypto_aead_reqsize(ctx->aead_send);
 199	struct aead_request *aead_req;
 200	int rc;
 201
 202	aead_req = kzalloc(req_size, flags);
 203	if (!aead_req)
 204		return -ENOMEM;
 205
 206	ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size;
 207	ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size;
 208
 209	aead_request_set_tfm(aead_req, ctx->aead_send);
 210	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
 211	aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out,
 212			       data_len, tls_ctx->tx.iv);
 213
 214	aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
 215				  crypto_req_done, &ctx->async_wait);
 216
 217	rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait);
 218
 219	ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
 220	ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;
 221
 222	kfree(aead_req);
 223	return rc;
 224}
 225
 226static int tls_push_record(struct sock *sk, int flags,
 227			   unsigned char record_type)
 228{
 229	struct tls_context *tls_ctx = tls_get_ctx(sk);
 230	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 231	int rc;
 232
 233	sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1);
 234	sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1);
 235
 236	tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size,
 237		     tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
 238		     record_type);
 239
 240	tls_fill_prepend(tls_ctx,
 241			 page_address(sg_page(&ctx->sg_encrypted_data[0])) +
 242			 ctx->sg_encrypted_data[0].offset,
 243			 ctx->sg_plaintext_size, record_type);
 244
 245	tls_ctx->pending_open_record_frags = 0;
 246	set_bit(TLS_PENDING_CLOSED_RECORD, &tls_ctx->flags);
 247
 248	rc = tls_do_encryption(tls_ctx, ctx, ctx->sg_plaintext_size,
 249			       sk->sk_allocation);
 250	if (rc < 0) {
 251		/* If we are called from write_space and
 252		 * we fail, we need to set this SOCK_NOSPACE
 253		 * to trigger another write_space in the future.
 254		 */
 255		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
 256		return rc;
 257	}
 258
 259	free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
 260		&ctx->sg_plaintext_size);
 261
 262	ctx->sg_encrypted_num_elem = 0;
 263	ctx->sg_encrypted_size = 0;
 264
 265	/* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */
 266	rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags);
 267	if (rc < 0 && rc != -EAGAIN)
 268		tls_err_abort(sk, EBADMSG);
 269
 270	tls_advance_record_sn(sk, &tls_ctx->tx);
 271	return rc;
 272}
 273
 274static int tls_sw_push_pending_record(struct sock *sk, int flags)
 275{
 276	return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
 277}
 278
 279static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
 280			      int length, int *pages_used,
 281			      unsigned int *size_used,
 282			      struct scatterlist *to, int to_max_pages,
 283			      bool charge)
 284{
 285	struct page *pages[MAX_SKB_FRAGS];
 286
 287	size_t offset;
 288	ssize_t copied, use;
 289	int i = 0;
 290	unsigned int size = *size_used;
 291	int num_elem = *pages_used;
 292	int rc = 0;
 293	int maxpages;
 294
 295	while (length > 0) {
 296		i = 0;
 297		maxpages = to_max_pages - num_elem;
 298		if (maxpages == 0) {
 299			rc = -EFAULT;
 300			goto out;
 301		}
 302		copied = iov_iter_get_pages(from, pages,
 303					    length,
 304					    maxpages, &offset);
 305		if (copied <= 0) {
 306			rc = -EFAULT;
 307			goto out;
 308		}
 309
 310		iov_iter_advance(from, copied);
 311
 312		length -= copied;
 313		size += copied;
 314		while (copied) {
 315			use = min_t(int, copied, PAGE_SIZE - offset);
 316
 317			sg_set_page(&to[num_elem],
 318				    pages[i], use, offset);
 319			sg_unmark_end(&to[num_elem]);
 320			if (charge)
 321				sk_mem_charge(sk, use);
 322
 323			offset = 0;
 324			copied -= use;
 325
 326			++i;
 327			++num_elem;
 328		}
 329	}
 330
 331out:
 332	*size_used = size;
 333	*pages_used = num_elem;
 334
 335	return rc;
 336}
 337
 338static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,
 339			     int bytes)
 340{
 341	struct tls_context *tls_ctx = tls_get_ctx(sk);
 342	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 343	struct scatterlist *sg = ctx->sg_plaintext_data;
 344	int copy, i, rc = 0;
 345
 346	for (i = tls_ctx->pending_open_record_frags;
 347	     i < ctx->sg_plaintext_num_elem; ++i) {
 348		copy = sg[i].length;
 349		if (copy_from_iter(
 350				page_address(sg_page(&sg[i])) + sg[i].offset,
 351				copy, from) != copy) {
 352			rc = -EFAULT;
 353			goto out;
 354		}
 355		bytes -= copy;
 356
 357		++tls_ctx->pending_open_record_frags;
 358
 359		if (!bytes)
 360			break;
 361	}
 362
 363out:
 364	return rc;
 365}
 366
 367int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 368{
 369	struct tls_context *tls_ctx = tls_get_ctx(sk);
 370	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 371	int ret = 0;
 372	int required_size;
 373	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
 374	bool eor = !(msg->msg_flags & MSG_MORE);
 375	size_t try_to_copy, copied = 0;
 376	unsigned char record_type = TLS_RECORD_TYPE_DATA;
 377	int record_room;
 378	bool full_record;
 379	int orig_size;
 380
 381	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
 382		return -ENOTSUPP;
 383
 384	lock_sock(sk);
 385
 386	if (tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo))
 387		goto send_end;
 388
 389	if (unlikely(msg->msg_controllen)) {
 390		ret = tls_proccess_cmsg(sk, msg, &record_type);
 391		if (ret)
 392			goto send_end;
 393	}
 394
 395	while (msg_data_left(msg)) {
 396		if (sk->sk_err) {
 397			ret = -sk->sk_err;
 398			goto send_end;
 399		}
 400
 401		orig_size = ctx->sg_plaintext_size;
 402		full_record = false;
 403		try_to_copy = msg_data_left(msg);
 404		record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
 405		if (try_to_copy >= record_room) {
 406			try_to_copy = record_room;
 407			full_record = true;
 408		}
 409
 410		required_size = ctx->sg_plaintext_size + try_to_copy +
 411				tls_ctx->tx.overhead_size;
 412
 413		if (!sk_stream_memory_free(sk))
 414			goto wait_for_sndbuf;
 415alloc_encrypted:
 416		ret = alloc_encrypted_sg(sk, required_size);
 417		if (ret) {
 418			if (ret != -ENOSPC)
 419				goto wait_for_memory;
 420
 421			/* Adjust try_to_copy according to the amount that was
 422			 * actually allocated. The difference is due
 423			 * to max sg elements limit
 424			 */
 425			try_to_copy -= required_size - ctx->sg_encrypted_size;
 426			full_record = true;
 427		}
 428
 429		if (full_record || eor) {
 430			ret = zerocopy_from_iter(sk, &msg->msg_iter,
 431				try_to_copy, &ctx->sg_plaintext_num_elem,
 432				&ctx->sg_plaintext_size,
 433				ctx->sg_plaintext_data,
 434				ARRAY_SIZE(ctx->sg_plaintext_data),
 435				true);
 436			if (ret)
 437				goto fallback_to_reg_send;
 438
 439			copied += try_to_copy;
 440			ret = tls_push_record(sk, msg->msg_flags, record_type);
 441			if (!ret)
 442				continue;
 443			if (ret == -EAGAIN)
 444				goto send_end;
 445
 446			copied -= try_to_copy;
 447fallback_to_reg_send:
 448			iov_iter_revert(&msg->msg_iter,
 449					ctx->sg_plaintext_size - orig_size);
 450			trim_sg(sk, ctx->sg_plaintext_data,
 451				&ctx->sg_plaintext_num_elem,
 452				&ctx->sg_plaintext_size,
 453				orig_size);
 454		}
 455
 456		required_size = ctx->sg_plaintext_size + try_to_copy;
 457alloc_plaintext:
 458		ret = alloc_plaintext_sg(sk, required_size);
 459		if (ret) {
 460			if (ret != -ENOSPC)
 461				goto wait_for_memory;
 462
 463			/* Adjust try_to_copy according to the amount that was
 464			 * actually allocated. The difference is due
 465			 * to max sg elements limit
 466			 */
 467			try_to_copy -= required_size - ctx->sg_plaintext_size;
 468			full_record = true;
 469
 470			trim_sg(sk, ctx->sg_encrypted_data,
 471				&ctx->sg_encrypted_num_elem,
 472				&ctx->sg_encrypted_size,
 473				ctx->sg_plaintext_size +
 474				tls_ctx->tx.overhead_size);
 475		}
 476
 477		ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
 478		if (ret)
 479			goto trim_sgl;
 480
 481		copied += try_to_copy;
 482		if (full_record || eor) {
 483push_record:
 484			ret = tls_push_record(sk, msg->msg_flags, record_type);
 485			if (ret) {
 486				if (ret == -ENOMEM)
 487					goto wait_for_memory;
 488
 489				goto send_end;
 490			}
 491		}
 492
 493		continue;
 494
 495wait_for_sndbuf:
 496		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
 497wait_for_memory:
 498		ret = sk_stream_wait_memory(sk, &timeo);
 499		if (ret) {
 500trim_sgl:
 501			trim_both_sgl(sk, orig_size);
 502			goto send_end;
 503		}
 504
 505		if (tls_is_pending_closed_record(tls_ctx))
 506			goto push_record;
 507
 508		if (ctx->sg_encrypted_size < required_size)
 509			goto alloc_encrypted;
 510
 511		goto alloc_plaintext;
 512	}
 513
 514send_end:
 515	ret = sk_stream_error(sk, msg->msg_flags, ret);
 516
 517	release_sock(sk);
 518	return copied ? copied : ret;
 519}
 520
 521int tls_sw_sendpage(struct sock *sk, struct page *page,
 522		    int offset, size_t size, int flags)
 523{
 524	struct tls_context *tls_ctx = tls_get_ctx(sk);
 525	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 526	int ret = 0;
 527	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
 528	bool eor;
 529	size_t orig_size = size;
 530	unsigned char record_type = TLS_RECORD_TYPE_DATA;
 531	struct scatterlist *sg;
 532	bool full_record;
 533	int record_room;
 534
 535	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
 536		      MSG_SENDPAGE_NOTLAST))
 537		return -ENOTSUPP;
 538
 539	/* No MSG_EOR from splice, only look at MSG_MORE */
 540	eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
 541
 542	lock_sock(sk);
 543
 544	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
 545
 546	if (tls_complete_pending_work(sk, tls_ctx, flags, &timeo))
 547		goto sendpage_end;
 548
 549	/* Call the sk_stream functions to manage the sndbuf mem. */
 550	while (size > 0) {
 551		size_t copy, required_size;
 552
 553		if (sk->sk_err) {
 554			ret = -sk->sk_err;
 555			goto sendpage_end;
 556		}
 557
 558		full_record = false;
 559		record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
 560		copy = size;
 561		if (copy >= record_room) {
 562			copy = record_room;
 563			full_record = true;
 564		}
 565		required_size = ctx->sg_plaintext_size + copy +
 566			      tls_ctx->tx.overhead_size;
 567
 568		if (!sk_stream_memory_free(sk))
 569			goto wait_for_sndbuf;
 570alloc_payload:
 571		ret = alloc_encrypted_sg(sk, required_size);
 572		if (ret) {
 573			if (ret != -ENOSPC)
 574				goto wait_for_memory;
 575
 576			/* Adjust copy according to the amount that was
 577			 * actually allocated. The difference is due
 578			 * to max sg elements limit
 579			 */
 580			copy -= required_size - ctx->sg_plaintext_size;
 581			full_record = true;
 582		}
 583
 584		get_page(page);
 585		sg = ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem;
 586		sg_set_page(sg, page, copy, offset);
 587		sg_unmark_end(sg);
 588
 589		ctx->sg_plaintext_num_elem++;
 590
 591		sk_mem_charge(sk, copy);
 592		offset += copy;
 593		size -= copy;
 594		ctx->sg_plaintext_size += copy;
 595		tls_ctx->pending_open_record_frags = ctx->sg_plaintext_num_elem;
 596
 597		if (full_record || eor ||
 598		    ctx->sg_plaintext_num_elem ==
 599		    ARRAY_SIZE(ctx->sg_plaintext_data)) {
 600push_record:
 601			ret = tls_push_record(sk, flags, record_type);
 602			if (ret) {
 603				if (ret == -ENOMEM)
 604					goto wait_for_memory;
 605
 606				goto sendpage_end;
 607			}
 608		}
 609		continue;
 610wait_for_sndbuf:
 611		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
 612wait_for_memory:
 613		ret = sk_stream_wait_memory(sk, &timeo);
 614		if (ret) {
 615			trim_both_sgl(sk, ctx->sg_plaintext_size);
 616			goto sendpage_end;
 617		}
 618
 619		if (tls_is_pending_closed_record(tls_ctx))
 620			goto push_record;
 621
 622		goto alloc_payload;
 623	}
 624
 625sendpage_end:
 626	if (orig_size > size)
 627		ret = orig_size - size;
 628	else
 629		ret = sk_stream_error(sk, flags, ret);
 630
 631	release_sock(sk);
 632	return ret;
 633}
 634
 635static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
 636				     long timeo, int *err)
 637{
 638	struct tls_context *tls_ctx = tls_get_ctx(sk);
 639	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 640	struct sk_buff *skb;
 641	DEFINE_WAIT_FUNC(wait, woken_wake_function);
 642
 643	while (!(skb = ctx->recv_pkt)) {
 644		if (sk->sk_err) {
 645			*err = sock_error(sk);
 646			return NULL;
 647		}
 648
 649		if (sock_flag(sk, SOCK_DONE))
 650			return NULL;
 651
 652		if ((flags & MSG_DONTWAIT) || !timeo) {
 653			*err = -EAGAIN;
 654			return NULL;
 655		}
 656
 657		add_wait_queue(sk_sleep(sk), &wait);
 658		sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
 659		sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
 660		sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
 661		remove_wait_queue(sk_sleep(sk), &wait);
 662
 663		/* Handle signals */
 664		if (signal_pending(current)) {
 665			*err = sock_intr_errno(timeo);
 666			return NULL;
 667		}
 668	}
 669
 670	return skb;
 671}
 672
 673static int decrypt_skb(struct sock *sk, struct sk_buff *skb,
 674		       struct scatterlist *sgout)
 675{
 676	struct tls_context *tls_ctx = tls_get_ctx(sk);
 677	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 678	char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + MAX_IV_SIZE];
 679	struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2];
 680	struct scatterlist *sgin = &sgin_arr[0];
 681	struct strp_msg *rxm = strp_msg(skb);
 682	int ret, nsg = ARRAY_SIZE(sgin_arr);
 683	struct sk_buff *unused;
 684
 685	ret = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
 686			    iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
 687			    tls_ctx->rx.iv_size);
 688	if (ret < 0)
 689		return ret;
 690
 691	memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
 692	if (!sgout) {
 693		nsg = skb_cow_data(skb, 0, &unused) + 1;
 694		sgin = kmalloc_array(nsg, sizeof(*sgin), sk->sk_allocation);
 695		if (!sgout)
 696			sgout = sgin;
 697	}
 698
 699	sg_init_table(sgin, nsg);
 700	sg_set_buf(&sgin[0], ctx->rx_aad_ciphertext, TLS_AAD_SPACE_SIZE);
 701
 702	nsg = skb_to_sgvec(skb, &sgin[1],
 703			   rxm->offset + tls_ctx->rx.prepend_size,
 704			   rxm->full_len - tls_ctx->rx.prepend_size);
 705
 706	tls_make_aad(ctx->rx_aad_ciphertext,
 707		     rxm->full_len - tls_ctx->rx.overhead_size,
 708		     tls_ctx->rx.rec_seq,
 709		     tls_ctx->rx.rec_seq_size,
 710		     ctx->control);
 711
 712	ret = tls_do_decryption(sk, sgin, sgout, iv,
 713				rxm->full_len - tls_ctx->rx.overhead_size,
 714				skb, sk->sk_allocation);
 715
 716	if (sgin != &sgin_arr[0])
 717		kfree(sgin);
 718
 719	return ret;
 720}
 721
 722static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
 723			       unsigned int len)
 724{
 725	struct tls_context *tls_ctx = tls_get_ctx(sk);
 726	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 727	struct strp_msg *rxm = strp_msg(skb);
 728
 729	if (len < rxm->full_len) {
 730		rxm->offset += len;
 731		rxm->full_len -= len;
 732
 733		return false;
 734	}
 735
 736	/* Finished with message */
 737	ctx->recv_pkt = NULL;
 738	kfree_skb(skb);
 739	strp_unpause(&ctx->strp);
 740
 741	return true;
 742}
 743
 744int tls_sw_recvmsg(struct sock *sk,
 745		   struct msghdr *msg,
 746		   size_t len,
 747		   int nonblock,
 748		   int flags,
 749		   int *addr_len)
 750{
 751	struct tls_context *tls_ctx = tls_get_ctx(sk);
 752	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 753	unsigned char control;
 754	struct strp_msg *rxm;
 755	struct sk_buff *skb;
 756	ssize_t copied = 0;
 757	bool cmsg = false;
 758	int err = 0;
 759	long timeo;
 760
 761	flags |= nonblock;
 762
 763	if (unlikely(flags & MSG_ERRQUEUE))
 764		return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
 765
 766	lock_sock(sk);
 767
 768	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
 769	do {
 770		bool zc = false;
 771		int chunk = 0;
 772
 773		skb = tls_wait_data(sk, flags, timeo, &err);
 774		if (!skb)
 775			goto recv_end;
 776
 777		rxm = strp_msg(skb);
 778		if (!cmsg) {
 779			int cerr;
 780
 781			cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
 782					sizeof(ctx->control), &ctx->control);
 783			cmsg = true;
 784			control = ctx->control;
 785			if (ctx->control != TLS_RECORD_TYPE_DATA) {
 786				if (cerr || msg->msg_flags & MSG_CTRUNC) {
 787					err = -EIO;
 788					goto recv_end;
 789				}
 790			}
 791		} else if (control != ctx->control) {
 792			goto recv_end;
 793		}
 794
 795		if (!ctx->decrypted) {
 796			int page_count;
 797			int to_copy;
 798
 799			page_count = iov_iter_npages(&msg->msg_iter,
 800						     MAX_SKB_FRAGS);
 801			to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
 802			if (to_copy <= len && page_count < MAX_SKB_FRAGS &&
 803			    likely(!(flags & MSG_PEEK)))  {
 804				struct scatterlist sgin[MAX_SKB_FRAGS + 1];
 805				int pages = 0;
 806
 807				zc = true;
 808				sg_init_table(sgin, MAX_SKB_FRAGS + 1);
 809				sg_set_buf(&sgin[0], ctx->rx_aad_plaintext,
 810					   TLS_AAD_SPACE_SIZE);
 811
 812				err = zerocopy_from_iter(sk, &msg->msg_iter,
 813							 to_copy, &pages,
 814							 &chunk, &sgin[1],
 815							 MAX_SKB_FRAGS,	false);
 816				if (err < 0)
 817					goto fallback_to_reg_recv;
 818
 819				err = decrypt_skb(sk, skb, sgin);
 820				for (; pages > 0; pages--)
 821					put_page(sg_page(&sgin[pages]));
 822				if (err < 0) {
 823					tls_err_abort(sk, EBADMSG);
 824					goto recv_end;
 825				}
 826			} else {
 827fallback_to_reg_recv:
 828				err = decrypt_skb(sk, skb, NULL);
 829				if (err < 0) {
 830					tls_err_abort(sk, EBADMSG);
 831					goto recv_end;
 832				}
 833			}
 834			ctx->decrypted = true;
 835		}
 836
 837		if (!zc) {
 838			chunk = min_t(unsigned int, rxm->full_len, len);
 839			err = skb_copy_datagram_msg(skb, rxm->offset, msg,
 840						    chunk);
 841			if (err < 0)
 842				goto recv_end;
 843		}
 844
 845		copied += chunk;
 846		len -= chunk;
 847		if (likely(!(flags & MSG_PEEK))) {
 848			u8 control = ctx->control;
 849
 850			if (tls_sw_advance_skb(sk, skb, chunk)) {
 851				/* Return full control message to
 852				 * userspace before trying to parse
 853				 * another message type
 854				 */
 855				msg->msg_flags |= MSG_EOR;
 856				if (control != TLS_RECORD_TYPE_DATA)
 857					goto recv_end;
 858			}
 859		}
 860	} while (len);
 861
 862recv_end:
 863	release_sock(sk);
 864	return copied ? : err;
 865}
 866
 867ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
 868			   struct pipe_inode_info *pipe,
 869			   size_t len, unsigned int flags)
 870{
 871	struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
 872	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 873	struct strp_msg *rxm = NULL;
 874	struct sock *sk = sock->sk;
 875	struct sk_buff *skb;
 876	ssize_t copied = 0;
 877	int err = 0;
 878	long timeo;
 879	int chunk;
 880
 881	lock_sock(sk);
 882
 883	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
 884
 885	skb = tls_wait_data(sk, flags, timeo, &err);
 886	if (!skb)
 887		goto splice_read_end;
 888
 889	/* splice does not support reading control messages */
 890	if (ctx->control != TLS_RECORD_TYPE_DATA) {
 891		err = -ENOTSUPP;
 892		goto splice_read_end;
 893	}
 894
 895	if (!ctx->decrypted) {
 896		err = decrypt_skb(sk, skb, NULL);
 897
 898		if (err < 0) {
 899			tls_err_abort(sk, EBADMSG);
 900			goto splice_read_end;
 901		}
 902		ctx->decrypted = true;
 903	}
 904	rxm = strp_msg(skb);
 905
 906	chunk = min_t(unsigned int, rxm->full_len, len);
 907	copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
 908	if (copied < 0)
 909		goto splice_read_end;
 910
 911	if (likely(!(flags & MSG_PEEK)))
 912		tls_sw_advance_skb(sk, skb, copied);
 913
 914splice_read_end:
 915	release_sock(sk);
 916	return copied ? : err;
 917}
 918
 919unsigned int tls_sw_poll(struct file *file, struct socket *sock,
 920			 struct poll_table_struct *wait)
 921{
 922	unsigned int ret;
 923	struct sock *sk = sock->sk;
 924	struct tls_context *tls_ctx = tls_get_ctx(sk);
 925	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 926
 927	/* Grab POLLOUT and POLLHUP from the underlying socket */
 928	ret = ctx->sk_poll(file, sock, wait);
 929
 930	/* Clear POLLIN bits, and set based on recv_pkt */
 931	ret &= ~(POLLIN | POLLRDNORM);
 932	if (ctx->recv_pkt)
 933		ret |= POLLIN | POLLRDNORM;
 934
 935	return ret;
 936}
 937
 938static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
 939{
 940	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
 941	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 942	char header[tls_ctx->rx.prepend_size];
 943	struct strp_msg *rxm = strp_msg(skb);
 944	size_t cipher_overhead;
 945	size_t data_len = 0;
 946	int ret;
 947
 948	/* Verify that we have a full TLS header, or wait for more data */
 949	if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
 950		return 0;
 951
 952	/* Linearize header to local buffer */
 953	ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);
 954
 955	if (ret < 0)
 956		goto read_failure;
 957
 958	ctx->control = header[0];
 959
 960	data_len = ((header[4] & 0xFF) | (header[3] << 8));
 961
 962	cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size;
 963
 964	if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) {
 965		ret = -EMSGSIZE;
 966		goto read_failure;
 967	}
 968	if (data_len < cipher_overhead) {
 969		ret = -EBADMSG;
 970		goto read_failure;
 971	}
 972
 973	if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.version) ||
 974	    header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.version)) {
 975		ret = -EINVAL;
 976		goto read_failure;
 977	}
 978
 979	return data_len + TLS_HEADER_SIZE;
 980
 981read_failure:
 982	tls_err_abort(strp->sk, ret);
 983
 984	return ret;
 985}
 986
 987static void tls_queue(struct strparser *strp, struct sk_buff *skb)
 988{
 989	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
 990	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 991	struct strp_msg *rxm;
 992
 993	rxm = strp_msg(skb);
 994
 995	ctx->decrypted = false;
 996
 997	ctx->recv_pkt = skb;
 998	strp_pause(strp);
 999
1000	strp->sk->sk_state_change(strp->sk);
1001}
1002
1003static void tls_data_ready(struct sock *sk)
1004{
1005	struct tls_context *tls_ctx = tls_get_ctx(sk);
1006	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
1007
1008	strp_data_ready(&ctx->strp);
1009}
1010
1011void tls_sw_free_resources(struct sock *sk)
1012{
1013	struct tls_context *tls_ctx = tls_get_ctx(sk);
1014	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
1015
1016	if (ctx->aead_send)
1017		crypto_free_aead(ctx->aead_send);
1018	if (ctx->aead_recv) {
1019		if (ctx->recv_pkt) {
1020			kfree_skb(ctx->recv_pkt);
1021			ctx->recv_pkt = NULL;
1022		}
1023		crypto_free_aead(ctx->aead_recv);
1024		strp_stop(&ctx->strp);
1025		write_lock_bh(&sk->sk_callback_lock);
1026		sk->sk_data_ready = ctx->saved_data_ready;
1027		write_unlock_bh(&sk->sk_callback_lock);
1028		release_sock(sk);
1029		strp_done(&ctx->strp);
1030		lock_sock(sk);
1031	}
1032
1033	tls_free_both_sg(sk);
1034
1035	kfree(ctx);
1036	kfree(tls_ctx);
1037}
1038
1039int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
1040{
1041	char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE];
1042	struct tls_crypto_info *crypto_info;
1043	struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
1044	struct tls_sw_context *sw_ctx;
1045	struct cipher_context *cctx;
1046	struct crypto_aead **aead;
1047	struct strp_callbacks cb;
1048	u16 nonce_size, tag_size, iv_size, rec_seq_size;
1049	char *iv, *rec_seq;
1050	int rc = 0;
1051
1052	if (!ctx) {
1053		rc = -EINVAL;
1054		goto out;
1055	}
1056
1057	if (!ctx->priv_ctx) {
1058		sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL);
1059		if (!sw_ctx) {
1060			rc = -ENOMEM;
1061			goto out;
1062		}
1063		crypto_init_wait(&sw_ctx->async_wait);
1064	} else {
1065		sw_ctx = ctx->priv_ctx;
1066	}
1067
1068	ctx->priv_ctx = (struct tls_offload_context *)sw_ctx;
1069
1070	if (tx) {
1071		crypto_info = &ctx->crypto_send;
1072		cctx = &ctx->tx;
1073		aead = &sw_ctx->aead_send;
1074	} else {
1075		crypto_info = &ctx->crypto_recv;
1076		cctx = &ctx->rx;
1077		aead = &sw_ctx->aead_recv;
1078	}
1079
1080	switch (crypto_info->cipher_type) {
1081	case TLS_CIPHER_AES_GCM_128: {
1082		nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1083		tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
1084		iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1085		iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
1086		rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
1087		rec_seq =
1088		 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
1089		gcm_128_info =
1090			(struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
1091		break;
1092	}
1093	default:
1094		rc = -EINVAL;
1095		goto free_priv;
1096	}
1097
1098	/* Sanity-check the IV size for stack allocations. */
1099	if (iv_size > MAX_IV_SIZE) {
1100		rc = -EINVAL;
1101		goto free_priv;
1102	}
1103
1104	cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
1105	cctx->tag_size = tag_size;
1106	cctx->overhead_size = cctx->prepend_size + cctx->tag_size;
1107	cctx->iv_size = iv_size;
1108	cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
1109			   GFP_KERNEL);
1110	if (!cctx->iv) {
1111		rc = -ENOMEM;
1112		goto free_priv;
1113	}
1114	memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
1115	memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
1116	cctx->rec_seq_size = rec_seq_size;
1117	cctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
1118	if (!cctx->rec_seq) {
1119		rc = -ENOMEM;
1120		goto free_iv;
1121	}
1122	memcpy(cctx->rec_seq, rec_seq, rec_seq_size);
1123
1124	if (tx) {
1125		sg_init_table(sw_ctx->sg_encrypted_data,
1126			      ARRAY_SIZE(sw_ctx->sg_encrypted_data));
1127		sg_init_table(sw_ctx->sg_plaintext_data,
1128			      ARRAY_SIZE(sw_ctx->sg_plaintext_data));
1129
1130		sg_init_table(sw_ctx->sg_aead_in, 2);
1131		sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space,
1132			   sizeof(sw_ctx->aad_space));
1133		sg_unmark_end(&sw_ctx->sg_aead_in[1]);
1134		sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data);
1135		sg_init_table(sw_ctx->sg_aead_out, 2);
1136		sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space,
1137			   sizeof(sw_ctx->aad_space));
1138		sg_unmark_end(&sw_ctx->sg_aead_out[1]);
1139		sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data);
1140	}
1141
1142	if (!*aead) {
1143		*aead = crypto_alloc_aead("gcm(aes)", 0, 0);
1144		if (IS_ERR(*aead)) {
1145			rc = PTR_ERR(*aead);
1146			*aead = NULL;
1147			goto free_rec_seq;
1148		}
1149	}
1150
1151	ctx->push_pending_record = tls_sw_push_pending_record;
1152
1153	memcpy(keyval, gcm_128_info->key, TLS_CIPHER_AES_GCM_128_KEY_SIZE);
1154
1155	rc = crypto_aead_setkey(*aead, keyval,
1156				TLS_CIPHER_AES_GCM_128_KEY_SIZE);
1157	if (rc)
1158		goto free_aead;
1159
1160	rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
1161	if (rc)
1162		goto free_aead;
1163
1164	if (!tx) {
1165		/* Set up strparser */
1166		memset(&cb, 0, sizeof(cb));
1167		cb.rcv_msg = tls_queue;
1168		cb.parse_msg = tls_read_size;
1169
1170		strp_init(&sw_ctx->strp, sk, &cb);
1171
1172		write_lock_bh(&sk->sk_callback_lock);
1173		sw_ctx->saved_data_ready = sk->sk_data_ready;
1174		sk->sk_data_ready = tls_data_ready;
1175		write_unlock_bh(&sk->sk_callback_lock);
1176
1177		sw_ctx->sk_poll = sk->sk_socket->ops->poll;
1178
1179		strp_check_rcv(&sw_ctx->strp);
1180	}
1181
1182	goto out;
1183
1184free_aead:
1185	crypto_free_aead(*aead);
1186	*aead = NULL;
1187free_rec_seq:
1188	kfree(cctx->rec_seq);
1189	cctx->rec_seq = NULL;
1190free_iv:
1191	kfree(ctx->tx.iv);
1192	ctx->tx.iv = NULL;
1193free_priv:
1194	kfree(ctx->priv_ctx);
1195	ctx->priv_ctx = NULL;
1196out:
1197	return rc;
1198}