From 6778948f9de86c3cfaf36725a7c87dcff9ba247f Mon Sep 17 00:00:00 2001
From: hc <hc@nodka.com>
Date: Mon, 11 Dec 2023 08:20:59 +0000
Subject: [PATCH] kernel_5.10 no rt

---
 kernel/arch/arm/crypto/aes-ce-glue.c |  577 ++++++++++++++++++++++++++++++++++++++++++--------------
 1 files changed, 428 insertions(+), 149 deletions(-)

diff --git a/kernel/arch/arm/crypto/aes-ce-glue.c b/kernel/arch/arm/crypto/aes-ce-glue.c
index d0a9cec..b668c97 100644
--- a/kernel/arch/arm/crypto/aes-ce-glue.c
+++ b/kernel/arch/arm/crypto/aes-ce-glue.c
@@ -1,19 +1,19 @@
+// SPDX-License-Identifier: GPL-2.0-only
 /*
  * aes-ce-glue.c - wrapper code for ARMv8 AES
  *
  * Copyright (C) 2015 Linaro Ltd <ard.biesheuvel@linaro.org>
- *
- * This program is free software; you can redistribute it and/or modify
- * it under the terms of the GNU General Public License version 2 as
- * published by the Free Software Foundation.
  */
 
 #include <asm/hwcap.h>
 #include <asm/neon.h>
-#include <asm/hwcap.h>
+#include <asm/simd.h>
+#include <asm/unaligned.h>
 #include <crypto/aes.h>
+#include <crypto/ctr.h>
 #include <crypto/internal/simd.h>
 #include <crypto/internal/skcipher.h>
+#include <crypto/scatterwalk.h>
 #include <linux/cpufeature.h>
 #include <linux/module.h>
 #include <crypto/xts.h>
@@ -26,25 +26,29 @@
 asmlinkage u32 ce_aes_sub(u32 input);
 asmlinkage void ce_aes_invert(void *dst, void *src);
 
-asmlinkage void ce_aes_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
+asmlinkage void ce_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
 				   int rounds, int blocks);
-asmlinkage void ce_aes_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
+asmlinkage void ce_aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
 				   int rounds, int blocks);
 
-asmlinkage void ce_aes_cbc_encrypt(u8 out[], u8 const in[], u8 const rk[],
+asmlinkage void ce_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
 				   int rounds, int blocks, u8 iv[]);
-asmlinkage void ce_aes_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
+asmlinkage void ce_aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
 				   int rounds, int blocks, u8 iv[]);
+asmlinkage void ce_aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
+				   int rounds, int bytes, u8 const iv[]);
+asmlinkage void ce_aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
+				   int rounds, int bytes, u8 const iv[]);
 
-asmlinkage void ce_aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
+asmlinkage void ce_aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
 				   int rounds, int blocks, u8 ctr[]);
 
-asmlinkage void ce_aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[],
-				   int rounds, int blocks, u8 iv[],
-				   u8 const rk2[], int first);
-asmlinkage void ce_aes_xts_decrypt(u8 out[], u8 const in[], u8 const rk1[],
-				   int rounds, int blocks, u8 iv[],
-				   u8 const rk2[], int first);
+asmlinkage void ce_aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
+				   int rounds, int bytes, u8 iv[],
+				   u32 const rk2[], int first);
+asmlinkage void ce_aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
+				   int rounds, int bytes, u8 iv[],
+				   u32 const rk2[], int first);
 
 struct aes_block {
 	u8 b[AES_BLOCK_SIZE];
@@ -81,21 +85,17 @@
 	    key_len != AES_KEYSIZE_256)
 		return -EINVAL;
 
-	memcpy(ctx->key_enc, in_key, key_len);
 	ctx->key_length = key_len;
