hc
2023-12-09 b22da3d8526a935aa31e086e63f60ff3246cb61c
kernel/net/sctp/input.c
....@@ -1,3 +1,4 @@
1
+// SPDX-License-Identifier: GPL-2.0-or-later
12 /* SCTP kernel implementation
23 * Copyright (c) 1999-2000 Cisco, Inc.
34 * Copyright (c) 1999-2001 Motorola, Inc.
....@@ -9,22 +10,6 @@
910 * This file is part of the SCTP kernel implementation
1011 *
1112 * These functions handle all input from the IP layer into SCTP.
12
- *
13
- * This SCTP implementation is free software;
14
- * you can redistribute it and/or modify it under the terms of
15
- * the GNU General Public License as published by
16
- * the Free Software Foundation; either version 2, or (at your option)
17
- * any later version.
18
- *
19
- * This SCTP implementation is distributed in the hope that it
20
- * will be useful, but WITHOUT ANY WARRANTY; without even the implied
21
- * ************************
22
- * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
23
- * See the GNU General Public License for more details.
24
- *
25
- * You should have received a copy of the GNU General Public License
26
- * along with GNU CC; see the file COPYING. If not, see
27
- * <http://www.gnu.org/licenses/>.
2813 *
2914 * Please send any bug reports or fixes you make to the
3015 * email address(es):
....@@ -57,6 +42,7 @@
5742 #include <net/sctp/checksum.h>
5843 #include <net/net_namespace.h>
5944 #include <linux/rhashtable.h>
45
+#include <net/sock_reuseport.h>
6046
6147 /* Forward declarations for internal helpers. */
6248 static int sctp_rcv_ootb(struct sk_buff *);
....@@ -65,8 +51,10 @@
6551 const union sctp_addr *paddr,
6652 const union sctp_addr *laddr,
6753 struct sctp_transport **transportp);
68
-static struct sctp_endpoint *__sctp_rcv_lookup_endpoint(struct net *net,
69
- const union sctp_addr *laddr);
54
+static struct sctp_endpoint *__sctp_rcv_lookup_endpoint(
55
+ struct net *net, struct sk_buff *skb,
56
+ const union sctp_addr *laddr,
57
+ const union sctp_addr *daddr);
7058 static struct sctp_association *__sctp_lookup_association(
7159 struct net *net,
7260 const union sctp_addr *local,
....@@ -104,6 +92,7 @@
10492 struct sctp_chunk *chunk;
10593 union sctp_addr src;
10694 union sctp_addr dest;
95
+ int bound_dev_if;
10796 int family;
10897 struct sctp_af *af;
10998 struct net *net = dev_net(skb->dev);
....@@ -171,7 +160,7 @@
171160 asoc = __sctp_rcv_lookup(net, skb, &src, &dest, &transport);
172161
173162 if (!asoc)
174
- ep = __sctp_rcv_lookup_endpoint(net, &dest);
163
+ ep = __sctp_rcv_lookup_endpoint(net, skb, &dest, &src);
175164
176165 /* Retrieve the common input handling substructure. */
177166 rcvr = asoc ? &asoc->base : &ep->base;
....@@ -181,7 +170,8 @@
181170 * If a frame arrives on an interface and the receiving socket is
182171 * bound to another interface, via SO_BINDTODEVICE, treat it as OOTB
183172 */
184
- if (sk->sk_bound_dev_if && (sk->sk_bound_dev_if != af->skb_iif(skb))) {
173
+ bound_dev_if = READ_ONCE(sk->sk_bound_dev_if);
174
+ if (bound_dev_if && (bound_dev_if != af->skb_iif(skb))) {
185175 if (transport) {
186176 sctp_transport_put(transport);
187177 asoc = NULL;
....@@ -213,7 +203,7 @@
213203
214204 if (!xfrm_policy_check(sk, XFRM_POLICY_IN, skb, family))
215205 goto discard_release;
216
- nf_reset(skb);
206
+ nf_reset_ct(skb);
217207
218208 if (sk_filter(sk, skb))
219209 goto discard_release;
....@@ -334,7 +324,7 @@
334324 bh_lock_sock(sk);
335325
336326 if (sock_owned_by_user(sk) || !sctp_newsk_ready(sk)) {
337
- if (sk_add_backlog(sk, skb, sk->sk_rcvbuf))
327
+ if (sk_add_backlog(sk, skb, READ_ONCE(sk->sk_rcvbuf)))
338328 sctp_chunk_free(chunk);
339329 else
340330 backloged = 1;
....@@ -349,7 +339,7 @@
349339 return 0;
350340 } else {
351341 if (!sctp_newsk_ready(sk)) {
352
- if (!sk_add_backlog(sk, skb, sk->sk_rcvbuf))
342
+ if (!sk_add_backlog(sk, skb, READ_ONCE(sk->sk_rcvbuf)))
353343 return 0;
354344 sctp_chunk_free(chunk);
355345 } else {
....@@ -376,7 +366,7 @@
376366 struct sctp_ep_common *rcvr = chunk->rcvr;
377367 int ret;
378368
379
- ret = sk_add_backlog(sk, skb, sk->sk_rcvbuf);
369
+ ret = sk_add_backlog(sk, skb, READ_ONCE(sk->sk_rcvbuf));
380370 if (!ret) {
381371 /* Hold the assoc/ep while hanging on the backlog queue.
382372 * This way, we know structures we need will not disappear
....@@ -560,6 +550,7 @@
560550
561551 /* Common cleanup code for icmp/icmpv6 error handler. */
562552 void sctp_err_finish(struct sock *sk, struct sctp_transport *t)
553
+ __releases(&((__sk)->sk_lock.slock))
563554 {
564555 bh_unlock_sock(sk);
565556 sctp_transport_put(t);
....@@ -580,7 +571,7 @@
580571 * is probably better.
581572 *
582573 */
583
-void sctp_v4_err(struct sk_buff *skb, __u32 info)
574
+int sctp_v4_err(struct sk_buff *skb, __u32 info)
584575 {
585576 const struct iphdr *iph = (const struct iphdr *)skb->data;
586577 const int ihlen = iph->ihl * 4;
....@@ -605,7 +596,7 @@
605596 skb->transport_header = savesctp;
606597 if (!sk) {
607598 __ICMP_INC_STATS(net, ICMP_MIB_INERRORS);
608
- return;
599
+ return -ENOENT;
609600 }
610601 /* Warning: The sock lock is held. Remember to call
611602 * sctp_err_finish!
....@@ -659,6 +650,7 @@
659650
660651 out_unlock:
661652 sctp_err_finish(sk, transport);
653
+ return 0;
662654 }
663655
664656 /*
....@@ -726,42 +718,86 @@
726718 }
727719
728720 /* Insert endpoint into the hash table. */
729
-static void __sctp_hash_endpoint(struct sctp_endpoint *ep)
721
+static int __sctp_hash_endpoint(struct sctp_endpoint *ep)
730722 {
731
- struct net *net = sock_net(ep->base.sk);
732
- struct sctp_ep_common *epb;
723
+ struct sock *sk = ep->base.sk;
724
+ struct net *net = sock_net(sk);
733725 struct sctp_hashbucket *head;
726
+ struct sctp_ep_common *epb;
734727
735728 epb = &ep->base;
736
-
737729 epb->hashent = sctp_ep_hashfn(net, epb->bind_addr.port);
738730 head = &sctp_ep_hashtable[epb->hashent];
731
+
732
+ if (sk->sk_reuseport) {
733
+ bool any = sctp_is_ep_boundall(sk);
734
+ struct sctp_ep_common *epb2;
735
+ struct list_head *list;
736
+ int cnt = 0, err = 1;
737
+
738
+ list_for_each(list, &ep->base.bind_addr.address_list)
739
+ cnt++;
740
+
741
+ sctp_for_each_hentry(epb2, &head->chain) {
742
+ struct sock *sk2 = epb2->sk;
743
+
744
+ if (!net_eq(sock_net(sk2), net) || sk2 == sk ||
745
+ !uid_eq(sock_i_uid(sk2), sock_i_uid(sk)) ||
746
+ !sk2->sk_reuseport)
747
+ continue;
748
+
749
+ err = sctp_bind_addrs_check(sctp_sk(sk2),
750
+ sctp_sk(sk), cnt);
751
+ if (!err) {
752
+ err = reuseport_add_sock(sk, sk2, any);
753
+ if (err)
754
+ return err;
755
+ break;
756
+ } else if (err < 0) {
757
+ return err;
758
+ }
759
+ }
760
+
761
+ if (err) {
762
+ err = reuseport_alloc(sk, any);
763
+ if (err)
764
+ return err;
765
+ }
766
+ }
739767
740768 write_lock(&head->lock);
741769 hlist_add_head(&epb->node, &head->chain);
742770 write_unlock(&head->lock);
771
+ return 0;
743772 }
744773
745774 /* Add an endpoint to the hash. Local BH-safe. */
746
-void sctp_hash_endpoint(struct sctp_endpoint *ep)
775
+int sctp_hash_endpoint(struct sctp_endpoint *ep)
747776 {
777
+ int err;
778
+
748779 local_bh_disable();
749
- __sctp_hash_endpoint(ep);
780
+ err = __sctp_hash_endpoint(ep);
750781 local_bh_enable();
782
+
783
+ return err;
751784 }
752785
753786 /* Remove endpoint from the hash table. */
754787 static void __sctp_unhash_endpoint(struct sctp_endpoint *ep)
755788 {
756
- struct net *net = sock_net(ep->base.sk);
789
+ struct sock *sk = ep->base.sk;
757790 struct sctp_hashbucket *head;
758791 struct sctp_ep_common *epb;
759792
760793 epb = &ep->base;
761794
762
- epb->hashent = sctp_ep_hashfn(net, epb->bind_addr.port);
795
+ epb->hashent = sctp_ep_hashfn(sock_net(sk), epb->bind_addr.port);
763796
764797 head = &sctp_ep_hashtable[epb->hashent];
798
+
799
+ if (rcu_access_pointer(sk->sk_reuseport_cb))
800
+ reuseport_detach_sock(sk);
765801
766802 write_lock(&head->lock);
767803 hlist_del_init(&epb->node);
....@@ -776,16 +812,35 @@
776812 local_bh_enable();
777813 }
778814
815
+static inline __u32 sctp_hashfn(const struct net *net, __be16 lport,
816
+ const union sctp_addr *paddr, __u32 seed)
817
+{
818
+ __u32 addr;
819
+
820
+ if (paddr->sa.sa_family == AF_INET6)
821
+ addr = jhash(&paddr->v6.sin6_addr, 16, seed);
822
+ else
823
+ addr = (__force __u32)paddr->v4.sin_addr.s_addr;
824
+
825
+ return jhash_3words(addr, ((__force __u32)paddr->v4.sin_port) << 16 |
826
+ (__force __u32)lport, net_hash_mix(net), seed);
827
+}
828
+
779829 /* Look up an endpoint. */
780
-static struct sctp_endpoint *__sctp_rcv_lookup_endpoint(struct net *net,
781
- const union sctp_addr *laddr)
830
+static struct sctp_endpoint *__sctp_rcv_lookup_endpoint(
831
+ struct net *net, struct sk_buff *skb,
832
+ const union sctp_addr *laddr,
833
+ const union sctp_addr *paddr)
782834 {
783835 struct sctp_hashbucket *head;
784836 struct sctp_ep_common *epb;
785837 struct sctp_endpoint *ep;
838
+ struct sock *sk;
839
+ __be16 lport;
786840 int hash;
787841
788
- hash = sctp_ep_hashfn(net, ntohs(laddr->v4.sin_port));
842
+ lport = laddr->v4.sin_port;
843
+ hash = sctp_ep_hashfn(net, ntohs(lport));
789844 head = &sctp_ep_hashtable[hash];
790845 read_lock(&head->lock);
791846 sctp_for_each_hentry(epb, &head->chain) {
....@@ -797,6 +852,15 @@
797852 ep = sctp_sk(net->sctp.ctl_sock)->ep;
798853
799854 hit:
855
+ sk = ep->base.sk;
856
+ if (sk->sk_reuseport) {
857
+ __u32 phash = sctp_hashfn(net, lport, paddr, 0);
858
+
859
+ sk = reuseport_select_sock(sk, phash, skb,
860
+ sizeof(struct sctphdr));
861
+ if (sk)
862
+ ep = sctp_sk(sk)->ep;
863
+ }
800864 sctp_endpoint_hold(ep);
801865 read_unlock(&head->lock);
802866 return ep;
....@@ -835,35 +899,17 @@
835899 static inline __u32 sctp_hash_obj(const void *data, u32 len, u32 seed)
836900 {
837901 const struct sctp_transport *t = data;
838
- const union sctp_addr *paddr = &t->ipaddr;
839
- const struct net *net = t->asoc->base.net;
840
- __be16 lport = htons(t->asoc->base.bind_addr.port);
841
- __u32 addr;
842902
843
- if (paddr->sa.sa_family == AF_INET6)
844
- addr = jhash(&paddr->v6.sin6_addr, 16, seed);
845
- else
846
- addr = (__force __u32)paddr->v4.sin_addr.s_addr;
847
-
848
- return jhash_3words(addr, ((__force __u32)paddr->v4.sin_port) << 16 |
849
- (__force __u32)lport, net_hash_mix(net), seed);
903
+ return sctp_hashfn(t->asoc->base.net,
904
+ htons(t->asoc->base.bind_addr.port),
905
+ &t->ipaddr, seed);
850906 }
851907
852908 static inline __u32 sctp_hash_key(const void *data, u32 len, u32 seed)
853909 {
854910 const struct sctp_hash_cmp_arg *x = data;
855
- const union sctp_addr *paddr = x->paddr;
856
- const struct net *net = x->net;
857
- __be16 lport = x->lport;
858
- __u32 addr;
859911
860
- if (paddr->sa.sa_family == AF_INET6)
861
- addr = jhash(&paddr->v6.sin6_addr, 16, seed);
862
- else
863
- addr = (__force __u32)paddr->v4.sin_addr.s_addr;
864
-
865
- return jhash_3words(addr, ((__force __u32)paddr->v4.sin_port) << 16 |
866
- (__force __u32)lport, net_hash_mix(net), seed);
912
+ return sctp_hashfn(x->net, x->lport, x->paddr, seed);
867913 }
868914
869915 static const struct rhashtable_params sctp_hash_params = {
....@@ -894,7 +940,7 @@
894940 if (t->asoc->temp)
895941 return 0;
896942
897
- arg.net = sock_net(t->asoc->base.sk);
943
+ arg.net = t->asoc->base.net;
898944 arg.paddr = &t->ipaddr;
899945 arg.lport = htons(t->asoc->base.bind_addr.port);
900946
....@@ -961,12 +1007,11 @@
9611007 const struct sctp_endpoint *ep,
9621008 const union sctp_addr *paddr)
9631009 {
964
- struct net *net = sock_net(ep->base.sk);
9651010 struct rhlist_head *tmp, *list;
9661011 struct sctp_transport *t;
9671012 struct sctp_hash_cmp_arg arg = {
9681013 .paddr = paddr,
969
- .net = net,
1014
+ .net = ep->base.net,
9701015 .lport = htons(ep->base.bind_addr.port),
9711016 };
9721017