From 102a0743326a03cd1a1202ceda21e175b7d3575c Mon Sep 17 00:00:00 2001
From: hc <hc@nodka.com>
Date: Tue, 20 Feb 2024 01:20:52 +0000
Subject: [PATCH] add new system file

---
 kernel/net/tls/tls_sw.c | 2567 +++++++++++++++++++++++++++++++++++++++++++---------------
 1 files changed, 1,883 insertions(+), 684 deletions(-)

diff --git a/kernel/net/tls/tls_sw.c b/kernel/net/tls/tls_sw.c
index 7d76124..50eae66 100644
--- a/kernel/net/tls/tls_sw.c
+++ b/kernel/net/tls/tls_sw.c
@@ -4,6 +4,7 @@
  * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
  * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
  * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
+ * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io
  *
  * This software is available to you under a choice of one of two
  * licenses.  You may choose to be licensed under the terms of the GNU
@@ -34,244 +35,1327 @@
  * SOFTWARE.
  */
 
+#include <linux/bug.h>
 #include <linux/sched/signal.h>
 #include <linux/module.h>
+#include <linux/splice.h>
 #include <crypto/aead.h>
 
 #include <net/strparser.h>
 #include <net/tls.h>
 
-#define MAX_IV_SIZE	TLS_CIPHER_AES_GCM_128_IV_SIZE
+noinline void tls_err_abort(struct sock *sk, int err)
+{
+	WARN_ON_ONCE(err >= 0);
+	/* sk->sk_err should contain a positive error code. */
+	sk->sk_err = -err;
+	sk->sk_error_report(sk);
+}
+
+static int __skb_nsg(struct sk_buff *skb, int offset, int len,
+                     unsigned int recursion_level)
+{
+        int start = skb_headlen(skb);
+        int i, chunk = start - offset;
+        struct sk_buff *frag_iter;
+        int elt = 0;
+
+        if (unlikely(recursion_level >= 24))
+                return -EMSGSIZE;
+
+        if (chunk > 0) {
+                if (chunk > len)
+                        chunk = len;
+                elt++;
+                len -= chunk;
+                if (len == 0)
+                        return elt;
+                offset += chunk;
+        }
+
+        for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
+                int end;
+
+                WARN_ON(start > offset + len);
+
+                end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
+                chunk = end - offset;
+                if (chunk > 0) {
+                        if (chunk > len)
+                                chunk = len;
+                        elt++;
+                        len -= chunk;
+                        if (len == 0)
+                                return elt;
+                        offset += chunk;
+                }
+                start = end;
+        }
+
+        if (unlikely(skb_has_frag_list(skb))) {
+                skb_walk_frags(skb, frag_iter) {
+                        int end, ret;
+
+                        WARN_ON(start > offset + len);
+
+                        end = start + frag_iter->len;
+                        chunk = end - offset;
+                        if (chunk > 0) {
+                                if (chunk > len)
+                                        chunk = len;
+                                ret = __skb_nsg(frag_iter, offset - start, chunk,
+                                                recursion_level + 1);
+                                if (unlikely(ret < 0))
+                                        return ret;
+                                elt += ret;
+                                len -= chunk;
+                                if (len == 0)
+                                        return elt;
+                                offset += chunk;
+                        }
+                        start = end;
+                }
+        }
+        BUG_ON(len);
+        return elt;
+}
+
+/* Return the number of scatterlist elements required to completely map the
+ * skb, or -EMSGSIZE if the recursion depth is exceeded.
+ */
+static int skb_nsg(struct sk_buff *skb, int offset, int len)
+{
+        return __skb_nsg(skb, offset, len, 0);
+}
+
+static int padding_length(struct tls_sw_context_rx *ctx,
+			  struct tls_prot_info *prot, struct sk_buff *skb)
+{
+	struct strp_msg *rxm = strp_msg(skb);
+	int sub = 0;
+
+	/* Determine zero-padding length */
+	if (prot->version == TLS_1_3_VERSION) {
+		char content_type = 0;
+		int err;
+		int back = 17;
+
+		while (content_type == 0) {
+			if (back > rxm->full_len - prot->prepend_size)
+				return -EBADMSG;
+			err = skb_copy_bits(skb,
+					    rxm->offset + rxm->full_len - back,
+					    &content_type, 1);
+			if (err)
+				return err;
+			if (content_type)
+				break;
+			sub++;
+			back++;
+		}
+		ctx->control = content_type;
+	}
+	return sub;
+}
+
+static void tls_decrypt_done(struct crypto_async_request *req, int err)
+{
+	struct aead_request *aead_req = (struct aead_request *)req;
+	struct scatterlist *sgout = aead_req->dst;
+	struct scatterlist *sgin = aead_req->src;
+	struct tls_sw_context_rx *ctx;
+	struct tls_context *tls_ctx;
+	struct tls_prot_info *prot;
+	struct scatterlist *sg;
+	struct sk_buff *skb;
+	unsigned int pages;
+	int pending;
+
+	skb = (struct sk_buff *)req->data;
+	tls_ctx = tls_get_ctx(skb->sk);
+	ctx = tls_sw_ctx_rx(tls_ctx);
+	prot = &tls_ctx->prot_info;
+
+	/* Propagate if there was an err */
+	if (err) {
+		if (err == -EBADMSG)
+			TLS_INC_STATS(sock_net(skb->sk),
+				      LINUX_MIB_TLSDECRYPTERROR);
+		ctx->async_wait.err = err;
+		tls_err_abort(skb->sk, err);
+	} else {
+		struct strp_msg *rxm = strp_msg(skb);
+		int pad;
+
+		pad = padding_length(ctx, prot, skb);
+		if (pad < 0) {
+			ctx->async_wait.err = pad;
+			tls_err_abort(skb->sk, pad);
+		} else {
+			rxm->full_len -= pad;
+			rxm->offset += prot->prepend_size;
+			rxm->full_len -= prot->overhead_size;
+		}
+	}
+
+	/* After using skb->sk to propagate sk through crypto async callback
+	 * we need to NULL it again.
+	 */
+	skb->sk = NULL;
+
+
+	/* Free the destination pages if skb was not decrypted inplace */
+	if (sgout != sgin) {
+		/* Skip the first S/G entry as it points to AAD */
+		for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
+			if (!sg)
+				break;
+			put_page(sg_page(sg));
+		}
+	}
+
+	kfree(aead_req);
+
+	spin_lock_bh(&ctx->decrypt_compl_lock);
+	pending = atomic_dec_return(&ctx->decrypt_pending);
+
+	if (!pending && ctx->async_notify)
+		complete(&ctx->async_wait.completion);
+	spin_unlock_bh(&ctx->decrypt_compl_lock);
+}
 
 static int tls_do_decryption(struct sock *sk,
+			     struct sk_buff *skb,
 			     struct scatterlist *sgin,
 			     struct scatterlist *sgout,
 			     char *iv_recv,
 			     size_t data_len,
-			     struct aead_request *aead_req)
+			     struct aead_request *aead_req,
+			     bool async)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
 	int ret;
 
 	aead_request_set_tfm(aead_req, ctx->aead_recv);
-	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
+	aead_request_set_ad(aead_req, prot->aad_size);
 	aead_request_set_crypt(aead_req, sgin, sgout,
-			       data_len + tls_ctx->rx.tag_size,
+			       data_len + prot->tag_size,
 			       (u8 *)iv_recv);
-	aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
-				  crypto_req_done, &ctx->async_wait);
 
-	ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);
+	if (async) {
+		/* Using skb->sk to push sk through to crypto async callback
+		 * handler. This allows propagating errors up to the socket
+		 * if needed. It _must_ be cleared in the async handler
+		 * before consume_skb is called. We _know_ skb->sk is NULL
+		 * because it is a clone from strparser.
+		 */
+		skb->sk = sk;
+		aead_request_set_callback(aead_req,
+					  CRYPTO_TFM_REQ_MAY_BACKLOG,
+					  tls_decrypt_done, skb);
+		atomic_inc(&ctx->decrypt_pending);
+	} else {
+		aead_request_set_callback(aead_req,
+					  CRYPTO_TFM_REQ_MAY_BACKLOG,
+					  crypto_req_done, &ctx->async_wait);
+	}
+
+	ret = crypto_aead_decrypt(aead_req);
+	if (ret == -EINPROGRESS) {
+		if (async)
+			return ret;
+
+		ret = crypto_wait_req(ret, &ctx->async_wait);
+	}
+
+	if (async)
+		atomic_dec(&ctx->decrypt_pending);
+
 	return ret;
 }
 
-static void trim_sg(struct sock *sk, struct scatterlist *sg,
-		    int *sg_num_elem, unsigned int *sg_size, int target_size)
-{
-	int i = *sg_num_elem - 1;
-	int trim = *sg_size - target_size;
-
-	if (trim <= 0) {
-		WARN_ON(trim < 0);
-		return;
-	}
-
-	*sg_size = target_size;
-	while (trim >= sg[i].length) {
-		trim -= sg[i].length;
-		sk_mem_uncharge(sk, sg[i].length);
-		put_page(sg_page(&sg[i]));
-		i--;
-
-		if (i < 0)
-			goto out;
-	}
-
-	sg[i].length -= trim;
-	sk_mem_uncharge(sk, trim);
-
-out:
-	*sg_num_elem = i + 1;
-}
-
-static void trim_both_sgl(struct sock *sk, int target_size)
+static void tls_trim_both_msgs(struct sock *sk, int target_size)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct tls_rec *rec = ctx->open_rec;
 
-	trim_sg(sk, ctx->sg_plaintext_data,
-		&ctx->sg_plaintext_num_elem,
-		&ctx->sg_plaintext_size,
-		target_size);
-
+	sk_msg_trim(sk, &rec->msg_plaintext, target_size);
 	if (target_size > 0)
-		target_size += tls_ctx->tx.overhead_size;
-
-	trim_sg(sk, ctx->sg_encrypted_data,
-		&ctx->sg_encrypted_num_elem,
-		&ctx->sg_encrypted_size,
-		target_size);
+		target_size += prot->overhead_size;
+	sk_msg_trim(sk, &rec->msg_encrypted, target_size);
 }
 
-static int alloc_encrypted_sg(struct sock *sk, int len)
+static int tls_alloc_encrypted_msg(struct sock *sk, int len)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-	int rc = 0;
+	struct tls_rec *rec = ctx->open_rec;
+	struct sk_msg *msg_en = &rec->msg_encrypted;
 
-	rc = sk_alloc_sg(sk, len,
-			 ctx->sg_encrypted_data, 0,
-			 &ctx->sg_encrypted_num_elem,
-			 &ctx->sg_encrypted_size, 0);
-
-	if (rc == -ENOSPC)
-		ctx->sg_encrypted_num_elem = ARRAY_SIZE(ctx->sg_encrypted_data);
-
-	return rc;
+	return sk_msg_alloc(sk, msg_en, len, 0);
 }
 