+	for (i = 0; i < kwords; i++)
+		ctx->key_enc[i] = get_unaligned_le32(in_key + i * sizeof(u32));
 
 	kernel_neon_begin();
 	for (i = 0; i < sizeof(rcon); i++) {
 		u32 *rki = ctx->key_enc + (i * kwords);
 		u32 *rko = rki + kwords;
 
-#ifndef CONFIG_CPU_BIG_ENDIAN
 		rko[0] = ror32(ce_aes_sub(rki[kwords - 1]), 8);
 		rko[0] = rko[0] ^ rki[0] ^ rcon[i];
-#else
-		rko[0] = rol32(ce_aes_sub(rki[kwords - 1]), 8);
-		rko[0] = rko[0] ^ rki[0] ^ (rcon[i] << 24);
-#endif
 		rko[1] = rko[0] ^ rki[1];
 		rko[2] = rko[1] ^ rki[2];
 		rko[3] = rko[2] ^ rki[3];
@@ -138,14 +138,8 @@
 			 unsigned int key_len)
 {
 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
-	int ret;
 
-	ret = ce_aes_expandkey(ctx, in_key, key_len);
-	if (!ret)
-		return 0;
-
-	crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
-	return -EINVAL;
+	return ce_aes_expandkey(ctx, in_key, key_len);
 }
 
 struct crypto_aes_xts_ctx {
@@ -167,11 +161,7 @@
 	if (!ret)
 		ret = ce_aes_expandkey(&ctx->key2, &in_key[key_len / 2],
 				       key_len / 2);
-	if (!ret)
-		return 0;
-
-	crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
-	return -EINVAL;
+	return ret;
 }
 
 static int ecb_encrypt(struct skcipher_request *req)
@@ -182,15 +172,15 @@
 	unsigned int blocks;
 	int err;
 
-	err = skcipher_walk_virt(&walk, req, true);
+	err = skcipher_walk_virt(&walk, req, false);
 
-	kernel_neon_begin();
 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+		kernel_neon_begin();
 		ce_aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-				   (u8 *)ctx->key_enc, num_rounds(ctx), blocks);
+				   ctx->key_enc, num_rounds(ctx), blocks);
+		kernel_neon_end();
 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 	}
-	kernel_neon_end();
 	return err;
 }
 
@@ -202,58 +192,192 @@
 	unsigned int blocks;
 	int err;
 
-	err = skcipher_walk_virt(&walk, req, true);
+	err = skcipher_walk_virt(&walk, req, false);
 
-	kernel_neon_begin();
 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+		kernel_neon_begin();
 		ce_aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-				   (u8 *)ctx->key_dec, num_rounds(ctx), blocks);
+				   ctx->key_dec, num_rounds(ctx), blocks);
+		kernel_neon_end();
 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 	}
-	kernel_neon_end();
+	return err;
+}
+
+static int cbc_encrypt_walk(struct skcipher_request *req,
+			    struct skcipher_walk *walk)
+{
+	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+	unsigned int blocks;
+	int err = 0;
+
+	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
+		kernel_neon_begin();
+		ce_aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
+				   ctx->key_enc, num_rounds(ctx), blocks,
+				   walk->iv);
+		kernel_neon_end();
+		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
+	}
 	return err;
 }
 
 static int cbc_encrypt(struct skcipher_request *req)
 {
-	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
-	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 	struct skcipher_walk walk;
-	unsigned int blocks;
 	int err;
 
-	err = skcipher_walk_virt(&walk, req, true);
+	err = skcipher_walk_virt(&walk, req, false);
+	if (err)
+		return err;
+	return cbc_encrypt_walk(req, &walk);
+}
 
-	kernel_neon_begin();
-	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
-		ce_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-				   (u8 *)ctx->key_enc, num_rounds(ctx), blocks,
-				   walk.iv);
-		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
+static int cbc_decrypt_walk(struct skcipher_request *req,
+			    struct skcipher_walk *walk)
+{
+	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+	unsigned int blocks;
+	int err = 0;
+
+	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
+		kernel_neon_begin();
+		ce_aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
+				   ctx->key_dec, num_rounds(ctx), blocks,
+				   walk->iv);
+		kernel_neon_end();
+		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
 	}
