Linux Audio

Check our new training course

Loading...
Note: File does not exist in v5.4.
  1// SPDX-License-Identifier: GPL-2.0
  2#include <net/tcp.h>
  3#include <net/strparser.h>
  4#include <net/xfrm.h>
  5#include <net/esp.h>
  6#include <net/espintcp.h>
  7#include <linux/skmsg.h>
  8#include <net/inet_common.h>
  9#if IS_ENABLED(CONFIG_IPV6)
 10#include <net/ipv6_stubs.h>
 11#endif
 12
 13static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb,
 14			  struct sock *sk)
 15{
 16	if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf ||
 17	    !sk_rmem_schedule(sk, skb, skb->truesize)) {
 18		XFRM_INC_STATS(sock_net(sk), LINUX_MIB_XFRMINERROR);
 19		kfree_skb(skb);
 20		return;
 21	}
 22
 23	skb_set_owner_r(skb, sk);
 24
 25	memset(skb->cb, 0, sizeof(skb->cb));
 26	skb_queue_tail(&ctx->ike_queue, skb);
 27	ctx->saved_data_ready(sk);
 28}
 29
 30static void handle_esp(struct sk_buff *skb, struct sock *sk)
 31{
 32	struct tcp_skb_cb *tcp_cb = (struct tcp_skb_cb *)skb->cb;
 33
 34	skb_reset_transport_header(skb);
 35
 36	/* restore IP CB, we need at least IP6CB->nhoff */
 37	memmove(skb->cb, &tcp_cb->header, sizeof(tcp_cb->header));
 38
 39	rcu_read_lock();
 40	skb->dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif);
 41	local_bh_disable();
 42#if IS_ENABLED(CONFIG_IPV6)
 43	if (sk->sk_family == AF_INET6)
 44		ipv6_stub->xfrm6_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
 45	else
 46#endif
 47		xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
 48	local_bh_enable();
 49	rcu_read_unlock();
 50}
 51
 52static void espintcp_rcv(struct strparser *strp, struct sk_buff *skb)
 53{
 54	struct espintcp_ctx *ctx = container_of(strp, struct espintcp_ctx,
 55						strp);
 56	struct strp_msg *rxm = strp_msg(skb);
 57	int len = rxm->full_len - 2;
 58	u32 nonesp_marker;
 59	int err;
 60
 61	/* keepalive packet? */
 62	if (unlikely(len == 1)) {
 63		u8 data;
 64
 65		err = skb_copy_bits(skb, rxm->offset + 2, &data, 1);
 66		if (err < 0) {
 67			XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
 68			kfree_skb(skb);
 69			return;
 70		}
 71
 72		if (data == 0xff) {
 73			kfree_skb(skb);
 74			return;
 75		}
 76	}
 77
 78	/* drop other short messages */
 79	if (unlikely(len <= sizeof(nonesp_marker))) {
 80		XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
 81		kfree_skb(skb);
 82		return;
 83	}
 84
 85	err = skb_copy_bits(skb, rxm->offset + 2, &nonesp_marker,
 86			    sizeof(nonesp_marker));
 87	if (err < 0) {
 88		XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
 89		kfree_skb(skb);
 90		return;
 91	}
 92
 93	/* remove header, leave non-ESP marker/SPI */
 94	if (!__pskb_pull(skb, rxm->offset + 2)) {
 95		XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR);
 96		kfree_skb(skb);
 97		return;
 98	}
 99
