From 8d2a02b24d66aa359e83eebc1ed3c0f85367a1cb Mon Sep 17 00:00:00 2001
From: hc <hc@nodka.com>
Date: Thu, 16 May 2024 03:11:33 +0000
Subject: [PATCH] AX88772C_eeprom and ax8872c build together

---
 kernel/net/ipv4/inet_diag.c |  449 +++++++++++++++++++++++++++++++++++++------------------
 1 files changed, 301 insertions(+), 148 deletions(-)

diff --git a/kernel/net/ipv4/inet_diag.c b/kernel/net/ipv4/inet_diag.c
index 069d96f..fa9f1de 100644
--- a/kernel/net/ipv4/inet_diag.c
+++ b/kernel/net/ipv4/inet_diag.c
@@ -1,12 +1,8 @@
+// SPDX-License-Identifier: GPL-2.0-or-later
 /*
  * inet_diag.c	Module for monitoring INET transport protocols sockets.
  *
  * Authors:	Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru>
- *
- *	This program is free software; you can redistribute it and/or
- *      modify it under the terms of the GNU General Public License
- *      as published by the Free Software Foundation; either version
- *      2 of the License, or (at your option) any later version.
  */
 
 #include <linux/kernel.h>
@@ -27,6 +23,7 @@
 #include <net/inet_hashtables.h>
 #include <net/inet_timewait_sock.h>
 #include <net/inet6_hashtables.h>
+#include <net/bpf_sk_storage.h>
 #include <net/netlink.h>
 
 #include <linux/inet.h>
@@ -46,12 +43,20 @@
 	u16 userlocks;
 	u32 ifindex;
 	u32 mark;
+#ifdef CONFIG_SOCK_CGROUP_DATA
+	u64 cgroup_id;
+#endif
 };
 
 static DEFINE_MUTEX(inet_diag_table_mutex);
 
 static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
 {
+	if (proto < 0 || proto >= IPPROTO_MAX) {
+		mutex_lock(&inet_diag_table_mutex);
+		return ERR_PTR(-ENOENT);
+	}
+
 	if (!inet_diag_table[proto])
 		sock_load_diag_module(AF_INET, proto);
 
@@ -120,6 +125,7 @@
 			     bool net_admin)
 {
 	const struct inet_sock *inet = inet_sk(sk);
+	struct inet_diag_sockopt inet_sockopt;
 
 	if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown))
 		goto errout;
@@ -165,8 +171,31 @@
 			goto errout;
 	}
 
+#ifdef CONFIG_SOCK_CGROUP_DATA
+	if (nla_put_u64_64bit(skb, INET_DIAG_CGROUP_ID,
+			      cgroup_id(sock_cgroup_ptr(&sk->sk_cgrp_data)),
+			      INET_DIAG_PAD))
+		goto errout;
+#endif
+
 	r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk));
 	r->idiag_inode = sock_i_ino(sk);
+
+	memset(&inet_sockopt, 0, sizeof(inet_sockopt));
+	inet_sockopt.recverr	= inet->recverr;
+	inet_sockopt.is_icsk	= inet->is_icsk;
+	inet_sockopt.freebind	= inet->freebind;
+	inet_sockopt.hdrincl	= inet->hdrincl;
+	inet_sockopt.mc_loop	= inet->mc_loop;
+	inet_sockopt.transparent = inet->transparent;
+	inet_sockopt.mc_all	= inet->mc_all;
+	inet_sockopt.nodefrag	= inet->nodefrag;
+	inet_sockopt.bind_address_no_port = inet->bind_address_no_port;
+	inet_sockopt.recverr_rfc4884 = inet->recverr_rfc4884;
+	inet_sockopt.defer_connect = inet->defer_connect;
+	if (nla_put(skb, INET_DIAG_SOCKOPT, sizeof(inet_sockopt),
+		    &inet_sockopt))
+		goto errout;
 
 	return 0;
 errout:
@@ -174,26 +203,54 @@
 }
 EXPORT_SYMBOL_GPL(inet_diag_msg_attrs_fill);
 
