From 04dd17822334871b23ea2862f7798fb0e0007777 Mon Sep 17 00:00:00 2001
From: hc <hc@nodka.com>
Date: Sat, 11 May 2024 08:53:19 +0000
Subject: [PATCH] change otg to host mode

---
 kernel/net/ipv4/udp.c |  723 +++++++++++++++++++++++++++++++++++++++----------------
 1 files changed, 513 insertions(+), 210 deletions(-)

diff --git a/kernel/net/ipv4/udp.c b/kernel/net/ipv4/udp.c
index b7acb6a..22862b0 100644
--- a/kernel/net/ipv4/udp.c
+++ b/kernel/net/ipv4/udp.c
@@ -1,3 +1,4 @@
+// SPDX-License-Identifier: GPL-2.0-or-later
 /*
  * INET		An implementation of the TCP/IP protocol suite for the LINUX
  *		operating system.  INET is implemented using the  BSD Socket
@@ -69,19 +70,13 @@
  *					a single port at the same time.
  *	Derek Atkins <derek@ihtfp.com>: Add Encapulation Support
  *	James Chapman		:	Add L2TP encapsulation type.
- *
- *
- *		This program is free software; you can redistribute it and/or
- *		modify it under the terms of the GNU General Public License
- *		as published by the Free Software Foundation; either version
- *		2 of the License, or (at your option) any later version.
  */
 
 #define pr_fmt(fmt) "UDP: " fmt
 
 #include <linux/uaccess.h>
 #include <asm/ioctls.h>
-#include <linux/bootmem.h>
+#include <linux/memblock.h>
 #include <linux/highmem.h>
 #include <linux/swap.h>
 #include <linux/types.h>
@@ -105,16 +100,23 @@
 #include <net/net_namespace.h>
 #include <net/icmp.h>
 #include <net/inet_hashtables.h>
+#include <net/ip_tunnels.h>
 #include <net/route.h>
 #include <net/checksum.h>
 #include <net/xfrm.h>
 #include <trace/events/udp.h>
 #include <linux/static_key.h>
+#include <linux/btf_ids.h>
 #include <trace/events/skb.h>
 #include <net/busy_poll.h>
 #include "udp_impl.h"
 #include <net/sock_reuseport.h>
 #include <net/addrconf.h>
+#include <net/udp_tunnel.h>
+#if IS_ENABLED(CONFIG_IPV6)
+#include <net/ipv6_stubs.h>
+#endif
+#include <trace/hooks/ipv4.h>
 
 struct udp_table udp_table __read_mostly;
 EXPORT_SYMBOL(udp_table);
@@ -127,17 +129,6 @@
 
 #define MAX_UDP_PORTS 65536
 #define PORTS_PER_CHAIN (MAX_UDP_PORTS / UDP_HTABLE_SIZE_MIN)
-
-/* IPCB reference means this can not be used from early demux */
-static bool udp_lib_exact_dif_match(struct net *net, struct sk_buff *skb)
-{
-#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV)
-	if (!net->ipv4.sysctl_udp_l3mdev_accept &&
-	    skb && ipv4_l3mdev_skb(IPCB(skb)->flags))
-		return true;
-#endif
-	return false;
-}
 
 static int udp_lib_lport_inuse(struct net *net, __u16 num,
 			       const struct udp_hslot *hslot,
@@ -367,25 +358,23 @@
 static int compute_score(struct sock *sk, struct net *net,
 			 __be32 saddr, __be16 sport,
 			 __be32 daddr, unsigned short hnum,
-			 int dif, int sdif, bool exact_dif)
+			 int dif, int sdif)
 {
 	int score;
 	struct inet_sock *inet;
+	bool dev_match;
 
 	if (!net_eq(sock_net(sk), net) ||
 	    udp_sk(sk)->udp_port_hash != hnum ||
 	    ipv6_only_sock(sk))
 		return -1;
 
+	if (sk->sk_rcv_saddr != daddr)
+		return -1;
+
 	score = (sk->sk_family == PF_INET) ? 2 : 1;
+
 	inet = inet_sk(sk);
-
-	if (inet->inet_rcv_saddr) {
-		if (inet->inet_rcv_saddr != daddr)
-			return -1;
-		score += 4;
-	}
-
 	if (inet->inet_daddr) {
 		if (inet->inet_daddr != saddr)
 			return -1;
@@ -398,15 +387,12 @@
 		score += 4;
 	}
 
-	if (sk->sk_bound_dev_if || exact_dif) {
-		bool dev_match = (sk->sk_bound_dev_if == dif ||
-				  sk->sk_bound_dev_if == sdif);
-
-		if (!dev_match)
-			return -1;
-		if (sk->sk_bound_dev_if)
-			score += 4;
-	}
+	dev_match = udp_sk_bound_dev_eq(net, sk->sk_bound_dev_if,
+					dif, sdif);
+	if (!dev_match)
+		return -1;
+	if (sk->sk_bound_dev_if)
+		score += 4;
 
 	if (READ_ONCE(sk->sk_incoming_cpu) == raw_smp_processor_id())
 		score++;
@@ -425,41 +411,83 @@
 			      udp_ehash_secret + net_hash_mix(net));
 }
 