-	kernel_neon_end();
 	return err;
 }
 
 static int cbc_decrypt(struct skcipher_request *req)
 {
-	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
-	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 	struct skcipher_walk walk;
-	unsigned int blocks;
 	int err;
 
-	err = skcipher_walk_virt(&walk, req, true);
+	err = skcipher_walk_virt(&walk, req, false);
+	if (err)
+		return err;
+	return cbc_decrypt_walk(req, &walk);
+}
+
+static int cts_cbc_encrypt(struct skcipher_request *req)
+{
+	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
+	struct scatterlist *src = req->src, *dst = req->dst;
+	struct scatterlist sg_src[2], sg_dst[2];
+	struct skcipher_request subreq;
+	struct skcipher_walk walk;
+	int err;
+
+	skcipher_request_set_tfm(&subreq, tfm);
+	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
+				      NULL, NULL);
+
+	if (req->cryptlen <= AES_BLOCK_SIZE) {
+		if (req->cryptlen < AES_BLOCK_SIZE)
+			return -EINVAL;
+		cbc_blocks = 1;
+	}
+
+	if (cbc_blocks > 0) {
+		skcipher_request_set_crypt(&subreq, req->src, req->dst,
+					   cbc_blocks * AES_BLOCK_SIZE,
+					   req->iv);
+
+		err = skcipher_walk_virt(&walk, &subreq, false) ?:
+		      cbc_encrypt_walk(&subreq, &walk);
+		if (err)
+			return err;
+
+		if (req->cryptlen == AES_BLOCK_SIZE)
+			return 0;
+
+		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
+		if (req->dst != req->src)
+			dst = scatterwalk_ffwd(sg_dst, req->dst,
+					       subreq.cryptlen);
+	}
+
+	/* handle ciphertext stealing */
+	skcipher_request_set_crypt(&subreq, src, dst,
+				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
+				   req->iv);
+
+	err = skcipher_walk_virt(&walk, &subreq, false);
+	if (err)
+		return err;
 
 	kernel_neon_begin();
-	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
-		ce_aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-				   (u8 *)ctx->key_dec, num_rounds(ctx), blocks,
-				   walk.iv);
-		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
-	}
+	ce_aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+			       ctx->key_enc, num_rounds(ctx), walk.nbytes,
+			       walk.iv);
 	kernel_neon_end();
-	return err;
+
+	return skcipher_walk_done(&walk, 0);
+}
+
+static int cts_cbc_decrypt(struct skcipher_request *req)
+{
+	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
+	struct scatterlist *src = req->src, *dst = req->dst;
+	struct scatterlist sg_src[2], sg_dst[2];
+	struct skcipher_request subreq;
+	struct skcipher_walk walk;
+	int err;
+
+	skcipher_request_set_tfm(&subreq, tfm);
+	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
+				      NULL, NULL);
+
+	if (req->cryptlen <= AES_BLOCK_SIZE) {
+		if (req->cryptlen < AES_BLOCK_SIZE)
+			return -EINVAL;
+		cbc_blocks = 1;
+	}
+
+	if (cbc_blocks > 0) {
+		skcipher_request_set_crypt(&subreq, req->src, req->dst,
+					   cbc_blocks * AES_BLOCK_SIZE,
+					   req->iv);
+
+		err = skcipher_walk_virt(&walk, &subreq, false) ?:
+		      cbc_decrypt_walk(&subreq, &walk);
+		if (err)
+			return err;
+
+		if (req->cryptlen == AES_BLOCK_SIZE)
+			return 0;
+
+		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
+		if (req->dst != req->src)
+			dst = scatterwalk_ffwd(sg_dst, req->dst,
+					       subreq.cryptlen);
+	}
+
+	/* handle ciphertext stealing */
+	skcipher_request_set_crypt(&subreq, src, dst,
+				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
+				   req->iv);
+
+	err = skcipher_walk_virt(&walk, &subreq, false);
+	if (err)
+		return err;
+
+	kernel_neon_begin();
+	ce_aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+			       ctx->key_dec, num_rounds(ctx), walk.nbytes,
+			       walk.iv);
+	kernel_neon_end();
+
+	return skcipher_walk_done(&walk, 0);
 }
 
 static int ctr_encrypt(struct skcipher_request *req)
@@ -263,13 +387,14 @@
 	struct skcipher_walk walk;
 	int err, blocks;
 
-	err = skcipher_walk_virt(&walk, req, true);
+	err = skcipher_walk_virt(&walk, req, false);
 
-	kernel_neon_begin();
 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+		kernel_neon_begin();
 		ce_aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-				   (u8 *)ctx->key_enc, num_rounds(ctx), blocks,
+				   ctx->key_enc, num_rounds(ctx), blocks,
 				   walk.iv);
+		kernel_neon_end();
 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 	}
 	if (walk.nbytes) {
@@ -283,14 +408,37 @@
 		 */
 		blocks = -1;
 
-		ce_aes_ctr_encrypt(tail, NULL, (u8 *)ctx->key_enc,
-				   num_rounds(ctx), blocks, walk.iv);
+		kernel_neon_begin();
+		ce_aes_ctr_encrypt(tail, NULL, ctx->key_enc, num_rounds(ctx),
+				   blocks, walk.iv);
+		kernel_neon_end();
 		crypto_xor_cpy(tdst, tsrc, tail, nbytes);
 		err = skcipher_walk_done(&walk, 0);
 	}
