From 9999e48639b3cecb08ffb37358bcba3b48161b29 Mon Sep 17 00:00:00 2001
From: hc <hc@nodka.com>
Date: Fri, 10 May 2024 08:50:17 +0000
Subject: [PATCH] add ax88772_rst

---
 kernel/net/netlink/af_netlink.c |  429 +++++++++++++++++++++++++++++++++--------------------
 1 files changed, 267 insertions(+), 162 deletions(-)

diff --git a/kernel/net/netlink/af_netlink.c b/kernel/net/netlink/af_netlink.c
index 21ec067..9737c32 100644
--- a/kernel/net/netlink/af_netlink.c
+++ b/kernel/net/netlink/af_netlink.c
@@ -1,14 +1,10 @@
+// SPDX-License-Identifier: GPL-2.0-or-later
 /*
  * NETLINK      Kernel-user communication protocol.
  *
  * 		Authors:	Alan Cox <alan@lxorguk.ukuu.org.uk>
  * 				Alexey Kuznetsov <kuznet@ms2.inr.ac.ru>
  * 				Patrick McHardy <kaber@trash.net>
- *
- *		This program is free software; you can redistribute it and/or
- *		modify it under the terms of the GNU General Public License
- *		as published by the Free Software Foundation; either version
- *		2 of the License, or (at your option) any later version.
  *
  * Tue Jun 26 14:36:48 MEST 2001 Herbert "herp" Rosmanith
  *                               added netlink_proto_exit
@@ -64,6 +60,7 @@
 #include <linux/genetlink.h>
 #include <linux/net_namespace.h>
 #include <linux/nospec.h>
+#include <linux/btf_ids.h>
 
 #include <net/net_namespace.h>
 #include <net/netns/generic.h>
@@ -75,7 +72,7 @@
 
 struct listeners {
 	struct rcu_head		rcu;
-	unsigned long		masks[0];
+	unsigned long		masks[];
 };
 
 /* state bits */
@@ -152,6 +149,8 @@
 
 static inline u32 netlink_group_mask(u32 group)
 {
+	if (group > 32)
+		return 0;
 	return group ? 1 << (group - 1) : 0;
 }
 
@@ -245,13 +244,8 @@
 	return 0;
 }
 
-static void __net_exit netlink_tap_exit_net(struct net *net)
-{
-}
-
 static struct pernet_operations netlink_tap_net_ops = {
 	.init = netlink_tap_init_net,
-	.exit = netlink_tap_exit_net,
 	.id   = &netlink_tap_net_id,
 	.size = sizeof(struct netlink_tap_net),
 };
