Linux Audio

Check our new training course

Loading...
Note: File does not exist in v4.10.11.
  1// SPDX-License-Identifier: GPL-2.0
  2/* Copyright (c) 2020, Tessares SA. */
  3/* Copyright (c) 2022, SUSE. */
  4
  5#include <linux/const.h>
  6#include <netinet/in.h>
  7#include <test_progs.h>
  8#include "cgroup_helpers.h"
  9#include "network_helpers.h"
 10#include "mptcp_sock.skel.h"
 11#include "mptcpify.skel.h"
 12
 13#define NS_TEST "mptcp_ns"
 14
 15#ifndef IPPROTO_MPTCP
 16#define IPPROTO_MPTCP 262
 17#endif
 18
 19#ifndef SOL_MPTCP
 20#define SOL_MPTCP 284
 21#endif
 22#ifndef MPTCP_INFO
 23#define MPTCP_INFO		1
 24#endif
 25#ifndef MPTCP_INFO_FLAG_FALLBACK
 26#define MPTCP_INFO_FLAG_FALLBACK		_BITUL(0)
 27#endif
 28#ifndef MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED
 29#define MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED	_BITUL(1)
 30#endif
 31
 32#ifndef TCP_CA_NAME_MAX
 33#define TCP_CA_NAME_MAX	16
 34#endif
 35
 36struct __mptcp_info {
 37	__u8	mptcpi_subflows;
 38	__u8	mptcpi_add_addr_signal;
 39	__u8	mptcpi_add_addr_accepted;
 40	__u8	mptcpi_subflows_max;
 41	__u8	mptcpi_add_addr_signal_max;
 42	__u8	mptcpi_add_addr_accepted_max;
 43	__u32	mptcpi_flags;
 44	__u32	mptcpi_token;
 45	__u64	mptcpi_write_seq;
 46	__u64	mptcpi_snd_una;
 47	__u64	mptcpi_rcv_nxt;
 48	__u8	mptcpi_local_addr_used;
 49	__u8	mptcpi_local_addr_max;
 50	__u8	mptcpi_csum_enabled;
 51	__u32	mptcpi_retransmits;
 52	__u64	mptcpi_bytes_retrans;
 53	__u64	mptcpi_bytes_sent;
 54	__u64	mptcpi_bytes_received;
 55	__u64	mptcpi_bytes_acked;
 56};
 57
 58struct mptcp_storage {
 59	__u32 invoked;
 60	__u32 is_mptcp;
 61	struct sock *sk;
 62	__u32 token;
 63	struct sock *first;
 64	char ca_name[TCP_CA_NAME_MAX];
 65};
 66
 67static struct nstoken *create_netns(void)
 68{
 69	SYS(fail, "ip netns add %s", NS_TEST);
 70	SYS(fail, "ip -net %s link set dev lo up", NS_TEST);
 71
 72	return open_netns(NS_TEST);
 73fail:
 74	return NULL;
 75}
 76
 77static void cleanup_netns(struct nstoken *nstoken)
 78{
 79	if (nstoken)
 80		close_netns(nstoken);
 81
 82	SYS_NOFAIL("ip netns del %s &> /dev/null", NS_TEST);
 83}
 84
 85static int verify_tsk(int map_fd, int client_fd)
 86{
 87	int err, cfd = client_fd;
 88	struct mptcp_storage val;
 89
 90	err = bpf_map_lookup_elem(map_fd, &cfd, &val);
 91	if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
 92		return err;
 93
 94	if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
 95		err++;
 96
 97	if (!ASSERT_EQ(val.is_mptcp, 0, "unexpected is_mptcp"))
 98		err++;
 99
100	return err;
101}
102
103static void get_msk_ca_name(char ca_name[])
104{
105	size_t len;
106	int fd;
107
108	fd = open("/proc/sys/net/ipv4/tcp_congestion_control", O_RDONLY);
109	if (!ASSERT_GE(fd, 0, "failed to open tcp_congestion_control"))
110		return;
111
112	len = read(fd, ca_name, TCP_CA_NAME_MAX);
113	if (!ASSERT_GT(len, 0, "failed to read ca_name"))
114		goto err;
115
116	if (len > 0 && ca_name[len - 1] == '\n')
117		ca_name[len - 1] = '\0';
118
119err:
120	close(fd);
121}
122
123static int verify_msk(int map_fd, int client_fd, __u32 token)
124{
125	char ca_name[TCP_CA_NAME_MAX];
126	int err, cfd = client_fd;
127	struct mptcp_storage val;
128
129	if (!ASSERT_GT(token, 0, "invalid token"))
130		return -1;
131
132	get_msk_ca_name(ca_name);
133
134	err = bpf_map_lookup_elem(map_fd, &cfd, &val);
135	if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
136		return err;
137
138	if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
139		err++;
140
141	if (!ASSERT_EQ(val.is_mptcp, 1, "unexpected is_mptcp"))
142		err++;
143
144	if (!ASSERT_EQ(val.token, token, "unexpected token"))
145		err++;
146
147	if (!ASSERT_EQ(val.first, val.sk, "unexpected first"))
148		err++;
149
150	if (!ASSERT_STRNEQ(val.ca_name, ca_name, TCP_CA_NAME_MAX, "unexpected ca_name"))
151		err++;
152
153	return err;
154}
155
156static int run_test(int cgroup_fd, int server_fd, bool is_mptcp)
157{
158	int client_fd, prog_fd, map_fd, err;
159	struct mptcp_sock *sock_skel;
160
161	sock_skel = mptcp_sock__open_and_load();
162	if (!ASSERT_OK_PTR(sock_skel, "skel_open_load"))
163		return libbpf_get_error(sock_skel);
164
165	err = mptcp_sock__attach(sock_skel);
166	if (!ASSERT_OK(err, "skel_attach"))
167		goto out;
168
169	prog_fd = bpf_program__fd(sock_skel->progs._sockops);
170	map_fd = bpf_map__fd(sock_skel->maps.socket_storage_map);
171	err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0);
172	if (!ASSERT_OK(err, "bpf_prog_attach"))
173		goto out;
174
175	client_fd = connect_to_fd(server_fd, 0);
176	if (!ASSERT_GE(client_fd, 0, "connect to fd")) {
177		err = -EIO;
178		goto out;
179	}
180
181	err += is_mptcp ? verify_msk(map_fd, client_fd, sock_skel->bss->token) :
182			  verify_tsk(map_fd, client_fd);
183
184	close(client_fd);
185
186out:
187	mptcp_sock__destroy(sock_skel);
188	return err;
189}
190
191static void test_base(void)
192{
193	struct nstoken *nstoken = NULL;
194	int server_fd, cgroup_fd;
195
196	cgroup_fd = test__join_cgroup("/mptcp");
197	if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup"))
198		return;
199
200	nstoken = create_netns();
201	if (!ASSERT_OK_PTR(nstoken, "create_netns"))
202		goto fail;
203
204	/* without MPTCP */
205	server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0);
206	if (!ASSERT_GE(server_fd, 0, "start_server"))
207		goto with_mptcp;
208
209	ASSERT_OK(run_test(cgroup_fd, server_fd, false), "run_test tcp");
210
211	close(server_fd);
212
213with_mptcp:
214	/* with MPTCP */
215	server_fd = start_mptcp_server(AF_INET, NULL, 0, 0);
216	if (!ASSERT_GE(server_fd, 0, "start_mptcp_server"))
217		goto fail;
218
219	ASSERT_OK(run_test(cgroup_fd, server_fd, true), "run_test mptcp");
220
221	close(server_fd);
222
223fail:
224	cleanup_netns(nstoken);
225	close(cgroup_fd);
226}
227
228static void send_byte(int fd)
229{
230	char b = 0x55;
231
232	ASSERT_EQ(write(fd, &b, sizeof(b)), 1, "send single byte");
233}
234
235static int verify_mptcpify(int server_fd, int client_fd)
236{
237	struct __mptcp_info info;
238	socklen_t optlen;
239	int protocol;
240	int err = 0;
241
242	optlen = sizeof(protocol);
243	if (!ASSERT_OK(getsockopt(server_fd, SOL_SOCKET, SO_PROTOCOL, &protocol, &optlen),
244		       "getsockopt(SOL_PROTOCOL)"))
245		return -1;
246
247	if (!ASSERT_EQ(protocol, IPPROTO_MPTCP, "protocol isn't MPTCP"))
248		err++;
249
250	optlen = sizeof(info);
251	if (!ASSERT_OK(getsockopt(client_fd, SOL_MPTCP, MPTCP_INFO, &info, &optlen),
252		       "getsockopt(MPTCP_INFO)"))
253		return -1;
254
255	if (!ASSERT_GE(info.mptcpi_flags, 0, "unexpected mptcpi_flags"))
256		err++;
257	if (!ASSERT_FALSE(info.mptcpi_flags & MPTCP_INFO_FLAG_FALLBACK,
258			  "MPTCP fallback"))
259		err++;
260	if (!ASSERT_TRUE(info.mptcpi_flags & MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED,
261			 "no remote key received"))
262		err++;
263
264	return err;
265}
266
267static int run_mptcpify(int cgroup_fd)
268{
269	int server_fd, client_fd, err = 0;
270	struct mptcpify *mptcpify_skel;
271
272	mptcpify_skel = mptcpify__open_and_load();
273	if (!ASSERT_OK_PTR(mptcpify_skel, "skel_open_load"))
274		return libbpf_get_error(mptcpify_skel);
275
276	err = mptcpify__attach(mptcpify_skel);
277	if (!ASSERT_OK(err, "skel_attach"))
278		goto out;
279
280	/* without MPTCP */
281	server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0);
282	if (!ASSERT_GE(server_fd, 0, "start_server")) {
283		err = -EIO;
284		goto out;
285	}
286
287	client_fd = connect_to_fd(server_fd, 0);
288	if (!ASSERT_GE(client_fd, 0, "connect to fd")) {
289		err = -EIO;
290		goto close_server;
291	}
292
293	send_byte(client_fd);
294
295	err = verify_mptcpify(server_fd, client_fd);
296
297	close(client_fd);
298close_server:
299	close(server_fd);
300out:
301	mptcpify__destroy(mptcpify_skel);
302	return err;
303}
304
305static void test_mptcpify(void)
306{
307	struct nstoken *nstoken = NULL;
308	int cgroup_fd;
309
310	cgroup_fd = test__join_cgroup("/mptcpify");
311	if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup"))
312		return;
313
314	nstoken = create_netns();
315	if (!ASSERT_OK_PTR(nstoken, "create_netns"))
316		goto fail;
317
318	ASSERT_OK(run_mptcpify(cgroup_fd), "run_mptcpify");
319
320fail:
321	cleanup_netns(nstoken);
322	close(cgroup_fd);
323}
324
325void test_mptcp(void)
326{
327	if (test__start_subtest("base"))
328		test_base();
329	if (test__start_subtest("mptcpify"))
330		test_mptcpify();
331}