-static int alloc_plaintext_sg(struct sock *sk, int len)
+static int tls_clone_plaintext_msg(struct sock *sk, int required)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
+	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct tls_rec *rec = ctx->open_rec;
+	struct sk_msg *msg_pl = &rec->msg_plaintext;
+	struct sk_msg *msg_en = &rec->msg_encrypted;
+	int skip, len;
+
+	/* We add page references worth len bytes from encrypted sg
+	 * at the end of plaintext sg. It is guaranteed that msg_en
+	 * has enough required room (ensured by caller).
+	 */
+	len = required - msg_pl->sg.size;
+
+	/* Skip initial bytes in msg_en's data to be able to use
+	 * same offset of both plain and encrypted data.
+	 */
+	skip = prot->prepend_size + msg_pl->sg.size;
+
+	return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
+}
+
+static struct tls_rec *tls_get_rec(struct sock *sk)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
+	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct sk_msg *msg_pl, *msg_en;
+	struct tls_rec *rec;
+	int mem_size;
+
+	mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
+
+	rec = kzalloc(mem_size, sk->sk_allocation);
+	if (!rec)
+		return NULL;
+
+	msg_pl = &rec->msg_plaintext;
+	msg_en = &rec->msg_encrypted;
+
+	sk_msg_init(msg_pl);
+	sk_msg_init(msg_en);
+
+	sg_init_table(rec->sg_aead_in, 2);
+	sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size);
+	sg_unmark_end(&rec->sg_aead_in[1]);
+
+	sg_init_table(rec->sg_aead_out, 2);
+	sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size);
+	sg_unmark_end(&rec->sg_aead_out[1]);
+
+	return rec;
+}
+
+static void tls_free_rec(struct sock *sk, struct tls_rec *rec)
+{
+	sk_msg_free(sk, &rec->msg_encrypted);
+	sk_msg_free(sk, &rec->msg_plaintext);
+	kfree(rec);
+}
+
+static void tls_free_open_rec(struct sock *sk)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-	int rc = 0;
+	struct tls_rec *rec = ctx->open_rec;
 
-	rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0,
-			 &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size,
-			 tls_ctx->pending_open_record_frags);
-
-	if (rc == -ENOSPC)
-		ctx->sg_plaintext_num_elem = ARRAY_SIZE(ctx->sg_plaintext_data);
-
-	return rc;
-}
-
-static void free_sg(struct sock *sk, struct scatterlist *sg,
-		    int *sg_num_elem, unsigned int *sg_size)
-{
-	int i, n = *sg_num_elem;
-
-	for (i = 0; i < n; ++i) {
-		sk_mem_uncharge(sk, sg[i].length);
-		put_page(sg_page(&sg[i]));
+	if (rec) {
+		tls_free_rec(sk, rec);
+		ctx->open_rec = NULL;
 	}
-	*sg_num_elem = 0;
-	*sg_size = 0;
 }
 
-static void tls_free_both_sg(struct sock *sk)
+int tls_tx_records(struct sock *sk, int flags)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct tls_rec *rec, *tmp;
+	struct sk_msg *msg_en;
+	int tx_flags, rc = 0;
 
-	free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem,
-		&ctx->sg_encrypted_size);
+	if (tls_is_partially_sent_record(tls_ctx)) {
+		rec = list_first_entry(&ctx->tx_list,
+				       struct tls_rec, list);
 
-	free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
-		&ctx->sg_plaintext_size);
+		if (flags == -1)
+			tx_flags = rec->tx_flags;
+		else
+			tx_flags = flags;
+
+		rc = tls_push_partial_record(sk, tls_ctx, tx_flags);
+		if (rc)
+			goto tx_err;
+
+		/* Full record has been transmitted.
+		 * Remove the head of tx_list
+		 */
+		list_del(&rec->list);
+		sk_msg_free(sk, &rec->msg_plaintext);
+		kfree(rec);
+	}
+
+	/* Tx all ready records */
+	list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
+		if (READ_ONCE(rec->tx_ready)) {
+			if (flags == -1)
+				tx_flags = rec->tx_flags;
+			else
+				tx_flags = flags;
+
+			msg_en = &rec->msg_encrypted;
+			rc = tls_push_sg(sk, tls_ctx,
+					 &msg_en->sg.data[msg_en->sg.curr],
+					 0, tx_flags);
+			if (rc)
+				goto tx_err;
+
+			list_del(&rec->list);
+			sk_msg_free(sk, &rec->msg_plaintext);
+			kfree(rec);
+		} else {
+			break;
+		}
+	}
+
+tx_err:
+	if (rc < 0 && rc != -EAGAIN)
+		tls_err_abort(sk, -EBADMSG);
+
+	return rc;
 }
 
-static int tls_do_encryption(struct tls_context *tls_ctx,
+static void tls_encrypt_done(struct crypto_async_request *req, int err)
+{
+	struct aead_request *aead_req = (struct aead_request *)req;
+	struct sock *sk = req->data;
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
+	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct scatterlist *sge;
+	struct sk_msg *msg_en;
+	struct tls_rec *rec;
+	bool ready = false;
+	int pending;
+
+	rec = container_of(aead_req, struct tls_rec, aead_req);
+	msg_en = &rec->msg_encrypted;
+
+	sge = sk_msg_elem(msg_en, msg_en->sg.curr);
+	sge->offset -= prot->prepend_size;
+	sge->length += prot->prepend_size;
+
+	/* Check if error is previously set on socket */
+	if (err || sk->sk_err) {
+		rec = NULL;
+
+		/* If err is already set on socket, return the same code */
+		if (sk->sk_err) {
+			ctx->async_wait.err = -sk->sk_err;
+		} else {
+			ctx->async_wait.err = err;
+			tls_err_abort(sk, err);
+		}
+	}
+
+	if (rec) {
+		struct tls_rec *first_rec;
+
+		/* Mark the record as ready for transmission */
+		smp_store_mb(rec->tx_ready, true);
+
+		/* If received record is at head of tx_list, schedule tx */
+		first_rec = list_first_entry(&ctx->tx_list,
+					     struct tls_rec, list);
+		if (rec == first_rec)
+			ready = true;
+	}
+
+	spin_lock_bh(&ctx->encrypt_compl_lock);
+	pending = atomic_dec_return(&ctx->encrypt_pending);
+
+	if (!pending && ctx->async_notify)
+		complete(&ctx->async_wait.completion);
+	spin_unlock_bh(&ctx->encrypt_compl_lock);
+
+	if (!ready)
+		return;
+
+	/* Schedule the transmission */
+	if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
+		schedule_delayed_work(&ctx->tx_work.work, 1);
+}
+
+static int tls_do_encryption(struct sock *sk,
+			     struct tls_context *tls_ctx,
 			     struct tls_sw_context_tx *ctx,
 			     struct aead_request *aead_req,
-			     size_t data_len)
+			     size_t data_len, u32 start)
 {
-	int rc;
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
+	struct tls_rec *rec = ctx->open_rec;
+	struct sk_msg *msg_en = &rec->msg_encrypted;
+	struct scatterlist *sge = sk_msg_elem(msg_en, start);
+	int rc, iv_offset = 0;
 
-	ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size;
-	ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size;
+	/* For CCM based ciphers, first byte of IV is a constant */
+	if (prot->cipher_type == TLS_CIPHER_AES_CCM_128) {
+		rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE;
+		iv_offset = 1;
+	}
+
+	memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
+	       prot->iv_size + prot->salt_size);
+
+	xor_iv_with_seq(prot->version, rec->iv_data + iv_offset, tls_ctx->tx.rec_seq);
+
+	sge->offset += prot->prepend_size;
+	sge->length -= prot->prepend_size;
+
+	msg_en->sg.curr = start;
 
 	aead_request_set_tfm(aead_req, ctx->aead_send);
-	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
-	aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out,
-			       data_len, tls_ctx->tx.iv);
+	aead_request_set_ad(aead_req, prot->aad_size);
+	aead_request_set_crypt(aead_req, rec->sg_aead_in,
+			       rec->sg_aead_out,
+			       data_len, rec->iv_data);
 
 	aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
-				  crypto_req_done, &ctx->async_wait);
+				  tls_encrypt_done, sk);
 
-	rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait);
+	/* Add the record in tx_list */
+	list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
+	atomic_inc(&ctx->encrypt_pending);
 
-	ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
-	ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;
+	rc = crypto_aead_encrypt(aead_req);
+	if (!rc || rc != -EINPROGRESS) {
+		atomic_dec(&ctx->encrypt_pending);
+		sge->offset -= prot->prepend_size;
+		sge->length += prot->prepend_size;
+	}
 
+	if (!rc) {
+		WRITE_ONCE(rec->tx_ready, true);
+	} else if (rc != -EINPROGRESS) {
+		list_del(&rec->list);
+		return rc;
+	}
+
+	/* Unhook the record from context if encryption is not failure */
+	ctx->open_rec = NULL;
+	tls_advance_record_sn(sk, prot, &tls_ctx->tx);
 	return rc;
+}
+
+static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
+				 struct tls_rec **to, struct sk_msg *msg_opl,
+				 struct sk_msg *msg_oen, u32 split_point,
+				 u32 tx_overhead_size, u32 *orig_end)
+{
+	u32 i, j, bytes = 0, apply = msg_opl->apply_bytes;
+	struct scatterlist *sge, *osge, *nsge;
+	u32 orig_size = msg_opl->sg.size;
+	struct scatterlist tmp = { };
+	struct sk_msg *msg_npl;
+	struct tls_rec *new;
+	int ret;
+
+	new = tls_get_rec(sk);
+	if (!new)
+		return -ENOMEM;
+	ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size +
+			   tx_overhead_size, 0);
+	if (ret < 0) {
+		tls_free_rec(sk, new);
+		return ret;
+	}
+
+	*orig_end = msg_opl->sg.end;
+	i = msg_opl->sg.start;
+	sge = sk_msg_elem(msg_opl, i);
+	while (apply && sge->length) {
+		if (sge->length > apply) {
+			u32 len = sge->length - apply;
+
+			get_page(sg_page(sge));
+			sg_set_page(&tmp, sg_page(sge), len,
+				    sge->offset + apply);
+			sge->length = apply;
+			bytes += apply;
+			apply = 0;
+		} else {
+			apply -= sge->length;
+			bytes += sge->length;
+		}
+
+		sk_msg_iter_var_next(i);
+		if (i == msg_opl->sg.end)
+			break;
+		sge = sk_msg_elem(msg_opl, i);
+	}
+
+	msg_opl->sg.end = i;
+	msg_opl->sg.curr = i;
+	msg_opl->sg.copybreak = 0;
+	msg_opl->apply_bytes = 0;
+	msg_opl->sg.size = bytes;
+
+	msg_npl = &new->msg_plaintext;
+	msg_npl->apply_bytes = apply;
+	msg_npl->sg.size = orig_size - bytes;
+
+	j = msg_npl->sg.start;
+	nsge = sk_msg_elem(msg_npl, j);
+	if (tmp.length) {
+		memcpy(nsge, &tmp, sizeof(*nsge));
+		sk_msg_iter_var_next(j);
+		nsge = sk_msg_elem(msg_npl, j);
+	}
+
+	osge = sk_msg_elem(msg_opl, i);
+	while (osge->length) {
+		memcpy(nsge, osge, sizeof(*nsge));
+		sg_unmark_end(nsge);
+		sk_msg_iter_var_next(i);
+		sk_msg_iter_var_next(j);
+		if (i == *orig_end)
+			break;
+		osge = sk_msg_elem(msg_opl, i);
+		nsge = sk_msg_elem(msg_npl, j);
+	}
+
+	msg_npl->sg.end = j;
+	msg_npl->sg.curr = j;
+	msg_npl->sg.copybreak = 0;
+
+	*to = new;
+	return 0;
+}
+
+static void tls_merge_open_record(struct sock *sk, struct tls_rec *to,
+				  struct tls_rec *from, u32 orig_end)
+{
+	struct sk_msg *msg_npl = &from->msg_plaintext;
+	struct sk_msg *msg_opl = &to->msg_plaintext;
+	struct scatterlist *osge, *nsge;
+	u32 i, j;
+
+	i = msg_opl->sg.end;
+	sk_msg_iter_var_prev(i);
+	j = msg_npl->sg.start;
+
+	osge = sk_msg_elem(msg_opl, i);
+	nsge = sk_msg_elem(msg_npl, j);
+
+	if (sg_page(osge) == sg_page(nsge) &&
+	    osge->offset + osge->length == nsge->offset) {
+		osge->length += nsge->length;
+		put_page(sg_page(nsge));
+	}
+
+	msg_opl->sg.end = orig_end;
+	msg_opl->sg.curr = orig_end;
+	msg_opl->sg.copybreak = 0;
+	msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size;
+	msg_opl->sg.size += msg_npl->sg.size;
+
+	sk_msg_free(sk, &to->msg_encrypted);
+	sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted);
+
+	kfree(from);
 }
 
 static int tls_push_record(struct sock *sk, int flags,
 			   unsigned char record_type)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
