Linux Audio

Check our new training course

Loading...
Note: File does not exist in v3.1.
  1// SPDX-License-Identifier: GPL-2.0-only
  2/*
  3 * vsock_diag_test - vsock_diag.ko test suite
  4 *
  5 * Copyright (C) 2017 Red Hat, Inc.
  6 *
  7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
  8 */
  9
 10#include <getopt.h>
 11#include <stdio.h>
 12#include <stdbool.h>
 13#include <stdlib.h>
 14#include <string.h>
 15#include <errno.h>
 16#include <unistd.h>
 17#include <signal.h>
 18#include <sys/socket.h>
 19#include <sys/stat.h>
 20#include <sys/types.h>
 21#include <linux/list.h>
 22#include <linux/net.h>
 23#include <linux/netlink.h>
 24#include <linux/sock_diag.h>
 25#include <netinet/tcp.h>
 26
 27#include "../../../include/uapi/linux/vm_sockets.h"
 28#include "../../../include/uapi/linux/vm_sockets_diag.h"
 29
 30#include "timeout.h"
 31#include "control.h"
 32
 33enum test_mode {
 34	TEST_MODE_UNSET,
 35	TEST_MODE_CLIENT,
 36	TEST_MODE_SERVER
 37};
 38
 39/* Per-socket status */
 40struct vsock_stat {
 41	struct list_head list;
 42	struct vsock_diag_msg msg;
 43};
 44
 45static const char *sock_type_str(int type)
 46{
 47	switch (type) {
 48	case SOCK_DGRAM:
 49		return "DGRAM";
 50	case SOCK_STREAM:
 51		return "STREAM";
 52	default:
 53		return "INVALID TYPE";
 54	}
 55}
 56
 57static const char *sock_state_str(int state)
 58{
 59	switch (state) {
 60	case TCP_CLOSE:
 61		return "UNCONNECTED";
 62	case TCP_SYN_SENT:
 63		return "CONNECTING";
 64	case TCP_ESTABLISHED:
 65		return "CONNECTED";
 66	case TCP_CLOSING:
 67		return "DISCONNECTING";
 68	case TCP_LISTEN:
 69		return "LISTEN";
 70	default:
 71		return "INVALID STATE";
 72	}
 73}
 74
 75static const char *sock_shutdown_str(int shutdown)
 76{
 77	switch (shutdown) {
 78	case 1:
 79		return "RCV_SHUTDOWN";
 80	case 2:
 81		return "SEND_SHUTDOWN";
 82	case 3:
 83		return "RCV_SHUTDOWN | SEND_SHUTDOWN";
 84	default:
 85		return "0";
 86	}
 87}
 88
 89static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
 90{
 91	if (cid == VMADDR_CID_ANY)
 92		fprintf(fp, "*:");
 93	else
 94		fprintf(fp, "%u:", cid);
 95
 96	if (port == VMADDR_PORT_ANY)
 97		fprintf(fp, "*");
 98	else
 99		fprintf(fp, "%u", port);
100}
101
102static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
103{
104	print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
105	fprintf(fp, " ");
106	print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
107	fprintf(fp, " %s %s %s %u\n",
108		sock_type_str(st->msg.vdiag_type),
109		sock_state_str(st->msg.vdiag_state),
110		sock_shutdown_str(st->msg.vdiag_shutdown),
111		st->msg.vdiag_ino);
112}
113
114static void print_vsock_stats(FILE *fp, struct list_head *head)
115{
116	struct vsock_stat *st;
117
118	list_for_each_entry(st, head, list)
119		print_vsock_stat(fp, st);
120}
121
122static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
123{
124	struct vsock_stat *st;
125	struct stat stat;
126
127	if (fstat(fd, &stat) < 0) {
128		perror("fstat");
129		exit(EXIT_FAILURE);
130	}
131
132	list_for_each_entry(st, head, list)
133		if (st->msg.vdiag_ino == stat.st_ino)
134			return st;
135
136	fprintf(stderr, "cannot find fd %d\n", fd);
137	exit(EXIT_FAILURE);
138}
139
140static void check_no_sockets(struct list_head *head)
141{
142	if (!list_empty(head)) {
143		fprintf(stderr, "expected no sockets\n");
144		print_vsock_stats(stderr, head);
145		exit(1);
146	}
147}
148
149static void check_num_sockets(struct list_head *head, int expected)
150{
151	struct list_head *node;
152	int n = 0;
153
154	list_for_each(node, head)
155		n++;
156
157	if (n != expected) {
158		fprintf(stderr, "expected %d sockets, found %d\n",
159			expected, n);
160		print_vsock_stats(stderr, head);
161		exit(EXIT_FAILURE);
162	}
163}
164
165static void check_socket_state(struct vsock_stat *st, __u8 state)
166{
167	if (st->msg.vdiag_state != state) {
168		fprintf(stderr, "expected socket state %#x, got %#x\n",
169			state, st->msg.vdiag_state);
170		exit(EXIT_FAILURE);
171	}
172}
173
174static void send_req(int fd)
175{
176	struct sockaddr_nl nladdr = {
177		.nl_family = AF_NETLINK,
178	};
179	struct {
180		struct nlmsghdr nlh;
181		struct vsock_diag_req vreq;
182	} req = {
183		.nlh = {
184			.nlmsg_len = sizeof(req),
185			.nlmsg_type = SOCK_DIAG_BY_FAMILY,
186			.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
187		},
188		.vreq = {
189			.sdiag_family = AF_VSOCK,
190			.vdiag_states = ~(__u32)0,
191		},
192	};
193	struct iovec iov = {
194		.iov_base = &req,
195		.iov_len = sizeof(req),
196	};
197	struct msghdr msg = {
198		.msg_name = &nladdr,
199		.msg_namelen = sizeof(nladdr),
200		.msg_iov = &iov,
201		.msg_iovlen = 1,
202	};
203
204	for (;;) {
205		if (sendmsg(fd, &msg, 0) < 0) {
206			if (errno == EINTR)
207				continue;
208
209			perror("sendmsg");
210			exit(EXIT_FAILURE);
211		}
212
213		return;
214	}
215}
216
217static ssize_t recv_resp(int fd, void *buf, size_t len)
218{
219	struct sockaddr_nl nladdr = {
220		.nl_family = AF_NETLINK,
221	};
222	struct iovec iov = {
223		.iov_base = buf,
224		.iov_len = len,
225	};
226	struct msghdr msg = {
227		.msg_name = &nladdr,
228		.msg_namelen = sizeof(nladdr),
229		.msg_iov = &iov,
230		.msg_iovlen = 1,
231	};
232	ssize_t ret;
233
234	do {
235		ret = recvmsg(fd, &msg, 0);
236	} while (ret < 0 && errno == EINTR);
237
238	if (ret < 0) {
239		perror("recvmsg");
240		exit(EXIT_FAILURE);
241	}
242
243	return ret;
244}
245
246static void add_vsock_stat(struct list_head *sockets,
247			   const struct vsock_diag_msg *resp)
248{
249	struct vsock_stat *st;
250
251	st = malloc(sizeof(*st));
252	if (!st) {
253		perror("malloc");
254		exit(EXIT_FAILURE);
255	}
256
257	st->msg = *resp;
258	list_add_tail(&st->list, sockets);
259}
260
261/*
262 * Read vsock stats into a list.
263 */
264static void read_vsock_stat(struct list_head *sockets)
265{
266	long buf[8192 / sizeof(long)];
267	int fd;
268
269	fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
270	if (fd < 0) {
271		perror("socket");
272		exit(EXIT_FAILURE);
273	}
274
275	send_req(fd);
276
277	for (;;) {
278		const struct nlmsghdr *h;
279		ssize_t ret;
280
281		ret = recv_resp(fd, buf, sizeof(buf));
282		if (ret == 0)
283			goto done;
284		if (ret < sizeof(*h)) {
285			fprintf(stderr, "short read of %zd bytes\n", ret);
286			exit(EXIT_FAILURE);
287		}
288
289		h = (struct nlmsghdr *)buf;
290
291		while (NLMSG_OK(h, ret)) {
292			if (h->nlmsg_type == NLMSG_DONE)
293				goto done;
294
295			if (h->nlmsg_type == NLMSG_ERROR) {
296				const struct nlmsgerr *err = NLMSG_DATA(h);
297
298				if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
299					fprintf(stderr, "NLMSG_ERROR\n");
300				else {
301					errno = -err->error;
302					perror("NLMSG_ERROR");
303				}
304
305				exit(EXIT_FAILURE);
306			}
307
308			if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
309				fprintf(stderr, "unexpected nlmsg_type %#x\n",
310					h->nlmsg_type);
311				exit(EXIT_FAILURE);
312			}
313			if (h->nlmsg_len <
314			    NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
315				fprintf(stderr, "short vsock_diag_msg\n");
316				exit(EXIT_FAILURE);
317			}
318
319			add_vsock_stat(sockets, NLMSG_DATA(h));
320
321			h = NLMSG_NEXT(h, ret);
322		}
323	}
324
325done:
326	close(fd);
327}
328
329static void free_sock_stat(struct list_head *sockets)
330{
331	struct vsock_stat *st;
332	struct vsock_stat *next;
333
334	list_for_each_entry_safe(st, next, sockets, list)
335		free(st);
336}
337
338static void test_no_sockets(unsigned int peer_cid)
339{
340	LIST_HEAD(sockets);
341
342	read_vsock_stat(&sockets);
343
344	check_no_sockets(&sockets);
345
346	free_sock_stat(&sockets);
347}
348
349static void test_listen_socket_server(unsigned int peer_cid)
350{
351	union {
352		struct sockaddr sa;
353		struct sockaddr_vm svm;
354	} addr = {
355		.svm = {
356			.svm_family = AF_VSOCK,
357			.svm_port = 1234,
358			.svm_cid = VMADDR_CID_ANY,
359		},
360	};
361	LIST_HEAD(sockets);
362	struct vsock_stat *st;
363	int fd;
364
365	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
366
367	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
368		perror("bind");
369		exit(EXIT_FAILURE);
370	}
371
372	if (listen(fd, 1) < 0) {
373		perror("listen");
374		exit(EXIT_FAILURE);
375	}
376
377	read_vsock_stat(&sockets);
378
379	check_num_sockets(&sockets, 1);
380	st = find_vsock_stat(&sockets, fd);
381	check_socket_state(st, TCP_LISTEN);
382
383	close(fd);
384	free_sock_stat(&sockets);
385}
386
387static void test_connect_client(unsigned int peer_cid)
388{
389	union {
390		struct sockaddr sa;
391		struct sockaddr_vm svm;
392	} addr = {
393		.svm = {
394			.svm_family = AF_VSOCK,
395			.svm_port = 1234,
396			.svm_cid = peer_cid,
397		},
398	};
399	int fd;
400	int ret;
401	LIST_HEAD(sockets);
402	struct vsock_stat *st;
403
404	control_expectln("LISTENING");
405
406	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
407
408	timeout_begin(TIMEOUT);
409	do {
410		ret = connect(fd, &addr.sa, sizeof(addr.svm));
411		timeout_check("connect");
412	} while (ret < 0 && errno == EINTR);
413	timeout_end();
414
415	if (ret < 0) {
416		perror("connect");
417		exit(EXIT_FAILURE);
418	}
419
420	read_vsock_stat(&sockets);
421
422	check_num_sockets(&sockets, 1);
423	st = find_vsock_stat(&sockets, fd);
424	check_socket_state(st, TCP_ESTABLISHED);
425
426	control_expectln("DONE");
427	control_writeln("DONE");
428
429	close(fd);
430	free_sock_stat(&sockets);
431}
432
433static void test_connect_server(unsigned int peer_cid)
434{
435	union {
436		struct sockaddr sa;
437		struct sockaddr_vm svm;
438	} addr = {
439		.svm = {
440			.svm_family = AF_VSOCK,
441			.svm_port = 1234,
442			.svm_cid = VMADDR_CID_ANY,
443		},
444	};
445	union {
446		struct sockaddr sa;
447		struct sockaddr_vm svm;
448	} clientaddr;
449	socklen_t clientaddr_len = sizeof(clientaddr.svm);
450	LIST_HEAD(sockets);
451	struct vsock_stat *st;
452	int fd;
453	int client_fd;
454
455	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
456
457	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
458		perror("bind");
459		exit(EXIT_FAILURE);
460	}
461
462	if (listen(fd, 1) < 0) {
463		perror("listen");
464		exit(EXIT_FAILURE);
465	}
466
467	control_writeln("LISTENING");
468
469	timeout_begin(TIMEOUT);
470	do {
471		client_fd = accept(fd, &clientaddr.sa, &clientaddr_len);
472		timeout_check("accept");
473	} while (client_fd < 0 && errno == EINTR);
474	timeout_end();
475
476	if (client_fd < 0) {
477		perror("accept");
478		exit(EXIT_FAILURE);
479	}
480	if (clientaddr.sa.sa_family != AF_VSOCK) {
481		fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n",
482			clientaddr.sa.sa_family);
483		exit(EXIT_FAILURE);
484	}
485	if (clientaddr.svm.svm_cid != peer_cid) {
486		fprintf(stderr, "expected peer CID %u from accept(2), got %u\n",
487			peer_cid, clientaddr.svm.svm_cid);
488		exit(EXIT_FAILURE);
489	}
490
491	read_vsock_stat(&sockets);
492
493	check_num_sockets(&sockets, 2);
494	find_vsock_stat(&sockets, fd);
495	st = find_vsock_stat(&sockets, client_fd);
496	check_socket_state(st, TCP_ESTABLISHED);
497
498	control_writeln("DONE");
499	control_expectln("DONE");
500
501	close(client_fd);
502	close(fd);
503	free_sock_stat(&sockets);
504}
505
506static struct {
507	const char *name;
508	void (*run_client)(unsigned int peer_cid);
509	void (*run_server)(unsigned int peer_cid);
510} test_cases[] = {
511	{
512		.name = "No sockets",
513		.run_server = test_no_sockets,
514	},
515	{
516		.name = "Listen socket",
517		.run_server = test_listen_socket_server,
518	},
519	{
520		.name = "Connect",
521		.run_client = test_connect_client,
522		.run_server = test_connect_server,
523	},
524	{},
525};
526
527static void init_signals(void)
528{
529	struct sigaction act = {
530		.sa_handler = sigalrm,
531	};
532
533	sigaction(SIGALRM, &act, NULL);
534	signal(SIGPIPE, SIG_IGN);
535}
536
537static unsigned int parse_cid(const char *str)
538{
539	char *endptr = NULL;
540	unsigned long int n;
541
542	errno = 0;
543	n = strtoul(str, &endptr, 10);
544	if (errno || *endptr != '\0') {
545		fprintf(stderr, "malformed CID \"%s\"\n", str);
546		exit(EXIT_FAILURE);
547	}
548	return n;
549}
550
551static const char optstring[] = "";
552static const struct option longopts[] = {
553	{
554		.name = "control-host",
555		.has_arg = required_argument,
556		.val = 'H',
557	},
558	{
559		.name = "control-port",
560		.has_arg = required_argument,
561		.val = 'P',
562	},
563	{
564		.name = "mode",
565		.has_arg = required_argument,
566		.val = 'm',
567	},
568	{
569		.name = "peer-cid",
570		.has_arg = required_argument,
571		.val = 'p',
572	},
573	{
574		.name = "help",
575		.has_arg = no_argument,
576		.val = '?',
577	},
578	{},
579};
580
581static void usage(void)
582{
583	fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid>\n"
584		"\n"
585		"  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
586		"  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
587		"\n"
588		"Run vsock_diag.ko tests.  Must be launched in both\n"
589		"guest and host.  One side must use --mode=client and\n"
590		"the other side must use --mode=server.\n"
591		"\n"
592		"A TCP control socket connection is used to coordinate tests\n"
593		"between the client and the server.  The server requires a\n"
594		"listen address and the client requires an address to\n"
595		"connect to.\n"
596		"\n"
597		"The CID of the other side must be given with --peer-cid=<cid>.\n");
598	exit(EXIT_FAILURE);
599}
600
601int main(int argc, char **argv)
602{
603	const char *control_host = NULL;
604	const char *control_port = NULL;
605	int mode = TEST_MODE_UNSET;
606	unsigned int peer_cid = VMADDR_CID_ANY;
607	int i;
608
609	init_signals();
610
611	for (;;) {
612		int opt = getopt_long(argc, argv, optstring, longopts, NULL);
613
614		if (opt == -1)
615			break;
616
617		switch (opt) {
618		case 'H':
619			control_host = optarg;
620			break;
621		case 'm':
622			if (strcmp(optarg, "client") == 0)
623				mode = TEST_MODE_CLIENT;
624			else if (strcmp(optarg, "server") == 0)
625				mode = TEST_MODE_SERVER;
626			else {
627				fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
628				return EXIT_FAILURE;
629			}
630			break;
631		case 'p':
632			peer_cid = parse_cid(optarg);
633			break;
634		case 'P':
635			control_port = optarg;
636			break;
637		case '?':
638		default:
639			usage();
640		}
641	}
642
643	if (!control_port)
644		usage();
645	if (mode == TEST_MODE_UNSET)
646		usage();
647	if (peer_cid == VMADDR_CID_ANY)
648		usage();
649
650	if (!control_host) {
651		if (mode != TEST_MODE_SERVER)
652			usage();
653		control_host = "0.0.0.0";
654	}
655
656	control_init(control_host, control_port, mode == TEST_MODE_SERVER);
657
658	for (i = 0; test_cases[i].name; i++) {
659		void (*run)(unsigned int peer_cid);
660
661		printf("%s...", test_cases[i].name);
662		fflush(stdout);
663
664		if (mode == TEST_MODE_CLIENT)
665			run = test_cases[i].run_client;
666		else
667			run = test_cases[i].run_server;
668
669		if (run)
670			run(peer_cid);
671
672		printf("ok\n");
673	}
674
675	control_cleanup();
676	return EXIT_SUCCESS;
677}