Linux Audio

Check our new training course

Loading...
v6.13.7
  1// SPDX-License-Identifier: GPL-2.0
  2/* Copyright Amazon.com Inc. or its affiliates. */
  3
  4#include "vmlinux.h"
  5
  6#include <bpf/bpf_helpers.h>
  7#include <bpf/bpf_endian.h>
  8#include "bpf_tracing_net.h"
  9#include "bpf_kfuncs.h"
 10#include "test_siphash.h"
 11#include "test_tcp_custom_syncookie.h"
 12#include "bpf_misc.h"
 13
 14#define MAX_PACKET_OFF 0xffff
 15
 16/* Hash is calculated for each client and split into ISN and TS.
 17 *
 18 *       MSB                                   LSB
 19 * ISN:  | 31 ... 8 | 7 6 |   5 |    4 | 3 2 1 0 |
 20 *       |   Hash_1 | MSS | ECN | SACK |  WScale |
 21 *
 22 * TS:   | 31 ... 8 |          7 ... 0           |
 23 *       |   Random |           Hash_2           |
 24 */
 25#define COOKIE_BITS	8
 26#define COOKIE_MASK	(((__u32)1 << COOKIE_BITS) - 1)
 27
 28enum {
 29	/* 0xf is invalid thus means that SYN did not have WScale. */
 30	BPF_SYNCOOKIE_WSCALE_MASK	= (1 << 4) - 1,
 31	BPF_SYNCOOKIE_SACK		= (1 << 4),
 32	BPF_SYNCOOKIE_ECN		= (1 << 5),
 33};
 34
 35#define MSS_LOCAL_IPV4	65495
 36#define MSS_LOCAL_IPV6	65476
 37
 38const __u16 msstab4[] = {
 39	536,
 40	1300,
 41	1460,
 42	MSS_LOCAL_IPV4,
 43};
 44
 45const __u16 msstab6[] = {
 46	1280 - 60, /* IPV6_MIN_MTU - 60 */
 47	1480 - 60,
 48	9000 - 60,
 49	MSS_LOCAL_IPV6,
 50};
 51
 52static siphash_key_t test_key_siphash = {
 53	{ 0x0706050403020100ULL, 0x0f0e0d0c0b0a0908ULL }
 54};
 55
 56struct tcp_syncookie {
 57	struct __sk_buff *skb;
 58	void *data;
 59	void *data_end;
 60	struct ethhdr *eth;
 61	struct iphdr *ipv4;
 62	struct ipv6hdr *ipv6;
 63	struct tcphdr *tcp;
 64	__be32 *ptr32;
 65	struct bpf_tcp_req_attrs attrs;
 66	u32 off;
 67	u32 cookie;
 68	u64 first;
 69};
 70
 71bool handled_syn, handled_ack;
 72
 73static int tcp_load_headers(struct tcp_syncookie *ctx)
 74{
 75	ctx->data = (void *)(long)ctx->skb->data;
 76	ctx->data_end = (void *)(long)ctx->skb->data_end;
 77	ctx->eth = (struct ethhdr *)(long)ctx->skb->data;
 78
 79	if (ctx->eth + 1 > ctx->data_end)
 80		goto err;
 81
 82	switch (bpf_ntohs(ctx->eth->h_proto)) {
 83	case ETH_P_IP:
 84		ctx->ipv4 = (struct iphdr *)(ctx->eth + 1);
 85
 86		if (ctx->ipv4 + 1 > ctx->data_end)
 87			goto err;
 88
 89		if (ctx->ipv4->ihl != sizeof(*ctx->ipv4) / 4)
 90			goto err;
 91
 92		if (ctx->ipv4->version != 4)
 93			goto err;
 94
 95		if (ctx->ipv4->protocol != IPPROTO_TCP)
 96			goto err;
 97
 98		ctx->tcp = (struct tcphdr *)(ctx->ipv4 + 1);
 99		break;
100	case ETH_P_IPV6:
101		ctx->ipv6 = (struct ipv6hdr *)(ctx->eth + 1);
102
103		if (ctx->ipv6 + 1 > ctx->data_end)
104			goto err;
105
106		if (ctx->ipv6->version != 6)
107			goto err;
108
109		if (ctx->ipv6->nexthdr != NEXTHDR_TCP)
110			goto err;
111
112		ctx->tcp = (struct tcphdr *)(ctx->ipv6 + 1);
113		break;
114	default:
115		goto err;
116	}
117
118	if (ctx->tcp + 1 > ctx->data_end)
119		goto err;
120
121	return 0;
122err:
123	return -1;
124}
125
126static int tcp_reload_headers(struct tcp_syncookie *ctx)
127{
128	/* Without volatile,
129	 * R3 32-bit pointer arithmetic prohibited
130	 */
131	volatile u64 data_len = ctx->skb->data_end - ctx->skb->data;
132
133	if (ctx->tcp->doff < sizeof(*ctx->tcp) / 4)
134		goto err;
135
136	/* Needed to calculate csum and parse TCP options. */
137	if (bpf_skb_change_tail(ctx->skb, data_len + 60 - ctx->tcp->doff * 4, 0))
138		goto err;
139
140	ctx->data = (void *)(long)ctx->skb->data;
141	ctx->data_end = (void *)(long)ctx->skb->data_end;
142	ctx->eth = (struct ethhdr *)(long)ctx->skb->data;
143	if (ctx->ipv4) {
144		ctx->ipv4 = (struct iphdr *)(ctx->eth + 1);
145		ctx->ipv6 = NULL;
146		ctx->tcp = (struct tcphdr *)(ctx->ipv4 + 1);
147	} else {
148		ctx->ipv4 = NULL;
149		ctx->ipv6 = (struct ipv6hdr *)(ctx->eth + 1);
150		ctx->tcp = (struct tcphdr *)(ctx->ipv6 + 1);
151	}
152
153	if ((void *)ctx->tcp + 60 > ctx->data_end)
154		goto err;
155
156	return 0;
157err:
158	return -1;
159}
160
161static __sum16 tcp_v4_csum(struct tcp_syncookie *ctx, __wsum csum)
162{
163	return csum_tcpudp_magic(ctx->ipv4->saddr, ctx->ipv4->daddr,
164				 ctx->tcp->doff * 4, IPPROTO_TCP, csum);
165}
166
167static __sum16 tcp_v6_csum(struct tcp_syncookie *ctx, __wsum csum)
168{
169	return csum_ipv6_magic(&ctx->ipv6->saddr, &ctx->ipv6->daddr,
170			       ctx->tcp->doff * 4, IPPROTO_TCP, csum);
171}
172
173static int tcp_validate_header(struct tcp_syncookie *ctx)
174{
175	s64 csum;
176
177	if (tcp_reload_headers(ctx))
178		goto err;
179
180	csum = bpf_csum_diff(0, 0, (void *)ctx->tcp, ctx->tcp->doff * 4, 0);
181	if (csum < 0)
182		goto err;
183
184	if (ctx->ipv4) {
185		/* check tcp_v4_csum(csum) is 0 if not on lo. */
186
187		csum = bpf_csum_diff(0, 0, (void *)ctx->ipv4, ctx->ipv4->ihl * 4, 0);
188		if (csum < 0)
189			goto err;
190
191		if (csum_fold(csum) != 0)
192			goto err;
193	} else if (ctx->ipv6) {
194		/* check tcp_v6_csum(csum) is 0 if not on lo. */
195	}
196
197	return 0;
198err:
199	return -1;
200}
201
202static __always_inline void *next(struct tcp_syncookie *ctx, __u32 sz)
203{
204	__u64 off = ctx->off;
205	__u8 *data;
206
207	/* Verifier forbids access to packet when offset exceeds MAX_PACKET_OFF */
208	if (off > MAX_PACKET_OFF - sz)
209		return NULL;
210
211	data = ctx->data + off;
212	barrier_var(data);
213	if (data + sz >= ctx->data_end)
214		return NULL;
215
216	ctx->off += sz;
217	return data;
218}
219
220static int tcp_parse_option(__u32 index, struct tcp_syncookie *ctx)
221{
222	__u8 *opcode, *opsize, *wscale;
223	__u32 *tsval, *tsecr;
224	__u16 *mss;
225	__u32 off;
226
227	off = ctx->off;
228	opcode = next(ctx, 1);
229	if (!opcode)
230		goto stop;
231
232	if (*opcode == TCPOPT_EOL)
233		goto stop;
234
235	if (*opcode == TCPOPT_NOP)
236		goto next;
237
238	opsize = next(ctx, 1);
239	if (!opsize)
240		goto stop;
241
242	if (*opsize < 2)
243		goto stop;
244
245	switch (*opcode) {
246	case TCPOPT_MSS:
247		mss = next(ctx, 2);
248		if (*opsize == TCPOLEN_MSS && ctx->tcp->syn && mss)
249			ctx->attrs.mss = get_unaligned_be16(mss);
250		break;
251	case TCPOPT_WINDOW:
252		wscale = next(ctx, 1);
253		if (*opsize == TCPOLEN_WINDOW && ctx->tcp->syn && wscale) {
254			ctx->attrs.wscale_ok = 1;
255			ctx->attrs.snd_wscale = *wscale;
256		}
257		break;
258	case TCPOPT_TIMESTAMP:
259		tsval = next(ctx, 4);
260		tsecr = next(ctx, 4);
261		if (*opsize == TCPOLEN_TIMESTAMP && tsval && tsecr) {
262			ctx->attrs.rcv_tsval = get_unaligned_be32(tsval);
263			ctx->attrs.rcv_tsecr = get_unaligned_be32(tsecr);
264
265			if (ctx->tcp->syn && ctx->attrs.rcv_tsecr)
266				ctx->attrs.tstamp_ok = 0;
267			else
268				ctx->attrs.tstamp_ok = 1;
269		}
270		break;
271	case TCPOPT_SACK_PERM:
272		if (*opsize == TCPOLEN_SACK_PERM && ctx->tcp->syn)
273			ctx->attrs.sack_ok = 1;
274		break;
275	}
276
277	ctx->off = off + *opsize;
278next:
279	return 0;
280stop:
281	return 1;
282}
283
284static void tcp_parse_options(struct tcp_syncookie *ctx)
285{
286	ctx->off = (__u8 *)(ctx->tcp + 1) - (__u8 *)ctx->data,
287
288	bpf_loop(40, tcp_parse_option, ctx, 0);
289}
290
291static int tcp_validate_sysctl(struct tcp_syncookie *ctx)
292{
293	if ((ctx->ipv4 && ctx->attrs.mss != MSS_LOCAL_IPV4) ||
294	    (ctx->ipv6 && ctx->attrs.mss != MSS_LOCAL_IPV6))
295		goto err;
296
297	if (!ctx->attrs.wscale_ok || ctx->attrs.snd_wscale != 7)
298		goto err;
299
300	if (!ctx->attrs.tstamp_ok)
301		goto err;
302
303	if (!ctx->attrs.sack_ok)
304		goto err;
305
306	if (!ctx->tcp->ece || !ctx->tcp->cwr)
307		goto err;
308
309	return 0;
310err:
311	return -1;
312}
313
314static void tcp_prepare_cookie(struct tcp_syncookie *ctx)
315{
316	u32 seq = bpf_ntohl(ctx->tcp->seq);
317	u64 first = 0, second;
318	int mssind = 0;
319	u32 hash;
320
321	if (ctx->ipv4) {
322		for (mssind = ARRAY_SIZE(msstab4) - 1; mssind; mssind--)
323			if (ctx->attrs.mss >= msstab4[mssind])
324				break;
325
326		ctx->attrs.mss = msstab4[mssind];
327
328		first = (u64)ctx->ipv4->saddr << 32 | ctx->ipv4->daddr;
329	} else if (ctx->ipv6) {
330		for (mssind = ARRAY_SIZE(msstab6) - 1; mssind; mssind--)
331			if (ctx->attrs.mss >= msstab6[mssind])
332				break;
333
334		ctx->attrs.mss = msstab6[mssind];
335
336		first = (u64)ctx->ipv6->saddr.in6_u.u6_addr8[0] << 32 |
337			ctx->ipv6->daddr.in6_u.u6_addr32[0];
338	}
339
340	second = (u64)seq << 32 | ctx->tcp->source << 16 | ctx->tcp->dest;
341	hash = siphash_2u64(first, second, &test_key_siphash);
342
343	if (ctx->attrs.tstamp_ok) {
344		ctx->attrs.rcv_tsecr = bpf_get_prandom_u32();
345		ctx->attrs.rcv_tsecr &= ~COOKIE_MASK;
346		ctx->attrs.rcv_tsecr |= hash & COOKIE_MASK;
347	}
348
349	hash &= ~COOKIE_MASK;
350	hash |= mssind << 6;
351
352	if (ctx->attrs.wscale_ok)
353		hash |= ctx->attrs.snd_wscale & BPF_SYNCOOKIE_WSCALE_MASK;
354
355	if (ctx->attrs.sack_ok)
356		hash |= BPF_SYNCOOKIE_SACK;
357
358	if (ctx->attrs.tstamp_ok && ctx->tcp->ece && ctx->tcp->cwr)
359		hash |= BPF_SYNCOOKIE_ECN;
360
361	ctx->cookie = hash;
362}
363
364static void tcp_write_options(struct tcp_syncookie *ctx)
365{
366	ctx->ptr32 = (__be32 *)(ctx->tcp + 1);
367
368	*ctx->ptr32++ = bpf_htonl(TCPOPT_MSS << 24 | TCPOLEN_MSS << 16 |
369				  ctx->attrs.mss);
370
371	if (ctx->attrs.wscale_ok)
372		*ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 |
373					  TCPOPT_WINDOW << 16 |
374					  TCPOLEN_WINDOW << 8 |
375					  ctx->attrs.snd_wscale);
376
377	if (ctx->attrs.tstamp_ok) {
378		if (ctx->attrs.sack_ok)
379			*ctx->ptr32++ = bpf_htonl(TCPOPT_SACK_PERM << 24 |
380						  TCPOLEN_SACK_PERM << 16 |
381						  TCPOPT_TIMESTAMP << 8 |
382						  TCPOLEN_TIMESTAMP);
383		else
384			*ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 |
385						  TCPOPT_NOP << 16 |
386						  TCPOPT_TIMESTAMP << 8 |
387						  TCPOLEN_TIMESTAMP);
388
389		*ctx->ptr32++ = bpf_htonl(ctx->attrs.rcv_tsecr);
390		*ctx->ptr32++ = bpf_htonl(ctx->attrs.rcv_tsval);
391	} else if (ctx->attrs.sack_ok) {
392		*ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 |
393					  TCPOPT_NOP << 16 |
394					  TCPOPT_SACK_PERM << 8 |
395					  TCPOLEN_SACK_PERM);
396	}
397}
398
399static int tcp_handle_syn(struct tcp_syncookie *ctx)
400{
401	s64 csum;
402
403	if (tcp_validate_header(ctx))
404		goto err;
405
406	tcp_parse_options(ctx);
407
408	if (tcp_validate_sysctl(ctx))
409		goto err;
410
411	tcp_prepare_cookie(ctx);
412	tcp_write_options(ctx);
413
414	swap(ctx->tcp->source, ctx->tcp->dest);
415	ctx->tcp->check = 0;
416	ctx->tcp->ack_seq = bpf_htonl(bpf_ntohl(ctx->tcp->seq) + 1);
417	ctx->tcp->seq = bpf_htonl(ctx->cookie);
418	ctx->tcp->doff = ((long)ctx->ptr32 - (long)ctx->tcp) >> 2;
419	ctx->tcp->ack = 1;
420	if (!ctx->attrs.tstamp_ok || !ctx->tcp->ece || !ctx->tcp->cwr)
421		ctx->tcp->ece = 0;
422	ctx->tcp->cwr = 0;
423
424	csum = bpf_csum_diff(0, 0, (void *)ctx->tcp, ctx->tcp->doff * 4, 0);
425	if (csum < 0)
426		goto err;
427
428	if (ctx->ipv4) {
429		swap(ctx->ipv4->saddr, ctx->ipv4->daddr);
430		ctx->tcp->check = tcp_v4_csum(ctx, csum);
431
432		ctx->ipv4->check = 0;
433		ctx->ipv4->tos = 0;
434		ctx->ipv4->tot_len = bpf_htons((long)ctx->ptr32 - (long)ctx->ipv4);
435		ctx->ipv4->id = 0;
436		ctx->ipv4->ttl = 64;
437
438		csum = bpf_csum_diff(0, 0, (void *)ctx->ipv4, sizeof(*ctx->ipv4), 0);
439		if (csum < 0)
440			goto err;
441
442		ctx->ipv4->check = csum_fold(csum);
443	} else if (ctx->ipv6) {
444		swap(ctx->ipv6->saddr, ctx->ipv6->daddr);
445		ctx->tcp->check = tcp_v6_csum(ctx, csum);
446
447		*(__be32 *)ctx->ipv6 = bpf_htonl(0x60000000);
448		ctx->ipv6->payload_len = bpf_htons((long)ctx->ptr32 - (long)ctx->tcp);
449		ctx->ipv6->hop_limit = 64;
450	}
451
452	swap_array(ctx->eth->h_source, ctx->eth->h_dest);
453
454	if (bpf_skb_change_tail(ctx->skb, (long)ctx->ptr32 - (long)ctx->eth, 0))
455		goto err;
456
457	return bpf_redirect(ctx->skb->ifindex, 0);
458err:
459	return TC_ACT_SHOT;
460}
461
462static int tcp_validate_cookie(struct tcp_syncookie *ctx)
463{
464	u32 cookie = bpf_ntohl(ctx->tcp->ack_seq) - 1;
465	u32 seq = bpf_ntohl(ctx->tcp->seq) - 1;
466	u64 first = 0, second;
467	int mssind;
468	u32 hash;
469
470	if (ctx->ipv4)
471		first = (u64)ctx->ipv4->saddr << 32 | ctx->ipv4->daddr;
472	else if (ctx->ipv6)
473		first = (u64)ctx->ipv6->saddr.in6_u.u6_addr8[0] << 32 |
474			ctx->ipv6->daddr.in6_u.u6_addr32[0];
475
476	second = (u64)seq << 32 | ctx->tcp->source << 16 | ctx->tcp->dest;
477	hash = siphash_2u64(first, second, &test_key_siphash);
478
479	if (ctx->attrs.tstamp_ok)
480		hash -= ctx->attrs.rcv_tsecr & COOKIE_MASK;
481	else
482		hash &= ~COOKIE_MASK;
483
484	hash -= cookie & ~COOKIE_MASK;
485	if (hash)
486		goto err;
487
488	mssind = (cookie & (3 << 6)) >> 6;
489	if (ctx->ipv4)
 
 
 
490		ctx->attrs.mss = msstab4[mssind];
491	else
 
 
 
492		ctx->attrs.mss = msstab6[mssind];
 
493
494	ctx->attrs.snd_wscale = cookie & BPF_SYNCOOKIE_WSCALE_MASK;
495	ctx->attrs.rcv_wscale = ctx->attrs.snd_wscale;
496	ctx->attrs.wscale_ok = ctx->attrs.snd_wscale == BPF_SYNCOOKIE_WSCALE_MASK;
497	ctx->attrs.sack_ok = cookie & BPF_SYNCOOKIE_SACK;
498	ctx->attrs.ecn_ok = cookie & BPF_SYNCOOKIE_ECN;
499
500	return 0;
501err:
502	return -1;
503}
504
505static int tcp_handle_ack(struct tcp_syncookie *ctx)
506{
507	struct bpf_sock_tuple tuple;
508	struct bpf_sock *skc;
509	int ret = TC_ACT_OK;
510	struct sock *sk;
511	u32 tuple_size;
512
513	if (ctx->ipv4) {
514		tuple.ipv4.saddr = ctx->ipv4->saddr;
515		tuple.ipv4.daddr = ctx->ipv4->daddr;
516		tuple.ipv4.sport = ctx->tcp->source;
517		tuple.ipv4.dport = ctx->tcp->dest;
518		tuple_size = sizeof(tuple.ipv4);
519	} else if (ctx->ipv6) {
520		__builtin_memcpy(tuple.ipv6.saddr, &ctx->ipv6->saddr, sizeof(tuple.ipv6.saddr));
521		__builtin_memcpy(tuple.ipv6.daddr, &ctx->ipv6->daddr, sizeof(tuple.ipv6.daddr));
522		tuple.ipv6.sport = ctx->tcp->source;
523		tuple.ipv6.dport = ctx->tcp->dest;
524		tuple_size = sizeof(tuple.ipv6);
525	} else {
526		goto out;
527	}
528
529	skc = bpf_skc_lookup_tcp(ctx->skb, &tuple, tuple_size, -1, 0);
530	if (!skc)
531		goto out;
532
533	if (skc->state != TCP_LISTEN)
534		goto release;
535
536	sk = (struct sock *)bpf_skc_to_tcp_sock(skc);
537	if (!sk)
538		goto err;
539
540	if (tcp_validate_header(ctx))
541		goto err;
542
543	tcp_parse_options(ctx);
544
545	if (tcp_validate_cookie(ctx))
546		goto err;
547
548	ret = bpf_sk_assign_tcp_reqsk(ctx->skb, sk, &ctx->attrs, sizeof(ctx->attrs));
549	if (ret < 0)
550		goto err;
551
552release:
553	bpf_sk_release(skc);
554out:
555	return ret;
556
557err:
558	ret = TC_ACT_SHOT;
559	goto release;
560}
561
562SEC("tc")
563int tcp_custom_syncookie(struct __sk_buff *skb)
564{
565	struct tcp_syncookie ctx = {
566		.skb = skb,
567	};
568
569	if (tcp_load_headers(&ctx))
570		return TC_ACT_OK;
571
572	if (ctx.tcp->rst)
573		return TC_ACT_OK;
574
575	if (ctx.tcp->syn) {
576		if (ctx.tcp->ack)
577			return TC_ACT_OK;
578
579		handled_syn = true;
580
581		return tcp_handle_syn(&ctx);
582	}
583
584	handled_ack = true;
585
586	return tcp_handle_ack(&ctx);
587}
588
589char _license[] SEC("license") = "GPL";
v6.9.4
  1// SPDX-License-Identifier: GPL-2.0
  2/* Copyright Amazon.com Inc. or its affiliates. */
  3
  4#include "vmlinux.h"
  5
  6#include <bpf/bpf_helpers.h>
  7#include <bpf/bpf_endian.h>
  8#include "bpf_tracing_net.h"
  9#include "bpf_kfuncs.h"
 10#include "test_siphash.h"
 11#include "test_tcp_custom_syncookie.h"
 
 12
 13#define MAX_PACKET_OFF 0xffff
 14
 15/* Hash is calculated for each client and split into ISN and TS.
 16 *
 17 *       MSB                                   LSB
 18 * ISN:  | 31 ... 8 | 7 6 |   5 |    4 | 3 2 1 0 |
 19 *       |   Hash_1 | MSS | ECN | SACK |  WScale |
 20 *
 21 * TS:   | 31 ... 8 |          7 ... 0           |
 22 *       |   Random |           Hash_2           |
 23 */
 24#define COOKIE_BITS	8
 25#define COOKIE_MASK	(((__u32)1 << COOKIE_BITS) - 1)
 26
 27enum {
 28	/* 0xf is invalid thus means that SYN did not have WScale. */
 29	BPF_SYNCOOKIE_WSCALE_MASK	= (1 << 4) - 1,
 30	BPF_SYNCOOKIE_SACK		= (1 << 4),
 31	BPF_SYNCOOKIE_ECN		= (1 << 5),
 32};
 33
 34#define MSS_LOCAL_IPV4	65495
 35#define MSS_LOCAL_IPV6	65476
 36
 37const __u16 msstab4[] = {
 38	536,
 39	1300,
 40	1460,
 41	MSS_LOCAL_IPV4,
 42};
 43
 44const __u16 msstab6[] = {
 45	1280 - 60, /* IPV6_MIN_MTU - 60 */
 46	1480 - 60,
 47	9000 - 60,
 48	MSS_LOCAL_IPV6,
 49};
 50
 51static siphash_key_t test_key_siphash = {
 52	{ 0x0706050403020100ULL, 0x0f0e0d0c0b0a0908ULL }
 53};
 54
 55struct tcp_syncookie {
 56	struct __sk_buff *skb;
 57	void *data;
 58	void *data_end;
 59	struct ethhdr *eth;
 60	struct iphdr *ipv4;
 61	struct ipv6hdr *ipv6;
 62	struct tcphdr *tcp;
 63	__be32 *ptr32;
 64	struct bpf_tcp_req_attrs attrs;
 65	u32 off;
 66	u32 cookie;
 67	u64 first;
 68};
 69
 70bool handled_syn, handled_ack;
 71
 72static int tcp_load_headers(struct tcp_syncookie *ctx)
 73{
 74	ctx->data = (void *)(long)ctx->skb->data;
 75	ctx->data_end = (void *)(long)ctx->skb->data_end;
 76	ctx->eth = (struct ethhdr *)(long)ctx->skb->data;
 77
 78	if (ctx->eth + 1 > ctx->data_end)
 79		goto err;
 80
 81	switch (bpf_ntohs(ctx->eth->h_proto)) {
 82	case ETH_P_IP:
 83		ctx->ipv4 = (struct iphdr *)(ctx->eth + 1);
 84
 85		if (ctx->ipv4 + 1 > ctx->data_end)
 86			goto err;
 87
 88		if (ctx->ipv4->ihl != sizeof(*ctx->ipv4) / 4)
 89			goto err;
 90
 91		if (ctx->ipv4->version != 4)
 92			goto err;
 93
 94		if (ctx->ipv4->protocol != IPPROTO_TCP)
 95			goto err;
 96
 97		ctx->tcp = (struct tcphdr *)(ctx->ipv4 + 1);
 98		break;
 99	case ETH_P_IPV6:
100		ctx->ipv6 = (struct ipv6hdr *)(ctx->eth + 1);
101
102		if (ctx->ipv6 + 1 > ctx->data_end)
103			goto err;
104
105		if (ctx->ipv6->version != 6)
106			goto err;
107
108		if (ctx->ipv6->nexthdr != NEXTHDR_TCP)
109			goto err;
110
111		ctx->tcp = (struct tcphdr *)(ctx->ipv6 + 1);
112		break;
113	default:
114		goto err;
115	}
116
117	if (ctx->tcp + 1 > ctx->data_end)
118		goto err;
119
120	return 0;
121err:
122	return -1;
123}
124
125static int tcp_reload_headers(struct tcp_syncookie *ctx)
126{
127	/* Without volatile,
128	 * R3 32-bit pointer arithmetic prohibited
129	 */
130	volatile u64 data_len = ctx->skb->data_end - ctx->skb->data;
131
132	if (ctx->tcp->doff < sizeof(*ctx->tcp) / 4)
133		goto err;
134
135	/* Needed to calculate csum and parse TCP options. */
136	if (bpf_skb_change_tail(ctx->skb, data_len + 60 - ctx->tcp->doff * 4, 0))
137		goto err;
138
139	ctx->data = (void *)(long)ctx->skb->data;
140	ctx->data_end = (void *)(long)ctx->skb->data_end;
141	ctx->eth = (struct ethhdr *)(long)ctx->skb->data;
142	if (ctx->ipv4) {
143		ctx->ipv4 = (struct iphdr *)(ctx->eth + 1);
144		ctx->ipv6 = NULL;
145		ctx->tcp = (struct tcphdr *)(ctx->ipv4 + 1);
146	} else {
147		ctx->ipv4 = NULL;
148		ctx->ipv6 = (struct ipv6hdr *)(ctx->eth + 1);
149		ctx->tcp = (struct tcphdr *)(ctx->ipv6 + 1);
150	}
151
152	if ((void *)ctx->tcp + 60 > ctx->data_end)
153		goto err;
154
155	return 0;
156err:
157	return -1;
158}
159
160static __sum16 tcp_v4_csum(struct tcp_syncookie *ctx, __wsum csum)
161{
162	return csum_tcpudp_magic(ctx->ipv4->saddr, ctx->ipv4->daddr,
163				 ctx->tcp->doff * 4, IPPROTO_TCP, csum);
164}
165
166static __sum16 tcp_v6_csum(struct tcp_syncookie *ctx, __wsum csum)
167{
168	return csum_ipv6_magic(&ctx->ipv6->saddr, &ctx->ipv6->daddr,
169			       ctx->tcp->doff * 4, IPPROTO_TCP, csum);
170}
171
172static int tcp_validate_header(struct tcp_syncookie *ctx)
173{
174	s64 csum;
175
176	if (tcp_reload_headers(ctx))
177		goto err;
178
179	csum = bpf_csum_diff(0, 0, (void *)ctx->tcp, ctx->tcp->doff * 4, 0);
180	if (csum < 0)
181		goto err;
182
183	if (ctx->ipv4) {
184		/* check tcp_v4_csum(csum) is 0 if not on lo. */
185
186		csum = bpf_csum_diff(0, 0, (void *)ctx->ipv4, ctx->ipv4->ihl * 4, 0);
187		if (csum < 0)
188			goto err;
189
190		if (csum_fold(csum) != 0)
191			goto err;
192	} else if (ctx->ipv6) {
193		/* check tcp_v6_csum(csum) is 0 if not on lo. */
194	}
195
196	return 0;
197err:
198	return -1;
199}
200
201static __always_inline void *next(struct tcp_syncookie *ctx, __u32 sz)
202{
203	__u64 off = ctx->off;
204	__u8 *data;
205
206	/* Verifier forbids access to packet when offset exceeds MAX_PACKET_OFF */
207	if (off > MAX_PACKET_OFF - sz)
208		return NULL;
209
210	data = ctx->data + off;
211	barrier_var(data);
212	if (data + sz >= ctx->data_end)
213		return NULL;
214
215	ctx->off += sz;
216	return data;
217}
218
219static int tcp_parse_option(__u32 index, struct tcp_syncookie *ctx)
220{
221	__u8 *opcode, *opsize, *wscale;
222	__u32 *tsval, *tsecr;
223	__u16 *mss;
224	__u32 off;
225
226	off = ctx->off;
227	opcode = next(ctx, 1);
228	if (!opcode)
229		goto stop;
230
231	if (*opcode == TCPOPT_EOL)
232		goto stop;
233
234	if (*opcode == TCPOPT_NOP)
235		goto next;
236
237	opsize = next(ctx, 1);
238	if (!opsize)
239		goto stop;
240
241	if (*opsize < 2)
242		goto stop;
243
244	switch (*opcode) {
245	case TCPOPT_MSS:
246		mss = next(ctx, 2);
247		if (*opsize == TCPOLEN_MSS && ctx->tcp->syn && mss)
248			ctx->attrs.mss = get_unaligned_be16(mss);
249		break;
250	case TCPOPT_WINDOW:
251		wscale = next(ctx, 1);
252		if (*opsize == TCPOLEN_WINDOW && ctx->tcp->syn && wscale) {
253			ctx->attrs.wscale_ok = 1;
254			ctx->attrs.snd_wscale = *wscale;
255		}
256		break;
257	case TCPOPT_TIMESTAMP:
258		tsval = next(ctx, 4);
259		tsecr = next(ctx, 4);
260		if (*opsize == TCPOLEN_TIMESTAMP && tsval && tsecr) {
261			ctx->attrs.rcv_tsval = get_unaligned_be32(tsval);
262			ctx->attrs.rcv_tsecr = get_unaligned_be32(tsecr);
263
264			if (ctx->tcp->syn && ctx->attrs.rcv_tsecr)
265				ctx->attrs.tstamp_ok = 0;
266			else
267				ctx->attrs.tstamp_ok = 1;
268		}
269		break;
270	case TCPOPT_SACK_PERM:
271		if (*opsize == TCPOLEN_SACK_PERM && ctx->tcp->syn)
272			ctx->attrs.sack_ok = 1;
273		break;
274	}
275
276	ctx->off = off + *opsize;
277next:
278	return 0;
279stop:
280	return 1;
281}
282
283static void tcp_parse_options(struct tcp_syncookie *ctx)
284{
285	ctx->off = (__u8 *)(ctx->tcp + 1) - (__u8 *)ctx->data,
286
287	bpf_loop(40, tcp_parse_option, ctx, 0);
288}
289
290static int tcp_validate_sysctl(struct tcp_syncookie *ctx)
291{
292	if ((ctx->ipv4 && ctx->attrs.mss != MSS_LOCAL_IPV4) ||
293	    (ctx->ipv6 && ctx->attrs.mss != MSS_LOCAL_IPV6))
294		goto err;
295
296	if (!ctx->attrs.wscale_ok || ctx->attrs.snd_wscale != 7)
297		goto err;
298
299	if (!ctx->attrs.tstamp_ok)
300		goto err;
301
302	if (!ctx->attrs.sack_ok)
303		goto err;
304
305	if (!ctx->tcp->ece || !ctx->tcp->cwr)
306		goto err;
307
308	return 0;
309err:
310	return -1;
311}
312
313static void tcp_prepare_cookie(struct tcp_syncookie *ctx)
314{
315	u32 seq = bpf_ntohl(ctx->tcp->seq);
316	u64 first = 0, second;
317	int mssind = 0;
318	u32 hash;
319
320	if (ctx->ipv4) {
321		for (mssind = ARRAY_SIZE(msstab4) - 1; mssind; mssind--)
322			if (ctx->attrs.mss >= msstab4[mssind])
323				break;
324
325		ctx->attrs.mss = msstab4[mssind];
326
327		first = (u64)ctx->ipv4->saddr << 32 | ctx->ipv4->daddr;
328	} else if (ctx->ipv6) {
329		for (mssind = ARRAY_SIZE(msstab6) - 1; mssind; mssind--)
330			if (ctx->attrs.mss >= msstab6[mssind])
331				break;
332
333		ctx->attrs.mss = msstab6[mssind];
334
335		first = (u64)ctx->ipv6->saddr.in6_u.u6_addr8[0] << 32 |
336			ctx->ipv6->daddr.in6_u.u6_addr32[0];
337	}
338
339	second = (u64)seq << 32 | ctx->tcp->source << 16 | ctx->tcp->dest;
340	hash = siphash_2u64(first, second, &test_key_siphash);
341
342	if (ctx->attrs.tstamp_ok) {
343		ctx->attrs.rcv_tsecr = bpf_get_prandom_u32();
344		ctx->attrs.rcv_tsecr &= ~COOKIE_MASK;
345		ctx->attrs.rcv_tsecr |= hash & COOKIE_MASK;
346	}
347
348	hash &= ~COOKIE_MASK;
349	hash |= mssind << 6;
350
351	if (ctx->attrs.wscale_ok)
352		hash |= ctx->attrs.snd_wscale & BPF_SYNCOOKIE_WSCALE_MASK;
353
354	if (ctx->attrs.sack_ok)
355		hash |= BPF_SYNCOOKIE_SACK;
356
357	if (ctx->attrs.tstamp_ok && ctx->tcp->ece && ctx->tcp->cwr)
358		hash |= BPF_SYNCOOKIE_ECN;
359
360	ctx->cookie = hash;
361}
362
363static void tcp_write_options(struct tcp_syncookie *ctx)
364{
365	ctx->ptr32 = (__be32 *)(ctx->tcp + 1);
366
367	*ctx->ptr32++ = bpf_htonl(TCPOPT_MSS << 24 | TCPOLEN_MSS << 16 |
368				  ctx->attrs.mss);
369
370	if (ctx->attrs.wscale_ok)
371		*ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 |
372					  TCPOPT_WINDOW << 16 |
373					  TCPOLEN_WINDOW << 8 |
374					  ctx->attrs.snd_wscale);
375
376	if (ctx->attrs.tstamp_ok) {
377		if (ctx->attrs.sack_ok)
378			*ctx->ptr32++ = bpf_htonl(TCPOPT_SACK_PERM << 24 |
379						  TCPOLEN_SACK_PERM << 16 |
380						  TCPOPT_TIMESTAMP << 8 |
381						  TCPOLEN_TIMESTAMP);
382		else
383			*ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 |
384						  TCPOPT_NOP << 16 |
385						  TCPOPT_TIMESTAMP << 8 |
386						  TCPOLEN_TIMESTAMP);
387
388		*ctx->ptr32++ = bpf_htonl(ctx->attrs.rcv_tsecr);
389		*ctx->ptr32++ = bpf_htonl(ctx->attrs.rcv_tsval);
390	} else if (ctx->attrs.sack_ok) {
391		*ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 |
392					  TCPOPT_NOP << 16 |
393					  TCPOPT_SACK_PERM << 8 |
394					  TCPOLEN_SACK_PERM);
395	}
396}
397
398static int tcp_handle_syn(struct tcp_syncookie *ctx)
399{
400	s64 csum;
401
402	if (tcp_validate_header(ctx))
403		goto err;
404
405	tcp_parse_options(ctx);
406
407	if (tcp_validate_sysctl(ctx))
408		goto err;
409
410	tcp_prepare_cookie(ctx);
411	tcp_write_options(ctx);
412
413	swap(ctx->tcp->source, ctx->tcp->dest);
414	ctx->tcp->check = 0;
415	ctx->tcp->ack_seq = bpf_htonl(bpf_ntohl(ctx->tcp->seq) + 1);
416	ctx->tcp->seq = bpf_htonl(ctx->cookie);
417	ctx->tcp->doff = ((long)ctx->ptr32 - (long)ctx->tcp) >> 2;
418	ctx->tcp->ack = 1;
419	if (!ctx->attrs.tstamp_ok || !ctx->tcp->ece || !ctx->tcp->cwr)
420		ctx->tcp->ece = 0;
421	ctx->tcp->cwr = 0;
422
423	csum = bpf_csum_diff(0, 0, (void *)ctx->tcp, ctx->tcp->doff * 4, 0);
424	if (csum < 0)
425		goto err;
426
427	if (ctx->ipv4) {
428		swap(ctx->ipv4->saddr, ctx->ipv4->daddr);
429		ctx->tcp->check = tcp_v4_csum(ctx, csum);
430
431		ctx->ipv4->check = 0;
432		ctx->ipv4->tos = 0;
433		ctx->ipv4->tot_len = bpf_htons((long)ctx->ptr32 - (long)ctx->ipv4);
434		ctx->ipv4->id = 0;
435		ctx->ipv4->ttl = 64;
436
437		csum = bpf_csum_diff(0, 0, (void *)ctx->ipv4, sizeof(*ctx->ipv4), 0);
438		if (csum < 0)
439			goto err;
440
441		ctx->ipv4->check = csum_fold(csum);
442	} else if (ctx->ipv6) {
443		swap(ctx->ipv6->saddr, ctx->ipv6->daddr);
444		ctx->tcp->check = tcp_v6_csum(ctx, csum);
445
446		*(__be32 *)ctx->ipv6 = bpf_htonl(0x60000000);
447		ctx->ipv6->payload_len = bpf_htons((long)ctx->ptr32 - (long)ctx->tcp);
448		ctx->ipv6->hop_limit = 64;
449	}
450
451	swap_array(ctx->eth->h_source, ctx->eth->h_dest);
452
453	if (bpf_skb_change_tail(ctx->skb, (long)ctx->ptr32 - (long)ctx->eth, 0))
454		goto err;
455
456	return bpf_redirect(ctx->skb->ifindex, 0);
457err:
458	return TC_ACT_SHOT;
459}
460
461static int tcp_validate_cookie(struct tcp_syncookie *ctx)
462{
463	u32 cookie = bpf_ntohl(ctx->tcp->ack_seq) - 1;
464	u32 seq = bpf_ntohl(ctx->tcp->seq) - 1;
465	u64 first = 0, second;
466	int mssind;
467	u32 hash;
468
469	if (ctx->ipv4)
470		first = (u64)ctx->ipv4->saddr << 32 | ctx->ipv4->daddr;
471	else if (ctx->ipv6)
472		first = (u64)ctx->ipv6->saddr.in6_u.u6_addr8[0] << 32 |
473			ctx->ipv6->daddr.in6_u.u6_addr32[0];
474
475	second = (u64)seq << 32 | ctx->tcp->source << 16 | ctx->tcp->dest;
476	hash = siphash_2u64(first, second, &test_key_siphash);
477
478	if (ctx->attrs.tstamp_ok)
479		hash -= ctx->attrs.rcv_tsecr & COOKIE_MASK;
480	else
481		hash &= ~COOKIE_MASK;
482
483	hash -= cookie & ~COOKIE_MASK;
484	if (hash)
485		goto err;
486
487	mssind = (cookie & (3 << 6)) >> 6;
488	if (ctx->ipv4) {
489		if (mssind > ARRAY_SIZE(msstab4))
490			goto err;
491
492		ctx->attrs.mss = msstab4[mssind];
493	} else {
494		if (mssind > ARRAY_SIZE(msstab6))
495			goto err;
496
497		ctx->attrs.mss = msstab6[mssind];
498	}
499
500	ctx->attrs.snd_wscale = cookie & BPF_SYNCOOKIE_WSCALE_MASK;
501	ctx->attrs.rcv_wscale = ctx->attrs.snd_wscale;
502	ctx->attrs.wscale_ok = ctx->attrs.snd_wscale == BPF_SYNCOOKIE_WSCALE_MASK;
503	ctx->attrs.sack_ok = cookie & BPF_SYNCOOKIE_SACK;
504	ctx->attrs.ecn_ok = cookie & BPF_SYNCOOKIE_ECN;
505
506	return 0;
507err:
508	return -1;
509}
510
511static int tcp_handle_ack(struct tcp_syncookie *ctx)
512{
513	struct bpf_sock_tuple tuple;
514	struct bpf_sock *skc;
515	int ret = TC_ACT_OK;
516	struct sock *sk;
517	u32 tuple_size;
518
519	if (ctx->ipv4) {
520		tuple.ipv4.saddr = ctx->ipv4->saddr;
521		tuple.ipv4.daddr = ctx->ipv4->daddr;
522		tuple.ipv4.sport = ctx->tcp->source;
523		tuple.ipv4.dport = ctx->tcp->dest;
524		tuple_size = sizeof(tuple.ipv4);
525	} else if (ctx->ipv6) {
526		__builtin_memcpy(tuple.ipv6.saddr, &ctx->ipv6->saddr, sizeof(tuple.ipv6.saddr));
527		__builtin_memcpy(tuple.ipv6.daddr, &ctx->ipv6->daddr, sizeof(tuple.ipv6.daddr));
528		tuple.ipv6.sport = ctx->tcp->source;
529		tuple.ipv6.dport = ctx->tcp->dest;
530		tuple_size = sizeof(tuple.ipv6);
531	} else {
532		goto out;
533	}
534
535	skc = bpf_skc_lookup_tcp(ctx->skb, &tuple, tuple_size, -1, 0);
536	if (!skc)
537		goto out;
538
539	if (skc->state != TCP_LISTEN)
540		goto release;
541
542	sk = (struct sock *)bpf_skc_to_tcp_sock(skc);
543	if (!sk)
544		goto err;
545
546	if (tcp_validate_header(ctx))
547		goto err;
548
549	tcp_parse_options(ctx);
550
551	if (tcp_validate_cookie(ctx))
552		goto err;
553
554	ret = bpf_sk_assign_tcp_reqsk(ctx->skb, sk, &ctx->attrs, sizeof(ctx->attrs));
555	if (ret < 0)
556		goto err;
557
558release:
559	bpf_sk_release(skc);
560out:
561	return ret;
562
563err:
564	ret = TC_ACT_SHOT;
565	goto release;
566}
567
568SEC("tc")
569int tcp_custom_syncookie(struct __sk_buff *skb)
570{
571	struct tcp_syncookie ctx = {
572		.skb = skb,
573	};
574
575	if (tcp_load_headers(&ctx))
576		return TC_ACT_OK;
577
578	if (ctx.tcp->rst)
579		return TC_ACT_OK;
580
581	if (ctx.tcp->syn) {
582		if (ctx.tcp->ack)
583			return TC_ACT_OK;
584
585		handled_syn = true;
586
587		return tcp_handle_syn(&ctx);
588	}
589
590	handled_ack = true;
591
592	return tcp_handle_ack(&ctx);
593}
594
595char _license[] SEC("license") = "GPL";