Linux Audio

Check our new training course

Loading...
Note: File does not exist in v5.4.
  1/* SPDX-License-Identifier: GPL-2.0 */
  2
  3#ifndef __SOCKET_HELPERS__
  4#define __SOCKET_HELPERS__
  5
  6#include <linux/vm_sockets.h>
  7
  8/* include/linux/net.h */
  9#define SOCK_TYPE_MASK 0xf
 10
 11#define IO_TIMEOUT_SEC 30
 12#define MAX_STRERR_LEN 256
 13
 14/* workaround for older vm_sockets.h */
 15#ifndef VMADDR_CID_LOCAL
 16#define VMADDR_CID_LOCAL 1
 17#endif
 18
 19/* include/linux/cleanup.h */
 20#define __get_and_null(p, nullvalue)                                           \
 21	({                                                                     \
 22		__auto_type __ptr = &(p);                                      \
 23		__auto_type __val = *__ptr;                                    \
 24		*__ptr = nullvalue;                                            \
 25		__val;                                                         \
 26	})
 27
 28#define take_fd(fd) __get_and_null(fd, -EBADF)
 29
 30/* Wrappers that fail the test on error and report it. */
 31
 32#define _FAIL(errnum, fmt...)                                                  \
 33	({                                                                     \
 34		error_at_line(0, (errnum), __func__, __LINE__, fmt);           \
 35		CHECK_FAIL(true);                                              \
 36	})
 37#define FAIL(fmt...) _FAIL(0, fmt)
 38#define FAIL_ERRNO(fmt...) _FAIL(errno, fmt)
 39#define FAIL_LIBBPF(err, msg)                                                  \
 40	({                                                                     \
 41		char __buf[MAX_STRERR_LEN];                                    \
 42		libbpf_strerror((err), __buf, sizeof(__buf));                  \
 43		FAIL("%s: %s", (msg), __buf);                                  \
 44	})
 45
 46
 47#define xaccept_nonblock(fd, addr, len)                                        \
 48	({                                                                     \
 49		int __ret =                                                    \
 50			accept_timeout((fd), (addr), (len), IO_TIMEOUT_SEC);   \
 51		if (__ret == -1)                                               \
 52			FAIL_ERRNO("accept");                                  \
 53		__ret;                                                         \
 54	})
 55
 56#define xbind(fd, addr, len)                                                   \
 57	({                                                                     \
 58		int __ret = bind((fd), (addr), (len));                         \
 59		if (__ret == -1)                                               \
 60			FAIL_ERRNO("bind");                                    \
 61		__ret;                                                         \
 62	})
 63
 64#define xclose(fd)                                                             \
 65	({                                                                     \
 66		int __ret = close((fd));                                       \
 67		if (__ret == -1)                                               \
 68			FAIL_ERRNO("close");                                   \
 69		__ret;                                                         \
 70	})
 71
 72#define xconnect(fd, addr, len)                                                \
 73	({                                                                     \
 74		int __ret = connect((fd), (addr), (len));                      \
 75		if (__ret == -1)                                               \
 76			FAIL_ERRNO("connect");                                 \
 77		__ret;                                                         \
 78	})
 79
 80#define xgetsockname(fd, addr, len)                                            \
 81	({                                                                     \
 82		int __ret = getsockname((fd), (addr), (len));                  \
 83		if (__ret == -1)                                               \
 84			FAIL_ERRNO("getsockname");                             \
 85		__ret;                                                         \
 86	})
 87
 88#define xgetsockopt(fd, level, name, val, len)                                 \
 89	({                                                                     \
 90		int __ret = getsockopt((fd), (level), (name), (val), (len));   \
 91		if (__ret == -1)                                               \
 92			FAIL_ERRNO("getsockopt(" #name ")");                   \
 93		__ret;                                                         \
 94	})
 95
 96#define xlisten(fd, backlog)                                                   \
 97	({                                                                     \
 98		int __ret = listen((fd), (backlog));                           \
 99		if (__ret == -1)                                               \
100			FAIL_ERRNO("listen");                                  \
101		__ret;                                                         \
102	})
103
104#define xsetsockopt(fd, level, name, val, len)                                 \
105	({                                                                     \
106		int __ret = setsockopt((fd), (level), (name), (val), (len));   \
107		if (__ret == -1)                                               \
108			FAIL_ERRNO("setsockopt(" #name ")");                   \
109		__ret;                                                         \
110	})
111
112#define xsend(fd, buf, len, flags)                                             \
113	({                                                                     \
114		ssize_t __ret = send((fd), (buf), (len), (flags));             \
115		if (__ret == -1)                                               \
116			FAIL_ERRNO("send");                                    \
117		__ret;                                                         \
118	})
119
120#define xrecv_nonblock(fd, buf, len, flags)                                    \
121	({                                                                     \
122		ssize_t __ret = recv_timeout((fd), (buf), (len), (flags),      \
123					     IO_TIMEOUT_SEC);                  \
124		if (__ret == -1)                                               \
125			FAIL_ERRNO("recv");                                    \
126		__ret;                                                         \
127	})
128
129#define xsocket(family, sotype, flags)                                         \
130	({                                                                     \
131		int __ret = socket(family, sotype, flags);                     \
132		if (__ret == -1)                                               \
133			FAIL_ERRNO("socket");                                  \
134		__ret;                                                         \
135	})
136
137static inline void close_fd(int *fd)
138{
139	if (*fd >= 0)
140		xclose(*fd);
141}
142
143#define __close_fd __attribute__((cleanup(close_fd)))
144
145static inline struct sockaddr *sockaddr(struct sockaddr_storage *ss)
146{
147	return (struct sockaddr *)ss;
148}
149
150static inline void init_addr_loopback4(struct sockaddr_storage *ss,
151				       socklen_t *len)
152{
153	struct sockaddr_in *addr4 = memset(ss, 0, sizeof(*ss));
154
155	addr4->sin_family = AF_INET;
156	addr4->sin_port = 0;
157	addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
158	*len = sizeof(*addr4);
159}
160
161static inline void init_addr_loopback6(struct sockaddr_storage *ss,
162				       socklen_t *len)
163{
164	struct sockaddr_in6 *addr6 = memset(ss, 0, sizeof(*ss));
165
166	addr6->sin6_family = AF_INET6;
167	addr6->sin6_port = 0;
168	addr6->sin6_addr = in6addr_loopback;
169	*len = sizeof(*addr6);
170}
171
172static inline void init_addr_loopback_vsock(struct sockaddr_storage *ss,
173					    socklen_t *len)
174{
175	struct sockaddr_vm *addr = memset(ss, 0, sizeof(*ss));
176
177	addr->svm_family = AF_VSOCK;
178	addr->svm_port = VMADDR_PORT_ANY;
179	addr->svm_cid = VMADDR_CID_LOCAL;
180	*len = sizeof(*addr);
181}
182
183static inline void init_addr_loopback(int family, struct sockaddr_storage *ss,
184				      socklen_t *len)
185{
186	switch (family) {
187	case AF_INET:
188		init_addr_loopback4(ss, len);
189		return;
190	case AF_INET6:
191		init_addr_loopback6(ss, len);
192		return;
193	case AF_VSOCK:
194		init_addr_loopback_vsock(ss, len);
195		return;
196	default:
197		FAIL("unsupported address family %d", family);
198	}
199}
200
201static inline int enable_reuseport(int s, int progfd)
202{
203	int err, one = 1;
204
205	err = xsetsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
206	if (err)
207		return -1;
208	err = xsetsockopt(s, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &progfd,
209			  sizeof(progfd));
210	if (err)
211		return -1;
212
213	return 0;
214}
215
216static inline int socket_loopback_reuseport(int family, int sotype, int progfd)
217{
218	struct sockaddr_storage addr;
219	socklen_t len = 0;
220	int err, s;
221
222	init_addr_loopback(family, &addr, &len);
223
224	s = xsocket(family, sotype, 0);
225	if (s == -1)
226		return -1;
227
228	if (progfd >= 0)
229		enable_reuseport(s, progfd);
230
231	err = xbind(s, sockaddr(&addr), len);
232	if (err)
233		goto close;
234
235	if (sotype & SOCK_DGRAM)
236		return s;
237
238	err = xlisten(s, SOMAXCONN);
239	if (err)
240		goto close;
241
242	return s;
243close:
244	xclose(s);
245	return -1;
246}
247
248static inline int socket_loopback(int family, int sotype)
249{
250	return socket_loopback_reuseport(family, sotype, -1);
251}
252
253static inline int poll_connect(int fd, unsigned int timeout_sec)
254{
255	struct timeval timeout = { .tv_sec = timeout_sec };
256	fd_set wfds;
257	int r, eval;
258	socklen_t esize = sizeof(eval);
259
260	FD_ZERO(&wfds);
261	FD_SET(fd, &wfds);
262
263	r = select(fd + 1, NULL, &wfds, NULL, &timeout);
264	if (r == 0)
265		errno = ETIME;
266	if (r != 1)
267		return -1;
268
269	if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &eval, &esize) < 0)
270		return -1;
271	if (eval != 0) {
272		errno = eval;
273		return -1;
274	}
275
276	return 0;
277}
278
279static inline int poll_read(int fd, unsigned int timeout_sec)
280{
281	struct timeval timeout = { .tv_sec = timeout_sec };
282	fd_set rfds;
283	int r;
284
285	FD_ZERO(&rfds);
286	FD_SET(fd, &rfds);
287
288	r = select(fd + 1, &rfds, NULL, NULL, &timeout);
289	if (r == 0)
290		errno = ETIME;
291
292	return r == 1 ? 0 : -1;
293}
294
295static inline int accept_timeout(int fd, struct sockaddr *addr, socklen_t *len,
296				 unsigned int timeout_sec)
297{
298	if (poll_read(fd, timeout_sec))
299		return -1;
300
301	return accept(fd, addr, len);
302}
303
304static inline int recv_timeout(int fd, void *buf, size_t len, int flags,
305			       unsigned int timeout_sec)
306{
307	if (poll_read(fd, timeout_sec))
308		return -1;
309
310	return recv(fd, buf, len, flags);
311}
312
313
314static inline int create_pair(int family, int sotype, int *p0, int *p1)
315{
316	__close_fd int s, c = -1, p = -1;
317	struct sockaddr_storage addr;
318	socklen_t len = sizeof(addr);
319	int err;
320
321	s = socket_loopback(family, sotype);
322	if (s < 0)
323		return s;
324
325	err = xgetsockname(s, sockaddr(&addr), &len);
326	if (err)
327		return err;
328
329	c = xsocket(family, sotype, 0);
330	if (c < 0)
331		return c;
332
333	err = connect(c, sockaddr(&addr), len);
334	if (err) {
335		if (errno != EINPROGRESS) {
336			FAIL_ERRNO("connect");
337			return err;
338		}
339
340		err = poll_connect(c, IO_TIMEOUT_SEC);
341		if (err) {
342			FAIL_ERRNO("poll_connect");
343			return err;
344		}
345	}
346
347	switch (sotype & SOCK_TYPE_MASK) {
348	case SOCK_DGRAM:
349		err = xgetsockname(c, sockaddr(&addr), &len);
350		if (err)
351			return err;
352
353		err = xconnect(s, sockaddr(&addr), len);
354		if (err)
355			return err;
356
357		*p0 = take_fd(s);
358		break;
359	case SOCK_STREAM:
360	case SOCK_SEQPACKET:
361		p = xaccept_nonblock(s, NULL, NULL);
362		if (p < 0)
363			return p;
364
365		*p0 = take_fd(p);
366		break;
367	default:
368		FAIL("Unsupported socket type %#x", sotype);
369		return -EOPNOTSUPP;
370	}
371
372	*p1 = take_fd(c);
373	return 0;
374}
375
376static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1,
377				      int *p0, int *p1)
378{
379	int err;
380
381	err = create_pair(family, sotype, c0, p0);
382	if (err)
383		return err;
384
385	err = create_pair(family, sotype, c1, p1);
386	if (err) {
387		close(*c0);
388		close(*p0);
389	}
390
391	return err;
392}
393
394#endif // __SOCKET_HELPERS__