Loading...
1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * vsock_diag_test - vsock_diag.ko test suite
4 *
5 * Copyright (C) 2017 Red Hat, Inc.
6 *
7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
8 */
9
10#include <getopt.h>
11#include <stdio.h>
12#include <stdlib.h>
13#include <string.h>
14#include <errno.h>
15#include <unistd.h>
16#include <sys/stat.h>
17#include <sys/types.h>
18#include <linux/list.h>
19#include <linux/net.h>
20#include <linux/netlink.h>
21#include <linux/sock_diag.h>
22#include <linux/vm_sockets_diag.h>
23#include <netinet/tcp.h>
24
25#include "timeout.h"
26#include "control.h"
27#include "util.h"
28
29/* Per-socket status */
30struct vsock_stat {
31 struct list_head list;
32 struct vsock_diag_msg msg;
33};
34
35static const char *sock_type_str(int type)
36{
37 switch (type) {
38 case SOCK_DGRAM:
39 return "DGRAM";
40 case SOCK_STREAM:
41 return "STREAM";
42 case SOCK_SEQPACKET:
43 return "SEQPACKET";
44 default:
45 return "INVALID TYPE";
46 }
47}
48
49static const char *sock_state_str(int state)
50{
51 switch (state) {
52 case TCP_CLOSE:
53 return "UNCONNECTED";
54 case TCP_SYN_SENT:
55 return "CONNECTING";
56 case TCP_ESTABLISHED:
57 return "CONNECTED";
58 case TCP_CLOSING:
59 return "DISCONNECTING";
60 case TCP_LISTEN:
61 return "LISTEN";
62 default:
63 return "INVALID STATE";
64 }
65}
66
67static const char *sock_shutdown_str(int shutdown)
68{
69 switch (shutdown) {
70 case 1:
71 return "RCV_SHUTDOWN";
72 case 2:
73 return "SEND_SHUTDOWN";
74 case 3:
75 return "RCV_SHUTDOWN | SEND_SHUTDOWN";
76 default:
77 return "0";
78 }
79}
80
81static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
82{
83 if (cid == VMADDR_CID_ANY)
84 fprintf(fp, "*:");
85 else
86 fprintf(fp, "%u:", cid);
87
88 if (port == VMADDR_PORT_ANY)
89 fprintf(fp, "*");
90 else
91 fprintf(fp, "%u", port);
92}
93
94static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
95{
96 print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
97 fprintf(fp, " ");
98 print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
99 fprintf(fp, " %s %s %s %u\n",
100 sock_type_str(st->msg.vdiag_type),
101 sock_state_str(st->msg.vdiag_state),
102 sock_shutdown_str(st->msg.vdiag_shutdown),
103 st->msg.vdiag_ino);
104}
105
106static void print_vsock_stats(FILE *fp, struct list_head *head)
107{
108 struct vsock_stat *st;
109
110 list_for_each_entry(st, head, list)
111 print_vsock_stat(fp, st);
112}
113
114static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
115{
116 struct vsock_stat *st;
117 struct stat stat;
118
119 if (fstat(fd, &stat) < 0) {
120 perror("fstat");
121 exit(EXIT_FAILURE);
122 }
123
124 list_for_each_entry(st, head, list)
125 if (st->msg.vdiag_ino == stat.st_ino)
126 return st;
127
128 fprintf(stderr, "cannot find fd %d\n", fd);
129 exit(EXIT_FAILURE);
130}
131
132static void check_no_sockets(struct list_head *head)
133{
134 if (!list_empty(head)) {
135 fprintf(stderr, "expected no sockets\n");
136 print_vsock_stats(stderr, head);
137 exit(1);
138 }
139}
140
141static void check_num_sockets(struct list_head *head, int expected)
142{
143 struct list_head *node;
144 int n = 0;
145
146 list_for_each(node, head)
147 n++;
148
149 if (n != expected) {
150 fprintf(stderr, "expected %d sockets, found %d\n",
151 expected, n);
152 print_vsock_stats(stderr, head);
153 exit(EXIT_FAILURE);
154 }
155}
156
157static void check_socket_state(struct vsock_stat *st, __u8 state)
158{
159 if (st->msg.vdiag_state != state) {
160 fprintf(stderr, "expected socket state %#x, got %#x\n",
161 state, st->msg.vdiag_state);
162 exit(EXIT_FAILURE);
163 }
164}
165
166static void send_req(int fd)
167{
168 struct sockaddr_nl nladdr = {
169 .nl_family = AF_NETLINK,
170 };
171 struct {
172 struct nlmsghdr nlh;
173 struct vsock_diag_req vreq;
174 } req = {
175 .nlh = {
176 .nlmsg_len = sizeof(req),
177 .nlmsg_type = SOCK_DIAG_BY_FAMILY,
178 .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
179 },
180 .vreq = {
181 .sdiag_family = AF_VSOCK,
182 .vdiag_states = ~(__u32)0,
183 },
184 };
185 struct iovec iov = {
186 .iov_base = &req,
187 .iov_len = sizeof(req),
188 };
189 struct msghdr msg = {
190 .msg_name = &nladdr,
191 .msg_namelen = sizeof(nladdr),
192 .msg_iov = &iov,
193 .msg_iovlen = 1,
194 };
195
196 for (;;) {
197 if (sendmsg(fd, &msg, 0) < 0) {
198 if (errno == EINTR)
199 continue;
200
201 perror("sendmsg");
202 exit(EXIT_FAILURE);
203 }
204
205 return;
206 }
207}
208
209static ssize_t recv_resp(int fd, void *buf, size_t len)
210{
211 struct sockaddr_nl nladdr = {
212 .nl_family = AF_NETLINK,
213 };
214 struct iovec iov = {
215 .iov_base = buf,
216 .iov_len = len,
217 };
218 struct msghdr msg = {
219 .msg_name = &nladdr,
220 .msg_namelen = sizeof(nladdr),
221 .msg_iov = &iov,
222 .msg_iovlen = 1,
223 };
224 ssize_t ret;
225
226 do {
227 ret = recvmsg(fd, &msg, 0);
228 } while (ret < 0 && errno == EINTR);
229
230 if (ret < 0) {
231 perror("recvmsg");
232 exit(EXIT_FAILURE);
233 }
234
235 return ret;
236}
237
238static void add_vsock_stat(struct list_head *sockets,
239 const struct vsock_diag_msg *resp)
240{
241 struct vsock_stat *st;
242
243 st = malloc(sizeof(*st));
244 if (!st) {
245 perror("malloc");
246 exit(EXIT_FAILURE);
247 }
248
249 st->msg = *resp;
250 list_add_tail(&st->list, sockets);
251}
252
253/*
254 * Read vsock stats into a list.
255 */
256static void read_vsock_stat(struct list_head *sockets)
257{
258 long buf[8192 / sizeof(long)];
259 int fd;
260
261 fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
262 if (fd < 0) {
263 perror("socket");
264 exit(EXIT_FAILURE);
265 }
266
267 send_req(fd);
268
269 for (;;) {
270 const struct nlmsghdr *h;
271 ssize_t ret;
272
273 ret = recv_resp(fd, buf, sizeof(buf));
274 if (ret == 0)
275 goto done;
276 if (ret < sizeof(*h)) {
277 fprintf(stderr, "short read of %zd bytes\n", ret);
278 exit(EXIT_FAILURE);
279 }
280
281 h = (struct nlmsghdr *)buf;
282
283 while (NLMSG_OK(h, ret)) {
284 if (h->nlmsg_type == NLMSG_DONE)
285 goto done;
286
287 if (h->nlmsg_type == NLMSG_ERROR) {
288 const struct nlmsgerr *err = NLMSG_DATA(h);
289
290 if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
291 fprintf(stderr, "NLMSG_ERROR\n");
292 else {
293 errno = -err->error;
294 perror("NLMSG_ERROR");
295 }
296
297 exit(EXIT_FAILURE);
298 }
299
300 if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
301 fprintf(stderr, "unexpected nlmsg_type %#x\n",
302 h->nlmsg_type);
303 exit(EXIT_FAILURE);
304 }
305 if (h->nlmsg_len <
306 NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
307 fprintf(stderr, "short vsock_diag_msg\n");
308 exit(EXIT_FAILURE);
309 }
310
311 add_vsock_stat(sockets, NLMSG_DATA(h));
312
313 h = NLMSG_NEXT(h, ret);
314 }
315 }
316
317done:
318 close(fd);
319}
320
321static void free_sock_stat(struct list_head *sockets)
322{
323 struct vsock_stat *st;
324 struct vsock_stat *next;
325
326 list_for_each_entry_safe(st, next, sockets, list)
327 free(st);
328}
329
330static void test_no_sockets(const struct test_opts *opts)
331{
332 LIST_HEAD(sockets);
333
334 read_vsock_stat(&sockets);
335
336 check_no_sockets(&sockets);
337}
338
339static void test_listen_socket_server(const struct test_opts *opts)
340{
341 union {
342 struct sockaddr sa;
343 struct sockaddr_vm svm;
344 } addr = {
345 .svm = {
346 .svm_family = AF_VSOCK,
347 .svm_port = opts->peer_port,
348 .svm_cid = VMADDR_CID_ANY,
349 },
350 };
351 LIST_HEAD(sockets);
352 struct vsock_stat *st;
353 int fd;
354
355 fd = socket(AF_VSOCK, SOCK_STREAM, 0);
356
357 if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
358 perror("bind");
359 exit(EXIT_FAILURE);
360 }
361
362 if (listen(fd, 1) < 0) {
363 perror("listen");
364 exit(EXIT_FAILURE);
365 }
366
367 read_vsock_stat(&sockets);
368
369 check_num_sockets(&sockets, 1);
370 st = find_vsock_stat(&sockets, fd);
371 check_socket_state(st, TCP_LISTEN);
372
373 close(fd);
374 free_sock_stat(&sockets);
375}
376
377static void test_connect_client(const struct test_opts *opts)
378{
379 int fd;
380 LIST_HEAD(sockets);
381 struct vsock_stat *st;
382
383 fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
384 if (fd < 0) {
385 perror("connect");
386 exit(EXIT_FAILURE);
387 }
388
389 read_vsock_stat(&sockets);
390
391 check_num_sockets(&sockets, 1);
392 st = find_vsock_stat(&sockets, fd);
393 check_socket_state(st, TCP_ESTABLISHED);
394
395 control_expectln("DONE");
396 control_writeln("DONE");
397
398 close(fd);
399 free_sock_stat(&sockets);
400}
401
402static void test_connect_server(const struct test_opts *opts)
403{
404 struct vsock_stat *st;
405 LIST_HEAD(sockets);
406 int client_fd;
407
408 client_fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
409 if (client_fd < 0) {
410 perror("accept");
411 exit(EXIT_FAILURE);
412 }
413
414 read_vsock_stat(&sockets);
415
416 check_num_sockets(&sockets, 1);
417 st = find_vsock_stat(&sockets, client_fd);
418 check_socket_state(st, TCP_ESTABLISHED);
419
420 control_writeln("DONE");
421 control_expectln("DONE");
422
423 close(client_fd);
424 free_sock_stat(&sockets);
425}
426
427static struct test_case test_cases[] = {
428 {
429 .name = "No sockets",
430 .run_server = test_no_sockets,
431 },
432 {
433 .name = "Listen socket",
434 .run_server = test_listen_socket_server,
435 },
436 {
437 .name = "Connect",
438 .run_client = test_connect_client,
439 .run_server = test_connect_server,
440 },
441 {},
442};
443
444static const char optstring[] = "";
445static const struct option longopts[] = {
446 {
447 .name = "control-host",
448 .has_arg = required_argument,
449 .val = 'H',
450 },
451 {
452 .name = "control-port",
453 .has_arg = required_argument,
454 .val = 'P',
455 },
456 {
457 .name = "mode",
458 .has_arg = required_argument,
459 .val = 'm',
460 },
461 {
462 .name = "peer-cid",
463 .has_arg = required_argument,
464 .val = 'p',
465 },
466 {
467 .name = "peer-port",
468 .has_arg = required_argument,
469 .val = 'q',
470 },
471 {
472 .name = "list",
473 .has_arg = no_argument,
474 .val = 'l',
475 },
476 {
477 .name = "skip",
478 .has_arg = required_argument,
479 .val = 's',
480 },
481 {
482 .name = "help",
483 .has_arg = no_argument,
484 .val = '?',
485 },
486 {},
487};
488
489static void usage(void)
490{
491 fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--peer-port=<port>] [--list] [--skip=<test_id>]\n"
492 "\n"
493 " Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
494 " Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
495 "\n"
496 "Run vsock_diag.ko tests. Must be launched in both\n"
497 "guest and host. One side must use --mode=client and\n"
498 "the other side must use --mode=server.\n"
499 "\n"
500 "A TCP control socket connection is used to coordinate tests\n"
501 "between the client and the server. The server requires a\n"
502 "listen address and the client requires an address to\n"
503 "connect to.\n"
504 "\n"
505 "The CID of the other side must be given with --peer-cid=<cid>.\n"
506 "\n"
507 "Options:\n"
508 " --help This help message\n"
509 " --control-host <host> Server IP address to connect to\n"
510 " --control-port <port> Server port to listen on/connect to\n"
511 " --mode client|server Server or client mode\n"
512 " --peer-cid <cid> CID of the other side\n"
513 " --peer-port <port> AF_VSOCK port used for the test [default: %d]\n"
514 " --list List of tests that will be executed\n"
515 " --skip <test_id> Test ID to skip;\n"
516 " use multiple --skip options to skip more tests\n",
517 DEFAULT_PEER_PORT
518 );
519 exit(EXIT_FAILURE);
520}
521
522int main(int argc, char **argv)
523{
524 const char *control_host = NULL;
525 const char *control_port = NULL;
526 struct test_opts opts = {
527 .mode = TEST_MODE_UNSET,
528 .peer_cid = VMADDR_CID_ANY,
529 .peer_port = DEFAULT_PEER_PORT,
530 };
531
532 init_signals();
533
534 for (;;) {
535 int opt = getopt_long(argc, argv, optstring, longopts, NULL);
536
537 if (opt == -1)
538 break;
539
540 switch (opt) {
541 case 'H':
542 control_host = optarg;
543 break;
544 case 'm':
545 if (strcmp(optarg, "client") == 0)
546 opts.mode = TEST_MODE_CLIENT;
547 else if (strcmp(optarg, "server") == 0)
548 opts.mode = TEST_MODE_SERVER;
549 else {
550 fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
551 return EXIT_FAILURE;
552 }
553 break;
554 case 'p':
555 opts.peer_cid = parse_cid(optarg);
556 break;
557 case 'q':
558 opts.peer_port = parse_port(optarg);
559 break;
560 case 'P':
561 control_port = optarg;
562 break;
563 case 'l':
564 list_tests(test_cases);
565 break;
566 case 's':
567 skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
568 optarg);
569 break;
570 case '?':
571 default:
572 usage();
573 }
574 }
575
576 if (!control_port)
577 usage();
578 if (opts.mode == TEST_MODE_UNSET)
579 usage();
580 if (opts.peer_cid == VMADDR_CID_ANY)
581 usage();
582
583 if (!control_host) {
584 if (opts.mode != TEST_MODE_SERVER)
585 usage();
586 control_host = "0.0.0.0";
587 }
588
589 control_init(control_host, control_port,
590 opts.mode == TEST_MODE_SERVER);
591
592 run_tests(test_cases, &opts);
593
594 control_cleanup();
595 return EXIT_SUCCESS;
596}
1/*
2 * vsock_diag_test - vsock_diag.ko test suite
3 *
4 * Copyright (C) 2017 Red Hat, Inc.
5 *
6 * Author: Stefan Hajnoczi <stefanha@redhat.com>
7 *
8 * This program is free software; you can redistribute it and/or
9 * modify it under the terms of the GNU General Public License
10 * as published by the Free Software Foundation; version 2
11 * of the License.
12 */
13
14#include <getopt.h>
15#include <stdio.h>
16#include <stdbool.h>
17#include <stdlib.h>
18#include <string.h>
19#include <errno.h>
20#include <unistd.h>
21#include <signal.h>
22#include <sys/socket.h>
23#include <sys/stat.h>
24#include <sys/types.h>
25#include <linux/list.h>
26#include <linux/net.h>
27#include <linux/netlink.h>
28#include <linux/sock_diag.h>
29#include <netinet/tcp.h>
30
31#include "../../../include/uapi/linux/vm_sockets.h"
32#include "../../../include/uapi/linux/vm_sockets_diag.h"
33
34#include "timeout.h"
35#include "control.h"
36
37enum test_mode {
38 TEST_MODE_UNSET,
39 TEST_MODE_CLIENT,
40 TEST_MODE_SERVER
41};
42
43/* Per-socket status */
44struct vsock_stat {
45 struct list_head list;
46 struct vsock_diag_msg msg;
47};
48
49static const char *sock_type_str(int type)
50{
51 switch (type) {
52 case SOCK_DGRAM:
53 return "DGRAM";
54 case SOCK_STREAM:
55 return "STREAM";
56 default:
57 return "INVALID TYPE";
58 }
59}
60
61static const char *sock_state_str(int state)
62{
63 switch (state) {
64 case TCP_CLOSE:
65 return "UNCONNECTED";
66 case TCP_SYN_SENT:
67 return "CONNECTING";
68 case TCP_ESTABLISHED:
69 return "CONNECTED";
70 case TCP_CLOSING:
71 return "DISCONNECTING";
72 case TCP_LISTEN:
73 return "LISTEN";
74 default:
75 return "INVALID STATE";
76 }
77}
78
79static const char *sock_shutdown_str(int shutdown)
80{
81 switch (shutdown) {
82 case 1:
83 return "RCV_SHUTDOWN";
84 case 2:
85 return "SEND_SHUTDOWN";
86 case 3:
87 return "RCV_SHUTDOWN | SEND_SHUTDOWN";
88 default:
89 return "0";
90 }
91}
92
93static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
94{
95 if (cid == VMADDR_CID_ANY)
96 fprintf(fp, "*:");
97 else
98 fprintf(fp, "%u:", cid);
99
100 if (port == VMADDR_PORT_ANY)
101 fprintf(fp, "*");
102 else
103 fprintf(fp, "%u", port);
104}
105
106static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
107{
108 print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
109 fprintf(fp, " ");
110 print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
111 fprintf(fp, " %s %s %s %u\n",
112 sock_type_str(st->msg.vdiag_type),
113 sock_state_str(st->msg.vdiag_state),
114 sock_shutdown_str(st->msg.vdiag_shutdown),
115 st->msg.vdiag_ino);
116}
117
118static void print_vsock_stats(FILE *fp, struct list_head *head)
119{
120 struct vsock_stat *st;
121
122 list_for_each_entry(st, head, list)
123 print_vsock_stat(fp, st);
124}
125
126static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
127{
128 struct vsock_stat *st;
129 struct stat stat;
130
131 if (fstat(fd, &stat) < 0) {
132 perror("fstat");
133 exit(EXIT_FAILURE);
134 }
135
136 list_for_each_entry(st, head, list)
137 if (st->msg.vdiag_ino == stat.st_ino)
138 return st;
139
140 fprintf(stderr, "cannot find fd %d\n", fd);
141 exit(EXIT_FAILURE);
142}
143
144static void check_no_sockets(struct list_head *head)
145{
146 if (!list_empty(head)) {
147 fprintf(stderr, "expected no sockets\n");
148 print_vsock_stats(stderr, head);
149 exit(1);
150 }
151}
152
153static void check_num_sockets(struct list_head *head, int expected)
154{
155 struct list_head *node;
156 int n = 0;
157
158 list_for_each(node, head)
159 n++;
160
161 if (n != expected) {
162 fprintf(stderr, "expected %d sockets, found %d\n",
163 expected, n);
164 print_vsock_stats(stderr, head);
165 exit(EXIT_FAILURE);
166 }
167}
168
169static void check_socket_state(struct vsock_stat *st, __u8 state)
170{
171 if (st->msg.vdiag_state != state) {
172 fprintf(stderr, "expected socket state %#x, got %#x\n",
173 state, st->msg.vdiag_state);
174 exit(EXIT_FAILURE);
175 }
176}
177
178static void send_req(int fd)
179{
180 struct sockaddr_nl nladdr = {
181 .nl_family = AF_NETLINK,
182 };
183 struct {
184 struct nlmsghdr nlh;
185 struct vsock_diag_req vreq;
186 } req = {
187 .nlh = {
188 .nlmsg_len = sizeof(req),
189 .nlmsg_type = SOCK_DIAG_BY_FAMILY,
190 .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
191 },
192 .vreq = {
193 .sdiag_family = AF_VSOCK,
194 .vdiag_states = ~(__u32)0,
195 },
196 };
197 struct iovec iov = {
198 .iov_base = &req,
199 .iov_len = sizeof(req),
200 };
201 struct msghdr msg = {
202 .msg_name = &nladdr,
203 .msg_namelen = sizeof(nladdr),
204 .msg_iov = &iov,
205 .msg_iovlen = 1,
206 };
207
208 for (;;) {
209 if (sendmsg(fd, &msg, 0) < 0) {
210 if (errno == EINTR)
211 continue;
212
213 perror("sendmsg");
214 exit(EXIT_FAILURE);
215 }
216
217 return;
218 }
219}
220
221static ssize_t recv_resp(int fd, void *buf, size_t len)
222{
223 struct sockaddr_nl nladdr = {
224 .nl_family = AF_NETLINK,
225 };
226 struct iovec iov = {
227 .iov_base = buf,
228 .iov_len = len,
229 };
230 struct msghdr msg = {
231 .msg_name = &nladdr,
232 .msg_namelen = sizeof(nladdr),
233 .msg_iov = &iov,
234 .msg_iovlen = 1,
235 };
236 ssize_t ret;
237
238 do {
239 ret = recvmsg(fd, &msg, 0);
240 } while (ret < 0 && errno == EINTR);
241
242 if (ret < 0) {
243 perror("recvmsg");
244 exit(EXIT_FAILURE);
245 }
246
247 return ret;
248}
249
250static void add_vsock_stat(struct list_head *sockets,
251 const struct vsock_diag_msg *resp)
252{
253 struct vsock_stat *st;
254
255 st = malloc(sizeof(*st));
256 if (!st) {
257 perror("malloc");
258 exit(EXIT_FAILURE);
259 }
260
261 st->msg = *resp;
262 list_add_tail(&st->list, sockets);
263}
264
265/*
266 * Read vsock stats into a list.
267 */
268static void read_vsock_stat(struct list_head *sockets)
269{
270 long buf[8192 / sizeof(long)];
271 int fd;
272
273 fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
274 if (fd < 0) {
275 perror("socket");
276 exit(EXIT_FAILURE);
277 }
278
279 send_req(fd);
280
281 for (;;) {
282 const struct nlmsghdr *h;
283 ssize_t ret;
284
285 ret = recv_resp(fd, buf, sizeof(buf));
286 if (ret == 0)
287 goto done;
288 if (ret < sizeof(*h)) {
289 fprintf(stderr, "short read of %zd bytes\n", ret);
290 exit(EXIT_FAILURE);
291 }
292
293 h = (struct nlmsghdr *)buf;
294
295 while (NLMSG_OK(h, ret)) {
296 if (h->nlmsg_type == NLMSG_DONE)
297 goto done;
298
299 if (h->nlmsg_type == NLMSG_ERROR) {
300 const struct nlmsgerr *err = NLMSG_DATA(h);
301
302 if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
303 fprintf(stderr, "NLMSG_ERROR\n");
304 else {
305 errno = -err->error;
306 perror("NLMSG_ERROR");
307 }
308
309 exit(EXIT_FAILURE);
310 }
311
312 if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
313 fprintf(stderr, "unexpected nlmsg_type %#x\n",
314 h->nlmsg_type);
315 exit(EXIT_FAILURE);
316 }
317 if (h->nlmsg_len <
318 NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
319 fprintf(stderr, "short vsock_diag_msg\n");
320 exit(EXIT_FAILURE);
321 }
322
323 add_vsock_stat(sockets, NLMSG_DATA(h));
324
325 h = NLMSG_NEXT(h, ret);
326 }
327 }
328
329done:
330 close(fd);
331}
332
333static void free_sock_stat(struct list_head *sockets)
334{
335 struct vsock_stat *st;
336 struct vsock_stat *next;
337
338 list_for_each_entry_safe(st, next, sockets, list)
339 free(st);
340}
341
342static void test_no_sockets(unsigned int peer_cid)
343{
344 LIST_HEAD(sockets);
345
346 read_vsock_stat(&sockets);
347
348 check_no_sockets(&sockets);
349
350 free_sock_stat(&sockets);
351}
352
353static void test_listen_socket_server(unsigned int peer_cid)
354{
355 union {
356 struct sockaddr sa;
357 struct sockaddr_vm svm;
358 } addr = {
359 .svm = {
360 .svm_family = AF_VSOCK,
361 .svm_port = 1234,
362 .svm_cid = VMADDR_CID_ANY,
363 },
364 };
365 LIST_HEAD(sockets);
366 struct vsock_stat *st;
367 int fd;
368
369 fd = socket(AF_VSOCK, SOCK_STREAM, 0);
370
371 if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
372 perror("bind");
373 exit(EXIT_FAILURE);
374 }
375
376 if (listen(fd, 1) < 0) {
377 perror("listen");
378 exit(EXIT_FAILURE);
379 }
380
381 read_vsock_stat(&sockets);
382
383 check_num_sockets(&sockets, 1);
384 st = find_vsock_stat(&sockets, fd);
385 check_socket_state(st, TCP_LISTEN);
386
387 close(fd);
388 free_sock_stat(&sockets);
389}
390
391static void test_connect_client(unsigned int peer_cid)
392{
393 union {
394 struct sockaddr sa;
395 struct sockaddr_vm svm;
396 } addr = {
397 .svm = {
398 .svm_family = AF_VSOCK,
399 .svm_port = 1234,
400 .svm_cid = peer_cid,
401 },
402 };
403 int fd;
404 int ret;
405 LIST_HEAD(sockets);
406 struct vsock_stat *st;
407
408 control_expectln("LISTENING");
409
410 fd = socket(AF_VSOCK, SOCK_STREAM, 0);
411
412 timeout_begin(TIMEOUT);
413 do {
414 ret = connect(fd, &addr.sa, sizeof(addr.svm));
415 timeout_check("connect");
416 } while (ret < 0 && errno == EINTR);
417 timeout_end();
418
419 if (ret < 0) {
420 perror("connect");
421 exit(EXIT_FAILURE);
422 }
423
424 read_vsock_stat(&sockets);
425
426 check_num_sockets(&sockets, 1);
427 st = find_vsock_stat(&sockets, fd);
428 check_socket_state(st, TCP_ESTABLISHED);
429
430 control_expectln("DONE");
431 control_writeln("DONE");
432
433 close(fd);
434 free_sock_stat(&sockets);
435}
436
437static void test_connect_server(unsigned int peer_cid)
438{
439 union {
440 struct sockaddr sa;
441 struct sockaddr_vm svm;
442 } addr = {
443 .svm = {
444 .svm_family = AF_VSOCK,
445 .svm_port = 1234,
446 .svm_cid = VMADDR_CID_ANY,
447 },
448 };
449 union {
450 struct sockaddr sa;
451 struct sockaddr_vm svm;
452 } clientaddr;
453 socklen_t clientaddr_len = sizeof(clientaddr.svm);
454 LIST_HEAD(sockets);
455 struct vsock_stat *st;
456 int fd;
457 int client_fd;
458
459 fd = socket(AF_VSOCK, SOCK_STREAM, 0);
460
461 if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
462 perror("bind");
463 exit(EXIT_FAILURE);
464 }
465
466 if (listen(fd, 1) < 0) {
467 perror("listen");
468 exit(EXIT_FAILURE);
469 }
470
471 control_writeln("LISTENING");
472
473 timeout_begin(TIMEOUT);
474 do {
475 client_fd = accept(fd, &clientaddr.sa, &clientaddr_len);
476 timeout_check("accept");
477 } while (client_fd < 0 && errno == EINTR);
478 timeout_end();
479
480 if (client_fd < 0) {
481 perror("accept");
482 exit(EXIT_FAILURE);
483 }
484 if (clientaddr.sa.sa_family != AF_VSOCK) {
485 fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n",
486 clientaddr.sa.sa_family);
487 exit(EXIT_FAILURE);
488 }
489 if (clientaddr.svm.svm_cid != peer_cid) {
490 fprintf(stderr, "expected peer CID %u from accept(2), got %u\n",
491 peer_cid, clientaddr.svm.svm_cid);
492 exit(EXIT_FAILURE);
493 }
494
495 read_vsock_stat(&sockets);
496
497 check_num_sockets(&sockets, 2);
498 find_vsock_stat(&sockets, fd);
499 st = find_vsock_stat(&sockets, client_fd);
500 check_socket_state(st, TCP_ESTABLISHED);
501
502 control_writeln("DONE");
503 control_expectln("DONE");
504
505 close(client_fd);
506 close(fd);
507 free_sock_stat(&sockets);
508}
509
510static struct {
511 const char *name;
512 void (*run_client)(unsigned int peer_cid);
513 void (*run_server)(unsigned int peer_cid);
514} test_cases[] = {
515 {
516 .name = "No sockets",
517 .run_server = test_no_sockets,
518 },
519 {
520 .name = "Listen socket",
521 .run_server = test_listen_socket_server,
522 },
523 {
524 .name = "Connect",
525 .run_client = test_connect_client,
526 .run_server = test_connect_server,
527 },
528 {},
529};
530
531static void init_signals(void)
532{
533 struct sigaction act = {
534 .sa_handler = sigalrm,
535 };
536
537 sigaction(SIGALRM, &act, NULL);
538 signal(SIGPIPE, SIG_IGN);
539}
540
541static unsigned int parse_cid(const char *str)
542{
543 char *endptr = NULL;
544 unsigned long int n;
545
546 errno = 0;
547 n = strtoul(str, &endptr, 10);
548 if (errno || *endptr != '\0') {
549 fprintf(stderr, "malformed CID \"%s\"\n", str);
550 exit(EXIT_FAILURE);
551 }
552 return n;
553}
554
555static const char optstring[] = "";
556static const struct option longopts[] = {
557 {
558 .name = "control-host",
559 .has_arg = required_argument,
560 .val = 'H',
561 },
562 {
563 .name = "control-port",
564 .has_arg = required_argument,
565 .val = 'P',
566 },
567 {
568 .name = "mode",
569 .has_arg = required_argument,
570 .val = 'm',
571 },
572 {
573 .name = "peer-cid",
574 .has_arg = required_argument,
575 .val = 'p',
576 },
577 {
578 .name = "help",
579 .has_arg = no_argument,
580 .val = '?',
581 },
582 {},
583};
584
585static void usage(void)
586{
587 fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid>\n"
588 "\n"
589 " Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
590 " Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
591 "\n"
592 "Run vsock_diag.ko tests. Must be launched in both\n"
593 "guest and host. One side must use --mode=client and\n"
594 "the other side must use --mode=server.\n"
595 "\n"
596 "A TCP control socket connection is used to coordinate tests\n"
597 "between the client and the server. The server requires a\n"
598 "listen address and the client requires an address to\n"
599 "connect to.\n"
600 "\n"
601 "The CID of the other side must be given with --peer-cid=<cid>.\n");
602 exit(EXIT_FAILURE);
603}
604
605int main(int argc, char **argv)
606{
607 const char *control_host = NULL;
608 const char *control_port = NULL;
609 int mode = TEST_MODE_UNSET;
610 unsigned int peer_cid = VMADDR_CID_ANY;
611 int i;
612
613 init_signals();
614
615 for (;;) {
616 int opt = getopt_long(argc, argv, optstring, longopts, NULL);
617
618 if (opt == -1)
619 break;
620
621 switch (opt) {
622 case 'H':
623 control_host = optarg;
624 break;
625 case 'm':
626 if (strcmp(optarg, "client") == 0)
627 mode = TEST_MODE_CLIENT;
628 else if (strcmp(optarg, "server") == 0)
629 mode = TEST_MODE_SERVER;
630 else {
631 fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
632 return EXIT_FAILURE;
633 }
634 break;
635 case 'p':
636 peer_cid = parse_cid(optarg);
637 break;
638 case 'P':
639 control_port = optarg;
640 break;
641 case '?':
642 default:
643 usage();
644 }
645 }
646
647 if (!control_port)
648 usage();
649 if (mode == TEST_MODE_UNSET)
650 usage();
651 if (peer_cid == VMADDR_CID_ANY)
652 usage();
653
654 if (!control_host) {
655 if (mode != TEST_MODE_SERVER)
656 usage();
657 control_host = "0.0.0.0";
658 }
659
660 control_init(control_host, control_port, mode == TEST_MODE_SERVER);
661
662 for (i = 0; test_cases[i].name; i++) {
663 void (*run)(unsigned int peer_cid);
664
665 printf("%s...", test_cases[i].name);
666 fflush(stdout);
667
668 if (mode == TEST_MODE_CLIENT)
669 run = test_cases[i].run_client;
670 else
671 run = test_cases[i].run_server;
672
673 if (run)
674 run(peer_cid);
675
676 printf("ok\n");
677 }
678
679 control_cleanup();
680 return EXIT_SUCCESS;
681}