Linux Audio

Check our new training course

Loading...
v6.13.7
  1// SPDX-License-Identifier: GPL-2.0-only
  2/*
  3 * aes-ce-glue.c - wrapper code for ARMv8 AES
  4 *
  5 * Copyright (C) 2015 Linaro Ltd <ard.biesheuvel@linaro.org>
  6 */
  7
  8#include <asm/hwcap.h>
  9#include <asm/neon.h>
 10#include <asm/simd.h>
 11#include <linux/unaligned.h>
 12#include <crypto/aes.h>
 13#include <crypto/ctr.h>
 14#include <crypto/internal/simd.h>
 15#include <crypto/internal/skcipher.h>
 16#include <crypto/scatterwalk.h>
 17#include <linux/cpufeature.h>
 18#include <linux/module.h>
 19#include <crypto/xts.h>
 20
 21MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
 22MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
 23MODULE_LICENSE("GPL v2");
 24
 25/* defined in aes-ce-core.S */
 26asmlinkage u32 ce_aes_sub(u32 input);
 27asmlinkage void ce_aes_invert(void *dst, void *src);
 28
 29asmlinkage void ce_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
 30				   int rounds, int blocks);
 31asmlinkage void ce_aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
 32				   int rounds, int blocks);
 33
 34asmlinkage void ce_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
 35				   int rounds, int blocks, u8 iv[]);
 36asmlinkage void ce_aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
 37				   int rounds, int blocks, u8 iv[]);
 38asmlinkage void ce_aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
 39				   int rounds, int bytes, u8 const iv[]);
 40asmlinkage void ce_aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
 41				   int rounds, int bytes, u8 const iv[]);
 42
 43asmlinkage void ce_aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
 44				   int rounds, int blocks, u8 ctr[]);
 45
 46asmlinkage void ce_aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
 47				   int rounds, int bytes, u8 iv[],
 48				   u32 const rk2[], int first);
 49asmlinkage void ce_aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
 50				   int rounds, int bytes, u8 iv[],
 51				   u32 const rk2[], int first);
 52
 53struct aes_block {
 54	u8 b[AES_BLOCK_SIZE];
 55};
 56
 57static int num_rounds(struct crypto_aes_ctx *ctx)
 58{
 59	/*
 60	 * # of rounds specified by AES:
 61	 * 128 bit key		10 rounds
 62	 * 192 bit key		12 rounds
 63	 * 256 bit key		14 rounds
 64	 * => n byte key	=> 6 + (n/4) rounds
 65	 */
 66	return 6 + ctx->key_length / 4;
 67}
 68
 69static int ce_aes_expandkey(struct crypto_aes_ctx *ctx, const u8 *in_key,
 70			    unsigned int key_len)
 71{
 72	/*
 73	 * The AES key schedule round constants
 74	 */
 75	static u8 const rcon[] = {
 76		0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36,
 77	};
 78
 79	u32 kwords = key_len / sizeof(u32);
 80	struct aes_block *key_enc, *key_dec;
 81	int i, j;
 82
 83	if (key_len != AES_KEYSIZE_128 &&
 84	    key_len != AES_KEYSIZE_192 &&
 85	    key_len != AES_KEYSIZE_256)
 86		return -EINVAL;
 87
 88	ctx->key_length = key_len;
 89	for (i = 0; i < kwords; i++)
 90		ctx->key_enc[i] = get_unaligned_le32(in_key + i * sizeof(u32));
 91
 92	kernel_neon_begin();
 93	for (i = 0; i < sizeof(rcon); i++) {
 94		u32 *rki = ctx->key_enc + (i * kwords);
 95		u32 *rko = rki + kwords;
 96
 97		rko[0] = ror32(ce_aes_sub(rki[kwords - 1]), 8);
 98		rko[0] = rko[0] ^ rki[0] ^ rcon[i];
 99		rko[1] = rko[0] ^ rki[1];
100		rko[2] = rko[1] ^ rki[2];
101		rko[3] = rko[2] ^ rki[3];
102
103		if (key_len == AES_KEYSIZE_192) {
104			if (i >= 7)
105				break;
106			rko[4] = rko[3] ^ rki[4];
107			rko[5] = rko[4] ^ rki[5];
108		} else if (key_len == AES_KEYSIZE_256) {
109			if (i >= 6)
110				break;
111			rko[4] = ce_aes_sub(rko[3]) ^ rki[4];
112			rko[5] = rko[4] ^ rki[5];
113			rko[6] = rko[5] ^ rki[6];
114			rko[7] = rko[6] ^ rki[7];
115		}
116	}
117
118	/*
119	 * Generate the decryption keys for the Equivalent Inverse Cipher.
120	 * This involves reversing the order of the round keys, and applying
121	 * the Inverse Mix Columns transformation on all but the first and
122	 * the last one.
123	 */
124	key_enc = (struct aes_block *)ctx->key_enc;
125	key_dec = (struct aes_block *)ctx->key_dec;
126	j = num_rounds(ctx);
127
128	key_dec[0] = key_enc[j];
129	for (i = 1, j--; j > 0; i++, j--)
130		ce_aes_invert(key_dec + i, key_enc + j);
131	key_dec[i] = key_enc[0];
132
133	kernel_neon_end();
134	return 0;
135}
136
137static int ce_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
138			 unsigned int key_len)
139{
140	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
141
142	return ce_aes_expandkey(ctx, in_key, key_len);
143}
144
145struct crypto_aes_xts_ctx {
146	struct crypto_aes_ctx key1;
147	struct crypto_aes_ctx __aligned(8) key2;
148};
149
150static int xts_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
151		       unsigned int key_len)
152{
153	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
154	int ret;
155
156	ret = xts_verify_key(tfm, in_key, key_len);
157	if (ret)
158		return ret;
159
160	ret = ce_aes_expandkey(&ctx->key1, in_key, key_len / 2);
161	if (!ret)
162		ret = ce_aes_expandkey(&ctx->key2, &in_key[key_len / 2],
163				       key_len / 2);
164	return ret;
165}
166
167static int ecb_encrypt(struct skcipher_request *req)
168{
169	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
170	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
171	struct skcipher_walk walk;
172	unsigned int blocks;
173	int err;
174
175	err = skcipher_walk_virt(&walk, req, false);
176
177	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
178		kernel_neon_begin();
179		ce_aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
180				   ctx->key_enc, num_rounds(ctx), blocks);
181		kernel_neon_end();
182		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
183	}
184	return err;
185}
186
187static int ecb_decrypt(struct skcipher_request *req)
188{
189	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
190	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
191	struct skcipher_walk walk;
192	unsigned int blocks;
193	int err;
194
195	err = skcipher_walk_virt(&walk, req, false);
196
197	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
198		kernel_neon_begin();
199		ce_aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
200				   ctx->key_dec, num_rounds(ctx), blocks);
201		kernel_neon_end();
202		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
203	}
204	return err;
205}
206
207static int cbc_encrypt_walk(struct skcipher_request *req,
208			    struct skcipher_walk *walk)
209{
210	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
211	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
212	unsigned int blocks;
213	int err = 0;
214
215	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
216		kernel_neon_begin();
217		ce_aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
218				   ctx->key_enc, num_rounds(ctx), blocks,
219				   walk->iv);
220		kernel_neon_end();
221		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
222	}
223	return err;
224}
225
226static int cbc_encrypt(struct skcipher_request *req)
227{
228	struct skcipher_walk walk;
229	int err;
230
231	err = skcipher_walk_virt(&walk, req, false);
232	if (err)
233		return err;
234	return cbc_encrypt_walk(req, &walk);
235}
236
237static int cbc_decrypt_walk(struct skcipher_request *req,
238			    struct skcipher_walk *walk)
239{
240	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
241	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
242	unsigned int blocks;
243	int err = 0;
244
245	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
246		kernel_neon_begin();
247		ce_aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
248				   ctx->key_dec, num_rounds(ctx), blocks,
249				   walk->iv);
250		kernel_neon_end();
251		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
252	}
253	return err;
254}
255
256static int cbc_decrypt(struct skcipher_request *req)
257{
258	struct skcipher_walk walk;
259	int err;
260
261	err = skcipher_walk_virt(&walk, req, false);
262	if (err)
263		return err;
264	return cbc_decrypt_walk(req, &walk);
265}
266
267static int cts_cbc_encrypt(struct skcipher_request *req)
268{
269	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
270	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
271	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
272	struct scatterlist *src = req->src, *dst = req->dst;
273	struct scatterlist sg_src[2], sg_dst[2];
274	struct skcipher_request subreq;
275	struct skcipher_walk walk;
276	int err;
277
278	skcipher_request_set_tfm(&subreq, tfm);
279	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
280				      NULL, NULL);
281
282	if (req->cryptlen <= AES_BLOCK_SIZE) {
283		if (req->cryptlen < AES_BLOCK_SIZE)
284			return -EINVAL;
285		cbc_blocks = 1;
286	}
287
288	if (cbc_blocks > 0) {
289		skcipher_request_set_crypt(&subreq, req->src, req->dst,
290					   cbc_blocks * AES_BLOCK_SIZE,
291					   req->iv);
292
293		err = skcipher_walk_virt(&walk, &subreq, false) ?:
294		      cbc_encrypt_walk(&subreq, &walk);
295		if (err)
296			return err;
297
298		if (req->cryptlen == AES_BLOCK_SIZE)
299			return 0;
300
301		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
302		if (req->dst != req->src)
303			dst = scatterwalk_ffwd(sg_dst, req->dst,
304					       subreq.cryptlen);
305	}
306
307	/* handle ciphertext stealing */
308	skcipher_request_set_crypt(&subreq, src, dst,
309				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
310				   req->iv);
311
312	err = skcipher_walk_virt(&walk, &subreq, false);
313	if (err)
314		return err;
315
316	kernel_neon_begin();
317	ce_aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
318			       ctx->key_enc, num_rounds(ctx), walk.nbytes,
319			       walk.iv);
320	kernel_neon_end();
321
322	return skcipher_walk_done(&walk, 0);
323}
324
325static int cts_cbc_decrypt(struct skcipher_request *req)
326{
327	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
328	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
329	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
330	struct scatterlist *src = req->src, *dst = req->dst;
331	struct scatterlist sg_src[2], sg_dst[2];
332	struct skcipher_request subreq;
333	struct skcipher_walk walk;
334	int err;
335
336	skcipher_request_set_tfm(&subreq, tfm);
337	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
338				      NULL, NULL);
339
340	if (req->cryptlen <= AES_BLOCK_SIZE) {
341		if (req->cryptlen < AES_BLOCK_SIZE)
342			return -EINVAL;
343		cbc_blocks = 1;
344	}
345
346	if (cbc_blocks > 0) {
347		skcipher_request_set_crypt(&subreq, req->src, req->dst,
348					   cbc_blocks * AES_BLOCK_SIZE,
349					   req->iv);
350
351		err = skcipher_walk_virt(&walk, &subreq, false) ?:
352		      cbc_decrypt_walk(&subreq, &walk);
353		if (err)
354			return err;
355
356		if (req->cryptlen == AES_BLOCK_SIZE)
357			return 0;
358
359		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
360		if (req->dst != req->src)
361			dst = scatterwalk_ffwd(sg_dst, req->dst,
362					       subreq.cryptlen);
363	}
364
365	/* handle ciphertext stealing */
366	skcipher_request_set_crypt(&subreq, src, dst,
367				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
368				   req->iv);
369
370	err = skcipher_walk_virt(&walk, &subreq, false);
371	if (err)
372		return err;
373
374	kernel_neon_begin();
375	ce_aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
376			       ctx->key_dec, num_rounds(ctx), walk.nbytes,
377			       walk.iv);
378	kernel_neon_end();
379
380	return skcipher_walk_done(&walk, 0);
381}
382
383static int ctr_encrypt(struct skcipher_request *req)
384{
385	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
386	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
387	struct skcipher_walk walk;
388	int err, blocks;
389
390	err = skcipher_walk_virt(&walk, req, false);
391
392	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
393		kernel_neon_begin();
394		ce_aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
395				   ctx->key_enc, num_rounds(ctx), blocks,
396				   walk.iv);
397		kernel_neon_end();
398		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
399	}
400	if (walk.nbytes) {
401		u8 __aligned(8) tail[AES_BLOCK_SIZE];
402		unsigned int nbytes = walk.nbytes;
403		u8 *tdst = walk.dst.virt.addr;
404		u8 *tsrc = walk.src.virt.addr;
405
406		/*
407		 * Tell aes_ctr_encrypt() to process a tail block.
408		 */
409		blocks = -1;
410
411		kernel_neon_begin();
412		ce_aes_ctr_encrypt(tail, NULL, ctx->key_enc, num_rounds(ctx),
413				   blocks, walk.iv);
414		kernel_neon_end();
415		crypto_xor_cpy(tdst, tsrc, tail, nbytes);
416		err = skcipher_walk_done(&walk, 0);
417	}
418	return err;
419}
420
421static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
422{
423	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
424	unsigned long flags;
425
426	/*
427	 * Temporarily disable interrupts to avoid races where
428	 * cachelines are evicted when the CPU is interrupted
429	 * to do something else.
430	 */
431	local_irq_save(flags);
432	aes_encrypt(ctx, dst, src);
433	local_irq_restore(flags);
434}
435
436static int ctr_encrypt_sync(struct skcipher_request *req)
437{
438	if (!crypto_simd_usable())
439		return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
440
441	return ctr_encrypt(req);
442}
443
444static int xts_encrypt(struct skcipher_request *req)
445{
446	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
447	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
448	int err, first, rounds = num_rounds(&ctx->key1);
449	int tail = req->cryptlen % AES_BLOCK_SIZE;
450	struct scatterlist sg_src[2], sg_dst[2];
451	struct skcipher_request subreq;
452	struct scatterlist *src, *dst;
453	struct skcipher_walk walk;
454
455	if (req->cryptlen < AES_BLOCK_SIZE)
456		return -EINVAL;
457
458	err = skcipher_walk_virt(&walk, req, false);
459
460	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
461		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
462					      AES_BLOCK_SIZE) - 2;
463
464		skcipher_walk_abort(&walk);
465
466		skcipher_request_set_tfm(&subreq, tfm);
467		skcipher_request_set_callback(&subreq,
468					      skcipher_request_flags(req),
469					      NULL, NULL);
470		skcipher_request_set_crypt(&subreq, req->src, req->dst,
471					   xts_blocks * AES_BLOCK_SIZE,
472					   req->iv);
473		req = &subreq;
474		err = skcipher_walk_virt(&walk, req, false);
475	} else {
476		tail = 0;
477	}
478
479	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
480		int nbytes = walk.nbytes;
481
482		if (walk.nbytes < walk.total)
483			nbytes &= ~(AES_BLOCK_SIZE - 1);
484
485		kernel_neon_begin();
486		ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
487				   ctx->key1.key_enc, rounds, nbytes, walk.iv,
488				   ctx->key2.key_enc, first);
489		kernel_neon_end();
490		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
491	}
492
493	if (err || likely(!tail))
494		return err;
495
496	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
497	if (req->dst != req->src)
498		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
499
500	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
501				   req->iv);
502
503	err = skcipher_walk_virt(&walk, req, false);
504	if (err)
505		return err;
506
507	kernel_neon_begin();
508	ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
509			   ctx->key1.key_enc, rounds, walk.nbytes, walk.iv,
510			   ctx->key2.key_enc, first);
511	kernel_neon_end();
512
513	return skcipher_walk_done(&walk, 0);
514}
515
516static int xts_decrypt(struct skcipher_request *req)
517{
518	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
519	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
520	int err, first, rounds = num_rounds(&ctx->key1);
521	int tail = req->cryptlen % AES_BLOCK_SIZE;
522	struct scatterlist sg_src[2], sg_dst[2];
523	struct skcipher_request subreq;
524	struct scatterlist *src, *dst;
525	struct skcipher_walk walk;
526
527	if (req->cryptlen < AES_BLOCK_SIZE)
528		return -EINVAL;
529
530	err = skcipher_walk_virt(&walk, req, false);
531
532	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
533		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
534					      AES_BLOCK_SIZE) - 2;
535
536		skcipher_walk_abort(&walk);
537
538		skcipher_request_set_tfm(&subreq, tfm);
539		skcipher_request_set_callback(&subreq,
540					      skcipher_request_flags(req),
541					      NULL, NULL);
542		skcipher_request_set_crypt(&subreq, req->src, req->dst,
543					   xts_blocks * AES_BLOCK_SIZE,
544					   req->iv);
545		req = &subreq;
546		err = skcipher_walk_virt(&walk, req, false);
547	} else {
548		tail = 0;
549	}
550
551	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
552		int nbytes = walk.nbytes;
553
554		if (walk.nbytes < walk.total)
555			nbytes &= ~(AES_BLOCK_SIZE - 1);
556
557		kernel_neon_begin();
558		ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
559				   ctx->key1.key_dec, rounds, nbytes, walk.iv,
560				   ctx->key2.key_enc, first);
561		kernel_neon_end();
562		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
563	}
564
565	if (err || likely(!tail))
566		return err;
567
568	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
569	if (req->dst != req->src)
570		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
571
572	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
573				   req->iv);
574
575	err = skcipher_walk_virt(&walk, req, false);
576	if (err)
577		return err;
578
579	kernel_neon_begin();
580	ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
581			   ctx->key1.key_dec, rounds, walk.nbytes, walk.iv,
582			   ctx->key2.key_enc, first);
583	kernel_neon_end();
584
585	return skcipher_walk_done(&walk, 0);
586}
587
588static struct skcipher_alg aes_algs[] = { {
589	.base.cra_name		= "__ecb(aes)",
590	.base.cra_driver_name	= "__ecb-aes-ce",
591	.base.cra_priority	= 300,
592	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
593	.base.cra_blocksize	= AES_BLOCK_SIZE,
594	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
595	.base.cra_module	= THIS_MODULE,
596
597	.min_keysize		= AES_MIN_KEY_SIZE,
598	.max_keysize		= AES_MAX_KEY_SIZE,
599	.setkey			= ce_aes_setkey,
600	.encrypt		= ecb_encrypt,
601	.decrypt		= ecb_decrypt,
602}, {
603	.base.cra_name		= "__cbc(aes)",
604	.base.cra_driver_name	= "__cbc-aes-ce",
605	.base.cra_priority	= 300,
606	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
607	.base.cra_blocksize	= AES_BLOCK_SIZE,
608	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
609	.base.cra_module	= THIS_MODULE,
610
611	.min_keysize		= AES_MIN_KEY_SIZE,
612	.max_keysize		= AES_MAX_KEY_SIZE,
613	.ivsize			= AES_BLOCK_SIZE,
614	.setkey			= ce_aes_setkey,
615	.encrypt		= cbc_encrypt,
616	.decrypt		= cbc_decrypt,
617}, {
618	.base.cra_name		= "__cts(cbc(aes))",
619	.base.cra_driver_name	= "__cts-cbc-aes-ce",
620	.base.cra_priority	= 300,
621	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
622	.base.cra_blocksize	= AES_BLOCK_SIZE,
623	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
624	.base.cra_module	= THIS_MODULE,
625
626	.min_keysize		= AES_MIN_KEY_SIZE,
627	.max_keysize		= AES_MAX_KEY_SIZE,
628	.ivsize			= AES_BLOCK_SIZE,
629	.walksize		= 2 * AES_BLOCK_SIZE,
630	.setkey			= ce_aes_setkey,
631	.encrypt		= cts_cbc_encrypt,
632	.decrypt		= cts_cbc_decrypt,
633}, {
634	.base.cra_name		= "__ctr(aes)",
635	.base.cra_driver_name	= "__ctr-aes-ce",
636	.base.cra_priority	= 300,
637	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
638	.base.cra_blocksize	= 1,
639	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
640	.base.cra_module	= THIS_MODULE,
641
642	.min_keysize		= AES_MIN_KEY_SIZE,
643	.max_keysize		= AES_MAX_KEY_SIZE,
644	.ivsize			= AES_BLOCK_SIZE,
645	.chunksize		= AES_BLOCK_SIZE,
646	.setkey			= ce_aes_setkey,
647	.encrypt		= ctr_encrypt,
648	.decrypt		= ctr_encrypt,
649}, {
650	.base.cra_name		= "ctr(aes)",
651	.base.cra_driver_name	= "ctr-aes-ce-sync",
652	.base.cra_priority	= 300 - 1,
653	.base.cra_blocksize	= 1,
654	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
655	.base.cra_module	= THIS_MODULE,
656
657	.min_keysize		= AES_MIN_KEY_SIZE,
658	.max_keysize		= AES_MAX_KEY_SIZE,
659	.ivsize			= AES_BLOCK_SIZE,
660	.chunksize		= AES_BLOCK_SIZE,
661	.setkey			= ce_aes_setkey,
662	.encrypt		= ctr_encrypt_sync,
663	.decrypt		= ctr_encrypt_sync,
664}, {
665	.base.cra_name		= "__xts(aes)",
666	.base.cra_driver_name	= "__xts-aes-ce",
667	.base.cra_priority	= 300,
668	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
669	.base.cra_blocksize	= AES_BLOCK_SIZE,
670	.base.cra_ctxsize	= sizeof(struct crypto_aes_xts_ctx),
671	.base.cra_module	= THIS_MODULE,
672
673	.min_keysize		= 2 * AES_MIN_KEY_SIZE,
674	.max_keysize		= 2 * AES_MAX_KEY_SIZE,
675	.ivsize			= AES_BLOCK_SIZE,
676	.walksize		= 2 * AES_BLOCK_SIZE,
677	.setkey			= xts_set_key,
678	.encrypt		= xts_encrypt,
679	.decrypt		= xts_decrypt,
680} };
681
682static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
683
684static void aes_exit(void)
685{
686	int i;
687
688	for (i = 0; i < ARRAY_SIZE(aes_simd_algs) && aes_simd_algs[i]; i++)
689		simd_skcipher_free(aes_simd_algs[i]);
690
691	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
692}
693
694static int __init aes_init(void)
695{
696	struct simd_skcipher_alg *simd;
697	const char *basename;
698	const char *algname;
699	const char *drvname;
700	int err;
701	int i;
702
703	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
704	if (err)
705		return err;
706
707	for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
708		if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
709			continue;
710
711		algname = aes_algs[i].base.cra_name + 2;
712		drvname = aes_algs[i].base.cra_driver_name + 2;
713		basename = aes_algs[i].base.cra_driver_name;
714		simd = simd_skcipher_create_compat(aes_algs + i, algname, drvname, basename);
715		err = PTR_ERR(simd);
716		if (IS_ERR(simd))
717			goto unregister_simds;
718
719		aes_simd_algs[i] = simd;
720	}
721
722	return 0;
723
724unregister_simds:
725	aes_exit();
726	return err;
727}
728
729module_cpu_feature_match(AES, aes_init);
730module_exit(aes_exit);
v5.14.15
  1// SPDX-License-Identifier: GPL-2.0-only
  2/*
  3 * aes-ce-glue.c - wrapper code for ARMv8 AES
  4 *
  5 * Copyright (C) 2015 Linaro Ltd <ard.biesheuvel@linaro.org>
  6 */
  7
  8#include <asm/hwcap.h>
  9#include <asm/neon.h>
 10#include <asm/simd.h>
 11#include <asm/unaligned.h>
 12#include <crypto/aes.h>
 13#include <crypto/ctr.h>
 14#include <crypto/internal/simd.h>
 15#include <crypto/internal/skcipher.h>
 16#include <crypto/scatterwalk.h>
 17#include <linux/cpufeature.h>
 18#include <linux/module.h>
 19#include <crypto/xts.h>
 20
 21MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
 22MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
 23MODULE_LICENSE("GPL v2");
 24
 25/* defined in aes-ce-core.S */
 26asmlinkage u32 ce_aes_sub(u32 input);
 27asmlinkage void ce_aes_invert(void *dst, void *src);
 28
 29asmlinkage void ce_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
 30				   int rounds, int blocks);
 31asmlinkage void ce_aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
 32				   int rounds, int blocks);
 33
 34asmlinkage void ce_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
 35				   int rounds, int blocks, u8 iv[]);
 36asmlinkage void ce_aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
 37				   int rounds, int blocks, u8 iv[]);
 38asmlinkage void ce_aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
 39				   int rounds, int bytes, u8 const iv[]);
 40asmlinkage void ce_aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
 41				   int rounds, int bytes, u8 const iv[]);
 42
 43asmlinkage void ce_aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
 44				   int rounds, int blocks, u8 ctr[]);
 45
 46asmlinkage void ce_aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
 47				   int rounds, int bytes, u8 iv[],
 48				   u32 const rk2[], int first);
 49asmlinkage void ce_aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
 50				   int rounds, int bytes, u8 iv[],
 51				   u32 const rk2[], int first);
 52
 53struct aes_block {
 54	u8 b[AES_BLOCK_SIZE];
 55};
 56
 57static int num_rounds(struct crypto_aes_ctx *ctx)
 58{
 59	/*
 60	 * # of rounds specified by AES:
 61	 * 128 bit key		10 rounds
 62	 * 192 bit key		12 rounds
 63	 * 256 bit key		14 rounds
 64	 * => n byte key	=> 6 + (n/4) rounds
 65	 */
 66	return 6 + ctx->key_length / 4;
 67}
 68
 69static int ce_aes_expandkey(struct crypto_aes_ctx *ctx, const u8 *in_key,
 70			    unsigned int key_len)
 71{
 72	/*
 73	 * The AES key schedule round constants
 74	 */
 75	static u8 const rcon[] = {
 76		0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36,
 77	};
 78
 79	u32 kwords = key_len / sizeof(u32);
 80	struct aes_block *key_enc, *key_dec;
 81	int i, j;
 82
 83	if (key_len != AES_KEYSIZE_128 &&
 84	    key_len != AES_KEYSIZE_192 &&
 85	    key_len != AES_KEYSIZE_256)
 86		return -EINVAL;
 87
 88	ctx->key_length = key_len;
 89	for (i = 0; i < kwords; i++)
 90		ctx->key_enc[i] = get_unaligned_le32(in_key + i * sizeof(u32));
 91
 92	kernel_neon_begin();
 93	for (i = 0; i < sizeof(rcon); i++) {
 94		u32 *rki = ctx->key_enc + (i * kwords);
 95		u32 *rko = rki + kwords;
 96
 97		rko[0] = ror32(ce_aes_sub(rki[kwords - 1]), 8);
 98		rko[0] = rko[0] ^ rki[0] ^ rcon[i];
 99		rko[1] = rko[0] ^ rki[1];
100		rko[2] = rko[1] ^ rki[2];
101		rko[3] = rko[2] ^ rki[3];
102
103		if (key_len == AES_KEYSIZE_192) {
104			if (i >= 7)
105				break;
106			rko[4] = rko[3] ^ rki[4];
107			rko[5] = rko[4] ^ rki[5];
108		} else if (key_len == AES_KEYSIZE_256) {
109			if (i >= 6)
110				break;
111			rko[4] = ce_aes_sub(rko[3]) ^ rki[4];
112			rko[5] = rko[4] ^ rki[5];
113			rko[6] = rko[5] ^ rki[6];
114			rko[7] = rko[6] ^ rki[7];
115		}
116	}
117
118	/*
119	 * Generate the decryption keys for the Equivalent Inverse Cipher.
120	 * This involves reversing the order of the round keys, and applying
121	 * the Inverse Mix Columns transformation on all but the first and
122	 * the last one.
123	 */
124	key_enc = (struct aes_block *)ctx->key_enc;
125	key_dec = (struct aes_block *)ctx->key_dec;
126	j = num_rounds(ctx);
127
128	key_dec[0] = key_enc[j];
129	for (i = 1, j--; j > 0; i++, j--)
130		ce_aes_invert(key_dec + i, key_enc + j);
131	key_dec[i] = key_enc[0];
132
133	kernel_neon_end();
134	return 0;
135}
136
137static int ce_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
138			 unsigned int key_len)
139{
140	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
141
142	return ce_aes_expandkey(ctx, in_key, key_len);
143}
144
145struct crypto_aes_xts_ctx {
146	struct crypto_aes_ctx key1;
147	struct crypto_aes_ctx __aligned(8) key2;
148};
149
150static int xts_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
151		       unsigned int key_len)
152{
153	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
154	int ret;
155
156	ret = xts_verify_key(tfm, in_key, key_len);
157	if (ret)
158		return ret;
159
160	ret = ce_aes_expandkey(&ctx->key1, in_key, key_len / 2);
161	if (!ret)
162		ret = ce_aes_expandkey(&ctx->key2, &in_key[key_len / 2],
163				       key_len / 2);
164	return ret;
165}
166
167static int ecb_encrypt(struct skcipher_request *req)
168{
169	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
170	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
171	struct skcipher_walk walk;
172	unsigned int blocks;
173	int err;
174
175	err = skcipher_walk_virt(&walk, req, false);
176
177	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
178		kernel_neon_begin();
179		ce_aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
180				   ctx->key_enc, num_rounds(ctx), blocks);
181		kernel_neon_end();
182		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
183	}
184	return err;
185}
186
187static int ecb_decrypt(struct skcipher_request *req)
188{
189	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
190	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
191	struct skcipher_walk walk;
192	unsigned int blocks;
193	int err;
194
195	err = skcipher_walk_virt(&walk, req, false);
196
197	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
198		kernel_neon_begin();
199		ce_aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
200				   ctx->key_dec, num_rounds(ctx), blocks);
201		kernel_neon_end();
202		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
203	}
204	return err;
205}
206
207static int cbc_encrypt_walk(struct skcipher_request *req,
208			    struct skcipher_walk *walk)
209{
210	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
211	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
212	unsigned int blocks;
213	int err = 0;
214
215	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
216		kernel_neon_begin();
217		ce_aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
218				   ctx->key_enc, num_rounds(ctx), blocks,
219				   walk->iv);
220		kernel_neon_end();
221		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
222	}
223	return err;
224}
225
226static int cbc_encrypt(struct skcipher_request *req)
227{
228	struct skcipher_walk walk;
229	int err;
230
231	err = skcipher_walk_virt(&walk, req, false);
232	if (err)
233		return err;
234	return cbc_encrypt_walk(req, &walk);
235}
236
237static int cbc_decrypt_walk(struct skcipher_request *req,
238			    struct skcipher_walk *walk)
239{
240	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
241	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
242	unsigned int blocks;
243	int err = 0;
244
245	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
246		kernel_neon_begin();
247		ce_aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
248				   ctx->key_dec, num_rounds(ctx), blocks,
249				   walk->iv);
250		kernel_neon_end();
251		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
252	}
253	return err;
254}
255
256static int cbc_decrypt(struct skcipher_request *req)
257{
258	struct skcipher_walk walk;
259	int err;
260
261	err = skcipher_walk_virt(&walk, req, false);
262	if (err)
263		return err;
264	return cbc_decrypt_walk(req, &walk);
265}
266
267static int cts_cbc_encrypt(struct skcipher_request *req)
268{
269	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
270	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
271	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
272	struct scatterlist *src = req->src, *dst = req->dst;
273	struct scatterlist sg_src[2], sg_dst[2];
274	struct skcipher_request subreq;
275	struct skcipher_walk walk;
276	int err;
277
278	skcipher_request_set_tfm(&subreq, tfm);
279	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
280				      NULL, NULL);
281
282	if (req->cryptlen <= AES_BLOCK_SIZE) {
283		if (req->cryptlen < AES_BLOCK_SIZE)
284			return -EINVAL;
285		cbc_blocks = 1;
286	}
287
288	if (cbc_blocks > 0) {
289		skcipher_request_set_crypt(&subreq, req->src, req->dst,
290					   cbc_blocks * AES_BLOCK_SIZE,
291					   req->iv);
292
293		err = skcipher_walk_virt(&walk, &subreq, false) ?:
294		      cbc_encrypt_walk(&subreq, &walk);
295		if (err)
296			return err;
297
298		if (req->cryptlen == AES_BLOCK_SIZE)
299			return 0;
300
301		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
302		if (req->dst != req->src)
303			dst = scatterwalk_ffwd(sg_dst, req->dst,
304					       subreq.cryptlen);
305	}
306
307	/* handle ciphertext stealing */
308	skcipher_request_set_crypt(&subreq, src, dst,
309				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
310				   req->iv);
311
312	err = skcipher_walk_virt(&walk, &subreq, false);
313	if (err)
314		return err;
315
316	kernel_neon_begin();
317	ce_aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
318			       ctx->key_enc, num_rounds(ctx), walk.nbytes,
319			       walk.iv);
320	kernel_neon_end();
321
322	return skcipher_walk_done(&walk, 0);
323}
324
325static int cts_cbc_decrypt(struct skcipher_request *req)
326{
327	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
328	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
329	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
330	struct scatterlist *src = req->src, *dst = req->dst;
331	struct scatterlist sg_src[2], sg_dst[2];
332	struct skcipher_request subreq;
333	struct skcipher_walk walk;
334	int err;
335
336	skcipher_request_set_tfm(&subreq, tfm);
337	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
338				      NULL, NULL);
339
340	if (req->cryptlen <= AES_BLOCK_SIZE) {
341		if (req->cryptlen < AES_BLOCK_SIZE)
342			return -EINVAL;
343		cbc_blocks = 1;
344	}
345
346	if (cbc_blocks > 0) {
347		skcipher_request_set_crypt(&subreq, req->src, req->dst,
348					   cbc_blocks * AES_BLOCK_SIZE,
349					   req->iv);
350
351		err = skcipher_walk_virt(&walk, &subreq, false) ?:
352		      cbc_decrypt_walk(&subreq, &walk);
353		if (err)
354			return err;
355
356		if (req->cryptlen == AES_BLOCK_SIZE)
357			return 0;
358
359		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
360		if (req->dst != req->src)
361			dst = scatterwalk_ffwd(sg_dst, req->dst,
362					       subreq.cryptlen);
363	}
364
365	/* handle ciphertext stealing */
366	skcipher_request_set_crypt(&subreq, src, dst,
367				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
368				   req->iv);
369
370	err = skcipher_walk_virt(&walk, &subreq, false);
371	if (err)
372		return err;
373
374	kernel_neon_begin();
375	ce_aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
376			       ctx->key_dec, num_rounds(ctx), walk.nbytes,
377			       walk.iv);
378	kernel_neon_end();
379
380	return skcipher_walk_done(&walk, 0);
381}
382
383static int ctr_encrypt(struct skcipher_request *req)
384{
385	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
386	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
387	struct skcipher_walk walk;
388	int err, blocks;
389
390	err = skcipher_walk_virt(&walk, req, false);
391
392	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
393		kernel_neon_begin();
394		ce_aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
395				   ctx->key_enc, num_rounds(ctx), blocks,
396				   walk.iv);
397		kernel_neon_end();
398		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
399	}
400	if (walk.nbytes) {
401		u8 __aligned(8) tail[AES_BLOCK_SIZE];
402		unsigned int nbytes = walk.nbytes;
403		u8 *tdst = walk.dst.virt.addr;
404		u8 *tsrc = walk.src.virt.addr;
405
406		/*
407		 * Tell aes_ctr_encrypt() to process a tail block.
408		 */
409		blocks = -1;
410
411		kernel_neon_begin();
412		ce_aes_ctr_encrypt(tail, NULL, ctx->key_enc, num_rounds(ctx),
413				   blocks, walk.iv);
414		kernel_neon_end();
415		crypto_xor_cpy(tdst, tsrc, tail, nbytes);
416		err = skcipher_walk_done(&walk, 0);
417	}
418	return err;
419}
420
421static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
422{
423	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
424	unsigned long flags;
425
426	/*
427	 * Temporarily disable interrupts to avoid races where
428	 * cachelines are evicted when the CPU is interrupted
429	 * to do something else.
430	 */
431	local_irq_save(flags);
432	aes_encrypt(ctx, dst, src);
433	local_irq_restore(flags);
434}
435
436static int ctr_encrypt_sync(struct skcipher_request *req)
437{
438	if (!crypto_simd_usable())
439		return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
440
441	return ctr_encrypt(req);
442}
443
444static int xts_encrypt(struct skcipher_request *req)
445{
446	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
447	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
448	int err, first, rounds = num_rounds(&ctx->key1);
449	int tail = req->cryptlen % AES_BLOCK_SIZE;
450	struct scatterlist sg_src[2], sg_dst[2];
451	struct skcipher_request subreq;
452	struct scatterlist *src, *dst;
453	struct skcipher_walk walk;
454
455	if (req->cryptlen < AES_BLOCK_SIZE)
456		return -EINVAL;
457
458	err = skcipher_walk_virt(&walk, req, false);
459
460	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
461		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
462					      AES_BLOCK_SIZE) - 2;
463
464		skcipher_walk_abort(&walk);
465
466		skcipher_request_set_tfm(&subreq, tfm);
467		skcipher_request_set_callback(&subreq,
468					      skcipher_request_flags(req),
469					      NULL, NULL);
470		skcipher_request_set_crypt(&subreq, req->src, req->dst,
471					   xts_blocks * AES_BLOCK_SIZE,
472					   req->iv);
473		req = &subreq;
474		err = skcipher_walk_virt(&walk, req, false);
475	} else {
476		tail = 0;
477	}
478
479	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
480		int nbytes = walk.nbytes;
481
482		if (walk.nbytes < walk.total)
483			nbytes &= ~(AES_BLOCK_SIZE - 1);
484
485		kernel_neon_begin();
486		ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
487				   ctx->key1.key_enc, rounds, nbytes, walk.iv,
488				   ctx->key2.key_enc, first);
489		kernel_neon_end();
490		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
491	}
492
493	if (err || likely(!tail))
494		return err;
495
496	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
497	if (req->dst != req->src)
498		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
499
500	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
501				   req->iv);
502
503	err = skcipher_walk_virt(&walk, req, false);
504	if (err)
505		return err;
506
507	kernel_neon_begin();
508	ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
509			   ctx->key1.key_enc, rounds, walk.nbytes, walk.iv,
510			   ctx->key2.key_enc, first);
511	kernel_neon_end();
512
513	return skcipher_walk_done(&walk, 0);
514}
515
516static int xts_decrypt(struct skcipher_request *req)
517{
518	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
519	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
520	int err, first, rounds = num_rounds(&ctx->key1);
521	int tail = req->cryptlen % AES_BLOCK_SIZE;
522	struct scatterlist sg_src[2], sg_dst[2];
523	struct skcipher_request subreq;
524	struct scatterlist *src, *dst;
525	struct skcipher_walk walk;
526
527	if (req->cryptlen < AES_BLOCK_SIZE)
528		return -EINVAL;
529
530	err = skcipher_walk_virt(&walk, req, false);
531
532	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
533		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
534					      AES_BLOCK_SIZE) - 2;
535
536		skcipher_walk_abort(&walk);
537
538		skcipher_request_set_tfm(&subreq, tfm);
539		skcipher_request_set_callback(&subreq,
540					      skcipher_request_flags(req),
541					      NULL, NULL);
542		skcipher_request_set_crypt(&subreq, req->src, req->dst,
543					   xts_blocks * AES_BLOCK_SIZE,
544					   req->iv);
545		req = &subreq;
546		err = skcipher_walk_virt(&walk, req, false);
547	} else {
548		tail = 0;
549	}
550
551	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
552		int nbytes = walk.nbytes;
553
554		if (walk.nbytes < walk.total)
555			nbytes &= ~(AES_BLOCK_SIZE - 1);
556
557		kernel_neon_begin();
558		ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
559				   ctx->key1.key_dec, rounds, nbytes, walk.iv,
560				   ctx->key2.key_enc, first);
561		kernel_neon_end();
562		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
563	}
564
565	if (err || likely(!tail))
566		return err;
567
568	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
569	if (req->dst != req->src)
570		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
571
572	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
573				   req->iv);
574
575	err = skcipher_walk_virt(&walk, req, false);
576	if (err)
577		return err;
578
579	kernel_neon_begin();
580	ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
581			   ctx->key1.key_dec, rounds, walk.nbytes, walk.iv,
582			   ctx->key2.key_enc, first);
583	kernel_neon_end();
584
585	return skcipher_walk_done(&walk, 0);
586}
587
588static struct skcipher_alg aes_algs[] = { {
589	.base.cra_name		= "__ecb(aes)",
590	.base.cra_driver_name	= "__ecb-aes-ce",
591	.base.cra_priority	= 300,
592	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
593	.base.cra_blocksize	= AES_BLOCK_SIZE,
594	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
595	.base.cra_module	= THIS_MODULE,
596
597	.min_keysize		= AES_MIN_KEY_SIZE,
598	.max_keysize		= AES_MAX_KEY_SIZE,
599	.setkey			= ce_aes_setkey,
600	.encrypt		= ecb_encrypt,
601	.decrypt		= ecb_decrypt,
602}, {
603	.base.cra_name		= "__cbc(aes)",
604	.base.cra_driver_name	= "__cbc-aes-ce",
605	.base.cra_priority	= 300,
606	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
607	.base.cra_blocksize	= AES_BLOCK_SIZE,
608	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
609	.base.cra_module	= THIS_MODULE,
610
611	.min_keysize		= AES_MIN_KEY_SIZE,
612	.max_keysize		= AES_MAX_KEY_SIZE,
613	.ivsize			= AES_BLOCK_SIZE,
614	.setkey			= ce_aes_setkey,
615	.encrypt		= cbc_encrypt,
616	.decrypt		= cbc_decrypt,
617}, {
618	.base.cra_name		= "__cts(cbc(aes))",
619	.base.cra_driver_name	= "__cts-cbc-aes-ce",
620	.base.cra_priority	= 300,
621	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
622	.base.cra_blocksize	= AES_BLOCK_SIZE,
623	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
624	.base.cra_module	= THIS_MODULE,
625
626	.min_keysize		= AES_MIN_KEY_SIZE,
627	.max_keysize		= AES_MAX_KEY_SIZE,
628	.ivsize			= AES_BLOCK_SIZE,
629	.walksize		= 2 * AES_BLOCK_SIZE,
630	.setkey			= ce_aes_setkey,
631	.encrypt		= cts_cbc_encrypt,
632	.decrypt		= cts_cbc_decrypt,
633}, {
634	.base.cra_name		= "__ctr(aes)",
635	.base.cra_driver_name	= "__ctr-aes-ce",
636	.base.cra_priority	= 300,
637	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
638	.base.cra_blocksize	= 1,
639	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
640	.base.cra_module	= THIS_MODULE,
641
642	.min_keysize		= AES_MIN_KEY_SIZE,
643	.max_keysize		= AES_MAX_KEY_SIZE,
644	.ivsize			= AES_BLOCK_SIZE,
645	.chunksize		= AES_BLOCK_SIZE,
646	.setkey			= ce_aes_setkey,
647	.encrypt		= ctr_encrypt,
648	.decrypt		= ctr_encrypt,
649}, {
650	.base.cra_name		= "ctr(aes)",
651	.base.cra_driver_name	= "ctr-aes-ce-sync",
652	.base.cra_priority	= 300 - 1,
653	.base.cra_blocksize	= 1,
654	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
655	.base.cra_module	= THIS_MODULE,
656
657	.min_keysize		= AES_MIN_KEY_SIZE,
658	.max_keysize		= AES_MAX_KEY_SIZE,
659	.ivsize			= AES_BLOCK_SIZE,
660	.chunksize		= AES_BLOCK_SIZE,
661	.setkey			= ce_aes_setkey,
662	.encrypt		= ctr_encrypt_sync,
663	.decrypt		= ctr_encrypt_sync,
664}, {
665	.base.cra_name		= "__xts(aes)",
666	.base.cra_driver_name	= "__xts-aes-ce",
667	.base.cra_priority	= 300,
668	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
669	.base.cra_blocksize	= AES_BLOCK_SIZE,
670	.base.cra_ctxsize	= sizeof(struct crypto_aes_xts_ctx),
671	.base.cra_module	= THIS_MODULE,
672
673	.min_keysize		= 2 * AES_MIN_KEY_SIZE,
674	.max_keysize		= 2 * AES_MAX_KEY_SIZE,
675	.ivsize			= AES_BLOCK_SIZE,
676	.walksize		= 2 * AES_BLOCK_SIZE,
677	.setkey			= xts_set_key,
678	.encrypt		= xts_encrypt,
679	.decrypt		= xts_decrypt,
680} };
681
682static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
683
684static void aes_exit(void)
685{
686	int i;
687
688	for (i = 0; i < ARRAY_SIZE(aes_simd_algs) && aes_simd_algs[i]; i++)
689		simd_skcipher_free(aes_simd_algs[i]);
690
691	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
692}
693
694static int __init aes_init(void)
695{
696	struct simd_skcipher_alg *simd;
697	const char *basename;
698	const char *algname;
699	const char *drvname;
700	int err;
701	int i;
702
703	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
704	if (err)
705		return err;
706
707	for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
708		if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
709			continue;
710
711		algname = aes_algs[i].base.cra_name + 2;
712		drvname = aes_algs[i].base.cra_driver_name + 2;
713		basename = aes_algs[i].base.cra_driver_name;
714		simd = simd_skcipher_create_compat(algname, drvname, basename);
715		err = PTR_ERR(simd);
716		if (IS_ERR(simd))
717			goto unregister_simds;
718
719		aes_simd_algs[i] = simd;
720	}
721
722	return 0;
723
724unregister_simds:
725	aes_exit();
726	return err;
727}
728
729module_cpu_feature_match(AES, aes_init);
730module_exit(aes_exit);