From 8ac6c7a54ed1b98d142dce24b11c6de6a1e239a5 Mon Sep 17 00:00:00 2001
From: hc <hc@nodka.com>
Date: Tue, 22 Oct 2024 10:36:11 +0000
Subject: [PATCH] 修改4g拨号为QMI,需要在系统里后台执行quectel-CM

---
 kernel/net/ipv6/udp.c |  548 +++++++++++++++++++++++++++++++++---------------------
 1 files changed, 334 insertions(+), 214 deletions(-)

diff --git a/kernel/net/ipv6/udp.c b/kernel/net/ipv6/udp.c
index 7d3caaf..5385037 100644
--- a/kernel/net/ipv6/udp.c
+++ b/kernel/net/ipv6/udp.c
@@ -1,3 +1,4 @@
+// SPDX-License-Identifier: GPL-2.0-or-later
 /*
  *	UDP over IPv6
  *	Linux INET6 implementation
@@ -14,11 +15,6 @@
  *					a single port at the same time.
  *      Kazunori MIYAZAWA @USAGI:       change process style to use ip6_append_data
  *      YOSHIFUJI Hideaki @USAGI:	convert /proc/net/udp6 to seq_file.
- *
- *	This program is free software; you can redistribute it and/or
- *      modify it under the terms of the GNU General Public License
- *      as published by the Free Software Foundation; either version
- *      2 of the License, or (at your option) any later version.
  */
 
 #include <linux/errno.h>
@@ -36,6 +32,7 @@
 #include <linux/skbuff.h>
 #include <linux/slab.h>
 #include <linux/uaccess.h>
+#include <linux/indirect_call_wrapper.h>
 
 #include <net/addrconf.h>
 #include <net/ndisc.h>
@@ -45,6 +42,7 @@
 #include <net/raw.h>
 #include <net/tcp_states.h>
 #include <net/ip6_checksum.h>
+#include <net/ip6_tunnel.h>
 #include <net/xfrm.h>
 #include <net/inet_hashtables.h>
 #include <net/inet6_hashtables.h>
@@ -56,14 +54,17 @@
 #include <trace/events/skb.h>
 #include "udp_impl.h"
 
-static bool udp6_lib_exact_dif_match(struct net *net, struct sk_buff *skb)
+static void udpv6_destruct_sock(struct sock *sk)
 {
-#if defined(CONFIG_NET_L3_MASTER_DEV)
-	if (!net->ipv4.sysctl_udp_l3mdev_accept &&
-	    skb && ipv6_l3mdev_skb(IP6CB(skb)->flags))
-		return true;
-#endif
-	return false;
+	udp_destruct_common(sk);
+	inet6_sock_destruct(sk);
+}
+
+int udpv6_init_sock(struct sock *sk)
+{
+	skb_queue_head_init(&udp_sk(sk)->reader_queue);
+	sk->sk_destruct = udpv6_destruct_sock;
+	return 0;
 }
 
 static u32 udp6_ehashfn(const struct net *net,
@@ -86,7 +87,7 @@
 	fhash = __ipv6_addr_jhash(faddr, udp_ipv6_hash_secret);
 
 	return __inet6_ehashfn(lhash, lport, fhash, fport,
-			       udp_ipv6_hash_secret + net_hash_mix(net));
+			       udp6_ehash_secret + net_hash_mix(net));
 }
 
 int udp_v6_get_port(struct sock *sk, unsigned short snum)
@@ -101,7 +102,7 @@
 	return udp_lib_get_port(sk, snum, hash2_nulladdr);
 }
 
-static void udp_v6_rehash(struct sock *sk)
+void udp_v6_rehash(struct sock *sk)
 {
 	u16 new_hash = ipv6_portaddr_hash(sock_net(sk),
 					  &sk->sk_v6_rcv_saddr,
@@ -113,14 +114,18 @@
 static int compute_score(struct sock *sk, struct net *net,
 			 const struct in6_addr *saddr, __be16 sport,
 			 const struct in6_addr *daddr, unsigned short hnum,
-			 int dif, int sdif, bool exact_dif)
+			 int dif, int sdif)
 {
 	int score;
 	struct inet_sock *inet;
+	bool dev_match;
 
 	if (!net_eq(sock_net(sk), net) ||
 	    udp_sk(sk)->udp_port_hash != hnum ||
 	    sk->sk_family != PF_INET6)
+		return -1;
+
+	if (!ipv6_addr_equal(&sk->sk_v6_rcv_saddr, daddr))
 		return -1;
 
 	score = 0;
@@ -132,27 +137,17 @@
 		score++;
 	}
 
-	if (!ipv6_addr_any(&sk->sk_v6_rcv_saddr)) {
-		if (!ipv6_addr_equal(&sk->sk_v6_rcv_saddr, daddr))
-			return -1;
-		score++;
-	}
-
 	if (!ipv6_addr_any(&sk->sk_v6_daddr)) {
 		if (!ipv6_addr_equal(&sk->sk_v6_daddr, saddr))
 			return -1;
 		score++;
 	}
 
-	if (sk->sk_bound_dev_if || exact_dif) {
-		bool dev_match = (sk->sk_bound_dev_if == dif ||
-				  sk->sk_bound_dev_if == sdif);
-
-		if (!dev_match)
-			return -1;
-		if (sk->sk_bound_dev_if)
-			score++;
-	}
+	dev_match = udp_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif);
+	if (!dev_match)
+		return -1;
+	if (sk->sk_bound_dev_if)
+		score++;
 
 	if (READ_ONCE(sk->sk_incoming_cpu) == raw_smp_processor_id())
 		score++;