+	u32 i, split_point, orig_end;
+	struct sk_msg *msg_pl, *msg_en;
 	struct aead_request *req;
+	bool split;
 	int rc;
 
-	req = aead_request_alloc(ctx->aead_send, sk->sk_allocation);
-	if (!req)
-		return -ENOMEM;
+	if (!rec)
+		return 0;
 
-	sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1);
-	sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1);
+	msg_pl = &rec->msg_plaintext;
+	msg_en = &rec->msg_encrypted;
 
-	tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size,
-		     tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
-		     record_type);
-
-	tls_fill_prepend(tls_ctx,
-			 page_address(sg_page(&ctx->sg_encrypted_data[0])) +
-			 ctx->sg_encrypted_data[0].offset,
-			 ctx->sg_plaintext_size, record_type);
-
-	tls_ctx->pending_open_record_frags = 0;
-	set_bit(TLS_PENDING_CLOSED_RECORD, &tls_ctx->flags);
-
-	rc = tls_do_encryption(tls_ctx, ctx, req, ctx->sg_plaintext_size);
-	if (rc < 0) {
-		/* If we are called from write_space and
-		 * we fail, we need to set this SOCK_NOSPACE
-		 * to trigger another write_space in the future.
+	split_point = msg_pl->apply_bytes;
+	split = split_point && split_point < msg_pl->sg.size;
+	if (unlikely((!split &&
+		      msg_pl->sg.size +
+		      prot->overhead_size > msg_en->sg.size) ||
+		     (split &&
+		      split_point +
+		      prot->overhead_size > msg_en->sg.size))) {
+		split = true;
+		split_point = msg_en->sg.size;
+	}
+	if (split) {
+		rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
+					   split_point, prot->overhead_size,
+					   &orig_end);
+		if (rc < 0)
+			return rc;
+		/* This can happen if above tls_split_open_record allocates
+		 * a single large encryption buffer instead of two smaller
+		 * ones. In this case adjust pointers and continue without
+		 * split.
 		 */
-		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
-		goto out_req;
+		if (!msg_pl->sg.size) {
+			tls_merge_open_record(sk, rec, tmp, orig_end);
+			msg_pl = &rec->msg_plaintext;
+			msg_en = &rec->msg_encrypted;
+			split = false;
+		}
+		sk_msg_trim(sk, msg_en, msg_pl->sg.size +
+			    prot->overhead_size);
 	}
 
-	free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
-		&ctx->sg_plaintext_size);
+	rec->tx_flags = flags;
+	req = &rec->aead_req;
 
-	ctx->sg_encrypted_num_elem = 0;
-	ctx->sg_encrypted_size = 0;
+	i = msg_pl->sg.end;
+	sk_msg_iter_var_prev(i);
 
-	/* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */
-	rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags);
-	if (rc < 0 && rc != -EAGAIN)
-		tls_err_abort(sk, EBADMSG);
+	rec->content_type = record_type;
+	if (prot->version == TLS_1_3_VERSION) {
+		/* Add content type to end of message.  No padding added */
+		sg_set_buf(&rec->sg_content_type, &rec->content_type, 1);
+		sg_mark_end(&rec->sg_content_type);
+		sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1,
+			 &rec->sg_content_type);
+	} else {
+		sg_mark_end(sk_msg_elem(msg_pl, i));
+	}
 
-	tls_advance_record_sn(sk, &tls_ctx->tx);
-out_req:
-	aead_request_free(req);
-	return rc;
+	if (msg_pl->sg.end < msg_pl->sg.start) {
+		sg_chain(&msg_pl->sg.data[msg_pl->sg.start],
+			 MAX_SKB_FRAGS - msg_pl->sg.start + 1,
+			 msg_pl->sg.data);
+	}
+
+	i = msg_pl->sg.start;
+	sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]);
+
+	i = msg_en->sg.end;
+	sk_msg_iter_var_prev(i);
+	sg_mark_end(sk_msg_elem(msg_en, i));
+
+	i = msg_en->sg.start;
+	sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]);
+
+	tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size,
+		     tls_ctx->tx.rec_seq, prot->rec_seq_size,
+		     record_type, prot->version);
+
+	tls_fill_prepend(tls_ctx,
+			 page_address(sg_page(&msg_en->sg.data[i])) +
+			 msg_en->sg.data[i].offset,
+			 msg_pl->sg.size + prot->tail_size,
+			 record_type, prot->version);
+
+	tls_ctx->pending_open_record_frags = false;
+
+	rc = tls_do_encryption(sk, tls_ctx, ctx, req,
+			       msg_pl->sg.size + prot->tail_size, i);
+	if (rc < 0) {
+		if (rc != -EINPROGRESS) {
+			tls_err_abort(sk, -EBADMSG);
+			if (split) {
+				tls_ctx->pending_open_record_frags = true;
+				tls_merge_open_record(sk, rec, tmp, orig_end);
+			}
+		}
+		ctx->async_capable = 1;
+		return rc;
+	} else if (split) {
+		msg_pl = &tmp->msg_plaintext;
+		msg_en = &tmp->msg_encrypted;
+		sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size);
+		tls_ctx->pending_open_record_frags = true;
+		ctx->open_rec = tmp;
+	}
+
+	return tls_tx_records(sk, flags);
+}
+
+static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
+			       bool full_record, u8 record_type,
+			       ssize_t *copied, int flags)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct sk_msg msg_redir = { };
+	struct sk_psock *psock;
+	struct sock *sk_redir;
+	struct tls_rec *rec;
+	bool enospc, policy;
+	int err = 0, send;
+	u32 delta = 0;
+
+	policy = !(flags & MSG_SENDPAGE_NOPOLICY);
+	psock = sk_psock_get(sk);
+	if (!psock || !policy) {
+		err = tls_push_record(sk, flags, record_type);
+		if (err && err != -EINPROGRESS && sk->sk_err == EBADMSG) {
+			*copied -= sk_msg_free(sk, msg);
+			tls_free_open_rec(sk);
+			err = -sk->sk_err;
+		}
+		if (psock)
+			sk_psock_put(sk, psock);
+		return err;
+	}
+more_data:
+	enospc = sk_msg_full(msg);
+	if (psock->eval == __SK_NONE) {
+		delta = msg->sg.size;
+		psock->eval = sk_psock_msg_verdict(sk, psock, msg);
+		delta -= msg->sg.size;
+	}
+	if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
+	    !enospc && !full_record) {
+		err = -ENOSPC;
+		goto out_err;
+	}
+	msg->cork_bytes = 0;
+	send = msg->sg.size;
+	if (msg->apply_bytes && msg->apply_bytes < send)
+		send = msg->apply_bytes;
+
+	switch (psock->eval) {
+	case __SK_PASS:
+		err = tls_push_record(sk, flags, record_type);
+		if (err && err != -EINPROGRESS && sk->sk_err == EBADMSG) {
+			*copied -= sk_msg_free(sk, msg);
+			tls_free_open_rec(sk);
+			err = -sk->sk_err;
+			goto out_err;
+		}
+		break;
+	case __SK_REDIRECT:
+		sk_redir = psock->sk_redir;
+		memcpy(&msg_redir, msg, sizeof(*msg));
+		if (msg->apply_bytes < send)
+			msg->apply_bytes = 0;
+		else
+			msg->apply_bytes -= send;
+		sk_msg_return_zero(sk, msg, send);
+		msg->sg.size -= send;
+		release_sock(sk);
+		err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags);
+		lock_sock(sk);
+		if (err < 0) {
+			*copied -= sk_msg_free_nocharge(sk, &msg_redir);
+			msg->sg.size = 0;
+		}
+		if (msg->sg.size == 0)
+			tls_free_open_rec(sk);
+		break;
+	case __SK_DROP:
+	default:
+		sk_msg_free_partial(sk, msg, send);
+		if (msg->apply_bytes < send)
+			msg->apply_bytes = 0;
+		else
+			msg->apply_bytes -= send;
+		if (msg->sg.size == 0)
+			tls_free_open_rec(sk);
+		*copied -= (send + delta);
+		err = -EACCES;
+	}
+
+	if (likely(!err)) {
+		bool reset_eval = !ctx->open_rec;
+
+		rec = ctx->open_rec;
+		if (rec) {
+			msg = &rec->msg_plaintext;
+			if (!msg->apply_bytes)
+				reset_eval = true;
+		}
+		if (reset_eval) {
+			psock->eval = __SK_NONE;
+			if (psock->sk_redir) {
+				sock_put(psock->sk_redir);
+				psock->sk_redir = NULL;
+			}
+		}
+		if (rec)
+			goto more_data;
+	}
+ out_err:
+	sk_psock_put(sk, psock);
+	return err;
 }
 
 static int tls_sw_push_pending_record(struct sock *sk, int flags)
 {
-	return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct tls_rec *rec = ctx->open_rec;
+	struct sk_msg *msg_pl;
+	size_t copied;
+
+	if (!rec)
+		return 0;
+
+	msg_pl = &rec->msg_plaintext;
+	copied = msg_pl->sg.size;
+	if (!copied)
+		return 0;
+
+	return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA,
+				   &copied, flags);
 }
 