+static int inet_diag_parse_attrs(const struct nlmsghdr *nlh, int hdrlen,
+				 struct nlattr **req_nlas)
+{
+	struct nlattr *nla;
+	int remaining;
+
+	nlmsg_for_each_attr(nla, nlh, hdrlen, remaining) {
+		int type = nla_type(nla);
+
+		if (type == INET_DIAG_REQ_PROTOCOL && nla_len(nla) != sizeof(u32))
+			return -EINVAL;
+
+		if (type < __INET_DIAG_REQ_MAX)
+			req_nlas[type] = nla;
+	}
+	return 0;
+}
+
+static int inet_diag_get_protocol(const struct inet_diag_req_v2 *req,
+				  const struct inet_diag_dump_data *data)
+{
+	if (data->req_nlas[INET_DIAG_REQ_PROTOCOL])
+		return nla_get_u32(data->req_nlas[INET_DIAG_REQ_PROTOCOL]);
+	return req->sdiag_protocol;
+}
+
+#define MAX_DUMP_ALLOC_SIZE (KMALLOC_MAX_SIZE - SKB_DATA_ALIGN(sizeof(struct skb_shared_info)))
+
 int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
-		      struct sk_buff *skb, const struct inet_diag_req_v2 *req,
-		      struct user_namespace *user_ns,
-		      u32 portid, u32 seq, u16 nlmsg_flags,
-		      const struct nlmsghdr *unlh,
-		      bool net_admin)
+		      struct sk_buff *skb, struct netlink_callback *cb,
+		      const struct inet_diag_req_v2 *req,
+		      u16 nlmsg_flags, bool net_admin)
 {
 	const struct tcp_congestion_ops *ca_ops;
 	const struct inet_diag_handler *handler;
+	struct inet_diag_dump_data *cb_data;
 	int ext = req->idiag_ext;
 	struct inet_diag_msg *r;
 	struct nlmsghdr  *nlh;
 	struct nlattr *attr;
 	void *info = NULL;
 
-	handler = inet_diag_table[req->sdiag_protocol];
+	cb_data = cb->data;
+	handler = inet_diag_table[inet_diag_get_protocol(req, cb_data)];
 	BUG_ON(!handler);
 
-	nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
-			nlmsg_flags);
+	nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
+			cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags);
 	if (!nlh)
 		return -EMSGSIZE;
 
@@ -204,14 +261,17 @@
 	r->idiag_state = sk->sk_state;
 	r->idiag_timer = 0;
 	r->idiag_retrans = 0;
+	r->idiag_expires = 0;
 
-	if (inet_diag_msg_attrs_fill(sk, skb, r, ext, user_ns, net_admin))
+	if (inet_diag_msg_attrs_fill(sk, skb, r, ext,
+				     sk_user_ns(NETLINK_CB(cb->skb).sk),
+				     net_admin))
 		goto errout;
 
 	if (ext & (1 << (INET_DIAG_MEMINFO - 1))) {
 		struct inet_diag_meminfo minfo = {
 			.idiag_rmem = sk_rmem_alloc_get(sk),
-			.idiag_wmem = sk->sk_wmem_queued,
+			.idiag_wmem = READ_ONCE(sk->sk_wmem_queued),
 			.idiag_fmem = sk->sk_forward_alloc,
 			.idiag_tmem = sk_wmem_alloc_get(sk),
 		};
@@ -244,20 +304,17 @@
 		r->idiag_timer = 1;
 		r->idiag_retrans = icsk->icsk_retransmits;
 		r->idiag_expires =
-			jiffies_to_msecs(icsk->icsk_timeout - jiffies);
+			jiffies_delta_to_msecs(icsk->icsk_timeout - jiffies);
 	} else if (icsk->icsk_pending == ICSK_TIME_PROBE0) {
 		r->idiag_timer = 4;
 		r->idiag_retrans = icsk->icsk_probes_out;
 		r->idiag_expires =
-			jiffies_to_msecs(icsk->icsk_timeout - jiffies);
+			jiffies_delta_to_msecs(icsk->icsk_timeout - jiffies);
 	} else if (timer_pending(&sk->sk_timer)) {
 		r->idiag_timer = 2;
 		r->idiag_retrans = icsk->icsk_probes_out;
 		r->idiag_expires =
-			jiffies_to_msecs(sk->sk_timer.expires - jiffies);
-	} else {
-		r->idiag_timer = 0;
-		r->idiag_expires = 0;
+			jiffies_delta_to_msecs(sk->sk_timer.expires - jiffies);
 	}
 
 	if ((ext & (1 << (INET_DIAG_INFO - 1))) && handler->idiag_info_size) {
@@ -302,6 +359,48 @@
 			goto errout;
 	}
 