+static struct sock *lookup_reuseport(struct net *net, struct sock *sk,
+				     struct sk_buff *skb,
+				     __be32 saddr, __be16 sport,
+				     __be32 daddr, unsigned short hnum)
+{
+	struct sock *reuse_sk = NULL;
+	u32 hash;
+
+	if (sk->sk_reuseport && sk->sk_state != TCP_ESTABLISHED) {
+		hash = udp_ehashfn(net, daddr, hnum, saddr, sport);
+		reuse_sk = reuseport_select_sock(sk, hash, skb,
+						 sizeof(struct udphdr));
+	}
+	return reuse_sk;
+}
+
 /* called with rcu_read_lock() */
 static struct sock *udp4_lib_lookup2(struct net *net,
 				     __be32 saddr, __be16 sport,
 				     __be32 daddr, unsigned int hnum,
-				     int dif, int sdif, bool exact_dif,
+				     int dif, int sdif,
 				     struct udp_hslot *hslot2,
 				     struct sk_buff *skb)
 {
-	struct sock *sk, *result, *reuseport_result;
+	struct sock *sk, *result;
 	int score, badness;
-	u32 hash = 0;
 
 	result = NULL;
 	badness = 0;
 	udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
 		score = compute_score(sk, net, saddr, sport,
-				      daddr, hnum, dif, sdif, exact_dif);
+				      daddr, hnum, dif, sdif);
 		if (score > badness) {
-			reuseport_result = NULL;
-
-			if (sk->sk_reuseport &&
-			    sk->sk_state != TCP_ESTABLISHED) {
-				hash = udp_ehashfn(net, daddr, hnum,
-						   saddr, sport);
-				reuseport_result = reuseport_select_sock(sk, hash, skb,
-									 sizeof(struct udphdr));
-				if (reuseport_result && !reuseport_has_conns(sk, false))
-					return reuseport_result;
+			badness = score;
+			result = lookup_reuseport(net, sk, skb, saddr, sport, daddr, hnum);
+			if (!result) {
+				result = sk;
+				continue;
 			}
 
-			result = reuseport_result ? : sk;
-			badness = score;
+			/* Fall back to scoring if group has connections */
+			if (!reuseport_has_conns(sk))
+				return result;
+
+			/* Reuseport logic returned an error, keep original score. */
+			if (IS_ERR(result))
+				continue;
+
+			badness = compute_score(result, net, saddr, sport,
+						daddr, hnum, dif, sdif);
+
 		}
 	}
 	return result;
+}
+
+static struct sock *udp4_lookup_run_bpf(struct net *net,
+					struct udp_table *udptable,
+					struct sk_buff *skb,
+					__be32 saddr, __be16 sport,
+					__be32 daddr, u16 hnum)
+{
+	struct sock *sk, *reuse_sk;
+	bool no_reuseport;
+
+	if (udptable != &udp_table)
+		return NULL; /* only UDP is supported */
+
+	no_reuseport = bpf_sk_lookup_run_v4(net, IPPROTO_UDP,
+					    saddr, sport, daddr, hnum, &sk);
+	if (no_reuseport || IS_ERR_OR_NULL(sk))
+		return sk;
+
+	reuse_sk = lookup_reuseport(net, sk, skb, saddr, sport, daddr, hnum);
+	if (reuse_sk)
+		sk = reuse_sk;
+	return sk;
 }
 
 /* UDP is nearly always wildcards out the wazoo, it makes no sense to try
@@ -469,65 +497,47 @@
 		__be16 sport, __be32 daddr, __be16 dport, int dif,
 		int sdif, struct udp_table *udptable, struct sk_buff *skb)
 {
-	struct sock *sk, *result;
 	unsigned short hnum = ntohs(dport);
-	unsigned int hash2, slot2, slot = udp_hashfn(net, hnum, udptable->mask);
-	struct udp_hslot *hslot2, *hslot = &udptable->hash[slot];
-	bool exact_dif = udp_lib_exact_dif_match(net, skb);
-	int score, badness;
-	u32 hash = 0;
+	unsigned int hash2, slot2;
+	struct udp_hslot *hslot2;
+	struct sock *result, *sk;
 
-	if (hslot->count > 10) {
-		hash2 = ipv4_portaddr_hash(net, daddr, hnum);
-		slot2 = hash2 & udptable->mask;
-		hslot2 = &udptable->hash2[slot2];
-		if (hslot->count < hslot2->count)
-			goto begin;
+	hash2 = ipv4_portaddr_hash(net, daddr, hnum);
+	slot2 = hash2 & udptable->mask;
+	hslot2 = &udptable->hash2[slot2];
 
-		result = udp4_lib_lookup2(net, saddr, sport,
-					  daddr, hnum, dif, sdif,
-					  exact_dif, hslot2, skb);
-		if (!result) {
-			unsigned int old_slot2 = slot2;
-			hash2 = ipv4_portaddr_hash(net, htonl(INADDR_ANY), hnum);
-			slot2 = hash2 & udptable->mask;
-			/* avoid searching the same slot again. */
-			if (unlikely(slot2 == old_slot2))
-				return result;
+	/* Lookup connected or non-wildcard socket */
+	result = udp4_lib_lookup2(net, saddr, sport,
+				  daddr, hnum, dif, sdif,
+				  hslot2, skb);
+	if (!IS_ERR_OR_NULL(result) && result->sk_state == TCP_ESTABLISHED)
+		goto done;
 
-			hslot2 = &udptable->hash2[slot2];
-			if (hslot->count < hslot2->count)
-				goto begin;
-
-			result = udp4_lib_lookup2(net, saddr, sport,
-						  daddr, hnum, dif, sdif,
-						  exact_dif, hslot2, skb);
-		}
-		if (unlikely(IS_ERR(result)))
-			return NULL;
-		return result;
-	}
-begin:
-	result = NULL;
-	badness = 0;
-	sk_for_each_rcu(sk, &hslot->head) {
-		score = compute_score(sk, net, saddr, sport,
-				      daddr, hnum, dif, sdif, exact_dif);
-		if (score > badness) {
-			if (sk->sk_reuseport) {
-				hash = udp_ehashfn(net, daddr, hnum,
-						   saddr, sport);
-				result = reuseport_select_sock(sk, hash, skb,
-							sizeof(struct udphdr));
-				if (unlikely(IS_ERR(result)))
-					return NULL;
-				if (result)
-					return result;
-			}
+	/* Lookup redirect from BPF */
+	if (static_branch_unlikely(&bpf_sk_lookup_enabled)) {
+		sk = udp4_lookup_run_bpf(net, udptable, skb,
+					 saddr, sport, daddr, hnum);
+		if (sk) {
 			result = sk;
-			badness = score;
+			goto done;
 		}
 	}
+
+	/* Got non-wildcard socket or error on first lookup */
+	if (result)
+		goto done;
+
+	/* Lookup wildcard sockets */
+	hash2 = ipv4_portaddr_hash(net, htonl(INADDR_ANY), hnum);
+	slot2 = hash2 & udptable->mask;
+	hslot2 = &udptable->hash2[slot2];
+
+	result = udp4_lib_lookup2(net, saddr, sport,
+				  htonl(INADDR_ANY), hnum, dif, sdif,
+				  hslot2, skb);
+done:
+	if (IS_ERR(result))
+		return NULL;
 	return result;
 }
 EXPORT_SYMBOL_GPL(__udp4_lib_lookup);
@@ -585,12 +595,102 @@
 	    (inet->inet_dport != rmt_port && inet->inet_dport) ||
 	    (inet->inet_rcv_saddr && inet->inet_rcv_saddr != loc_addr) ||
 	    ipv6_only_sock(sk) ||
-	    (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif &&
-	     sk->sk_bound_dev_if != sdif))
+	    !udp_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif))
 		return false;
 	if (!ip_mc_sf_allow(sk, loc_addr, rmt_addr, dif, sdif))
 		return false;
 	return true;
