Linux Audio

Check our new training course

Loading...
v6.13.7
  1// SPDX-License-Identifier: GPL-2.0-only
  2/*
  3 * vsock_diag_test - vsock_diag.ko test suite
  4 *
  5 * Copyright (C) 2017 Red Hat, Inc.
  6 *
  7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
 
 
 
 
 
  8 */
  9
 10#include <getopt.h>
 11#include <stdio.h>
 
 12#include <stdlib.h>
 13#include <string.h>
 14#include <errno.h>
 15#include <unistd.h>
 
 
 16#include <sys/stat.h>
 17#include <sys/types.h>
 18#include <linux/list.h>
 19#include <linux/net.h>
 20#include <linux/netlink.h>
 21#include <linux/sock_diag.h>
 22#include <linux/vm_sockets_diag.h>
 23#include <netinet/tcp.h>
 24
 
 
 
 25#include "timeout.h"
 26#include "control.h"
 27#include "util.h"
 
 
 
 
 
 28
 29/* Per-socket status */
 30struct vsock_stat {
 31	struct list_head list;
 32	struct vsock_diag_msg msg;
 33};
 34
 35static const char *sock_type_str(int type)
 36{
 37	switch (type) {
 38	case SOCK_DGRAM:
 39		return "DGRAM";
 40	case SOCK_STREAM:
 41		return "STREAM";
 42	case SOCK_SEQPACKET:
 43		return "SEQPACKET";
 44	default:
 45		return "INVALID TYPE";
 46	}
 47}
 48
 49static const char *sock_state_str(int state)
 50{
 51	switch (state) {
 52	case TCP_CLOSE:
 53		return "UNCONNECTED";
 54	case TCP_SYN_SENT:
 55		return "CONNECTING";
 56	case TCP_ESTABLISHED:
 57		return "CONNECTED";
 58	case TCP_CLOSING:
 59		return "DISCONNECTING";
 60	case TCP_LISTEN:
 61		return "LISTEN";
 62	default:
 63		return "INVALID STATE";
 64	}
 65}
 66
 67static const char *sock_shutdown_str(int shutdown)
 68{
 69	switch (shutdown) {
 70	case 1:
 71		return "RCV_SHUTDOWN";
 72	case 2:
 73		return "SEND_SHUTDOWN";
 74	case 3:
 75		return "RCV_SHUTDOWN | SEND_SHUTDOWN";
 76	default:
 77		return "0";
 78	}
 79}
 80
 81static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
 82{
 83	if (cid == VMADDR_CID_ANY)
 84		fprintf(fp, "*:");
 85	else
 86		fprintf(fp, "%u:", cid);
 87
 88	if (port == VMADDR_PORT_ANY)
 89		fprintf(fp, "*");
 90	else
 91		fprintf(fp, "%u", port);
 92}
 93
 94static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
 95{
 96	print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
 97	fprintf(fp, " ");
 98	print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
 99	fprintf(fp, " %s %s %s %u\n",
100		sock_type_str(st->msg.vdiag_type),
101		sock_state_str(st->msg.vdiag_state),
102		sock_shutdown_str(st->msg.vdiag_shutdown),
103		st->msg.vdiag_ino);
104}
105
106static void print_vsock_stats(FILE *fp, struct list_head *head)
107{
108	struct vsock_stat *st;
109
110	list_for_each_entry(st, head, list)
111		print_vsock_stat(fp, st);
112}
113
114static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
115{
116	struct vsock_stat *st;
117	struct stat stat;
118
119	if (fstat(fd, &stat) < 0) {
120		perror("fstat");
121		exit(EXIT_FAILURE);
122	}
123
124	list_for_each_entry(st, head, list)
125		if (st->msg.vdiag_ino == stat.st_ino)
126			return st;
127
128	fprintf(stderr, "cannot find fd %d\n", fd);
129	exit(EXIT_FAILURE);
130}
131
132static void check_no_sockets(struct list_head *head)
133{
134	if (!list_empty(head)) {
135		fprintf(stderr, "expected no sockets\n");
136		print_vsock_stats(stderr, head);
137		exit(1);
138	}
139}
140
141static void check_num_sockets(struct list_head *head, int expected)
142{
143	struct list_head *node;
144	int n = 0;
145
146	list_for_each(node, head)
147		n++;
148
149	if (n != expected) {
150		fprintf(stderr, "expected %d sockets, found %d\n",
151			expected, n);
152		print_vsock_stats(stderr, head);
153		exit(EXIT_FAILURE);
154	}
155}
156
157static void check_socket_state(struct vsock_stat *st, __u8 state)
158{
159	if (st->msg.vdiag_state != state) {
160		fprintf(stderr, "expected socket state %#x, got %#x\n",
161			state, st->msg.vdiag_state);
162		exit(EXIT_FAILURE);
163	}
164}
165
166static void send_req(int fd)
167{
168	struct sockaddr_nl nladdr = {
169		.nl_family = AF_NETLINK,
170	};
171	struct {
172		struct nlmsghdr nlh;
173		struct vsock_diag_req vreq;
174	} req = {
175		.nlh = {
176			.nlmsg_len = sizeof(req),
177			.nlmsg_type = SOCK_DIAG_BY_FAMILY,
178			.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
179		},
180		.vreq = {
181			.sdiag_family = AF_VSOCK,
182			.vdiag_states = ~(__u32)0,
183		},
184	};
185	struct iovec iov = {
186		.iov_base = &req,
187		.iov_len = sizeof(req),
188	};
189	struct msghdr msg = {
190		.msg_name = &nladdr,
191		.msg_namelen = sizeof(nladdr),
192		.msg_iov = &iov,
193		.msg_iovlen = 1,
194	};
195
196	for (;;) {
197		if (sendmsg(fd, &msg, 0) < 0) {
198			if (errno == EINTR)
199				continue;
200
201			perror("sendmsg");
202			exit(EXIT_FAILURE);
203		}
204
205		return;
206	}
207}
208
209static ssize_t recv_resp(int fd, void *buf, size_t len)
210{
211	struct sockaddr_nl nladdr = {
212		.nl_family = AF_NETLINK,
213	};
214	struct iovec iov = {
215		.iov_base = buf,
216		.iov_len = len,
217	};
218	struct msghdr msg = {
219		.msg_name = &nladdr,
220		.msg_namelen = sizeof(nladdr),
221		.msg_iov = &iov,
222		.msg_iovlen = 1,
223	};
224	ssize_t ret;
225
226	do {
227		ret = recvmsg(fd, &msg, 0);
228	} while (ret < 0 && errno == EINTR);
229
230	if (ret < 0) {
231		perror("recvmsg");
232		exit(EXIT_FAILURE);
233	}
234
235	return ret;
236}
237
238static void add_vsock_stat(struct list_head *sockets,
239			   const struct vsock_diag_msg *resp)
240{
241	struct vsock_stat *st;
242
243	st = malloc(sizeof(*st));
244	if (!st) {
245		perror("malloc");
246		exit(EXIT_FAILURE);
247	}
248
249	st->msg = *resp;
250	list_add_tail(&st->list, sockets);
251}
252
253/*
254 * Read vsock stats into a list.
255 */
256static void read_vsock_stat(struct list_head *sockets)
257{
258	long buf[8192 / sizeof(long)];
259	int fd;
260
261	fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
262	if (fd < 0) {
263		perror("socket");
264		exit(EXIT_FAILURE);
265	}
266
267	send_req(fd);
268
269	for (;;) {
270		const struct nlmsghdr *h;
271		ssize_t ret;
272
273		ret = recv_resp(fd, buf, sizeof(buf));
274		if (ret == 0)
275			goto done;
276		if (ret < sizeof(*h)) {
277			fprintf(stderr, "short read of %zd bytes\n", ret);
278			exit(EXIT_FAILURE);
279		}
280
281		h = (struct nlmsghdr *)buf;
282
283		while (NLMSG_OK(h, ret)) {
284			if (h->nlmsg_type == NLMSG_DONE)
285				goto done;
286
287			if (h->nlmsg_type == NLMSG_ERROR) {
288				const struct nlmsgerr *err = NLMSG_DATA(h);
289
290				if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
291					fprintf(stderr, "NLMSG_ERROR\n");
292				else {
293					errno = -err->error;
294					perror("NLMSG_ERROR");
295				}
296
297				exit(EXIT_FAILURE);
298			}
299
300			if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
301				fprintf(stderr, "unexpected nlmsg_type %#x\n",
302					h->nlmsg_type);
303				exit(EXIT_FAILURE);
304			}
305			if (h->nlmsg_len <
306			    NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
307				fprintf(stderr, "short vsock_diag_msg\n");
308				exit(EXIT_FAILURE);
309			}
310
311			add_vsock_stat(sockets, NLMSG_DATA(h));
312
313			h = NLMSG_NEXT(h, ret);
314		}
315	}
316
317done:
318	close(fd);
319}
320
321static void free_sock_stat(struct list_head *sockets)
322{
323	struct vsock_stat *st;
324	struct vsock_stat *next;
325
326	list_for_each_entry_safe(st, next, sockets, list)
327		free(st);
328}
329
330static void test_no_sockets(const struct test_opts *opts)
331{
332	LIST_HEAD(sockets);
333
334	read_vsock_stat(&sockets);
335
336	check_no_sockets(&sockets);
 
 
337}
338
339static void test_listen_socket_server(const struct test_opts *opts)
340{
341	union {
342		struct sockaddr sa;
343		struct sockaddr_vm svm;
344	} addr = {
345		.svm = {
346			.svm_family = AF_VSOCK,
347			.svm_port = opts->peer_port,
348			.svm_cid = VMADDR_CID_ANY,
349		},
350	};
351	LIST_HEAD(sockets);
352	struct vsock_stat *st;
353	int fd;
354
355	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
356
357	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
358		perror("bind");
359		exit(EXIT_FAILURE);
360	}
361
362	if (listen(fd, 1) < 0) {
363		perror("listen");
364		exit(EXIT_FAILURE);
365	}
366
367	read_vsock_stat(&sockets);
368
369	check_num_sockets(&sockets, 1);
370	st = find_vsock_stat(&sockets, fd);
371	check_socket_state(st, TCP_LISTEN);
372
373	close(fd);
374	free_sock_stat(&sockets);
375}
376
377static void test_connect_client(const struct test_opts *opts)
378{
 
 
 
 
 
 
 
 
 
 
379	int fd;
 
380	LIST_HEAD(sockets);
381	struct vsock_stat *st;
382
383	fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
384	if (fd < 0) {
 
 
 
 
 
 
 
 
 
 
385		perror("connect");
386		exit(EXIT_FAILURE);
387	}
388
389	read_vsock_stat(&sockets);
390
391	check_num_sockets(&sockets, 1);
392	st = find_vsock_stat(&sockets, fd);
393	check_socket_state(st, TCP_ESTABLISHED);
394
395	control_expectln("DONE");
396	control_writeln("DONE");
397
398	close(fd);
399	free_sock_stat(&sockets);
400}
401
402static void test_connect_server(const struct test_opts *opts)
403{
404	struct vsock_stat *st;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405	LIST_HEAD(sockets);
 
 
406	int client_fd;
407
408	client_fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409	if (client_fd < 0) {
410		perror("accept");
411		exit(EXIT_FAILURE);
412	}
 
 
 
 
 
 
 
 
 
 
413
414	read_vsock_stat(&sockets);
415
416	check_num_sockets(&sockets, 1);
 
417	st = find_vsock_stat(&sockets, client_fd);
418	check_socket_state(st, TCP_ESTABLISHED);
419
420	control_writeln("DONE");
421	control_expectln("DONE");
422
423	close(client_fd);
 
424	free_sock_stat(&sockets);
425}
426
427static struct test_case test_cases[] = {
 
 
 
 
428	{
429		.name = "No sockets",
430		.run_server = test_no_sockets,
431	},
432	{
433		.name = "Listen socket",
434		.run_server = test_listen_socket_server,
435	},
436	{
437		.name = "Connect",
438		.run_client = test_connect_client,
439		.run_server = test_connect_server,
440	},
441	{},
442};
443
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444static const char optstring[] = "";
445static const struct option longopts[] = {
446	{
447		.name = "control-host",
448		.has_arg = required_argument,
449		.val = 'H',
450	},
451	{
452		.name = "control-port",
453		.has_arg = required_argument,
454		.val = 'P',
455	},
456	{
457		.name = "mode",
458		.has_arg = required_argument,
459		.val = 'm',
460	},
461	{
462		.name = "peer-cid",
463		.has_arg = required_argument,
464		.val = 'p',
465	},
466	{
467		.name = "peer-port",
468		.has_arg = required_argument,
469		.val = 'q',
470	},
471	{
472		.name = "list",
473		.has_arg = no_argument,
474		.val = 'l',
475	},
476	{
477		.name = "skip",
478		.has_arg = required_argument,
479		.val = 's',
480	},
481	{
482		.name = "help",
483		.has_arg = no_argument,
484		.val = '?',
485	},
486	{},
487};
488
489static void usage(void)
490{
491	fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--peer-port=<port>] [--list] [--skip=<test_id>]\n"
492		"\n"
493		"  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
494		"  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
495		"\n"
496		"Run vsock_diag.ko tests.  Must be launched in both\n"
497		"guest and host.  One side must use --mode=client and\n"
498		"the other side must use --mode=server.\n"
499		"\n"
500		"A TCP control socket connection is used to coordinate tests\n"
501		"between the client and the server.  The server requires a\n"
502		"listen address and the client requires an address to\n"
503		"connect to.\n"
504		"\n"
505		"The CID of the other side must be given with --peer-cid=<cid>.\n"
506		"\n"
507		"Options:\n"
508		"  --help                 This help message\n"
509		"  --control-host <host>  Server IP address to connect to\n"
510		"  --control-port <port>  Server port to listen on/connect to\n"
511		"  --mode client|server   Server or client mode\n"
512		"  --peer-cid <cid>       CID of the other side\n"
513		"  --peer-port <port>     AF_VSOCK port used for the test [default: %d]\n"
514		"  --list                 List of tests that will be executed\n"
515		"  --skip <test_id>       Test ID to skip;\n"
516		"                         use multiple --skip options to skip more tests\n",
517		DEFAULT_PEER_PORT
518		);
519	exit(EXIT_FAILURE);
520}
521
522int main(int argc, char **argv)
523{
524	const char *control_host = NULL;
525	const char *control_port = NULL;
526	struct test_opts opts = {
527		.mode = TEST_MODE_UNSET,
528		.peer_cid = VMADDR_CID_ANY,
529		.peer_port = DEFAULT_PEER_PORT,
530	};
531
532	init_signals();
533
534	for (;;) {
535		int opt = getopt_long(argc, argv, optstring, longopts, NULL);
536
537		if (opt == -1)
538			break;
539
540		switch (opt) {
541		case 'H':
542			control_host = optarg;
543			break;
544		case 'm':
545			if (strcmp(optarg, "client") == 0)
546				opts.mode = TEST_MODE_CLIENT;
547			else if (strcmp(optarg, "server") == 0)
548				opts.mode = TEST_MODE_SERVER;
549			else {
550				fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
551				return EXIT_FAILURE;
552			}
553			break;
554		case 'p':
555			opts.peer_cid = parse_cid(optarg);
556			break;
557		case 'q':
558			opts.peer_port = parse_port(optarg);
559			break;
560		case 'P':
561			control_port = optarg;
562			break;
563		case 'l':
564			list_tests(test_cases);
565			break;
566		case 's':
567			skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
568				  optarg);
569			break;
570		case '?':
571		default:
572			usage();
573		}
574	}
575
576	if (!control_port)
577		usage();
578	if (opts.mode == TEST_MODE_UNSET)
579		usage();
580	if (opts.peer_cid == VMADDR_CID_ANY)
581		usage();
582
583	if (!control_host) {
584		if (opts.mode != TEST_MODE_SERVER)
585			usage();
586		control_host = "0.0.0.0";
587	}
588
589	control_init(control_host, control_port,
590		     opts.mode == TEST_MODE_SERVER);
591
592	run_tests(test_cases, &opts);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593
594	control_cleanup();
595	return EXIT_SUCCESS;
596}
v4.17
 
  1/*
  2 * vsock_diag_test - vsock_diag.ko test suite
  3 *
  4 * Copyright (C) 2017 Red Hat, Inc.
  5 *
  6 * Author: Stefan Hajnoczi <stefanha@redhat.com>
  7 *
  8 * This program is free software; you can redistribute it and/or
  9 * modify it under the terms of the GNU General Public License
 10 * as published by the Free Software Foundation; version 2
 11 * of the License.
 12 */
 13
 14#include <getopt.h>
 15#include <stdio.h>
 16#include <stdbool.h>
 17#include <stdlib.h>
 18#include <string.h>
 19#include <errno.h>
 20#include <unistd.h>
 21#include <signal.h>
 22#include <sys/socket.h>
 23#include <sys/stat.h>
 24#include <sys/types.h>
 25#include <linux/list.h>
 26#include <linux/net.h>
 27#include <linux/netlink.h>
 28#include <linux/sock_diag.h>
 
 29#include <netinet/tcp.h>
 30
 31#include "../../../include/uapi/linux/vm_sockets.h"
 32#include "../../../include/uapi/linux/vm_sockets_diag.h"
 33
 34#include "timeout.h"
 35#include "control.h"
 36
 37enum test_mode {
 38	TEST_MODE_UNSET,
 39	TEST_MODE_CLIENT,
 40	TEST_MODE_SERVER
 41};
 42
 43/* Per-socket status */
 44struct vsock_stat {
 45	struct list_head list;
 46	struct vsock_diag_msg msg;
 47};
 48
 49static const char *sock_type_str(int type)
 50{
 51	switch (type) {
 52	case SOCK_DGRAM:
 53		return "DGRAM";
 54	case SOCK_STREAM:
 55		return "STREAM";
 
 
 56	default:
 57		return "INVALID TYPE";
 58	}
 59}
 60
 61static const char *sock_state_str(int state)
 62{
 63	switch (state) {
 64	case TCP_CLOSE:
 65		return "UNCONNECTED";
 66	case TCP_SYN_SENT:
 67		return "CONNECTING";
 68	case TCP_ESTABLISHED:
 69		return "CONNECTED";
 70	case TCP_CLOSING:
 71		return "DISCONNECTING";
 72	case TCP_LISTEN:
 73		return "LISTEN";
 74	default:
 75		return "INVALID STATE";
 76	}
 77}
 78
 79static const char *sock_shutdown_str(int shutdown)
 80{
 81	switch (shutdown) {
 82	case 1:
 83		return "RCV_SHUTDOWN";
 84	case 2:
 85		return "SEND_SHUTDOWN";
 86	case 3:
 87		return "RCV_SHUTDOWN | SEND_SHUTDOWN";
 88	default:
 89		return "0";
 90	}
 91}
 92
 93static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
 94{
 95	if (cid == VMADDR_CID_ANY)
 96		fprintf(fp, "*:");
 97	else
 98		fprintf(fp, "%u:", cid);
 99
100	if (port == VMADDR_PORT_ANY)
101		fprintf(fp, "*");
102	else
103		fprintf(fp, "%u", port);
104}
105
106static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
107{
108	print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
109	fprintf(fp, " ");
110	print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
111	fprintf(fp, " %s %s %s %u\n",
112		sock_type_str(st->msg.vdiag_type),
113		sock_state_str(st->msg.vdiag_state),
114		sock_shutdown_str(st->msg.vdiag_shutdown),
115		st->msg.vdiag_ino);
116}
117
118static void print_vsock_stats(FILE *fp, struct list_head *head)
119{
120	struct vsock_stat *st;
121
122	list_for_each_entry(st, head, list)
123		print_vsock_stat(fp, st);
124}
125
126static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
127{
128	struct vsock_stat *st;
129	struct stat stat;
130
131	if (fstat(fd, &stat) < 0) {
132		perror("fstat");
133		exit(EXIT_FAILURE);
134	}
135
136	list_for_each_entry(st, head, list)
137		if (st->msg.vdiag_ino == stat.st_ino)
138			return st;
139
140	fprintf(stderr, "cannot find fd %d\n", fd);
141	exit(EXIT_FAILURE);
142}
143
144static void check_no_sockets(struct list_head *head)
145{
146	if (!list_empty(head)) {
147		fprintf(stderr, "expected no sockets\n");
148		print_vsock_stats(stderr, head);
149		exit(1);
150	}
151}
152
153static void check_num_sockets(struct list_head *head, int expected)
154{
155	struct list_head *node;
156	int n = 0;
157
158	list_for_each(node, head)
159		n++;
160
161	if (n != expected) {
162		fprintf(stderr, "expected %d sockets, found %d\n",
163			expected, n);
164		print_vsock_stats(stderr, head);
165		exit(EXIT_FAILURE);
166	}
167}
168
169static void check_socket_state(struct vsock_stat *st, __u8 state)
170{
171	if (st->msg.vdiag_state != state) {
172		fprintf(stderr, "expected socket state %#x, got %#x\n",
173			state, st->msg.vdiag_state);
174		exit(EXIT_FAILURE);
175	}
176}
177
178static void send_req(int fd)
179{
180	struct sockaddr_nl nladdr = {
181		.nl_family = AF_NETLINK,
182	};
183	struct {
184		struct nlmsghdr nlh;
185		struct vsock_diag_req vreq;
186	} req = {
187		.nlh = {
188			.nlmsg_len = sizeof(req),
189			.nlmsg_type = SOCK_DIAG_BY_FAMILY,
190			.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
191		},
192		.vreq = {
193			.sdiag_family = AF_VSOCK,
194			.vdiag_states = ~(__u32)0,
195		},
196	};
197	struct iovec iov = {
198		.iov_base = &req,
199		.iov_len = sizeof(req),
200	};
201	struct msghdr msg = {
202		.msg_name = &nladdr,
203		.msg_namelen = sizeof(nladdr),
204		.msg_iov = &iov,
205		.msg_iovlen = 1,
206	};
207
208	for (;;) {
209		if (sendmsg(fd, &msg, 0) < 0) {
210			if (errno == EINTR)
211				continue;
212
213			perror("sendmsg");
214			exit(EXIT_FAILURE);
215		}
216
217		return;
218	}
219}
220
221static ssize_t recv_resp(int fd, void *buf, size_t len)
222{
223	struct sockaddr_nl nladdr = {
224		.nl_family = AF_NETLINK,
225	};
226	struct iovec iov = {
227		.iov_base = buf,
228		.iov_len = len,
229	};
230	struct msghdr msg = {
231		.msg_name = &nladdr,
232		.msg_namelen = sizeof(nladdr),
233		.msg_iov = &iov,
234		.msg_iovlen = 1,
235	};
236	ssize_t ret;
237
238	do {
239		ret = recvmsg(fd, &msg, 0);
240	} while (ret < 0 && errno == EINTR);
241
242	if (ret < 0) {
243		perror("recvmsg");
244		exit(EXIT_FAILURE);
245	}
246
247	return ret;
248}
249
250static void add_vsock_stat(struct list_head *sockets,
251			   const struct vsock_diag_msg *resp)
252{
253	struct vsock_stat *st;
254
255	st = malloc(sizeof(*st));
256	if (!st) {
257		perror("malloc");
258		exit(EXIT_FAILURE);
259	}
260
261	st->msg = *resp;
262	list_add_tail(&st->list, sockets);
263}
264
265/*
266 * Read vsock stats into a list.
267 */
268static void read_vsock_stat(struct list_head *sockets)
269{
270	long buf[8192 / sizeof(long)];
271	int fd;
272
273	fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
274	if (fd < 0) {
275		perror("socket");
276		exit(EXIT_FAILURE);
277	}
278
279	send_req(fd);
280
281	for (;;) {
282		const struct nlmsghdr *h;
283		ssize_t ret;
284
285		ret = recv_resp(fd, buf, sizeof(buf));
286		if (ret == 0)
287			goto done;
288		if (ret < sizeof(*h)) {
289			fprintf(stderr, "short read of %zd bytes\n", ret);
290			exit(EXIT_FAILURE);
291		}
292
293		h = (struct nlmsghdr *)buf;
294
295		while (NLMSG_OK(h, ret)) {
296			if (h->nlmsg_type == NLMSG_DONE)
297				goto done;
298
299			if (h->nlmsg_type == NLMSG_ERROR) {
300				const struct nlmsgerr *err = NLMSG_DATA(h);
301
302				if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
303					fprintf(stderr, "NLMSG_ERROR\n");
304				else {
305					errno = -err->error;
306					perror("NLMSG_ERROR");
307				}
308
309				exit(EXIT_FAILURE);
310			}
311
312			if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
313				fprintf(stderr, "unexpected nlmsg_type %#x\n",
314					h->nlmsg_type);
315				exit(EXIT_FAILURE);
316			}
317			if (h->nlmsg_len <
318			    NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
319				fprintf(stderr, "short vsock_diag_msg\n");
320				exit(EXIT_FAILURE);
321			}
322
323			add_vsock_stat(sockets, NLMSG_DATA(h));
324
325			h = NLMSG_NEXT(h, ret);
326		}
327	}
328
329done:
330	close(fd);
331}
332
333static void free_sock_stat(struct list_head *sockets)
334{
335	struct vsock_stat *st;
336	struct vsock_stat *next;
337
338	list_for_each_entry_safe(st, next, sockets, list)
339		free(st);
340}
341
342static void test_no_sockets(unsigned int peer_cid)
343{
344	LIST_HEAD(sockets);
345
346	read_vsock_stat(&sockets);
347
348	check_no_sockets(&sockets);
349
350	free_sock_stat(&sockets);
351}
352
353static void test_listen_socket_server(unsigned int peer_cid)
354{
355	union {
356		struct sockaddr sa;
357		struct sockaddr_vm svm;
358	} addr = {
359		.svm = {
360			.svm_family = AF_VSOCK,
361			.svm_port = 1234,
362			.svm_cid = VMADDR_CID_ANY,
363		},
364	};
365	LIST_HEAD(sockets);
366	struct vsock_stat *st;
367	int fd;
368
369	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
370
371	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
372		perror("bind");
373		exit(EXIT_FAILURE);
374	}
375
376	if (listen(fd, 1) < 0) {
377		perror("listen");
378		exit(EXIT_FAILURE);
379	}
380
381	read_vsock_stat(&sockets);
382
383	check_num_sockets(&sockets, 1);
384	st = find_vsock_stat(&sockets, fd);
385	check_socket_state(st, TCP_LISTEN);
386
387	close(fd);
388	free_sock_stat(&sockets);
389}
390
391static void test_connect_client(unsigned int peer_cid)
392{
393	union {
394		struct sockaddr sa;
395		struct sockaddr_vm svm;
396	} addr = {
397		.svm = {
398			.svm_family = AF_VSOCK,
399			.svm_port = 1234,
400			.svm_cid = peer_cid,
401		},
402	};
403	int fd;
404	int ret;
405	LIST_HEAD(sockets);
406	struct vsock_stat *st;
407
408	control_expectln("LISTENING");
409
410	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
411
412	timeout_begin(TIMEOUT);
413	do {
414		ret = connect(fd, &addr.sa, sizeof(addr.svm));
415		timeout_check("connect");
416	} while (ret < 0 && errno == EINTR);
417	timeout_end();
418
419	if (ret < 0) {
420		perror("connect");
421		exit(EXIT_FAILURE);
422	}
423
424	read_vsock_stat(&sockets);
425
426	check_num_sockets(&sockets, 1);
427	st = find_vsock_stat(&sockets, fd);
428	check_socket_state(st, TCP_ESTABLISHED);
429
430	control_expectln("DONE");
431	control_writeln("DONE");
432
433	close(fd);
434	free_sock_stat(&sockets);
435}
436
437static void test_connect_server(unsigned int peer_cid)
438{
439	union {
440		struct sockaddr sa;
441		struct sockaddr_vm svm;
442	} addr = {
443		.svm = {
444			.svm_family = AF_VSOCK,
445			.svm_port = 1234,
446			.svm_cid = VMADDR_CID_ANY,
447		},
448	};
449	union {
450		struct sockaddr sa;
451		struct sockaddr_vm svm;
452	} clientaddr;
453	socklen_t clientaddr_len = sizeof(clientaddr.svm);
454	LIST_HEAD(sockets);
455	struct vsock_stat *st;
456	int fd;
457	int client_fd;
458
459	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
460
461	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
462		perror("bind");
463		exit(EXIT_FAILURE);
464	}
465
466	if (listen(fd, 1) < 0) {
467		perror("listen");
468		exit(EXIT_FAILURE);
469	}
470
471	control_writeln("LISTENING");
472
473	timeout_begin(TIMEOUT);
474	do {
475		client_fd = accept(fd, &clientaddr.sa, &clientaddr_len);
476		timeout_check("accept");
477	} while (client_fd < 0 && errno == EINTR);
478	timeout_end();
479
480	if (client_fd < 0) {
481		perror("accept");
482		exit(EXIT_FAILURE);
483	}
484	if (clientaddr.sa.sa_family != AF_VSOCK) {
485		fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n",
486			clientaddr.sa.sa_family);
487		exit(EXIT_FAILURE);
488	}
489	if (clientaddr.svm.svm_cid != peer_cid) {
490		fprintf(stderr, "expected peer CID %u from accept(2), got %u\n",
491			peer_cid, clientaddr.svm.svm_cid);
492		exit(EXIT_FAILURE);
493	}
494
495	read_vsock_stat(&sockets);
496
497	check_num_sockets(&sockets, 2);
498	find_vsock_stat(&sockets, fd);
499	st = find_vsock_stat(&sockets, client_fd);
500	check_socket_state(st, TCP_ESTABLISHED);
501
502	control_writeln("DONE");
503	control_expectln("DONE");
504
505	close(client_fd);
506	close(fd);
507	free_sock_stat(&sockets);
508}
509
510static struct {
511	const char *name;
512	void (*run_client)(unsigned int peer_cid);
513	void (*run_server)(unsigned int peer_cid);
514} test_cases[] = {
515	{
516		.name = "No sockets",
517		.run_server = test_no_sockets,
518	},
519	{
520		.name = "Listen socket",
521		.run_server = test_listen_socket_server,
522	},
523	{
524		.name = "Connect",
525		.run_client = test_connect_client,
526		.run_server = test_connect_server,
527	},
528	{},
529};
530
531static void init_signals(void)
532{
533	struct sigaction act = {
534		.sa_handler = sigalrm,
535	};
536
537	sigaction(SIGALRM, &act, NULL);
538	signal(SIGPIPE, SIG_IGN);
539}
540
541static unsigned int parse_cid(const char *str)
542{
543	char *endptr = NULL;
544	unsigned long int n;
545
546	errno = 0;
547	n = strtoul(str, &endptr, 10);
548	if (errno || *endptr != '\0') {
549		fprintf(stderr, "malformed CID \"%s\"\n", str);
550		exit(EXIT_FAILURE);
551	}
552	return n;
553}
554
555static const char optstring[] = "";
556static const struct option longopts[] = {
557	{
558		.name = "control-host",
559		.has_arg = required_argument,
560		.val = 'H',
561	},
562	{
563		.name = "control-port",
564		.has_arg = required_argument,
565		.val = 'P',
566	},
567	{
568		.name = "mode",
569		.has_arg = required_argument,
570		.val = 'm',
571	},
572	{
573		.name = "peer-cid",
574		.has_arg = required_argument,
575		.val = 'p',
576	},
577	{
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
578		.name = "help",
579		.has_arg = no_argument,
580		.val = '?',
581	},
582	{},
583};
584
585static void usage(void)
586{
587	fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid>\n"
588		"\n"
589		"  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
590		"  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
591		"\n"
592		"Run vsock_diag.ko tests.  Must be launched in both\n"
593		"guest and host.  One side must use --mode=client and\n"
594		"the other side must use --mode=server.\n"
595		"\n"
596		"A TCP control socket connection is used to coordinate tests\n"
597		"between the client and the server.  The server requires a\n"
598		"listen address and the client requires an address to\n"
599		"connect to.\n"
600		"\n"
601		"The CID of the other side must be given with --peer-cid=<cid>.\n");
 
 
 
 
 
 
 
 
 
 
 
 
 