+	/* Keep it at the end for potential retry with a larger skb,
+	 * or else do best-effort fitting, which is only done for the
+	 * first_nlmsg.
+	 */
+	if (cb_data->bpf_stg_diag) {
+		bool first_nlmsg = ((unsigned char *)nlh == skb->data);
+		unsigned int prev_min_dump_alloc;
+		unsigned int total_nla_size = 0;
+		unsigned int msg_len;
+		int err;
+
+		msg_len = skb_tail_pointer(skb) - (unsigned char *)nlh;
+		err = bpf_sk_storage_diag_put(cb_data->bpf_stg_diag, sk, skb,
+					      INET_DIAG_SK_BPF_STORAGES,
+					      &total_nla_size);
+
+		if (!err)
+			goto out;
+
+		total_nla_size += msg_len;
+		prev_min_dump_alloc = cb->min_dump_alloc;
+		if (total_nla_size > prev_min_dump_alloc)
+			cb->min_dump_alloc = min_t(u32, total_nla_size,
+						   MAX_DUMP_ALLOC_SIZE);
+
+		if (!first_nlmsg)
+			goto errout;
+
+		if (cb->min_dump_alloc > prev_min_dump_alloc)
+			/* Retry with pskb_expand_head() with
+			 * __GFP_DIRECT_RECLAIM
+			 */
+			goto errout;
+
+		WARN_ON_ONCE(total_nla_size <= prev_min_dump_alloc);
+
+		/* Send what we have for this sk
+		 * and move on to the next sk in the following
+		 * dump()
+		 */
+	}
+
 out:
 	nlmsg_end(skb, nlh);
 	return 0;
@@ -312,46 +411,32 @@
 }
 EXPORT_SYMBOL_GPL(inet_sk_diag_fill);
 
-static int inet_csk_diag_fill(struct sock *sk,
-			      struct sk_buff *skb,
-			      const struct inet_diag_req_v2 *req,
-			      struct user_namespace *user_ns,
-			      u32 portid, u32 seq, u16 nlmsg_flags,
-			      const struct nlmsghdr *unlh,
-			      bool net_admin)
-{
-	return inet_sk_diag_fill(sk, inet_csk(sk), skb, req, user_ns,
-				 portid, seq, nlmsg_flags, unlh, net_admin);
-}
-
 static int inet_twsk_diag_fill(struct sock *sk,
 			       struct sk_buff *skb,
-			       u32 portid, u32 seq, u16 nlmsg_flags,
-			       const struct nlmsghdr *unlh)
+			       struct netlink_callback *cb,
+			       u16 nlmsg_flags)
 {
 	struct inet_timewait_sock *tw = inet_twsk(sk);
 	struct inet_diag_msg *r;
 	struct nlmsghdr *nlh;
 	long tmo;
 
-	nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
-			nlmsg_flags);
+	nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid,
+			cb->nlh->nlmsg_seq, cb->nlh->nlmsg_type,
+			sizeof(*r), nlmsg_flags);
 	if (!nlh)
 		return -EMSGSIZE;
 
 	r = nlmsg_data(nlh);
 	BUG_ON(tw->tw_state != TCP_TIME_WAIT);
 
-	tmo = tw->tw_timer.expires - jiffies;
-	if (tmo < 0)
-		tmo = 0;
-
 	inet_diag_msg_common_fill(r, sk);
 	r->idiag_retrans      = 0;
 
 	r->idiag_state	      = tw->tw_substate;
 	r->idiag_timer	      = 3;
-	r->idiag_expires      = jiffies_to_msecs(tmo);
+	tmo = tw->tw_timer.expires - jiffies;
+	r->idiag_expires      = jiffies_delta_to_msecs(tmo);
 	r->idiag_rqueue	      = 0;
 	r->idiag_wqueue	      = 0;
 	r->idiag_uid	      = 0;
@@ -362,16 +447,16 @@
 }
 
 static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