@@ -361,7 +355,7 @@
 {
 	struct netlink_sock *nlk = nlk_sk(sk);
 
-	if (skb_queue_empty(&sk->sk_receive_queue))
+	if (skb_queue_empty_lockless(&sk->sk_receive_queue))
 		clear_bit(NETLINK_S_CONGESTED, &nlk->state);
 	if (!test_bit(NETLINK_S_CONGESTED, &nlk->state))
 		wake_up_interruptible(&nlk->wait);
@@ -576,12 +570,9 @@
 	if (nlk_sk(sk)->bound)
 		goto err;
 
-	err = -ENOMEM;
-	if (BITS_PER_LONG > 32 &&
-	    unlikely(atomic_read(&table->hash.nelems) >= UINT_MAX))
-		goto err;
+	/* portid can be read locklessly from netlink_getname(). */
+	WRITE_ONCE(nlk_sk(sk)->portid, portid);
 
-	nlk_sk(sk)->portid = portid;
 	sock_hold(sk);
 
 	err = __netlink_insert(table, sk);
@@ -866,7 +857,7 @@
  *
  * Test to see if the opener of the socket we received the message
  * from had when the netlink socket was created and the sender of the
- * message has has the capability @cap in the user namespace @user_ns.
+ * message has the capability @cap in the user namespace @user_ns.
  */
 bool __netlink_ns_capable(const struct netlink_skb_parms *nsp,
 			struct user_namespace *user_ns, int cap)
@@ -885,7 +876,7 @@
  *
  * Test to see if the opener of the socket we received the message
  * from had when the netlink socket was created and the sender of the
- * message has has the capability @cap in the user namespace @user_ns.
+ * message has the capability @cap in the user namespace @user_ns.
  */
 bool netlink_ns_capable(const struct sk_buff *skb,
 			struct user_namespace *user_ns, int cap)
@@ -901,7 +892,7 @@
  *
  * Test to see if the opener of the socket we received the message
  * from had when the netlink socket was created and the sender of the
- * message has has the capability @cap in all user namespaces.
+ * message has the capability @cap in all user namespaces.
  */
 bool netlink_capable(const struct sk_buff *skb, int cap)
 {
@@ -916,7 +907,7 @@
  *
  * Test to see if the opener of the socket we received the message
  * from had when the netlink socket was created and the sender of the
- * message has has the capability @cap over the network namespace of
+ * message has the capability @cap over the network namespace of
  * the socket we received the message from.
  */
 bool netlink_net_capable(const struct sk_buff *skb, int cap)
@@ -998,7 +989,7 @@
 	struct netlink_sock *nlk = nlk_sk(sk);
 	struct sockaddr_nl *nladdr = (struct sockaddr_nl *)addr;
 	int err = 0;
-	long unsigned int groups = nladdr->nl_groups;
+	unsigned long groups;
 	bool bound;
 
 	if (addr_len < sizeof(struct sockaddr_nl))
@@ -1006,6 +997,7 @@
 
 	if (nladdr->nl_family != AF_NETLINK)
 		return -EINVAL;
+	groups = nladdr->nl_groups;
 
 	/* Only superuser is allowed to listen multicasts */
 	if (groups) {
@@ -1016,9 +1008,7 @@
 			return err;
 	}
 
-	if (nlk->ngroups == 0)
-		groups = 0;
-	else if (nlk->ngroups < 8*sizeof(groups))
+	if (nlk->ngroups < BITS_PER_LONG)
 		groups &= (1UL << nlk->ngroups) - 1;
 
 	/* Paired with WRITE_ONCE() in netlink_insert() */
@@ -1091,9 +1081,11 @@
 		return -EINVAL;
 
 	if (addr->sa_family == AF_UNSPEC) {
-		sk->sk_state	= NETLINK_UNCONNECTED;
-		nlk->dst_portid	= 0;
-		nlk->dst_group  = 0;
+		/* paired with READ_ONCE() in netlink_getsockbyportid() */
+		WRITE_ONCE(sk->sk_state, NETLINK_UNCONNECTED);
+		/* dst_portid and dst_group can be read locklessly */
+		WRITE_ONCE(nlk->dst_portid, 0);
+		WRITE_ONCE(nlk->dst_group, 0);
 		return 0;
 	}
 	if (addr->sa_family != AF_NETLINK)
@@ -1114,9 +1106,11 @@
 		err = netlink_autobind(sock);
 
 	if (err == 0) {
-		sk->sk_state	= NETLINK_CONNECTED;
-		nlk->dst_portid = nladdr->nl_pid;
-		nlk->dst_group  = ffs(nladdr->nl_groups);
+		/* paired with READ_ONCE() in netlink_getsockbyportid() */
+		WRITE_ONCE(sk->sk_state, NETLINK_CONNECTED);
+		/* dst_portid and dst_group can be read locklessly */
+		WRITE_ONCE(nlk->dst_portid, nladdr->nl_pid);
+		WRITE_ONCE(nlk->dst_group, ffs(nladdr->nl_groups));
 	}
 
 	return err;
@@ -1133,10 +1127,12 @@
 	nladdr->nl_pad = 0;
 
 	if (peer) {
-		nladdr->nl_pid = nlk->dst_portid;
-		nladdr->nl_groups = netlink_group_mask(nlk->dst_group);
+		/* Paired with WRITE_ONCE() in netlink_connect() */
+		nladdr->nl_pid = READ_ONCE(nlk->dst_portid);
+		nladdr->nl_groups = netlink_group_mask(READ_ONCE(nlk->dst_group));
 	} else {
-		nladdr->nl_pid = nlk->portid;
+		/* Paired with WRITE_ONCE() in netlink_insert() */
+		nladdr->nl_pid = READ_ONCE(nlk->portid);
 		netlink_lock_table();
 		nladdr->nl_groups = nlk->groups ? nlk->groups[0] : 0;
 		netlink_unlock_table();
@@ -1163,8 +1159,9 @@
 
 	/* Don't bother queuing skb if kernel socket has no input function */
 	nlk = nlk_sk(sock);
-	if (sock->sk_state == NETLINK_CONNECTED &&
-	    nlk->dst_portid != nlk_sk(ssk)->portid) {
+	/* dst_portid and sk_state can be changed in netlink_connect() */
+	if (READ_ONCE(sock->sk_state) == NETLINK_CONNECTED &&
+	    READ_ONCE(nlk->dst_portid) != nlk_sk(ssk)->portid) {
 		sock_put(sock);
 		return ERR_PTR(-ECONNREFUSED);
 	}
@@ -1386,6 +1383,14 @@
 }
 EXPORT_SYMBOL_GPL(netlink_has_listeners);
 
+bool netlink_strict_get_check(struct sk_buff *skb)
+{
+	const struct netlink_sock *nlk = nlk_sk(NETLINK_CB(skb).sk);
+
+	return nlk->flags & NETLINK_F_STRICT_CHK;
+}
+EXPORT_SYMBOL_GPL(netlink_strict_get_check);
+
 static int netlink_broadcast_deliver(struct sock *sk, struct sk_buff *skb)
 {
 	struct netlink_sock *nlk = nlk_sk(sk);
@@ -1597,6 +1602,7 @@
 int netlink_set_err(struct sock *ssk, u32 portid, u32 group, int code)
 {
 	struct netlink_set_err_data info;
+	unsigned long flags;
 	struct sock *sk;
 	int ret = 0;
 
@@ -1606,12 +1612,12 @@
 	/* sk->sk_err wants a positive error value */
 	info.code = -code;
 
-	read_lock(&nl_table_lock);
+	read_lock_irqsave(&nl_table_lock, flags);
 
 	sk_for_each_bound(sk, &nl_table[ssk->sk_protocol].mc_list)
 		ret += do_one_set_err(sk, &info);
 
-	read_unlock(&nl_table_lock);
+	read_unlock_irqrestore(&nl_table_lock, flags);
 	return ret;
 }
 EXPORT_SYMBOL(netlink_set_err);
@@ -1634,7 +1640,7 @@
 }
 
 static int netlink_setsockopt(struct socket *sock, int level, int optname,
-			      char __user *optval, unsigned int optlen)
+			      sockptr_t optval, unsigned int optlen)
 {
 	struct sock *sk = sock->sk;
 	struct netlink_sock *nlk = nlk_sk(sk);
@@ -1645,7 +1651,7 @@
 		return -ENOPROTOOPT;
 
 	if (optlen >= sizeof(int) &&
-	    get_user(val, (unsigned int __user *)optval))
+	    copy_from_sockptr(&val, optval, sizeof(val)))
 		return -EFAULT;
 
 	switch (optname) {
@@ -1721,6 +1727,13 @@
 			nlk->flags &= ~NETLINK_F_EXT_ACK;
 		err = 0;
 		break;
+	case NETLINK_GET_STRICT_CHK:
+		if (val)
+			nlk->flags |= NETLINK_F_STRICT_CHK;
+		else
+			nlk->flags &= ~NETLINK_F_STRICT_CHK;
+		err = 0;
+		break;
 	default:
 		err = -ENOPROTOOPT;
 	}
@@ -1732,7 +1745,8 @@
 {
 	struct sock *sk = sock->sk;
 	struct netlink_sock *nlk = nlk_sk(sk);
-	int len, val, err;
+	unsigned int flag;
+	int len, val;
 
 	if (level != SOL_NETLINK)
 		return -ENOPROTOOPT;
@@ -1744,39 +1758,17 @@
 
 	switch (optname) {
 	case NETLINK_PKTINFO:
-		if (len < sizeof(int))
-			return -EINVAL;
-		len = sizeof(int);
-		val = nlk->flags & NETLINK_F_RECV_PKTINFO ? 1 : 0;
-		if (put_user(len, optlen) ||
-		    put_user(val, optval))
-			return -EFAULT;
-		err = 0;
+		flag = NETLINK_F_RECV_PKTINFO;
 		break;
 	case NETLINK_BROADCAST_ERROR:
-		if (len < sizeof(int))
-			return -EINVAL;
-		len = sizeof(int);
-		val = nlk->flags & NETLINK_F_BROADCAST_SEND_ERROR ? 1 : 0;
-		if (put_user(len, optlen) ||
-		    put_user(val, optval))
-			return -EFAULT;
-		err = 0;
+		flag = NETLINK_F_BROADCAST_SEND_ERROR;
 		break;
 	case NETLINK_NO_ENOBUFS:
-		if (len < sizeof(int))
-			return -EINVAL;
-		len = sizeof(int);
-		val = nlk->flags & NETLINK_F_RECV_NO_ENOBUFS ? 1 : 0;
-		if (put_user(len, optlen) ||
-		    put_user(val, optval))
-			return -EFAULT;
-		err = 0;
+		flag = NETLINK_F_RECV_NO_ENOBUFS;
 		break;
 	case NETLINK_LIST_MEMBERSHIPS: {
-		int pos, idx, shift;
+		int pos, idx, shift, err = 0;
 
-		err = 0;
 		netlink_lock_table();
 		for (pos = 0; pos * 8 < nlk->ngroups; pos += sizeof(u32)) {
 			if (len - pos < sizeof(u32))
@@ -1790,34 +1782,35 @@
 				break;
 			}
 		}
-		if (put_user(ALIGN(nlk->ngroups / 8, sizeof(u32)), optlen))
+		if (put_user(ALIGN(BITS_TO_BYTES(nlk->ngroups), sizeof(u32)), optlen))
 			err = -EFAULT;
 		netlink_unlock_table();
-		break;
+		return err;
 	}
 	case NETLINK_CAP_ACK:
-		if (len < sizeof(int))
-			return -EINVAL;
-		len = sizeof(int);
-		val = nlk->flags & NETLINK_F_CAP_ACK ? 1 : 0;
-		if (put_user(len, optlen) ||
-		    put_user(val, optval))
-			return -EFAULT;
-		err = 0;
+		flag = NETLINK_F_CAP_ACK;
 		break;
 	case NETLINK_EXT_ACK:
-		if (len < sizeof(int))
-			return -EINVAL;
-		len = sizeof(int);
-		val = nlk->flags & NETLINK_F_EXT_ACK ? 1 : 0;
-		if (put_user(len, optlen) || put_user(val, optval))
-			return -EFAULT;
-		err = 0;
+		flag = NETLINK_F_EXT_ACK;
+		break;
+	case NETLINK_GET_STRICT_CHK:
+		flag = NETLINK_F_STRICT_CHK;
 		break;
 	default:
-		err = -ENOPROTOOPT;
+		return -ENOPROTOOPT;
 	}
-	return err;
+
+	if (len < sizeof(int))
+		return -EINVAL;
+
+	len = sizeof(int);
+	val = nlk->flags & flag ? 1 : 0;
+
+	if (put_user(len, optlen) ||
+	    copy_to_user(optval, &val, len))
+		return -EFAULT;
+
+	return 0;
 }
 
 static void netlink_cmsg_recv_pktinfo(struct msghdr *msg, struct sk_buff *skb)
@@ -1850,7 +1843,7 @@
 	struct scm_cookie scm;
 	u32 netlink_skb_flags = 0;
 
-	if (msg->msg_flags&MSG_OOB)
+	if (msg->msg_flags & MSG_OOB)
 		return -EOPNOTSUPP;
 
 	if (len == 0) {
@@ -1876,8 +1869,9 @@
 			goto out;
 		netlink_skb_flags |= NETLINK_SKB_DST;
 	} else {
-		dst_portid = nlk->dst_portid;
-		dst_group = nlk->dst_group;
+		/* Paired with WRITE_ONCE() in netlink_connect() */
+		dst_portid = READ_ONCE(nlk->dst_portid);
+		dst_group = READ_ONCE(nlk->dst_group);
 	}
 
 	/* Paired with WRITE_ONCE() in netlink_insert() */
@@ -1919,7 +1913,7 @@
 		refcount_inc(&skb->users);
 		netlink_broadcast(sk, skb, dst_portid, dst_group, GFP_KERNEL);
 	}
