hc
2023-12-11 d2ccde1c8e90d38cee87a1b0309ad2827f3fd30d
kernel/arch/arm64/crypto/aes-neonbs-glue.c
....@@ -1,22 +1,19 @@
1
+// SPDX-License-Identifier: GPL-2.0-only
12 /*
23 * Bit sliced AES using NEON instructions
34 *
45 * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
5
- *
6
- * This program is free software; you can redistribute it and/or modify
7
- * it under the terms of the GNU General Public License version 2 as
8
- * published by the Free Software Foundation.
96 */
107
118 #include <asm/neon.h>
129 #include <asm/simd.h>
1310 #include <crypto/aes.h>
11
+#include <crypto/ctr.h>
1412 #include <crypto/internal/simd.h>
1513 #include <crypto/internal/skcipher.h>
14
+#include <crypto/scatterwalk.h>
1615 #include <crypto/xts.h>
1716 #include <linux/module.h>
18
-
19
-#include "aes-ctr-fallback.h"
2017
2118 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
2219 MODULE_LICENSE("GPL v2");
....@@ -49,6 +46,12 @@
4946 int rounds, int blocks);
5047 asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
5148 int rounds, int blocks, u8 iv[]);
49
+asmlinkage void neon_aes_xts_encrypt(u8 out[], u8 const in[],
50
+ u32 const rk1[], int rounds, int bytes,
51
+ u32 const rk2[], u8 iv[], int first);
52
+asmlinkage void neon_aes_xts_decrypt(u8 out[], u8 const in[],
53
+ u32 const rk1[], int rounds, int bytes,
54
+ u32 const rk2[], u8 iv[], int first);
5255
5356 struct aesbs_ctx {
5457 u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32];
....@@ -60,14 +63,10 @@
6063 u32 enc[AES_MAX_KEYLENGTH_U32];
6164 };
6265
63
-struct aesbs_ctr_ctx {
64
- struct aesbs_ctx key; /* must be first member */
65
- struct crypto_aes_ctx fallback;
66
-};
67
-
6866 struct aesbs_xts_ctx {
6967 struct aesbs_ctx key;
7068 u32 twkey[AES_MAX_KEYLENGTH_U32];
69
+ struct crypto_aes_ctx cts;
7170 };
7271
7372 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
....@@ -77,7 +76,7 @@
7776 struct crypto_aes_ctx rk;
7877 int err;
7978
80
- err = crypto_aes_expand_key(&rk, in_key, key_len);
79
+ err = aes_expandkey(&rk, in_key, key_len);
8180 if (err)
8281 return err;
8382
....@@ -136,7 +135,7 @@
136135 struct crypto_aes_ctx rk;
137136 int err;
138137
139
- err = crypto_aes_expand_key(&rk, in_key, key_len);
138
+ err = aes_expandkey(&rk, in_key, key_len);
140139 if (err)
141140 return err;
142141
....@@ -147,6 +146,7 @@
147146 kernel_neon_begin();
148147 aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
149148 kernel_neon_end();
149
+ memzero_explicit(&rk, sizeof(rk));
150150
151151 return 0;
152152 }
....@@ -202,25 +202,6 @@
202202 return err;
203203 }
204204
205
-static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
206
- unsigned int key_len)
207
-{
208
- struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
209
- int err;
210
-
211
- err = crypto_aes_expand_key(&ctx->fallback, in_key, key_len);
212
- if (err)
213
- return err;
214
-
215
- ctx->key.rounds = 6 + key_len / 4;
216
-
217
- kernel_neon_begin();
218
- aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
219
- kernel_neon_end();
220
-
221
- return 0;
222
-}
223
-
224205 static int ctr_encrypt(struct skcipher_request *req)
225206 {
226207 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
....@@ -274,7 +255,11 @@
274255 return err;
275256
276257 key_len /= 2;
277
- err = crypto_aes_expand_key(&rk, in_key + key_len, key_len);
258
+ err = aes_expandkey(&ctx->cts, in_key, key_len);
259
+ if (err)
260
+ return err;
261
+
262
+ err = aes_expandkey(&rk, in_key + key_len, key_len);
278263 if (err)
279264 return err;
280265
....@@ -283,69 +268,128 @@
283268 return aesbs_setkey(tfm, in_key, key_len);
284269 }
285270
286
-static int ctr_encrypt_sync(struct skcipher_request *req)
287
-{
288
- struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
289
- struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
290
-
291
- if (!may_use_simd())
292
- return aes_ctr_encrypt_fallback(&ctx->fallback, req);
293
-
294
- return ctr_encrypt(req);
295
-}
296
-
297
-static int __xts_crypt(struct skcipher_request *req,
271
+static int __xts_crypt(struct skcipher_request *req, bool encrypt,
298272 void (*fn)(u8 out[], u8 const in[], u8 const rk[],
299273 int rounds, int blocks, u8 iv[]))
300274 {
301275 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
302276 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
277
+ int tail = req->cryptlen % (8 * AES_BLOCK_SIZE);
278
+ struct scatterlist sg_src[2], sg_dst[2];
279
+ struct skcipher_request subreq;
280
+ struct scatterlist *src, *dst;
303281 struct skcipher_walk walk;
304
- int err;
282
+ int nbytes, err;
283
+ int first = 1;
284
+ u8 *out, *in;
285
+
286
+ if (req->cryptlen < AES_BLOCK_SIZE)
287
+ return -EINVAL;
288
+
289
+ /* ensure that the cts tail is covered by a single step */
290
+ if (unlikely(tail > 0 && tail < AES_BLOCK_SIZE)) {
291
+ int xts_blocks = DIV_ROUND_UP(req->cryptlen,
292
+ AES_BLOCK_SIZE) - 2;
293
+
294
+ skcipher_request_set_tfm(&subreq, tfm);
295
+ skcipher_request_set_callback(&subreq,
296
+ skcipher_request_flags(req),
297
+ NULL, NULL);
298
+ skcipher_request_set_crypt(&subreq, req->src, req->dst,
299
+ xts_blocks * AES_BLOCK_SIZE,
300
+ req->iv);
301
+ req = &subreq;
302
+ } else {
303
+ tail = 0;
304
+ }
305305
306306 err = skcipher_walk_virt(&walk, req, false);
307307 if (err)
308308 return err;
309309
310
- kernel_neon_begin();
311
- neon_aes_ecb_encrypt(walk.iv, walk.iv, ctx->twkey, ctx->key.rounds, 1);
312
- kernel_neon_end();
313
-
314310 while (walk.nbytes >= AES_BLOCK_SIZE) {
315311 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
316312
317
- if (walk.nbytes < walk.total)
313
+ if (walk.nbytes < walk.total || walk.nbytes % AES_BLOCK_SIZE)
318314 blocks = round_down(blocks,
319315 walk.stride / AES_BLOCK_SIZE);
320316
317
+ out = walk.dst.virt.addr;
318
+ in = walk.src.virt.addr;
319
+ nbytes = walk.nbytes;
320
+
321321 kernel_neon_begin();
322
- fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
323
- ctx->key.rounds, blocks, walk.iv);
322
+ if (likely(blocks > 6)) { /* plain NEON is faster otherwise */
323
+ if (first)
324
+ neon_aes_ecb_encrypt(walk.iv, walk.iv,
325
+ ctx->twkey,
326
+ ctx->key.rounds, 1);
327
+ first = 0;
328
+
329
+ fn(out, in, ctx->key.rk, ctx->key.rounds, blocks,
330
+ walk.iv);
331
+
332
+ out += blocks * AES_BLOCK_SIZE;
333
+ in += blocks * AES_BLOCK_SIZE;
334
+ nbytes -= blocks * AES_BLOCK_SIZE;
335
+ }
336
+
337
+ if (walk.nbytes == walk.total && nbytes > 0)
338
+ goto xts_tail;
339
+
324340 kernel_neon_end();
325
- err = skcipher_walk_done(&walk,
326
- walk.nbytes - blocks * AES_BLOCK_SIZE);
341
+ err = skcipher_walk_done(&walk, nbytes);
327342 }
328
- return err;
343
+
344
+ if (err || likely(!tail))
345
+ return err;
346
+
347
+ /* handle ciphertext stealing */
348
+ dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
349
+ if (req->dst != req->src)
350
+ dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
351
+
352
+ skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
353
+ req->iv);
354
+
355
+ err = skcipher_walk_virt(&walk, req, false);
356
+ if (err)
357
+ return err;
358
+
359
+ out = walk.dst.virt.addr;
360
+ in = walk.src.virt.addr;
361
+ nbytes = walk.nbytes;
362
+
363
+ kernel_neon_begin();
364
+xts_tail:
365
+ if (encrypt)
366
+ neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, ctx->key.rounds,
367
+ nbytes, ctx->twkey, walk.iv, first ?: 2);
368
+ else
369
+ neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, ctx->key.rounds,
370
+ nbytes, ctx->twkey, walk.iv, first ?: 2);
371
+ kernel_neon_end();
372
+
373
+ return skcipher_walk_done(&walk, 0);
329374 }
330375
331376 static int xts_encrypt(struct skcipher_request *req)
332377 {
333
- return __xts_crypt(req, aesbs_xts_encrypt);
378
+ return __xts_crypt(req, true, aesbs_xts_encrypt);
334379 }
335380
336381 static int xts_decrypt(struct skcipher_request *req)
337382 {
338
- return __xts_crypt(req, aesbs_xts_decrypt);
383
+ return __xts_crypt(req, false, aesbs_xts_decrypt);
339384 }
340385
341386 static struct skcipher_alg aes_algs[] = { {
342
- .base.cra_name = "__ecb(aes)",
343
- .base.cra_driver_name = "__ecb-aes-neonbs",
387
+ .base.cra_name = "ecb(aes)",
388
+ .base.cra_driver_name = "ecb-aes-neonbs",
344389 .base.cra_priority = 250,
345390 .base.cra_blocksize = AES_BLOCK_SIZE,
346391 .base.cra_ctxsize = sizeof(struct aesbs_ctx),
347392 .base.cra_module = THIS_MODULE,
348
- .base.cra_flags = CRYPTO_ALG_INTERNAL,
349393
350394 .min_keysize = AES_MIN_KEY_SIZE,
351395 .max_keysize = AES_MAX_KEY_SIZE,
....@@ -354,13 +398,12 @@
354398 .encrypt = ecb_encrypt,
355399 .decrypt = ecb_decrypt,
356400 }, {
357
- .base.cra_name = "__cbc(aes)",
358
- .base.cra_driver_name = "__cbc-aes-neonbs",
401
+ .base.cra_name = "cbc(aes)",
402
+ .base.cra_driver_name = "cbc-aes-neonbs",
359403 .base.cra_priority = 250,
360404 .base.cra_blocksize = AES_BLOCK_SIZE,
361405 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctx),
362406 .base.cra_module = THIS_MODULE,
363
- .base.cra_flags = CRYPTO_ALG_INTERNAL,
364407
365408 .min_keysize = AES_MIN_KEY_SIZE,
366409 .max_keysize = AES_MAX_KEY_SIZE,
....@@ -370,13 +413,12 @@
370413 .encrypt = cbc_encrypt,
371414 .decrypt = cbc_decrypt,
372415 }, {
373
- .base.cra_name = "__ctr(aes)",
374
- .base.cra_driver_name = "__ctr-aes-neonbs",
416
+ .base.cra_name = "ctr(aes)",
417
+ .base.cra_driver_name = "ctr-aes-neonbs",
375418 .base.cra_priority = 250,
376419 .base.cra_blocksize = 1,
377420 .base.cra_ctxsize = sizeof(struct aesbs_ctx),
378421 .base.cra_module = THIS_MODULE,
379
- .base.cra_flags = CRYPTO_ALG_INTERNAL,
380422
381423 .min_keysize = AES_MIN_KEY_SIZE,
382424 .max_keysize = AES_MAX_KEY_SIZE,
....@@ -387,29 +429,12 @@
387429 .encrypt = ctr_encrypt,
388430 .decrypt = ctr_encrypt,
389431 }, {
390
- .base.cra_name = "ctr(aes)",
391
- .base.cra_driver_name = "ctr-aes-neonbs",
392
- .base.cra_priority = 250 - 1,
393
- .base.cra_blocksize = 1,
394
- .base.cra_ctxsize = sizeof(struct aesbs_ctr_ctx),
395
- .base.cra_module = THIS_MODULE,
396
-
397
- .min_keysize = AES_MIN_KEY_SIZE,
398
- .max_keysize = AES_MAX_KEY_SIZE,
399
- .chunksize = AES_BLOCK_SIZE,
400
- .walksize = 8 * AES_BLOCK_SIZE,
401
- .ivsize = AES_BLOCK_SIZE,
402
- .setkey = aesbs_ctr_setkey_sync,
403
- .encrypt = ctr_encrypt_sync,
404
- .decrypt = ctr_encrypt_sync,
405
-}, {
406
- .base.cra_name = "__xts(aes)",
407
- .base.cra_driver_name = "__xts-aes-neonbs",
432
+ .base.cra_name = "xts(aes)",
433
+ .base.cra_driver_name = "xts-aes-neonbs",
408434 .base.cra_priority = 250,
409435 .base.cra_blocksize = AES_BLOCK_SIZE,
410436 .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx),
411437 .base.cra_module = THIS_MODULE,
412
- .base.cra_flags = CRYPTO_ALG_INTERNAL,
413438
414439 .min_keysize = 2 * AES_MIN_KEY_SIZE,
415440 .max_keysize = 2 * AES_MAX_KEY_SIZE,
....@@ -420,54 +445,17 @@
420445 .decrypt = xts_decrypt,
421446 } };
422447
423
-static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
424
-
425448 static void aes_exit(void)
426449 {
427
- int i;
428
-
429
- for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
430
- if (aes_simd_algs[i])
431
- simd_skcipher_free(aes_simd_algs[i]);
432
-
433450 crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
434451 }
435452
436453 static int __init aes_init(void)
437454 {
438
- struct simd_skcipher_alg *simd;
439
- const char *basename;
440
- const char *algname;
441
- const char *drvname;
442
- int err;
443
- int i;
444
-
445
- if (!(elf_hwcap & HWCAP_ASIMD))
455
+ if (!cpu_have_named_feature(ASIMD))
446456 return -ENODEV;
447457
448
- err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
449
- if (err)
450
- return err;
451
-
452
- for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
453
- if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
454
- continue;
455
-
456
- algname = aes_algs[i].base.cra_name + 2;
457
- drvname = aes_algs[i].base.cra_driver_name + 2;
458
- basename = aes_algs[i].base.cra_driver_name;
459
- simd = simd_skcipher_create_compat(algname, drvname, basename);
460
- err = PTR_ERR(simd);
461
- if (IS_ERR(simd))
462
- goto unregister_simds;
463
-
464
- aes_simd_algs[i] = simd;
465
- }
466
- return 0;
467
-
468
-unregister_simds:
469
- aes_exit();
470
- return err;
458
+ return crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
471459 }
472460
473461 module_init(aes_init);