@@ -160,41 +155,85 @@
 	return score;
 }
 
+static struct sock *lookup_reuseport(struct net *net, struct sock *sk,
+				     struct sk_buff *skb,
+				     const struct in6_addr *saddr,
+				     __be16 sport,
+				     const struct in6_addr *daddr,
+				     unsigned int hnum)
+{
+	struct sock *reuse_sk = NULL;
+	u32 hash;
+
+	if (sk->sk_reuseport && sk->sk_state != TCP_ESTABLISHED) {
+		hash = udp6_ehashfn(net, daddr, hnum, saddr, sport);
+		reuse_sk = reuseport_select_sock(sk, hash, skb,
+						 sizeof(struct udphdr));
+	}
+	return reuse_sk;
+}
+
 /* called with rcu_read_lock() */
 static struct sock *udp6_lib_lookup2(struct net *net,
 		const struct in6_addr *saddr, __be16 sport,
 		const struct in6_addr *daddr, unsigned int hnum,
-		int dif, int sdif, bool exact_dif,
-		struct udp_hslot *hslot2, struct sk_buff *skb)
+		int dif, int sdif, struct udp_hslot *hslot2,
+		struct sk_buff *skb)
 {
-	struct sock *sk, *result, *reuseport_result;
+	struct sock *sk, *result;
 	int score, badness;
-	u32 hash = 0;
 
 	result = NULL;
 	badness = -1;
 	udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
 		score = compute_score(sk, net, saddr, sport,
-				      daddr, hnum, dif, sdif, exact_dif);
+				      daddr, hnum, dif, sdif);
 		if (score > badness) {
-			reuseport_result = NULL;
-
-			if (sk->sk_reuseport &&
-			    sk->sk_state != TCP_ESTABLISHED) {
-				hash = udp6_ehashfn(net, daddr, hnum,
-						    saddr, sport);
-
-				reuseport_result = reuseport_select_sock(sk, hash, skb,
-									 sizeof(struct udphdr));
-				if (reuseport_result && !reuseport_has_conns(sk, false))
-					return reuseport_result;
+			badness = score;
+			result = lookup_reuseport(net, sk, skb, saddr, sport, daddr, hnum);
+			if (!result) {
+				result = sk;
+				continue;
 			}
 
-			result = reuseport_result ? : sk;
-			badness = score;
+			/* Fall back to scoring if group has connections */
+			if (!reuseport_has_conns(sk))
+				return result;
+
+			/* Reuseport logic returned an error, keep original score. */
+			if (IS_ERR(result))
+				continue;
+
+			badness = compute_score(sk, net, saddr, sport,
+						daddr, hnum, dif, sdif);
 		}
 	}
 	return result;
+}
+
+static inline struct sock *udp6_lookup_run_bpf(struct net *net,
+					       struct udp_table *udptable,
+					       struct sk_buff *skb,
+					       const struct in6_addr *saddr,
+					       __be16 sport,
+					       const struct in6_addr *daddr,
+					       u16 hnum)
+{
+	struct sock *sk, *reuse_sk;
+	bool no_reuseport;
+
+	if (udptable != &udp_table)
+		return NULL; /* only UDP is supported */
+
+	no_reuseport = bpf_sk_lookup_run_v6(net, IPPROTO_UDP,
+					    saddr, sport, daddr, hnum, &sk);
+	if (no_reuseport || IS_ERR_OR_NULL(sk))
+		return sk;
+
+	reuse_sk = lookup_reuseport(net, sk, skb, saddr, sport, daddr, hnum);
+	if (reuse_sk)
+		sk = reuse_sk;
+	return sk;
 }
 
 /* rcu_read_lock() must be held */