-			      u32 portid, u32 seq, u16 nlmsg_flags,
-			      const struct nlmsghdr *unlh, bool net_admin)
+			      struct netlink_callback *cb,
+			      u16 nlmsg_flags, bool net_admin)
 {
 	struct request_sock *reqsk = inet_reqsk(sk);
 	struct inet_diag_msg *r;
 	struct nlmsghdr *nlh;
 	long tmo;
 
-	nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
-			nlmsg_flags);
+	nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
+			cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags);
 	if (!nlh)
 		return -EMSGSIZE;
 
@@ -385,7 +470,7 @@
 		     offsetof(struct sock, sk_cookie));
 
 	tmo = inet_reqsk(sk)->rsk_timer.expires - jiffies;
-	r->idiag_expires = (tmo >= 0) ? jiffies_to_msecs(tmo) : 0;
+	r->idiag_expires = jiffies_delta_to_msecs(tmo);
 	r->idiag_rqueue	= 0;
 	r->idiag_wqueue	= 0;
 	r->idiag_uid	= 0;
@@ -402,21 +487,18 @@
 }
 
 static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
+			struct netlink_callback *cb,
 			const struct inet_diag_req_v2 *r,
-			struct user_namespace *user_ns,
-			u32 portid, u32 seq, u16 nlmsg_flags,
-			const struct nlmsghdr *unlh, bool net_admin)
+			u16 nlmsg_flags, bool net_admin)
 {
 	if (sk->sk_state == TCP_TIME_WAIT)
-		return inet_twsk_diag_fill(sk, skb, portid, seq,
-					   nlmsg_flags, unlh);
+		return inet_twsk_diag_fill(sk, skb, cb, nlmsg_flags);
 
 	if (sk->sk_state == TCP_NEW_SYN_RECV)
-		return inet_req_diag_fill(sk, skb, portid, seq,
-					  nlmsg_flags, unlh, net_admin);
+		return inet_req_diag_fill(sk, skb, cb, nlmsg_flags, net_admin);
 
-	return inet_csk_diag_fill(sk, skb, r, user_ns, portid, seq,
-				  nlmsg_flags, unlh, net_admin);
+	return inet_sk_diag_fill(sk, inet_csk(sk), skb, cb, r, nlmsg_flags,
+				 net_admin);
 }
 
 struct sock *inet_diag_find_one_icsk(struct net *net,
@@ -464,10 +546,10 @@
 EXPORT_SYMBOL_GPL(inet_diag_find_one_icsk);
 
 int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
-			    struct sk_buff *in_skb,
-			    const struct nlmsghdr *nlh,
+			    struct netlink_callback *cb,
 			    const struct inet_diag_req_v2 *req)
 {
+	struct sk_buff *in_skb = cb->skb;
 	bool net_admin = netlink_net_capable(in_skb, CAP_NET_ADMIN);
 	struct net *net = sock_net(in_skb->sk);
 	struct sk_buff *rep;
@@ -484,10 +566,7 @@
 		goto out;
 	}
 
-	err = sk_diag_fill(sk, rep, req,
-			   sk_user_ns(NETLINK_CB(in_skb).sk),
-			   NETLINK_CB(in_skb).portid,
-			   nlh->nlmsg_seq, 0, nlh, net_admin);
+	err = sk_diag_fill(sk, rep, cb, req, 0, net_admin);
 	if (err < 0) {
 		WARN_ON(err == -EMSGSIZE);
 		nlmsg_free(rep);
@@ -508,20 +587,35 @@
 
 static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb,
 			       const struct nlmsghdr *nlh,
+			       int hdrlen,
 			       const struct inet_diag_req_v2 *req)
 {
 	const struct inet_diag_handler *handler;
-	int err;
+	struct inet_diag_dump_data dump_data;
+	int err, protocol;
 
-	handler = inet_diag_lock_handler(req->sdiag_protocol);
-	if (IS_ERR(handler))
+	memset(&dump_data, 0, sizeof(dump_data));
+	err = inet_diag_parse_attrs(nlh, hdrlen, dump_data.req_nlas);
+	if (err)
+		return err;
+
+	protocol = inet_diag_get_protocol(req, &dump_data);
+
+	handler = inet_diag_lock_handler(protocol);
+	if (IS_ERR(handler)) {
 		err = PTR_ERR(handler);
-	else if (cmd == SOCK_DIAG_BY_FAMILY)
-		err = handler->dump_one(in_skb, nlh, req);
-	else if (cmd == SOCK_DESTROY && handler->destroy)
+	} else if (cmd == SOCK_DIAG_BY_FAMILY) {
+		struct netlink_callback cb = {
+			.nlh = nlh,
+			.skb = in_skb,
+			.data = &dump_data,
+		};
+		err = handler->dump_one(&cb, req);
+	} else if (cmd == SOCK_DESTROY && handler->destroy) {
 		err = handler->destroy(in_skb, req);
-	else
+	} else {
 		err = -EOPNOTSUPP;
+	}
 	inet_diag_unlock_handler(handler);
 
 	return err;
@@ -647,6 +741,16 @@
 				yes = 0;
 			break;
 		}
