From cf4ce59b3b70238352c7f1729f0f7223214828ad Mon Sep 17 00:00:00 2001
From: hc <hc@nodka.com>
Date: Fri, 20 Sep 2024 01:46:19 +0000
Subject: [PATCH] rtl88x2CE_WiFi_linux add concurrent mode

---
 kernel/net/sctp/input.c |  169 +++++++++++++++++++++++++++++++++++--------------------
 1 files changed, 107 insertions(+), 62 deletions(-)

diff --git a/kernel/net/sctp/input.c b/kernel/net/sctp/input.c
index 64dc292..8f3aab6 100644
--- a/kernel/net/sctp/input.c
+++ b/kernel/net/sctp/input.c
@@ -1,3 +1,4 @@
+// SPDX-License-Identifier: GPL-2.0-or-later
 /* SCTP kernel implementation
  * Copyright (c) 1999-2000 Cisco, Inc.
  * Copyright (c) 1999-2001 Motorola, Inc.
@@ -9,22 +10,6 @@
  * This file is part of the SCTP kernel implementation
  *
  * These functions handle all input from the IP layer into SCTP.
- *
- * This SCTP implementation is free software;
- * you can redistribute it and/or modify it under the terms of
- * the GNU General Public License as published by
- * the Free Software Foundation; either version 2, or (at your option)
- * any later version.
- *
- * This SCTP implementation is distributed in the hope that it
- * will be useful, but WITHOUT ANY WARRANTY; without even the implied
- *                 ************************
- * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
- * See the GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with GNU CC; see the file COPYING.  If not, see
- * <http://www.gnu.org/licenses/>.
  *
  * Please send any bug reports or fixes you make to the
  * email address(es):
@@ -57,6 +42,7 @@
 #include <net/sctp/checksum.h>
 #include <net/net_namespace.h>
 #include <linux/rhashtable.h>
+#include <net/sock_reuseport.h>
 
 /* Forward declarations for internal helpers. */
 static int sctp_rcv_ootb(struct sk_buff *);
@@ -65,8 +51,10 @@
 				      const union sctp_addr *paddr,
 				      const union sctp_addr *laddr,
 				      struct sctp_transport **transportp);