@@ -204,66 +243,47 @@
 			       int dif, int sdif, struct udp_table *udptable,
 			       struct sk_buff *skb)
 {
-	struct sock *sk, *result;
 	unsigned short hnum = ntohs(dport);
-	unsigned int hash2, slot2, slot = udp_hashfn(net, hnum, udptable->mask);
-	struct udp_hslot *hslot2, *hslot = &udptable->hash[slot];
-	bool exact_dif = udp6_lib_exact_dif_match(net, skb);
-	int score, badness;
-	u32 hash = 0;
+	unsigned int hash2, slot2;
+	struct udp_hslot *hslot2;
+	struct sock *result, *sk;
 
-	if (hslot->count > 10) {
-		hash2 = ipv6_portaddr_hash(net, daddr, hnum);
-		slot2 = hash2 & udptable->mask;
-		hslot2 = &udptable->hash2[slot2];
-		if (hslot->count < hslot2->count)
-			goto begin;
+	hash2 = ipv6_portaddr_hash(net, daddr, hnum);
+	slot2 = hash2 & udptable->mask;
+	hslot2 = &udptable->hash2[slot2];
 
-		result = udp6_lib_lookup2(net, saddr, sport,
-					  daddr, hnum, dif, sdif, exact_dif,
-					  hslot2, skb);
-		if (!result) {
-			unsigned int old_slot2 = slot2;
-			hash2 = ipv6_portaddr_hash(net, &in6addr_any, hnum);
-			slot2 = hash2 & udptable->mask;
-			/* avoid searching the same slot again. */
-			if (unlikely(slot2 == old_slot2))
-				return result;
+	/* Lookup connected or non-wildcard sockets */
+	result = udp6_lib_lookup2(net, saddr, sport,
+				  daddr, hnum, dif, sdif,
+				  hslot2, skb);
+	if (!IS_ERR_OR_NULL(result) && result->sk_state == TCP_ESTABLISHED)
+		goto done;
 
-			hslot2 = &udptable->hash2[slot2];
-			if (hslot->count < hslot2->count)
-				goto begin;
-
-			result = udp6_lib_lookup2(net, saddr, sport,
-						  daddr, hnum, dif, sdif,
-						  exact_dif, hslot2,
-						  skb);
-		}
-		if (unlikely(IS_ERR(result)))
-			return NULL;
-		return result;
-	}
-begin:
-	result = NULL;
-	badness = -1;
-	sk_for_each_rcu(sk, &hslot->head) {
-		score = compute_score(sk, net, saddr, sport, daddr, hnum, dif,
-				      sdif, exact_dif);
-		if (score > badness) {
-			if (sk->sk_reuseport) {
-				hash = udp6_ehashfn(net, daddr, hnum,
-						    saddr, sport);
-				result = reuseport_select_sock(sk, hash, skb,
-							sizeof(struct udphdr));
-				if (unlikely(IS_ERR(result)))
-					return NULL;
-				if (result)
-					return result;
-			}
+	/* Lookup redirect from BPF */
+	if (static_branch_unlikely(&bpf_sk_lookup_enabled)) {
+		sk = udp6_lookup_run_bpf(net, udptable, skb,
+					 saddr, sport, daddr, hnum);
+		if (sk) {
 			result = sk;
-			badness = score;
+			goto done;
 		}
 	}
+
+	/* Got non-wildcard socket or error on first lookup */
+	if (result)
+		goto done;
+
+	/* Lookup wildcard sockets */
+	hash2 = ipv6_portaddr_hash(net, &in6addr_any, hnum);
+	slot2 = hash2 & udptable->mask;
+	hslot2 = &udptable->hash2[slot2];
+
+	result = udp6_lib_lookup2(net, saddr, sport,
+				  &in6addr_any, hnum, dif, sdif,
+				  hslot2, skb);
+done:
+	if (IS_ERR(result))
+		return NULL;
 	return result;
 }
 EXPORT_SYMBOL_GPL(__udp6_lib_lookup);
@@ -329,9 +349,9 @@
 	struct inet_sock *inet = inet_sk(sk);
 	struct sk_buff *skb;
 	unsigned int ulen, copied;
-	int peeked, peeking, off;
-	int err;
+	int off, err, peeking = flags & MSG_PEEK;
 	int is_udplite = IS_UDPLITE(sk);
+	struct udp_mib __percpu *mib;
 	bool checksum_valid = false;
 	int is_udp4;
 
@@ -342,9 +362,8 @@
 		return ipv6_recv_rxpmtu(sk, msg, len, addr_len);
 
 try_again:
-	peeking = flags & MSG_PEEK;
 	off = sk_peek_offset(sk, flags);
-	skb = __skb_recv_udp(sk, flags, noblock, &peeked, &off, &err);
+	skb = __skb_recv_udp(sk, flags, noblock, &off, &err);
 	if (!skb)
 		return err;
 
@@ -356,6 +375,7 @@
 		msg->msg_flags |= MSG_TRUNC;
 
 	is_udp4 = (skb->protocol == htons(ETH_P_IP));
+	mib = __UDPX_MIB(sk, is_udp4);
 
 	/*
 	 * If checksum is needed at all, try to do it while copying the
@@ -382,26 +402,15 @@
 			goto csum_copy_err;
 	}
 	if (unlikely(err)) {
-		if (!peeked) {
+		if (!peeking) {
 			atomic_inc(&sk->sk_drops);
-			if (is_udp4)
-				UDP_INC_STATS(sock_net(sk), UDP_MIB_INERRORS,
-					      is_udplite);
-			else
-				UDP6_INC_STATS(sock_net(sk), UDP_MIB_INERRORS,
-					       is_udplite);
+			SNMP_INC_STATS(mib, UDP_MIB_INERRORS);
 		}
 		kfree_skb(skb);
 		return err;
 	}
-	if (!peeked) {
-		if (is_udp4)
-			UDP_INC_STATS(sock_net(sk), UDP_MIB_INDATAGRAMS,
-				      is_udplite);
-		else
-			UDP6_INC_STATS(sock_net(sk), UDP_MIB_INDATAGRAMS,
-				       is_udplite);
-	}
+	if (!peeking)
+		SNMP_INC_STATS(mib, UDP_MIB_INDATAGRAMS);
 
 	sock_recv_ts_and_drops(msg, sk, skb);
 
@@ -429,6 +438,9 @@
 						(struct sockaddr *)sin6);
 	}
 
+	if (udp_sk(sk)->gro_enabled)
+		udp_cmsg_recv(msg, sk, skb);
+
 	if (np->rxopt.all)
 		ip6_datagram_recv_common_ctl(sk, msg, skb);
 
@@ -451,17 +463,8 @@
 csum_copy_err:
 	if (!__sk_queue_drop_skb(sk, &udp_sk(sk)->reader_queue, skb, flags,
 				 udp_skb_destructor)) {
-		if (is_udp4) {
-			UDP_INC_STATS(sock_net(sk),
-				      UDP_MIB_CSUMERRORS, is_udplite);
-			UDP_INC_STATS(sock_net(sk),
-				      UDP_MIB_INERRORS, is_udplite);
-		} else {
-			UDP6_INC_STATS(sock_net(sk),
-				       UDP_MIB_CSUMERRORS, is_udplite);
-			UDP6_INC_STATS(sock_net(sk),
-				       UDP_MIB_INERRORS, is_udplite);
-		}
+		SNMP_INC_STATS(mib, UDP_MIB_CSUMERRORS);
+		SNMP_INC_STATS(mib, UDP_MIB_INERRORS);
 	}
 	kfree_skb(skb);
 
@@ -471,26 +474,133 @@
 	goto try_again;
 }
 
-void __udp6_lib_err(struct sk_buff *skb, struct inet6_skb_parm *opt,
-		    u8 type, u8 code, int offset, __be32 info,
-		    struct udp_table *udptable)
+DEFINE_STATIC_KEY_FALSE(udpv6_encap_needed_key);
+void udpv6_encap_enable(void)
+{
+	static_branch_inc(&udpv6_encap_needed_key);
+}
+EXPORT_SYMBOL(udpv6_encap_enable);
+
+/* Handler for tunnels with arbitrary destination ports: no socket lookup, go
+ * through error handlers in encapsulations looking for a match.
+ */
+static int __udp6_lib_err_encap_no_sk(struct sk_buff *skb,
+				      struct inet6_skb_parm *opt,
+				      u8 type, u8 code, int offset, __be32 info)
+{
+	int i;
+
+	for (i = 0; i < MAX_IPTUN_ENCAP_OPS; i++) {
+		int (*handler)(struct sk_buff *skb, struct inet6_skb_parm *opt,
+			       u8 type, u8 code, int offset, __be32 info);
+		const struct ip6_tnl_encap_ops *encap;
+
+		encap = rcu_dereference(ip6tun_encaps[i]);
+		if (!encap)
+			continue;
+		handler = encap->err_handler;
+		if (handler && !handler(skb, opt, type, code, offset, info))
+			return 0;
+	}
+
+	return -ENOENT;
+}
+
+/* Try to match ICMP errors to UDP tunnels by looking up a socket without
+ * reversing source and destination port: this will match tunnels that force the
+ * same destination port on both endpoints (e.g. VXLAN, GENEVE). Note that
+ * lwtunnels might actually break this assumption by being configured with
+ * different destination ports on endpoints, in this case we won't be able to
+ * trace ICMP messages back to them.
+ *
+ * If this doesn't match any socket, probe tunnels with arbitrary destination
+ * ports (e.g. FoU, GUE): there, the receiving socket is useless, as the port
+ * we've sent packets to won't necessarily match the local destination port.
+ *
+ * Then ask the tunnel implementation to match the error against a valid
+ * association.
+ *
+ * Return an error if we can't find a match, the socket if we need further
+ * processing, zero otherwise.
+ */
+static struct sock *__udp6_lib_err_encap(struct net *net,
+					 const struct ipv6hdr *hdr, int offset,
+					 struct udphdr *uh,
+					 struct udp_table *udptable,
+					 struct sk_buff *skb,
+					 struct inet6_skb_parm *opt,
+					 u8 type, u8 code, __be32 info)
+{
+	int network_offset, transport_offset;
+	struct sock *sk;
+
+	network_offset = skb_network_offset(skb);
+	transport_offset = skb_transport_offset(skb);
+
+	/* Network header needs to point to the outer IPv6 header inside ICMP */
+	skb_reset_network_header(skb);
+
+	/* Transport header needs to point to the UDP header */
+	skb_set_transport_header(skb, offset);
+
+	sk = __udp6_lib_lookup(net, &hdr->daddr, uh->source,
+			       &hdr->saddr, uh->dest,
+			       inet6_iif(skb), 0, udptable, skb);
+	if (sk) {
+		int (*lookup)(struct sock *sk, struct sk_buff *skb);
+		struct udp_sock *up = udp_sk(sk);
+
+		lookup = READ_ONCE(up->encap_err_lookup);
+		if (!lookup || lookup(sk, skb))
+			sk = NULL;
+	}
+
+	if (!sk) {
+		sk = ERR_PTR(__udp6_lib_err_encap_no_sk(skb, opt, type, code,
+							offset, info));
+	}
+
+	skb_set_transport_header(skb, transport_offset);
+	skb_set_network_header(skb, network_offset);
+
+	return sk;
+}
+
+int __udp6_lib_err(struct sk_buff *skb, struct inet6_skb_parm *opt,
+		   u8 type, u8 code, int offset, __be32 info,
+		   struct udp_table *udptable)
 {
 	struct ipv6_pinfo *np;
 	const struct ipv6hdr *hdr = (const struct ipv6hdr *)skb->data;
 	const struct in6_addr *saddr = &hdr->saddr;
 	const struct in6_addr *daddr = &hdr->daddr;
 	struct udphdr *uh = (struct udphdr *)(skb->data+offset);
+	bool tunnel = false;
 	struct sock *sk;
 	int harderr;
 	int err;
 	struct net *net = dev_net(skb->dev);
 
 	sk = __udp6_lib_lookup(net, daddr, uh->dest, saddr, uh->source,
-			       inet6_iif(skb), 0, udptable, NULL);
+			       inet6_iif(skb), inet6_sdif(skb), udptable, NULL);
 	if (!sk) {
-		__ICMP6_INC_STATS(net, __in6_dev_get(skb->dev),
-				  ICMP6_MIB_INERRORS);
-		return;
+		/* No socket for error: try tunnels before discarding */
+		sk = ERR_PTR(-ENOENT);
+		if (static_branch_unlikely(&udpv6_encap_needed_key)) {
+			sk = __udp6_lib_err_encap(net, hdr, offset, uh,
+						  udptable, skb,
+						  opt, type, code, info);
+			if (!sk)
+				return 0;
+		}
+
+		if (IS_ERR(sk)) {
+			__ICMP6_INC_STATS(net, __in6_dev_get(skb->dev),
+					  ICMP6_MIB_INERRORS);
+			return PTR_ERR(sk);
+		}
+
+		tunnel = true;
 	}
 
 	harderr = icmpv6_err_convert(type, code, &err);
@@ -504,9 +614,18 @@
 			harderr = 1;
 	}
 	if (type == NDISC_REDIRECT) {
-		ip6_sk_redirect(skb, sk);
+		if (tunnel) {
+			ip6_redirect(skb, sock_net(sk), inet6_iif(skb),
+				     sk->sk_mark, sk->sk_uid);
+		} else {
+			ip6_sk_redirect(skb, sk);
+		}
 		goto out;
 	}
+
+	/* Tunnels don't have an application socket: don't pass errors back */
+	if (tunnel)
+		goto out;
 
 	if (!np->recverr) {
 		if (!harderr || sk->sk_state != TCP_ESTABLISHED)
@@ -518,7 +637,7 @@
 	sk->sk_err = err;
 	sk->sk_error_report(sk);
 out:
-	return;
+	return 0;
 }
 
 static int __udpv6_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
@@ -549,21 +668,14 @@
 	return 0;
 }
 
