From 04dd17822334871b23ea2862f7798fb0e0007777 Mon Sep 17 00:00:00 2001
From: hc <hc@nodka.com>
Date: Sat, 11 May 2024 08:53:19 +0000
Subject: [PATCH] change otg to host mode

---
 kernel/net/ipv4/esp4.c |  339 +++++++++++++++++++++++++++++++++++++++++++-------------
 1 files changed, 259 insertions(+), 80 deletions(-)

diff --git a/kernel/net/ipv4/esp4.c b/kernel/net/ipv4/esp4.c
index 0792a9e..a16d177 100644
--- a/kernel/net/ipv4/esp4.c
+++ b/kernel/net/ipv4/esp4.c
@@ -1,3 +1,4 @@
+// SPDX-License-Identifier: GPL-2.0-only
 #define pr_fmt(fmt) "IPsec: " fmt
 
 #include <crypto/aead.h>
@@ -17,6 +18,8 @@
 #include <net/icmp.h>
 #include <net/protocol.h>
 #include <net/udp.h>
+#include <net/tcp.h>
+#include <net/espintcp.h>
 
 #include <linux/highmem.h>
 
@@ -31,8 +34,6 @@
 };
 
 #define ESP_SKB_CB(__skb) ((struct esp_skb_cb *)&((__skb)->cb[0]))
-
-static u32 esp4_get_mtu(struct xfrm_state *x, int mtu);
 
 /*
  * Allocate an AEAD request structure with extra space for SG and IV.
@@ -118,6 +119,132 @@
 			put_page(sg_page(sg));
 }
 
+#ifdef CONFIG_INET_ESPINTCP
+struct esp_tcp_sk {
+	struct sock *sk;
+	struct rcu_head rcu;
+};
+
+static void esp_free_tcp_sk(struct rcu_head *head)
+{
+	struct esp_tcp_sk *esk = container_of(head, struct esp_tcp_sk, rcu);
+
+	sock_put(esk->sk);
+	kfree(esk);
+}
+
+static struct sock *esp_find_tcp_sk(struct xfrm_state *x)
+{
+	struct xfrm_encap_tmpl *encap = x->encap;
+	struct esp_tcp_sk *esk;
+	__be16 sport, dport;
+	struct sock *nsk;
+	struct sock *sk;
+
+	sk = rcu_dereference(x->encap_sk);
+	if (sk && sk->sk_state == TCP_ESTABLISHED)
+		return sk;
+
+	spin_lock_bh(&x->lock);
+	sport = encap->encap_sport;
+	dport = encap->encap_dport;
+	nsk = rcu_dereference_protected(x->encap_sk,
+					lockdep_is_held(&x->lock));
+	if (sk && sk == nsk) {
+		esk = kmalloc(sizeof(*esk), GFP_ATOMIC);
+		if (!esk) {
+			spin_unlock_bh(&x->lock);
+			return ERR_PTR(-ENOMEM);
+		}
+		RCU_INIT_POINTER(x->encap_sk, NULL);
+		esk->sk = sk;
+		call_rcu(&esk->rcu, esp_free_tcp_sk);
+	}
+	spin_unlock_bh(&x->lock);
+
+	sk = inet_lookup_established(xs_net(x), &tcp_hashinfo, x->id.daddr.a4,
+				     dport, x->props.saddr.a4, sport, 0);
+	if (!sk)
+		return ERR_PTR(-ENOENT);
+
+	if (!tcp_is_ulp_esp(sk)) {
+		sock_put(sk);
+		return ERR_PTR(-EINVAL);
+	}
+
+	spin_lock_bh(&x->lock);
+	nsk = rcu_dereference_protected(x->encap_sk,
+					lockdep_is_held(&x->lock));
+	if (encap->encap_sport != sport ||
+	    encap->encap_dport != dport) {
+		sock_put(sk);
+		sk = nsk ?: ERR_PTR(-EREMCHG);
+	} else if (sk == nsk) {
+		sock_put(sk);
+	} else {
+		rcu_assign_pointer(x->encap_sk, sk);
+	}
+	spin_unlock_bh(&x->lock);
+
+	return sk;
+}
+
+static int esp_output_tcp_finish(struct xfrm_state *x, struct sk_buff *skb)
+{
+	struct sock *sk;
+	int err;
+
+	rcu_read_lock();
+
+	sk = esp_find_tcp_sk(x);
+	err = PTR_ERR_OR_ZERO(sk);
+	if (err)
+		goto out;
+
+	bh_lock_sock(sk);
+	if (sock_owned_by_user(sk))
+		err = espintcp_queue_out(sk, skb);
+	else
+		err = espintcp_push_skb(sk, skb);
+	bh_unlock_sock(sk);
+
+out:
+	rcu_read_unlock();
+	return err;
+}
+
+static int esp_output_tcp_encap_cb(struct net *net, struct sock *sk,
+				   struct sk_buff *skb)
+{
+	struct dst_entry *dst = skb_dst(skb);
+	struct xfrm_state *x = dst->xfrm;
+
+	return esp_output_tcp_finish(x, skb);
+}
+
+static int esp_output_tail_tcp(struct xfrm_state *x, struct sk_buff *skb)
+{
+	int err;
+
+	local_bh_disable();
+	err = xfrm_trans_queue_net(xs_net(x), skb, esp_output_tcp_encap_cb);
+	local_bh_enable();
+
+	/* EINPROGRESS just happens to do the right thing.  It
+	 * actually means that the skb has been consumed and
+	 * isn't coming back.
+	 */
+	return err ?: -EINPROGRESS;
+}
+#else
+static int esp_output_tail_tcp(struct xfrm_state *x, struct sk_buff *skb)
+{
+	kfree_skb(skb);
+
+	return -EOPNOTSUPP;
+}
+#endif
+
 static void esp_output_done(struct crypto_async_request *base, int err)
 {
 	struct sk_buff *skb = base->data;
@@ -125,10 +252,13 @@
 	void *tmp;
 	struct xfrm_state *x;
 
-	if (xo && (xo->flags & XFRM_DEV_RESUME))
-		x = skb->sp->xvec[skb->sp->len - 1];
-	else
+	if (xo && (xo->flags & XFRM_DEV_RESUME)) {
+		struct sec_path *sp = skb_sec_path(skb);
+
+		x = sp->xvec[sp->len - 1];
+	} else {
 		x = skb_dst(skb)->xfrm;
+	}
 
 	tmp = ESP_SKB_CB(skb)->tmp;
 	esp_ssg_unref(x, tmp);
@@ -145,7 +275,11 @@
 		secpath_reset(skb);
 		xfrm_dev_resume(skb);
 	} else {
-		xfrm_output_resume(skb, err);
+		if (!err &&
+		    x->encap && x->encap->encap_type == TCP_ENCAP_ESPINTCP)
+			esp_output_tail_tcp(x, skb);
+		else
+			xfrm_output_resume(skb, err);
 	}
 }
 