-static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
-			      int length, int *pages_used,
-			      unsigned int *size_used,
-			      struct scatterlist *to, int to_max_pages,
-			      bool charge)
+int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 {
-	struct page *pages[MAX_SKB_FRAGS];
+	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
+	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	bool async_capable = ctx->async_capable;
+	unsigned char record_type = TLS_RECORD_TYPE_DATA;
+	bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
+	bool eor = !(msg->msg_flags & MSG_MORE);
+	size_t try_to_copy;
+	ssize_t copied = 0;
+	struct sk_msg *msg_pl, *msg_en;
+	struct tls_rec *rec;
+	int required_size;
+	int num_async = 0;
+	bool full_record;
+	int record_room;
+	int num_zc = 0;
+	int orig_size;
+	int ret = 0;
+	int pending;
 
-	size_t offset;
-	ssize_t copied, use;
-	int i = 0;
+	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
+			       MSG_CMSG_COMPAT))
+		return -EOPNOTSUPP;
+
+	ret = mutex_lock_interruptible(&tls_ctx->tx_lock);
+	if (ret)
+		return ret;
+	lock_sock(sk);
+
+	if (unlikely(msg->msg_controllen)) {
+		ret = tls_proccess_cmsg(sk, msg, &record_type);
+		if (ret) {
+			if (ret == -EINPROGRESS)
+				num_async++;
+			else if (ret != -EAGAIN)
+				goto send_end;
+		}
+	}
+
+	while (msg_data_left(msg)) {
+		if (sk->sk_err) {
+			ret = -sk->sk_err;
+			goto send_end;
+		}
+
+		if (ctx->open_rec)
+			rec = ctx->open_rec;
+		else
+			rec = ctx->open_rec = tls_get_rec(sk);
+		if (!rec) {
+			ret = -ENOMEM;
+			goto send_end;
+		}
+
+		msg_pl = &rec->msg_plaintext;
+		msg_en = &rec->msg_encrypted;
+
+		orig_size = msg_pl->sg.size;
+		full_record = false;
+		try_to_copy = msg_data_left(msg);
+		record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
+		if (try_to_copy >= record_room) {
+			try_to_copy = record_room;
+			full_record = true;
+		}
+
+		required_size = msg_pl->sg.size + try_to_copy +
+				prot->overhead_size;
+
+		if (!sk_stream_memory_free(sk))
+			goto wait_for_sndbuf;
+
+alloc_encrypted:
+		ret = tls_alloc_encrypted_msg(sk, required_size);
+		if (ret) {
+			if (ret != -ENOSPC)
+				goto wait_for_memory;
+
+			/* Adjust try_to_copy according to the amount that was
+			 * actually allocated. The difference is due
+			 * to max sg elements limit
+			 */
+			try_to_copy -= required_size - msg_en->sg.size;
+			full_record = true;
+		}
+
+		if (!is_kvec && (full_record || eor) && !async_capable) {
+			u32 first = msg_pl->sg.end;
+
+			ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter,
+							msg_pl, try_to_copy);
+			if (ret)
+				goto fallback_to_reg_send;
+
+			num_zc++;
+			copied += try_to_copy;
+
+			sk_msg_sg_copy_set(msg_pl, first);
+			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
+						  record_type, &copied,
+						  msg->msg_flags);
+			if (ret) {
+				if (ret == -EINPROGRESS)
+					num_async++;
+				else if (ret == -ENOMEM)
+					goto wait_for_memory;
+				else if (ctx->open_rec && ret == -ENOSPC)
+					goto rollback_iter;
+				else if (ret != -EAGAIN)
+					goto send_end;
+			}
+			continue;
+rollback_iter:
+			copied -= try_to_copy;
+			sk_msg_sg_copy_clear(msg_pl, first);
+			iov_iter_revert(&msg->msg_iter,
+					msg_pl->sg.size - orig_size);
+fallback_to_reg_send:
+			sk_msg_trim(sk, msg_pl, orig_size);
+		}
+
+		required_size = msg_pl->sg.size + try_to_copy;
+
+		ret = tls_clone_plaintext_msg(sk, required_size);
+		if (ret) {
+			if (ret != -ENOSPC)
+				goto send_end;
+
+			/* Adjust try_to_copy according to the amount that was
+			 * actually allocated. The difference is due
+			 * to max sg elements limit
+			 */
+			try_to_copy -= required_size - msg_pl->sg.size;
+			full_record = true;
+			sk_msg_trim(sk, msg_en,
+				    msg_pl->sg.size + prot->overhead_size);
+		}
+
+		if (try_to_copy) {
+			ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter,
+						       msg_pl, try_to_copy);
+			if (ret < 0)
+				goto trim_sgl;
+		}
+
+		/* Open records defined only if successfully copied, otherwise
+		 * we would trim the sg but not reset the open record frags.
+		 */
+		tls_ctx->pending_open_record_frags = true;
+		copied += try_to_copy;
+		if (full_record || eor) {
+			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
+						  record_type, &copied,
+						  msg->msg_flags);
+			if (ret) {
+				if (ret == -EINPROGRESS)
+					num_async++;
+				else if (ret == -ENOMEM)
+					goto wait_for_memory;
+				else if (ret != -EAGAIN) {
+					if (ret == -ENOSPC)
+						ret = 0;
+					goto send_end;
+				}
+			}
+		}
+
+		continue;
+
+wait_for_sndbuf:
+		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
+wait_for_memory:
+		ret = sk_stream_wait_memory(sk, &timeo);
+		if (ret) {
+trim_sgl:
+			if (ctx->open_rec)
+				tls_trim_both_msgs(sk, orig_size);
+			goto send_end;
+		}
+
+		if (ctx->open_rec && msg_en->sg.size < required_size)
+			goto alloc_encrypted;
+	}
+
+	if (!num_async) {
+		goto send_end;
+	} else if (num_zc) {
+		/* Wait for pending encryptions to get completed */
+		spin_lock_bh(&ctx->encrypt_compl_lock);
+		ctx->async_notify = true;
+
+		pending = atomic_read(&ctx->encrypt_pending);
+		spin_unlock_bh(&ctx->encrypt_compl_lock);
+		if (pending)
+			crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
+		else
+			reinit_completion(&ctx->async_wait.completion);
+
+		/* There can be no concurrent accesses, since we have no
+		 * pending encrypt operations
+		 */
+		WRITE_ONCE(ctx->async_notify, false);
+
+		if (ctx->async_wait.err) {
+			ret = ctx->async_wait.err;
+			copied = 0;
+		}
+	}
+
+	/* Transmit if any encryptions have completed */
+	if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
+		cancel_delayed_work(&ctx->tx_work.work);
+		tls_tx_records(sk, msg->msg_flags);
+	}
+
+send_end:
+	ret = sk_stream_error(sk, msg->msg_flags, ret);
+
+	release_sock(sk);
+	mutex_unlock(&tls_ctx->tx_lock);
+	return copied > 0 ? copied : ret;
+}
+
+static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
+			      int offset, size_t size, int flags)
+{
+	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
+	unsigned char record_type = TLS_RECORD_TYPE_DATA;
+	struct sk_msg *msg_pl;
+	struct tls_rec *rec;
+	int num_async = 0;
+	ssize_t copied = 0;
+	bool full_record;
+	int record_room;
+	int ret = 0;
+	bool eor;
+
+	eor = !(flags & MSG_SENDPAGE_NOTLAST);
+	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
+
+	/* Call the sk_stream functions to manage the sndbuf mem. */
+	while (size > 0) {
+		size_t copy, required_size;
+
+		if (sk->sk_err) {
+			ret = -sk->sk_err;
+			goto sendpage_end;
+		}
+
+		if (ctx->open_rec)
+			rec = ctx->open_rec;
+		else
+			rec = ctx->open_rec = tls_get_rec(sk);
+		if (!rec) {
+			ret = -ENOMEM;
+			goto sendpage_end;
+		}
+
+		msg_pl = &rec->msg_plaintext;
+
+		full_record = false;
+		record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
+		copy = size;
+		if (copy >= record_room) {
+			copy = record_room;
+			full_record = true;
+		}
+
+		required_size = msg_pl->sg.size + copy + prot->overhead_size;
+
+		if (!sk_stream_memory_free(sk))
+			goto wait_for_sndbuf;
+alloc_payload:
+		ret = tls_alloc_encrypted_msg(sk, required_size);
+		if (ret) {
+			if (ret != -ENOSPC)
+				goto wait_for_memory;
+
+			/* Adjust copy according to the amount that was
+			 * actually allocated. The difference is due
+			 * to max sg elements limit
+			 */
+			copy -= required_size - msg_pl->sg.size;
+			full_record = true;
+		}
+
+		sk_msg_page_add(msg_pl, page, copy, offset);
+		sk_mem_charge(sk, copy);
+
+		offset += copy;
+		size -= copy;
+		copied += copy;
+
+		tls_ctx->pending_open_record_frags = true;
+		if (full_record || eor || sk_msg_full(msg_pl)) {
+			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
+						  record_type, &copied, flags);
+			if (ret) {
+				if (ret == -EINPROGRESS)
+					num_async++;
+				else if (ret == -ENOMEM)
+					goto wait_for_memory;
+				else if (ret != -EAGAIN) {
+					if (ret == -ENOSPC)
+						ret = 0;
+					goto sendpage_end;
+				}
+			}
+		}
+		continue;
+wait_for_sndbuf:
+		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
+wait_for_memory:
+		ret = sk_stream_wait_memory(sk, &timeo);
+		if (ret) {
+			if (ctx->open_rec)
+				tls_trim_both_msgs(sk, msg_pl->sg.size);
+			goto sendpage_end;
+		}
+
+		if (ctx->open_rec)
+			goto alloc_payload;
+	}
+
+	if (num_async) {
+		/* Transmit if any encryptions have completed */
+		if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
+			cancel_delayed_work(&ctx->tx_work.work);
+			tls_tx_records(sk, flags);
+		}
+	}
+sendpage_end:
+	ret = sk_stream_error(sk, flags, ret);
+	return copied > 0 ? copied : ret;
+}
+
+int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
+			   int offset, size_t size, int flags)
+{
+	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
+		      MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY |
+		      MSG_NO_SHARED_FRAGS))
+		return -EOPNOTSUPP;
+
+	return tls_sw_do_sendpage(sk, page, offset, size, flags);
+}
+
+int tls_sw_sendpage(struct sock *sk, struct page *page,
+		    int offset, size_t size, int flags)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	int ret;
+
+	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
+		      MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
+		return -EOPNOTSUPP;
+
+	ret = mutex_lock_interruptible(&tls_ctx->tx_lock);
+	if (ret)
+		return ret;
+	lock_sock(sk);
+	ret = tls_sw_do_sendpage(sk, page, offset, size, flags);
+	release_sock(sk);
+	mutex_unlock(&tls_ctx->tx_lock);
+	return ret;
+}
+
+static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock,
+				     bool nonblock, long timeo, int *err)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+	struct sk_buff *skb;
+	DEFINE_WAIT_FUNC(wait, woken_wake_function);
+
+	while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) {
+		if (sk->sk_err) {
+			*err = sock_error(sk);
+			return NULL;
+		}
+
+		if (!skb_queue_empty(&sk->sk_receive_queue)) {
+			__strp_unpause(&ctx->strp);
+			if (ctx->recv_pkt)
+				return ctx->recv_pkt;
+		}
+
+		if (sk->sk_shutdown & RCV_SHUTDOWN)
+			return NULL;
+
+		if (sock_flag(sk, SOCK_DONE))
+			return NULL;
+
+		if (nonblock || !timeo) {
+			*err = -EAGAIN;
+			return NULL;
+		}
+
+		add_wait_queue(sk_sleep(sk), &wait);
+		sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
+		sk_wait_event(sk, &timeo,
+			      ctx->recv_pkt != skb ||
+			      !sk_psock_queue_empty(psock),
+			      &wait);
+		sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
+		remove_wait_queue(sk_sleep(sk), &wait);
+
+		/* Handle signals */
+		if (signal_pending(current)) {
+			*err = sock_intr_errno(timeo);
+			return NULL;
+		}
+	}
+
+	return skb;
+}
+
+static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
+			       int length, int *pages_used,
+			       unsigned int *size_used,
+			       struct scatterlist *to,
+			       int to_max_pages)
+{
+	int rc = 0, i = 0, num_elem = *pages_used, maxpages;
+	struct page *pages[MAX_SKB_FRAGS];
 	unsigned int size = *size_used;
-	int num_elem = *pages_used;
-	int rc = 0;
-	int maxpages;
+	ssize_t copied, use;
+	size_t offset;
 
 	while (length > 0) {
 		i = 0;
@@ -298,17 +1382,15 @@
 			sg_set_page(&to[num_elem],
 				    pages[i], use, offset);
 			sg_unmark_end(&to[num_elem]);
-			if (charge)
-				sk_mem_charge(sk, use);
+			/* We do not uncharge memory from this API */
 
 			offset = 0;
 			copied -= use;
 
-			++i;
-			++num_elem;
+			i++;
+			num_elem++;
 		}
 	}
-
 	/* Mark the end in the last sg entry if newly added */
 	if (num_elem > *pages_used)
 		sg_mark_end(&to[num_elem - 1]);
@@ -319,348 +1401,6 @@
 	*pages_used = num_elem;
 
 	return rc;