-static __inline__ void udpv6_err(struct sk_buff *skb,
-				 struct inet6_skb_parm *opt, u8 type,
-				 u8 code, int offset, __be32 info)
+static __inline__ int udpv6_err(struct sk_buff *skb,
+				struct inet6_skb_parm *opt, u8 type,
+				u8 code, int offset, __be32 info)
 {
-	__udp6_lib_err(skb, opt, type, code, offset, info, &udp_table);
+	return __udp6_lib_err(skb, opt, type, code, offset, info, &udp_table);
 }
 
-static DEFINE_STATIC_KEY_FALSE(udpv6_encap_needed_key);
-void udpv6_encap_enable(void)
-{
-	static_branch_enable(&udpv6_encap_needed_key);
-}
-EXPORT_SYMBOL(udpv6_encap_enable);
-
-static int udpv6_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
+static int udpv6_queue_rcv_one_skb(struct sock *sk, struct sk_buff *skb)
 {
 	struct udp_sock *up = udp_sk(sk);
 	int is_udplite = IS_UDPLITE(sk);
@@ -646,10 +758,31 @@
 	return -1;
 }
 
+static int udpv6_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
+{
+	struct sk_buff *next, *segs;
+	int ret;
+
+	if (likely(!udp_unexpected_gso(sk, skb)))
+		return udpv6_queue_rcv_one_skb(sk, skb);
+
+	__skb_push(skb, -skb_mac_offset(skb));
+	segs = udp_rcv_segment(sk, skb, false);
+	skb_list_walk_safe(segs, skb, next) {
+		__skb_pull(skb, skb_transport_offset(skb));
+
+		ret = udpv6_queue_rcv_one_skb(sk, skb);
+		if (ret > 0)
+			ip6_protocol_deliver_rcu(dev_net(skb->dev), skb, ret,
+						 true);
+	}
+	return 0;
+}
+
 static bool __udp_v6_is_mcast_sock(struct net *net, struct sock *sk,
 				   __be16 loc_port, const struct in6_addr *loc_addr,
 				   __be16 rmt_port, const struct in6_addr *rmt_addr,
-				   int dif, unsigned short hnum)
+				   int dif, int sdif, unsigned short hnum)
 {
 	struct inet_sock *inet = inet_sk(sk);
 
@@ -661,7 +794,7 @@
 	    (inet->inet_dport && inet->inet_dport != rmt_port) ||
 	    (!ipv6_addr_any(&sk->sk_v6_daddr) &&
 		    !ipv6_addr_equal(&sk->sk_v6_daddr, rmt_addr)) ||
-	    (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif) ||
+	    !udp_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif) ||
 	    (!ipv6_addr_any(&sk->sk_v6_rcv_saddr) &&
 		    !ipv6_addr_equal(&sk->sk_v6_rcv_saddr, loc_addr)))
 		return false;