@@ -207,31 +341,79 @@
 	esp_output_done(base, err);
 }
 
-static void esp_output_fill_trailer(u8 *tail, int tfclen, int plen, __u8 proto)
+static struct ip_esp_hdr *esp_output_udp_encap(struct sk_buff *skb,
+					       int encap_type,
+					       struct esp_info *esp,
+					       __be16 sport,
+					       __be16 dport)
 {
-	/* Fill padding... */
-	if (tfclen) {
-		memset(tail, 0, tfclen);
-		tail += tfclen;
-	}
-	do {
-		int i;
-		for (i = 0; i < plen - 2; i++)
-			tail[i] = i + 1;
-	} while (0);
-	tail[plen - 2] = plen - 2;
-	tail[plen - 1] = proto;
-}
-
-static int esp_output_udp_encap(struct xfrm_state *x, struct sk_buff *skb, struct esp_info *esp)
-{
-	int encap_type;
 	struct udphdr *uh;
 	__be32 *udpdata32;
-	__be16 sport, dport;
-	struct xfrm_encap_tmpl *encap = x->encap;
-	struct ip_esp_hdr *esph = esp->esph;
 	unsigned int len;
+
+	len = skb->len + esp->tailen - skb_transport_offset(skb);
+	if (len + sizeof(struct iphdr) > IP_MAX_MTU)
+		return ERR_PTR(-EMSGSIZE);
+
+	uh = (struct udphdr *)esp->esph;
+	uh->source = sport;
+	uh->dest = dport;
+	uh->len = htons(len);
+	uh->check = 0;
+
+	*skb_mac_header(skb) = IPPROTO_UDP;
+
+	if (encap_type == UDP_ENCAP_ESPINUDP_NON_IKE) {
+		udpdata32 = (__be32 *)(uh + 1);
+		udpdata32[0] = udpdata32[1] = 0;
+		return (struct ip_esp_hdr *)(udpdata32 + 2);
+	}
+
+	return (struct ip_esp_hdr *)(uh + 1);
+}
+
+#ifdef CONFIG_INET_ESPINTCP
+static struct ip_esp_hdr *esp_output_tcp_encap(struct xfrm_state *x,
+						    struct sk_buff *skb,
+						    struct esp_info *esp)
+{
+	__be16 *lenp = (void *)esp->esph;
+	struct ip_esp_hdr *esph;
+	unsigned int len;
+	struct sock *sk;
+
+	len = skb->len + esp->tailen - skb_transport_offset(skb);
+	if (len > IP_MAX_MTU)
+		return ERR_PTR(-EMSGSIZE);
+
+	rcu_read_lock();
+	sk = esp_find_tcp_sk(x);
+	rcu_read_unlock();
+
+	if (IS_ERR(sk))
+		return ERR_CAST(sk);
+
+	*lenp = htons(len);
+	esph = (struct ip_esp_hdr *)(lenp + 1);
+
+	return esph;
+}
+#else
+static struct ip_esp_hdr *esp_output_tcp_encap(struct xfrm_state *x,
+						    struct sk_buff *skb,
+						    struct esp_info *esp)
+{
+	return ERR_PTR(-EOPNOTSUPP);
+}
+#endif
+
+static int esp_output_encap(struct xfrm_state *x, struct sk_buff *skb,
+			    struct esp_info *esp)
+{
+	struct xfrm_encap_tmpl *encap = x->encap;
+	struct ip_esp_hdr *esph;
+	__be16 sport, dport;
+	int encap_type;
 
 	spin_lock_bh(&x->lock);
 	sport = encap->encap_sport;
@@ -239,29 +421,20 @@
 	encap_type = encap->encap_type;
 	spin_unlock_bh(&x->lock);
 
-	len = skb->len + esp->tailen - skb_transport_offset(skb);
-	if (len + sizeof(struct iphdr) >= IP_MAX_MTU)
-		return -EMSGSIZE;
-
-	uh = (struct udphdr *)esph;
-	uh->source = sport;
-	uh->dest = dport;
-	uh->len = htons(len);
-	uh->check = 0;
-
 	switch (encap_type) {
 	default:
 	case UDP_ENCAP_ESPINUDP:
-		esph = (struct ip_esp_hdr *)(uh + 1);
-		break;
 	case UDP_ENCAP_ESPINUDP_NON_IKE:
-		udpdata32 = (__be32 *)(uh + 1);
-		udpdata32[0] = udpdata32[1] = 0;
-		esph = (struct ip_esp_hdr *)(udpdata32 + 2);
+		esph = esp_output_udp_encap(skb, encap_type, esp, sport, dport);
+		break;
+	case TCP_ENCAP_ESPINTCP:
+		esph = esp_output_tcp_encap(x, skb, esp);
 		break;
 	}
 
-	*skb_mac_header(skb) = IPPROTO_UDP;
+	if (IS_ERR(esph))
+		return PTR_ERR(esph);
+
 	esp->esph = esph;
 
 	return 0;
@@ -276,13 +449,17 @@
 	struct sk_buff *trailer;
 	int tailen = esp->tailen;
 
-	/* this is non-NULL only with UDP Encapsulation */
+	/* this is non-NULL only with TCP/UDP Encapsulation */
 	if (x->encap) {
-		int err = esp_output_udp_encap(x, skb, esp);
+		int err = esp_output_encap(x, skb, esp);
 
 		if (err < 0)
 			return err;
 	}
+
+	if (ALIGN(tailen, L1_CACHE_BYTES) > PAGE_SIZE ||
+	    ALIGN(skb->data_len, L1_CACHE_BYTES) > PAGE_SIZE)
+		goto cow;
 
 	if (!skb_cloned(skb)) {
 		if (tailen <= skb_tailroom(skb)) {
@@ -467,6 +644,9 @@
 	if (sg != dsg)
 		esp_ssg_unref(x, tmp);
 
+	if (!err && x->encap && x->encap->encap_type == TCP_ENCAP_ESPINTCP)
+		err = esp_output_tail_tcp(x, skb);
+
 error_free:
 	kfree(tmp);
 error:
@@ -497,7 +677,7 @@
 		struct xfrm_dst *dst = (struct xfrm_dst *)skb_dst(skb);
 		u32 padto;
 
-		padto = min(x->tfcpad, esp4_get_mtu(x, dst->child_mtu_cached));
+		padto = min(x->tfcpad, xfrm_state_mtu(x, dst->child_mtu_cached));
 		if (skb->len < padto)
 			esp.tfclen = padto - skb->len;
 	}
@@ -593,7 +773,23 @@
 
 	if (x->encap) {
 		struct xfrm_encap_tmpl *encap = x->encap;
+		struct tcphdr *th = (void *)(skb_network_header(skb) + ihl);
 		struct udphdr *uh = (void *)(skb_network_header(skb) + ihl);
+		__be16 source;
+
+		switch (x->encap->encap_type) {
+		case TCP_ENCAP_ESPINTCP:
+			source = th->source;
+			break;
+		case UDP_ENCAP_ESPINUDP:
+		case UDP_ENCAP_ESPINUDP_NON_IKE:
+			source = uh->source;
+			break;
+		default:
+			WARN_ON_ONCE(1);
+			err = -EINVAL;
+			goto out;
+		}
 
 		/*
 		 * 1) if the NAT-T peer's IP or port changed then
@@ -602,11 +798,11 @@
 		 *    SRC ports.
 		 */
 		if (iph->saddr != x->props.saddr.a4 ||
-		    uh->source != encap->encap_sport) {
+		    source != encap->encap_sport) {
 			xfrm_address_t ipaddr;
 
 			ipaddr.a4 = iph->saddr;
-			km_new_mapping(x, &ipaddr, uh->source);
+			km_new_mapping(x, &ipaddr, source);
 
 			/* XXX: perhaps add an extra
 			 * policy check here, to see
@@ -688,12 +884,11 @@
  */
 static int esp_input(struct xfrm_state *x, struct sk_buff *skb)
 {
-	struct ip_esp_hdr *esph;
 	struct crypto_aead *aead = x->data;
 	struct aead_request *req;
 	struct sk_buff *trailer;
 	int ivlen = crypto_aead_ivsize(aead);
-	int elen = skb->len - sizeof(*esph) - ivlen;
+	int elen = skb->len - sizeof(struct ip_esp_hdr) - ivlen;
 	int nfrags;
 	int assoclen;
 	int seqhilen;
@@ -703,13 +898,13 @@
 	struct scatterlist *sg;
 	int err = -EINVAL;
 
-	if (!pskb_may_pull(skb, sizeof(*esph) + ivlen))
+	if (!pskb_may_pull(skb, sizeof(struct ip_esp_hdr) + ivlen))
 		goto out;
 
 	if (elen <= 0)
 		goto out;
 
-	assoclen = sizeof(*esph);
+	assoclen = sizeof(struct ip_esp_hdr);
 	seqhilen = 0;
 
 	if (x->props.flags & XFRM_STATE_ESN) {
@@ -780,28 +975,6 @@
 	return err;
 }
 
-static u32 esp4_get_mtu(struct xfrm_state *x, int mtu)
-{
-	struct crypto_aead *aead = x->data;
-	u32 blksize = ALIGN(crypto_aead_blocksize(aead), 4);
-	unsigned int net_adj;
-
-	switch (x->props.mode) {
-	case XFRM_MODE_TRANSPORT:
-	case XFRM_MODE_BEET:
-		net_adj = sizeof(struct iphdr);
-		break;
-	case XFRM_MODE_TUNNEL:
-		net_adj = 0;
-		break;
-	default:
-		BUG();
-	}
-
-	return ((mtu - x->props.header_len - crypto_aead_authsize(aead) -
-		 net_adj) & ~(blksize - 1)) + net_adj - 2;
-}
-
 static int esp4_err(struct sk_buff *skb, u32 info)
 {
 	struct net *net = dev_net(skb->dev);
@@ -825,9 +998,9 @@
 		return 0;
 
 	if (icmp_hdr(skb)->type == ICMP_DEST_UNREACH)
-		ipv4_update_pmtu(skb, net, info, 0, 0, IPPROTO_ESP, 0);
+		ipv4_update_pmtu(skb, net, info, 0, IPPROTO_ESP);
 	else
-		ipv4_redirect(skb, net, 0, 0, IPPROTO_ESP, 0);
+		ipv4_redirect(skb, net, 0, IPPROTO_ESP);
 	xfrm_state_put(x);
 
 	return 0;
@@ -961,7 +1134,7 @@
 	err = crypto_aead_setkey(aead, key, keylen);
 
 free_key:
-	kfree(key);
+	kfree_sensitive(key);
 
 error:
 	return err;
@@ -1004,6 +1177,14 @@
 		case UDP_ENCAP_ESPINUDP_NON_IKE:
 			x->props.header_len += sizeof(struct udphdr) + 2 * sizeof(u32);
 			break;
+#ifdef CONFIG_INET_ESPINTCP
+		case TCP_ENCAP_ESPINTCP:
+			/* only the length field, TCP encap is done by
+			 * the socket
+			 */
+			x->props.header_len += 2;
+			break;
+#endif
 		}
 	}
 
@@ -1027,7 +1208,6 @@
 	.flags		= XFRM_TYPE_REPLAY_PROT,
 	.init_state	= esp_init_state,
 	.destructor	= esp_destroy,
-	.get_mtu	= esp4_get_mtu,
 	.input		= esp_input,
 	.output		= esp_output,
 };
@@ -1058,8 +1238,7 @@
 {
 	if (xfrm4_protocol_deregister(&esp4_protocol, IPPROTO_ESP) < 0)
 		pr_info("%s: can't remove protocol\n", __func__);
-	if (xfrm_unregister_type(&esp_type, AF_INET) < 0)
-		pr_info("%s: can't remove xfrm type\n", __func__);
+	xfrm_unregister_type(&esp_type, AF_INET);
 }
 
 module_init(esp4_init);

--
Gitblit v1.6.2