Linux Audio

Check our new training course

Loading...
Note: File does not exist in v3.1.
  1// SPDX-License-Identifier: GPL-2.0-only
  2/*
  3 * vsock_diag_test - vsock_diag.ko test suite
  4 *
  5 * Copyright (C) 2017 Red Hat, Inc.
  6 *
  7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
  8 */
  9
 10#include <getopt.h>
 11#include <stdio.h>
 12#include <stdlib.h>
 13#include <string.h>
 14#include <errno.h>
 15#include <unistd.h>
 16#include <sys/stat.h>
 17#include <sys/types.h>
 18#include <linux/list.h>
 19#include <linux/net.h>
 20#include <linux/netlink.h>
 21#include <linux/sock_diag.h>
 22#include <linux/vm_sockets_diag.h>
 23#include <netinet/tcp.h>
 24
 25#include "timeout.h"
 26#include "control.h"
 27#include "util.h"
 28
 29/* Per-socket status */
 30struct vsock_stat {
 31	struct list_head list;
 32	struct vsock_diag_msg msg;
 33};
 34
 35static const char *sock_type_str(int type)
 36{
 37	switch (type) {
 38	case SOCK_DGRAM:
 39		return "DGRAM";
 40	case SOCK_STREAM:
 41		return "STREAM";
 42	default:
 43		return "INVALID TYPE";
 44	}
 45}
 46
 47static const char *sock_state_str(int state)
 48{
 49	switch (state) {
 50	case TCP_CLOSE:
 51		return "UNCONNECTED";
 52	case TCP_SYN_SENT:
 53		return "CONNECTING";
 54	case TCP_ESTABLISHED:
 55		return "CONNECTED";
 56	case TCP_CLOSING:
 57		return "DISCONNECTING";
 58	case TCP_LISTEN:
 59		return "LISTEN";
 60	default:
 61		return "INVALID STATE";
 62	}
 63}
 64
 65static const char *sock_shutdown_str(int shutdown)
 66{
 67	switch (shutdown) {
 68	case 1:
 69		return "RCV_SHUTDOWN";
 70	case 2:
 71		return "SEND_SHUTDOWN";
 72	case 3:
 73		return "RCV_SHUTDOWN | SEND_SHUTDOWN";
 74	default:
 75		return "0";
 76	}
 77}
 78
 79static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
 80{
 81	if (cid == VMADDR_CID_ANY)
 82		fprintf(fp, "*:");
 83	else
 84		fprintf(fp, "%u:", cid);
 85
 86	if (port == VMADDR_PORT_ANY)
 87		fprintf(fp, "*");
 88	else
 89		fprintf(fp, "%u", port);
 90}
 91
 92static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
 93{
 94	print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
 95	fprintf(fp, " ");
 96	print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
 97	fprintf(fp, " %s %s %s %u\n",
 98		sock_type_str(st->msg.vdiag_type),
 99		sock_state_str(st->msg.vdiag_state),
100		sock_shutdown_str(st->msg.vdiag_shutdown),
101		st->msg.vdiag_ino);
102}
103
104static void print_vsock_stats(FILE *fp, struct list_head *head)
105{
106	struct vsock_stat *st;
107
108	list_for_each_entry(st, head, list)
109		print_vsock_stat(fp, st);
110}
111
112static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
113{
114	struct vsock_stat *st;
115	struct stat stat;
116
117	if (fstat(fd, &stat) < 0) {
118		perror("fstat");
119		exit(EXIT_FAILURE);
120	}
121
122	list_for_each_entry(st, head, list)
123		if (st->msg.vdiag_ino == stat.st_ino)
124			return st;
125
126	fprintf(stderr, "cannot find fd %d\n", fd);
127	exit(EXIT_FAILURE);
128}
129
130static void check_no_sockets(struct list_head *head)
131{
132	if (!list_empty(head)) {
133		fprintf(stderr, "expected no sockets\n");
134		print_vsock_stats(stderr, head);
135		exit(1);
136	}
137}
138
139static void check_num_sockets(struct list_head *head, int expected)
140{
141	struct list_head *node;
142	int n = 0;
143
144	list_for_each(node, head)
145		n++;
146
147	if (n != expected) {
148		fprintf(stderr, "expected %d sockets, found %d\n",
149			expected, n);
150		print_vsock_stats(stderr, head);
151		exit(EXIT_FAILURE);
152	}
153}
154
155static void check_socket_state(struct vsock_stat *st, __u8 state)
156{
157	if (st->msg.vdiag_state != state) {
158		fprintf(stderr, "expected socket state %#x, got %#x\n",
159			state, st->msg.vdiag_state);
160		exit(EXIT_FAILURE);
161	}
162}
163
164static void send_req(int fd)
165{
166	struct sockaddr_nl nladdr = {
167		.nl_family = AF_NETLINK,
168	};
169	struct {
170		struct nlmsghdr nlh;
171		struct vsock_diag_req vreq;
172	} req = {
173		.nlh = {
174			.nlmsg_len = sizeof(req),
175			.nlmsg_type = SOCK_DIAG_BY_FAMILY,
176			.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
177		},
178		.vreq = {
179			.sdiag_family = AF_VSOCK,
180			.vdiag_states = ~(__u32)0,
181		},
182	};
183	struct iovec iov = {
184		.iov_base = &req,
185		.iov_len = sizeof(req),
186	};
187	struct msghdr msg = {
188		.msg_name = &nladdr,
189		.msg_namelen = sizeof(nladdr),
190		.msg_iov = &iov,
191		.msg_iovlen = 1,
192	};
193
194	for (;;) {
195		if (sendmsg(fd, &msg, 0) < 0) {
196			if (errno == EINTR)
197				continue;
198
199			perror("sendmsg");
200			exit(EXIT_FAILURE);
201		}
202
203		return;
204	}
205}
206
207static ssize_t recv_resp(int fd, void *buf, size_t len)
208{
209	struct sockaddr_nl nladdr = {
210		.nl_family = AF_NETLINK,
211	};
212	struct iovec iov = {
213		.iov_base = buf,
214		.iov_len = len,
215	};
216	struct msghdr msg = {
217		.msg_name = &nladdr,
218		.msg_namelen = sizeof(nladdr),
219		.msg_iov = &iov,
220		.msg_iovlen = 1,
221	};
222	ssize_t ret;
223
224	do {
225		ret = recvmsg(fd, &msg, 0);
226	} while (ret < 0 && errno == EINTR);
227
228	if (ret < 0) {
229		perror("recvmsg");
230		exit(EXIT_FAILURE);
231	}
232
233	return ret;
234}
235
236static void add_vsock_stat(struct list_head *sockets,
237			   const struct vsock_diag_msg *resp)
238{
239	struct vsock_stat *st;
240
241	st = malloc(sizeof(*st));
242	if (!st) {
243		perror("malloc");
244		exit(EXIT_FAILURE);
245	}
246
247	st->msg = *resp;
248	list_add_tail(&st->list, sockets);
249}
250
251/*
252 * Read vsock stats into a list.
253 */
254static void read_vsock_stat(struct list_head *sockets)
255{
256	long buf[8192 / sizeof(long)];
257	int fd;
258
259	fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
260	if (fd < 0) {
261		perror("socket");
262		exit(EXIT_FAILURE);
263	}
264
265	send_req(fd);
266
267	for (;;) {
268		const struct nlmsghdr *h;
269		ssize_t ret;
270
271		ret = recv_resp(fd, buf, sizeof(buf));
272		if (ret == 0)
273			goto done;
274		if (ret < sizeof(*h)) {
275			fprintf(stderr, "short read of %zd bytes\n", ret);
276			exit(EXIT_FAILURE);
277		}
278
279		h = (struct nlmsghdr *)buf;
280
281		while (NLMSG_OK(h, ret)) {
282			if (h->nlmsg_type == NLMSG_DONE)
283				goto done;
284
285			if (h->nlmsg_type == NLMSG_ERROR) {
286				const struct nlmsgerr *err = NLMSG_DATA(h);
287
288				if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
289					fprintf(stderr, "NLMSG_ERROR\n");
290				else {
291					errno = -err->error;
292					perror("NLMSG_ERROR");
293				}
294
295				exit(EXIT_FAILURE);
296			}
297
298			if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
299				fprintf(stderr, "unexpected nlmsg_type %#x\n",
300					h->nlmsg_type);
301				exit(EXIT_FAILURE);
302			}
303			if (h->nlmsg_len <
304			    NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
305				fprintf(stderr, "short vsock_diag_msg\n");
306				exit(EXIT_FAILURE);
307			}
308
309			add_vsock_stat(sockets, NLMSG_DATA(h));
310
311			h = NLMSG_NEXT(h, ret);
312		}
313	}
314
315done:
316	close(fd);
317}
318
319static void free_sock_stat(struct list_head *sockets)
320{
321	struct vsock_stat *st;
322	struct vsock_stat *next;
323
324	list_for_each_entry_safe(st, next, sockets, list)
325		free(st);
326}
327
328static void test_no_sockets(const struct test_opts *opts)
329{
330	LIST_HEAD(sockets);
331
332	read_vsock_stat(&sockets);
333
334	check_no_sockets(&sockets);
335
336	free_sock_stat(&sockets);
337}
338
339static void test_listen_socket_server(const struct test_opts *opts)
340{
341	union {
342		struct sockaddr sa;
343		struct sockaddr_vm svm;
344	} addr = {
345		.svm = {
346			.svm_family = AF_VSOCK,
347			.svm_port = 1234,
348			.svm_cid = VMADDR_CID_ANY,
349		},
350	};
351	LIST_HEAD(sockets);
352	struct vsock_stat *st;
353	int fd;
354
355	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
356
357	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
358		perror("bind");
359		exit(EXIT_FAILURE);
360	}
361
362	if (listen(fd, 1) < 0) {
363		perror("listen");
364		exit(EXIT_FAILURE);
365	}
366
367	read_vsock_stat(&sockets);
368
369	check_num_sockets(&sockets, 1);
370	st = find_vsock_stat(&sockets, fd);
371	check_socket_state(st, TCP_LISTEN);
372
373	close(fd);
374	free_sock_stat(&sockets);
375}
376
377static void test_connect_client(const struct test_opts *opts)
378{
379	int fd;
380	LIST_HEAD(sockets);
381	struct vsock_stat *st;
382
383	fd = vsock_stream_connect(opts->peer_cid, 1234);
384	if (fd < 0) {
385		perror("connect");
386		exit(EXIT_FAILURE);
387	}
388
389	read_vsock_stat(&sockets);
390
391	check_num_sockets(&sockets, 1);
392	st = find_vsock_stat(&sockets, fd);
393	check_socket_state(st, TCP_ESTABLISHED);
394
395	control_expectln("DONE");
396	control_writeln("DONE");
397
398	close(fd);
399	free_sock_stat(&sockets);
400}
401
402static void test_connect_server(const struct test_opts *opts)
403{
404	struct vsock_stat *st;
405	LIST_HEAD(sockets);
406	int client_fd;
407
408	client_fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
409	if (client_fd < 0) {
410		perror("accept");
411		exit(EXIT_FAILURE);
412	}
413
414	read_vsock_stat(&sockets);
415
416	check_num_sockets(&sockets, 1);
417	st = find_vsock_stat(&sockets, client_fd);
418	check_socket_state(st, TCP_ESTABLISHED);
419
420	control_writeln("DONE");
421	control_expectln("DONE");
422
423	close(client_fd);
424	free_sock_stat(&sockets);
425}
426
427static struct test_case test_cases[] = {
428	{
429		.name = "No sockets",
430		.run_server = test_no_sockets,
431	},
432	{
433		.name = "Listen socket",
434		.run_server = test_listen_socket_server,
435	},
436	{
437		.name = "Connect",
438		.run_client = test_connect_client,
439		.run_server = test_connect_server,
440	},
441	{},
442};
443
444static const char optstring[] = "";
445static const struct option longopts[] = {
446	{
447		.name = "control-host",
448		.has_arg = required_argument,
449		.val = 'H',
450	},
451	{
452		.name = "control-port",
453		.has_arg = required_argument,
454		.val = 'P',
455	},
456	{
457		.name = "mode",
458		.has_arg = required_argument,
459		.val = 'm',
460	},
461	{
462		.name = "peer-cid",
463		.has_arg = required_argument,
464		.val = 'p',
465	},
466	{
467		.name = "list",
468		.has_arg = no_argument,
469		.val = 'l',
470	},
471	{
472		.name = "skip",
473		.has_arg = required_argument,
474		.val = 's',
475	},
476	{
477		.name = "help",
478		.has_arg = no_argument,
479		.val = '?',
480	},
481	{},
482};
483
484static void usage(void)
485{
486	fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--list] [--skip=<test_id>]\n"
487		"\n"
488		"  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
489		"  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
490		"\n"
491		"Run vsock_diag.ko tests.  Must be launched in both\n"
492		"guest and host.  One side must use --mode=client and\n"
493		"the other side must use --mode=server.\n"
494		"\n"
495		"A TCP control socket connection is used to coordinate tests\n"
496		"between the client and the server.  The server requires a\n"
497		"listen address and the client requires an address to\n"
498		"connect to.\n"
499		"\n"
500		"The CID of the other side must be given with --peer-cid=<cid>.\n"
501		"\n"
502		"Options:\n"
503		"  --help                 This help message\n"
504		"  --control-host <host>  Server IP address to connect to\n"
505		"  --control-port <port>  Server port to listen on/connect to\n"
506		"  --mode client|server   Server or client mode\n"
507		"  --peer-cid <cid>       CID of the other side\n"
508		"  --list                 List of tests that will be executed\n"
509		"  --skip <test_id>       Test ID to skip;\n"
510		"                         use multiple --skip options to skip more tests\n"
511		);
512	exit(EXIT_FAILURE);
513}
514
515int main(int argc, char **argv)
516{
517	const char *control_host = NULL;
518	const char *control_port = NULL;
519	struct test_opts opts = {
520		.mode = TEST_MODE_UNSET,
521		.peer_cid = VMADDR_CID_ANY,
522	};
523
524	init_signals();
525
526	for (;;) {
527		int opt = getopt_long(argc, argv, optstring, longopts, NULL);
528
529		if (opt == -1)
530			break;
531
532		switch (opt) {
533		case 'H':
534			control_host = optarg;
535			break;
536		case 'm':
537			if (strcmp(optarg, "client") == 0)
538				opts.mode = TEST_MODE_CLIENT;
539			else if (strcmp(optarg, "server") == 0)
540				opts.mode = TEST_MODE_SERVER;
541			else {
542				fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
543				return EXIT_FAILURE;
544			}
545			break;
546		case 'p':
547			opts.peer_cid = parse_cid(optarg);
548			break;
549		case 'P':
550			control_port = optarg;
551			break;
552		case 'l':
553			list_tests(test_cases);
554			break;
555		case 's':
556			skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
557				  optarg);
558			break;
559		case '?':
560		default:
561			usage();
562		}
563	}
564
565	if (!control_port)
566		usage();
567	if (opts.mode == TEST_MODE_UNSET)
568		usage();
569	if (opts.peer_cid == VMADDR_CID_ANY)
570		usage();
571
572	if (!control_host) {
573		if (opts.mode != TEST_MODE_SERVER)
574			usage();
575		control_host = "0.0.0.0";
576	}
577
578	control_init(control_host, control_port,
579		     opts.mode == TEST_MODE_SERVER);
580
581	run_tests(test_cases, &opts);
582
583	control_cleanup();
584	return EXIT_SUCCESS;
585}