Linux Audio

Check our new training course

Loading...
v6.13.7
  1// SPDX-License-Identifier: GPL-2.0
  2/* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */
  3
  4#include <linux/skmsg.h>
  5#include <net/sock.h>
  6#include <net/udp.h>
  7#include <net/inet_common.h>
  8
  9#include "udp_impl.h"
 10
 11static struct proto *udpv6_prot_saved __read_mostly;
 12
 13static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
 14			  int flags, int *addr_len)
 15{
 16#if IS_ENABLED(CONFIG_IPV6)
 17	if (sk->sk_family == AF_INET6)
 18		return udpv6_prot_saved->recvmsg(sk, msg, len, flags, addr_len);
 
 19#endif
 20	return udp_prot.recvmsg(sk, msg, len, flags, addr_len);
 21}
 22
 23static bool udp_sk_has_data(struct sock *sk)
 24{
 25	return !skb_queue_empty(&udp_sk(sk)->reader_queue) ||
 26	       !skb_queue_empty(&sk->sk_receive_queue);
 27}
 28
 29static bool psock_has_data(struct sk_psock *psock)
 30{
 31	return !skb_queue_empty(&psock->ingress_skb) ||
 32	       !sk_psock_queue_empty(psock);
 33}
 34
 35#define udp_msg_has_data(__sk, __psock)	\
 36		({ udp_sk_has_data(__sk) || psock_has_data(__psock); })
 37
 38static int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock,
 39			     long timeo)
 40{
 41	DEFINE_WAIT_FUNC(wait, woken_wake_function);
 42	int ret = 0;
 43
 44	if (sk->sk_shutdown & RCV_SHUTDOWN)
 45		return 1;
 46
 47	if (!timeo)
 48		return ret;
 49
 50	add_wait_queue(sk_sleep(sk), &wait);
 51	sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
 52	ret = udp_msg_has_data(sk, psock);
 53	if (!ret) {
 54		wait_woken(&wait, TASK_INTERRUPTIBLE, timeo);
 55		ret = udp_msg_has_data(sk, psock);
 56	}
 57	sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
 58	remove_wait_queue(sk_sleep(sk), &wait);
 59	return ret;
 60}
 61
 62static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
 63			   int flags, int *addr_len)
 64{
 65	struct sk_psock *psock;
 66	int copied, ret;
 67
 68	if (unlikely(flags & MSG_ERRQUEUE))
 69		return inet_recv_error(sk, msg, len, addr_len);
 70
 71	if (!len)
 72		return 0;
 73
 74	psock = sk_psock_get(sk);
 75	if (unlikely(!psock))
 76		return sk_udp_recvmsg(sk, msg, len, flags, addr_len);
 77
 78	if (!psock_has_data(psock)) {
 79		ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len);
 80		goto out;
 81	}
 82
 83msg_bytes_ready:
 84	copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
 85	if (!copied) {
 86		long timeo;
 87		int data;
 88
 89		timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
 90		data = udp_msg_wait_data(sk, psock, timeo);
 91		if (data) {
 92			if (psock_has_data(psock))
 93				goto msg_bytes_ready;
 94			ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len);
 95			goto out;
 96		}
 97		copied = -EAGAIN;
 98	}
 99	ret = copied;