+}
+
+DEFINE_STATIC_KEY_FALSE(udp_encap_needed_key);
+void udp_encap_enable(void)
+{
+	static_branch_inc(&udp_encap_needed_key);
+}
+EXPORT_SYMBOL(udp_encap_enable);
+
+void udp_encap_disable(void)
+{
+	static_branch_dec(&udp_encap_needed_key);
+}
+EXPORT_SYMBOL(udp_encap_disable);
+
+/* Handler for tunnels with arbitrary destination ports: no socket lookup, go
+ * through error handlers in encapsulations looking for a match.
+ */
+static int __udp4_lib_err_encap_no_sk(struct sk_buff *skb, u32 info)
+{
+	int i;
+
+	for (i = 0; i < MAX_IPTUN_ENCAP_OPS; i++) {
+		int (*handler)(struct sk_buff *skb, u32 info);
+		const struct ip_tunnel_encap_ops *encap;
+
+		encap = rcu_dereference(iptun_encaps[i]);
+		if (!encap)
+			continue;
+		handler = encap->err_handler;
+		if (handler && !handler(skb, info))
+			return 0;
+	}
+
+	return -ENOENT;
+}
+
+/* Try to match ICMP errors to UDP tunnels by looking up a socket without
+ * reversing source and destination port: this will match tunnels that force the
+ * same destination port on both endpoints (e.g. VXLAN, GENEVE). Note that
+ * lwtunnels might actually break this assumption by being configured with
+ * different destination ports on endpoints, in this case we won't be able to
+ * trace ICMP messages back to them.
+ *
+ * If this doesn't match any socket, probe tunnels with arbitrary destination
+ * ports (e.g. FoU, GUE): there, the receiving socket is useless, as the port
+ * we've sent packets to won't necessarily match the local destination port.
+ *
+ * Then ask the tunnel implementation to match the error against a valid
+ * association.
+ *
+ * Return an error if we can't find a match, the socket if we need further
+ * processing, zero otherwise.
+ */
+static struct sock *__udp4_lib_err_encap(struct net *net,
+					 const struct iphdr *iph,
+					 struct udphdr *uh,
+					 struct udp_table *udptable,
+					 struct sk_buff *skb, u32 info)
+{
+	int network_offset, transport_offset;
+	struct sock *sk;
+
+	network_offset = skb_network_offset(skb);
+	transport_offset = skb_transport_offset(skb);
+
+	/* Network header needs to point to the outer IPv4 header inside ICMP */
+	skb_reset_network_header(skb);
+
+	/* Transport header needs to point to the UDP header */
+	skb_set_transport_header(skb, iph->ihl << 2);
+
+	sk = __udp4_lib_lookup(net, iph->daddr, uh->source,
+			       iph->saddr, uh->dest, skb->dev->ifindex, 0,
+			       udptable, NULL);
+	if (sk) {
+		int (*lookup)(struct sock *sk, struct sk_buff *skb);
+		struct udp_sock *up = udp_sk(sk);
+
+		lookup = READ_ONCE(up->encap_err_lookup);
+		if (!lookup || lookup(sk, skb))
+			sk = NULL;
+	}
+
+	if (!sk)
+		sk = ERR_PTR(__udp4_lib_err_encap_no_sk(skb, info));
+
+	skb_set_transport_header(skb, transport_offset);
+	skb_set_network_header(skb, network_offset);
+
+	return sk;
 }
 
 /*
@@ -604,24 +704,38 @@
  * to find the appropriate port.
  */
 
-void __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable)
+int __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable)
 {
 	struct inet_sock *inet;
 	const struct iphdr *iph = (const struct iphdr *)skb->data;
 	struct udphdr *uh = (struct udphdr *)(skb->data+(iph->ihl<<2));
 	const int type = icmp_hdr(skb)->type;
 	const int code = icmp_hdr(skb)->code;
+	bool tunnel = false;
 	struct sock *sk;
 	int harderr;
 	int err;
 	struct net *net = dev_net(skb->dev);
 
 	sk = __udp4_lib_lookup(net, iph->daddr, uh->dest,
-			       iph->saddr, uh->source, skb->dev->ifindex, 0,
-			       udptable, NULL);
+			       iph->saddr, uh->source, skb->dev->ifindex,
+			       inet_sdif(skb), udptable, NULL);
 	if (!sk) {
-		__ICMP_INC_STATS(net, ICMP_MIB_INERRORS);
-		return;	/* No socket for error */
+		/* No socket for error: try tunnels before discarding */
+		sk = ERR_PTR(-ENOENT);
+		if (static_branch_unlikely(&udp_encap_needed_key)) {
+			sk = __udp4_lib_err_encap(net, iph, uh, udptable, skb,
+						  info);
+			if (!sk)
+				return 0;
+		}
+
+		if (IS_ERR(sk)) {
+			__ICMP_INC_STATS(net, ICMP_MIB_INERRORS);
+			return PTR_ERR(sk);
+		}
+
+		tunnel = true;
 	}
 
 	err = 0;
@@ -664,6 +778,10 @@
 	 *      RFC1122: OK.  Passes ICMP errors back to application, as per
 	 *	4.1.3.3.
 	 */