-	kernel_neon_end();
-
 	return err;
+}
+
+static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
+{
+	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+	unsigned long flags;
+
+	/*
+	 * Temporarily disable interrupts to avoid races where
+	 * cachelines are evicted when the CPU is interrupted
+	 * to do something else.
+	 */
+	local_irq_save(flags);
+	aes_encrypt(ctx, dst, src);
+	local_irq_restore(flags);
+}
+
+static int ctr_encrypt_sync(struct skcipher_request *req)
+{
+	if (!crypto_simd_usable())
+		return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
+
+	return ctr_encrypt(req);
 }
 
 static int xts_encrypt(struct skcipher_request *req)
@@ -298,21 +446,71 @@
 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 	int err, first, rounds = num_rounds(&ctx->key1);
+	int tail = req->cryptlen % AES_BLOCK_SIZE;
+	struct scatterlist sg_src[2], sg_dst[2];
+	struct skcipher_request subreq;
+	struct scatterlist *src, *dst;
 	struct skcipher_walk walk;
-	unsigned int blocks;
 
-	err = skcipher_walk_virt(&walk, req, true);
+	if (req->cryptlen < AES_BLOCK_SIZE)
+		return -EINVAL;
+
+	err = skcipher_walk_virt(&walk, req, false);
+
+	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
+		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
+					      AES_BLOCK_SIZE) - 2;
+
+		skcipher_walk_abort(&walk);
+
+		skcipher_request_set_tfm(&subreq, tfm);
+		skcipher_request_set_callback(&subreq,
+					      skcipher_request_flags(req),
+					      NULL, NULL);
+		skcipher_request_set_crypt(&subreq, req->src, req->dst,
+					   xts_blocks * AES_BLOCK_SIZE,
+					   req->iv);
+		req = &subreq;
+		err = skcipher_walk_virt(&walk, req, false);
+	} else {
+		tail = 0;
+	}
+
+	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
+		int nbytes = walk.nbytes;
+
+		if (walk.nbytes < walk.total)
+			nbytes &= ~(AES_BLOCK_SIZE - 1);
+
+		kernel_neon_begin();
+		ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+				   ctx->key1.key_enc, rounds, nbytes, walk.iv,
+				   ctx->key2.key_enc, first);
+		kernel_neon_end();
+		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
+	}
+
+	if (err || likely(!tail))
+		return err;
+
+	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
+	if (req->dst != req->src)
+		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
+
+	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
+				   req->iv);
+
+	err = skcipher_walk_virt(&walk, req, false);
+	if (err)
+		return err;
 
 	kernel_neon_begin();
-	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
-		ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-				   (u8 *)ctx->key1.key_enc, rounds, blocks,
-				   walk.iv, (u8 *)ctx->key2.key_enc, first);
-		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
-	}
+	ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+			   ctx->key1.key_enc, rounds, walk.nbytes, walk.iv,
+			   ctx->key2.key_enc, first);
 	kernel_neon_end();
 
-	return err;
+	return skcipher_walk_done(&walk, 0);
 }
 
 static int xts_decrypt(struct skcipher_request *req)
@@ -320,87 +518,165 @@
 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 	int err, first, rounds = num_rounds(&ctx->key1);
+	int tail = req->cryptlen % AES_BLOCK_SIZE;
+	struct scatterlist sg_src[2], sg_dst[2];
+	struct skcipher_request subreq;
+	struct scatterlist *src, *dst;
 	struct skcipher_walk walk;
-	unsigned int blocks;
 
