Linux Audio

Check our new training course

Loading...
Note: File does not exist in v5.9.
   1// SPDX-License-Identifier: GPL-2.0
   2/*
   3 * ipsec.c - Check xfrm on veth inside a net-ns.
   4 * Copyright (c) 2018 Dmitry Safonov
   5 */
   6
   7#define _GNU_SOURCE
   8
   9#include <arpa/inet.h>
  10#include <asm/types.h>
  11#include <errno.h>
  12#include <fcntl.h>
  13#include <limits.h>
  14#include <linux/limits.h>
  15#include <linux/netlink.h>
  16#include <linux/random.h>
  17#include <linux/rtnetlink.h>
  18#include <linux/veth.h>
  19#include <linux/xfrm.h>
  20#include <netinet/in.h>
  21#include <net/if.h>
  22#include <sched.h>
  23#include <stdbool.h>
  24#include <stdint.h>
  25#include <stdio.h>
  26#include <stdlib.h>
  27#include <string.h>
  28#include <sys/mman.h>
  29#include <sys/socket.h>
  30#include <sys/stat.h>
  31#include <sys/syscall.h>
  32#include <sys/types.h>
  33#include <sys/wait.h>
  34#include <time.h>
  35#include <unistd.h>
  36
  37#include "../kselftest.h"
  38
  39#define printk(fmt, ...)						\
  40	ksft_print_msg("%d[%u] " fmt "\n", getpid(), __LINE__, ##__VA_ARGS__)
  41
  42#define pr_err(fmt, ...)	printk(fmt ": %m", ##__VA_ARGS__)
  43
  44#define BUILD_BUG_ON(condition) ((void)sizeof(char[1 - 2*!!(condition)]))
  45
  46#define IPV4_STR_SZ	16	/* xxx.xxx.xxx.xxx is longest + \0 */
  47#define MAX_PAYLOAD	2048
  48#define XFRM_ALGO_KEY_BUF_SIZE	512
  49#define MAX_PROCESSES	(1 << 14) /* /16 mask divided by /30 subnets */
  50#define INADDR_A	((in_addr_t) 0x0a000000) /* 10.0.0.0 */
  51#define INADDR_B	((in_addr_t) 0xc0a80000) /* 192.168.0.0 */
  52
  53/* /30 mask for one veth connection */
  54#define PREFIX_LEN	30
  55#define child_ip(nr)	(4*nr + 1)
  56#define grchild_ip(nr)	(4*nr + 2)
  57
  58#define VETH_FMT	"ktst-%d"
  59#define VETH_LEN	12
  60
  61#define XFRM_ALGO_NR_KEYS 29
  62
  63static int nsfd_parent	= -1;
  64static int nsfd_childa	= -1;
  65static int nsfd_childb	= -1;
  66static long page_size;
  67
  68/*
  69 * ksft_cnt is static in kselftest, so isn't shared with children.
  70 * We have to send a test result back to parent and count there.
  71 * results_fd is a pipe with test feedback from children.
  72 */
  73static int results_fd[2];
  74
  75const unsigned int ping_delay_nsec	= 50 * 1000 * 1000;
  76const unsigned int ping_timeout		= 300;
  77const unsigned int ping_count		= 100;
  78const unsigned int ping_success		= 80;
  79
  80struct xfrm_key_entry {
  81	char algo_name[35];
  82	int key_len;
  83};
  84
  85struct xfrm_key_entry xfrm_key_entries[] = {
  86	{"digest_null", 0},
  87	{"ecb(cipher_null)", 0},
  88	{"cbc(des)", 64},
  89	{"hmac(md5)", 128},
  90	{"cmac(aes)", 128},
  91	{"xcbc(aes)", 128},
  92	{"cbc(cast5)", 128},
  93	{"cbc(serpent)", 128},
  94	{"hmac(sha1)", 160},
  95	{"hmac(rmd160)", 160},
  96	{"cbc(des3_ede)", 192},
  97	{"hmac(sha256)", 256},
  98	{"cbc(aes)", 256},
  99	{"cbc(camellia)", 256},
 100	{"cbc(twofish)", 256},
 101	{"rfc3686(ctr(aes))", 288},
 102	{"hmac(sha384)", 384},
 103	{"cbc(blowfish)", 448},
 104	{"hmac(sha512)", 512},
 105	{"rfc4106(gcm(aes))-128", 160},
 106	{"rfc4543(gcm(aes))-128", 160},
 107	{"rfc4309(ccm(aes))-128", 152},
 108	{"rfc4106(gcm(aes))-192", 224},
 109	{"rfc4543(gcm(aes))-192", 224},
 110	{"rfc4309(ccm(aes))-192", 216},
 111	{"rfc4106(gcm(aes))-256", 288},
 112	{"rfc4543(gcm(aes))-256", 288},
 113	{"rfc4309(ccm(aes))-256", 280},
 114	{"rfc7539(chacha20,poly1305)-128", 0}
 115};
 116
 117static void randomize_buffer(void *buf, size_t buflen)
 118{
 119	int *p = (int *)buf;
 120	size_t words = buflen / sizeof(int);
 121	size_t leftover = buflen % sizeof(int);
 122
 123	if (!buflen)
 124		return;
 125
 126	while (words--)
 127		*p++ = rand();
 128
 129	if (leftover) {
 130		int tmp = rand();
 131
 132		memcpy(buf + buflen - leftover, &tmp, leftover);
 133	}
 134
 135	return;
 136}
 137
 138static int unshare_open(void)
 139{
 140	const char *netns_path = "/proc/self/ns/net";
 141	int fd;
 142
 143	if (unshare(CLONE_NEWNET) != 0) {
 144		pr_err("unshare()");
 145		return -1;
 146	}
 147
 148	fd = open(netns_path, O_RDONLY);
 149	if (fd <= 0) {
 150		pr_err("open(%s)", netns_path);
 151		return -1;
 152	}
 153
 154	return fd;
 155}
 156
 157static int switch_ns(int fd)
 158{
 159	if (setns(fd, CLONE_NEWNET)) {
 160		pr_err("setns()");
 161		return -1;
 162	}
 163	return 0;
 164}
 165
 166/*
 167 * Running the test inside a new parent net namespace to bother less
 168 * about cleanup on error-path.
 169 */
 170static int init_namespaces(void)
 171{
 172	nsfd_parent = unshare_open();
 173	if (nsfd_parent <= 0)
 174		return -1;
 175
 176	nsfd_childa = unshare_open();
 177	if (nsfd_childa <= 0)
 178		return -1;
 179
 180	if (switch_ns(nsfd_parent))
 181		return -1;
 182
 183	nsfd_childb = unshare_open();
 184	if (nsfd_childb <= 0)
 185		return -1;
 186
 187	if (switch_ns(nsfd_parent))
 188		return -1;
 189	return 0;
 190}
 191
 192static int netlink_sock(int *sock, uint32_t *seq_nr, int proto)
 193{
 194	if (*sock > 0) {
 195		seq_nr++;
 196		return 0;
 197	}
 198
 199	*sock = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, proto);
 200	if (*sock <= 0) {
 201		pr_err("socket(AF_NETLINK)");
 202		return -1;
 203	}
 204
 205	randomize_buffer(seq_nr, sizeof(*seq_nr));
 206
 207	return 0;
 208}
 209
 210static inline struct rtattr *rtattr_hdr(struct nlmsghdr *nh)
 211{
 212	return (struct rtattr *)((char *)(nh) + RTA_ALIGN((nh)->nlmsg_len));
 213}
 214
 215static int rtattr_pack(struct nlmsghdr *nh, size_t req_sz,
 216		unsigned short rta_type, const void *payload, size_t size)
 217{
 218	/* NLMSG_ALIGNTO == RTA_ALIGNTO, nlmsg_len already aligned */
 219	struct rtattr *attr = rtattr_hdr(nh);
 220	size_t nl_size = RTA_ALIGN(nh->nlmsg_len) + RTA_LENGTH(size);
 221
 222	if (req_sz < nl_size) {
 223		printk("req buf is too small: %zu < %zu", req_sz, nl_size);
 224		return -1;
 225	}
 226	nh->nlmsg_len = nl_size;
 227
 228	attr->rta_len = RTA_LENGTH(size);
 229	attr->rta_type = rta_type;
 230	if (payload)
 231		memcpy(RTA_DATA(attr), payload, size);
 232
 233	return 0;
 234}
 235
 236static struct rtattr *_rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
 237		unsigned short rta_type, const void *payload, size_t size)
 238{
 239	struct rtattr *ret = rtattr_hdr(nh);
 240
 241	if (rtattr_pack(nh, req_sz, rta_type, payload, size))
 242		return 0;
 243
 244	return ret;
 245}
 246
 247static inline struct rtattr *rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
 248		unsigned short rta_type)
 249{
 250	return _rtattr_begin(nh, req_sz, rta_type, 0, 0);
 251}
 252
 253static inline void rtattr_end(struct nlmsghdr *nh, struct rtattr *attr)
 254{
 255	char *nlmsg_end = (char *)nh + nh->nlmsg_len;
 256
 257	attr->rta_len = nlmsg_end - (char *)attr;
 258}
 259
 260static int veth_pack_peerb(struct nlmsghdr *nh, size_t req_sz,
 261		const char *peer, int ns)
 262{
 263	struct ifinfomsg pi;
 264	struct rtattr *peer_attr;
 265
 266	memset(&pi, 0, sizeof(pi));
 267	pi.ifi_family	= AF_UNSPEC;
 268	pi.ifi_change	= 0xFFFFFFFF;
 269
 270	peer_attr = _rtattr_begin(nh, req_sz, VETH_INFO_PEER, &pi, sizeof(pi));
 271	if (!peer_attr)
 272		return -1;
 273
 274	if (rtattr_pack(nh, req_sz, IFLA_IFNAME, peer, strlen(peer)))
 275		return -1;
 276
 277	if (rtattr_pack(nh, req_sz, IFLA_NET_NS_FD, &ns, sizeof(ns)))
 278		return -1;
 279
 280	rtattr_end(nh, peer_attr);
 281
 282	return 0;
 283}
 284
 285static int netlink_check_answer(int sock)
 286{
 287	struct nlmsgerror {
 288		struct nlmsghdr hdr;
 289		int error;
 290		struct nlmsghdr orig_msg;
 291	} answer;
 292
 293	if (recv(sock, &answer, sizeof(answer), 0) < 0) {
 294		pr_err("recv()");
 295		return -1;
 296	} else if (answer.hdr.nlmsg_type != NLMSG_ERROR) {
 297		printk("expected NLMSG_ERROR, got %d", (int)answer.hdr.nlmsg_type);
 298		return -1;
 299	} else if (answer.error) {
 300		printk("NLMSG_ERROR: %d: %s",
 301			answer.error, strerror(-answer.error));
 302		return answer.error;
 303	}
 304
 305	return 0;
 306}
 307
 308static int veth_add(int sock, uint32_t seq, const char *peera, int ns_a,
 309		const char *peerb, int ns_b)
 310{
 311	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
 312	struct {
 313		struct nlmsghdr		nh;
 314		struct ifinfomsg	info;
 315		char			attrbuf[MAX_PAYLOAD];
 316	} req;
 317	const char veth_type[] = "veth";
 318	struct rtattr *link_info, *info_data;
 319
 320	memset(&req, 0, sizeof(req));
 321	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
 322	req.nh.nlmsg_type	= RTM_NEWLINK;
 323	req.nh.nlmsg_flags	= flags;
 324	req.nh.nlmsg_seq	= seq;
 325	req.info.ifi_family	= AF_UNSPEC;
 326	req.info.ifi_change	= 0xFFFFFFFF;
 327
 328	if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, peera, strlen(peera)))
 329		return -1;
 330
 331	if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD, &ns_a, sizeof(ns_a)))
 332		return -1;
 333
 334	link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO);
 335	if (!link_info)
 336		return -1;
 337
 338	if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, veth_type, sizeof(veth_type)))
 339		return -1;
 340
 341	info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA);
 342	if (!info_data)
 343		return -1;
 344
 345	if (veth_pack_peerb(&req.nh, sizeof(req), peerb, ns_b))
 346		return -1;
 347
 348	rtattr_end(&req.nh, info_data);
 349	rtattr_end(&req.nh, link_info);
 350
 351	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
 352		pr_err("send()");
 353		return -1;
 354	}
 355	return netlink_check_answer(sock);
 356}
 357
 358static int ip4_addr_set(int sock, uint32_t seq, const char *intf,
 359		struct in_addr addr, uint8_t prefix)
 360{
 361	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
 362	struct {
 363		struct nlmsghdr		nh;
 364		struct ifaddrmsg	info;
 365		char			attrbuf[MAX_PAYLOAD];
 366	} req;
 367
 368	memset(&req, 0, sizeof(req));
 369	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
 370	req.nh.nlmsg_type	= RTM_NEWADDR;
 371	req.nh.nlmsg_flags	= flags;
 372	req.nh.nlmsg_seq	= seq;
 373	req.info.ifa_family	= AF_INET;
 374	req.info.ifa_prefixlen	= prefix;
 375	req.info.ifa_index	= if_nametoindex(intf);
 376
 377#ifdef DEBUG
 378	{
 379		char addr_str[IPV4_STR_SZ] = {};
 380
 381		strncpy(addr_str, inet_ntoa(addr), IPV4_STR_SZ - 1);
 382
 383		printk("ip addr set %s", addr_str);
 384	}
 385#endif
 386
 387	if (rtattr_pack(&req.nh, sizeof(req), IFA_LOCAL, &addr, sizeof(addr)))
 388		return -1;
 389
 390	if (rtattr_pack(&req.nh, sizeof(req), IFA_ADDRESS, &addr, sizeof(addr)))
 391		return -1;
 392
 393	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
 394		pr_err("send()");
 395		return -1;
 396	}
 397	return netlink_check_answer(sock);
 398}
 399
 400static int link_set_up(int sock, uint32_t seq, const char *intf)
 401{
 402	struct {
 403		struct nlmsghdr		nh;
 404		struct ifinfomsg	info;
 405		char			attrbuf[MAX_PAYLOAD];
 406	} req;
 407
 408	memset(&req, 0, sizeof(req));
 409	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
 410	req.nh.nlmsg_type	= RTM_NEWLINK;
 411	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
 412	req.nh.nlmsg_seq	= seq;
 413	req.info.ifi_family	= AF_UNSPEC;
 414	req.info.ifi_change	= 0xFFFFFFFF;
 415	req.info.ifi_index	= if_nametoindex(intf);
 416	req.info.ifi_flags	= IFF_UP;
 417	req.info.ifi_change	= IFF_UP;
 418
 419	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
 420		pr_err("send()");
 421		return -1;
 422	}
 423	return netlink_check_answer(sock);
 424}
 425
 426static int ip4_route_set(int sock, uint32_t seq, const char *intf,
 427		struct in_addr src, struct in_addr dst)
 428{
 429	struct {
 430		struct nlmsghdr	nh;
 431		struct rtmsg	rt;
 432		char		attrbuf[MAX_PAYLOAD];
 433	} req;
 434	unsigned int index = if_nametoindex(intf);
 435
 436	memset(&req, 0, sizeof(req));
 437	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.rt));
 438	req.nh.nlmsg_type	= RTM_NEWROUTE;
 439	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE;
 440	req.nh.nlmsg_seq	= seq;
 441	req.rt.rtm_family	= AF_INET;
 442	req.rt.rtm_dst_len	= 32;
 443	req.rt.rtm_table	= RT_TABLE_MAIN;
 444	req.rt.rtm_protocol	= RTPROT_BOOT;
 445	req.rt.rtm_scope	= RT_SCOPE_LINK;
 446	req.rt.rtm_type		= RTN_UNICAST;
 447
 448	if (rtattr_pack(&req.nh, sizeof(req), RTA_DST, &dst, sizeof(dst)))
 449		return -1;
 450
 451	if (rtattr_pack(&req.nh, sizeof(req), RTA_PREFSRC, &src, sizeof(src)))
 452		return -1;
 453
 454	if (rtattr_pack(&req.nh, sizeof(req), RTA_OIF, &index, sizeof(index)))
 455		return -1;
 456
 457	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
 458		pr_err("send()");
 459		return -1;
 460	}
 461
 462	return netlink_check_answer(sock);
 463}
 464
 465static int tunnel_set_route(int route_sock, uint32_t *route_seq, char *veth,
 466		struct in_addr tunsrc, struct in_addr tundst)
 467{
 468	if (ip4_addr_set(route_sock, (*route_seq)++, "lo",
 469			tunsrc, PREFIX_LEN)) {
 470		printk("Failed to set ipv4 addr");
 471		return -1;
 472	}
 473
 474	if (ip4_route_set(route_sock, (*route_seq)++, veth, tunsrc, tundst)) {
 475		printk("Failed to set ipv4 route");
 476		return -1;
 477	}
 478
 479	return 0;
 480}
 481
 482static int init_child(int nsfd, char *veth, unsigned int src, unsigned int dst)
 483{
 484	struct in_addr intsrc = inet_makeaddr(INADDR_B, src);
 485	struct in_addr tunsrc = inet_makeaddr(INADDR_A, src);
 486	struct in_addr tundst = inet_makeaddr(INADDR_A, dst);
 487	int route_sock = -1, ret = -1;
 488	uint32_t route_seq;
 489
 490	if (switch_ns(nsfd))
 491		return -1;
 492
 493	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) {
 494		printk("Failed to open netlink route socket in child");
 495		return -1;
 496	}
 497
 498	if (ip4_addr_set(route_sock, route_seq++, veth, intsrc, PREFIX_LEN)) {
 499		printk("Failed to set ipv4 addr");
 500		goto err;
 501	}
 502
 503	if (link_set_up(route_sock, route_seq++, veth)) {
 504		printk("Failed to bring up %s", veth);
 505		goto err;
 506	}
 507
 508	if (tunnel_set_route(route_sock, &route_seq, veth, tunsrc, tundst)) {
 509		printk("Failed to add tunnel route on %s", veth);
 510		goto err;
 511	}
 512	ret = 0;
 513
 514err:
 515	close(route_sock);
 516	return ret;
 517}
 518
 519#define ALGO_LEN	64
 520enum desc_type {
 521	CREATE_TUNNEL	= 0,
 522	ALLOCATE_SPI,
 523	MONITOR_ACQUIRE,
 524	EXPIRE_STATE,
 525	EXPIRE_POLICY,
 526	SPDINFO_ATTRS,
 527};
 528const char *desc_name[] = {
 529	"create tunnel",
 530	"alloc spi",
 531	"monitor acquire",
 532	"expire state",
 533	"expire policy",
 534	"spdinfo attributes",
 535	""
 536};
 537struct xfrm_desc {
 538	enum desc_type	type;
 539	uint8_t		proto;
 540	char		a_algo[ALGO_LEN];
 541	char		e_algo[ALGO_LEN];
 542	char		c_algo[ALGO_LEN];
 543	char		ae_algo[ALGO_LEN];
 544	unsigned int	icv_len;
 545	/* unsigned key_len; */
 546};
 547
 548enum msg_type {
 549	MSG_ACK		= 0,
 550	MSG_EXIT,
 551	MSG_PING,
 552	MSG_XFRM_PREPARE,
 553	MSG_XFRM_ADD,
 554	MSG_XFRM_DEL,
 555	MSG_XFRM_CLEANUP,
 556};
 557
 558struct test_desc {
 559	enum msg_type type;
 560	union {
 561		struct {
 562			in_addr_t reply_ip;
 563			unsigned int port;
 564		} ping;
 565		struct xfrm_desc xfrm_desc;
 566	} body;
 567};
 568
 569struct test_result {
 570	struct xfrm_desc desc;
 571	unsigned int res;
 572};
 573
 574static void write_test_result(unsigned int res, struct xfrm_desc *d)
 575{
 576	struct test_result tr = {};
 577	ssize_t ret;
 578
 579	tr.desc = *d;
 580	tr.res = res;
 581
 582	ret = write(results_fd[1], &tr, sizeof(tr));
 583	if (ret != sizeof(tr))
 584		pr_err("Failed to write the result in pipe %zd", ret);
 585}
 586
 587static void write_msg(int fd, struct test_desc *msg, bool exit_of_fail)
 588{
 589	ssize_t bytes = write(fd, msg, sizeof(*msg));
 590
 591	/* Make sure that write/read is atomic to a pipe */
 592	BUILD_BUG_ON(sizeof(struct test_desc) > PIPE_BUF);
 593
 594	if (bytes < 0) {
 595		pr_err("write()");
 596		if (exit_of_fail)
 597			exit(KSFT_FAIL);
 598	}
 599	if (bytes != sizeof(*msg)) {
 600		pr_err("sent part of the message %zd/%zu", bytes, sizeof(*msg));
 601		if (exit_of_fail)
 602			exit(KSFT_FAIL);
 603	}
 604}
 605
 606static void read_msg(int fd, struct test_desc *msg, bool exit_of_fail)
 607{
 608	ssize_t bytes = read(fd, msg, sizeof(*msg));
 609
 610	if (bytes < 0) {
 611		pr_err("read()");
 612		if (exit_of_fail)
 613			exit(KSFT_FAIL);
 614	}
 615	if (bytes != sizeof(*msg)) {
 616		pr_err("got incomplete message %zd/%zu", bytes, sizeof(*msg));
 617		if (exit_of_fail)
 618			exit(KSFT_FAIL);
 619	}
 620}
 621
 622static int udp_ping_init(struct in_addr listen_ip, unsigned int u_timeout,
 623		unsigned int *server_port, int sock[2])
 624{
 625	struct sockaddr_in server;
 626	struct timeval t = { .tv_sec = 0, .tv_usec = u_timeout };
 627	socklen_t s_len = sizeof(server);
 628
 629	sock[0] = socket(AF_INET, SOCK_DGRAM, 0);
 630	if (sock[0] < 0) {
 631		pr_err("socket()");
 632		return -1;
 633	}
 634
 635	server.sin_family	= AF_INET;
 636	server.sin_port		= 0;
 637	memcpy(&server.sin_addr.s_addr, &listen_ip, sizeof(struct in_addr));
 638
 639	if (bind(sock[0], (struct sockaddr *)&server, s_len)) {
 640		pr_err("bind()");
 641		goto err_close_server;
 642	}
 643
 644	if (getsockname(sock[0], (struct sockaddr *)&server, &s_len)) {
 645		pr_err("getsockname()");
 646		goto err_close_server;
 647	}
 648
 649	*server_port = ntohs(server.sin_port);
 650
 651	if (setsockopt(sock[0], SOL_SOCKET, SO_RCVTIMEO, (const char *)&t, sizeof t)) {
 652		pr_err("setsockopt()");
 653		goto err_close_server;
 654	}
 655
 656	sock[1] = socket(AF_INET, SOCK_DGRAM, 0);
 657	if (sock[1] < 0) {
 658		pr_err("socket()");
 659		goto err_close_server;
 660	}
 661
 662	return 0;
 663
 664err_close_server:
 665	close(sock[0]);
 666	return -1;
 667}
 668
 669static int udp_ping_send(int sock[2], in_addr_t dest_ip, unsigned int port,
 670		char *buf, size_t buf_len)
 671{
 672	struct sockaddr_in server;
 673	const struct sockaddr *dest_addr = (struct sockaddr *)&server;
 674	char *sock_buf[buf_len];
 675	ssize_t r_bytes, s_bytes;
 676
 677	server.sin_family	= AF_INET;
 678	server.sin_port		= htons(port);
 679	server.sin_addr.s_addr	= dest_ip;
 680
 681	s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
 682	if (s_bytes < 0) {
 683		pr_err("sendto()");
 684		return -1;
 685	} else if (s_bytes != buf_len) {
 686		printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
 687		return -1;
 688	}
 689
 690	r_bytes = recv(sock[0], sock_buf, buf_len, 0);
 691	if (r_bytes < 0) {
 692		if (errno != EAGAIN)
 693			pr_err("recv()");
 694		return -1;
 695	} else if (r_bytes == 0) { /* EOF */
 696		printk("EOF on reply to ping");
 697		return -1;
 698	} else if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
 699		printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
 700		return -1;
 701	}
 702
 703	return 0;
 704}
 705
 706static int udp_ping_reply(int sock[2], in_addr_t dest_ip, unsigned int port,
 707		char *buf, size_t buf_len)
 708{
 709	struct sockaddr_in server;
 710	const struct sockaddr *dest_addr = (struct sockaddr *)&server;
 711	char *sock_buf[buf_len];
 712	ssize_t r_bytes, s_bytes;
 713
 714	server.sin_family	= AF_INET;
 715	server.sin_port		= htons(port);
 716	server.sin_addr.s_addr	= dest_ip;
 717
 718	r_bytes = recv(sock[0], sock_buf, buf_len, 0);
 719	if (r_bytes < 0) {
 720		if (errno != EAGAIN)
 721			pr_err("recv()");
 722		return -1;
 723	}
 724	if (r_bytes == 0) { /* EOF */
 725		printk("EOF on reply to ping");
 726		return -1;
 727	}
 728	if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
 729		printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
 730		return -1;
 731	}
 732
 733	s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
 734	if (s_bytes < 0) {
 735		pr_err("sendto()");
 736		return -1;
 737	} else if (s_bytes != buf_len) {
 738		printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
 739		return -1;
 740	}
 741
 742	return 0;
 743}
 744
 745typedef int (*ping_f)(int sock[2], in_addr_t dest_ip, unsigned int port,
 746		char *buf, size_t buf_len);
 747static int do_ping(int cmd_fd, char *buf, size_t buf_len, struct in_addr from,
 748		bool init_side, int d_port, in_addr_t to, ping_f func)
 749{
 750	struct test_desc msg;
 751	unsigned int s_port, i, ping_succeeded = 0;
 752	int ping_sock[2];
 753	char to_str[IPV4_STR_SZ] = {}, from_str[IPV4_STR_SZ] = {};
 754
 755	if (udp_ping_init(from, ping_timeout, &s_port, ping_sock)) {
 756		printk("Failed to init ping");
 757		return -1;
 758	}
 759
 760	memset(&msg, 0, sizeof(msg));
 761	msg.type		= MSG_PING;
 762	msg.body.ping.port	= s_port;
 763	memcpy(&msg.body.ping.reply_ip, &from, sizeof(from));
 764
 765	write_msg(cmd_fd, &msg, 0);
 766	if (init_side) {
 767		/* The other end sends ip to ping */
 768		read_msg(cmd_fd, &msg, 0);
 769		if (msg.type != MSG_PING)
 770			return -1;
 771		to = msg.body.ping.reply_ip;
 772		d_port = msg.body.ping.port;
 773	}
 774
 775	for (i = 0; i < ping_count ; i++) {
 776		struct timespec sleep_time = {
 777			.tv_sec = 0,
 778			.tv_nsec = ping_delay_nsec,
 779		};
 780
 781		ping_succeeded += !func(ping_sock, to, d_port, buf, page_size);
 782		nanosleep(&sleep_time, 0);
 783	}
 784
 785	close(ping_sock[0]);
 786	close(ping_sock[1]);
 787
 788	strncpy(to_str, inet_ntoa(*(struct in_addr *)&to), IPV4_STR_SZ - 1);
 789	strncpy(from_str, inet_ntoa(from), IPV4_STR_SZ - 1);
 790
 791	if (ping_succeeded < ping_success) {
 792		printk("ping (%s) %s->%s failed %u/%u times",
 793			init_side ? "send" : "reply", from_str, to_str,
 794			ping_count - ping_succeeded, ping_count);
 795		return -1;
 796	}
 797
 798#ifdef DEBUG
 799	printk("ping (%s) %s->%s succeeded %u/%u times",
 800		init_side ? "send" : "reply", from_str, to_str,
 801		ping_succeeded, ping_count);
 802#endif
 803
 804	return 0;
 805}
 806
 807static int xfrm_fill_key(char *name, char *buf,
 808		size_t buf_len, unsigned int *key_len)
 809{
 810	int i;
 811
 812	for (i = 0; i < XFRM_ALGO_NR_KEYS; i++) {
 813		if (strncmp(name, xfrm_key_entries[i].algo_name, ALGO_LEN) == 0)
 814			*key_len = xfrm_key_entries[i].key_len;
 815	}
 816
 817	if (*key_len > buf_len) {
 818		printk("Can't pack a key - too big for buffer");
 819		return -1;
 820	}
 821
 822	randomize_buffer(buf, *key_len);
 823
 824	return 0;
 825}
 826
 827static int xfrm_state_pack_algo(struct nlmsghdr *nh, size_t req_sz,
 828		struct xfrm_desc *desc)
 829{
 830	struct {
 831		union {
 832			struct xfrm_algo	alg;
 833			struct xfrm_algo_aead	aead;
 834			struct xfrm_algo_auth	auth;
 835		} u;
 836		char buf[XFRM_ALGO_KEY_BUF_SIZE];
 837	} alg = {};
 838	size_t alen, elen, clen, aelen;
 839	unsigned short type;
 840
 841	alen = strlen(desc->a_algo);
 842	elen = strlen(desc->e_algo);
 843	clen = strlen(desc->c_algo);
 844	aelen = strlen(desc->ae_algo);
 845
 846	/* Verify desc */
 847	switch (desc->proto) {
 848	case IPPROTO_AH:
 849		if (!alen || elen || clen || aelen) {
 850			printk("BUG: buggy ah desc");
 851			return -1;
 852		}
 853		strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN - 1);
 854		if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
 855				sizeof(alg.buf), &alg.u.alg.alg_key_len))
 856			return -1;
 857		type = XFRMA_ALG_AUTH;
 858		break;
 859	case IPPROTO_COMP:
 860		if (!clen || elen || alen || aelen) {
 861			printk("BUG: buggy comp desc");
 862			return -1;
 863		}
 864		strncpy(alg.u.alg.alg_name, desc->c_algo, ALGO_LEN - 1);
 865		if (xfrm_fill_key(desc->c_algo, alg.u.alg.alg_key,
 866				sizeof(alg.buf), &alg.u.alg.alg_key_len))
 867			return -1;
 868		type = XFRMA_ALG_COMP;
 869		break;
 870	case IPPROTO_ESP:
 871		if (!((alen && elen) ^ aelen) || clen) {
 872			printk("BUG: buggy esp desc");
 873			return -1;
 874		}
 875		if (aelen) {
 876			alg.u.aead.alg_icv_len = desc->icv_len;
 877			strncpy(alg.u.aead.alg_name, desc->ae_algo, ALGO_LEN - 1);
 878			if (xfrm_fill_key(desc->ae_algo, alg.u.aead.alg_key,
 879						sizeof(alg.buf), &alg.u.aead.alg_key_len))
 880				return -1;
 881			type = XFRMA_ALG_AEAD;
 882		} else {
 883
 884			strncpy(alg.u.alg.alg_name, desc->e_algo, ALGO_LEN - 1);
 885			type = XFRMA_ALG_CRYPT;
 886			if (xfrm_fill_key(desc->e_algo, alg.u.alg.alg_key,
 887						sizeof(alg.buf), &alg.u.alg.alg_key_len))
 888				return -1;
 889			if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
 890				return -1;
 891
 892			strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN);
 893			type = XFRMA_ALG_AUTH;
 894			if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
 895						sizeof(alg.buf), &alg.u.alg.alg_key_len))
 896				return -1;
 897		}
 898		break;
 899	default:
 900		printk("BUG: unknown proto in desc");
 901		return -1;
 902	}
 903
 904	if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
 905		return -1;
 906
 907	return 0;
 908}
 909
 910static inline uint32_t gen_spi(struct in_addr src)
 911{
 912	return htonl(inet_lnaof(src));
 913}
 914
 915static int xfrm_state_add(int xfrm_sock, uint32_t seq, uint32_t spi,
 916		struct in_addr src, struct in_addr dst,
 917		struct xfrm_desc *desc)
 918{
 919	struct {
 920		struct nlmsghdr		nh;
 921		struct xfrm_usersa_info	info;
 922		char			attrbuf[MAX_PAYLOAD];
 923	} req;
 924
 925	memset(&req, 0, sizeof(req));
 926	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
 927	req.nh.nlmsg_type	= XFRM_MSG_NEWSA;
 928	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
 929	req.nh.nlmsg_seq	= seq;
 930
 931	/* Fill selector. */
 932	memcpy(&req.info.sel.daddr, &dst, sizeof(dst));
 933	memcpy(&req.info.sel.saddr, &src, sizeof(src));
 934	req.info.sel.family		= AF_INET;
 935	req.info.sel.prefixlen_d	= PREFIX_LEN;
 936	req.info.sel.prefixlen_s	= PREFIX_LEN;
 937
 938	/* Fill id */
 939	memcpy(&req.info.id.daddr, &dst, sizeof(dst));
 940	/* Note: zero-spi cannot be deleted */
 941	req.info.id.spi = spi;
 942	req.info.id.proto	= desc->proto;
 943
 944	memcpy(&req.info.saddr, &src, sizeof(src));
 945
 946	/* Fill lifteme_cfg */
 947	req.info.lft.soft_byte_limit	= XFRM_INF;
 948	req.info.lft.hard_byte_limit	= XFRM_INF;
 949	req.info.lft.soft_packet_limit	= XFRM_INF;
 950	req.info.lft.hard_packet_limit	= XFRM_INF;
 951
 952	req.info.family		= AF_INET;
 953	req.info.mode		= XFRM_MODE_TUNNEL;
 954
 955	if (xfrm_state_pack_algo(&req.nh, sizeof(req), desc))
 956		return -1;
 957
 958	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
 959		pr_err("send()");
 960		return -1;
 961	}
 962
 963	return netlink_check_answer(xfrm_sock);
 964}
 965
 966static bool xfrm_usersa_found(struct xfrm_usersa_info *info, uint32_t spi,
 967		struct in_addr src, struct in_addr dst,
 968		struct xfrm_desc *desc)
 969{
 970	if (memcmp(&info->sel.daddr, &dst, sizeof(dst)))
 971		return false;
 972
 973	if (memcmp(&info->sel.saddr, &src, sizeof(src)))
 974		return false;
 975
 976	if (info->sel.family != AF_INET					||
 977			info->sel.prefixlen_d != PREFIX_LEN		||
 978			info->sel.prefixlen_s != PREFIX_LEN)
 979		return false;
 980
 981	if (info->id.spi != spi || info->id.proto != desc->proto)
 982		return false;
 983
 984	if (memcmp(&info->id.daddr, &dst, sizeof(dst)))
 985		return false;
 986
 987	if (memcmp(&info->saddr, &src, sizeof(src)))
 988		return false;
 989
 990	if (info->lft.soft_byte_limit != XFRM_INF			||
 991			info->lft.hard_byte_limit != XFRM_INF		||
 992			info->lft.soft_packet_limit != XFRM_INF		||
 993			info->lft.hard_packet_limit != XFRM_INF)
 994		return false;
 995
 996	if (info->family != AF_INET || info->mode != XFRM_MODE_TUNNEL)
 997		return false;
 998
 999	/* XXX: check xfrm algo, see xfrm_state_pack_algo(). */
1000
1001	return true;
1002}
1003
1004static int xfrm_state_check(int xfrm_sock, uint32_t seq, uint32_t spi,
1005		struct in_addr src, struct in_addr dst,
1006		struct xfrm_desc *desc)
1007{
1008	struct {
1009		struct nlmsghdr		nh;
1010		char			attrbuf[MAX_PAYLOAD];
1011	} req;
1012	struct {
1013		struct nlmsghdr		nh;
1014		union {
1015			struct xfrm_usersa_info	info;
1016			int error;
1017		};
1018		char			attrbuf[MAX_PAYLOAD];
1019	} answer;
1020	struct xfrm_address_filter filter = {};
1021	bool found = false;
1022
1023
1024	memset(&req, 0, sizeof(req));
1025	req.nh.nlmsg_len	= NLMSG_LENGTH(0);
1026	req.nh.nlmsg_type	= XFRM_MSG_GETSA;
1027	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_DUMP;
1028	req.nh.nlmsg_seq	= seq;
1029
1030	/*
1031	 * Add dump filter by source address as there may be other tunnels
1032	 * in this netns (if tests run in parallel).
1033	 */
1034	filter.family = AF_INET;
1035	filter.splen = 0x1f;	/* 0xffffffff mask see addr_match() */
1036	memcpy(&filter.saddr, &src, sizeof(src));
1037	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_ADDRESS_FILTER,
1038				&filter, sizeof(filter)))
1039		return -1;
1040
1041	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1042		pr_err("send()");
1043		return -1;
1044	}
1045
1046	while (1) {
1047		if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1048			pr_err("recv()");
1049			return -1;
1050		}
1051		if (answer.nh.nlmsg_type == NLMSG_ERROR) {
1052			printk("NLMSG_ERROR: %d: %s",
1053				answer.error, strerror(-answer.error));
1054			return -1;
1055		} else if (answer.nh.nlmsg_type == NLMSG_DONE) {
1056			if (found)
1057				return 0;
1058			printk("didn't find allocated xfrm state in dump");
1059			return -1;
1060		} else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1061			if (xfrm_usersa_found(&answer.info, spi, src, dst, desc))
1062				found = true;
1063		}
1064	}
1065}
1066
1067static int xfrm_set(int xfrm_sock, uint32_t *seq,
1068		struct in_addr src, struct in_addr dst,
1069		struct in_addr tunsrc, struct in_addr tundst,
1070		struct xfrm_desc *desc)
1071{
1072	int err;
1073
1074	err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1075	if (err) {
1076		printk("Failed to add xfrm state");
1077		return -1;
1078	}
1079
1080	err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1081	if (err) {
1082		printk("Failed to add xfrm state");
1083		return -1;
1084	}
1085
1086	/* Check dumps for XFRM_MSG_GETSA */
1087	err = xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1088	err |= xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1089	if (err) {
1090		printk("Failed to check xfrm state");
1091		return -1;
1092	}
1093
1094	return 0;
1095}
1096
1097static int xfrm_policy_add(int xfrm_sock, uint32_t seq, uint32_t spi,
1098		struct in_addr src, struct in_addr dst, uint8_t dir,
1099		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1100{
1101	struct {
1102		struct nlmsghdr			nh;
1103		struct xfrm_userpolicy_info	info;
1104		char				attrbuf[MAX_PAYLOAD];
1105	} req;
1106	struct xfrm_user_tmpl tmpl;
1107
1108	memset(&req, 0, sizeof(req));
1109	memset(&tmpl, 0, sizeof(tmpl));
1110	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
1111	req.nh.nlmsg_type	= XFRM_MSG_NEWPOLICY;
1112	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1113	req.nh.nlmsg_seq	= seq;
1114
1115	/* Fill selector. */
1116	memcpy(&req.info.sel.daddr, &dst, sizeof(tundst));
1117	memcpy(&req.info.sel.saddr, &src, sizeof(tunsrc));
1118	req.info.sel.family		= AF_INET;
1119	req.info.sel.prefixlen_d	= PREFIX_LEN;
1120	req.info.sel.prefixlen_s	= PREFIX_LEN;
1121
1122	/* Fill lifteme_cfg */
1123	req.info.lft.soft_byte_limit	= XFRM_INF;
1124	req.info.lft.hard_byte_limit	= XFRM_INF;
1125	req.info.lft.soft_packet_limit	= XFRM_INF;
1126	req.info.lft.hard_packet_limit	= XFRM_INF;
1127
1128	req.info.dir = dir;
1129
1130	/* Fill tmpl */
1131	memcpy(&tmpl.id.daddr, &dst, sizeof(dst));
1132	/* Note: zero-spi cannot be deleted */
1133	tmpl.id.spi = spi;
1134	tmpl.id.proto	= proto;
1135	tmpl.family	= AF_INET;
1136	memcpy(&tmpl.saddr, &src, sizeof(src));
1137	tmpl.mode	= XFRM_MODE_TUNNEL;
1138	tmpl.aalgos = (~(uint32_t)0);
1139	tmpl.ealgos = (~(uint32_t)0);
1140	tmpl.calgos = (~(uint32_t)0);
1141
1142	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &tmpl, sizeof(tmpl)))
1143		return -1;
1144
1145	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1146		pr_err("send()");
1147		return -1;
1148	}
1149
1150	return netlink_check_answer(xfrm_sock);
1151}
1152
1153static int xfrm_prepare(int xfrm_sock, uint32_t *seq,
1154		struct in_addr src, struct in_addr dst,
1155		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1156{
1157	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1158				XFRM_POLICY_OUT, tunsrc, tundst, proto)) {
1159		printk("Failed to add xfrm policy");
1160		return -1;
1161	}
1162
1163	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src,
1164				XFRM_POLICY_IN, tunsrc, tundst, proto)) {
1165		printk("Failed to add xfrm policy");
1166		return -1;
1167	}
1168
1169	return 0;
1170}
1171
1172static int xfrm_policy_del(int xfrm_sock, uint32_t seq,
1173		struct in_addr src, struct in_addr dst, uint8_t dir,
1174		struct in_addr tunsrc, struct in_addr tundst)
1175{
1176	struct {
1177		struct nlmsghdr			nh;
1178		struct xfrm_userpolicy_id	id;
1179		char				attrbuf[MAX_PAYLOAD];
1180	} req;
1181
1182	memset(&req, 0, sizeof(req));
1183	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.id));
1184	req.nh.nlmsg_type	= XFRM_MSG_DELPOLICY;
1185	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1186	req.nh.nlmsg_seq	= seq;
1187
1188	/* Fill id */
1189	memcpy(&req.id.sel.daddr, &dst, sizeof(tundst));
1190	memcpy(&req.id.sel.saddr, &src, sizeof(tunsrc));
1191	req.id.sel.family		= AF_INET;
1192	req.id.sel.prefixlen_d		= PREFIX_LEN;
1193	req.id.sel.prefixlen_s		= PREFIX_LEN;
1194	req.id.dir = dir;
1195
1196	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1197		pr_err("send()");
1198		return -1;
1199	}
1200
1201	return netlink_check_answer(xfrm_sock);
1202}
1203
1204static int xfrm_cleanup(int xfrm_sock, uint32_t *seq,
1205		struct in_addr src, struct in_addr dst,
1206		struct in_addr tunsrc, struct in_addr tundst)
1207{
1208	if (xfrm_policy_del(xfrm_sock, (*seq)++, src, dst,
1209				XFRM_POLICY_OUT, tunsrc, tundst)) {
1210		printk("Failed to add xfrm policy");
1211		return -1;
1212	}
1213
1214	if (xfrm_policy_del(xfrm_sock, (*seq)++, dst, src,
1215				XFRM_POLICY_IN, tunsrc, tundst)) {
1216		printk("Failed to add xfrm policy");
1217		return -1;
1218	}
1219
1220	return 0;
1221}
1222
1223static int xfrm_state_del(int xfrm_sock, uint32_t seq, uint32_t spi,
1224		struct in_addr src, struct in_addr dst, uint8_t proto)
1225{
1226	struct {
1227		struct nlmsghdr		nh;
1228		struct xfrm_usersa_id	id;
1229		char			attrbuf[MAX_PAYLOAD];
1230	} req;
1231	xfrm_address_t saddr = {};
1232
1233	memset(&req, 0, sizeof(req));
1234	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.id));
1235	req.nh.nlmsg_type	= XFRM_MSG_DELSA;
1236	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1237	req.nh.nlmsg_seq	= seq;
1238
1239	memcpy(&req.id.daddr, &dst, sizeof(dst));
1240	req.id.family		= AF_INET;
1241	req.id.proto		= proto;
1242	/* Note: zero-spi cannot be deleted */
1243	req.id.spi = spi;
1244
1245	memcpy(&saddr, &src, sizeof(src));
1246	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SRCADDR, &saddr, sizeof(saddr)))
1247		return -1;
1248
1249	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1250		pr_err("send()");
1251		return -1;
1252	}
1253
1254	return netlink_check_answer(xfrm_sock);
1255}
1256
1257static int xfrm_delete(int xfrm_sock, uint32_t *seq,
1258		struct in_addr src, struct in_addr dst,
1259		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1260{
1261	if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), src, dst, proto)) {
1262		printk("Failed to remove xfrm state");
1263		return -1;
1264	}
1265
1266	if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), dst, src, proto)) {
1267		printk("Failed to remove xfrm state");
1268		return -1;
1269	}
1270
1271	return 0;
1272}
1273
1274static int xfrm_state_allocspi(int xfrm_sock, uint32_t *seq,
1275		uint32_t spi, uint8_t proto)
1276{
1277	struct {
1278		struct nlmsghdr			nh;
1279		struct xfrm_userspi_info	spi;
1280	} req;
1281	struct {
1282		struct nlmsghdr			nh;
1283		union {
1284			struct xfrm_usersa_info	info;
1285			int error;
1286		};
1287	} answer;
1288
1289	memset(&req, 0, sizeof(req));
1290	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.spi));
1291	req.nh.nlmsg_type	= XFRM_MSG_ALLOCSPI;
1292	req.nh.nlmsg_flags	= NLM_F_REQUEST;
1293	req.nh.nlmsg_seq	= (*seq)++;
1294
1295	req.spi.info.family	= AF_INET;
1296	req.spi.min		= spi;
1297	req.spi.max		= spi;
1298	req.spi.info.id.proto	= proto;
1299
1300	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1301		pr_err("send()");
1302		return KSFT_FAIL;
1303	}
1304
1305	if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1306		pr_err("recv()");
1307		return KSFT_FAIL;
1308	} else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1309		uint32_t new_spi = htonl(answer.info.id.spi);
1310
1311		if (new_spi != spi) {
1312			printk("allocated spi is different from requested: %#x != %#x",
1313					new_spi, spi);
1314			return KSFT_FAIL;
1315		}
1316		return KSFT_PASS;
1317	} else if (answer.nh.nlmsg_type != NLMSG_ERROR) {
1318		printk("expected NLMSG_ERROR, got %d", (int)answer.nh.nlmsg_type);
1319		return KSFT_FAIL;
1320	}
1321
1322	printk("NLMSG_ERROR: %d: %s", answer.error, strerror(-answer.error));
1323	return (answer.error) ? KSFT_FAIL : KSFT_PASS;
1324}
1325
1326static int netlink_sock_bind(int *sock, uint32_t *seq, int proto, uint32_t groups)
1327{
1328	struct sockaddr_nl snl = {};
1329	socklen_t addr_len;
1330	int ret = -1;
1331
1332	snl.nl_family = AF_NETLINK;
1333	snl.nl_groups = groups;
1334
1335	if (netlink_sock(sock, seq, proto)) {
1336		printk("Failed to open xfrm netlink socket");
1337		return -1;
1338	}
1339
1340	if (bind(*sock, (struct sockaddr *)&snl, sizeof(snl)) < 0) {
1341		pr_err("bind()");
1342		goto out_close;
1343	}
1344
1345	addr_len = sizeof(snl);
1346	if (getsockname(*sock, (struct sockaddr *)&snl, &addr_len) < 0) {
1347		pr_err("getsockname()");
1348		goto out_close;
1349	}
1350	if (addr_len != sizeof(snl)) {
1351		printk("Wrong address length %d", addr_len);
1352		goto out_close;
1353	}
1354	if (snl.nl_family != AF_NETLINK) {
1355		printk("Wrong address family %d", snl.nl_family);
1356		goto out_close;
1357	}
1358	return 0;
1359
1360out_close:
1361	close(*sock);
1362	return ret;
1363}
1364
1365static int xfrm_monitor_acquire(int xfrm_sock, uint32_t *seq, unsigned int nr)
1366{
1367	struct {
1368		struct nlmsghdr nh;
1369		union {
1370			struct xfrm_user_acquire acq;
1371			int error;
1372		};
1373		char attrbuf[MAX_PAYLOAD];
1374	} req;
1375	struct xfrm_user_tmpl xfrm_tmpl = {};
1376	int xfrm_listen = -1, ret = KSFT_FAIL;
1377	uint32_t seq_listen;
1378
1379	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_ACQUIRE))
1380		return KSFT_FAIL;
1381
1382	memset(&req, 0, sizeof(req));
1383	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.acq));
1384	req.nh.nlmsg_type	= XFRM_MSG_ACQUIRE;
1385	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1386	req.nh.nlmsg_seq	= (*seq)++;
1387
1388	req.acq.policy.sel.family	= AF_INET;
1389	req.acq.aalgos	= 0xfeed;
1390	req.acq.ealgos	= 0xbaad;
1391	req.acq.calgos	= 0xbabe;
1392
1393	xfrm_tmpl.family = AF_INET;
1394	xfrm_tmpl.id.proto = IPPROTO_ESP;
1395	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &xfrm_tmpl, sizeof(xfrm_tmpl)))
1396		goto out_close;
1397
1398	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1399		pr_err("send()");
1400		goto out_close;
1401	}
1402
1403	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1404		pr_err("recv()");
1405		goto out_close;
1406	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1407		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1408		goto out_close;
1409	}
1410
1411	if (req.error) {
1412		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1413		ret = req.error;
1414		goto out_close;
1415	}
1416
1417	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1418		pr_err("recv()");
1419		goto out_close;
1420	}
1421
1422	if (req.acq.aalgos != 0xfeed || req.acq.ealgos != 0xbaad
1423			|| req.acq.calgos != 0xbabe) {
1424		printk("xfrm_user_acquire has changed  %x %x %x",
1425				req.acq.aalgos, req.acq.ealgos, req.acq.calgos);
1426		goto out_close;
1427	}
1428
1429	ret = KSFT_PASS;
1430out_close:
1431	close(xfrm_listen);
1432	return ret;
1433}
1434
1435static int xfrm_expire_state(int xfrm_sock, uint32_t *seq,
1436		unsigned int nr, struct xfrm_desc *desc)
1437{
1438	struct {
1439		struct nlmsghdr nh;
1440		union {
1441			struct xfrm_user_expire expire;
1442			int error;
1443		};
1444	} req;
1445	struct in_addr src, dst;
1446	int xfrm_listen = -1, ret = KSFT_FAIL;
1447	uint32_t seq_listen;
1448
1449	src = inet_makeaddr(INADDR_B, child_ip(nr));
1450	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1451
1452	if (xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc)) {
1453		printk("Failed to add xfrm state");
1454		return KSFT_FAIL;
1455	}
1456
1457	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1458		return KSFT_FAIL;
1459
1460	memset(&req, 0, sizeof(req));
1461	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.expire));
1462	req.nh.nlmsg_type	= XFRM_MSG_EXPIRE;
1463	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1464	req.nh.nlmsg_seq	= (*seq)++;
1465
1466	memcpy(&req.expire.state.id.daddr, &dst, sizeof(dst));
1467	req.expire.state.id.spi		= gen_spi(src);
1468	req.expire.state.id.proto	= desc->proto;
1469	req.expire.state.family		= AF_INET;
1470	req.expire.hard			= 0xff;
1471
1472	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1473		pr_err("send()");
1474		goto out_close;
1475	}
1476
1477	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1478		pr_err("recv()");
1479		goto out_close;
1480	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1481		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1482		goto out_close;
1483	}
1484
1485	if (req.error) {
1486		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1487		ret = req.error;
1488		goto out_close;
1489	}
1490
1491	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1492		pr_err("recv()");
1493		goto out_close;
1494	}
1495
1496	if (req.expire.hard != 0x1) {
1497		printk("expire.hard is not set: %x", req.expire.hard);
1498		goto out_close;
1499	}
1500
1501	ret = KSFT_PASS;
1502out_close:
1503	close(xfrm_listen);
1504	return ret;
1505}
1506
1507static int xfrm_expire_policy(int xfrm_sock, uint32_t *seq,
1508		unsigned int nr, struct xfrm_desc *desc)
1509{
1510	struct {
1511		struct nlmsghdr nh;
1512		union {
1513			struct xfrm_user_polexpire expire;
1514			int error;
1515		};
1516	} req;
1517	struct in_addr src, dst, tunsrc, tundst;
1518	int xfrm_listen = -1, ret = KSFT_FAIL;
1519	uint32_t seq_listen;
1520
1521	src = inet_makeaddr(INADDR_B, child_ip(nr));
1522	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1523	tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1524	tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1525
1526	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1527				XFRM_POLICY_OUT, tunsrc, tundst, desc->proto)) {
1528		printk("Failed to add xfrm policy");
1529		return KSFT_FAIL;
1530	}
1531
1532	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1533		return KSFT_FAIL;
1534
1535	memset(&req, 0, sizeof(req));
1536	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.expire));
1537	req.nh.nlmsg_type	= XFRM_MSG_POLEXPIRE;
1538	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1539	req.nh.nlmsg_seq	= (*seq)++;
1540
1541	/* Fill selector. */
1542	memcpy(&req.expire.pol.sel.daddr, &dst, sizeof(tundst));
1543	memcpy(&req.expire.pol.sel.saddr, &src, sizeof(tunsrc));
1544	req.expire.pol.sel.family	= AF_INET;
1545	req.expire.pol.sel.prefixlen_d	= PREFIX_LEN;
1546	req.expire.pol.sel.prefixlen_s	= PREFIX_LEN;
1547	req.expire.pol.dir		= XFRM_POLICY_OUT;
1548	req.expire.hard			= 0xff;
1549
1550	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1551		pr_err("send()");
1552		goto out_close;
1553	}
1554
1555	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1556		pr_err("recv()");
1557		goto out_close;
1558	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1559		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1560		goto out_close;
1561	}
1562
1563	if (req.error) {
1564		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1565		ret = req.error;
1566		goto out_close;
1567	}
1568
1569	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1570		pr_err("recv()");
1571		goto out_close;
1572	}
1573
1574	if (req.expire.hard != 0x1) {
1575		printk("expire.hard is not set: %x", req.expire.hard);
1576		goto out_close;
1577	}
1578
1579	ret = KSFT_PASS;
1580out_close:
1581	close(xfrm_listen);
1582	return ret;
1583}
1584
1585static int xfrm_spdinfo_set_thresh(int xfrm_sock, uint32_t *seq,
1586		unsigned thresh4_l, unsigned thresh4_r,
1587		unsigned thresh6_l, unsigned thresh6_r,
1588		bool add_bad_attr)
1589
1590{
1591	struct {
1592		struct nlmsghdr		nh;
1593		union {
1594			uint32_t	unused;
1595			int		error;
1596		};
1597		char			attrbuf[MAX_PAYLOAD];
1598	} req;
1599	struct xfrmu_spdhthresh thresh;
1600
1601	memset(&req, 0, sizeof(req));
1602	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.unused));
1603	req.nh.nlmsg_type	= XFRM_MSG_NEWSPDINFO;
1604	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1605	req.nh.nlmsg_seq	= (*seq)++;
1606
1607	thresh.lbits = thresh4_l;
1608	thresh.rbits = thresh4_r;
1609	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV4_HTHRESH, &thresh, sizeof(thresh)))
1610		return -1;
1611
1612	thresh.lbits = thresh6_l;
1613	thresh.rbits = thresh6_r;
1614	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV6_HTHRESH, &thresh, sizeof(thresh)))
1615		return -1;
1616
1617	if (add_bad_attr) {
1618		BUILD_BUG_ON(XFRMA_IF_ID <= XFRMA_SPD_MAX + 1);
1619		if (rtattr_pack(&req.nh, sizeof(req), XFRMA_IF_ID, NULL, 0)) {
1620			pr_err("adding attribute failed: no space");
1621			return -1;
1622		}
1623	}
1624
1625	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1626		pr_err("send()");
1627		return -1;
1628	}
1629
1630	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1631		pr_err("recv()");
1632		return -1;
1633	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1634		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1635		return -1;
1636	}
1637
1638	if (req.error) {
1639		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1640		return -1;
1641	}
1642
1643	return 0;
1644}
1645
1646static int xfrm_spdinfo_attrs(int xfrm_sock, uint32_t *seq)
1647{
1648	struct {
1649		struct nlmsghdr			nh;
1650		union {
1651			uint32_t	unused;
1652			int		error;
1653		};
1654		char			attrbuf[MAX_PAYLOAD];
1655	} req;
1656
1657	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 31, 120, 16, false)) {
1658		pr_err("Can't set SPD HTHRESH");
1659		return KSFT_FAIL;
1660	}
1661
1662	memset(&req, 0, sizeof(req));
1663
1664	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.unused));
1665	req.nh.nlmsg_type	= XFRM_MSG_GETSPDINFO;
1666	req.nh.nlmsg_flags	= NLM_F_REQUEST;
1667	req.nh.nlmsg_seq	= (*seq)++;
1668	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1669		pr_err("send()");
1670		return KSFT_FAIL;
1671	}
1672
1673	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1674		pr_err("recv()");
1675		return KSFT_FAIL;
1676	} else if (req.nh.nlmsg_type == XFRM_MSG_NEWSPDINFO) {
1677		size_t len = NLMSG_PAYLOAD(&req.nh, sizeof(req.unused));
1678		struct rtattr *attr = (void *)req.attrbuf;
1679		int got_thresh = 0;
1680
1681		for (; RTA_OK(attr, len); attr = RTA_NEXT(attr, len)) {
1682			if (attr->rta_type == XFRMA_SPD_IPV4_HTHRESH) {
1683				struct xfrmu_spdhthresh *t = RTA_DATA(attr);
1684
1685				got_thresh++;
1686				if (t->lbits != 32 || t->rbits != 31) {
1687					pr_err("thresh differ: %u, %u",
1688							t->lbits, t->rbits);
1689					return KSFT_FAIL;
1690				}
1691			}
1692			if (attr->rta_type == XFRMA_SPD_IPV6_HTHRESH) {
1693				struct xfrmu_spdhthresh *t = RTA_DATA(attr);
1694
1695				got_thresh++;
1696				if (t->lbits != 120 || t->rbits != 16) {
1697					pr_err("thresh differ: %u, %u",
1698							t->lbits, t->rbits);
1699					return KSFT_FAIL;
1700				}
1701			}
1702		}
1703		if (got_thresh != 2) {
1704			pr_err("only %d thresh returned by XFRM_MSG_GETSPDINFO", got_thresh);
1705			return KSFT_FAIL;
1706		}
1707	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1708		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1709		return KSFT_FAIL;
1710	} else {
1711		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1712		return -1;
1713	}
1714
1715	/* Restore the default */
1716	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, false)) {
1717		pr_err("Can't restore SPD HTHRESH");
1718		return KSFT_FAIL;
1719	}
1720
1721	/*
1722	 * At this moment xfrm uses nlmsg_parse_deprecated(), which
1723	 * implies NL_VALIDATE_LIBERAL - ignoring attributes with
1724	 * (type > maxtype). nla_parse_depricated_strict() would enforce
1725	 * it. Or even stricter nla_parse().
1726	 * Right now it's not expected to fail, but to be ignored.
1727	 */
1728	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, true))
1729		return KSFT_PASS;
1730
1731	return KSFT_PASS;
1732}
1733
1734static int child_serv(int xfrm_sock, uint32_t *seq,
1735		unsigned int nr, int cmd_fd, void *buf, struct xfrm_desc *desc)
1736{
1737	struct in_addr src, dst, tunsrc, tundst;
1738	struct test_desc msg;
1739	int ret = KSFT_FAIL;
1740
1741	src = inet_makeaddr(INADDR_B, child_ip(nr));
1742	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1743	tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1744	tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1745
1746	/* UDP pinging without xfrm */
1747	if (do_ping(cmd_fd, buf, page_size, src, true, 0, 0, udp_ping_send)) {
1748		printk("ping failed before setting xfrm");
1749		return KSFT_FAIL;
1750	}
1751
1752	memset(&msg, 0, sizeof(msg));
1753	msg.type = MSG_XFRM_PREPARE;
1754	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1755	write_msg(cmd_fd, &msg, 1);
1756
1757	if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1758		printk("failed to prepare xfrm");
1759		goto cleanup;
1760	}
1761
1762	memset(&msg, 0, sizeof(msg));
1763	msg.type = MSG_XFRM_ADD;
1764	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1765	write_msg(cmd_fd, &msg, 1);
1766	if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1767		printk("failed to set xfrm");
1768		goto delete;
1769	}
1770
1771	/* UDP pinging with xfrm tunnel */
1772	if (do_ping(cmd_fd, buf, page_size, tunsrc,
1773				true, 0, 0, udp_ping_send)) {
1774		printk("ping failed for xfrm");
1775		goto delete;
1776	}
1777
1778	ret = KSFT_PASS;
1779delete:
1780	/* xfrm delete */
1781	memset(&msg, 0, sizeof(msg));
1782	msg.type = MSG_XFRM_DEL;
1783	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1784	write_msg(cmd_fd, &msg, 1);
1785
1786	if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1787		printk("failed ping to remove xfrm");
1788		ret = KSFT_FAIL;
1789	}
1790
1791cleanup:
1792	memset(&msg, 0, sizeof(msg));
1793	msg.type = MSG_XFRM_CLEANUP;
1794	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1795	write_msg(cmd_fd, &msg, 1);
1796	if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1797		printk("failed ping to cleanup xfrm");
1798		ret = KSFT_FAIL;
1799	}
1800	return ret;
1801}
1802
1803static int child_f(unsigned int nr, int test_desc_fd, int cmd_fd, void *buf)
1804{
1805	struct xfrm_desc desc;
1806	struct test_desc msg;
1807	int xfrm_sock = -1;
1808	uint32_t seq;
1809
1810	if (switch_ns(nsfd_childa))
1811		exit(KSFT_FAIL);
1812
1813	if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1814		printk("Failed to open xfrm netlink socket");
1815		exit(KSFT_FAIL);
1816	}
1817
1818	/* Check that seq sock is ready, just for sure. */
1819	memset(&msg, 0, sizeof(msg));
1820	msg.type = MSG_ACK;
1821	write_msg(cmd_fd, &msg, 1);
1822	read_msg(cmd_fd, &msg, 1);
1823	if (msg.type != MSG_ACK) {
1824		printk("Ack failed");
1825		exit(KSFT_FAIL);
1826	}
1827
1828	for (;;) {
1829		ssize_t received = read(test_desc_fd, &desc, sizeof(desc));
1830		int ret;
1831
1832		if (received == 0) /* EOF */
1833			break;
1834
1835		if (received != sizeof(desc)) {
1836			pr_err("read() returned %zd", received);
1837			exit(KSFT_FAIL);
1838		}
1839
1840		switch (desc.type) {
1841		case CREATE_TUNNEL:
1842			ret = child_serv(xfrm_sock, &seq, nr,
1843					 cmd_fd, buf, &desc);
1844			break;
1845		case ALLOCATE_SPI:
1846			ret = xfrm_state_allocspi(xfrm_sock, &seq,
1847						  -1, desc.proto);
1848			break;
1849		case MONITOR_ACQUIRE:
1850			ret = xfrm_monitor_acquire(xfrm_sock, &seq, nr);
1851			break;
1852		case EXPIRE_STATE:
1853			ret = xfrm_expire_state(xfrm_sock, &seq, nr, &desc);
1854			break;
1855		case EXPIRE_POLICY:
1856			ret = xfrm_expire_policy(xfrm_sock, &seq, nr, &desc);
1857			break;
1858		case SPDINFO_ATTRS:
1859			ret = xfrm_spdinfo_attrs(xfrm_sock, &seq);
1860			break;
1861		default:
1862			printk("Unknown desc type %d", desc.type);
1863			exit(KSFT_FAIL);
1864		}
1865		write_test_result(ret, &desc);
1866	}
1867
1868	close(xfrm_sock);
1869
1870	msg.type = MSG_EXIT;
1871	write_msg(cmd_fd, &msg, 1);
1872	exit(KSFT_PASS);
1873}
1874
1875static void grand_child_serv(unsigned int nr, int cmd_fd, void *buf,
1876		struct test_desc *msg, int xfrm_sock, uint32_t *seq)
1877{
1878	struct in_addr src, dst, tunsrc, tundst;
1879	bool tun_reply;
1880	struct xfrm_desc *desc = &msg->body.xfrm_desc;
1881
1882	src = inet_makeaddr(INADDR_B, grchild_ip(nr));
1883	dst = inet_makeaddr(INADDR_B, child_ip(nr));
1884	tunsrc = inet_makeaddr(INADDR_A, grchild_ip(nr));
1885	tundst = inet_makeaddr(INADDR_A, child_ip(nr));
1886
1887	switch (msg->type) {
1888	case MSG_EXIT:
1889		exit(KSFT_PASS);
1890	case MSG_ACK:
1891		write_msg(cmd_fd, msg, 1);
1892		break;
1893	case MSG_PING:
1894		tun_reply = memcmp(&dst, &msg->body.ping.reply_ip, sizeof(in_addr_t));
1895		/* UDP pinging without xfrm */
1896		if (do_ping(cmd_fd, buf, page_size, tun_reply ? tunsrc : src,
1897				false, msg->body.ping.port,
1898				msg->body.ping.reply_ip, udp_ping_reply)) {
1899			printk("ping failed before setting xfrm");
1900		}
1901		break;
1902	case MSG_XFRM_PREPARE:
1903		if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst,
1904					desc->proto)) {
1905			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1906			printk("failed to prepare xfrm");
1907		}
1908		break;
1909	case MSG_XFRM_ADD:
1910		if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1911			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1912			printk("failed to set xfrm");
1913		}
1914		break;
1915	case MSG_XFRM_DEL:
1916		if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst,
1917					desc->proto)) {
1918			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1919			printk("failed to remove xfrm");
1920		}
1921		break;
1922	case MSG_XFRM_CLEANUP:
1923		if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1924			printk("failed to cleanup xfrm");
1925		}
1926		break;
1927	default:
1928		printk("got unknown msg type %d", msg->type);
1929	}
1930}
1931
1932static int grand_child_f(unsigned int nr, int cmd_fd, void *buf)
1933{
1934	struct test_desc msg;
1935	int xfrm_sock = -1;
1936	uint32_t seq;
1937
1938	if (switch_ns(nsfd_childb))
1939		exit(KSFT_FAIL);
1940
1941	if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1942		printk("Failed to open xfrm netlink socket");
1943		exit(KSFT_FAIL);
1944	}
1945
1946	do {
1947		read_msg(cmd_fd, &msg, 1);
1948		grand_child_serv(nr, cmd_fd, buf, &msg, xfrm_sock, &seq);
1949	} while (1);
1950
1951	close(xfrm_sock);
1952	exit(KSFT_FAIL);
1953}
1954
1955static int start_child(unsigned int nr, char *veth, int test_desc_fd[2])
1956{
1957	int cmd_sock[2];
1958	void *data_map;
1959	pid_t child;
1960
1961	if (init_child(nsfd_childa, veth, child_ip(nr), grchild_ip(nr)))
1962		return -1;
1963
1964	if (init_child(nsfd_childb, veth, grchild_ip(nr), child_ip(nr)))
1965		return -1;
1966
1967	child = fork();
1968	if (child < 0) {
1969		pr_err("fork()");
1970		return -1;
1971	} else if (child) {
1972		/* in parent - selftest */
1973		return switch_ns(nsfd_parent);
1974	}
1975
1976	if (close(test_desc_fd[1])) {
1977		pr_err("close()");
1978		return -1;
1979	}
1980
1981	/* child */
1982	data_map = mmap(0, page_size, PROT_READ | PROT_WRITE,
1983			MAP_SHARED | MAP_ANONYMOUS, -1, 0);
1984	if (data_map == MAP_FAILED) {
1985		pr_err("mmap()");
1986		return -1;
1987	}
1988
1989	randomize_buffer(data_map, page_size);
1990
1991	if (socketpair(PF_LOCAL, SOCK_SEQPACKET, 0, cmd_sock)) {
1992		pr_err("socketpair()");
1993		return -1;
1994	}
1995
1996	child = fork();
1997	if (child < 0) {
1998		pr_err("fork()");
1999		return -1;
2000	} else if (child) {
2001		if (close(cmd_sock[0])) {
2002			pr_err("close()");
2003			return -1;
2004		}
2005		return child_f(nr, test_desc_fd[0], cmd_sock[1], data_map);
2006	}
2007	if (close(cmd_sock[1])) {
2008		pr_err("close()");
2009		return -1;
2010	}
2011	return grand_child_f(nr, cmd_sock[0], data_map);
2012}
2013
2014static void exit_usage(char **argv)
2015{
2016	printk("Usage: %s [nr_process]", argv[0]);
2017	exit(KSFT_FAIL);
2018}
2019
2020static int __write_desc(int test_desc_fd, struct xfrm_desc *desc)
2021{
2022	ssize_t ret;
2023
2024	ret = write(test_desc_fd, desc, sizeof(*desc));
2025
2026	if (ret == sizeof(*desc))
2027		return 0;
2028
2029	pr_err("Writing test's desc failed %ld", ret);
2030
2031	return -1;
2032}
2033
2034static int write_desc(int proto, int test_desc_fd,
2035		char *a, char *e, char *c, char *ae)
2036{
2037	struct xfrm_desc desc = {};
2038
2039	desc.type = CREATE_TUNNEL;
2040	desc.proto = proto;
2041
2042	if (a)
2043		strncpy(desc.a_algo, a, ALGO_LEN - 1);
2044	if (e)
2045		strncpy(desc.e_algo, e, ALGO_LEN - 1);
2046	if (c)
2047		strncpy(desc.c_algo, c, ALGO_LEN - 1);
2048	if (ae)
2049		strncpy(desc.ae_algo, ae, ALGO_LEN - 1);
2050
2051	return __write_desc(test_desc_fd, &desc);
2052}
2053
2054int proto_list[] = { IPPROTO_AH, IPPROTO_COMP, IPPROTO_ESP };
2055char *ah_list[] = {
2056	"digest_null", "hmac(md5)", "hmac(sha1)", "hmac(sha256)",
2057	"hmac(sha384)", "hmac(sha512)", "hmac(rmd160)",
2058	"xcbc(aes)", "cmac(aes)"
2059};
2060char *comp_list[] = {
2061	"deflate",
2062#if 0
2063	/* No compression backend realization */
2064	"lzs", "lzjh"
2065#endif
2066};
2067char *e_list[] = {
2068	"ecb(cipher_null)", "cbc(des)", "cbc(des3_ede)", "cbc(cast5)",
2069	"cbc(blowfish)", "cbc(aes)", "cbc(serpent)", "cbc(camellia)",
2070	"cbc(twofish)", "rfc3686(ctr(aes))"
2071};
2072char *ae_list[] = {
2073#if 0
2074	/* not implemented */
2075	"rfc4106(gcm(aes))", "rfc4309(ccm(aes))", "rfc4543(gcm(aes))",
2076	"rfc7539esp(chacha20,poly1305)"
2077#endif
2078};
2079
2080const unsigned int proto_plan = ARRAY_SIZE(ah_list) + ARRAY_SIZE(comp_list) \
2081				+ (ARRAY_SIZE(ah_list) * ARRAY_SIZE(e_list)) \
2082				+ ARRAY_SIZE(ae_list);
2083
2084static int write_proto_plan(int fd, int proto)
2085{
2086	unsigned int i;
2087
2088	switch (proto) {
2089	case IPPROTO_AH:
2090		for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
2091			if (write_desc(proto, fd, ah_list[i], 0, 0, 0))
2092				return -1;
2093		}
2094		break;
2095	case IPPROTO_COMP:
2096		for (i = 0; i < ARRAY_SIZE(comp_list); i++) {
2097			if (write_desc(proto, fd, 0, 0, comp_list[i], 0))
2098				return -1;
2099		}
2100		break;
2101	case IPPROTO_ESP:
2102		for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
2103			int j;
2104
2105			for (j = 0; j < ARRAY_SIZE(e_list); j++) {
2106				if (write_desc(proto, fd, ah_list[i],
2107							e_list[j], 0, 0))
2108					return -1;
2109			}
2110		}
2111		for (i = 0; i < ARRAY_SIZE(ae_list); i++) {
2112			if (write_desc(proto, fd, 0, 0, 0, ae_list[i]))
2113				return -1;
2114		}
2115		break;
2116	default:
2117		printk("BUG: Specified unknown proto %d", proto);
2118		return -1;
2119	}
2120
2121	return 0;
2122}
2123
2124/*
2125 * Some structures in xfrm uapi header differ in size between
2126 * 64-bit and 32-bit ABI:
2127 *
2128 *             32-bit UABI               |            64-bit UABI
2129 *  -------------------------------------|-------------------------------------
2130 *   sizeof(xfrm_usersa_info)     = 220  |  sizeof(xfrm_usersa_info)     = 224
2131 *   sizeof(xfrm_userpolicy_info) = 164  |  sizeof(xfrm_userpolicy_info) = 168
2132 *   sizeof(xfrm_userspi_info)    = 228  |  sizeof(xfrm_userspi_info)    = 232
2133 *   sizeof(xfrm_user_acquire)    = 276  |  sizeof(xfrm_user_acquire)    = 280
2134 *   sizeof(xfrm_user_expire)     = 224  |  sizeof(xfrm_user_expire)     = 232
2135 *   sizeof(xfrm_user_polexpire)  = 168  |  sizeof(xfrm_user_polexpire)  = 176
2136 *
2137 * Check the affected by the UABI difference structures.
2138 * Also, check translation for xfrm_set_spdinfo: it has it's own attributes
2139 * which needs to be correctly copied, but not translated.
2140 */
2141const unsigned int compat_plan = 5;
2142static int write_compat_struct_tests(int test_desc_fd)
2143{
2144	struct xfrm_desc desc = {};
2145
2146	desc.type = ALLOCATE_SPI;
2147	desc.proto = IPPROTO_AH;
2148	strncpy(desc.a_algo, ah_list[0], ALGO_LEN - 1);
2149
2150	if (__write_desc(test_desc_fd, &desc))
2151		return -1;
2152
2153	desc.type = MONITOR_ACQUIRE;
2154	if (__write_desc(test_desc_fd, &desc))
2155		return -1;
2156
2157	desc.type = EXPIRE_STATE;
2158	if (__write_desc(test_desc_fd, &desc))
2159		return -1;
2160
2161	desc.type = EXPIRE_POLICY;
2162	if (__write_desc(test_desc_fd, &desc))
2163		return -1;
2164
2165	desc.type = SPDINFO_ATTRS;
2166	if (__write_desc(test_desc_fd, &desc))
2167		return -1;
2168
2169	return 0;
2170}
2171
2172static int write_test_plan(int test_desc_fd)
2173{
2174	unsigned int i;
2175	pid_t child;
2176
2177	child = fork();
2178	if (child < 0) {
2179		pr_err("fork()");
2180		return -1;
2181	}
2182	if (child) {
2183		if (close(test_desc_fd))
2184			printk("close(): %m");
2185		return 0;
2186	}
2187
2188	if (write_compat_struct_tests(test_desc_fd))
2189		exit(KSFT_FAIL);
2190
2191	for (i = 0; i < ARRAY_SIZE(proto_list); i++) {
2192		if (write_proto_plan(test_desc_fd, proto_list[i]))
2193			exit(KSFT_FAIL);
2194	}
2195
2196	exit(KSFT_PASS);
2197}
2198
2199static int children_cleanup(void)
2200{
2201	unsigned ret = KSFT_PASS;
2202
2203	while (1) {
2204		int status;
2205		pid_t p = wait(&status);
2206
2207		if ((p < 0) && errno == ECHILD)
2208			break;
2209
2210		if (p < 0) {
2211			pr_err("wait()");
2212			return KSFT_FAIL;
2213		}
2214
2215		if (!WIFEXITED(status)) {
2216			ret = KSFT_FAIL;
2217			continue;
2218		}
2219
2220		if (WEXITSTATUS(status) == KSFT_FAIL)
2221			ret = KSFT_FAIL;
2222	}
2223
2224	return ret;
2225}
2226
2227typedef void (*print_res)(const char *, ...);
2228
2229static int check_results(void)
2230{
2231	struct test_result tr = {};
2232	struct xfrm_desc *d = &tr.desc;
2233	int ret = KSFT_PASS;
2234
2235	while (1) {
2236		ssize_t received = read(results_fd[0], &tr, sizeof(tr));
2237		print_res result;
2238
2239		if (received == 0) /* EOF */
2240			break;
2241
2242		if (received != sizeof(tr)) {
2243			pr_err("read() returned %zd", received);
2244			return KSFT_FAIL;
2245		}
2246
2247		switch (tr.res) {
2248		case KSFT_PASS:
2249			result = ksft_test_result_pass;
2250			break;
2251		case KSFT_FAIL:
2252		default:
2253			result = ksft_test_result_fail;
2254			ret = KSFT_FAIL;
2255		}
2256
2257		result(" %s: [%u, '%s', '%s', '%s', '%s', %u]\n",
2258		       desc_name[d->type], (unsigned int)d->proto, d->a_algo,
2259		       d->e_algo, d->c_algo, d->ae_algo, d->icv_len);
2260	}
2261
2262	return ret;
2263}
2264
2265int main(int argc, char **argv)
2266{
2267	long nr_process = 1;
2268	int route_sock = -1, ret = KSFT_SKIP;
2269	int test_desc_fd[2];
2270	uint32_t route_seq;
2271	unsigned int i;
2272
2273	if (argc > 2)
2274		exit_usage(argv);
2275
2276	if (argc > 1) {
2277		char *endptr;
2278
2279		errno = 0;
2280		nr_process = strtol(argv[1], &endptr, 10);
2281		if ((errno == ERANGE && (nr_process == LONG_MAX || nr_process == LONG_MIN))
2282				|| (errno != 0 && nr_process == 0)
2283				|| (endptr == argv[1]) || (*endptr != '\0')) {
2284			printk("Failed to parse [nr_process]");
2285			exit_usage(argv);
2286		}
2287
2288		if (nr_process > MAX_PROCESSES || nr_process < 1) {
2289			printk("nr_process should be between [1; %u]",
2290					MAX_PROCESSES);
2291			exit_usage(argv);
2292		}
2293	}
2294
2295	srand(time(NULL));
2296	page_size = sysconf(_SC_PAGESIZE);
2297	if (page_size < 1)
2298		ksft_exit_skip("sysconf(): %m\n");
2299
2300	if (pipe2(test_desc_fd, O_DIRECT) < 0)
2301		ksft_exit_skip("pipe(): %m\n");
2302
2303	if (pipe2(results_fd, O_DIRECT) < 0)
2304		ksft_exit_skip("pipe(): %m\n");
2305
2306	if (init_namespaces())
2307		ksft_exit_skip("Failed to create namespaces\n");
2308
2309	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
2310		ksft_exit_skip("Failed to open netlink route socket\n");
2311
2312	for (i = 0; i < nr_process; i++) {
2313		char veth[VETH_LEN];
2314
2315		snprintf(veth, VETH_LEN, VETH_FMT, i);
2316
2317		if (veth_add(route_sock, route_seq++, veth, nsfd_childa, veth, nsfd_childb)) {
2318			close(route_sock);
2319			ksft_exit_fail_msg("Failed to create veth device");
2320		}
2321
2322		if (start_child(i, veth, test_desc_fd)) {
2323			close(route_sock);
2324			ksft_exit_fail_msg("Child %u failed to start", i);
2325		}
2326	}
2327
2328	if (close(route_sock) || close(test_desc_fd[0]) || close(results_fd[1]))
2329		ksft_exit_fail_msg("close(): %m");
2330
2331	ksft_set_plan(proto_plan + compat_plan);
2332
2333	if (write_test_plan(test_desc_fd[1]))
2334		ksft_exit_fail_msg("Failed to write test plan to pipe");
2335
2336	ret = check_results();
2337
2338	if (children_cleanup() == KSFT_FAIL)
2339		exit(KSFT_FAIL);
2340
2341	exit(ret);
2342}