@@ -695,6 +828,7 @@
 	unsigned int offset = offsetof(typeof(*sk), sk_node);
 	unsigned int hash2 = 0, hash2_any = 0, use_hash2 = (hslot->count > 10);
 	int dif = inet6_iif(skb);
+	int sdif = inet6_sdif(skb);
 	struct hlist_node *node;
 	struct sk_buff *nskb;
 
@@ -709,7 +843,8 @@
 
 	sk_for_each_entry_offset_rcu(sk, node, &hslot->head, offset) {
 		if (!__udp_v6_is_mcast_sock(net, sk, uh->dest, daddr,
-					    uh->source, saddr, dif, hnum))
+					    uh->source, saddr, dif, sdif,
+					    hnum))
 			continue;
 		/* If zero checksum and no_check is not on for
 		 * the socket then skip it.
@@ -769,8 +904,7 @@
 	int ret;
 
 	if (inet_get_convert_csum(sk) && uh->check && !IS_UDPLITE(sk))
-		skb_checksum_try_convert(skb, IPPROTO_UDP, uh->check,
-					 ip6_compute_pseudo);
+		skb_checksum_try_convert(skb, IPPROTO_UDP, ip6_compute_pseudo);
 
 	ret = udpv6_queue_rcv_skb(sk, skb);
 
@@ -787,6 +921,7 @@
 	struct net *net = dev_net(skb->dev);
 	struct udphdr *uh;
 	struct sock *sk;
+	bool refcounted;
 	u32 ulen = 0;
 
 	if (!pskb_may_pull(skb, sizeof(struct udphdr)))
@@ -823,21 +958,23 @@
 		goto csum_error;
 
 	/* Check if the socket is already available, e.g. due to early demux */