-	err = skcipher_walk_virt(&walk, req, true);
+	if (req->cryptlen < AES_BLOCK_SIZE)
+		return -EINVAL;
+
+	err = skcipher_walk_virt(&walk, req, false);
+
+	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
+		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
+					      AES_BLOCK_SIZE) - 2;
+
+		skcipher_walk_abort(&walk);
+
+		skcipher_request_set_tfm(&subreq, tfm);
+		skcipher_request_set_callback(&subreq,
+					      skcipher_request_flags(req),
+					      NULL, NULL);
+		skcipher_request_set_crypt(&subreq, req->src, req->dst,
+					   xts_blocks * AES_BLOCK_SIZE,
+					   req->iv);
+		req = &subreq;
+		err = skcipher_walk_virt(&walk, req, false);
+	} else {
+		tail = 0;
+	}
+
+	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
+		int nbytes = walk.nbytes;
+
+		if (walk.nbytes < walk.total)
+			nbytes &= ~(AES_BLOCK_SIZE - 1);
+
+		kernel_neon_begin();
+		ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+				   ctx->key1.key_dec, rounds, nbytes, walk.iv,
+				   ctx->key2.key_enc, first);
+		kernel_neon_end();
+		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
+	}
+
+	if (err || likely(!tail))
+		return err;
+
+	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
+	if (req->dst != req->src)
+		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
+
+	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
+				   req->iv);
+
+	err = skcipher_walk_virt(&walk, req, false);
+	if (err)
+		return err;
 
 	kernel_neon_begin();
-	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
-		ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-				   (u8 *)ctx->key1.key_dec, rounds, blocks,
-				   walk.iv, (u8 *)ctx->key2.key_enc, first);
-		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
-	}
+	ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+			   ctx->key1.key_dec, rounds, walk.nbytes, walk.iv,
+			   ctx->key2.key_enc, first);
 	kernel_neon_end();
 