+#ifdef CONFIG_SOCK_CGROUP_DATA
+		case INET_DIAG_BC_CGROUP_COND: {
+			u64 cgroup_id;
+
+			cgroup_id = get_unaligned((const u64 *)(op + 1));
+			if (cgroup_id != entry->cgroup_id)
+				yes = 0;
+			break;
+		}
+#endif
 		}
 
 		if (yes) {
@@ -697,6 +801,10 @@
 		entry.mark = inet_rsk(inet_reqsk(sk))->ir_mark;
 	else
 		entry.mark = 0;
+#ifdef CONFIG_SOCK_CGROUP_DATA
+	entry.cgroup_id = sk_fullsock(sk) ?
+		cgroup_id(sock_cgroup_ptr(&sk->sk_cgrp_data)) : 0;
+#endif
 
 	return inet_diag_bc_run(bc, &entry);
 }
@@ -786,6 +894,15 @@
 	return len >= *min_len;
 }
 
+#ifdef CONFIG_SOCK_CGROUP_DATA
+static bool valid_cgroupcond(const struct inet_diag_bc_op *op, int len,
+			     int *min_len)
+{
+	*min_len += sizeof(u64);
+	return len >= *min_len;
+}
+#endif
+
 static int inet_diag_bc_audit(const struct nlattr *attr,
 			      const struct sk_buff *skb)
 {
@@ -828,6 +945,12 @@
 			if (!valid_markcond(bc, len, &min_len))
 				return -EINVAL;
 			break;
+#ifdef CONFIG_SOCK_CGROUP_DATA
+		case INET_DIAG_BC_CGROUP_COND:
+			if (!valid_cgroupcond(bc, len, &min_len))
+				return -EINVAL;
+			break;
+#endif
 		case INET_DIAG_BC_AUTO:
 		case INET_DIAG_BC_JMP:
 		case INET_DIAG_BC_NOP:
@@ -850,23 +973,6 @@
 		len -= op->yes;
 	}
 	return len == 0 ? 0 : -EINVAL;
-}
-
-static int inet_csk_diag_dump(struct sock *sk,
-			      struct sk_buff *skb,
-			      struct netlink_callback *cb,
-			      const struct inet_diag_req_v2 *r,
-			      const struct nlattr *bc,
-			      bool net_admin)
-{
-	if (!inet_diag_bc_sk(bc, sk))
-		return 0;
-
-	return inet_csk_diag_fill(sk, skb, r,
-				  sk_user_ns(NETLINK_CB(cb->skb).sk),
-				  NETLINK_CB(cb->skb).portid,
-				  cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh,
-				  net_admin);
 }
 
 static void twsk_build_assert(void)
@@ -897,14 +1003,17 @@
 
 void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
 			 struct netlink_callback *cb,
-			 const struct inet_diag_req_v2 *r, struct nlattr *bc)
+			 const struct inet_diag_req_v2 *r)
 {
 	bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN);
+	struct inet_diag_dump_data *cb_data = cb->data;
 	struct net *net = sock_net(skb->sk);
 	u32 idiag_states = r->idiag_states;
 	int i, num, s_i, s_num;
+	struct nlattr *bc;
 	struct sock *sk;
 
+	bc = cb_data->inet_diag_nla_bc;
 	if (idiag_states & TCPF_SYN_RECV)
 		idiag_states |= TCPF_NEW_SYN_RECV;
 	s_i = cb->args[1];
@@ -940,8 +1049,12 @@
 				    r->id.idiag_sport)
 					goto next_listen;
 