-static struct sctp_endpoint *__sctp_rcv_lookup_endpoint(struct net *net,
-						const union sctp_addr *laddr);
+static struct sctp_endpoint *__sctp_rcv_lookup_endpoint(
+					struct net *net, struct sk_buff *skb,
+					const union sctp_addr *laddr,
+					const union sctp_addr *daddr);
 static struct sctp_association *__sctp_lookup_association(
 					struct net *net,
 					const union sctp_addr *local,
@@ -104,6 +92,7 @@
 	struct sctp_chunk *chunk;
 	union sctp_addr src;
 	union sctp_addr dest;
+	int bound_dev_if;
 	int family;
 	struct sctp_af *af;
 	struct net *net = dev_net(skb->dev);
@@ -171,7 +160,7 @@
 	asoc = __sctp_rcv_lookup(net, skb, &src, &dest, &transport);
 
 	if (!asoc)
-		ep = __sctp_rcv_lookup_endpoint(net, &dest);
+		ep = __sctp_rcv_lookup_endpoint(net, skb, &dest, &src);
 
 	/* Retrieve the common input handling substructure. */
 	rcvr = asoc ? &asoc->base : &ep->base;
@@ -181,7 +170,8 @@
 	 * If a frame arrives on an interface and the receiving socket is
 	 * bound to another interface, via SO_BINDTODEVICE, treat it as OOTB
 	 */
-	if (sk->sk_bound_dev_if && (sk->sk_bound_dev_if != af->skb_iif(skb))) {
+	bound_dev_if = READ_ONCE(sk->sk_bound_dev_if);
+	if (bound_dev_if && (bound_dev_if != af->skb_iif(skb))) {
 		if (transport) {
 			sctp_transport_put(transport);
 			asoc = NULL;
@@ -213,7 +203,7 @@
 
 	if (!xfrm_policy_check(sk, XFRM_POLICY_IN, skb, family))
 		goto discard_release;
-	nf_reset(skb);
+	nf_reset_ct(skb);
 
 	if (sk_filter(sk, skb))
 		goto discard_release;
@@ -334,7 +324,7 @@
 		bh_lock_sock(sk);
 
 		if (sock_owned_by_user(sk) || !sctp_newsk_ready(sk)) {
-			if (sk_add_backlog(sk, skb, sk->sk_rcvbuf))
+			if (sk_add_backlog(sk, skb, READ_ONCE(sk->sk_rcvbuf)))
 				sctp_chunk_free(chunk);
 			else
 				backloged = 1;
@@ -349,7 +339,7 @@
 			return 0;
 	} else {
 		if (!sctp_newsk_ready(sk)) {
-			if (!sk_add_backlog(sk, skb, sk->sk_rcvbuf))
+			if (!sk_add_backlog(sk, skb, READ_ONCE(sk->sk_rcvbuf)))
 				return 0;
 			sctp_chunk_free(chunk);
 		} else {
@@ -376,7 +366,7 @@
 	struct sctp_ep_common *rcvr = chunk->rcvr;
 	int ret;
 
-	ret = sk_add_backlog(sk, skb, sk->sk_rcvbuf);
+	ret = sk_add_backlog(sk, skb, READ_ONCE(sk->sk_rcvbuf));
 	if (!ret) {
 		/* Hold the assoc/ep while hanging on the backlog queue.
 		 * This way, we know structures we need will not disappear
@@ -560,6 +550,7 @@
 
 /* Common cleanup code for icmp/icmpv6 error handler. */
 void sctp_err_finish(struct sock *sk, struct sctp_transport *t)
+	__releases(&((__sk)->sk_lock.slock))
 {
 	bh_unlock_sock(sk);
 	sctp_transport_put(t);
@@ -580,7 +571,7 @@
  * is probably better.
  *
  */
-void sctp_v4_err(struct sk_buff *skb, __u32 info)
+int sctp_v4_err(struct sk_buff *skb, __u32 info)
 {
 	const struct iphdr *iph = (const struct iphdr *)skb->data;
 	const int ihlen = iph->ihl * 4;
@@ -605,7 +596,7 @@
 	skb->transport_header = savesctp;
 	if (!sk) {
 		__ICMP_INC_STATS(net, ICMP_MIB_INERRORS);
-		return;
+		return -ENOENT;
 	}
 	/* Warning:  The sock lock is held.  Remember to call
 	 * sctp_err_finish!
@@ -659,6 +650,7 @@
 
 out_unlock:
 	sctp_err_finish(sk, transport);
+	return 0;
 }
 
 /*
@@ -726,42 +718,86 @@
 }
 
 /* Insert endpoint into the hash table.  */
-static void __sctp_hash_endpoint(struct sctp_endpoint *ep)
+static int __sctp_hash_endpoint(struct sctp_endpoint *ep)
 {
-	struct net *net = sock_net(ep->base.sk);
-	struct sctp_ep_common *epb;
+	struct sock *sk = ep->base.sk;
+	struct net *net = sock_net(sk);
 	struct sctp_hashbucket *head;
+	struct sctp_ep_common *epb;
 
 	epb = &ep->base;
-
 	epb->hashent = sctp_ep_hashfn(net, epb->bind_addr.port);
 	head = &sctp_ep_hashtable[epb->hashent];
+
+	if (sk->sk_reuseport) {
+		bool any = sctp_is_ep_boundall(sk);
+		struct sctp_ep_common *epb2;
+		struct list_head *list;
+		int cnt = 0, err = 1;
+
+		list_for_each(list, &ep->base.bind_addr.address_list)
+			cnt++;
+
+		sctp_for_each_hentry(epb2, &head->chain) {
+			struct sock *sk2 = epb2->sk;
+
+			if (!net_eq(sock_net(sk2), net) || sk2 == sk ||
+			    !uid_eq(sock_i_uid(sk2), sock_i_uid(sk)) ||
+			    !sk2->sk_reuseport)
+				continue;
+
+			err = sctp_bind_addrs_check(sctp_sk(sk2),
+						    sctp_sk(sk), cnt);
+			if (!err) {
+				err = reuseport_add_sock(sk, sk2, any);
+				if (err)
+					return err;
+				break;
+			} else if (err < 0) {
+				return err;
+			}
+		}
+
+		if (err) {
+			err = reuseport_alloc(sk, any);
+			if (err)
+				return err;
+		}
+	}
 
 	write_lock(&head->lock);
 	hlist_add_head(&epb->node, &head->chain);
 	write_unlock(&head->lock);
+	return 0;
 }
 
 /* Add an endpoint to the hash. Local BH-safe. */
-void sctp_hash_endpoint(struct sctp_endpoint *ep)
+int sctp_hash_endpoint(struct sctp_endpoint *ep)
 {
+	int err;
+
 	local_bh_disable();
-	__sctp_hash_endpoint(ep);
+	err = __sctp_hash_endpoint(ep);
 	local_bh_enable();
+
+	return err;
 }
 
 /* Remove endpoint from the hash table.  */
 static void __sctp_unhash_endpoint(struct sctp_endpoint *ep)
 {
-	struct net *net = sock_net(ep->base.sk);
+	struct sock *sk = ep->base.sk;
 	struct sctp_hashbucket *head;
 	struct sctp_ep_common *epb;
 
 	epb = &ep->base;
 
-	epb->hashent = sctp_ep_hashfn(net, epb->bind_addr.port);
+	epb->hashent = sctp_ep_hashfn(sock_net(sk), epb->bind_addr.port);
 
 	head = &sctp_ep_hashtable[epb->hashent];
+
+	if (rcu_access_pointer(sk->sk_reuseport_cb))
+		reuseport_detach_sock(sk);
 
 	write_lock(&head->lock);
 	hlist_del_init(&epb->node);
@@ -776,16 +812,35 @@
 	local_bh_enable();
 }
 
+static inline __u32 sctp_hashfn(const struct net *net, __be16 lport,
+				const union sctp_addr *paddr, __u32 seed)
+{
+	__u32 addr;
+
+	if (paddr->sa.sa_family == AF_INET6)
+		addr = jhash(&paddr->v6.sin6_addr, 16, seed);
+	else
+		addr = (__force __u32)paddr->v4.sin_addr.s_addr;
+
+	return  jhash_3words(addr, ((__force __u32)paddr->v4.sin_port) << 16 |
+			     (__force __u32)lport, net_hash_mix(net), seed);
+}
+
 /* Look up an endpoint. */
-static struct sctp_endpoint *__sctp_rcv_lookup_endpoint(struct net *net,
-						const union sctp_addr *laddr)
+static struct sctp_endpoint *__sctp_rcv_lookup_endpoint(
+					struct net *net, struct sk_buff *skb,
+					const union sctp_addr *laddr,
+					const union sctp_addr *paddr)
 {
 	struct sctp_hashbucket *head;
 	struct sctp_ep_common *epb;
 	struct sctp_endpoint *ep;
+	struct sock *sk;
+	__be16 lport;
 	int hash;
 
-	hash = sctp_ep_hashfn(net, ntohs(laddr->v4.sin_port));
+	lport = laddr->v4.sin_port;
+	hash = sctp_ep_hashfn(net, ntohs(lport));
 	head = &sctp_ep_hashtable[hash];
 	read_lock(&head->lock);
 	sctp_for_each_hentry(epb, &head->chain) {
@@ -797,6 +852,15 @@
 	ep = sctp_sk(net->sctp.ctl_sock)->ep;
 
 hit:
+	sk = ep->base.sk;
+	if (sk->sk_reuseport) {
+		__u32 phash = sctp_hashfn(net, lport, paddr, 0);
+
+		sk = reuseport_select_sock(sk, phash, skb,
+					   sizeof(struct sctphdr));
+		if (sk)
+			ep = sctp_sk(sk)->ep;
+	}
 	sctp_endpoint_hold(ep);
 	read_unlock(&head->lock);
 	return ep;
@@ -835,35 +899,17 @@
 static inline __u32 sctp_hash_obj(const void *data, u32 len, u32 seed)
 {
 	const struct sctp_transport *t = data;
-	const union sctp_addr *paddr = &t->ipaddr;
-	const struct net *net = t->asoc->base.net;
-	__be16 lport = htons(t->asoc->base.bind_addr.port);
-	__u32 addr;
 
-	if (paddr->sa.sa_family == AF_INET6)
-		addr = jhash(&paddr->v6.sin6_addr, 16, seed);
-	else
-		addr = (__force __u32)paddr->v4.sin_addr.s_addr;
-
-	return  jhash_3words(addr, ((__force __u32)paddr->v4.sin_port) << 16 |
-			     (__force __u32)lport, net_hash_mix(net), seed);
+	return sctp_hashfn(t->asoc->base.net,
+			   htons(t->asoc->base.bind_addr.port),
+			   &t->ipaddr, seed);
 }
 
 static inline __u32 sctp_hash_key(const void *data, u32 len, u32 seed)
 {
 	const struct sctp_hash_cmp_arg *x = data;
-	const union sctp_addr *paddr = x->paddr;
-	const struct net *net = x->net;
-	__be16 lport = x->lport;
-	__u32 addr;
 
-	if (paddr->sa.sa_family == AF_INET6)
-		addr = jhash(&paddr->v6.sin6_addr, 16, seed);
-	else
-		addr = (__force __u32)paddr->v4.sin_addr.s_addr;
-
-	return  jhash_3words(addr, ((__force __u32)paddr->v4.sin_port) << 16 |
-			     (__force __u32)lport, net_hash_mix(net), seed);
+	return sctp_hashfn(x->net, x->lport, x->paddr, seed);
 }
 
 static const struct rhashtable_params sctp_hash_params = {
@@ -894,7 +940,7 @@
 	if (t->asoc->temp)
 		return 0;
 
-	arg.net   = sock_net(t->asoc->base.sk);
+	arg.net   = t->asoc->base.net;
 	arg.paddr = &t->ipaddr;
 	arg.lport = htons(t->asoc->base.bind_addr.port);
 
@@ -961,12 +1007,11 @@
 				const struct sctp_endpoint *ep,
 				const union sctp_addr *paddr)
 {
-	struct net *net = sock_net(ep->base.sk);
 	struct rhlist_head *tmp, *list;
 	struct sctp_transport *t;
 	struct sctp_hash_cmp_arg arg = {
 		.paddr = paddr,
-		.net   = net,
+		.net   = ep->base.net,
 		.lport = htons(ep->base.bind_addr.port),
 	};
 

--
Gitblit v1.6.2