hc
2023-12-11 d2ccde1c8e90d38cee87a1b0309ad2827f3fd30d
kernel/arch/arm/crypto/aes-neonbs-glue.c
....@@ -1,18 +1,18 @@
1
+// SPDX-License-Identifier: GPL-2.0-only
12 /*
23 * Bit sliced AES using NEON instructions
34 *
45 * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
5
- *
6
- * This program is free software; you can redistribute it and/or modify
7
- * it under the terms of the GNU General Public License version 2 as
8
- * published by the Free Software Foundation.
96 */
107
118 #include <asm/neon.h>
9
+#include <asm/simd.h>
1210 #include <crypto/aes.h>
13
-#include <crypto/cbc.h>
11
+#include <crypto/ctr.h>
12
+#include <crypto/internal/cipher.h>
1413 #include <crypto/internal/simd.h>
1514 #include <crypto/internal/skcipher.h>
15
+#include <crypto/scatterwalk.h>
1616 #include <crypto/xts.h>
1717 #include <linux/module.h>
1818
....@@ -20,9 +20,11 @@
2020 MODULE_LICENSE("GPL v2");
2121
2222 MODULE_ALIAS_CRYPTO("ecb(aes)");
23
-MODULE_ALIAS_CRYPTO("cbc(aes)");
23
+MODULE_ALIAS_CRYPTO("cbc(aes)-all");
2424 MODULE_ALIAS_CRYPTO("ctr(aes)");
2525 MODULE_ALIAS_CRYPTO("xts(aes)");
26
+
27
+MODULE_IMPORT_NS(CRYPTO_INTERNAL);
2628
2729 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
2830
....@@ -38,9 +40,9 @@
3840 int rounds, int blocks, u8 ctr[], u8 final[]);
3941
4042 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
41
- int rounds, int blocks, u8 iv[]);
43
+ int rounds, int blocks, u8 iv[], int);
4244 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
43
- int rounds, int blocks, u8 iv[]);
45
+ int rounds, int blocks, u8 iv[], int);
4446
4547 struct aesbs_ctx {
4648 int rounds;
....@@ -49,12 +51,18 @@
4951
5052 struct aesbs_cbc_ctx {
5153 struct aesbs_ctx key;
52
- struct crypto_cipher *enc_tfm;
54
+ struct crypto_skcipher *enc_tfm;
5355 };
5456
5557 struct aesbs_xts_ctx {
5658 struct aesbs_ctx key;
59
+ struct crypto_cipher *cts_tfm;
5760 struct crypto_cipher *tweak_tfm;
61
+};
62
+
63
+struct aesbs_ctr_ctx {
64
+ struct aesbs_ctx key; /* must be first member */
65
+ struct crypto_aes_ctx fallback;
5866 };
5967
6068 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
....@@ -64,7 +72,7 @@
6472 struct crypto_aes_ctx rk;
6573 int err;
6674
67
- err = crypto_aes_expand_key(&rk, in_key, key_len);
75
+ err = aes_expandkey(&rk, in_key, key_len);
6876 if (err)
6977 return err;
7078
....@@ -86,9 +94,8 @@
8694 struct skcipher_walk walk;
8795 int err;
8896
89
- err = skcipher_walk_virt(&walk, req, true);
97
+ err = skcipher_walk_virt(&walk, req, false);
9098
91
- kernel_neon_begin();
9299 while (walk.nbytes >= AES_BLOCK_SIZE) {
93100 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
94101
....@@ -96,12 +103,13 @@
96103 blocks = round_down(blocks,
97104 walk.stride / AES_BLOCK_SIZE);
98105
106
+ kernel_neon_begin();
99107 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
100108 ctx->rounds, blocks);
109
+ kernel_neon_end();
101110 err = skcipher_walk_done(&walk,
102111 walk.nbytes - blocks * AES_BLOCK_SIZE);
103112 }
104
- kernel_neon_end();
105113
106114 return err;
107115 }
....@@ -123,7 +131,7 @@
123131 struct crypto_aes_ctx rk;
124132 int err;
125133
126
- err = crypto_aes_expand_key(&rk, in_key, key_len);
134
+ err = aes_expandkey(&rk, in_key, key_len);
127135 if (err)
128136 return err;
129137
....@@ -132,20 +140,25 @@
132140 kernel_neon_begin();
133141 aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
134142 kernel_neon_end();
143
+ memzero_explicit(&rk, sizeof(rk));
135144
136
- return crypto_cipher_setkey(ctx->enc_tfm, in_key, key_len);
137
-}
138
-
139
-static void cbc_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
140
-{
141
- struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
142
-
143
- crypto_cipher_encrypt_one(ctx->enc_tfm, dst, src);
145
+ return crypto_skcipher_setkey(ctx->enc_tfm, in_key, key_len);
144146 }
145147
146148 static int cbc_encrypt(struct skcipher_request *req)
147149 {
148
- return crypto_cbc_encrypt_walk(req, cbc_encrypt_one);
150
+ struct skcipher_request *subreq = skcipher_request_ctx(req);
151
+ struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
152
+ struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
153
+
154
+ skcipher_request_set_tfm(subreq, ctx->enc_tfm);
155
+ skcipher_request_set_callback(subreq,
156
+ skcipher_request_flags(req),
157
+ NULL, NULL);
158
+ skcipher_request_set_crypt(subreq, req->src, req->dst,
159
+ req->cryptlen, req->iv);
160
+
161
+ return crypto_skcipher_encrypt(subreq);
149162 }
150163
151164 static int cbc_decrypt(struct skcipher_request *req)
....@@ -155,9 +168,8 @@
155168 struct skcipher_walk walk;
156169 int err;
157170
158
- err = skcipher_walk_virt(&walk, req, true);
171
+ err = skcipher_walk_virt(&walk, req, false);
159172
160
- kernel_neon_begin();
161173 while (walk.nbytes >= AES_BLOCK_SIZE) {
162174 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
163175
....@@ -165,31 +177,59 @@
165177 blocks = round_down(blocks,
166178 walk.stride / AES_BLOCK_SIZE);
167179
180
+ kernel_neon_begin();
168181 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
169182 ctx->key.rk, ctx->key.rounds, blocks,
170183 walk.iv);
184
+ kernel_neon_end();
171185 err = skcipher_walk_done(&walk,
172186 walk.nbytes - blocks * AES_BLOCK_SIZE);
173187 }
174
- kernel_neon_end();
175188
176189 return err;
177190 }
178191
179
-static int cbc_init(struct crypto_tfm *tfm)
192
+static int cbc_init(struct crypto_skcipher *tfm)
180193 {
181
- struct aesbs_cbc_ctx *ctx = crypto_tfm_ctx(tfm);
194
+ struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
195
+ unsigned int reqsize;
182196
183
- ctx->enc_tfm = crypto_alloc_cipher("aes", 0, 0);
197
+ ctx->enc_tfm = crypto_alloc_skcipher("cbc(aes)", 0, CRYPTO_ALG_ASYNC |
198
+ CRYPTO_ALG_NEED_FALLBACK);
199
+ if (IS_ERR(ctx->enc_tfm))
200
+ return PTR_ERR(ctx->enc_tfm);
184201
185
- return PTR_ERR_OR_ZERO(ctx->enc_tfm);
202
+ reqsize = sizeof(struct skcipher_request);
203
+ reqsize += crypto_skcipher_reqsize(ctx->enc_tfm);
204
+ crypto_skcipher_set_reqsize(tfm, reqsize);
205
+
206
+ return 0;
186207 }
187208
188
-static void cbc_exit(struct crypto_tfm *tfm)
209
+static void cbc_exit(struct crypto_skcipher *tfm)
189210 {
190
- struct aesbs_cbc_ctx *ctx = crypto_tfm_ctx(tfm);
211
+ struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
191212
192
- crypto_free_cipher(ctx->enc_tfm);
213
+ crypto_free_skcipher(ctx->enc_tfm);
214
+}
215
+
216
+static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
217
+ unsigned int key_len)
218
+{
219
+ struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
220
+ int err;
221
+
222
+ err = aes_expandkey(&ctx->fallback, in_key, key_len);
223
+ if (err)
224
+ return err;
225
+
226
+ ctx->key.rounds = 6 + key_len / 4;
227
+
228
+ kernel_neon_begin();
229
+ aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
230
+ kernel_neon_end();
231
+
232
+ return 0;
193233 }
194234
195235 static int ctr_encrypt(struct skcipher_request *req)
....@@ -200,9 +240,8 @@
200240 u8 buf[AES_BLOCK_SIZE];
201241 int err;
202242
203
- err = skcipher_walk_virt(&walk, req, true);
243
+ err = skcipher_walk_virt(&walk, req, false);
204244
205
- kernel_neon_begin();
206245 while (walk.nbytes > 0) {
207246 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
208247 u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
....@@ -213,8 +252,10 @@
213252 final = NULL;
214253 }
215254
255
+ kernel_neon_begin();
216256 aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
217257 ctx->rk, ctx->rounds, blocks, walk.iv, final);
258
+ kernel_neon_end();
218259
219260 if (final) {
220261 u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
....@@ -229,9 +270,31 @@
229270 err = skcipher_walk_done(&walk,
230271 walk.nbytes - blocks * AES_BLOCK_SIZE);
231272 }
232
- kernel_neon_end();
233273
234274 return err;
275
+}
276
+
277
+static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
278
+{
279
+ struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
280
+ unsigned long flags;
281
+
282
+ /*
283
+ * Temporarily disable interrupts to avoid races where
284
+ * cachelines are evicted when the CPU is interrupted
285
+ * to do something else.
286
+ */
287
+ local_irq_save(flags);
288
+ aes_encrypt(&ctx->fallback, dst, src);
289
+ local_irq_restore(flags);
290
+}
291
+
292
+static int ctr_encrypt_sync(struct skcipher_request *req)
293
+{
294
+ if (!crypto_simd_usable())
295
+ return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
296
+
297
+ return ctr_encrypt(req);
235298 }
236299
237300 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
....@@ -245,6 +308,9 @@
245308 return err;
246309
247310 key_len /= 2;
311
+ err = crypto_cipher_setkey(ctx->cts_tfm, in_key, key_len);
312
+ if (err)
313
+ return err;
248314 err = crypto_cipher_setkey(ctx->tweak_tfm, in_key + key_len, key_len);
249315 if (err)
250316 return err;
....@@ -252,30 +318,53 @@
252318 return aesbs_setkey(tfm, in_key, key_len);
253319 }
254320
255
-static int xts_init(struct crypto_tfm *tfm)
321
+static int xts_init(struct crypto_skcipher *tfm)
256322 {
257
- struct aesbs_xts_ctx *ctx = crypto_tfm_ctx(tfm);
323
+ struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
324
+
325
+ ctx->cts_tfm = crypto_alloc_cipher("aes", 0, 0);
326
+ if (IS_ERR(ctx->cts_tfm))
327
+ return PTR_ERR(ctx->cts_tfm);
258328
259329 ctx->tweak_tfm = crypto_alloc_cipher("aes", 0, 0);
330
+ if (IS_ERR(ctx->tweak_tfm))
331
+ crypto_free_cipher(ctx->cts_tfm);
260332
261333 return PTR_ERR_OR_ZERO(ctx->tweak_tfm);
262334 }
263335
264
-static void xts_exit(struct crypto_tfm *tfm)
336
+static void xts_exit(struct crypto_skcipher *tfm)
265337 {
266
- struct aesbs_xts_ctx *ctx = crypto_tfm_ctx(tfm);
338
+ struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
267339
268340 crypto_free_cipher(ctx->tweak_tfm);
341
+ crypto_free_cipher(ctx->cts_tfm);
269342 }
270343
271
-static int __xts_crypt(struct skcipher_request *req,
344
+static int __xts_crypt(struct skcipher_request *req, bool encrypt,
272345 void (*fn)(u8 out[], u8 const in[], u8 const rk[],
273
- int rounds, int blocks, u8 iv[]))
346
+ int rounds, int blocks, u8 iv[], int))
274347 {
275348 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
276349 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
350
+ int tail = req->cryptlen % AES_BLOCK_SIZE;
351
+ struct skcipher_request subreq;
352
+ u8 buf[2 * AES_BLOCK_SIZE];
277353 struct skcipher_walk walk;
278354 int err;
355
+
356
+ if (req->cryptlen < AES_BLOCK_SIZE)
357
+ return -EINVAL;
358
+
359
+ if (unlikely(tail)) {
360
+ skcipher_request_set_tfm(&subreq, tfm);
361
+ skcipher_request_set_callback(&subreq,
362
+ skcipher_request_flags(req),
363
+ NULL, NULL);
364
+ skcipher_request_set_crypt(&subreq, req->src, req->dst,
365
+ req->cryptlen - tail, req->iv);
366
+ req = &subreq;
367
+ }
279368
280369 err = skcipher_walk_virt(&walk, req, true);
281370 if (err)
....@@ -283,32 +372,55 @@
283372
284373 crypto_cipher_encrypt_one(ctx->tweak_tfm, walk.iv, walk.iv);
285374
286
- kernel_neon_begin();
287375 while (walk.nbytes >= AES_BLOCK_SIZE) {
288376 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
377
+ int reorder_last_tweak = !encrypt && tail > 0;
289378
290
- if (walk.nbytes < walk.total)
379
+ if (walk.nbytes < walk.total) {
291380 blocks = round_down(blocks,
292381 walk.stride / AES_BLOCK_SIZE);
382
+ reorder_last_tweak = 0;
383
+ }
293384
385
+ kernel_neon_begin();
294386 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
295
- ctx->key.rounds, blocks, walk.iv);
387
+ ctx->key.rounds, blocks, walk.iv, reorder_last_tweak);
388
+ kernel_neon_end();
296389 err = skcipher_walk_done(&walk,
297390 walk.nbytes - blocks * AES_BLOCK_SIZE);
298391 }
299
- kernel_neon_end();
300392
301
- return err;
393
+ if (err || likely(!tail))
394
+ return err;
395
+
396
+ /* handle ciphertext stealing */
397
+ scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
398
+ AES_BLOCK_SIZE, 0);
399
+ memcpy(buf + AES_BLOCK_SIZE, buf, tail);
400
+ scatterwalk_map_and_copy(buf, req->src, req->cryptlen, tail, 0);
401
+
402
+ crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
403
+
404
+ if (encrypt)
405
+ crypto_cipher_encrypt_one(ctx->cts_tfm, buf, buf);
406
+ else
407
+ crypto_cipher_decrypt_one(ctx->cts_tfm, buf, buf);
408
+
409
+ crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
410
+
411
+ scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
412
+ AES_BLOCK_SIZE + tail, 1);
413
+ return 0;
302414 }
303415
304416 static int xts_encrypt(struct skcipher_request *req)
305417 {
306
- return __xts_crypt(req, aesbs_xts_encrypt);
418
+ return __xts_crypt(req, true, aesbs_xts_encrypt);
307419 }
308420
309421 static int xts_decrypt(struct skcipher_request *req)
310422 {
311
- return __xts_crypt(req, aesbs_xts_decrypt);
423
+ return __xts_crypt(req, false, aesbs_xts_decrypt);
312424 }
313425
314426 static struct skcipher_alg aes_algs[] = { {
....@@ -333,9 +445,8 @@
333445 .base.cra_blocksize = AES_BLOCK_SIZE,
334446 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctx),
335447 .base.cra_module = THIS_MODULE,
336
- .base.cra_flags = CRYPTO_ALG_INTERNAL,
337
- .base.cra_init = cbc_init,
338
- .base.cra_exit = cbc_exit,
448
+ .base.cra_flags = CRYPTO_ALG_INTERNAL |
449
+ CRYPTO_ALG_NEED_FALLBACK,
339450
340451 .min_keysize = AES_MIN_KEY_SIZE,
341452 .max_keysize = AES_MAX_KEY_SIZE,
....@@ -344,6 +455,8 @@
344455 .setkey = aesbs_cbc_setkey,
345456 .encrypt = cbc_encrypt,
346457 .decrypt = cbc_decrypt,
458
+ .init = cbc_init,
459
+ .exit = cbc_exit,
347460 }, {
348461 .base.cra_name = "__ctr(aes)",
349462 .base.cra_driver_name = "__ctr-aes-neonbs",
....@@ -362,6 +475,22 @@
362475 .encrypt = ctr_encrypt,
363476 .decrypt = ctr_encrypt,
364477 }, {
478
+ .base.cra_name = "ctr(aes)",
479
+ .base.cra_driver_name = "ctr-aes-neonbs-sync",
480
+ .base.cra_priority = 250 - 1,
481
+ .base.cra_blocksize = 1,
482
+ .base.cra_ctxsize = sizeof(struct aesbs_ctr_ctx),
483
+ .base.cra_module = THIS_MODULE,
484
+
485
+ .min_keysize = AES_MIN_KEY_SIZE,
486
+ .max_keysize = AES_MAX_KEY_SIZE,
487
+ .chunksize = AES_BLOCK_SIZE,
488
+ .walksize = 8 * AES_BLOCK_SIZE,
489
+ .ivsize = AES_BLOCK_SIZE,
490
+ .setkey = aesbs_ctr_setkey_sync,
491
+ .encrypt = ctr_encrypt_sync,
492
+ .decrypt = ctr_encrypt_sync,
493
+}, {
365494 .base.cra_name = "__xts(aes)",
366495 .base.cra_driver_name = "__xts-aes-neonbs",
367496 .base.cra_priority = 250,
....@@ -369,8 +498,6 @@
369498 .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx),
370499 .base.cra_module = THIS_MODULE,
371500 .base.cra_flags = CRYPTO_ALG_INTERNAL,
372
- .base.cra_init = xts_init,
373
- .base.cra_exit = xts_exit,
374501
375502 .min_keysize = 2 * AES_MIN_KEY_SIZE,
376503 .max_keysize = 2 * AES_MAX_KEY_SIZE,
....@@ -379,6 +506,8 @@
379506 .setkey = aesbs_xts_setkey,
380507 .encrypt = xts_encrypt,
381508 .decrypt = xts_decrypt,
509
+ .init = xts_init,
510
+ .exit = xts_exit,
382511 } };
383512
384513 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];