-}
-
-static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,
-			     int bytes)
-{
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
-	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-	struct scatterlist *sg = ctx->sg_plaintext_data;
-	int copy, i, rc = 0;
-
-	for (i = tls_ctx->pending_open_record_frags;
-	     i < ctx->sg_plaintext_num_elem; ++i) {
-		copy = sg[i].length;
-		if (copy_from_iter(
-				page_address(sg_page(&sg[i])) + sg[i].offset,
-				copy, from) != copy) {
-			rc = -EFAULT;
-			goto out;
-		}
-		bytes -= copy;
-
-		++tls_ctx->pending_open_record_frags;
-
-		if (!bytes)
-			break;
-	}
-
-out:
-	return rc;
-}
-
-int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
-{
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
-	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-	int ret;
-	int required_size;
-	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
-	bool eor = !(msg->msg_flags & MSG_MORE);
-	size_t try_to_copy, copied = 0;
-	unsigned char record_type = TLS_RECORD_TYPE_DATA;
-	int record_room;
-	bool full_record;
-	int orig_size;
-	bool is_kvec = msg->msg_iter.type & ITER_KVEC;
-
-	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
-		return -ENOTSUPP;
-
-	lock_sock(sk);
-
-	ret = tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo);
-	if (ret)
-		goto send_end;
-
-	if (unlikely(msg->msg_controllen)) {
-		ret = tls_proccess_cmsg(sk, msg, &record_type);
-		if (ret)
-			goto send_end;
-	}
-
-	while (msg_data_left(msg)) {
-		if (sk->sk_err) {
-			ret = -sk->sk_err;
-			goto send_end;
-		}
-
-		orig_size = ctx->sg_plaintext_size;
-		full_record = false;
-		try_to_copy = msg_data_left(msg);
-		record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
-		if (try_to_copy >= record_room) {
-			try_to_copy = record_room;
-			full_record = true;
-		}
-
-		required_size = ctx->sg_plaintext_size + try_to_copy +
-				tls_ctx->tx.overhead_size;
-
-		if (!sk_stream_memory_free(sk))
-			goto wait_for_sndbuf;
-alloc_encrypted:
-		ret = alloc_encrypted_sg(sk, required_size);
-		if (ret) {
-			if (ret != -ENOSPC)
-				goto wait_for_memory;
-
-			/* Adjust try_to_copy according to the amount that was
-			 * actually allocated. The difference is due
-			 * to max sg elements limit
-			 */
-			try_to_copy -= required_size - ctx->sg_encrypted_size;
-			full_record = true;
-		}
-		if (!is_kvec && (full_record || eor)) {
-			ret = zerocopy_from_iter(sk, &msg->msg_iter,
-				try_to_copy, &ctx->sg_plaintext_num_elem,
-				&ctx->sg_plaintext_size,
-				ctx->sg_plaintext_data,
-				ARRAY_SIZE(ctx->sg_plaintext_data),
-				true);
-			if (ret)
-				goto fallback_to_reg_send;
-
-			copied += try_to_copy;
-			ret = tls_push_record(sk, msg->msg_flags, record_type);
-			if (ret)
-				goto send_end;
-			continue;
-
-fallback_to_reg_send:
-			trim_sg(sk, ctx->sg_plaintext_data,
-				&ctx->sg_plaintext_num_elem,
-				&ctx->sg_plaintext_size,
-				orig_size);
-		}
-
-		required_size = ctx->sg_plaintext_size + try_to_copy;
-alloc_plaintext:
-		ret = alloc_plaintext_sg(sk, required_size);
-		if (ret) {
-			if (ret != -ENOSPC)
-				goto wait_for_memory;
-
-			/* Adjust try_to_copy according to the amount that was
-			 * actually allocated. The difference is due
-			 * to max sg elements limit
-			 */
-			try_to_copy -= required_size - ctx->sg_plaintext_size;
-			full_record = true;
-
-			trim_sg(sk, ctx->sg_encrypted_data,
-				&ctx->sg_encrypted_num_elem,
-				&ctx->sg_encrypted_size,
-				ctx->sg_plaintext_size +
-				tls_ctx->tx.overhead_size);
-		}
-
-		ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
-		if (ret)
-			goto trim_sgl;
-
-		copied += try_to_copy;
-		if (full_record || eor) {
-push_record:
-			ret = tls_push_record(sk, msg->msg_flags, record_type);
-			if (ret) {
-				if (ret == -ENOMEM)
-					goto wait_for_memory;
-
-				goto send_end;
-			}
-		}
-
-		continue;
-
-wait_for_sndbuf:
-		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
-wait_for_memory:
-		ret = sk_stream_wait_memory(sk, &timeo);
-		if (ret) {
-trim_sgl:
-			trim_both_sgl(sk, orig_size);
-			goto send_end;
-		}
-
-		if (tls_is_pending_closed_record(tls_ctx))
-			goto push_record;
-
-		if (ctx->sg_encrypted_size < required_size)
-			goto alloc_encrypted;
-
-		goto alloc_plaintext;
-	}
-
-send_end:
-	ret = sk_stream_error(sk, msg->msg_flags, ret);
-
-	release_sock(sk);
-	return copied ? copied : ret;
-}
-
-int tls_sw_sendpage(struct sock *sk, struct page *page,
-		    int offset, size_t size, int flags)
-{
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
-	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-	int ret;
-	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
-	bool eor;
-	size_t orig_size = size;
-	unsigned char record_type = TLS_RECORD_TYPE_DATA;
-	struct scatterlist *sg;
-	bool full_record;
-	int record_room;
-
-	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
-		      MSG_SENDPAGE_NOTLAST))
-		return -ENOTSUPP;
-
-	/* No MSG_EOR from splice, only look at MSG_MORE */
-	eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
-
-	lock_sock(sk);
-
-	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
-
-	ret = tls_complete_pending_work(sk, tls_ctx, flags, &timeo);
-	if (ret)
-		goto sendpage_end;
-
-	/* Call the sk_stream functions to manage the sndbuf mem. */
-	while (size > 0) {
-		size_t copy, required_size;
-
-		if (sk->sk_err) {
-			ret = -sk->sk_err;
-			goto sendpage_end;
-		}
-
-		full_record = false;
-		record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
-		copy = size;
-		if (copy >= record_room) {
-			copy = record_room;
-			full_record = true;
-		}
-		required_size = ctx->sg_plaintext_size + copy +
-			      tls_ctx->tx.overhead_size;
-
-		if (!sk_stream_memory_free(sk))
-			goto wait_for_sndbuf;
-alloc_payload:
-		ret = alloc_encrypted_sg(sk, required_size);
-		if (ret) {
-			if (ret != -ENOSPC)
-				goto wait_for_memory;
-
-			/* Adjust copy according to the amount that was
-			 * actually allocated. The difference is due
-			 * to max sg elements limit
-			 */
-			copy -= required_size - ctx->sg_plaintext_size;
-			full_record = true;
-		}
-
-		get_page(page);
-		sg = ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem;
-		sg_set_page(sg, page, copy, offset);
-		sg_unmark_end(sg);
-
-		ctx->sg_plaintext_num_elem++;
-
-		sk_mem_charge(sk, copy);
-		offset += copy;
-		size -= copy;
-		ctx->sg_plaintext_size += copy;
-		tls_ctx->pending_open_record_frags = ctx->sg_plaintext_num_elem;
-
-		if (full_record || eor ||
-		    ctx->sg_plaintext_num_elem ==
-		    ARRAY_SIZE(ctx->sg_plaintext_data)) {
-push_record:
-			ret = tls_push_record(sk, flags, record_type);
-			if (ret) {
-				if (ret == -ENOMEM)
-					goto wait_for_memory;
-
-				goto sendpage_end;
-			}
-		}
-		continue;
-wait_for_sndbuf:
-		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
-wait_for_memory:
-		ret = sk_stream_wait_memory(sk, &timeo);
-		if (ret) {
-			trim_both_sgl(sk, ctx->sg_plaintext_size);
-			goto sendpage_end;
-		}
-
-		if (tls_is_pending_closed_record(tls_ctx))
-			goto push_record;
-
-		goto alloc_payload;
-	}
-
-sendpage_end:
-	if (orig_size > size)
-		ret = orig_size - size;
-	else
-		ret = sk_stream_error(sk, flags, ret);
-
-	release_sock(sk);
-	return ret;
-}
-
-static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
-				     long timeo, int *err)
-{
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
-	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
-	struct sk_buff *skb;
-	DEFINE_WAIT_FUNC(wait, woken_wake_function);
-
-	while (!(skb = ctx->recv_pkt)) {
-		if (sk->sk_err) {
-			*err = sock_error(sk);
-			return NULL;
-		}
-
-		if (!skb_queue_empty(&sk->sk_receive_queue)) {
-			__strp_unpause(&ctx->strp);
-			if (ctx->recv_pkt)
-				return ctx->recv_pkt;
-		}
-
-		if (sk->sk_shutdown & RCV_SHUTDOWN)
-			return NULL;
-
-		if (sock_flag(sk, SOCK_DONE))
-			return NULL;
-
-		if ((flags & MSG_DONTWAIT) || !timeo) {
-			*err = -EAGAIN;
-			return NULL;
-		}
-
-		add_wait_queue(sk_sleep(sk), &wait);
-		sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
-		sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
-		sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
-		remove_wait_queue(sk_sleep(sk), &wait);
-
-		/* Handle signals */
-		if (signal_pending(current)) {
-			*err = sock_intr_errno(timeo);
-			return NULL;
-		}
-	}
-
-	return skb;
 }
 
 /* This function decrypts the input skb into either out_iov or in out_sg
@@ -674,10 +1414,11 @@
 static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 			    struct iov_iter *out_iov,
 			    struct scatterlist *out_sg,
-			    int *chunk, bool *zc)
+			    int *chunk, bool *zc, bool async)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct strp_msg *rxm = strp_msg(skb);
 	int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
 	struct aead_request *aead_req;
@@ -685,19 +1426,23 @@
 	u8 *aad, *iv, *mem = NULL;
 	struct scatterlist *sgin = NULL;
 	struct scatterlist *sgout = NULL;
-	const int data_len = rxm->full_len - tls_ctx->rx.overhead_size;
+	const int data_len = rxm->full_len - prot->overhead_size +
+			     prot->tail_size;
+	int iv_offset = 0;
 
 	if (*zc && (out_iov || out_sg)) {
 		if (out_iov)
 			n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
 		else
 			n_sgout = sg_nents(out_sg);
+		n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
+				 rxm->full_len - prot->prepend_size);
 	} else {
 		n_sgout = 0;
 		*zc = false;
+		n_sgin = skb_cow_data(skb, 0, &unused);
 	}
 
-	n_sgin = skb_cow_data(skb, 0, &unused);
 	if (n_sgin < 1)
 		return -EBADMSG;
 
@@ -708,7 +1453,7 @@
 
 	aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
 	mem_size = aead_size + (nsg * sizeof(struct scatterlist));
-	mem_size = mem_size + TLS_AAD_SPACE_SIZE;
+	mem_size = mem_size + prot->aad_size;
 	mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
 
 	/* Allocate a single block of memory which contains
@@ -724,29 +1469,42 @@
 	sgin = (struct scatterlist *)(mem + aead_size);
 	sgout = sgin + n_sgin;
 	aad = (u8 *)(sgout + n_sgout);
-	iv = aad + TLS_AAD_SPACE_SIZE;
+	iv = aad + prot->aad_size;
+
+	/* For CCM based ciphers, first byte of nonce+iv is always '2' */
+	if (prot->cipher_type == TLS_CIPHER_AES_CCM_128) {
+		iv[0] = 2;
+		iv_offset = 1;
+	}
 
 	/* Prepare IV */
 	err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
-			    iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
-			    tls_ctx->rx.iv_size);
+			    iv + iv_offset + prot->salt_size,
+			    prot->iv_size);
 	if (err < 0) {
 		kfree(mem);
 		return err;
 	}
-	memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+	if (prot->version == TLS_1_3_VERSION)
+		memcpy(iv + iv_offset, tls_ctx->rx.iv,
+		       prot->iv_size + prot->salt_size);
+	else
+		memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size);
+
+	xor_iv_with_seq(prot->version, iv + iv_offset, tls_ctx->rx.rec_seq);
 
 	/* Prepare AAD */
