Loading...
1// SPDX-License-Identifier: GPL-2.0
2#include <sys/un.h>
3
4#include "test_progs.h"
5
6#include "connect_unix_prog.skel.h"
7#include "sendmsg_unix_prog.skel.h"
8#include "recvmsg_unix_prog.skel.h"
9#include "getsockname_unix_prog.skel.h"
10#include "getpeername_unix_prog.skel.h"
11#include "network_helpers.h"
12
13#define SERVUN_ADDRESS "bpf_cgroup_unix_test"
14#define SERVUN_REWRITE_ADDRESS "bpf_cgroup_unix_test_rewrite"
15#define SRCUN_ADDRESS "bpf_cgroup_unix_test_src"
16
17enum sock_addr_test_type {
18 SOCK_ADDR_TEST_BIND,
19 SOCK_ADDR_TEST_CONNECT,
20 SOCK_ADDR_TEST_SENDMSG,
21 SOCK_ADDR_TEST_RECVMSG,
22 SOCK_ADDR_TEST_GETSOCKNAME,
23 SOCK_ADDR_TEST_GETPEERNAME,
24};
25
26typedef void *(*load_fn)(int cgroup_fd);
27typedef void (*destroy_fn)(void *skel);
28
29struct sock_addr_test {
30 enum sock_addr_test_type type;
31 const char *name;
32 /* BPF prog properties */
33 load_fn loadfn;
34 destroy_fn destroyfn;
35 /* Socket properties */
36 int socket_family;
37 int socket_type;
38 /* IP:port pairs for BPF prog to override */
39 const char *requested_addr;
40 unsigned short requested_port;
41 const char *expected_addr;
42 unsigned short expected_port;
43 const char *expected_src_addr;
44};
45
46static void *connect_unix_prog_load(int cgroup_fd)
47{
48 struct connect_unix_prog *skel;
49
50 skel = connect_unix_prog__open_and_load();
51 if (!ASSERT_OK_PTR(skel, "skel_open"))
52 goto cleanup;
53
54 skel->links.connect_unix_prog = bpf_program__attach_cgroup(
55 skel->progs.connect_unix_prog, cgroup_fd);
56 if (!ASSERT_OK_PTR(skel->links.connect_unix_prog, "prog_attach"))
57 goto cleanup;
58
59 return skel;
60cleanup:
61 connect_unix_prog__destroy(skel);
62 return NULL;
63}
64
65static void connect_unix_prog_destroy(void *skel)
66{
67 connect_unix_prog__destroy(skel);
68}
69
70static void *sendmsg_unix_prog_load(int cgroup_fd)
71{
72 struct sendmsg_unix_prog *skel;
73
74 skel = sendmsg_unix_prog__open_and_load();
75 if (!ASSERT_OK_PTR(skel, "skel_open"))
76 goto cleanup;
77
78 skel->links.sendmsg_unix_prog = bpf_program__attach_cgroup(
79 skel->progs.sendmsg_unix_prog, cgroup_fd);
80 if (!ASSERT_OK_PTR(skel->links.sendmsg_unix_prog, "prog_attach"))
81 goto cleanup;
82
83 return skel;
84cleanup:
85 sendmsg_unix_prog__destroy(skel);
86 return NULL;
87}
88
89static void sendmsg_unix_prog_destroy(void *skel)
90{
91 sendmsg_unix_prog__destroy(skel);
92}
93
94static void *recvmsg_unix_prog_load(int cgroup_fd)
95{
96 struct recvmsg_unix_prog *skel;
97
98 skel = recvmsg_unix_prog__open_and_load();
99 if (!ASSERT_OK_PTR(skel, "skel_open"))
100 goto cleanup;
101
102 skel->links.recvmsg_unix_prog = bpf_program__attach_cgroup(
103 skel->progs.recvmsg_unix_prog, cgroup_fd);
104 if (!ASSERT_OK_PTR(skel->links.recvmsg_unix_prog, "prog_attach"))
105 goto cleanup;
106
107 return skel;
108cleanup:
109 recvmsg_unix_prog__destroy(skel);
110 return NULL;
111}
112
113static void recvmsg_unix_prog_destroy(void *skel)
114{
115 recvmsg_unix_prog__destroy(skel);
116}
117
118static void *getsockname_unix_prog_load(int cgroup_fd)
119{
120 struct getsockname_unix_prog *skel;
121
122 skel = getsockname_unix_prog__open_and_load();
123 if (!ASSERT_OK_PTR(skel, "skel_open"))
124 goto cleanup;
125
126 skel->links.getsockname_unix_prog = bpf_program__attach_cgroup(
127 skel->progs.getsockname_unix_prog, cgroup_fd);
128 if (!ASSERT_OK_PTR(skel->links.getsockname_unix_prog, "prog_attach"))
129 goto cleanup;
130
131 return skel;
132cleanup:
133 getsockname_unix_prog__destroy(skel);
134 return NULL;
135}
136
137static void getsockname_unix_prog_destroy(void *skel)
138{
139 getsockname_unix_prog__destroy(skel);
140}
141
142static void *getpeername_unix_prog_load(int cgroup_fd)
143{
144 struct getpeername_unix_prog *skel;
145
146 skel = getpeername_unix_prog__open_and_load();
147 if (!ASSERT_OK_PTR(skel, "skel_open"))
148 goto cleanup;
149
150 skel->links.getpeername_unix_prog = bpf_program__attach_cgroup(
151 skel->progs.getpeername_unix_prog, cgroup_fd);
152 if (!ASSERT_OK_PTR(skel->links.getpeername_unix_prog, "prog_attach"))
153 goto cleanup;
154
155 return skel;
156cleanup:
157 getpeername_unix_prog__destroy(skel);
158 return NULL;
159}
160
161static void getpeername_unix_prog_destroy(void *skel)
162{
163 getpeername_unix_prog__destroy(skel);
164}
165
166static struct sock_addr_test tests[] = {
167 {
168 SOCK_ADDR_TEST_CONNECT,
169 "connect_unix",
170 connect_unix_prog_load,
171 connect_unix_prog_destroy,
172 AF_UNIX,
173 SOCK_STREAM,
174 SERVUN_ADDRESS,
175 0,
176 SERVUN_REWRITE_ADDRESS,
177 0,
178 NULL,
179 },
180 {
181 SOCK_ADDR_TEST_SENDMSG,
182 "sendmsg_unix",
183 sendmsg_unix_prog_load,
184 sendmsg_unix_prog_destroy,
185 AF_UNIX,
186 SOCK_DGRAM,
187 SERVUN_ADDRESS,
188 0,
189 SERVUN_REWRITE_ADDRESS,
190 0,
191 NULL,
192 },
193 {
194 SOCK_ADDR_TEST_RECVMSG,
195 "recvmsg_unix-dgram",
196 recvmsg_unix_prog_load,
197 recvmsg_unix_prog_destroy,
198 AF_UNIX,
199 SOCK_DGRAM,
200 SERVUN_REWRITE_ADDRESS,
201 0,
202 SERVUN_REWRITE_ADDRESS,
203 0,
204 SERVUN_ADDRESS,
205 },
206 {
207 SOCK_ADDR_TEST_RECVMSG,
208 "recvmsg_unix-stream",
209 recvmsg_unix_prog_load,
210 recvmsg_unix_prog_destroy,
211 AF_UNIX,
212 SOCK_STREAM,
213 SERVUN_REWRITE_ADDRESS,
214 0,
215 SERVUN_REWRITE_ADDRESS,
216 0,
217 SERVUN_ADDRESS,
218 },
219 {
220 SOCK_ADDR_TEST_GETSOCKNAME,
221 "getsockname_unix",
222 getsockname_unix_prog_load,
223 getsockname_unix_prog_destroy,
224 AF_UNIX,
225 SOCK_STREAM,
226 SERVUN_ADDRESS,
227 0,
228 SERVUN_REWRITE_ADDRESS,
229 0,
230 NULL,
231 },
232 {
233 SOCK_ADDR_TEST_GETPEERNAME,
234 "getpeername_unix",
235 getpeername_unix_prog_load,
236 getpeername_unix_prog_destroy,
237 AF_UNIX,
238 SOCK_STREAM,
239 SERVUN_ADDRESS,
240 0,
241 SERVUN_REWRITE_ADDRESS,
242 0,
243 NULL,
244 },
245};
246
247typedef int (*info_fn)(int, struct sockaddr *, socklen_t *);
248
249static int cmp_addr(const struct sockaddr_storage *addr1, socklen_t addr1_len,
250 const struct sockaddr_storage *addr2, socklen_t addr2_len,
251 bool cmp_port)
252{
253 const struct sockaddr_in *four1, *four2;
254 const struct sockaddr_in6 *six1, *six2;
255 const struct sockaddr_un *un1, *un2;
256
257 if (addr1->ss_family != addr2->ss_family)
258 return -1;
259
260 if (addr1_len != addr2_len)
261 return -1;
262
263 if (addr1->ss_family == AF_INET) {
264 four1 = (const struct sockaddr_in *)addr1;
265 four2 = (const struct sockaddr_in *)addr2;
266 return !((four1->sin_port == four2->sin_port || !cmp_port) &&
267 four1->sin_addr.s_addr == four2->sin_addr.s_addr);
268 } else if (addr1->ss_family == AF_INET6) {
269 six1 = (const struct sockaddr_in6 *)addr1;
270 six2 = (const struct sockaddr_in6 *)addr2;
271 return !((six1->sin6_port == six2->sin6_port || !cmp_port) &&
272 !memcmp(&six1->sin6_addr, &six2->sin6_addr,
273 sizeof(struct in6_addr)));
274 } else if (addr1->ss_family == AF_UNIX) {
275 un1 = (const struct sockaddr_un *)addr1;
276 un2 = (const struct sockaddr_un *)addr2;
277 return memcmp(un1, un2, addr1_len);
278 }
279
280 return -1;
281}
282
283static int cmp_sock_addr(info_fn fn, int sock1,
284 const struct sockaddr_storage *addr2,
285 socklen_t addr2_len, bool cmp_port)
286{
287 struct sockaddr_storage addr1;
288 socklen_t len1 = sizeof(addr1);
289
290 memset(&addr1, 0, len1);
291 if (fn(sock1, (struct sockaddr *)&addr1, (socklen_t *)&len1) != 0)
292 return -1;
293
294 return cmp_addr(&addr1, len1, addr2, addr2_len, cmp_port);
295}
296
297static int cmp_local_addr(int sock1, const struct sockaddr_storage *addr2,
298 socklen_t addr2_len, bool cmp_port)
299{
300 return cmp_sock_addr(getsockname, sock1, addr2, addr2_len, cmp_port);
301}
302
303static int cmp_peer_addr(int sock1, const struct sockaddr_storage *addr2,
304 socklen_t addr2_len, bool cmp_port)
305{
306 return cmp_sock_addr(getpeername, sock1, addr2, addr2_len, cmp_port);
307}
308
309static void test_bind(struct sock_addr_test *test)
310{
311 struct sockaddr_storage expected_addr;
312 socklen_t expected_addr_len = sizeof(struct sockaddr_storage);
313 int serv = -1, client = -1, err;
314
315 serv = start_server(test->socket_family, test->socket_type,
316 test->requested_addr, test->requested_port, 0);
317 if (!ASSERT_GE(serv, 0, "start_server"))
318 goto cleanup;
319
320 err = make_sockaddr(test->socket_family,
321 test->expected_addr, test->expected_port,
322 &expected_addr, &expected_addr_len);
323 if (!ASSERT_EQ(err, 0, "make_sockaddr"))
324 goto cleanup;
325
326 err = cmp_local_addr(serv, &expected_addr, expected_addr_len, true);
327 if (!ASSERT_EQ(err, 0, "cmp_local_addr"))
328 goto cleanup;
329
330 /* Try to connect to server just in case */
331 client = connect_to_addr(&expected_addr, expected_addr_len, test->socket_type);
332 if (!ASSERT_GE(client, 0, "connect_to_addr"))
333 goto cleanup;
334
335cleanup:
336 if (client != -1)
337 close(client);
338 if (serv != -1)
339 close(serv);
340}
341
342static void test_connect(struct sock_addr_test *test)
343{
344 struct sockaddr_storage addr, expected_addr, expected_src_addr;
345 socklen_t addr_len = sizeof(struct sockaddr_storage),
346 expected_addr_len = sizeof(struct sockaddr_storage),
347 expected_src_addr_len = sizeof(struct sockaddr_storage);
348 int serv = -1, client = -1, err;
349
350 serv = start_server(test->socket_family, test->socket_type,
351 test->expected_addr, test->expected_port, 0);
352 if (!ASSERT_GE(serv, 0, "start_server"))
353 goto cleanup;
354
355 err = make_sockaddr(test->socket_family, test->requested_addr, test->requested_port,
356 &addr, &addr_len);
357 if (!ASSERT_EQ(err, 0, "make_sockaddr"))
358 goto cleanup;
359
360 client = connect_to_addr(&addr, addr_len, test->socket_type);
361 if (!ASSERT_GE(client, 0, "connect_to_addr"))
362 goto cleanup;
363
364 err = make_sockaddr(test->socket_family, test->expected_addr, test->expected_port,
365 &expected_addr, &expected_addr_len);
366 if (!ASSERT_EQ(err, 0, "make_sockaddr"))
367 goto cleanup;
368
369 if (test->expected_src_addr) {
370 err = make_sockaddr(test->socket_family, test->expected_src_addr, 0,
371 &expected_src_addr, &expected_src_addr_len);
372 if (!ASSERT_EQ(err, 0, "make_sockaddr"))
373 goto cleanup;
374 }
375
376 err = cmp_peer_addr(client, &expected_addr, expected_addr_len, true);
377 if (!ASSERT_EQ(err, 0, "cmp_peer_addr"))
378 goto cleanup;
379
380 if (test->expected_src_addr) {
381 err = cmp_local_addr(client, &expected_src_addr, expected_src_addr_len, false);
382 if (!ASSERT_EQ(err, 0, "cmp_local_addr"))
383 goto cleanup;
384 }
385cleanup:
386 if (client != -1)
387 close(client);
388 if (serv != -1)
389 close(serv);
390}
391
392static void test_xmsg(struct sock_addr_test *test)
393{
394 struct sockaddr_storage addr, src_addr;
395 socklen_t addr_len = sizeof(struct sockaddr_storage),
396 src_addr_len = sizeof(struct sockaddr_storage);
397 struct msghdr hdr;
398 struct iovec iov;
399 char data = 'a';
400 int serv = -1, client = -1, err;
401
402 /* Unlike the other tests, here we test that we can rewrite the src addr
403 * with a recvmsg() hook.
404 */
405
406 serv = start_server(test->socket_family, test->socket_type,
407 test->expected_addr, test->expected_port, 0);
408 if (!ASSERT_GE(serv, 0, "start_server"))
409 goto cleanup;
410
411 client = socket(test->socket_family, test->socket_type, 0);
412 if (!ASSERT_GE(client, 0, "socket"))
413 goto cleanup;
414
415 /* AF_UNIX sockets have to be bound to something to trigger the recvmsg bpf program. */
416 if (test->socket_family == AF_UNIX) {
417 err = make_sockaddr(AF_UNIX, SRCUN_ADDRESS, 0, &src_addr, &src_addr_len);
418 if (!ASSERT_EQ(err, 0, "make_sockaddr"))
419 goto cleanup;
420
421 err = bind(client, (const struct sockaddr *) &src_addr, src_addr_len);
422 if (!ASSERT_OK(err, "bind"))
423 goto cleanup;
424 }
425
426 err = make_sockaddr(test->socket_family, test->requested_addr, test->requested_port,
427 &addr, &addr_len);
428 if (!ASSERT_EQ(err, 0, "make_sockaddr"))
429 goto cleanup;
430
431 if (test->socket_type == SOCK_DGRAM) {
432 memset(&iov, 0, sizeof(iov));
433 iov.iov_base = &data;
434 iov.iov_len = sizeof(data);
435
436 memset(&hdr, 0, sizeof(hdr));
437 hdr.msg_name = (void *)&addr;
438 hdr.msg_namelen = addr_len;
439 hdr.msg_iov = &iov;
440 hdr.msg_iovlen = 1;
441
442 err = sendmsg(client, &hdr, 0);
443 if (!ASSERT_EQ(err, sizeof(data), "sendmsg"))
444 goto cleanup;
445 } else {
446 /* Testing with connection-oriented sockets is only valid for
447 * recvmsg() tests.
448 */
449 if (!ASSERT_EQ(test->type, SOCK_ADDR_TEST_RECVMSG, "recvmsg"))
450 goto cleanup;
451
452 err = connect(client, (const struct sockaddr *)&addr, addr_len);
453 if (!ASSERT_OK(err, "connect"))
454 goto cleanup;
455
456 err = send(client, &data, sizeof(data), 0);
457 if (!ASSERT_EQ(err, sizeof(data), "send"))
458 goto cleanup;
459
460 err = listen(serv, 0);
461 if (!ASSERT_OK(err, "listen"))
462 goto cleanup;
463
464 err = accept(serv, NULL, NULL);
465 if (!ASSERT_GE(err, 0, "accept"))
466 goto cleanup;
467
468 close(serv);
469 serv = err;
470 }
471
472 addr_len = src_addr_len = sizeof(struct sockaddr_storage);
473
474 err = recvfrom(serv, &data, sizeof(data), 0, (struct sockaddr *) &src_addr, &src_addr_len);
475 if (!ASSERT_EQ(err, sizeof(data), "recvfrom"))
476 goto cleanup;
477
478 ASSERT_EQ(data, 'a', "data mismatch");
479
480 if (test->expected_src_addr) {
481 err = make_sockaddr(test->socket_family, test->expected_src_addr, 0,
482 &addr, &addr_len);
483 if (!ASSERT_EQ(err, 0, "make_sockaddr"))
484 goto cleanup;
485
486 err = cmp_addr(&src_addr, src_addr_len, &addr, addr_len, false);
487 if (!ASSERT_EQ(err, 0, "cmp_addr"))
488 goto cleanup;
489 }
490
491cleanup:
492 if (client != -1)
493 close(client);
494 if (serv != -1)
495 close(serv);
496}
497
498static void test_getsockname(struct sock_addr_test *test)
499{
500 struct sockaddr_storage expected_addr;
501 socklen_t expected_addr_len = sizeof(struct sockaddr_storage);
502 int serv = -1, err;
503
504 serv = start_server(test->socket_family, test->socket_type,
505 test->requested_addr, test->requested_port, 0);
506 if (!ASSERT_GE(serv, 0, "start_server"))
507 goto cleanup;
508
509 err = make_sockaddr(test->socket_family,
510 test->expected_addr, test->expected_port,
511 &expected_addr, &expected_addr_len);
512 if (!ASSERT_EQ(err, 0, "make_sockaddr"))
513 goto cleanup;
514
515 err = cmp_local_addr(serv, &expected_addr, expected_addr_len, true);
516 if (!ASSERT_EQ(err, 0, "cmp_local_addr"))
517 goto cleanup;
518
519cleanup:
520 if (serv != -1)
521 close(serv);
522}
523
524static void test_getpeername(struct sock_addr_test *test)
525{
526 struct sockaddr_storage addr, expected_addr;
527 socklen_t addr_len = sizeof(struct sockaddr_storage),
528 expected_addr_len = sizeof(struct sockaddr_storage);
529 int serv = -1, client = -1, err;
530
531 serv = start_server(test->socket_family, test->socket_type,
532 test->requested_addr, test->requested_port, 0);
533 if (!ASSERT_GE(serv, 0, "start_server"))
534 goto cleanup;
535
536 err = make_sockaddr(test->socket_family, test->requested_addr, test->requested_port,
537 &addr, &addr_len);
538 if (!ASSERT_EQ(err, 0, "make_sockaddr"))
539 goto cleanup;
540
541 client = connect_to_addr(&addr, addr_len, test->socket_type);
542 if (!ASSERT_GE(client, 0, "connect_to_addr"))
543 goto cleanup;
544
545 err = make_sockaddr(test->socket_family, test->expected_addr, test->expected_port,
546 &expected_addr, &expected_addr_len);
547 if (!ASSERT_EQ(err, 0, "make_sockaddr"))
548 goto cleanup;
549
550 err = cmp_peer_addr(client, &expected_addr, expected_addr_len, true);
551 if (!ASSERT_EQ(err, 0, "cmp_peer_addr"))
552 goto cleanup;
553
554cleanup:
555 if (client != -1)
556 close(client);
557 if (serv != -1)
558 close(serv);
559}
560
561void test_sock_addr(void)
562{
563 int cgroup_fd = -1;
564 void *skel;
565
566 cgroup_fd = test__join_cgroup("/sock_addr");
567 if (!ASSERT_GE(cgroup_fd, 0, "join_cgroup"))
568 goto cleanup;
569
570 for (size_t i = 0; i < ARRAY_SIZE(tests); ++i) {
571 struct sock_addr_test *test = &tests[i];
572
573 if (!test__start_subtest(test->name))
574 continue;
575
576 skel = test->loadfn(cgroup_fd);
577 if (!skel)
578 continue;
579
580 switch (test->type) {
581 /* Not exercised yet but we leave this code here for when the
582 * INET and INET6 sockaddr tests are migrated to this file in
583 * the future.
584 */
585 case SOCK_ADDR_TEST_BIND:
586 test_bind(test);
587 break;
588 case SOCK_ADDR_TEST_CONNECT:
589 test_connect(test);
590 break;
591 case SOCK_ADDR_TEST_SENDMSG:
592 case SOCK_ADDR_TEST_RECVMSG:
593 test_xmsg(test);
594 break;
595 case SOCK_ADDR_TEST_GETSOCKNAME:
596 test_getsockname(test);
597 break;
598 case SOCK_ADDR_TEST_GETPEERNAME:
599 test_getpeername(test);
600 break;
601 default:
602 ASSERT_TRUE(false, "Unknown sock addr test type");
603 break;
604 }
605
606 test->destroyfn(skel);
607 }
608
609cleanup:
610 if (cgroup_fd >= 0)
611 close(cgroup_fd);
612}