Linux Audio

Check our new training course

Loading...
Note: File does not exist in v3.15.
  1// SPDX-License-Identifier: GPL-2.0
  2/*
  3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
  4 */
  5
  6#include "noise.h"
  7#include "device.h"
  8#include "peer.h"
  9#include "messages.h"
 10#include "queueing.h"
 11#include "peerlookup.h"
 12
 13#include <linux/rcupdate.h>
 14#include <linux/slab.h>
 15#include <linux/bitmap.h>
 16#include <linux/scatterlist.h>
 17#include <linux/highmem.h>
 18#include <crypto/algapi.h>
 19
 20/* This implements Noise_IKpsk2:
 21 *
 22 * <- s
 23 * ******
 24 * -> e, es, s, ss, {t}
 25 * <- e, ee, se, psk, {}
 26 */
 27
 28static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
 29static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com";
 30static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init;
 31static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init;
 32static atomic64_t keypair_counter = ATOMIC64_INIT(0);
 33
 34void __init wg_noise_init(void)
 35{
 36	struct blake2s_state blake;
 37
 38	blake2s(handshake_init_chaining_key, handshake_name, NULL,
 39		NOISE_HASH_LEN, sizeof(handshake_name), 0);
 40	blake2s_init(&blake, NOISE_HASH_LEN);
 41	blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN);
 42	blake2s_update(&blake, identifier_name, sizeof(identifier_name));
 43	blake2s_final(&blake, handshake_init_hash);
 44}
 45
 46/* Must hold peer->handshake.static_identity->lock */
 47void wg_noise_precompute_static_static(struct wg_peer *peer)
 48{
 49	down_write(&peer->handshake.lock);
 50	if (!peer->handshake.static_identity->has_identity ||
 51	    !curve25519(peer->handshake.precomputed_static_static,
 52			peer->handshake.static_identity->static_private,
 53			peer->handshake.remote_static))
 54		memset(peer->handshake.precomputed_static_static, 0,
 55		       NOISE_PUBLIC_KEY_LEN);
 56	up_write(&peer->handshake.lock);
 57}
 58
 59void wg_noise_handshake_init(struct noise_handshake *handshake,
 60			     struct noise_static_identity *static_identity,
 61			     const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
 62			     const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
 63			     struct wg_peer *peer)
 64{
 65	memset(handshake, 0, sizeof(*handshake));
 66	init_rwsem(&handshake->lock);
 67	handshake->entry.type = INDEX_HASHTABLE_HANDSHAKE;
 68	handshake->entry.peer = peer;
 69	memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN);
 70	if (peer_preshared_key)
 71		memcpy(handshake->preshared_key, peer_preshared_key,
 72		       NOISE_SYMMETRIC_KEY_LEN);
 73	handshake->static_identity = static_identity;
 74	handshake->state = HANDSHAKE_ZEROED;
 75	wg_noise_precompute_static_static(peer);
 76}
 77
 78static void handshake_zero(struct noise_handshake *handshake)
 79{
 80	memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN);
 81	memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN);
 82	memset(&handshake->hash, 0, NOISE_HASH_LEN);
 83	memset(&handshake->chaining_key, 0, NOISE_HASH_LEN);
 84	handshake->remote_index = 0;
 85	handshake->state = HANDSHAKE_ZEROED;
 86}
 87
 88void wg_noise_handshake_clear(struct noise_handshake *handshake)
 89{
 90	down_write(&handshake->lock);
 91	wg_index_hashtable_remove(
 92			handshake->entry.peer->device->index_hashtable,
 93			&handshake->entry);
 94	handshake_zero(handshake);
 95	up_write(&handshake->lock);
 96}
 97
 98static struct noise_keypair *keypair_create(struct wg_peer *peer)
 99{
100	struct noise_keypair *keypair = kzalloc(sizeof(*keypair), GFP_KERNEL);
101
102	if (unlikely(!keypair))
103		return NULL;
104	spin_lock_init(&keypair->receiving_counter.lock);
105	keypair->internal_id = atomic64_inc_return(&keypair_counter);
106	keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
107	keypair->entry.peer = peer;
108	kref_init(&keypair->refcount);
109	return keypair;
110}
111
112static void keypair_free_rcu(struct rcu_head *rcu)
113{
114	kfree_sensitive(container_of(rcu, struct noise_keypair, rcu));
115}
116
117static void keypair_free_kref(struct kref *kref)
118{
119	struct noise_keypair *keypair =
120		container_of(kref, struct noise_keypair, refcount);
121
122	net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n",
123			    keypair->entry.peer->device->dev->name,
124			    keypair->internal_id,
125			    keypair->entry.peer->internal_id);
126	wg_index_hashtable_remove(keypair->entry.peer->device->index_hashtable,
127				  &keypair->entry);
128	call_rcu(&keypair->rcu, keypair_free_rcu);
129}
130
131void wg_noise_keypair_put(struct noise_keypair *keypair, bool unreference_now)
132{
133	if (unlikely(!keypair))
134		return;
135	if (unlikely(unreference_now))
136		wg_index_hashtable_remove(
137			keypair->entry.peer->device->index_hashtable,
138			&keypair->entry);
139	kref_put(&keypair->refcount, keypair_free_kref);
140}
141
142struct noise_keypair *wg_noise_keypair_get(struct noise_keypair *keypair)
143{
144	RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(),
145		"Taking noise keypair reference without holding the RCU BH read lock");
146	if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount)))
147		return NULL;
148	return keypair;
149}
150
151void wg_noise_keypairs_clear(struct noise_keypairs *keypairs)
152{
153	struct noise_keypair *old;
154
155	spin_lock_bh(&keypairs->keypair_update_lock);
156
157	/* We zero the next_keypair before zeroing the others, so that
158	 * wg_noise_received_with_keypair returns early before subsequent ones
159	 * are zeroed.
160	 */
161	old = rcu_dereference_protected(keypairs->next_keypair,
162		lockdep_is_held(&keypairs->keypair_update_lock));
163	RCU_INIT_POINTER(keypairs->next_keypair, NULL);
164	wg_noise_keypair_put(old, true);
165
166	old = rcu_dereference_protected(keypairs->previous_keypair,
167		lockdep_is_held(&keypairs->keypair_update_lock));
168	RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
169	wg_noise_keypair_put(old, true);
170
171	old = rcu_dereference_protected(keypairs->current_keypair,
172		lockdep_is_held(&keypairs->keypair_update_lock));
173	RCU_INIT_POINTER(keypairs->current_keypair, NULL);
174	wg_noise_keypair_put(old, true);
175
176	spin_unlock_bh(&keypairs->keypair_update_lock);
177}
178
179void wg_noise_expire_current_peer_keypairs(struct wg_peer *peer)
180{
181	struct noise_keypair *keypair;
182
183	wg_noise_handshake_clear(&peer->handshake);
184	wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake);
185
186	spin_lock_bh(&peer->keypairs.keypair_update_lock);
187	keypair = rcu_dereference_protected(peer->keypairs.next_keypair,
188			lockdep_is_held(&peer->keypairs.keypair_update_lock));
189	if (keypair)
190		keypair->sending.is_valid = false;
191	keypair = rcu_dereference_protected(peer->keypairs.current_keypair,
192			lockdep_is_held(&peer->keypairs.keypair_update_lock));
193	if (keypair)
194		keypair->sending.is_valid = false;
195	spin_unlock_bh(&peer->keypairs.keypair_update_lock);
196}
197
198static void add_new_keypair(struct noise_keypairs *keypairs,
199			    struct noise_keypair *new_keypair)
200{
201	struct noise_keypair *previous_keypair, *next_keypair, *current_keypair;
202
203	spin_lock_bh(&keypairs->keypair_update_lock);
204	previous_keypair = rcu_dereference_protected(keypairs->previous_keypair,
205		lockdep_is_held(&keypairs->keypair_update_lock));
206	next_keypair = rcu_dereference_protected(keypairs->next_keypair,
207		lockdep_is_held(&keypairs->keypair_update_lock));
208	current_keypair = rcu_dereference_protected(keypairs->current_keypair,
209		lockdep_is_held(&keypairs->keypair_update_lock));
210	if (new_keypair->i_am_the_initiator) {
211		/* If we're the initiator, it means we've sent a handshake, and
212		 * received a confirmation response, which means this new
213		 * keypair can now be used.
214		 */
215		if (next_keypair) {
216			/* If there already was a next keypair pending, we
217			 * demote it to be the previous keypair, and free the
218			 * existing current. Note that this means KCI can result
219			 * in this transition. It would perhaps be more sound to
220			 * always just get rid of the unused next keypair
221			 * instead of putting it in the previous slot, but this
222			 * might be a bit less robust. Something to think about
223			 * for the future.
224			 */
225			RCU_INIT_POINTER(keypairs->next_keypair, NULL);
226			rcu_assign_pointer(keypairs->previous_keypair,
227					   next_keypair);
228			wg_noise_keypair_put(current_keypair, true);
229		} else /* If there wasn't an existing next keypair, we replace
230			* the previous with the current one.
231			*/
232			rcu_assign_pointer(keypairs->previous_keypair,
233					   current_keypair);
234		/* At this point we can get rid of the old previous keypair, and
235		 * set up the new keypair.
236		 */
237		wg_noise_keypair_put(previous_keypair, true);
238		rcu_assign_pointer(keypairs->current_keypair, new_keypair);
239	} else {
240		/* If we're the responder, it means we can't use the new keypair
241		 * until we receive confirmation via the first data packet, so
242		 * we get rid of the existing previous one, the possibly
243		 * existing next one, and slide in the new next one.
244		 */
245		rcu_assign_pointer(keypairs->next_keypair, new_keypair);
246		wg_noise_keypair_put(next_keypair, true);
247		RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
248		wg_noise_keypair_put(previous_keypair, true);
249	}
250	spin_unlock_bh(&keypairs->keypair_update_lock);
251}
252
253bool wg_noise_received_with_keypair(struct noise_keypairs *keypairs,
254				    struct noise_keypair *received_keypair)
255{
256	struct noise_keypair *old_keypair;
257	bool key_is_new;
258
259	/* We first check without taking the spinlock. */
260	key_is_new = received_keypair ==
261		     rcu_access_pointer(keypairs->next_keypair);
262	if (likely(!key_is_new))
263		return false;
264
265	spin_lock_bh(&keypairs->keypair_update_lock);
266	/* After locking, we double check that things didn't change from
267	 * beneath us.
268	 */
269	if (unlikely(received_keypair !=
270		    rcu_dereference_protected(keypairs->next_keypair,
271			    lockdep_is_held(&keypairs->keypair_update_lock)))) {
272		spin_unlock_bh(&keypairs->keypair_update_lock);
273		return false;
274	}
275
276	/* When we've finally received the confirmation, we slide the next
277	 * into the current, the current into the previous, and get rid of
278	 * the old previous.
279	 */
280	old_keypair = rcu_dereference_protected(keypairs->previous_keypair,
281		lockdep_is_held(&keypairs->keypair_update_lock));
282	rcu_assign_pointer(keypairs->previous_keypair,
283		rcu_dereference_protected(keypairs->current_keypair,
284			lockdep_is_held(&keypairs->keypair_update_lock)));
285	wg_noise_keypair_put(old_keypair, true);
286	rcu_assign_pointer(keypairs->current_keypair, received_keypair);
287	RCU_INIT_POINTER(keypairs->next_keypair, NULL);
288
289	spin_unlock_bh(&keypairs->keypair_update_lock);
290	return true;
291}
292
293/* Must hold static_identity->lock */
294void wg_noise_set_static_identity_private_key(
295	struct noise_static_identity *static_identity,
296	const u8 private_key[NOISE_PUBLIC_KEY_LEN])
297{
298	memcpy(static_identity->static_private, private_key,
299	       NOISE_PUBLIC_KEY_LEN);
300	curve25519_clamp_secret(static_identity->static_private);
301	static_identity->has_identity = curve25519_generate_public(
302		static_identity->static_public, private_key);
303}
304
305/* This is Hugo Krawczyk's HKDF:
306 *  - https://eprint.iacr.org/2010/264.pdf
307 *  - https://tools.ietf.org/html/rfc5869
308 */
309static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
310		size_t first_len, size_t second_len, size_t third_len,
311		size_t data_len, const u8 chaining_key[NOISE_HASH_LEN])
312{
313	u8 output[BLAKE2S_HASH_SIZE + 1];
314	u8 secret[BLAKE2S_HASH_SIZE];
315
316	WARN_ON(IS_ENABLED(DEBUG) &&
317		(first_len > BLAKE2S_HASH_SIZE ||
318		 second_len > BLAKE2S_HASH_SIZE ||
319		 third_len > BLAKE2S_HASH_SIZE ||
320		 ((second_len || second_dst || third_len || third_dst) &&
321		  (!first_len || !first_dst)) ||
322		 ((third_len || third_dst) && (!second_len || !second_dst))));
323
324	/* Extract entropy from data into secret */
325	blake2s256_hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN);
326
327	if (!first_dst || !first_len)
328		goto out;
329
330	/* Expand first key: key = secret, data = 0x1 */
331	output[0] = 1;
332	blake2s256_hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE);
333	memcpy(first_dst, output, first_len);
334
335	if (!second_dst || !second_len)
336		goto out;
337
338	/* Expand second key: key = secret, data = first-key || 0x2 */
339	output[BLAKE2S_HASH_SIZE] = 2;
340	blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
341			BLAKE2S_HASH_SIZE);
342	memcpy(second_dst, output, second_len);
343
344	if (!third_dst || !third_len)
345		goto out;
346
347	/* Expand third key: key = secret, data = second-key || 0x3 */
348	output[BLAKE2S_HASH_SIZE] = 3;
349	blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
350			BLAKE2S_HASH_SIZE);
351	memcpy(third_dst, output, third_len);
352
353out:
354	/* Clear sensitive data from stack */
355	memzero_explicit(secret, BLAKE2S_HASH_SIZE);
356	memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
357}
358
359static void derive_keys(struct noise_symmetric_key *first_dst,
360			struct noise_symmetric_key *second_dst,
361			const u8 chaining_key[NOISE_HASH_LEN])
362{
363	u64 birthdate = ktime_get_coarse_boottime_ns();
364	kdf(first_dst->key, second_dst->key, NULL, NULL,
365	    NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
366	    chaining_key);
367	first_dst->birthdate = second_dst->birthdate = birthdate;
368	first_dst->is_valid = second_dst->is_valid = true;
369}
370
371static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
372				u8 key[NOISE_SYMMETRIC_KEY_LEN],
373				const u8 private[NOISE_PUBLIC_KEY_LEN],
374				const u8 public[NOISE_PUBLIC_KEY_LEN])
375{
376	u8 dh_calculation[NOISE_PUBLIC_KEY_LEN];
377
378	if (unlikely(!curve25519(dh_calculation, private, public)))
379		return false;
380	kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN,
381	    NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key);
382	memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN);
383	return true;
384}
385
386static bool __must_check mix_precomputed_dh(u8 chaining_key[NOISE_HASH_LEN],
387					    u8 key[NOISE_SYMMETRIC_KEY_LEN],
388					    const u8 precomputed[NOISE_PUBLIC_KEY_LEN])
389{
390	static u8 zero_point[NOISE_PUBLIC_KEY_LEN];
391	if (unlikely(!crypto_memneq(precomputed, zero_point, NOISE_PUBLIC_KEY_LEN)))
392		return false;
393	kdf(chaining_key, key, NULL, precomputed, NOISE_HASH_LEN,
394	    NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
395	    chaining_key);
396	return true;
397}
398
399static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
400{
401	struct blake2s_state blake;
402
403	blake2s_init(&blake, NOISE_HASH_LEN);
404	blake2s_update(&blake, hash, NOISE_HASH_LEN);
405	blake2s_update(&blake, src, src_len);
406	blake2s_final(&blake, hash);
407}
408
409static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN],
410		    u8 key[NOISE_SYMMETRIC_KEY_LEN],
411		    const u8 psk[NOISE_SYMMETRIC_KEY_LEN])
412{
413	u8 temp_hash[NOISE_HASH_LEN];
414
415	kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN,
416	    NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key);
417	mix_hash(hash, temp_hash, NOISE_HASH_LEN);
418	memzero_explicit(temp_hash, NOISE_HASH_LEN);
419}
420
421static void handshake_init(u8 chaining_key[NOISE_HASH_LEN],
422			   u8 hash[NOISE_HASH_LEN],
423			   const u8 remote_static[NOISE_PUBLIC_KEY_LEN])
424{
425	memcpy(hash, handshake_init_hash, NOISE_HASH_LEN);
426	memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN);
427	mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN);
428}
429
430static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext,
431			    size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
432			    u8 hash[NOISE_HASH_LEN])
433{
434	chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash,
435				 NOISE_HASH_LEN,
436				 0 /* Always zero for Noise_IK */, key);
437	mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len));
438}
439
440static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext,
441			    size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
442			    u8 hash[NOISE_HASH_LEN])
443{
444	if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len,
445				      hash, NOISE_HASH_LEN,
446				      0 /* Always zero for Noise_IK */, key))
447		return false;
448	mix_hash(hash, src_ciphertext, src_len);
449	return true;
450}
451
452static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN],
453			      const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN],
454			      u8 chaining_key[NOISE_HASH_LEN],
455			      u8 hash[NOISE_HASH_LEN])
456{
457	if (ephemeral_dst != ephemeral_src)
458		memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
459	mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
460	kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0,
461	    NOISE_PUBLIC_KEY_LEN, chaining_key);
462}
463
464static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN])
465{
466	struct timespec64 now;
467
468	ktime_get_real_ts64(&now);
469
470	/* In order to prevent some sort of infoleak from precise timers, we
471	 * round down the nanoseconds part to the closest rounded-down power of
472	 * two to the maximum initiations per second allowed anyway by the
473	 * implementation.
474	 */
475	now.tv_nsec = ALIGN_DOWN(now.tv_nsec,
476		rounddown_pow_of_two(NSEC_PER_SEC / INITIATIONS_PER_SECOND));
477
478	/* https://cr.yp.to/libtai/tai64.html */
479	*(__be64 *)output = cpu_to_be64(0x400000000000000aULL + now.tv_sec);
480	*(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec);
481}
482
483bool
484wg_noise_handshake_create_initiation(struct message_handshake_initiation *dst,
485				     struct noise_handshake *handshake)
486{
487	u8 timestamp[NOISE_TIMESTAMP_LEN];
488	u8 key[NOISE_SYMMETRIC_KEY_LEN];
489	bool ret = false;
490
491	/* We need to wait for crng _before_ taking any locks, since
492	 * curve25519_generate_secret uses get_random_bytes_wait.
493	 */
494	wait_for_random_bytes();
495
496	down_read(&handshake->static_identity->lock);
497	down_write(&handshake->lock);
498
499	if (unlikely(!handshake->static_identity->has_identity))
500		goto out;
501
502	dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION);
503
504	handshake_init(handshake->chaining_key, handshake->hash,
505		       handshake->remote_static);
506
507	/* e */
508	curve25519_generate_secret(handshake->ephemeral_private);
509	if (!curve25519_generate_public(dst->unencrypted_ephemeral,
510					handshake->ephemeral_private))
511		goto out;
512	message_ephemeral(dst->unencrypted_ephemeral,
513			  dst->unencrypted_ephemeral, handshake->chaining_key,
514			  handshake->hash);
515
516	/* es */
517	if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private,
518		    handshake->remote_static))
519		goto out;
520
521	/* s */
522	message_encrypt(dst->encrypted_static,
523			handshake->static_identity->static_public,
524			NOISE_PUBLIC_KEY_LEN, key, handshake->hash);
525
526	/* ss */
527	if (!mix_precomputed_dh(handshake->chaining_key, key,
528				handshake->precomputed_static_static))
529		goto out;
530
531	/* {t} */
532	tai64n_now(timestamp);
533	message_encrypt(dst->encrypted_timestamp, timestamp,
534			NOISE_TIMESTAMP_LEN, key, handshake->hash);
535
536	dst->sender_index = wg_index_hashtable_insert(
537		handshake->entry.peer->device->index_hashtable,
538		&handshake->entry);
539
540	handshake->state = HANDSHAKE_CREATED_INITIATION;
541	ret = true;
542
543out:
544	up_write(&handshake->lock);
545	up_read(&handshake->static_identity->lock);
546	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
547	return ret;
548}
549
550struct wg_peer *
551wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
552				      struct wg_device *wg)
553{
554	struct wg_peer *peer = NULL, *ret_peer = NULL;
555	struct noise_handshake *handshake;
556	bool replay_attack, flood_attack;
557	u8 key[NOISE_SYMMETRIC_KEY_LEN];
558	u8 chaining_key[NOISE_HASH_LEN];
559	u8 hash[NOISE_HASH_LEN];
560	u8 s[NOISE_PUBLIC_KEY_LEN];
561	u8 e[NOISE_PUBLIC_KEY_LEN];
562	u8 t[NOISE_TIMESTAMP_LEN];
563	u64 initiation_consumption;
564
565	down_read(&wg->static_identity.lock);
566	if (unlikely(!wg->static_identity.has_identity))
567		goto out;
568
569	handshake_init(chaining_key, hash, wg->static_identity.static_public);
570
571	/* e */
572	message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
573
574	/* es */
575	if (!mix_dh(chaining_key, key, wg->static_identity.static_private, e))
576		goto out;
577
578	/* s */
579	if (!message_decrypt(s, src->encrypted_static,
580			     sizeof(src->encrypted_static), key, hash))
581		goto out;
582
583	/* Lookup which peer we're actually talking to */
584	peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable, s);
585	if (!peer)
586		goto out;
587	handshake = &peer->handshake;
588
589	/* ss */
590	if (!mix_precomputed_dh(chaining_key, key,
591				handshake->precomputed_static_static))
592	    goto out;
593
594	/* {t} */
595	if (!message_decrypt(t, src->encrypted_timestamp,
596			     sizeof(src->encrypted_timestamp), key, hash))
597		goto out;
598
599	down_read(&handshake->lock);
600	replay_attack = memcmp(t, handshake->latest_timestamp,
601			       NOISE_TIMESTAMP_LEN) <= 0;
602	flood_attack = (s64)handshake->last_initiation_consumption +
603			       NSEC_PER_SEC / INITIATIONS_PER_SECOND >
604		       (s64)ktime_get_coarse_boottime_ns();
605	up_read(&handshake->lock);
606	if (replay_attack || flood_attack)
607		goto out;
608
609	/* Success! Copy everything to peer */
610	down_write(&handshake->lock);
611	memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
612	if (memcmp(t, handshake->latest_timestamp, NOISE_TIMESTAMP_LEN) > 0)
613		memcpy(handshake->latest_timestamp, t, NOISE_TIMESTAMP_LEN);
614	memcpy(handshake->hash, hash, NOISE_HASH_LEN);
615	memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
616	handshake->remote_index = src->sender_index;
617	initiation_consumption = ktime_get_coarse_boottime_ns();
618	if ((s64)(handshake->last_initiation_consumption - initiation_consumption) < 0)
619		handshake->last_initiation_consumption = initiation_consumption;
620	handshake->state = HANDSHAKE_CONSUMED_INITIATION;
621	up_write(&handshake->lock);
622	ret_peer = peer;
623
624out:
625	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
626	memzero_explicit(hash, NOISE_HASH_LEN);
627	memzero_explicit(chaining_key, NOISE_HASH_LEN);
628	up_read(&wg->static_identity.lock);
629	if (!ret_peer)
630		wg_peer_put(peer);
631	return ret_peer;
632}
633
634bool wg_noise_handshake_create_response(struct message_handshake_response *dst,
635					struct noise_handshake *handshake)
636{
637	u8 key[NOISE_SYMMETRIC_KEY_LEN];
638	bool ret = false;
639
640	/* We need to wait for crng _before_ taking any locks, since
641	 * curve25519_generate_secret uses get_random_bytes_wait.
642	 */
643	wait_for_random_bytes();
644
645	down_read(&handshake->static_identity->lock);
646	down_write(&handshake->lock);
647
648	if (handshake->state != HANDSHAKE_CONSUMED_INITIATION)
649		goto out;
650
651	dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE);
652	dst->receiver_index = handshake->remote_index;
653
654	/* e */
655	curve25519_generate_secret(handshake->ephemeral_private);
656	if (!curve25519_generate_public(dst->unencrypted_ephemeral,
657					handshake->ephemeral_private))
658		goto out;
659	message_ephemeral(dst->unencrypted_ephemeral,
660			  dst->unencrypted_ephemeral, handshake->chaining_key,
661			  handshake->hash);
662
663	/* ee */
664	if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
665		    handshake->remote_ephemeral))
666		goto out;
667
668	/* se */
669	if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
670		    handshake->remote_static))
671		goto out;
672
673	/* psk */
674	mix_psk(handshake->chaining_key, handshake->hash, key,
675		handshake->preshared_key);
676
677	/* {} */
678	message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash);
679
680	dst->sender_index = wg_index_hashtable_insert(
681		handshake->entry.peer->device->index_hashtable,
682		&handshake->entry);
683
684	handshake->state = HANDSHAKE_CREATED_RESPONSE;
685	ret = true;
686
687out:
688	up_write(&handshake->lock);
689	up_read(&handshake->static_identity->lock);
690	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
691	return ret;
692}
693
694struct wg_peer *
695wg_noise_handshake_consume_response(struct message_handshake_response *src,
696				    struct wg_device *wg)
697{
698	enum noise_handshake_state state = HANDSHAKE_ZEROED;
699	struct wg_peer *peer = NULL, *ret_peer = NULL;
700	struct noise_handshake *handshake;
701	u8 key[NOISE_SYMMETRIC_KEY_LEN];
702	u8 hash[NOISE_HASH_LEN];
703	u8 chaining_key[NOISE_HASH_LEN];
704	u8 e[NOISE_PUBLIC_KEY_LEN];
705	u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
706	u8 static_private[NOISE_PUBLIC_KEY_LEN];
707	u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN];
708
709	down_read(&wg->static_identity.lock);
710
711	if (unlikely(!wg->static_identity.has_identity))
712		goto out;
713
714	handshake = (struct noise_handshake *)wg_index_hashtable_lookup(
715		wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE,
716		src->receiver_index, &peer);
717	if (unlikely(!handshake))
718		goto out;
719
720	down_read(&handshake->lock);
721	state = handshake->state;
722	memcpy(hash, handshake->hash, NOISE_HASH_LEN);
723	memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
724	memcpy(ephemeral_private, handshake->ephemeral_private,
725	       NOISE_PUBLIC_KEY_LEN);
726	memcpy(preshared_key, handshake->preshared_key,
727	       NOISE_SYMMETRIC_KEY_LEN);
728	up_read(&handshake->lock);
729
730	if (state != HANDSHAKE_CREATED_INITIATION)
731		goto fail;
732
733	/* e */
734	message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
735
736	/* ee */
737	if (!mix_dh(chaining_key, NULL, ephemeral_private, e))
738		goto fail;
739
740	/* se */
741	if (!mix_dh(chaining_key, NULL, wg->static_identity.static_private, e))
742		goto fail;
743
744	/* psk */
745	mix_psk(chaining_key, hash, key, preshared_key);
746
747	/* {} */
748	if (!message_decrypt(NULL, src->encrypted_nothing,
749			     sizeof(src->encrypted_nothing), key, hash))
750		goto fail;
751
752	/* Success! Copy everything to peer */
753	down_write(&handshake->lock);
754	/* It's important to check that the state is still the same, while we
755	 * have an exclusive lock.
756	 */
757	if (handshake->state != state) {
758		up_write(&handshake->lock);
759		goto fail;
760	}
761	memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
762	memcpy(handshake->hash, hash, NOISE_HASH_LEN);
763	memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
764	handshake->remote_index = src->sender_index;
765	handshake->state = HANDSHAKE_CONSUMED_RESPONSE;
766	up_write(&handshake->lock);
767	ret_peer = peer;
768	goto out;
769
770fail:
771	wg_peer_put(peer);
772out:
773	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
774	memzero_explicit(hash, NOISE_HASH_LEN);
775	memzero_explicit(chaining_key, NOISE_HASH_LEN);
776	memzero_explicit(ephemeral_private, NOISE_PUBLIC_KEY_LEN);
777	memzero_explicit(static_private, NOISE_PUBLIC_KEY_LEN);
778	memzero_explicit(preshared_key, NOISE_SYMMETRIC_KEY_LEN);
779	up_read(&wg->static_identity.lock);
780	return ret_peer;
781}
782
783bool wg_noise_handshake_begin_session(struct noise_handshake *handshake,
784				      struct noise_keypairs *keypairs)
785{
786	struct noise_keypair *new_keypair;
787	bool ret = false;
788
789	down_write(&handshake->lock);
790	if (handshake->state != HANDSHAKE_CREATED_RESPONSE &&
791	    handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
792		goto out;
793
794	new_keypair = keypair_create(handshake->entry.peer);
795	if (!new_keypair)
796		goto out;
797	new_keypair->i_am_the_initiator = handshake->state ==
798					  HANDSHAKE_CONSUMED_RESPONSE;
799	new_keypair->remote_index = handshake->remote_index;
800
801	if (new_keypair->i_am_the_initiator)
802		derive_keys(&new_keypair->sending, &new_keypair->receiving,
803			    handshake->chaining_key);
804	else
805		derive_keys(&new_keypair->receiving, &new_keypair->sending,
806			    handshake->chaining_key);
807
808	handshake_zero(handshake);
809	rcu_read_lock_bh();
810	if (likely(!READ_ONCE(container_of(handshake, struct wg_peer,
811					   handshake)->is_dead))) {
812		add_new_keypair(keypairs, new_keypair);
813		net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n",
814				    handshake->entry.peer->device->dev->name,
815				    new_keypair->internal_id,
816				    handshake->entry.peer->internal_id);
817		ret = wg_index_hashtable_replace(
818			handshake->entry.peer->device->index_hashtable,
819			&handshake->entry, &new_keypair->entry);
820	} else {
821		kfree_sensitive(new_keypair);
822	}
823	rcu_read_unlock_bh();
824
825out:
826	up_write(&handshake->lock);
827	return ret;
828}