-	tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size,
-		     tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size,
-		     ctx->control);
+	tls_make_aad(aad, rxm->full_len - prot->overhead_size +
+		     prot->tail_size,
+		     tls_ctx->rx.rec_seq, prot->rec_seq_size,
+		     ctx->control, prot->version);
 
 	/* Prepare sgin */
 	sg_init_table(sgin, n_sgin);
-	sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE);
+	sg_set_buf(&sgin[0], aad, prot->aad_size);
 	err = skb_to_sgvec(skb, &sgin[1],
-			   rxm->offset + tls_ctx->rx.prepend_size,
-			   rxm->full_len - tls_ctx->rx.prepend_size);
+			   rxm->offset + prot->prepend_size,
+			   rxm->full_len - prot->prepend_size);
 	if (err < 0) {
 		kfree(mem);
 		return err;
@@ -755,12 +1513,12 @@
 	if (n_sgout) {
 		if (out_iov) {
 			sg_init_table(sgout, n_sgout);
-			sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE);
+			sg_set_buf(&sgout[0], aad, prot->aad_size);
 
 			*chunk = 0;
-			err = zerocopy_from_iter(sk, out_iov, data_len, &pages,
-						 chunk, &sgout[1],
-						 (n_sgout - 1), false);
+			err = tls_setup_from_iter(sk, out_iov, data_len,
+						  &pages, chunk, &sgout[1],
+						  (n_sgout - 1));
 			if (err < 0)
 				goto fallback_to_reg_recv;
 		} else if (out_sg) {
@@ -772,12 +1530,15 @@
 fallback_to_reg_recv:
 		sgout = sgin;
 		pages = 0;
-		*chunk = 0;
+		*chunk = data_len;
 		*zc = false;
 	}
 
 	/* Prepare and submit AEAD request */
-	err = tls_do_decryption(sk, sgin, sgout, iv, data_len, aead_req);
+	err = tls_do_decryption(sk, skb, sgin, sgout, iv,
+				data_len, aead_req, async);
+	if (err == -EINPROGRESS)
+		return err;
 
 	/* Release the pages in case iov was mapped to pages */
 	for (; pages > 0; pages--)
@@ -788,31 +1549,52 @@
 }
 
 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
-			      struct iov_iter *dest, int *chunk, bool *zc)
+			      struct iov_iter *dest, int *chunk, bool *zc,
+			      bool async)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct strp_msg *rxm = strp_msg(skb);
-	int err = 0;
+	int pad, err = 0;
 
-#ifdef CONFIG_TLS_DEVICE
-	err = tls_device_decrypted(sk, skb);
-	if (err < 0)
-		return err;
-#endif
 	if (!ctx->decrypted) {
-		err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
-		if (err < 0)
-			return err;
+		if (tls_ctx->rx_conf == TLS_HW) {
+			err = tls_device_decrypted(sk, tls_ctx, skb, rxm);
+			if (err < 0)
+				return err;
+		}
+
+		/* Still not decrypted after tls_device */
+		if (!ctx->decrypted) {
+			err = decrypt_internal(sk, skb, dest, NULL, chunk, zc,
+					       async);
+			if (err < 0) {
+				if (err == -EINPROGRESS)
+					tls_advance_record_sn(sk, prot,
+							      &tls_ctx->rx);
+				else if (err == -EBADMSG)
+					TLS_INC_STATS(sock_net(sk),
+						      LINUX_MIB_TLSDECRYPTERROR);
+				return err;
+			}
+		} else {
+			*zc = false;
+		}
+
+		pad = padding_length(ctx, prot, skb);
+		if (pad < 0)
+			return pad;
+
+		rxm->full_len -= pad;
+		rxm->offset += prot->prepend_size;
+		rxm->full_len -= prot->overhead_size;
+		tls_advance_record_sn(sk, prot, &tls_ctx->rx);
+		ctx->decrypted = 1;
+		ctx->saved_data_ready(sk);
 	} else {
 		*zc = false;
 	}
-
-	rxm->offset += tls_ctx->rx.prepend_size;
-	rxm->full_len -= tls_ctx->rx.overhead_size;
-	tls_advance_record_sn(sk, &tls_ctx->rx);
-	ctx->decrypted = true;
-	ctx->saved_data_ready(sk);
 
 	return err;
 }
@@ -823,7 +1605,7 @@
 	bool zc = true;
 	int chunk;
 
-	return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
+	return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false);
 }
 
 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
@@ -831,21 +1613,132 @@
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
-	struct strp_msg *rxm = strp_msg(skb);
 
-	if (len < rxm->full_len) {
-		rxm->offset += len;
-		rxm->full_len -= len;
+	if (skb) {
+		struct strp_msg *rxm = strp_msg(skb);
 
-		return false;
+		if (len < rxm->full_len) {
+			rxm->offset += len;
+			rxm->full_len -= len;
+			return false;
+		}
+		consume_skb(skb);
 	}
 
 	/* Finished with message */
 	ctx->recv_pkt = NULL;
-	kfree_skb(skb);
 	__strp_unpause(&ctx->strp);
 
 	return true;
+}
+
+/* This function traverses the rx_list in tls receive context to copies the
+ * decrypted records into the buffer provided by caller zero copy is not
+ * true. Further, the records are removed from the rx_list if it is not a peek
+ * case and the record has been consumed completely.
+ */
+static int process_rx_list(struct tls_sw_context_rx *ctx,
+			   struct msghdr *msg,
+			   u8 *control,
+			   bool *cmsg,
+			   size_t skip,
+			   size_t len,
+			   bool zc,
+			   bool is_peek)
+{
+	struct sk_buff *skb = skb_peek(&ctx->rx_list);
+	u8 ctrl = *control;
+	u8 msgc = *cmsg;
+	struct tls_msg *tlm;
+	ssize_t copied = 0;
+
+	/* Set the record type in 'control' if caller didn't pass it */
+	if (!ctrl && skb) {
+		tlm = tls_msg(skb);
+		ctrl = tlm->control;
+	}
+
+	while (skip && skb) {
+		struct strp_msg *rxm = strp_msg(skb);
+		tlm = tls_msg(skb);
+
+		/* Cannot process a record of different type */
+		if (ctrl != tlm->control)
+			return 0;
+
+		if (skip < rxm->full_len)
+			break;
+
+		skip = skip - rxm->full_len;
+		skb = skb_peek_next(skb, &ctx->rx_list);
+	}
+
+	while (len && skb) {
+		struct sk_buff *next_skb;
+		struct strp_msg *rxm = strp_msg(skb);
+		int chunk = min_t(unsigned int, rxm->full_len - skip, len);
+
+		tlm = tls_msg(skb);
+
+		/* Cannot process a record of different type */
+		if (ctrl != tlm->control)
+			return 0;
+
+		/* Set record type if not already done. For a non-data record,
+		 * do not proceed if record type could not be copied.
+		 */
+		if (!msgc) {
+			int cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
+					    sizeof(ctrl), &ctrl);
+			msgc = true;
+			if (ctrl != TLS_RECORD_TYPE_DATA) {
+				if (cerr || msg->msg_flags & MSG_CTRUNC)
+					return -EIO;
+
+				*cmsg = msgc;
+			}
+		}
+
+		if (!zc || (rxm->full_len - skip) > len) {
+			int err = skb_copy_datagram_msg(skb, rxm->offset + skip,
+						    msg, chunk);
+			if (err < 0)
+				return err;
+		}
+
+		len = len - chunk;
+		copied = copied + chunk;
+
+		/* Consume the data from record if it is non-peek case*/
+		if (!is_peek) {
+			rxm->offset = rxm->offset + chunk;
+			rxm->full_len = rxm->full_len - chunk;
+
+			/* Return if there is unconsumed data in the record */
+			if (rxm->full_len - skip)
+				break;
+		}
+
+		/* The remaining skip-bytes must lie in 1st record in rx_list.
+		 * So from the 2nd record, 'skip' should be 0.
+		 */
+		skip = 0;
+
+		if (msg)
+			msg->msg_flags |= MSG_EOR;
+
+		next_skb = skb_peek_next(skb, &ctx->rx_list);
+
+		if (!is_peek) {
+			skb_unlink(skb, &ctx->rx_list);
+			consume_skb(skb);
+		}
+
+		skb = next_skb;
+	}
+
+	*control = ctrl;
+	return copied;
 }
 
 int tls_sw_recvmsg(struct sock *sk,
@@ -857,104 +1750,241 @@
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
-	unsigned char control;
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
+	struct sk_psock *psock;
+	unsigned char control = 0;
+	ssize_t decrypted = 0;
 	struct strp_msg *rxm;
+	struct tls_msg *tlm;
 	struct sk_buff *skb;
 	ssize_t copied = 0;
 	bool cmsg = false;
 	int target, err = 0;
 	long timeo;
-	bool is_kvec = msg->msg_iter.type & ITER_KVEC;
+	bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
+	bool is_peek = flags & MSG_PEEK;
+	bool bpf_strp_enabled;
+	int num_async = 0;
+	int pending;
 
 	flags |= nonblock;
 
 	if (unlikely(flags & MSG_ERRQUEUE))
 		return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
 
+	psock = sk_psock_get(sk);
 	lock_sock(sk);
+	bpf_strp_enabled = sk_psock_strp_enabled(psock);
+
+	/* Process pending decrypted records. It must be non-zero-copy */
+	err = process_rx_list(ctx, msg, &control, &cmsg, 0, len, false,
+			      is_peek);
+	if (err < 0) {
+		tls_err_abort(sk, err);
+		goto end;
+	} else {
+		copied = err;
+	}
+
+	if (len <= copied)
+		goto recv_end;
 
 	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
+	len = len - copied;
 	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
-	do {
-		bool zc = false;
-		int chunk = 0;
 
-		skb = tls_wait_data(sk, flags, timeo, &err);
-		if (!skb)
+	while (len && (decrypted + copied < target || ctx->recv_pkt)) {
+		bool retain_skb = false;
+		bool zc = false;
+		int to_decrypt;
+		int chunk = 0;
+		bool async_capable;
+		bool async = false;
+
+		skb = tls_wait_data(sk, psock, flags & MSG_DONTWAIT, timeo, &err);
+		if (!skb) {
+			if (psock) {
+				int ret = __tcp_bpf_recvmsg(sk, psock,
+							    msg, len, flags);
+
+				if (ret > 0) {
+					decrypted += ret;
+					len -= ret;
+					continue;
+				}
+			}
 			goto recv_end;
+		} else {
+			tlm = tls_msg(skb);
+			if (prot->version == TLS_1_3_VERSION)
+				tlm->control = 0;
+			else
+				tlm->control = ctx->control;
+		}
 
 		rxm = strp_msg(skb);
+
+		to_decrypt = rxm->full_len - prot->overhead_size;
+
+		if (to_decrypt <= len && !is_kvec && !is_peek &&
+		    ctx->control == TLS_RECORD_TYPE_DATA &&
+		    prot->version != TLS_1_3_VERSION &&
+		    !bpf_strp_enabled)
+			zc = true;
+
+		/* Do not use async mode if record is non-data */
+		if (ctx->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
+			async_capable = ctx->async_capable;
+		else
+			async_capable = false;
+
+		err = decrypt_skb_update(sk, skb, &msg->msg_iter,
+					 &chunk, &zc, async_capable);
+		if (err < 0 && err != -EINPROGRESS) {
+			tls_err_abort(sk, -EBADMSG);
+			goto recv_end;
+		}
+
+		if (err == -EINPROGRESS) {
+			async = true;
+			num_async++;
+		} else if (prot->version == TLS_1_3_VERSION) {
+			tlm->control = ctx->control;
+		}
+
+		/* If the type of records being processed is not known yet,
+		 * set it to record type just dequeued. If it is already known,
+		 * but does not match the record type just dequeued, go to end.
+		 * We always get record type here since for tls1.2, record type
+		 * is known just after record is dequeued from stream parser.
+		 * For tls1.3, we disable async.
+		 */
+
+		if (!control)
+			control = tlm->control;
+		else if (control != tlm->control)
+			goto recv_end;
+
 		if (!cmsg) {
 			int cerr;
 
 			cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
-					sizeof(ctx->control), &ctx->control);
+					sizeof(control), &control);
 			cmsg = true;
-			control = ctx->control;
-			if (ctx->control != TLS_RECORD_TYPE_DATA) {
+			if (control != TLS_RECORD_TYPE_DATA) {
 				if (cerr || msg->msg_flags & MSG_CTRUNC) {
 					err = -EIO;
 					goto recv_end;
 				}
 			}
-		} else if (control != ctx->control) {
-			goto recv_end;
 		}
 
-		if (!ctx->decrypted) {
-			int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
-
-			if (!is_kvec && to_copy <= len &&
-			    likely(!(flags & MSG_PEEK)))
-				zc = true;
-
-			err = decrypt_skb_update(sk, skb, &msg->msg_iter,
-						 &chunk, &zc);
-			if (err < 0) {
-				tls_err_abort(sk, EBADMSG);
-				goto recv_end;
-			}
-			ctx->decrypted = true;
-		}
+		if (async)
+			goto pick_next_record;
 
 		if (!zc) {
-			chunk = min_t(unsigned int, rxm->full_len, len);
-			err = skb_copy_datagram_msg(skb, rxm->offset, msg,
-						    chunk);
+			if (bpf_strp_enabled) {
+				err = sk_psock_tls_strp_read(psock, skb);
+				if (err != __SK_PASS) {
+					rxm->offset = rxm->offset + rxm->full_len;
+					rxm->full_len = 0;
+					if (err == __SK_DROP)
+						consume_skb(skb);
+					ctx->recv_pkt = NULL;
+					__strp_unpause(&ctx->strp);
+					continue;
+				}
+			}
+
+			if (rxm->full_len > len) {
+				retain_skb = true;
+				chunk = len;
+			} else {
+				chunk = rxm->full_len;
+			}
+
+			err = skb_copy_datagram_msg(skb, rxm->offset,
+						    msg, chunk);
 			if (err < 0)
 				goto recv_end;
-		}
 
-		copied += chunk;
-		len -= chunk;
-		if (likely(!(flags & MSG_PEEK))) {
-			u8 control = ctx->control;
-
-			if (tls_sw_advance_skb(sk, skb, chunk)) {
-				/* Return full control message to
-				 * userspace before trying to parse
-				 * another message type
-				 */
-				msg->msg_flags |= MSG_EOR;
-				if (control != TLS_RECORD_TYPE_DATA)
-					goto recv_end;
+			if (!is_peek) {
+				rxm->offset = rxm->offset + chunk;
+				rxm->full_len = rxm->full_len - chunk;
 			}
-		} else {
-			/* MSG_PEEK right now cannot look beyond current skb
-			 * from strparser, meaning we cannot advance skb here
-			 * and thus unpause strparser since we'd loose original
-			 * one.
-			 */
-			break;
 		}
 
-		/* If we have a new message from strparser, continue now. */
-		if (copied >= target && !ctx->recv_pkt)
+pick_next_record:
+		if (chunk > len)
+			chunk = len;
+
+		decrypted += chunk;
+		len -= chunk;
+
+		/* For async or peek case, queue the current skb */
+		if (async || is_peek || retain_skb) {
+			skb_queue_tail(&ctx->rx_list, skb);
+			skb = NULL;
+		}
+
+		if (tls_sw_advance_skb(sk, skb, chunk)) {
+			/* Return full control message to
+			 * userspace before trying to parse
+			 * another message type
+			 */
+			msg->msg_flags |= MSG_EOR;
+			if (control != TLS_RECORD_TYPE_DATA)
+				goto recv_end;
+		} else {
 			break;
-	} while (len);
+		}
+	}
 
 recv_end:
+	if (num_async) {
+		/* Wait for all previously submitted records to be decrypted */
+		spin_lock_bh(&ctx->decrypt_compl_lock);
+		ctx->async_notify = true;
+		pending = atomic_read(&ctx->decrypt_pending);
+		spin_unlock_bh(&ctx->decrypt_compl_lock);
+		if (pending) {
+			err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
+			if (err) {
+				/* one of async decrypt failed */
+				tls_err_abort(sk, err);
+				copied = 0;
+				decrypted = 0;
+				goto end;
+			}
+		} else {
+			reinit_completion(&ctx->async_wait.completion);
+		}
+
+		/* There can be no concurrent accesses, since we have no
+		 * pending decrypt operations
+		 */
+		WRITE_ONCE(ctx->async_notify, false);
+
+		/* Drain records from the rx_list & copy if required */
+		if (is_peek || is_kvec)
+			err = process_rx_list(ctx, msg, &control, &cmsg, copied,
+					      decrypted, false, is_peek);
+		else
+			err = process_rx_list(ctx, msg, &control, &cmsg, 0,
+					      decrypted, true, is_peek);
+		if (err < 0) {
+			tls_err_abort(sk, err);
+			copied = 0;
+			goto end;
+		}
+	}
+
+	copied += decrypted;
+
+end:
 	release_sock(sk);
+	if (psock)
+		sk_psock_put(sk, psock);
 	return copied ? : err;
 }
 
@@ -975,27 +2005,24 @@
 
 	lock_sock(sk);
 
-	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+	timeo = sock_rcvtimeo(sk, flags & SPLICE_F_NONBLOCK);
 
-	skb = tls_wait_data(sk, flags, timeo, &err);
+	skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo, &err);
 	if (!skb)
 		goto splice_read_end;
 