100out:
101	sk_psock_put(sk, psock);
102	return ret;
103}
104
105enum {
106	UDP_BPF_IPV4,
107	UDP_BPF_IPV6,
108	UDP_BPF_NUM_PROTS,
109};
110
111static DEFINE_SPINLOCK(udpv6_prot_lock);
112static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS];
113
114static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
115{
116	*prot        = *base;
 
117	prot->close  = sock_map_close;
118	prot->recvmsg = udp_bpf_recvmsg;
119	prot->sock_is_readable = sk_msg_is_readable;
120}
121
122static void udp_bpf_check_v6_needs_rebuild(struct proto *ops)
123{
124	if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) {
125		spin_lock_bh(&udpv6_prot_lock);
126		if (likely(ops != udpv6_prot_saved)) {
127			udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops);
128			smp_store_release(&udpv6_prot_saved, ops);
129		}
130		spin_unlock_bh(&udpv6_prot_lock);
131	}
132}
133
134static int __init udp_bpf_v4_build_proto(void)
135{
136	udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV4], &udp_prot);
137	return 0;
138}
139late_initcall(udp_bpf_v4_build_proto);
140
141int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
142{
143	int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
144
145	if (restore) {
146		sk->sk_write_space = psock->saved_write_space;
147		sock_replace_proto(sk, psock->sk_proto);
148		return 0;
149	}
150
151	if (sk->sk_family == AF_INET6)
152		udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
153
154	sock_replace_proto(sk, &udp_bpf_prots[family]);
155	return 0;
156}
157EXPORT_SYMBOL_GPL(udp_bpf_update_proto);
v5.14.15
  1// SPDX-License-Identifier: GPL-2.0
  2/* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */
  3
  4#include <linux/skmsg.h>
  5#include <net/sock.h>
  6#include <net/udp.h>
  7#include <net/inet_common.h>
  8
  9#include "udp_impl.h"
 10
 11static struct proto *udpv6_prot_saved __read_mostly;
 12
 13static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
 14			  int noblock, int flags, int *addr_len)
 15{
 16#if IS_ENABLED(CONFIG_IPV6)
 17	if (sk->sk_family == AF_INET6)
 18		return udpv6_prot_saved->recvmsg(sk, msg, len, noblock, flags,
 19						 addr_len);
 20#endif
 21	return udp_prot.recvmsg(sk, msg, len, noblock, flags, addr_len);
 22}
 23
 24static bool udp_sk_has_data(struct sock *sk)
 25{
 26	return !skb_queue_empty(&udp_sk(sk)->reader_queue) ||
 27	       !skb_queue_empty(&sk->sk_receive_queue);
 28}
 29
 30static bool psock_has_data(struct sk_psock *psock)
 31{
 32	return !skb_queue_empty(&psock->ingress_skb) ||
 33	       !sk_psock_queue_empty(psock);
 34}
 35
 36#define udp_msg_has_data(__sk, __psock)	\
 37		({ udp_sk_has_data(__sk) || psock_has_data(__psock); })
 38
 39static int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock,
 40			     long timeo)
 41{
 42	DEFINE_WAIT_FUNC(wait, woken_wake_function);
 43	int ret = 0;
 44
 45	if (sk->sk_shutdown & RCV_SHUTDOWN)
 46		return 1;
 47
 48	if (!timeo)
 49		return ret;
 50
 51	add_wait_queue(sk_sleep(sk), &wait);
 52	sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
 53	ret = udp_msg_has_data(sk, psock);
 54	if (!ret) {
 55		wait_woken(&wait, TASK_INTERRUPTIBLE, timeo);
 56		ret = udp_msg_has_data(sk, psock);
 57	}
 58	sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
 59	remove_wait_queue(sk_sleep(sk), &wait);
 60	return ret;
 61}
 62
 63static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
 64			   int nonblock, int flags, int *addr_len)
 65{
 66	struct sk_psock *psock;
 67	int copied, ret;
 68
 69	if (unlikely(flags & MSG_ERRQUEUE))
 70		return inet_recv_error(sk, msg, len, addr_len);
 71
 
 
 
 72	psock = sk_psock_get(sk);
 73	if (unlikely(!psock))
 74		return sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
 75
 76	if (!psock_has_data(psock)) {
 77		ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
 78		goto out;
 79	}
 80
 81msg_bytes_ready:
 82	copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
 83	if (!copied) {
 84		long timeo;
 85		int data;
 86
 87		timeo = sock_rcvtimeo(sk, nonblock);
 88		data = udp_msg_wait_data(sk, psock, timeo);
 89		if (data) {
 90			if (psock_has_data(psock))
 91				goto msg_bytes_ready;
 92			ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
 93			goto out;
 94		}
 95		copied = -EAGAIN;
 96	}
 97	ret = copied;
 98out:
 99	sk_psock_put(sk, psock);
100	return ret;
101}
102
103enum {
104	UDP_BPF_IPV4,
105	UDP_BPF_IPV6,
106	UDP_BPF_NUM_PROTS,
107};
108
109static DEFINE_SPINLOCK(udpv6_prot_lock);
110static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS];
111
112static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
113{
114	*prot        = *base;
115	prot->unhash = sock_map_unhash;
116	prot->close  = sock_map_close;
117	prot->recvmsg = udp_bpf_recvmsg;
 
118}
119
120static void udp_bpf_check_v6_needs_rebuild(struct proto *ops)
121{
122	if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) {
123		spin_lock_bh(&udpv6_prot_lock);
124		if (likely(ops != udpv6_prot_saved)) {
125			udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops);
126			smp_store_release(&udpv6_prot_saved, ops);
127		}
128		spin_unlock_bh(&udpv6_prot_lock);
129	}
130}
131
132static int __init udp_bpf_v4_build_proto(void)
133{
134	udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV4], &udp_prot);
135	return 0;
136}
137late_initcall(udp_bpf_v4_build_proto);
138
139int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
140{
141	int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
142
143	if (restore) {
144		sk->sk_write_space = psock->saved_write_space;
145		WRITE_ONCE(sk->sk_prot, psock->sk_proto);
146		return 0;
147	}
148
149	if (sk->sk_family == AF_INET6)
150		udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
151
152	WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]);
153	return 0;
154}
155EXPORT_SYMBOL_GPL(udp_bpf_update_proto);