-	sk = skb_steal_sock(skb);
+	sk = skb_steal_sock(skb, &refcounted);
 	if (sk) {
 		struct dst_entry *dst = skb_dst(skb);
 		int ret;
 
-		if (unlikely(sk->sk_rx_dst != dst))
+		if (unlikely(rcu_dereference(sk->sk_rx_dst) != dst))
 			udp6_sk_rx_dst_set(sk, dst);
 
 		if (!uh->check && !udp_sk(sk)->no_check6_rx) {
-			sock_put(sk);
+			if (refcounted)
+				sock_put(sk);
 			goto report_csum_error;
 		}
 
 		ret = udp6_unicast_rcv_skb(sk, skb, uh);
-		sock_put(sk);
+		if (refcounted)
+			sock_put(sk);
 		return ret;
 	}
 
@@ -904,7 +1041,7 @@
 
 	udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
 		if (sk->sk_state == TCP_ESTABLISHED &&
-		    INET6_MATCH(sk, net, rmt_addr, loc_addr, ports, dif, sdif))
+		    inet6_match(net, sk, rmt_addr, loc_addr, ports, dif, sdif))
 			return sk;
 		/* Only check first socket in chain */
 		break;
@@ -912,7 +1049,7 @@
 	return NULL;
 }
 