-	err = netlink_unicast(sk, skb, dst_portid, msg->msg_flags&MSG_DONTWAIT);
+	err = netlink_unicast(sk, skb, dst_portid, msg->msg_flags & MSG_DONTWAIT);
 
 out:
 	scm_destroy(&scm);
@@ -1932,12 +1926,12 @@
 	struct scm_cookie scm;
 	struct sock *sk = sock->sk;
 	struct netlink_sock *nlk = nlk_sk(sk);
-	int noblock = flags&MSG_DONTWAIT;
+	int noblock = flags & MSG_DONTWAIT;
 	size_t copied;
 	struct sk_buff *skb, *data_skb;
 	int err, ret;
 
-	if (flags&MSG_OOB)
+	if (flags & MSG_OOB)
 		return -EOPNOTSUPP;
 
 	copied = 0;
@@ -1976,7 +1970,6 @@
 		copied = len;
 	}
 
-	skb_reset_transport_header(data_skb);
 	err = skb_copy_datagram_msg(data_skb, 0, msg, copied);
 
 	if (msg->msg_name) {
@@ -2000,7 +1993,7 @@
 
 	skb_free_datagram(sk, skb);
 
-	if (nlk->cb_running &&
+	if (READ_ONCE(nlk->cb_running) &&
 	    atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf / 2) {
 		ret = netlink_dump(sk);
 		if (ret) {
@@ -2189,12 +2182,35 @@
  * It would be better to create kernel thread.
  */
 
+static int netlink_dump_done(struct netlink_sock *nlk, struct sk_buff *skb,
+			     struct netlink_callback *cb,
+			     struct netlink_ext_ack *extack)
+{
+	struct nlmsghdr *nlh;
+
+	nlh = nlmsg_put_answer(skb, cb, NLMSG_DONE, sizeof(nlk->dump_done_errno),
+			       NLM_F_MULTI | cb->answer_flags);
+	if (WARN_ON(!nlh))
+		return -ENOBUFS;
+
+	nl_dump_check_consistent(cb, nlh);
+	memcpy(nlmsg_data(nlh), &nlk->dump_done_errno, sizeof(nlk->dump_done_errno));
+
+	if (extack->_msg && nlk->flags & NETLINK_F_EXT_ACK) {
+		nlh->nlmsg_flags |= NLM_F_ACK_TLVS;
+		if (!nla_put_string(skb, NLMSGERR_ATTR_MSG, extack->_msg))
+			nlmsg_end(skb, nlh);
+	}
+
+	return 0;
+}
+
 static int netlink_dump(struct sock *sk)
 {
 	struct netlink_sock *nlk = nlk_sk(sk);
+	struct netlink_ext_ack extack = {};
 	struct netlink_callback *cb;
 	struct sk_buff *skb = NULL;
-	struct nlmsghdr *nlh;
 	struct module *module;
 	int err = -ENOBUFS;
 	int alloc_min_size;
@@ -2241,10 +2257,20 @@
 	 * single netdev. The outcome is MSG_TRUNC error.
 	 */
 	skb_reserve(skb, skb_tailroom(skb) - alloc_size);
+
+	/* Make sure malicious BPF programs can not read unitialized memory
+	 * from skb->head -> skb->data
+	 */
+	skb_reset_network_header(skb);
+	skb_reset_mac_header(skb);
+
 	netlink_skb_set_owner_r(skb, sk);
 
-	if (nlk->dump_done_errno > 0)
+	if (nlk->dump_done_errno > 0) {
+		cb->extack = &extack;
 		nlk->dump_done_errno = cb->dump(skb, cb);
+		cb->extack = NULL;
+	}
 
 	if (nlk->dump_done_errno > 0 ||
 	    skb_tailroom(skb) < nlmsg_total_size(sizeof(nlk->dump_done_errno))) {
@@ -2257,15 +2283,19 @@
 		return 0;
 	}
 
-	nlh = nlmsg_put_answer(skb, cb, NLMSG_DONE,
-			       sizeof(nlk->dump_done_errno), NLM_F_MULTI);
-	if (WARN_ON(!nlh))
+	if (netlink_dump_done(nlk, skb, cb, &extack))
 		goto errout_skb;
 
-	nl_dump_check_consistent(cb, nlh);
-
-	memcpy(nlmsg_data(nlh), &nlk->dump_done_errno,
-	       sizeof(nlk->dump_done_errno));
+#ifdef CONFIG_COMPAT_NETLINK_MESSAGES
+	/* frag_list skb's data is used for compat tasks
+	 * and the regular skb's data for normal (non-compat) tasks.
+	 * See netlink_recvmsg().
+	 */
+	if (unlikely(skb_shinfo(skb)->frag_list)) {
+		if (netlink_dump_done(nlk, skb_shinfo(skb)->frag_list, cb, &extack))
+			goto errout_skb;
+	}
+#endif
 
 	if (sk_filter(sk, skb))
 		kfree_skb(skb);
@@ -2275,7 +2305,7 @@
 	if (cb->done)
 		cb->done(cb);
 
-	nlk->cb_running = false;
+	WRITE_ONCE(nlk->cb_running, false);
 	module = cb->module;
 	skb = cb->skb;
 	mutex_unlock(nlk->cb_mutex);
@@ -2293,9 +2323,9 @@
 			 const struct nlmsghdr *nlh,
 			 struct netlink_dump_control *control)
 {
+	struct netlink_sock *nlk, *nlk2;
 	struct netlink_callback *cb;
 	struct sock *sk;
-	struct netlink_sock *nlk;
 	int ret;
 
 	refcount_inc(&skb->users);
@@ -2329,13 +2359,16 @@
 	cb->min_dump_alloc = control->min_dump_alloc;
 	cb->skb = skb;
 
+	nlk2 = nlk_sk(NETLINK_CB(skb).sk);
+	cb->strict_check = !!(nlk2->flags & NETLINK_F_STRICT_CHK);
+
 	if (control->start) {
 		ret = control->start(cb);
 		if (ret)
 			goto error_put;
 	}
 
-	nlk->cb_running = true;
+	WRITE_ONCE(nlk->cb_running, true);
 	nlk->dump_done_errno = INT_MAX;
 
 	mutex_unlock(nlk->cb_mutex);
@@ -2382,19 +2415,16 @@
 	if (nlk_has_extack && extack && extack->_msg)
 		tlvlen += nla_total_size(strlen(extack->_msg) + 1);
 
-	if (err) {
-		if (!(nlk->flags & NETLINK_F_CAP_ACK))
-			payload += nlmsg_len(nlh);
-		else
-			flags |= NLM_F_CAPPED;
-		if (nlk_has_extack && extack && extack->bad_attr)
-			tlvlen += nla_total_size(sizeof(u32));
-	} else {
+	if (err && !(nlk->flags & NETLINK_F_CAP_ACK))
+		payload += nlmsg_len(nlh);
+	else
 		flags |= NLM_F_CAPPED;
-
-		if (nlk_has_extack && extack && extack->cookie_len)
-			tlvlen += nla_total_size(extack->cookie_len);
-	}
+	if (err && nlk_has_extack && extack && extack->bad_attr)
+		tlvlen += nla_total_size(sizeof(u32));
+	if (nlk_has_extack && extack && extack->cookie_len)
+		tlvlen += nla_total_size(extack->cookie_len);
+	if (err && nlk_has_extack && extack && extack->policy)
+		tlvlen += netlink_policy_dump_attr_size_estimate(extack->policy);
 
 	if (tlvlen)
 		flags |= NLM_F_ACK_TLVS;
@@ -2417,20 +2447,19 @@
 			WARN_ON(nla_put_string(skb, NLMSGERR_ATTR_MSG,
 					       extack->_msg));
 		}
-		if (err) {
-			if (extack->bad_attr &&
-			    !WARN_ON((u8 *)extack->bad_attr < in_skb->data ||
-				     (u8 *)extack->bad_attr >= in_skb->data +
-							       in_skb->len))
-				WARN_ON(nla_put_u32(skb, NLMSGERR_ATTR_OFFS,
-						    (u8 *)extack->bad_attr -
-						    (u8 *)nlh));
-		} else {
-			if (extack->cookie_len)
-				WARN_ON(nla_put(skb, NLMSGERR_ATTR_COOKIE,
-						extack->cookie_len,
-						extack->cookie));
-		}
+		if (err && extack->bad_attr &&
+		    !WARN_ON((u8 *)extack->bad_attr < in_skb->data ||
+			     (u8 *)extack->bad_attr >= in_skb->data +
+						       in_skb->len))
+			WARN_ON(nla_put_u32(skb, NLMSGERR_ATTR_OFFS,
+					    (u8 *)extack->bad_attr -
+					    (u8 *)nlh));
+		if (extack->cookie_len)
+			WARN_ON(nla_put(skb, NLMSGERR_ATTR_COOKIE,
+					extack->cookie_len, extack->cookie));
+		if (extack->policy)
+			netlink_policy_dump_write_attr(skb, extack->policy,
+						       NLMSGERR_ATTR_POLICY);
 	}
 
 	nlmsg_end(skb, rep);
@@ -2532,20 +2561,10 @@
 	int link;
 };
 
-static int netlink_walk_start(struct nl_seq_iter *iter)
+static void netlink_walk_start(struct nl_seq_iter *iter)
 {
-	int err;
-
-	err = rhashtable_walk_init(&nl_table[iter->link].hash, &iter->hti,
-				   GFP_KERNEL);
-	if (err) {
-		iter->link = MAX_LINKS;
-		return err;
-	}
-
+	rhashtable_walk_enter(&nl_table[iter->link].hash, &iter->hti);
 	rhashtable_walk_start(&iter->hti);
-
-	return 0;
 }
 
 static void netlink_walk_stop(struct nl_seq_iter *iter)
@@ -2561,8 +2580,6 @@
 
 	do {
 		for (;;) {
-			int err;
-
 			nlk = rhashtable_walk_next(&iter->hti);
 
 			if (IS_ERR(nlk)) {
@@ -2579,9 +2596,7 @@
 			if (++iter->link >= MAX_LINKS)
 				return NULL;
 
-			err = netlink_walk_start(iter);
-			if (err)
-				return ERR_PTR(err);
+			netlink_walk_start(iter);
 		}
 	} while (sock_net(&nlk->sk) != seq_file_net(seq));
 
@@ -2589,17 +2604,15 @@
 }
 
 static void *netlink_seq_start(struct seq_file *seq, loff_t *posp)
+	__acquires(RCU)
 {
 	struct nl_seq_iter *iter = seq->private;
 	void *obj = SEQ_START_TOKEN;
 	loff_t pos;
-	int err;
 
 	iter->link = 0;
 
-	err = netlink_walk_start(iter);
-	if (err)
-		return ERR_PTR(err);
+	netlink_walk_start(iter);
 
 	for (pos = *posp; pos && obj && !IS_ERR(obj); pos--)
 		obj = __netlink_seq_next(seq);
@@ -2613,7 +2626,7 @@
 	return __netlink_seq_next(seq);
 }
 
-static void netlink_seq_stop(struct seq_file *seq, void *v)
+static void netlink_native_seq_stop(struct seq_file *seq, void *v)
 {
 	struct nl_seq_iter *iter = seq->private;
 
@@ -2624,7 +2637,7 @@
 }
 
 
-static int netlink_seq_show(struct seq_file *seq, void *v)
+static int netlink_native_seq_show(struct seq_file *seq, void *v)
 {
 	if (v == SEQ_START_TOKEN) {
 		seq_puts(seq,
@@ -2634,14 +2647,14 @@
 		struct sock *s = v;
 		struct netlink_sock *nlk = nlk_sk(s);
 
-		seq_printf(seq, "%pK %-3d %-10u %08x %-8d %-8d %-5d %-8d %-8d %-8lu\n",
+		seq_printf(seq, "%pK %-3d %-10u %08x %-8d %-8d %-5d %-8d %-8u %-8lu\n",
 			   s,
 			   s->sk_protocol,
 			   nlk->portid,
 			   nlk->groups ? (u32)nlk->groups[0] : 0,
 			   sk_rmem_alloc_get(s),
 			   sk_wmem_alloc_get(s),
-			   nlk->cb_running,
+			   READ_ONCE(nlk->cb_running),
 			   refcount_read(&s->sk_refcnt),
 			   atomic_read(&s->sk_drops),
 			   sock_i_ino(s)
@@ -2650,6 +2663,68 @@
 	}
 	return 0;
 }
+
+#ifdef CONFIG_BPF_SYSCALL
+struct bpf_iter__netlink {
+	__bpf_md_ptr(struct bpf_iter_meta *, meta);
+	__bpf_md_ptr(struct netlink_sock *, sk);
+};
+
+DEFINE_BPF_ITER_FUNC(netlink, struct bpf_iter_meta *meta, struct netlink_sock *sk)
+
+static int netlink_prog_seq_show(struct bpf_prog *prog,
+				  struct bpf_iter_meta *meta,
+				  void *v)
+{
+	struct bpf_iter__netlink ctx;
+
+	meta->seq_num--;  /* skip SEQ_START_TOKEN */
+	ctx.meta = meta;
+	ctx.sk = nlk_sk((struct sock *)v);
+	return bpf_iter_run_prog(prog, &ctx);
+}
+
+static int netlink_seq_show(struct seq_file *seq, void *v)
+{
+	struct bpf_iter_meta meta;
+	struct bpf_prog *prog;
+
+	meta.seq = seq;
+	prog = bpf_iter_get_info(&meta, false);
+	if (!prog)
+		return netlink_native_seq_show(seq, v);
+
+	if (v != SEQ_START_TOKEN)
+		return netlink_prog_seq_show(prog, &meta, v);
+
+	return 0;
+}
+
+static void netlink_seq_stop(struct seq_file *seq, void *v)
+{
+	struct bpf_iter_meta meta;
+	struct bpf_prog *prog;
+
+	if (!v) {
+		meta.seq = seq;
+		prog = bpf_iter_get_info(&meta, true);
+		if (prog)
+			(void)netlink_prog_seq_show(prog, &meta, v);
+	}
+
+	netlink_native_seq_stop(seq, v);
+}
+#else
+static int netlink_seq_show(struct seq_file *seq, void *v)
+{
+	return netlink_native_seq_show(seq, v);
+}
+
+static void netlink_seq_stop(struct seq_file *seq, void *v)
+{
+	netlink_native_seq_stop(seq, v);
+}
+#endif
 
 static const struct seq_operations netlink_seq_ops = {
 	.start  = netlink_seq_start,
@@ -2757,6 +2832,34 @@
 	.automatic_shrinking = true,
 };
 
+#if defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS)
+BTF_ID_LIST(btf_netlink_sock_id)
+BTF_ID(struct, netlink_sock)
+
+static const struct bpf_iter_seq_info netlink_seq_info = {
+	.seq_ops		= &netlink_seq_ops,
+	.init_seq_private	= bpf_iter_init_seq_net,
+	.fini_seq_private	= bpf_iter_fini_seq_net,
+	.seq_priv_size		= sizeof(struct nl_seq_iter),
+};
+
+static struct bpf_iter_reg netlink_reg_info = {
+	.target			= "netlink",
+	.ctx_arg_info_size	= 1,
+	.ctx_arg_info		= {
+		{ offsetof(struct bpf_iter__netlink, sk),
+		  PTR_TO_BTF_ID_OR_NULL },
+	},
+	.seq_info		= &netlink_seq_info,
+};
+
+static int __init bpf_iter_register(void)
+{
+	netlink_reg_info.ctx_arg_info[0].btf_id = *btf_netlink_sock_id;
+	return bpf_iter_reg_target(&netlink_reg_info);
+}
+#endif
+
 static int __init netlink_proto_init(void)
 {
 	int i;
@@ -2765,7 +2868,13 @@
 	if (err != 0)
 		goto out;
 
-	BUILD_BUG_ON(sizeof(struct netlink_skb_parms) > FIELD_SIZEOF(struct sk_buff, cb));
+#if defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS)
+	err = bpf_iter_register();
+	if (err)
+		goto out;
+#endif
+
+	BUILD_BUG_ON(sizeof(struct netlink_skb_parms) > sizeof_field(struct sk_buff, cb));
 
 	nl_table = kcalloc(MAX_LINKS, sizeof(*nl_table), GFP_KERNEL);
 	if (!nl_table)
@@ -2794,8 +2903,4 @@
 	panic("netlink_init: Cannot allocate nl_table\n");
 }
 
-#ifdef CONFIG_ROCKCHIP_THUNDER_BOOT
-core_initcall_sync(netlink_proto_init);
-#else
 core_initcall(netlink_proto_init);
-#endif

--
Gitblit v1.6.2