100	if (pskb_trim(skb, rxm->full_len - 2) != 0) {
101		XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR);
102		kfree_skb(skb);
103		return;
104	}
105
106	if (nonesp_marker == 0)
107		handle_nonesp(ctx, skb, strp->sk);
108	else
109		handle_esp(skb, strp->sk);
110}
111
112static int espintcp_parse(struct strparser *strp, struct sk_buff *skb)
113{
114	struct strp_msg *rxm = strp_msg(skb);
115	__be16 blen;
116	u16 len;
117	int err;
118
119	if (skb->len < rxm->offset + 2)
120		return 0;
121
122	err = skb_copy_bits(skb, rxm->offset, &blen, sizeof(blen));
123	if (err < 0)
124		return err;
125
126	len = be16_to_cpu(blen);
127	if (len < 2)
128		return -EINVAL;
129
130	return len;
131}
132
133static int espintcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
134			    int nonblock, int flags, int *addr_len)
135{
136	struct espintcp_ctx *ctx = espintcp_getctx(sk);
137	struct sk_buff *skb;
138	int err = 0;
139	int copied;
140	int off = 0;
141
142	flags |= nonblock ? MSG_DONTWAIT : 0;
143
144	skb = __skb_recv_datagram(sk, &ctx->ike_queue, flags, &off, &err);
145	if (!skb) {
146		if (err == -EAGAIN && sk->sk_shutdown & RCV_SHUTDOWN)
147			return 0;
148		return err;
149	}
150
151	copied = len;
152	if (copied > skb->len)
153		copied = skb->len;
154	else if (copied < skb->len)
155		msg->msg_flags |= MSG_TRUNC;
156
157	err = skb_copy_datagram_msg(skb, 0, msg, copied);
158	if (unlikely(err)) {
159		kfree_skb(skb);
160		return err;
161	}
162
163	if (flags & MSG_TRUNC)
164		copied = skb->len;
165	kfree_skb(skb);
166	return copied;
167}
168
169int espintcp_queue_out(struct sock *sk, struct sk_buff *skb)
170{
171	struct espintcp_ctx *ctx = espintcp_getctx(sk);
172
173	if (skb_queue_len(&ctx->out_queue) >= netdev_max_backlog)
174		return -ENOBUFS;
175
176	__skb_queue_tail(&ctx->out_queue, skb);
177
178	return 0;
179}
180EXPORT_SYMBOL_GPL(espintcp_queue_out);
181
182/* espintcp length field is 2B and length includes the length field's size */
183#define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2)
184
185static int espintcp_sendskb_locked(struct sock *sk, struct espintcp_msg *emsg,
186				   int flags)
187{
188	do {
189		int ret;
190
191		ret = skb_send_sock_locked(sk, emsg->skb,
192					   emsg->offset, emsg->len);
193		if (ret < 0)
194			return ret;
195
196		emsg->len -= ret;
197		emsg->offset += ret;
198	} while (emsg->len > 0);
199
200	kfree_skb(emsg->skb);
201	memset(emsg, 0, sizeof(*emsg));
202
203	return 0;
204}
205
206static int espintcp_sendskmsg_locked(struct sock *sk,
207				     struct espintcp_msg *emsg, int flags)
208{
209	struct sk_msg *skmsg = &emsg->skmsg;
210	struct scatterlist *sg;
211	int done = 0;
212	int ret;
213
214	flags |= MSG_SENDPAGE_NOTLAST;
215	sg = &skmsg->sg.data[skmsg->sg.start];
216	do {
217		size_t size = sg->length - emsg->offset;
218		int offset = sg->offset + emsg->offset;
219		struct page *p;
220
221		emsg->offset = 0;
222
223		if (sg_is_last(sg))
224			flags &= ~MSG_SENDPAGE_NOTLAST;
225
226		p = sg_page(sg);
227retry:
228		ret = do_tcp_sendpages(sk, p, offset, size, flags);
229		if (ret < 0) {
230			emsg->offset = offset - sg->offset;
231			skmsg->sg.start += done;
232			return ret;
233		}
234
235		if (ret != size) {
236			offset += ret;
237			size -= ret;
238			goto retry;
239		}
240
241		done++;
242		put_page(p);
243		sk_mem_uncharge(sk, sg->length);
244		sg = sg_next(sg);
245	} while (sg);
246
247	memset(emsg, 0, sizeof(*emsg));
248
249	return 0;
250}
251
252static int espintcp_push_msgs(struct sock *sk, int flags)
253{
254	struct espintcp_ctx *ctx = espintcp_getctx(sk);
255	struct espintcp_msg *emsg = &ctx->partial;
256	int err;
257
258	if (!emsg->len)
259		return 0;
260
261	if (ctx->tx_running)
262		return -EAGAIN;
263	ctx->tx_running = 1;
264
265	if (emsg->skb)
266		err = espintcp_sendskb_locked(sk, emsg, flags);
267	else
268		err = espintcp_sendskmsg_locked(sk, emsg, flags);
269	if (err == -EAGAIN) {
270		ctx->tx_running = 0;
271		return flags & MSG_DONTWAIT ? -EAGAIN : 0;
272	}
273	if (!err)
274		memset(emsg, 0, sizeof(*emsg));
275
276	ctx->tx_running = 0;
277
278	return err;
279}
280
281int espintcp_push_skb(struct sock *sk, struct sk_buff *skb)
282{
283	struct espintcp_ctx *ctx = espintcp_getctx(sk);
284	struct espintcp_msg *emsg = &ctx->partial;
285	unsigned int len;
286	int offset;
287
288	if (sk->sk_state != TCP_ESTABLISHED) {
289		kfree_skb(skb);
290		return -ECONNRESET;
291	}
292
293	offset = skb_transport_offset(skb);
294	len = skb->len - offset;
295
296	espintcp_push_msgs(sk, 0);
297
298	if (emsg->len) {
299		kfree_skb(skb);
300		return -ENOBUFS;
301	}
302
303	skb_set_owner_w(skb, sk);
304
305	emsg->offset = offset;
306	emsg->len = len;
307	emsg->skb = skb;
308
309	espintcp_push_msgs(sk, 0);
310
311	return 0;
312}
313EXPORT_SYMBOL_GPL(espintcp_push_skb);
314
315static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
316{
317	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
318	struct espintcp_ctx *ctx = espintcp_getctx(sk);
319	struct espintcp_msg *emsg = &ctx->partial;
320	struct iov_iter pfx_iter;
321	struct kvec pfx_iov = {};
322	size_t msglen = size + 2;
323	char buf[2] = {0};
324	int err, end;
325
326	if (msg->msg_flags & ~MSG_DONTWAIT)
327		return -EOPNOTSUPP;
328
329	if (size > MAX_ESPINTCP_MSG)
330		return -EMSGSIZE;
331
332	if (msg->msg_controllen)
333		return -EOPNOTSUPP;
334
335	lock_sock(sk);
336
337	err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT);
338	if (err < 0) {
339		if (err != -EAGAIN || !(msg->msg_flags & MSG_DONTWAIT))
340			err = -ENOBUFS;
341		goto unlock;
342	}
343
344	sk_msg_init(&emsg->skmsg);
345	while (1) {
346		/* only -ENOMEM is possible since we don't coalesce */
347		err = sk_msg_alloc(sk, &emsg->skmsg, msglen, 0);
348		if (!err)
349			break;
350
351		err = sk_stream_wait_memory(sk, &timeo);
352		if (err)
353			goto fail;
354	}
355
356	*((__be16 *)buf) = cpu_to_be16(msglen);
357	pfx_iov.iov_base = buf;
358	pfx_iov.iov_len = sizeof(buf);
359	iov_iter_kvec(&pfx_iter, WRITE, &pfx_iov, 1, pfx_iov.iov_len);
360
361	err = sk_msg_memcopy_from_iter(sk, &pfx_iter, &emsg->skmsg,
362				       pfx_iov.iov_len);
363	if (err < 0)
364		goto fail;
365
366	err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, &emsg->skmsg, size);
367	if (err < 0)
368		goto fail;
369
370	end = emsg->skmsg.sg.end;
371	emsg->len = size;
372	sk_msg_iter_var_prev(end);
373	sg_mark_end(sk_msg_elem(&emsg->skmsg, end));
374
375	tcp_rate_check_app_limited(sk);
376
377	err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT);
378	/* this message could be partially sent, keep it */
379
380	release_sock(sk);
381
382	return size;
383
384fail:
385	sk_msg_free(sk, &emsg->skmsg);
386	memset(emsg, 0, sizeof(*emsg));
387unlock:
388	release_sock(sk);
389	return err;
390}
391
392static struct proto espintcp_prot __ro_after_init;
393static struct proto_ops espintcp_ops __ro_after_init;
394static struct proto espintcp6_prot;
395static struct proto_ops espintcp6_ops;
396static DEFINE_MUTEX(tcpv6_prot_mutex);
397
398static void espintcp_data_ready(struct sock *sk)
399{
400	struct espintcp_ctx *ctx = espintcp_getctx(sk);
401
402	strp_data_ready(&ctx->strp);
403}
404
405static void espintcp_tx_work(struct work_struct *work)
406{
407	struct espintcp_ctx *ctx = container_of(work,
408						struct espintcp_ctx, work);
409	struct sock *sk = ctx->strp.sk;
410
411	lock_sock(sk);
412	if (!ctx->tx_running)
413		espintcp_push_msgs(sk, 0);
414	release_sock(sk);
415}
416
417static void espintcp_write_space(struct sock *sk)
418{
419	struct espintcp_ctx *ctx = espintcp_getctx(sk);
420
421	schedule_work(&ctx->work);
422	ctx->saved_write_space(sk);
423}
424
425static void espintcp_destruct(struct sock *sk)
426{
427	struct espintcp_ctx *ctx = espintcp_getctx(sk);
428
429	ctx->saved_destruct(sk);
430	kfree(ctx);
431}
432
433bool tcp_is_ulp_esp(struct sock *sk)
434{
435	return sk->sk_prot == &espintcp_prot || sk->sk_prot == &espintcp6_prot;
436}
437EXPORT_SYMBOL_GPL(tcp_is_ulp_esp);
438
439static void build_protos(struct proto *espintcp_prot,
440			 struct proto_ops *espintcp_ops,
441			 const struct proto *orig_prot,
442			 const struct proto_ops *orig_ops);
443static int espintcp_init_sk(struct sock *sk)
444{
445	struct inet_connection_sock *icsk = inet_csk(sk);
446	struct strp_callbacks cb = {
447		.rcv_msg = espintcp_rcv,
448		.parse_msg = espintcp_parse,
449	};
450	struct espintcp_ctx *ctx;
451	int err;
452
453	/* sockmap is not compatible with espintcp */
454	if (sk->sk_user_data)
455		return -EBUSY;
456
457	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
458	if (!ctx)
459		return -ENOMEM;
460
461	err = strp_init(&ctx->strp, sk, &cb);
462	if (err)
463		goto free;
464
465	__sk_dst_reset(sk);
466
467	strp_check_rcv(&ctx->strp);
468	skb_queue_head_init(&ctx->ike_queue);
469	skb_queue_head_init(&ctx->out_queue);
470
471	if (sk->sk_family == AF_INET) {
472		sk->sk_prot = &espintcp_prot;
473		sk->sk_socket->ops = &espintcp_ops;
474	} else {
475		mutex_lock(&tcpv6_prot_mutex);
476		if (!espintcp6_prot.recvmsg)
477			build_protos(&espintcp6_prot, &espintcp6_ops, sk->sk_prot, sk->sk_socket->ops);
478		mutex_unlock(&tcpv6_prot_mutex);
479
480		sk->sk_prot = &espintcp6_prot;
481		sk->sk_socket->ops = &espintcp6_ops;
482	}
483	ctx->saved_data_ready = sk->sk_data_ready;
484	ctx->saved_write_space = sk->sk_write_space;
485	ctx->saved_destruct = sk->sk_destruct;
486	sk->sk_data_ready = espintcp_data_ready;
487	sk->sk_write_space = espintcp_write_space;
488	sk->sk_destruct = espintcp_destruct;
489	rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
490	INIT_WORK(&ctx->work, espintcp_tx_work);
491
492	/* avoid using task_frag */
493	sk->sk_allocation = GFP_ATOMIC;
494
495	return 0;
496
497free:
498	kfree(ctx);
499	return err;
500}
501
502static void espintcp_release(struct sock *sk)
503{
504	struct espintcp_ctx *ctx = espintcp_getctx(sk);
505	struct sk_buff_head queue;
506	struct sk_buff *skb;
507
508	__skb_queue_head_init(&queue);
509	skb_queue_splice_init(&ctx->out_queue, &queue);
510
511	while ((skb = __skb_dequeue(&queue)))
512		espintcp_push_skb(sk, skb);
513
514	tcp_release_cb(sk);
515}
516
517static void espintcp_close(struct sock *sk, long timeout)
518{
519	struct espintcp_ctx *ctx = espintcp_getctx(sk);
520	struct espintcp_msg *emsg = &ctx->partial;
521
522	strp_stop(&ctx->strp);
523
524	sk->sk_prot = &tcp_prot;
525	barrier();
526
527	cancel_work_sync(&ctx->work);
528	strp_done(&ctx->strp);
529
530	skb_queue_purge(&ctx->out_queue);
531	skb_queue_purge(&ctx->ike_queue);
532
533	if (emsg->len) {
534		if (emsg->skb)
535			kfree_skb(emsg->skb);
536		else
537			sk_msg_free(sk, &emsg->skmsg);
538	}
539
540	tcp_close(sk, timeout);
541}
542
543static __poll_t espintcp_poll(struct file *file, struct socket *sock,
544			      poll_table *wait)
545{
546	__poll_t mask = datagram_poll(file, sock, wait);
547	struct sock *sk = sock->sk;
548	struct espintcp_ctx *ctx = espintcp_getctx(sk);
549
550	if (!skb_queue_empty(&ctx->ike_queue))
551		mask |= EPOLLIN | EPOLLRDNORM;
552
553	return mask;
554}
555
556static void build_protos(struct proto *espintcp_prot,
557			 struct proto_ops *espintcp_ops,
558			 const struct proto *orig_prot,
559			 const struct proto_ops *orig_ops)
560{
561	memcpy(espintcp_prot, orig_prot, sizeof(struct proto));
562	memcpy(espintcp_ops, orig_ops, sizeof(struct proto_ops));
563	espintcp_prot->sendmsg = espintcp_sendmsg;
564	espintcp_prot->recvmsg = espintcp_recvmsg;
565	espintcp_prot->close = espintcp_close;
566	espintcp_prot->release_cb = espintcp_release;
567	espintcp_ops->poll = espintcp_poll;
568}
569
570static struct tcp_ulp_ops espintcp_ulp __read_mostly = {
571	.name = "espintcp",
572	.owner = THIS_MODULE,
573	.init = espintcp_init_sk,
574};
575
576void __init espintcp_init(void)
577{
578	build_protos(&espintcp_prot, &espintcp_ops, &tcp_prot, &inet_stream_ops);
579
580	tcp_register_ulp(&espintcp_ulp);
581}