-				if (inet_csk_diag_dump(sk, skb, cb, r,
-						       bc, net_admin) < 0) {
+				if (!inet_diag_bc_sk(bc, sk))
+					goto next_listen;
+
+				if (inet_sk_diag_fill(sk, inet_csk(sk), skb,
+						      cb, r, NLM_F_MULTI,
+						      net_admin) < 0) {
 					spin_unlock(&ilb->lock);
 					goto done;
 				}
@@ -1019,11 +1132,8 @@
 		res = 0;
 		for (idx = 0; idx < accum; idx++) {
 			if (res >= 0) {
-				res = sk_diag_fill(sk_arr[idx], skb, r,
-					   sk_user_ns(NETLINK_CB(cb->skb).sk),
-					   NETLINK_CB(cb->skb).portid,
-					   cb->nlh->nlmsg_seq, NLM_F_MULTI,
-					   cb->nlh, net_admin);
+				res = sk_diag_fill(sk_arr[idx], skb, cb, r,
+						   NLM_F_MULTI, net_admin);
 				if (res < 0)
 					num = num_arr[idx];
 			}
@@ -1047,31 +1157,101 @@
 EXPORT_SYMBOL_GPL(inet_diag_dump_icsk);
 
 static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
-			    const struct inet_diag_req_v2 *r,
-			    struct nlattr *bc)
+			    const struct inet_diag_req_v2 *r)
 {
+	struct inet_diag_dump_data *cb_data = cb->data;
 	const struct inet_diag_handler *handler;
-	int err = 0;
+	u32 prev_min_dump_alloc;
+	int protocol, err = 0;
 
-	handler = inet_diag_lock_handler(r->sdiag_protocol);
+	protocol = inet_diag_get_protocol(r, cb_data);
+
+again:
+	prev_min_dump_alloc = cb->min_dump_alloc;
+	handler = inet_diag_lock_handler(protocol);
 	if (!IS_ERR(handler))
-		handler->dump(skb, cb, r, bc);
+		handler->dump(skb, cb, r);
 	else
 		err = PTR_ERR(handler);
 	inet_diag_unlock_handler(handler);
+
+	/* The skb is not large enough to fit one sk info and
+	 * inet_sk_diag_fill() has requested for a larger skb.
+	 */
+	if (!skb->len && cb->min_dump_alloc > prev_min_dump_alloc) {
+		err = pskb_expand_head(skb, 0, cb->min_dump_alloc, GFP_KERNEL);
+		if (!err)
+			goto again;
+	}
 
 	return err ? : skb->len;
 }
 
 static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
 {
-	int hdrlen = sizeof(struct inet_diag_req_v2);
-	struct nlattr *bc = NULL;
+	return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh));
+}
 
-	if (nlmsg_attrlen(cb->nlh, hdrlen))
-		bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
+static int __inet_diag_dump_start(struct netlink_callback *cb, int hdrlen)
+{
+	const struct nlmsghdr *nlh = cb->nlh;
+	struct inet_diag_dump_data *cb_data;
+	struct sk_buff *skb = cb->skb;
+	struct nlattr *nla;
+	int err;
 
-	return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh), bc);
+	cb_data = kzalloc(sizeof(*cb_data), GFP_KERNEL);
+	if (!cb_data)
+		return -ENOMEM;
+
+	err = inet_diag_parse_attrs(nlh, hdrlen, cb_data->req_nlas);
+	if (err) {
+		kfree(cb_data);
+		return err;
+	}
+	nla = cb_data->inet_diag_nla_bc;
+	if (nla) {
+		err = inet_diag_bc_audit(nla, skb);
+		if (err) {
+			kfree(cb_data);
+			return err;
+		}
+	}
+
+	nla = cb_data->inet_diag_nla_bpf_stgs;
+	if (nla) {
+		struct bpf_sk_storage_diag *bpf_stg_diag;
+
+		bpf_stg_diag = bpf_sk_storage_diag_alloc(nla);
+		if (IS_ERR(bpf_stg_diag)) {
+			kfree(cb_data);
+			return PTR_ERR(bpf_stg_diag);
+		}
+		cb_data->bpf_stg_diag = bpf_stg_diag;
+	}
+
+	cb->data = cb_data;
+	return 0;
+}
+
+static int inet_diag_dump_start(struct netlink_callback *cb)
+{
+	return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req_v2));
+}
+
+static int inet_diag_dump_start_compat(struct netlink_callback *cb)
+{
+	return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req));
+}
+
+static int inet_diag_dump_done(struct netlink_callback *cb)
+{
+	struct inet_diag_dump_data *cb_data = cb->data;
+
+	bpf_sk_storage_diag_free(cb_data->bpf_stg_diag);
+	kfree(cb->data);
+
+	return 0;
 }
 
 static int inet_diag_type2proto(int type)