+	if (tunnel) {
+		/* ...not for tunnels though: we don't have a sending socket */
+		goto out;
+	}
 	if (!inet->recverr) {
 		if (!harderr || sk->sk_state != TCP_ESTABLISHED)
 			goto out;
@@ -673,12 +791,12 @@
 	sk->sk_err = err;
 	sk->sk_error_report(sk);
 out:
-	return;
+	return 0;
 }
 
-void udp_err(struct sk_buff *skb, u32 info)
+int udp_err(struct sk_buff *skb, u32 info)
 {
-	__udp4_lib_err(skb, info, &udp_table);
+	return __udp4_lib_err(skb, info, &udp_table);
 }
 
 /*
@@ -949,6 +1067,7 @@
 
 	if (msg->msg_flags & MSG_OOB) /* Mirror BSD error message compatibility */
 		return -EOPNOTSUPP;
+	trace_android_rvh_udp_sendmsg(sk);
 
 	getfrag = is_udplite ? udplite_getfrag : ip_generic_getfrag;
 
@@ -1061,7 +1180,7 @@
 	}
 
 	if (ipv4_is_multicast(daddr)) {
-		if (!ipc.oif)
+		if (!ipc.oif || netif_index_is_l3_master(sock_net(sk), ipc.oif))
 			ipc.oif = inet->mc_index;
 		if (!saddr)
 			saddr = inet->mc_addr;
@@ -1070,7 +1189,7 @@
 		ipc.oif = inet->uc_index;
 	} else if (ipv4_is_lbcast(daddr) && inet->uc_index) {
 		/* oif is set, packet is to local broadcast and
-		 * and uc_index is set. oif is most likely set
+		 * uc_index is set. oif is most likely set
 		 * by sk_bound_dev_if. If uc_index != oif check if the
 		 * oif is an L3 master and uc_index is an L3 slave.
 		 * If so, we want to allow the send using the uc_index.
@@ -1091,13 +1210,13 @@
 
 		fl4 = &fl4_stack;
 
-		flowi4_init_output(fl4, ipc.oif, sk->sk_mark, tos,
+		flowi4_init_output(fl4, ipc.oif, ipc.sockc.mark, tos,
 				   RT_SCOPE_UNIVERSE, sk->sk_protocol,
 				   flow_flags,
 				   faddr, saddr, dport, inet->inet_sport,
 				   sk->sk_uid);
 
-		security_sk_classify_flow(sk, flowi4_to_flowi(fl4));
+		security_sk_classify_flow(sk, flowi4_to_flowi_common(fl4));
 		rt = ip_route_output_flow(net, fl4, sk);
 		if (IS_ERR(rt)) {
 			err = PTR_ERR(rt);
@@ -1254,6 +1373,27 @@
 
 #define UDP_SKB_IS_STATELESS 0x80000000
 
+/* all head states (dst, sk, nf conntrack) except skb extensions are
+ * cleared by udp_rcv().
+ *
+ * We need to preserve secpath, if present, to eventually process
+ * IP_CMSG_PASSSEC at recvmsg() time.
+ *
+ * Other extensions can be cleared.
+ */
+static bool udp_try_make_stateless(struct sk_buff *skb)
+{
+	if (!skb_has_extensions(skb))
+		return true;
+
+	if (!secpath_exists(skb)) {
+		skb_ext_reset(skb);
+		return true;
+	}
+
+	return false;
+}
+
 static void udp_set_dev_scratch(struct sk_buff *skb)
 {
 	struct udp_dev_scratch *scratch = udp_skb_scratch(skb);
@@ -1265,11 +1405,7 @@
 	scratch->csum_unnecessary = !!skb_csum_unnecessary(skb);
 	scratch->is_linear = !skb_is_nonlinear(skb);
 #endif
-	/* all head states execept sp (dst, sk, nf) are always cleared by
-	 * udp_rcv() and we need to preserve secpath, if present, to eventually
-	 * process IP_CMSG_PASSSEC at recvmsg() time
-	 */
-	if (likely(!skb_sec_path(skb)))
+	if (udp_try_make_stateless(skb))
 		scratch->_tsize_state |= UDP_SKB_IS_STATELESS;
 }
 
@@ -1458,7 +1594,7 @@
 }
 EXPORT_SYMBOL_GPL(__udp_enqueue_schedule_skb);
 
-void udp_destruct_sock(struct sock *sk)
+void udp_destruct_common(struct sock *sk)
 {
 	/* reclaim completely the forward allocated memory */
 	struct udp_sock *up = udp_sk(sk);
@@ -1471,10 +1607,14 @@
 		kfree_skb(skb);
 	}
 	udp_rmem_release(sk, total, 0, true);
+}
+EXPORT_SYMBOL_GPL(udp_destruct_common);
 
+static void udp_destruct_sock(struct sock *sk)
+{
+	udp_destruct_common(sk);
 	inet_sock_destruct(sk);
 }
-EXPORT_SYMBOL_GPL(udp_destruct_sock);
 
 int udp_init_sock(struct sock *sk)
 {
@@ -1482,7 +1622,6 @@
 	sk->sk_destruct = udp_destruct_sock;
 	return 0;
 }