-	/* splice does not support reading control messages */
-	if (ctx->control != TLS_RECORD_TYPE_DATA) {
-		err = -ENOTSUPP;
+	err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
+	if (err < 0) {
+		tls_err_abort(sk, -EBADMSG);
 		goto splice_read_end;
 	}
 
-	if (!ctx->decrypted) {
-		err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
-
-		if (err < 0) {
-			tls_err_abort(sk, EBADMSG);
-			goto splice_read_end;
-		}
-		ctx->decrypted = true;
+	/* splice does not support reading control messages */
+	if (ctx->control != TLS_RECORD_TYPE_DATA) {
+		err = -EINVAL;
+		goto splice_read_end;
 	}
+
 	rxm = strp_msg(skb);
 
 	chunk = min_t(unsigned int, rxm->full_len, len);
@@ -1011,29 +2038,28 @@
 	return copied ? : err;
 }
 
-unsigned int tls_sw_poll(struct file *file, struct socket *sock,
-			 struct poll_table_struct *wait)
+bool tls_sw_stream_read(const struct sock *sk)
 {
-	unsigned int ret;
-	struct sock *sk = sock->sk;
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+	bool ingress_empty = true;
+	struct sk_psock *psock;
 
-	/* Grab POLLOUT and POLLHUP from the underlying socket */
-	ret = ctx->sk_poll(file, sock, wait);
+	rcu_read_lock();
+	psock = sk_psock(sk);
+	if (psock)
+		ingress_empty = list_empty(&psock->ingress_msg);
+	rcu_read_unlock();
 
-	/* Clear POLLIN bits, and set based on recv_pkt */
-	ret &= ~(POLLIN | POLLRDNORM);
-	if (ctx->recv_pkt)
-		ret |= POLLIN | POLLRDNORM;
-
-	return ret;
+	return !ingress_empty || ctx->recv_pkt ||
+		!skb_queue_empty(&ctx->rx_list);
 }
 
 static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
 	struct strp_msg *rxm = strp_msg(skb);
 	size_t cipher_overhead;
@@ -1041,17 +2067,17 @@
 	int ret;
 
 	/* Verify that we have a full TLS header, or wait for more data */
-	if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
+	if (rxm->offset + prot->prepend_size > skb->len)
 		return 0;
 
 	/* Sanity-check size of on-stack buffer. */
-	if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) {
+	if (WARN_ON(prot->prepend_size > sizeof(header))) {
 		ret = -EINVAL;
 		goto read_failure;
 	}
 
 	/* Linearize header to local buffer */
-	ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);
+	ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size);
 
 	if (ret < 0)
 		goto read_failure;
@@ -1060,9 +2086,12 @@
 
 	data_len = ((header[4] & 0xFF) | (header[3] << 8));
 
-	cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size;
+	cipher_overhead = prot->tag_size;
+	if (prot->version != TLS_1_3_VERSION)
+		cipher_overhead += prot->iv_size;
 
-	if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) {
+	if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead +
+	    prot->tail_size) {
 		ret = -EMSGSIZE;
 		goto read_failure;
 	}
@@ -1071,16 +2100,15 @@
 		goto read_failure;
 	}
 
-	if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.info.version) ||
-	    header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.info.version)) {
+	/* Note that both TLS1.3 and TLS1.2 use TLS_1_2 version here */
+	if (header[1] != TLS_1_2_VERSION_MINOR ||
+	    header[2] != TLS_1_2_VERSION_MAJOR) {
 		ret = -EINVAL;
 		goto read_failure;
 	}
 
-#ifdef CONFIG_TLS_DEVICE
-	handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset,
-			     *(u64*)tls_ctx->rx.rec_seq);
-#endif
+	tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE,
+				     TCP_SKB_CB(skb)->seq + rxm->offset);
 	return data_len + TLS_HEADER_SIZE;
 
 read_failure:
@@ -1094,7 +2122,7 @@
 	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
 
-	ctx->decrypted = false;
+	ctx->decrypted = 0;
 
 	ctx->recv_pkt = skb;
 	strp_pause(strp);
@@ -1106,17 +2134,71 @@
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+	struct sk_psock *psock;
 
 	strp_data_ready(&ctx->strp);
+
+	psock = sk_psock_get(sk);
+	if (psock) {
+		if (!list_empty(&psock->ingress_msg))
+			ctx->saved_data_ready(sk);
+		sk_psock_put(sk, psock);
+	}
 }
 
-void tls_sw_free_resources_tx(struct sock *sk)
+void tls_sw_cancel_work_tx(struct tls_context *tls_ctx)
+{
+	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+
+	set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask);
+	set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask);
+	cancel_delayed_work_sync(&ctx->tx_work.work);
+}
+
+void tls_sw_release_resources_tx(struct sock *sk)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct tls_rec *rec, *tmp;
+	int pending;
+
+	/* Wait for any pending async encryptions to complete */
+	spin_lock_bh(&ctx->encrypt_compl_lock);
+	ctx->async_notify = true;
+	pending = atomic_read(&ctx->encrypt_pending);
+	spin_unlock_bh(&ctx->encrypt_compl_lock);
+
+	if (pending)
+		crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
+
+	tls_tx_records(sk, -1);
+
+	/* Free up un-sent records in tx_list. First, free
+	 * the partially sent record if any at head of tx_list.
+	 */
+	if (tls_ctx->partially_sent_record) {
+		tls_free_partial_record(sk, tls_ctx);
+		rec = list_first_entry(&ctx->tx_list,
+				       struct tls_rec, list);
+		list_del(&rec->list);
+		sk_msg_free(sk, &rec->msg_plaintext);
+		kfree(rec);
+	}
+
+	list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
+		list_del(&rec->list);
+		sk_msg_free(sk, &rec->msg_encrypted);
+		sk_msg_free(sk, &rec->msg_plaintext);
+		kfree(rec);
+	}
 
 	crypto_free_aead(ctx->aead_send);