-static void udp_v6_early_demux(struct sk_buff *skb)
+void udp_v6_early_demux(struct sk_buff *skb)
 {
 	struct net *net = dev_net(skb->dev);
 	const struct udphdr *uh;
@@ -940,7 +1077,7 @@
 
 	skb->sk = sk;
 	skb->destructor = sock_efree;
-	dst = READ_ONCE(sk->sk_rx_dst);
+	dst = rcu_dereference(sk->sk_rx_dst);
 
 	if (dst)
 		dst = dst_check(dst, inet6_sk(sk)->rx_dst_cookie);
@@ -953,7 +1090,7 @@
 	}
 }
 
-static __inline__ int udpv6_rcv(struct sk_buff *skb)
+INDIRECT_CALLABLE_SCOPE int udpv6_rcv(struct sk_buff *skb)
 {
 	return __udp6_lib_rcv(skb, &udp_table, IPPROTO_UDP);
 }
@@ -977,6 +1114,8 @@
 static int udpv6_pre_connect(struct sock *sk, struct sockaddr *uaddr,
 			     int addr_len)
 {
+	if (addr_len < offsetofend(struct sockaddr, sa_family))
+		return -EINVAL;
 	/* The following checks are replicated from __ip6_datagram_connect()
 	 * and intended to prevent BPF program called below from accessing
 	 * bytes that are out of the bound specified by user in addr_len.
@@ -998,6 +1137,9 @@
  *	@sk:	socket we are sending on
  *	@skb:	sk_buff containing the filled-in UDP header
  *		(checksum field must be zeroed out)
+ *	@saddr: source address
+ *	@daddr: destination address
+ *	@len:	length of packet
  */
 static void udp6_hwcsum_outgoing(struct sock *sk, struct sk_buff *skb,
 				 const struct in6_addr *saddr,
@@ -1177,6 +1319,7 @@
 	ipcm6_init(&ipc6);
 	ipc6.gso_size = READ_ONCE(up->gso_size);
 	ipc6.sockc.tsflags = sk->sk_tsflags;
+	ipc6.sockc.mark = sk->sk_mark;
 
 	/* destination address check */
 	if (sin6) {
@@ -1219,9 +1362,11 @@
 			msg->msg_name = &sin;
 			msg->msg_namelen = sizeof(sin);
 do_udp_sendmsg:
-			if (__ipv6_only_sock(sk))
-				return -ENETUNREACH;
-			return udp_sendmsg(sk, msg, len);
+			err = __ipv6_only_sock(sk) ?
+				-ENETUNREACH : udp_sendmsg(sk, msg, len);
+			msg->msg_name = sin6;
+			msg->msg_namelen = addr_len;
+			return err;
 		}
 	}
 
@@ -1266,7 +1411,7 @@
 			fl6.flowlabel = sin6->sin6_flowinfo&IPV6_FLOWINFO_MASK;
 			if (fl6.flowlabel&IPV6_FLOWLABEL_MASK) {
 				flowlabel = fl6_sock_lookup(sk, fl6.flowlabel);
-				if (!flowlabel)
+				if (IS_ERR(flowlabel))
 					return -EINVAL;
 			}
 		}
@@ -1299,7 +1444,6 @@
 	if (!fl6.flowi6_oif)
 		fl6.flowi6_oif = np->sticky_pktinfo.ipi6_ifindex;
 
-	fl6.flowi6_mark = sk->sk_mark;
 	fl6.flowi6_uid = sk->sk_uid;
 
 	if (msg->msg_controllen) {
@@ -1318,7 +1462,7 @@
 		}
 		if ((fl6.flowlabel&IPV6_FLOWLABEL_MASK) && !flowlabel) {
 			flowlabel = fl6_sock_lookup(sk, fl6.flowlabel);
-			if (!flowlabel)
+			if (IS_ERR(flowlabel))
 				return -EINVAL;
 		}
 		if (!(opt->opt_nflen|opt->opt_flen))
@@ -1335,6 +1479,7 @@
 	ipc6.opt = opt;
 
 	fl6.flowi6_proto = sk->sk_protocol;