-EXPORT_SYMBOL_GPL(udp_init_sock);
 
 void skb_consume_udp(struct sock *sk, struct sk_buff *skb, int len)
 {
@@ -1590,7 +1729,7 @@
 EXPORT_SYMBOL(udp_ioctl);
 
 struct sk_buff *__skb_recv_udp(struct sock *sk, unsigned int flags,
-			       int noblock, int *peeked, int *off, int *err)
+			       int noblock, int *off, int *err)
 {
 	struct sk_buff_head *sk_queue = &sk->sk_receive_queue;
 	struct sk_buff_head *queue;
@@ -1609,14 +1748,13 @@
 			break;
 
 		error = -EAGAIN;
-		*peeked = 0;
 		do {
 			spin_lock_bh(&queue->lock);
-			skb = __skb_try_recv_from_queue(sk, queue, flags,
-							udp_skb_destructor,
-							peeked, off, err,
-							&last);
+			skb = __skb_try_recv_from_queue(sk, queue, flags, off,
+							err, &last);
 			if (skb) {
+				if (!(flags & MSG_PEEK))
+					udp_skb_destructor(sk, skb);
 				spin_unlock_bh(&queue->lock);
 				return skb;
 			}
@@ -1634,10 +1772,10 @@
 			spin_lock(&sk_queue->lock);
 			skb_queue_splice_tail_init(sk_queue, queue);
 
-			skb = __skb_try_recv_from_queue(sk, queue, flags,
-							udp_skb_dtor_locked,
-							peeked, off, err,
-							&last);
+			skb = __skb_try_recv_from_queue(sk, queue, flags, off,
+							err, &last);
+			if (skb && !(flags & MSG_PEEK))
+				udp_skb_dtor_locked(sk, skb);
 			spin_unlock(&sk_queue->lock);
 			spin_unlock_bh(&queue->lock);
 			if (skb)
@@ -1652,7 +1790,8 @@
 
 		/* sk_queue is empty, reader_queue may contain peeked packets */
 	} while (timeo &&
-		 !__skb_wait_for_more_packets(sk, &error, &timeo,
+		 !__skb_wait_for_more_packets(sk, &sk->sk_receive_queue,
+					      &error, &timeo,
 					      (struct sk_buff *)sk_queue));
 
 	*err = error;
@@ -1672,8 +1811,7 @@
 	DECLARE_SOCKADDR(struct sockaddr_in *, sin, msg->msg_name);
 	struct sk_buff *skb;
 	unsigned int ulen, copied;
-	int peeked, peeking, off;
-	int err;
+	int off, err, peeking = flags & MSG_PEEK;
 	int is_udplite = IS_UDPLITE(sk);
 	bool checksum_valid = false;
 
@@ -1681,11 +1819,11 @@
 		return ip_recv_error(sk, msg, len, addr_len);
 
 try_again:
-	peeking = flags & MSG_PEEK;
 	off = sk_peek_offset(sk, flags);
-	skb = __skb_recv_udp(sk, flags, noblock, &peeked, &off, &err);
+	skb = __skb_recv_udp(sk, flags, noblock, &off, &err);
 	if (!skb)
 		return err;
+	trace_android_rvh_udp_recvmsg(sk);
 
 	ulen = udp_skb_len(skb);
 	copied = len;
@@ -1721,7 +1859,7 @@
 	}
 
 	if (unlikely(err)) {
-		if (!peeked) {
+		if (!peeking) {
 			atomic_inc(&sk->sk_drops);
 			UDP_INC_STATS(sock_net(sk),
 				      UDP_MIB_INERRORS, is_udplite);
@@ -1730,7 +1868,7 @@
 		return err;
 	}
 
-	if (!peeked)
+	if (!peeking)
 		UDP_INC_STATS(sock_net(sk),
 			      UDP_MIB_INDATAGRAMS, is_udplite);
 
@@ -1748,6 +1886,10 @@
 			BPF_CGROUP_RUN_PROG_UDP4_RECVMSG_LOCK(sk,
 							(struct sockaddr *)sin);
 	}
+
+	if (udp_sk(sk)->gro_enabled)
+		udp_cmsg_recv(msg, sk, skb);
+
 	if (inet->cmsg_flags)
 		ip_cmsg_recv_offset(msg, sk, skb, sizeof(struct udphdr), off);
 
@@ -1797,8 +1939,12 @@
 	inet->inet_dport = 0;
 	sock_rps_reset_rxhash(sk);
 	sk->sk_bound_dev_if = 0;
-	if (!(sk->sk_userlocks & SOCK_BINDADDR_LOCK))
+	if (!(sk->sk_userlocks & SOCK_BINDADDR_LOCK)) {
 		inet_reset_saddr(sk);
+		if (sk->sk_prot->rehash &&
+		    (sk->sk_userlocks & SOCK_BINDPORT_LOCK))
+			sk->sk_prot->rehash(sk);
+	}
 
 	if (!(sk->sk_userlocks & SOCK_BINDPORT_LOCK)) {
 		sk->sk_prot->unhash(sk);
@@ -1887,7 +2033,7 @@
 }
 EXPORT_SYMBOL(udp_lib_rehash);
 
-static void udp_v4_rehash(struct sock *sk)
+void udp_v4_rehash(struct sock *sk)
 {
 	u16 new_hash = ipv4_portaddr_hash(sock_net(sk),
 					  inet_sk(sk)->inet_rcv_saddr,
@@ -1924,13 +2070,6 @@
 	return 0;
 }
 
-static DEFINE_STATIC_KEY_FALSE(udp_encap_needed_key);
-void udp_encap_enable(void)
-{
-	static_branch_enable(&udp_encap_needed_key);
-}
-EXPORT_SYMBOL(udp_encap_enable);
-
 /* returns:
  *  -1: error
  *   0: success
@@ -1939,7 +2078,7 @@
  * Note that in the success and error cases, the skb is assumed to
  * have either been requeued or freed.
  */
-static int udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
+static int udp_queue_rcv_one_skb(struct sock *sk, struct sk_buff *skb)
 {
 	struct udp_sock *up = udp_sk(sk);
 	int is_udplite = IS_UDPLITE(sk);
@@ -1949,7 +2088,7 @@
 	 */
 	if (!xfrm4_policy_check(sk, XFRM_POLICY_IN, skb))
 		goto drop;
-	nf_reset(skb);
+	nf_reset_ct(skb);
 
 	if (static_branch_unlikely(&udp_encap_needed_key) && up->encap_type) {
 		int (*encap_rcv)(struct sock *sk, struct sk_buff *skb);
@@ -2042,6 +2181,26 @@
 	return -1;
 }
 
+static int udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
+{
+	struct sk_buff *next, *segs;
+	int ret;
+
+	if (likely(!udp_unexpected_gso(sk, skb)))
+		return udp_queue_rcv_one_skb(sk, skb);
+
+	BUILD_BUG_ON(sizeof(struct udp_skb_cb) > SKB_GSO_CB_OFFSET);
+	__skb_push(skb, -skb_mac_offset(skb));
+	segs = udp_rcv_segment(sk, skb, true);
+	skb_list_walk_safe(segs, skb, next) {
+		__skb_pull(skb, skb_transport_offset(skb));
+		ret = udp_queue_rcv_one_skb(sk, skb);
+		if (ret > 0)
+			ip_protocol_deliver_rcu(dev_net(skb->dev), skb, ret);
+	}
+	return 0;
+}
+
 /* For TCP sockets, sk_rx_dst is protected by socket lock
  * For UDP, we use xchg() to guard against concurrent changes.
  */
@@ -2050,7 +2209,7 @@
 	struct dst_entry *old;
 
 	if (dst_hold_safe(dst)) {
-		old = xchg(&sk->sk_rx_dst, dst);
+		old = xchg((__force struct dst_entry **)&sk->sk_rx_dst, dst);
 		dst_release(old);
 		return old != dst;
 	}
@@ -2130,7 +2289,7 @@
 
 /* Initialize UDP checksum. If exited with zero value (success),
  * CHECKSUM_UNNECESSARY means, that no more checks are required.
- * Otherwise, csum completion requires chacksumming packet body,
+ * Otherwise, csum completion requires checksumming packet body,
  * including udp header and folding it to skb->csum.
  */
 static inline int udp4_csum_init(struct sk_buff *skb, struct udphdr *uh,
@@ -2184,8 +2343,7 @@
 	int ret;
 
 	if (inet_get_convert_csum(sk) && uh->check && !IS_UDPLITE(sk))
-		skb_checksum_try_convert(skb, IPPROTO_UDP, uh->check,
-					 inet_compute_pseudo);
+		skb_checksum_try_convert(skb, IPPROTO_UDP, inet_compute_pseudo);
 
 	ret = udp_queue_rcv_skb(sk, skb);
 
@@ -2210,6 +2368,7 @@
 	struct rtable *rt = skb_rtable(skb);
 	__be32 saddr, daddr;
 	struct net *net = dev_net(skb->dev);
+	bool refcounted;
 
 	/*
 	 *  Validate the packet.
@@ -2235,16 +2394,17 @@
 	if (udp4_csum_init(skb, uh, proto))
 		goto csum_error;
 
-	sk = skb_steal_sock(skb);
+	sk = skb_steal_sock(skb, &refcounted);
 	if (sk) {
 		struct dst_entry *dst = skb_dst(skb);
 		int ret;
 
-		if (unlikely(sk->sk_rx_dst != dst))
+		if (unlikely(rcu_dereference(sk->sk_rx_dst) != dst))
 			udp_sk_rx_dst_set(sk, dst);
 
 		ret = udp_unicast_rcv_skb(sk, skb, uh);
-		sock_put(sk);
+		if (refcounted)
+			sock_put(sk);
 		return ret;
 	}
 
@@ -2258,7 +2418,7 @@
 
 	if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb))
 		goto drop;
-	nf_reset(skb);
+	nf_reset_ct(skb);
 
 	/* No socket. Drop packet silently, if checksum is wrong */
 	if (udp_lib_checksum_complete(skb))
@@ -2346,8 +2506,7 @@
 	struct sock *sk;
 
 	udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
-		if (INET_MATCH(sk, net, acookie, rmt_addr,
-			       loc_addr, ports, dif, sdif))
+		if (INET_MATCH(net, sk, acookie, ports, dif, sdif))
 			return sk;
 		/* Only check first socket in chain */
 		break;
@@ -2398,7 +2557,7 @@
 
 	skb->sk = sk;
 	skb->destructor = sock_efree;
-	dst = READ_ONCE(sk->sk_rx_dst);
+	dst = rcu_dereference(sk->sk_rx_dst);
 
 	if (dst)
 		dst = dst_check(dst, 0);
@@ -2437,11 +2596,15 @@
 	sock_set_flag(sk, SOCK_DEAD);
 	udp_flush_pending_frames(sk);
 	unlock_sock_fast(sk, slow);
-	if (static_branch_unlikely(&udp_encap_needed_key) && up->encap_type) {
-		void (*encap_destroy)(struct sock *sk);
-		encap_destroy = READ_ONCE(up->encap_destroy);
-		if (encap_destroy)
-			encap_destroy(sk);
+	if (static_branch_unlikely(&udp_encap_needed_key)) {
+		if (up->encap_type) {
+			void (*encap_destroy)(struct sock *sk);
+			encap_destroy = READ_ONCE(up->encap_destroy);
+			if (encap_destroy)
+				encap_destroy(sk);
+		}
+		if (up->encap_enabled)
+			static_branch_dec(&udp_encap_needed_key);
 	}
 }
 
@@ -2449,7 +2612,7 @@
  *	Socket option code for UDP
  */
 int udp_lib_setsockopt(struct sock *sk, int level, int optname,
-		       char __user *optval, unsigned int optlen,
+		       sockptr_t optval, unsigned int optlen,
 		       int (*push_pending_frames)(struct sock *))
 {
 	struct udp_sock *up = udp_sk(sk);
@@ -2460,7 +2623,7 @@
 	if (optlen < sizeof(int))
 		return -EINVAL;
 
-	if (get_user(val, (int __user *)optval))
+	if (copy_from_sockptr(&val, optval, sizeof(val)))
 		return -EFAULT;
 
 	valbool = val ? 1 : 0;
@@ -2480,13 +2643,22 @@
 	case UDP_ENCAP:
 		switch (val) {
 		case 0:
+#ifdef CONFIG_XFRM
 		case UDP_ENCAP_ESPINUDP:
 		case UDP_ENCAP_ESPINUDP_NON_IKE:
-			up->encap_rcv = xfrm4_udp_encap_rcv;
-			/* FALLTHROUGH */
+#if IS_ENABLED(CONFIG_IPV6)
+			if (sk->sk_family == AF_INET6)
+				up->encap_rcv = ipv6_stub->xfrm6_udp_encap_rcv;
+			else
+#endif
+				up->encap_rcv = xfrm4_udp_encap_rcv;
+#endif
+			fallthrough;
 		case UDP_ENCAP_L2TPINUDP:
 			up->encap_type = val;
-			udp_encap_enable();
+			lock_sock(sk);
+			udp_tunnel_encap_enable(sk->sk_socket);
+			release_sock(sk);
 			break;
 		default:
 			err = -ENOPROTOOPT;
@@ -2506,6 +2678,17 @@
 		if (val < 0 || val > USHRT_MAX)
 			return -EINVAL;
 		WRITE_ONCE(up->gso_size, val);
+		break;
+
+	case UDP_GRO:
+		lock_sock(sk);
+
+		/* when enabling GRO, accept the related GSO packet type */
+		if (valbool)
+			udp_tunnel_encap_enable(sk->sk_socket);
+		up->gro_enabled = valbool;
+		up->accept_udp_l4 = valbool;
+		release_sock(sk);
 		break;
 
 	/*
@@ -2547,25 +2730,15 @@
 }
 EXPORT_SYMBOL(udp_lib_setsockopt);
 
-int udp_setsockopt(struct sock *sk, int level, int optname,
-		   char __user *optval, unsigned int optlen)
+int udp_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval,
+		   unsigned int optlen)
 {
 	if (level == SOL_UDP  ||  level == SOL_UDPLITE)
-		return udp_lib_setsockopt(sk, level, optname, optval, optlen,
+		return udp_lib_setsockopt(sk, level, optname,
+					  optval, optlen,
 					  udp_push_pending_frames);
 	return ip_setsockopt(sk, level, optname, optval, optlen);
 }
-
-#ifdef CONFIG_COMPAT
-int compat_udp_setsockopt(struct sock *sk, int level, int optname,
-			  char __user *optval, unsigned int optlen)
-{
-	if (level == SOL_UDP  ||  level == SOL_UDPLITE)
-		return udp_lib_setsockopt(sk, level, optname, optval, optlen,
-					  udp_push_pending_frames);
-	return compat_ip_setsockopt(sk, level, optname, optval, optlen);
-}
-#endif
 
 int udp_lib_getsockopt(struct sock *sk, int level, int optname,
 		       char __user *optval, int __user *optlen)
@@ -2602,6 +2775,10 @@
 		val = READ_ONCE(up->gso_size);
 		break;
 
+	case UDP_GRO:
+		val = up->gro_enabled;
+		break;
+
 	/* The following two cannot be changed on UDP sockets, the return is
 	 * always 0 (which corresponds to the full checksum coverage of UDP). */
 	case UDPLITE_SEND_CSCOV:
@@ -2632,20 +2809,11 @@
 	return ip_getsockopt(sk, level, optname, optval, optlen);
 }
 
-#ifdef CONFIG_COMPAT
-int compat_udp_getsockopt(struct sock *sk, int level, int optname,
-				 char __user *optval, int __user *optlen)
-{
-	if (level == SOL_UDP  ||  level == SOL_UDPLITE)
-		return udp_lib_getsockopt(sk, level, optname, optval, optlen);
-	return compat_ip_getsockopt(sk, level, optname, optval, optlen);
-}
-#endif
 /**
  * 	udp_poll - wait for a UDP event.
- *	@file - file struct
- *	@sock - socket
- *	@wait - poll table
+ *	@file: - file struct
+ *	@sock: - socket
+ *	@wait: - poll table
  *
  *	This is same as datagram poll, except for the special case of
  *	blocking sockets. If application is using a blocking fd
@@ -2719,10 +2887,6 @@
 	.sysctl_rmem_offset	= offsetof(struct net, ipv4.sysctl_udp_rmem_min),
 	.obj_size		= sizeof(struct udp_sock),
 	.h.udp_table		= &udp_table,
-#ifdef CONFIG_COMPAT
-	.compat_setsockopt	= compat_udp_setsockopt,
-	.compat_getsockopt	= compat_udp_getsockopt,
-#endif
 	.diag_destroy		= udp_abort,
 };
 EXPORT_SYMBOL(udp_prot);
@@ -2733,9 +2897,14 @@
 static struct sock *udp_get_first(struct seq_file *seq, int start)
 {
 	struct sock *sk;
-	struct udp_seq_afinfo *afinfo = PDE_DATA(file_inode(seq->file));
+	struct udp_seq_afinfo *afinfo;
 	struct udp_iter_state *state = seq->private;
 	struct net *net = seq_file_net(seq);
+
+	if (state->bpf_seq_afinfo)
+		afinfo = state->bpf_seq_afinfo;
+	else
+		afinfo = PDE_DATA(file_inode(seq->file));
 
 	for (state->bucket = start; state->bucket <= afinfo->udp_table->mask;
 	     ++state->bucket) {
@@ -2748,7 +2917,8 @@
 		sk_for_each(sk, &hslot->head) {
 			if (!net_eq(sock_net(sk), net))
 				continue;
-			if (sk->sk_family == afinfo->family)
+			if (afinfo->family == AF_UNSPEC ||
+			    sk->sk_family == afinfo->family)
 				goto found;
 		}
 		spin_unlock_bh(&hslot->lock);
@@ -2760,13 +2930,20 @@
 
 static struct sock *udp_get_next(struct seq_file *seq, struct sock *sk)
 {
-	struct udp_seq_afinfo *afinfo = PDE_DATA(file_inode(seq->file));
+	struct udp_seq_afinfo *afinfo;
 	struct udp_iter_state *state = seq->private;
 	struct net *net = seq_file_net(seq);
 
+	if (state->bpf_seq_afinfo)
+		afinfo = state->bpf_seq_afinfo;
+	else
+		afinfo = PDE_DATA(file_inode(seq->file));
+
 	do {
 		sk = sk_next(sk);
-	} while (sk && (!net_eq(sock_net(sk), net) || sk->sk_family != afinfo->family));
+	} while (sk && (!net_eq(sock_net(sk), net) ||
+			(afinfo->family != AF_UNSPEC &&
+			 sk->sk_family != afinfo->family)));
 
 	if (!sk) {
 		if (state->bucket <= afinfo->udp_table->mask)
@@ -2811,8 +2988,13 @@
 
 void udp_seq_stop(struct seq_file *seq, void *v)
 {
-	struct udp_seq_afinfo *afinfo = PDE_DATA(file_inode(seq->file));
+	struct udp_seq_afinfo *afinfo;
 	struct udp_iter_state *state = seq->private;
+
+	if (state->bpf_seq_afinfo)
+		afinfo = state->bpf_seq_afinfo;
+	else
+		afinfo = PDE_DATA(file_inode(seq->file));
 
 	if (state->bucket <= afinfo->udp_table->mask)
 		spin_unlock_bh(&afinfo->udp_table->hash[state->bucket].lock);
@@ -2830,7 +3012,7 @@
 	__u16 srcp	  = ntohs(inet->inet_sport);
 
 	seq_printf(f, "%5d: %08X:%04X %08X:%04X"
-		" %02X %08X:%08X %02X:%08lX %08X %5u %8d %lu %d %pK %d",
+		" %02X %08X:%08X %02X:%08lX %08X %5u %8d %lu %d %pK %u",
 		bucket, src, srcp, dest, destp, sp->sk_state,
 		sk_wmem_alloc_get(sp),
 		udp_rqueue_get(sp),
@@ -2856,6 +3038,67 @@
 	seq_pad(seq, '\n');
 	return 0;
 }
+
+#ifdef CONFIG_BPF_SYSCALL
+struct bpf_iter__udp {
+	__bpf_md_ptr(struct bpf_iter_meta *, meta);
+	__bpf_md_ptr(struct udp_sock *, udp_sk);
+	uid_t uid __aligned(8);
+	int bucket __aligned(8);
+};
+
+static int udp_prog_seq_show(struct bpf_prog *prog, struct bpf_iter_meta *meta,
+			     struct udp_sock *udp_sk, uid_t uid, int bucket)
+{
+	struct bpf_iter__udp ctx;
+
+	meta->seq_num--;  /* skip SEQ_START_TOKEN */
+	ctx.meta = meta;
+	ctx.udp_sk = udp_sk;
+	ctx.uid = uid;
+	ctx.bucket = bucket;
+	return bpf_iter_run_prog(prog, &ctx);
+}
+
+static int bpf_iter_udp_seq_show(struct seq_file *seq, void *v)
+{
+	struct udp_iter_state *state = seq->private;
+	struct bpf_iter_meta meta;
+	struct bpf_prog *prog;
+	struct sock *sk = v;
+	uid_t uid;
+
+	if (v == SEQ_START_TOKEN)
+		return 0;
+
+	uid = from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk));
+	meta.seq = seq;
+	prog = bpf_iter_get_info(&meta, false);
+	return udp_prog_seq_show(prog, &meta, v, uid, state->bucket);
+}
+
+static void bpf_iter_udp_seq_stop(struct seq_file *seq, void *v)
+{
+	struct bpf_iter_meta meta;
+	struct bpf_prog *prog;
+
+	if (!v) {
+		meta.seq = seq;
+		prog = bpf_iter_get_info(&meta, true);
+		if (prog)
+			(void)udp_prog_seq_show(prog, &meta, v, 0, 0);
+	}
+
+	udp_seq_stop(seq, v);
+}
+
+static const struct seq_operations bpf_iter_udp_seq_ops = {
+	.start		= udp_seq_start,
+	.next		= udp_seq_next,
+	.stop		= bpf_iter_udp_seq_stop,
+	.show		= bpf_iter_udp_seq_show,
+};
+#endif
 
 const struct seq_operations udp_seq_ops = {
 	.start		= udp_seq_start,
@@ -2974,6 +3217,62 @@
 	.init	= udp_sysctl_init,
 };
 
+#if defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS)
+DEFINE_BPF_ITER_FUNC(udp, struct bpf_iter_meta *meta,
+		     struct udp_sock *udp_sk, uid_t uid, int bucket)
+
+static int bpf_iter_init_udp(void *priv_data, struct bpf_iter_aux_info *aux)
+{
+	struct udp_iter_state *st = priv_data;
+	struct udp_seq_afinfo *afinfo;
+	int ret;
+
+	afinfo = kmalloc(sizeof(*afinfo), GFP_USER | __GFP_NOWARN);
+	if (!afinfo)
+		return -ENOMEM;
+
+	afinfo->family = AF_UNSPEC;
+	afinfo->udp_table = &udp_table;
+	st->bpf_seq_afinfo = afinfo;
+	ret = bpf_iter_init_seq_net(priv_data, aux);
+	if (ret)
+		kfree(afinfo);
+	return ret;
+}
+
+static void bpf_iter_fini_udp(void *priv_data)
+{
+	struct udp_iter_state *st = priv_data;
+
+	kfree(st->bpf_seq_afinfo);
+	bpf_iter_fini_seq_net(priv_data);
+}
+
+static const struct bpf_iter_seq_info udp_seq_info = {
+	.seq_ops		= &bpf_iter_udp_seq_ops,
+	.init_seq_private	= bpf_iter_init_udp,
+	.fini_seq_private	= bpf_iter_fini_udp,
+	.seq_priv_size		= sizeof(struct udp_iter_state),
+};
+
+static struct bpf_iter_reg udp_reg_info = {
+	.target			= "udp",
+	.ctx_arg_info_size	= 1,
+	.ctx_arg_info		= {
+		{ offsetof(struct bpf_iter__udp, udp_sk),
+		  PTR_TO_BTF_ID_OR_NULL },
+	},
+	.seq_info		= &udp_seq_info,
+};
+
+static void __init bpf_iter_register(void)
+{
+	udp_reg_info.ctx_arg_info[0].btf_id = btf_sock_ids[BTF_SOCK_TYPE_UDP];
+	if (bpf_iter_reg_target(&udp_reg_info))
+		pr_warn("Warning: could not register bpf iterator udp\n");
+}
+#endif
+
 void __init udp_init(void)
 {
 	unsigned long limit;
@@ -2999,4 +3298,8 @@
 
 	if (register_pernet_subsys(&udp_sysctl_ops))
 		panic("UDP: failed to init sysctl parameters.\n");
+
+#if defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS)
+	bpf_iter_register();
+#endif
 }

--
Gitblit v1.6.2