-	tls_free_both_sg(sk);
+	tls_free_open_rec(sk);
+}
+
+void tls_sw_free_ctx_tx(struct tls_context *tls_ctx)
+{
+	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
 
 	kfree(ctx);
 }
@@ -1132,38 +2214,116 @@
 	if (ctx->aead_recv) {
 		kfree_skb(ctx->recv_pkt);
 		ctx->recv_pkt = NULL;
+		skb_queue_purge(&ctx->rx_list);
 		crypto_free_aead(ctx->aead_recv);
 		strp_stop(&ctx->strp);
-		write_lock_bh(&sk->sk_callback_lock);
-		sk->sk_data_ready = ctx->saved_data_ready;
-		write_unlock_bh(&sk->sk_callback_lock);
-		release_sock(sk);
-		strp_done(&ctx->strp);
-		lock_sock(sk);
+		/* If tls_sw_strparser_arm() was not called (cleanup paths)
+		 * we still want to strp_stop(), but sk->sk_data_ready was
+		 * never swapped.
+		 */
+		if (ctx->saved_data_ready) {
+			write_lock_bh(&sk->sk_callback_lock);
+			sk->sk_data_ready = ctx->saved_data_ready;
+			write_unlock_bh(&sk->sk_callback_lock);
+		}
 	}
+}
+
+void tls_sw_strparser_done(struct tls_context *tls_ctx)
+{
+	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+
+	strp_done(&ctx->strp);
+}
+
+void tls_sw_free_ctx_rx(struct tls_context *tls_ctx)
+{
+	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+
+	kfree(ctx);
 }
 
 void tls_sw_free_resources_rx(struct sock *sk)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
-	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
 
 	tls_sw_release_resources_rx(sk);
+	tls_sw_free_ctx_rx(tls_ctx);
+}
 
-	kfree(ctx);
+/* The work handler to transmitt the encrypted records in tx_list */
+static void tx_work_handler(struct work_struct *work)
+{
+	struct delayed_work *delayed_work = to_delayed_work(work);
+	struct tx_work *tx_work = container_of(delayed_work,
+					       struct tx_work, work);
+	struct sock *sk = tx_work->sk;
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_sw_context_tx *ctx;
+
+	if (unlikely(!tls_ctx))
+		return;
+
+	ctx = tls_sw_ctx_tx(tls_ctx);
+	if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask))
+		return;
+
+	if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
+		return;
+
+	if (mutex_trylock(&tls_ctx->tx_lock)) {
+		lock_sock(sk);
+		tls_tx_records(sk, -1);
+		release_sock(sk);
+		mutex_unlock(&tls_ctx->tx_lock);
+	} else if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
+		/* Someone is holding the tx_lock, they will likely run Tx
+		 * and cancel the work on their way out of the lock section.
+		 * Schedule a long delay just in case.
+		 */
+		schedule_delayed_work(&ctx->tx_work.work, msecs_to_jiffies(10));
+	}
+}
+
+void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
+{
+	struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
+
+	/* Schedule the transmission if tx list is ready */
+	if (is_tx_ready(tx_ctx) &&
+	    !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask))
+		schedule_delayed_work(&tx_ctx->tx_work.work, 0);
+}
+
+void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
+{
+	struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
+
+	write_lock_bh(&sk->sk_callback_lock);
+	rx_ctx->saved_data_ready = sk->sk_data_ready;
+	sk->sk_data_ready = tls_data_ready;
+	write_unlock_bh(&sk->sk_callback_lock);
+
+	strp_check_rcv(&rx_ctx->strp);
 }
 
 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 {
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_crypto_info *crypto_info;
 	struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
+	struct tls12_crypto_info_aes_gcm_256 *gcm_256_info;
+	struct tls12_crypto_info_aes_ccm_128 *ccm_128_info;
 	struct tls_sw_context_tx *sw_ctx_tx = NULL;
 	struct tls_sw_context_rx *sw_ctx_rx = NULL;
 	struct cipher_context *cctx;
 	struct crypto_aead **aead;
 	struct strp_callbacks cb;
-	u16 nonce_size, tag_size, iv_size, rec_seq_size;
-	char *iv, *rec_seq;
+	u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
+	struct crypto_tfm *tfm;
+	char *iv, *rec_seq, *key, *salt, *cipher_name;
+	size_t keysize;
 	int rc = 0;
 
 	if (!ctx) {
@@ -1199,13 +2359,19 @@
 
 	if (tx) {
 		crypto_init_wait(&sw_ctx_tx->async_wait);
+		spin_lock_init(&sw_ctx_tx->encrypt_compl_lock);
 		crypto_info = &ctx->crypto_send.info;
 		cctx = &ctx->tx;
 		aead = &sw_ctx_tx->aead_send;
+		INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
+		INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
+		sw_ctx_tx->tx_work.sk = sk;
 	} else {
 		crypto_init_wait(&sw_ctx_rx->async_wait);
+		spin_lock_init(&sw_ctx_rx->decrypt_compl_lock);
 		crypto_info = &ctx->crypto_recv.info;
 		cctx = &ctx->rx;
+		skb_queue_head_init(&sw_ctx_rx->rx_list);
 		aead = &sw_ctx_rx->aead_recv;
 	}
 
@@ -1220,6 +2386,45 @@
 		 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
 		gcm_128_info =
 			(struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
+		keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE;
+		key = gcm_128_info->key;
+		salt = gcm_128_info->salt;
+		salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
+		cipher_name = "gcm(aes)";
+		break;
+	}
+	case TLS_CIPHER_AES_GCM_256: {
+		nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
+		tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE;
+		iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
+		iv = ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->iv;
+		rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE;
+		rec_seq =
+		 ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->rec_seq;
+		gcm_256_info =
+			(struct tls12_crypto_info_aes_gcm_256 *)crypto_info;
+		keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE;
+		key = gcm_256_info->key;
+		salt = gcm_256_info->salt;
+		salt_size = TLS_CIPHER_AES_GCM_256_SALT_SIZE;
+		cipher_name = "gcm(aes)";
+		break;
+	}
+	case TLS_CIPHER_AES_CCM_128: {
+		nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
+		tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE;
+		iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
+		iv = ((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->iv;
+		rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE;
+		rec_seq =
+		((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->rec_seq;
+		ccm_128_info =
+		(struct tls12_crypto_info_aes_ccm_128 *)crypto_info;
+		keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE;
+		key = ccm_128_info->key;
+		salt = ccm_128_info->salt;
+		salt_size = TLS_CIPHER_AES_CCM_128_SALT_SIZE;
+		cipher_name = "ccm(aes)";
 		break;
 	}
 	default:
@@ -1227,53 +2432,47 @@
 		goto free_priv;
 	}
 
-	/* Sanity-check the IV size for stack allocations. */
-	if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) {
+	/* Sanity-check the sizes for stack allocations. */
+	if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
+	    rec_seq_size > TLS_MAX_REC_SEQ_SIZE) {
 		rc = -EINVAL;
 		goto free_priv;
 	}
 
-	cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
-	cctx->tag_size = tag_size;
-	cctx->overhead_size = cctx->prepend_size + cctx->tag_size;
-	cctx->iv_size = iv_size;
-	cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
-			   GFP_KERNEL);
+	if (crypto_info->version == TLS_1_3_VERSION) {
+		nonce_size = 0;
+		prot->aad_size = TLS_HEADER_SIZE;
+		prot->tail_size = 1;
+	} else {
+		prot->aad_size = TLS_AAD_SPACE_SIZE;
+		prot->tail_size = 0;
+	}
+
+	prot->version = crypto_info->version;
+	prot->cipher_type = crypto_info->cipher_type;
+	prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
+	prot->tag_size = tag_size;
+	prot->overhead_size = prot->prepend_size +
+			      prot->tag_size + prot->tail_size;
+	prot->iv_size = iv_size;
+	prot->salt_size = salt_size;
+	cctx->iv = kmalloc(iv_size + salt_size, GFP_KERNEL);
 	if (!cctx->iv) {
 		rc = -ENOMEM;
 		goto free_priv;
 	}
-	memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
-	memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
-	cctx->rec_seq_size = rec_seq_size;
+	/* Note: 128 & 256 bit salt are the same size */
+	prot->rec_seq_size = rec_seq_size;
+	memcpy(cctx->iv, salt, salt_size);
+	memcpy(cctx->iv + salt_size, iv, iv_size);
 	cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
 	if (!cctx->rec_seq) {
 		rc = -ENOMEM;
 		goto free_iv;
 	}
 
-	if (sw_ctx_tx) {
-		sg_init_table(sw_ctx_tx->sg_encrypted_data,
-			      ARRAY_SIZE(sw_ctx_tx->sg_encrypted_data));
-		sg_init_table(sw_ctx_tx->sg_plaintext_data,
-			      ARRAY_SIZE(sw_ctx_tx->sg_plaintext_data));
-
-		sg_init_table(sw_ctx_tx->sg_aead_in, 2);
-		sg_set_buf(&sw_ctx_tx->sg_aead_in[0], sw_ctx_tx->aad_space,
-			   sizeof(sw_ctx_tx->aad_space));
-		sg_unmark_end(&sw_ctx_tx->sg_aead_in[1]);
-		sg_chain(sw_ctx_tx->sg_aead_in, 2,
-			 sw_ctx_tx->sg_plaintext_data);
-		sg_init_table(sw_ctx_tx->sg_aead_out, 2);
-		sg_set_buf(&sw_ctx_tx->sg_aead_out[0], sw_ctx_tx->aad_space,
-			   sizeof(sw_ctx_tx->aad_space));
-		sg_unmark_end(&sw_ctx_tx->sg_aead_out[1]);
-		sg_chain(sw_ctx_tx->sg_aead_out, 2,
-			 sw_ctx_tx->sg_encrypted_data);
-	}
-
 	if (!*aead) {
-		*aead = crypto_alloc_aead("gcm(aes)", 0, 0);
+		*aead = crypto_alloc_aead(cipher_name, 0, 0);
 		if (IS_ERR(*aead)) {
 			rc = PTR_ERR(*aead);
 			*aead = NULL;
@@ -1283,31 +2482,31 @@
 
 	ctx->push_pending_record = tls_sw_push_pending_record;
 
-	rc = crypto_aead_setkey(*aead, gcm_128_info->key,
-				TLS_CIPHER_AES_GCM_128_KEY_SIZE);
+	rc = crypto_aead_setkey(*aead, key, keysize);
+
 	if (rc)
 		goto free_aead;
 
-	rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
+	rc = crypto_aead_setauthsize(*aead, prot->tag_size);
 	if (rc)
 		goto free_aead;
 
 	if (sw_ctx_rx) {
+		tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
+
+		if (crypto_info->version == TLS_1_3_VERSION)
+			sw_ctx_rx->async_capable = 0;
+		else
+			sw_ctx_rx->async_capable =
+				!!(tfm->__crt_alg->cra_flags &
+				   CRYPTO_ALG_ASYNC);
+
 		/* Set up strparser */
 		memset(&cb, 0, sizeof(cb));
 		cb.rcv_msg = tls_queue;
 		cb.parse_msg = tls_read_size;
 
 		strp_init(&sw_ctx_rx->strp, sk, &cb);
-
-		write_lock_bh(&sk->sk_callback_lock);
-		sw_ctx_rx->saved_data_ready = sk->sk_data_ready;
-		sk->sk_data_ready = tls_data_ready;
-		write_unlock_bh(&sk->sk_callback_lock);
-
-		sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll;
-
-		strp_check_rcv(&sw_ctx_rx->strp);
 	}
 
 	goto out;

--
Gitblit v1.6.2