-	return err;
+	return skcipher_walk_done(&walk, 0);
 }
 
 static struct skcipher_alg aes_algs[] = { {
-	.base = {
-		.cra_name		= "__ecb(aes)",
-		.cra_driver_name	= "__ecb-aes-ce",
-		.cra_priority		= 300,
-		.cra_flags		= CRYPTO_ALG_INTERNAL,
-		.cra_blocksize		= AES_BLOCK_SIZE,
-		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
-		.cra_module		= THIS_MODULE,
-	},
-	.min_keysize	= AES_MIN_KEY_SIZE,
-	.max_keysize	= AES_MAX_KEY_SIZE,
-	.setkey		= ce_aes_setkey,
-	.encrypt	= ecb_encrypt,
-	.decrypt	= ecb_decrypt,
+	.base.cra_name		= "__ecb(aes)",
+	.base.cra_driver_name	= "__ecb-aes-ce",
+	.base.cra_priority	= 300,
+	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
+	.base.cra_blocksize	= AES_BLOCK_SIZE,
+	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
+	.base.cra_module	= THIS_MODULE,
+
+	.min_keysize		= AES_MIN_KEY_SIZE,
+	.max_keysize		= AES_MAX_KEY_SIZE,
+	.setkey			= ce_aes_setkey,
+	.encrypt		= ecb_encrypt,
+	.decrypt		= ecb_decrypt,
 }, {
-	.base = {
-		.cra_name		= "__cbc(aes)",
-		.cra_driver_name	= "__cbc-aes-ce",
-		.cra_priority		= 300,
-		.cra_flags		= CRYPTO_ALG_INTERNAL,
-		.cra_blocksize		= AES_BLOCK_SIZE,
-		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
-		.cra_module		= THIS_MODULE,
-	},
-	.min_keysize	= AES_MIN_KEY_SIZE,
-	.max_keysize	= AES_MAX_KEY_SIZE,
-	.ivsize		= AES_BLOCK_SIZE,
-	.setkey		= ce_aes_setkey,
-	.encrypt	= cbc_encrypt,
-	.decrypt	= cbc_decrypt,
+	.base.cra_name		= "__cbc(aes)",
+	.base.cra_driver_name	= "__cbc-aes-ce",
+	.base.cra_priority	= 300,
+	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
+	.base.cra_blocksize	= AES_BLOCK_SIZE,
+	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
+	.base.cra_module	= THIS_MODULE,
+
+	.min_keysize		= AES_MIN_KEY_SIZE,
+	.max_keysize		= AES_MAX_KEY_SIZE,
+	.ivsize			= AES_BLOCK_SIZE,
+	.setkey			= ce_aes_setkey,
+	.encrypt		= cbc_encrypt,
+	.decrypt		= cbc_decrypt,
 }, {
-	.base = {
-		.cra_name		= "__ctr(aes)",
-		.cra_driver_name	= "__ctr-aes-ce",
-		.cra_priority		= 300,
-		.cra_flags		= CRYPTO_ALG_INTERNAL,
-		.cra_blocksize		= 1,
-		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
-		.cra_module		= THIS_MODULE,
-	},
-	.min_keysize	= AES_MIN_KEY_SIZE,
-	.max_keysize	= AES_MAX_KEY_SIZE,
-	.ivsize		= AES_BLOCK_SIZE,
-	.chunksize	= AES_BLOCK_SIZE,
-	.setkey		= ce_aes_setkey,
-	.encrypt	= ctr_encrypt,
-	.decrypt	= ctr_encrypt,
+	.base.cra_name		= "__cts(cbc(aes))",
+	.base.cra_driver_name	= "__cts-cbc-aes-ce",
+	.base.cra_priority	= 300,
+	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
+	.base.cra_blocksize	= AES_BLOCK_SIZE,
+	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
+	.base.cra_module	= THIS_MODULE,
+
+	.min_keysize		= AES_MIN_KEY_SIZE,
+	.max_keysize		= AES_MAX_KEY_SIZE,
+	.ivsize			= AES_BLOCK_SIZE,
+	.walksize		= 2 * AES_BLOCK_SIZE,
+	.setkey			= ce_aes_setkey,
+	.encrypt		= cts_cbc_encrypt,
+	.decrypt		= cts_cbc_decrypt,
 }, {
-	.base = {
-		.cra_name		= "__xts(aes)",
-		.cra_driver_name	= "__xts-aes-ce",
-		.cra_priority		= 300,
-		.cra_flags		= CRYPTO_ALG_INTERNAL,
-		.cra_blocksize		= AES_BLOCK_SIZE,
-		.cra_ctxsize		= sizeof(struct crypto_aes_xts_ctx),
-		.cra_module		= THIS_MODULE,
-	},
-	.min_keysize	= 2 * AES_MIN_KEY_SIZE,
-	.max_keysize	= 2 * AES_MAX_KEY_SIZE,
-	.ivsize		= AES_BLOCK_SIZE,
-	.setkey		= xts_set_key,
-	.encrypt	= xts_encrypt,
-	.decrypt	= xts_decrypt,
+	.base.cra_name		= "__ctr(aes)",
+	.base.cra_driver_name	= "__ctr-aes-ce",
+	.base.cra_priority	= 300,
+	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
+	.base.cra_blocksize	= 1,
+	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
+	.base.cra_module	= THIS_MODULE,
+
+	.min_keysize		= AES_MIN_KEY_SIZE,
+	.max_keysize		= AES_MAX_KEY_SIZE,
+	.ivsize			= AES_BLOCK_SIZE,
+	.chunksize		= AES_BLOCK_SIZE,
+	.setkey			= ce_aes_setkey,
+	.encrypt		= ctr_encrypt,
+	.decrypt		= ctr_encrypt,
+}, {
+	.base.cra_name		= "ctr(aes)",
+	.base.cra_driver_name	= "ctr-aes-ce-sync",
+	.base.cra_priority	= 300 - 1,
+	.base.cra_blocksize	= 1,
+	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
+	.base.cra_module	= THIS_MODULE,
+
+	.min_keysize		= AES_MIN_KEY_SIZE,
+	.max_keysize		= AES_MAX_KEY_SIZE,
+	.ivsize			= AES_BLOCK_SIZE,
+	.chunksize		= AES_BLOCK_SIZE,
+	.setkey			= ce_aes_setkey,
+	.encrypt		= ctr_encrypt_sync,
+	.decrypt		= ctr_encrypt_sync,
+}, {
+	.base.cra_name		= "__xts(aes)",
+	.base.cra_driver_name	= "__xts-aes-ce",
+	.base.cra_priority	= 300,
+	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
+	.base.cra_blocksize	= AES_BLOCK_SIZE,
+	.base.cra_ctxsize	= sizeof(struct crypto_aes_xts_ctx),
+	.base.cra_module	= THIS_MODULE,
+
+	.min_keysize		= 2 * AES_MIN_KEY_SIZE,
+	.max_keysize		= 2 * AES_MAX_KEY_SIZE,
+	.ivsize			= AES_BLOCK_SIZE,
+	.walksize		= 2 * AES_BLOCK_SIZE,
+	.setkey			= xts_set_key,
+	.encrypt		= xts_encrypt,
+	.decrypt		= xts_decrypt,
 } };
 
 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
@@ -429,6 +705,9 @@
 		return err;
 
 	for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
+		if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
+			continue;
+
 		algname = aes_algs[i].base.cra_name + 2;
 		drvname = aes_algs[i].base.cra_driver_name + 2;
 		basename = aes_algs[i].base.cra_driver_name;

--
Gitblit v1.6.2