@@ -1090,9 +1270,7 @@
 				 struct netlink_callback *cb)
 {
 	struct inet_diag_req *rc = nlmsg_data(cb->nlh);
-	int hdrlen = sizeof(struct inet_diag_req);
 	struct inet_diag_req_v2 req;
-	struct nlattr *bc = NULL;
 
 	req.sdiag_family = AF_UNSPEC; /* compatibility */
 	req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type);
@@ -1100,10 +1278,7 @@
 	req.idiag_states = rc->idiag_states;
 	req.id = rc->id;
 
-	if (nlmsg_attrlen(cb->nlh, hdrlen))
-		bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
-
-	return __inet_diag_dump(skb, cb, &req, bc);
+	return __inet_diag_dump(skb, cb, &req);
 }
 
 static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
@@ -1118,7 +1293,8 @@
 	req.idiag_states = rc->idiag_states;
 	req.id = rc->id;
 
-	return inet_diag_cmd_exact(SOCK_DIAG_BY_FAMILY, in_skb, nlh, &req);
+	return inet_diag_cmd_exact(SOCK_DIAG_BY_FAMILY, in_skb, nlh,
+				   sizeof(struct inet_diag_req), &req);
 }
 
 static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
@@ -1131,22 +1307,12 @@
 		return -EINVAL;
 
 	if (nlh->nlmsg_flags & NLM_F_DUMP) {
-		if (nlmsg_attrlen(nlh, hdrlen)) {
-			struct nlattr *attr;
-			int err;
-
-			attr = nlmsg_find_attr(nlh, hdrlen,
-					       INET_DIAG_REQ_BYTECODE);
-			err = inet_diag_bc_audit(attr, skb);
-			if (err)
-				return err;
-		}
-		{
-			struct netlink_dump_control c = {
-				.dump = inet_diag_dump_compat,
-			};
-			return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);
-		}
+		struct netlink_dump_control c = {
+			.start = inet_diag_dump_start_compat,
+			.done = inet_diag_dump_done,
+			.dump = inet_diag_dump_compat,
+		};
+		return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);
 	}
 
 	return inet_diag_get_exact_compat(skb, nlh);
@@ -1162,25 +1328,16 @@
 
 	if (h->nlmsg_type == SOCK_DIAG_BY_FAMILY &&
 	    h->nlmsg_flags & NLM_F_DUMP) {
-		if (nlmsg_attrlen(h, hdrlen)) {
-			struct nlattr *attr;
-			int err;
-
-			attr = nlmsg_find_attr(h, hdrlen,
-					       INET_DIAG_REQ_BYTECODE);
-			err = inet_diag_bc_audit(attr, skb);
-			if (err)
-				return err;
-		}
-		{
-			struct netlink_dump_control c = {
-				.dump = inet_diag_dump,
-			};
-			return netlink_dump_start(net->diag_nlsk, skb, h, &c);
-		}
+		struct netlink_dump_control c = {
+			.start = inet_diag_dump_start,
+			.done = inet_diag_dump_done,
+			.dump = inet_diag_dump,
+		};
+		return netlink_dump_start(net->diag_nlsk, skb, h, &c);
 	}
 
-	return inet_diag_cmd_exact(h->nlmsg_type, skb, h, nlmsg_data(h));
+	return inet_diag_cmd_exact(h->nlmsg_type, skb, h, hdrlen,
+				   nlmsg_data(h));
 }
 
 static
@@ -1315,11 +1472,7 @@
 	kfree(inet_diag_table);
 }
 
-#ifdef CONFIG_ROCKCHIP_THUNDER_BOOT
-rootfs_initcall(inet_diag_init);
-#else
 module_init(inet_diag_init);
-#endif
 module_exit(inet_diag_exit);
 MODULE_LICENSE("GPL");
 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */);

--
Gitblit v1.6.2