602	exit(EXIT_FAILURE);
603}
604
605int main(int argc, char **argv)
606{
607	const char *control_host = NULL;
608	const char *control_port = NULL;
609	int mode = TEST_MODE_UNSET;
610	unsigned int peer_cid = VMADDR_CID_ANY;
611	int i;
 
 
612
613	init_signals();
614
615	for (;;) {
616		int opt = getopt_long(argc, argv, optstring, longopts, NULL);
617
618		if (opt == -1)
619			break;
620
621		switch (opt) {
622		case 'H':
623			control_host = optarg;
624			break;
625		case 'm':
626			if (strcmp(optarg, "client") == 0)
627				mode = TEST_MODE_CLIENT;
628			else if (strcmp(optarg, "server") == 0)
629				mode = TEST_MODE_SERVER;
630			else {
631				fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
632				return EXIT_FAILURE;
633			}
634			break;
635		case 'p':
636			peer_cid = parse_cid(optarg);
 
 
 
637			break;
638		case 'P':
639			control_port = optarg;
640			break;
 
 
 
 
 
 
 
641		case '?':
642		default:
643			usage();
644		}
645	}
646
647	if (!control_port)
648		usage();
649	if (mode == TEST_MODE_UNSET)
650		usage();
651	if (peer_cid == VMADDR_CID_ANY)
652		usage();
653
654	if (!control_host) {
655		if (mode != TEST_MODE_SERVER)
656			usage();
657		control_host = "0.0.0.0";
658	}
659
660	control_init(control_host, control_port, mode == TEST_MODE_SERVER);
 
661
662	for (i = 0; test_cases[i].name; i++) {
663		void (*run)(unsigned int peer_cid);
664
665		printf("%s...", test_cases[i].name);
666		fflush(stdout);
667
668		if (mode == TEST_MODE_CLIENT)
669			run = test_cases[i].run_client;
670		else
671			run = test_cases[i].run_server;
672
673		if (run)
674			run(peer_cid);
675
676		printf("ok\n");
677	}
678
679	control_cleanup();
680	return EXIT_SUCCESS;
681}