Loading...
1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * vsock_diag_test - vsock_diag.ko test suite
4 *
5 * Copyright (C) 2017 Red Hat, Inc.
6 *
7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
8 */
9
10#include <getopt.h>
11#include <stdio.h>
12#include <stdlib.h>
13#include <string.h>
14#include <errno.h>
15#include <unistd.h>
16#include <sys/stat.h>
17#include <sys/types.h>
18#include <linux/list.h>
19#include <linux/net.h>
20#include <linux/netlink.h>
21#include <linux/sock_diag.h>
22#include <linux/vm_sockets_diag.h>
23#include <netinet/tcp.h>
24
25#include "timeout.h"
26#include "control.h"
27#include "util.h"
28
29/* Per-socket status */
30struct vsock_stat {
31 struct list_head list;
32 struct vsock_diag_msg msg;
33};
34
35static const char *sock_type_str(int type)
36{
37 switch (type) {
38 case SOCK_DGRAM:
39 return "DGRAM";
40 case SOCK_STREAM:
41 return "STREAM";
42 case SOCK_SEQPACKET:
43 return "SEQPACKET";
44 default:
45 return "INVALID TYPE";
46 }
47}
48
49static const char *sock_state_str(int state)
50{
51 switch (state) {
52 case TCP_CLOSE:
53 return "UNCONNECTED";
54 case TCP_SYN_SENT:
55 return "CONNECTING";
56 case TCP_ESTABLISHED:
57 return "CONNECTED";
58 case TCP_CLOSING:
59 return "DISCONNECTING";
60 case TCP_LISTEN:
61 return "LISTEN";
62 default:
63 return "INVALID STATE";
64 }
65}
66
67static const char *sock_shutdown_str(int shutdown)
68{
69 switch (shutdown) {
70 case 1:
71 return "RCV_SHUTDOWN";
72 case 2:
73 return "SEND_SHUTDOWN";
74 case 3:
75 return "RCV_SHUTDOWN | SEND_SHUTDOWN";
76 default:
77 return "0";
78 }
79}
80
81static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
82{
83 if (cid == VMADDR_CID_ANY)
84 fprintf(fp, "*:");
85 else
86 fprintf(fp, "%u:", cid);
87
88 if (port == VMADDR_PORT_ANY)
89 fprintf(fp, "*");
90 else
91 fprintf(fp, "%u", port);
92}
93
94static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
95{
96 print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
97 fprintf(fp, " ");
98 print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
99 fprintf(fp, " %s %s %s %u\n",
100 sock_type_str(st->msg.vdiag_type),
101 sock_state_str(st->msg.vdiag_state),
102 sock_shutdown_str(st->msg.vdiag_shutdown),
103 st->msg.vdiag_ino);
104}
105
106static void print_vsock_stats(FILE *fp, struct list_head *head)
107{
108 struct vsock_stat *st;
109
110 list_for_each_entry(st, head, list)
111 print_vsock_stat(fp, st);
112}
113
114static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
115{
116 struct vsock_stat *st;
117 struct stat stat;
118
119 if (fstat(fd, &stat) < 0) {
120 perror("fstat");
121 exit(EXIT_FAILURE);
122 }
123
124 list_for_each_entry(st, head, list)
125 if (st->msg.vdiag_ino == stat.st_ino)
126 return st;
127
128 fprintf(stderr, "cannot find fd %d\n", fd);
129 exit(EXIT_FAILURE);
130}
131
132static void check_no_sockets(struct list_head *head)
133{
134 if (!list_empty(head)) {
135 fprintf(stderr, "expected no sockets\n");
136 print_vsock_stats(stderr, head);
137 exit(1);
138 }
139}
140
141static void check_num_sockets(struct list_head *head, int expected)
142{
143 struct list_head *node;
144 int n = 0;
145
146 list_for_each(node, head)
147 n++;
148
149 if (n != expected) {
150 fprintf(stderr, "expected %d sockets, found %d\n",
151 expected, n);
152 print_vsock_stats(stderr, head);
153 exit(EXIT_FAILURE);
154 }
155}
156
157static void check_socket_state(struct vsock_stat *st, __u8 state)
158{
159 if (st->msg.vdiag_state != state) {
160 fprintf(stderr, "expected socket state %#x, got %#x\n",
161 state, st->msg.vdiag_state);
162 exit(EXIT_FAILURE);
163 }
164}
165
166static void send_req(int fd)
167{
168 struct sockaddr_nl nladdr = {
169 .nl_family = AF_NETLINK,
170 };
171 struct {
172 struct nlmsghdr nlh;
173 struct vsock_diag_req vreq;
174 } req = {
175 .nlh = {
176 .nlmsg_len = sizeof(req),
177 .nlmsg_type = SOCK_DIAG_BY_FAMILY,
178 .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
179 },
180 .vreq = {
181 .sdiag_family = AF_VSOCK,
182 .vdiag_states = ~(__u32)0,
183 },
184 };
185 struct iovec iov = {
186 .iov_base = &req,
187 .iov_len = sizeof(req),
188 };
189 struct msghdr msg = {
190 .msg_name = &nladdr,
191 .msg_namelen = sizeof(nladdr),
192 .msg_iov = &iov,
193 .msg_iovlen = 1,
194 };
195
196 for (;;) {
197 if (sendmsg(fd, &msg, 0) < 0) {
198 if (errno == EINTR)
199 continue;
200
201 perror("sendmsg");
202 exit(EXIT_FAILURE);
203 }
204
205 return;
206 }
207}
208
209static ssize_t recv_resp(int fd, void *buf, size_t len)
210{
211 struct sockaddr_nl nladdr = {
212 .nl_family = AF_NETLINK,
213 };
214 struct iovec iov = {
215 .iov_base = buf,
216 .iov_len = len,
217 };
218 struct msghdr msg = {
219 .msg_name = &nladdr,
220 .msg_namelen = sizeof(nladdr),
221 .msg_iov = &iov,
222 .msg_iovlen = 1,
223 };
224 ssize_t ret;
225
226 do {
227 ret = recvmsg(fd, &msg, 0);
228 } while (ret < 0 && errno == EINTR);
229
230 if (ret < 0) {
231 perror("recvmsg");
232 exit(EXIT_FAILURE);
233 }
234
235 return ret;
236}
237
238static void add_vsock_stat(struct list_head *sockets,
239 const struct vsock_diag_msg *resp)
240{
241 struct vsock_stat *st;
242
243 st = malloc(sizeof(*st));
244 if (!st) {
245 perror("malloc");
246 exit(EXIT_FAILURE);
247 }
248
249 st->msg = *resp;
250 list_add_tail(&st->list, sockets);
251}
252
253/*
254 * Read vsock stats into a list.
255 */
256static void read_vsock_stat(struct list_head *sockets)
257{
258 long buf[8192 / sizeof(long)];
259 int fd;
260
261 fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
262 if (fd < 0) {
263 perror("socket");
264 exit(EXIT_FAILURE);
265 }
266
267 send_req(fd);
268
269 for (;;) {
270 const struct nlmsghdr *h;
271 ssize_t ret;
272
273 ret = recv_resp(fd, buf, sizeof(buf));
274 if (ret == 0)
275 goto done;
276 if (ret < sizeof(*h)) {
277 fprintf(stderr, "short read of %zd bytes\n", ret);
278 exit(EXIT_FAILURE);
279 }
280
281 h = (struct nlmsghdr *)buf;
282
283 while (NLMSG_OK(h, ret)) {
284 if (h->nlmsg_type == NLMSG_DONE)
285 goto done;
286
287 if (h->nlmsg_type == NLMSG_ERROR) {
288 const struct nlmsgerr *err = NLMSG_DATA(h);
289
290 if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
291 fprintf(stderr, "NLMSG_ERROR\n");
292 else {
293 errno = -err->error;
294 perror("NLMSG_ERROR");
295 }
296
297 exit(EXIT_FAILURE);
298 }
299
300 if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
301 fprintf(stderr, "unexpected nlmsg_type %#x\n",
302 h->nlmsg_type);
303 exit(EXIT_FAILURE);
304 }
305 if (h->nlmsg_len <
306 NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
307 fprintf(stderr, "short vsock_diag_msg\n");
308 exit(EXIT_FAILURE);
309 }
310
311 add_vsock_stat(sockets, NLMSG_DATA(h));
312
313 h = NLMSG_NEXT(h, ret);
314 }
315 }
316
317done:
318 close(fd);
319}
320
321static void free_sock_stat(struct list_head *sockets)
322{
323 struct vsock_stat *st;
324 struct vsock_stat *next;
325
326 list_for_each_entry_safe(st, next, sockets, list)
327 free(st);
328}
329
330static void test_no_sockets(const struct test_opts *opts)
331{
332 LIST_HEAD(sockets);
333
334 read_vsock_stat(&sockets);
335
336 check_no_sockets(&sockets);
337}
338
339static void test_listen_socket_server(const struct test_opts *opts)
340{
341 union {
342 struct sockaddr sa;
343 struct sockaddr_vm svm;
344 } addr = {
345 .svm = {
346 .svm_family = AF_VSOCK,
347 .svm_port = opts->peer_port,
348 .svm_cid = VMADDR_CID_ANY,
349 },
350 };
351 LIST_HEAD(sockets);
352 struct vsock_stat *st;
353 int fd;
354
355 fd = socket(AF_VSOCK, SOCK_STREAM, 0);
356
357 if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
358 perror("bind");
359 exit(EXIT_FAILURE);
360 }
361
362 if (listen(fd, 1) < 0) {
363 perror("listen");
364 exit(EXIT_FAILURE);
365 }
366
367 read_vsock_stat(&sockets);
368
369 check_num_sockets(&sockets, 1);
370 st = find_vsock_stat(&sockets, fd);
371 check_socket_state(st, TCP_LISTEN);
372
373 close(fd);
374 free_sock_stat(&sockets);
375}
376
377static void test_connect_client(const struct test_opts *opts)
378{
379 int fd;
380 LIST_HEAD(sockets);
381 struct vsock_stat *st;
382
383 fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
384 if (fd < 0) {
385 perror("connect");
386 exit(EXIT_FAILURE);
387 }
388
389 read_vsock_stat(&sockets);
390
391 check_num_sockets(&sockets, 1);
392 st = find_vsock_stat(&sockets, fd);
393 check_socket_state(st, TCP_ESTABLISHED);
394
395 control_expectln("DONE");
396 control_writeln("DONE");
397
398 close(fd);
399 free_sock_stat(&sockets);
400}
401
402static void test_connect_server(const struct test_opts *opts)
403{
404 struct vsock_stat *st;
405 LIST_HEAD(sockets);
406 int client_fd;
407
408 client_fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
409 if (client_fd < 0) {
410 perror("accept");
411 exit(EXIT_FAILURE);
412 }
413
414 read_vsock_stat(&sockets);
415
416 check_num_sockets(&sockets, 1);
417 st = find_vsock_stat(&sockets, client_fd);
418 check_socket_state(st, TCP_ESTABLISHED);
419
420 control_writeln("DONE");
421 control_expectln("DONE");
422
423 close(client_fd);
424 free_sock_stat(&sockets);
425}
426
427static struct test_case test_cases[] = {
428 {
429 .name = "No sockets",
430 .run_server = test_no_sockets,
431 },
432 {
433 .name = "Listen socket",
434 .run_server = test_listen_socket_server,
435 },
436 {
437 .name = "Connect",
438 .run_client = test_connect_client,
439 .run_server = test_connect_server,
440 },
441 {},
442};
443
444static const char optstring[] = "";
445static const struct option longopts[] = {
446 {
447 .name = "control-host",
448 .has_arg = required_argument,
449 .val = 'H',
450 },
451 {
452 .name = "control-port",
453 .has_arg = required_argument,
454 .val = 'P',
455 },
456 {
457 .name = "mode",
458 .has_arg = required_argument,
459 .val = 'm',
460 },
461 {
462 .name = "peer-cid",
463 .has_arg = required_argument,
464 .val = 'p',
465 },
466 {
467 .name = "peer-port",
468 .has_arg = required_argument,
469 .val = 'q',
470 },
471 {
472 .name = "list",
473 .has_arg = no_argument,
474 .val = 'l',
475 },
476 {
477 .name = "skip",
478 .has_arg = required_argument,
479 .val = 's',
480 },
481 {
482 .name = "help",
483 .has_arg = no_argument,
484 .val = '?',
485 },
486 {},
487};
488
489static void usage(void)
490{
491 fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--peer-port=<port>] [--list] [--skip=<test_id>]\n"
492 "\n"
493 " Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
494 " Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
495 "\n"
496 "Run vsock_diag.ko tests. Must be launched in both\n"
497 "guest and host. One side must use --mode=client and\n"
498 "the other side must use --mode=server.\n"
499 "\n"
500 "A TCP control socket connection is used to coordinate tests\n"
501 "between the client and the server. The server requires a\n"
502 "listen address and the client requires an address to\n"
503 "connect to.\n"
504 "\n"
505 "The CID of the other side must be given with --peer-cid=<cid>.\n"
506 "\n"
507 "Options:\n"
508 " --help This help message\n"
509 " --control-host <host> Server IP address to connect to\n"
510 " --control-port <port> Server port to listen on/connect to\n"
511 " --mode client|server Server or client mode\n"
512 " --peer-cid <cid> CID of the other side\n"
513 " --peer-port <port> AF_VSOCK port used for the test [default: %d]\n"
514 " --list List of tests that will be executed\n"
515 " --skip <test_id> Test ID to skip;\n"
516 " use multiple --skip options to skip more tests\n",
517 DEFAULT_PEER_PORT
518 );
519 exit(EXIT_FAILURE);
520}
521
522int main(int argc, char **argv)
523{
524 const char *control_host = NULL;
525 const char *control_port = NULL;
526 struct test_opts opts = {
527 .mode = TEST_MODE_UNSET,
528 .peer_cid = VMADDR_CID_ANY,
529 .peer_port = DEFAULT_PEER_PORT,
530 };
531
532 init_signals();
533
534 for (;;) {
535 int opt = getopt_long(argc, argv, optstring, longopts, NULL);
536
537 if (opt == -1)
538 break;
539
540 switch (opt) {
541 case 'H':
542 control_host = optarg;
543 break;
544 case 'm':
545 if (strcmp(optarg, "client") == 0)
546 opts.mode = TEST_MODE_CLIENT;
547 else if (strcmp(optarg, "server") == 0)
548 opts.mode = TEST_MODE_SERVER;
549 else {
550 fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
551 return EXIT_FAILURE;
552 }
553 break;
554 case 'p':
555 opts.peer_cid = parse_cid(optarg);
556 break;
557 case 'q':
558 opts.peer_port = parse_port(optarg);
559 break;
560 case 'P':
561 control_port = optarg;
562 break;
563 case 'l':
564 list_tests(test_cases);
565 break;
566 case 's':
567 skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
568 optarg);
569 break;
570 case '?':
571 default:
572 usage();
573 }
574 }
575
576 if (!control_port)
577 usage();
578 if (opts.mode == TEST_MODE_UNSET)
579 usage();
580 if (opts.peer_cid == VMADDR_CID_ANY)
581 usage();
582
583 if (!control_host) {
584 if (opts.mode != TEST_MODE_SERVER)
585 usage();
586 control_host = "0.0.0.0";
587 }
588
589 control_init(control_host, control_port,
590 opts.mode == TEST_MODE_SERVER);
591
592 run_tests(test_cases, &opts);
593
594 control_cleanup();
595 return EXIT_SUCCESS;
596}
1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * vsock_diag_test - vsock_diag.ko test suite
4 *
5 * Copyright (C) 2017 Red Hat, Inc.
6 *
7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
8 */
9
10#include <getopt.h>
11#include <stdio.h>
12#include <stdlib.h>
13#include <string.h>
14#include <errno.h>
15#include <unistd.h>
16#include <sys/stat.h>
17#include <sys/types.h>
18#include <linux/list.h>
19#include <linux/net.h>
20#include <linux/netlink.h>
21#include <linux/sock_diag.h>
22#include <linux/vm_sockets_diag.h>
23#include <netinet/tcp.h>
24
25#include "timeout.h"
26#include "control.h"
27#include "util.h"
28
29/* Per-socket status */
30struct vsock_stat {
31 struct list_head list;
32 struct vsock_diag_msg msg;
33};
34
35static const char *sock_type_str(int type)
36{
37 switch (type) {
38 case SOCK_DGRAM:
39 return "DGRAM";
40 case SOCK_STREAM:
41 return "STREAM";
42 default:
43 return "INVALID TYPE";
44 }
45}
46
47static const char *sock_state_str(int state)
48{
49 switch (state) {
50 case TCP_CLOSE:
51 return "UNCONNECTED";
52 case TCP_SYN_SENT:
53 return "CONNECTING";
54 case TCP_ESTABLISHED:
55 return "CONNECTED";
56 case TCP_CLOSING:
57 return "DISCONNECTING";
58 case TCP_LISTEN:
59 return "LISTEN";
60 default:
61 return "INVALID STATE";
62 }
63}
64
65static const char *sock_shutdown_str(int shutdown)
66{
67 switch (shutdown) {
68 case 1:
69 return "RCV_SHUTDOWN";
70 case 2:
71 return "SEND_SHUTDOWN";
72 case 3:
73 return "RCV_SHUTDOWN | SEND_SHUTDOWN";
74 default:
75 return "0";
76 }
77}
78
79static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
80{
81 if (cid == VMADDR_CID_ANY)
82 fprintf(fp, "*:");
83 else
84 fprintf(fp, "%u:", cid);
85
86 if (port == VMADDR_PORT_ANY)
87 fprintf(fp, "*");
88 else
89 fprintf(fp, "%u", port);
90}
91
92static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
93{
94 print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
95 fprintf(fp, " ");
96 print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
97 fprintf(fp, " %s %s %s %u\n",
98 sock_type_str(st->msg.vdiag_type),
99 sock_state_str(st->msg.vdiag_state),
100 sock_shutdown_str(st->msg.vdiag_shutdown),
101 st->msg.vdiag_ino);
102}
103
104static void print_vsock_stats(FILE *fp, struct list_head *head)
105{
106 struct vsock_stat *st;
107
108 list_for_each_entry(st, head, list)
109 print_vsock_stat(fp, st);
110}
111
112static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
113{
114 struct vsock_stat *st;
115 struct stat stat;
116
117 if (fstat(fd, &stat) < 0) {
118 perror("fstat");
119 exit(EXIT_FAILURE);
120 }
121
122 list_for_each_entry(st, head, list)
123 if (st->msg.vdiag_ino == stat.st_ino)
124 return st;
125
126 fprintf(stderr, "cannot find fd %d\n", fd);
127 exit(EXIT_FAILURE);
128}
129
130static void check_no_sockets(struct list_head *head)
131{
132 if (!list_empty(head)) {
133 fprintf(stderr, "expected no sockets\n");
134 print_vsock_stats(stderr, head);
135 exit(1);
136 }
137}
138
139static void check_num_sockets(struct list_head *head, int expected)
140{
141 struct list_head *node;
142 int n = 0;
143
144 list_for_each(node, head)
145 n++;
146
147 if (n != expected) {
148 fprintf(stderr, "expected %d sockets, found %d\n",
149 expected, n);
150 print_vsock_stats(stderr, head);
151 exit(EXIT_FAILURE);
152 }
153}
154
155static void check_socket_state(struct vsock_stat *st, __u8 state)
156{
157 if (st->msg.vdiag_state != state) {
158 fprintf(stderr, "expected socket state %#x, got %#x\n",
159 state, st->msg.vdiag_state);
160 exit(EXIT_FAILURE);
161 }
162}
163
164static void send_req(int fd)
165{
166 struct sockaddr_nl nladdr = {
167 .nl_family = AF_NETLINK,
168 };
169 struct {
170 struct nlmsghdr nlh;
171 struct vsock_diag_req vreq;
172 } req = {
173 .nlh = {
174 .nlmsg_len = sizeof(req),
175 .nlmsg_type = SOCK_DIAG_BY_FAMILY,
176 .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
177 },
178 .vreq = {
179 .sdiag_family = AF_VSOCK,
180 .vdiag_states = ~(__u32)0,
181 },
182 };
183 struct iovec iov = {
184 .iov_base = &req,
185 .iov_len = sizeof(req),
186 };
187 struct msghdr msg = {
188 .msg_name = &nladdr,
189 .msg_namelen = sizeof(nladdr),
190 .msg_iov = &iov,
191 .msg_iovlen = 1,
192 };
193
194 for (;;) {
195 if (sendmsg(fd, &msg, 0) < 0) {
196 if (errno == EINTR)
197 continue;
198
199 perror("sendmsg");
200 exit(EXIT_FAILURE);
201 }
202
203 return;
204 }
205}
206
207static ssize_t recv_resp(int fd, void *buf, size_t len)
208{
209 struct sockaddr_nl nladdr = {
210 .nl_family = AF_NETLINK,
211 };
212 struct iovec iov = {
213 .iov_base = buf,
214 .iov_len = len,
215 };
216 struct msghdr msg = {
217 .msg_name = &nladdr,
218 .msg_namelen = sizeof(nladdr),
219 .msg_iov = &iov,
220 .msg_iovlen = 1,
221 };
222 ssize_t ret;
223
224 do {
225 ret = recvmsg(fd, &msg, 0);
226 } while (ret < 0 && errno == EINTR);
227
228 if (ret < 0) {
229 perror("recvmsg");
230 exit(EXIT_FAILURE);
231 }
232
233 return ret;
234}
235
236static void add_vsock_stat(struct list_head *sockets,
237 const struct vsock_diag_msg *resp)
238{
239 struct vsock_stat *st;
240
241 st = malloc(sizeof(*st));
242 if (!st) {
243 perror("malloc");
244 exit(EXIT_FAILURE);
245 }
246
247 st->msg = *resp;
248 list_add_tail(&st->list, sockets);
249}
250
251/*
252 * Read vsock stats into a list.
253 */
254static void read_vsock_stat(struct list_head *sockets)
255{
256 long buf[8192 / sizeof(long)];
257 int fd;
258
259 fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
260 if (fd < 0) {
261 perror("socket");
262 exit(EXIT_FAILURE);
263 }
264
265 send_req(fd);
266
267 for (;;) {
268 const struct nlmsghdr *h;
269 ssize_t ret;
270
271 ret = recv_resp(fd, buf, sizeof(buf));
272 if (ret == 0)
273 goto done;
274 if (ret < sizeof(*h)) {
275 fprintf(stderr, "short read of %zd bytes\n", ret);
276 exit(EXIT_FAILURE);
277 }
278
279 h = (struct nlmsghdr *)buf;
280
281 while (NLMSG_OK(h, ret)) {
282 if (h->nlmsg_type == NLMSG_DONE)
283 goto done;
284
285 if (h->nlmsg_type == NLMSG_ERROR) {
286 const struct nlmsgerr *err = NLMSG_DATA(h);
287
288 if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
289 fprintf(stderr, "NLMSG_ERROR\n");
290 else {
291 errno = -err->error;
292 perror("NLMSG_ERROR");
293 }
294
295 exit(EXIT_FAILURE);
296 }
297
298 if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
299 fprintf(stderr, "unexpected nlmsg_type %#x\n",
300 h->nlmsg_type);
301 exit(EXIT_FAILURE);
302 }
303 if (h->nlmsg_len <
304 NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
305 fprintf(stderr, "short vsock_diag_msg\n");
306 exit(EXIT_FAILURE);
307 }
308
309 add_vsock_stat(sockets, NLMSG_DATA(h));
310
311 h = NLMSG_NEXT(h, ret);
312 }
313 }
314
315done:
316 close(fd);
317}
318
319static void free_sock_stat(struct list_head *sockets)
320{
321 struct vsock_stat *st;
322 struct vsock_stat *next;
323
324 list_for_each_entry_safe(st, next, sockets, list)
325 free(st);
326}
327
328static void test_no_sockets(const struct test_opts *opts)
329{
330 LIST_HEAD(sockets);
331
332 read_vsock_stat(&sockets);
333
334 check_no_sockets(&sockets);
335}
336
337static void test_listen_socket_server(const struct test_opts *opts)
338{
339 union {
340 struct sockaddr sa;
341 struct sockaddr_vm svm;
342 } addr = {
343 .svm = {
344 .svm_family = AF_VSOCK,
345 .svm_port = 1234,
346 .svm_cid = VMADDR_CID_ANY,
347 },
348 };
349 LIST_HEAD(sockets);
350 struct vsock_stat *st;
351 int fd;
352
353 fd = socket(AF_VSOCK, SOCK_STREAM, 0);
354
355 if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
356 perror("bind");
357 exit(EXIT_FAILURE);
358 }
359
360 if (listen(fd, 1) < 0) {
361 perror("listen");
362 exit(EXIT_FAILURE);
363 }
364
365 read_vsock_stat(&sockets);
366
367 check_num_sockets(&sockets, 1);
368 st = find_vsock_stat(&sockets, fd);
369 check_socket_state(st, TCP_LISTEN);
370
371 close(fd);
372 free_sock_stat(&sockets);
373}
374
375static void test_connect_client(const struct test_opts *opts)
376{
377 int fd;
378 LIST_HEAD(sockets);
379 struct vsock_stat *st;
380
381 fd = vsock_stream_connect(opts->peer_cid, 1234);
382 if (fd < 0) {
383 perror("connect");
384 exit(EXIT_FAILURE);
385 }
386
387 read_vsock_stat(&sockets);
388
389 check_num_sockets(&sockets, 1);
390 st = find_vsock_stat(&sockets, fd);
391 check_socket_state(st, TCP_ESTABLISHED);
392
393 control_expectln("DONE");
394 control_writeln("DONE");
395
396 close(fd);
397 free_sock_stat(&sockets);
398}
399
400static void test_connect_server(const struct test_opts *opts)
401{
402 struct vsock_stat *st;
403 LIST_HEAD(sockets);
404 int client_fd;
405
406 client_fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
407 if (client_fd < 0) {
408 perror("accept");
409 exit(EXIT_FAILURE);
410 }
411
412 read_vsock_stat(&sockets);
413
414 check_num_sockets(&sockets, 1);
415 st = find_vsock_stat(&sockets, client_fd);
416 check_socket_state(st, TCP_ESTABLISHED);
417
418 control_writeln("DONE");
419 control_expectln("DONE");
420
421 close(client_fd);
422 free_sock_stat(&sockets);
423}
424
425static struct test_case test_cases[] = {
426 {
427 .name = "No sockets",
428 .run_server = test_no_sockets,
429 },
430 {
431 .name = "Listen socket",
432 .run_server = test_listen_socket_server,
433 },
434 {
435 .name = "Connect",
436 .run_client = test_connect_client,
437 .run_server = test_connect_server,
438 },
439 {},
440};
441
442static const char optstring[] = "";
443static const struct option longopts[] = {
444 {
445 .name = "control-host",
446 .has_arg = required_argument,
447 .val = 'H',
448 },
449 {
450 .name = "control-port",
451 .has_arg = required_argument,
452 .val = 'P',
453 },
454 {
455 .name = "mode",
456 .has_arg = required_argument,
457 .val = 'm',
458 },
459 {
460 .name = "peer-cid",
461 .has_arg = required_argument,
462 .val = 'p',
463 },
464 {
465 .name = "list",
466 .has_arg = no_argument,
467 .val = 'l',
468 },
469 {
470 .name = "skip",
471 .has_arg = required_argument,
472 .val = 's',
473 },
474 {
475 .name = "help",
476 .has_arg = no_argument,
477 .val = '?',
478 },
479 {},
480};
481
482static void usage(void)
483{
484 fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--list] [--skip=<test_id>]\n"
485 "\n"
486 " Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
487 " Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
488 "\n"
489 "Run vsock_diag.ko tests. Must be launched in both\n"
490 "guest and host. One side must use --mode=client and\n"
491 "the other side must use --mode=server.\n"
492 "\n"
493 "A TCP control socket connection is used to coordinate tests\n"
494 "between the client and the server. The server requires a\n"
495 "listen address and the client requires an address to\n"
496 "connect to.\n"
497 "\n"
498 "The CID of the other side must be given with --peer-cid=<cid>.\n"
499 "\n"
500 "Options:\n"
501 " --help This help message\n"
502 " --control-host <host> Server IP address to connect to\n"
503 " --control-port <port> Server port to listen on/connect to\n"
504 " --mode client|server Server or client mode\n"
505 " --peer-cid <cid> CID of the other side\n"
506 " --list List of tests that will be executed\n"
507 " --skip <test_id> Test ID to skip;\n"
508 " use multiple --skip options to skip more tests\n"
509 );
510 exit(EXIT_FAILURE);
511}
512
513int main(int argc, char **argv)
514{
515 const char *control_host = NULL;
516 const char *control_port = NULL;
517 struct test_opts opts = {
518 .mode = TEST_MODE_UNSET,
519 .peer_cid = VMADDR_CID_ANY,
520 };
521
522 init_signals();
523
524 for (;;) {
525 int opt = getopt_long(argc, argv, optstring, longopts, NULL);
526
527 if (opt == -1)
528 break;
529
530 switch (opt) {
531 case 'H':
532 control_host = optarg;
533 break;
534 case 'm':
535 if (strcmp(optarg, "client") == 0)
536 opts.mode = TEST_MODE_CLIENT;
537 else if (strcmp(optarg, "server") == 0)
538 opts.mode = TEST_MODE_SERVER;
539 else {
540 fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
541 return EXIT_FAILURE;
542 }
543 break;
544 case 'p':
545 opts.peer_cid = parse_cid(optarg);
546 break;
547 case 'P':
548 control_port = optarg;
549 break;
550 case 'l':
551 list_tests(test_cases);
552 break;
553 case 's':
554 skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
555 optarg);
556 break;
557 case '?':
558 default:
559 usage();
560 }
561 }
562
563 if (!control_port)
564 usage();
565 if (opts.mode == TEST_MODE_UNSET)
566 usage();
567 if (opts.peer_cid == VMADDR_CID_ANY)
568 usage();
569
570 if (!control_host) {
571 if (opts.mode != TEST_MODE_SERVER)
572 usage();
573 control_host = "0.0.0.0";
574 }
575
576 control_init(control_host, control_port,
577 opts.mode == TEST_MODE_SERVER);
578
579 run_tests(test_cases, &opts);
580
581 control_cleanup();
582 return EXIT_SUCCESS;
583}