+	fl6.flowi6_mark = ipc6.sockc.mark;
 	fl6.daddr = *daddr;
 	if (ipv6_addr_any(&fl6.saddr) && !ipv6_addr_any(&np->saddr))
 		fl6.saddr = np->saddr;
@@ -1376,7 +1521,7 @@
 	} else if (!fl6.flowi6_oif)
 		fl6.flowi6_oif = np->ucast_oif;
 
-	security_sk_classify_flow(sk, flowi6_to_flowi(&fl6));
+	security_sk_classify_flow(sk, flowi6_to_flowi_common(&fl6));
 
 	if (ipc6.tclass < 0)
 		ipc6.tclass = np->tclass;
@@ -1482,38 +1627,32 @@
 	udp_v6_flush_pending_frames(sk);
 	release_sock(sk);
 
-	if (static_branch_unlikely(&udpv6_encap_needed_key) && up->encap_type) {
-		void (*encap_destroy)(struct sock *sk);
-		encap_destroy = READ_ONCE(up->encap_destroy);
-		if (encap_destroy)
-			encap_destroy(sk);
+	if (static_branch_unlikely(&udpv6_encap_needed_key)) {
+		if (up->encap_type) {
+			void (*encap_destroy)(struct sock *sk);
+			encap_destroy = READ_ONCE(up->encap_destroy);
+			if (encap_destroy)
+				encap_destroy(sk);
+		}
+		if (up->encap_enabled) {
+			static_branch_dec(&udpv6_encap_needed_key);
+			udp_encap_disable();
+		}
 	}
-
-	inet6_destroy_sock(sk);
 }
 
 /*
  *	Socket option code for UDP
  */
-int udpv6_setsockopt(struct sock *sk, int level, int optname,
-		     char __user *optval, unsigned int optlen)
+int udpv6_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval,
+		     unsigned int optlen)
 {
 	if (level == SOL_UDP  ||  level == SOL_UDPLITE)
-		return udp_lib_setsockopt(sk, level, optname, optval, optlen,
+		return udp_lib_setsockopt(sk, level, optname,
+					  optval, optlen,
 					  udp_v6_push_pending_frames);
 	return ipv6_setsockopt(sk, level, optname, optval, optlen);
 }
-
-#ifdef CONFIG_COMPAT
-int compat_udpv6_setsockopt(struct sock *sk, int level, int optname,
-			    char __user *optval, unsigned int optlen)
-{
-	if (level == SOL_UDP  ||  level == SOL_UDPLITE)
-		return udp_lib_setsockopt(sk, level, optname, optval, optlen,
-					  udp_v6_push_pending_frames);
-	return compat_ipv6_setsockopt(sk, level, optname, optval, optlen);
-}
-#endif
 
 int udpv6_getsockopt(struct sock *sk, int level, int optname,
 		     char __user *optval, int __user *optlen)
@@ -1523,22 +1662,7 @@
 	return ipv6_getsockopt(sk, level, optname, optval, optlen);
 }
 
-#ifdef CONFIG_COMPAT
-int compat_udpv6_getsockopt(struct sock *sk, int level, int optname,
-			    char __user *optval, int __user *optlen)
-{
-	if (level == SOL_UDP  ||  level == SOL_UDPLITE)
-		return udp_lib_getsockopt(sk, level, optname, optval, optlen);
-	return compat_ipv6_getsockopt(sk, level, optname, optval, optlen);
-}
-#endif
-
-/* thinking of making this const? Don't.
- * early_demux can change based on sysctl.
- */
-static struct inet6_protocol udpv6_protocol = {
-	.early_demux	=	udp_v6_early_demux,
-	.early_demux_handler =  udp_v6_early_demux,
+static const struct inet6_protocol udpv6_protocol = {
 	.handler	=	udpv6_rcv,
 	.err_handler	=	udpv6_err,
 	.flags		=	INET6_PROTO_NOPOLICY|INET6_PROTO_FINAL,
@@ -1598,7 +1722,7 @@
 	.connect		= ip6_datagram_connect,
 	.disconnect		= udp_disconnect,
 	.ioctl			= udp_ioctl,
-	.init			= udp_init_sock,
+	.init			= udpv6_init_sock,
 	.destroy		= udpv6_destroy_sock,
 	.setsockopt		= udpv6_setsockopt,
 	.getsockopt		= udpv6_getsockopt,
@@ -1615,10 +1739,6 @@
 	.sysctl_rmem_offset     = offsetof(struct net, ipv4.sysctl_udp_rmem_min),
 	.obj_size		= sizeof(struct udp6_sock),
 	.h.udp_table		= &udp_table,
-#ifdef CONFIG_COMPAT
-	.compat_setsockopt	= compat_udpv6_setsockopt,
-	.compat_getsockopt	= compat_udpv6_getsockopt,
-#endif
 	.diag_destroy		= udp_abort,
 };
 

--
Gitblit v1.6.2