From 102a0743326a03cd1a1202ceda21e175b7d3575c Mon Sep 17 00:00:00 2001
From: hc <hc@nodka.com>
Date: Tue, 20 Feb 2024 01:20:52 +0000
Subject: [PATCH] add new system file

---
 kernel/net/core/net_namespace.c |  336 +++++++++++++++++++++++++++++++++++++++++++++-----------
 1 files changed, 270 insertions(+), 66 deletions(-)

diff --git a/kernel/net/core/net_namespace.c b/kernel/net/core/net_namespace.c
index 3368624..f1258af 100644
--- a/kernel/net/core/net_namespace.c
+++ b/kernel/net/core/net_namespace.c
@@ -1,3 +1,4 @@
+// SPDX-License-Identifier: GPL-2.0-only
 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
 
 #include <linux/workqueue.h>
@@ -18,6 +19,7 @@
 #include <linux/net_namespace.h>
 #include <linux/sched/task.h>
 #include <linux/uidgid.h>
+#include <linux/cookie.h>
 
 #include <net/sock.h>
 #include <net/netlink.h>
@@ -38,9 +40,16 @@
 DECLARE_RWSEM(net_rwsem);
 EXPORT_SYMBOL_GPL(net_rwsem);
 
+#ifdef CONFIG_KEYS
+static struct key_tag init_net_key_domain = { .usage = REFCOUNT_INIT(1) };
+#endif
+
 struct net init_net = {
 	.count		= REFCOUNT_INIT(1),
 	.dev_base_head	= LIST_HEAD_INIT(init_net.dev_base_head),
+#ifdef CONFIG_KEYS
+	.key_domain	= &init_net_key_domain,
+#endif
 };
 EXPORT_SYMBOL(init_net);
 
@@ -60,6 +69,8 @@
 #define INITIAL_NET_GEN_PTRS	13 /* +1 for len +2 for rcu_head */
 
 static unsigned int max_gen_ptrs = INITIAL_NET_GEN_PTRS;
+
+DEFINE_COOKIE(net_cookie);
 
 static struct net_generic *net_alloc_generic(void)
 {
@@ -112,6 +123,7 @@
 
 static int ops_init(const struct pernet_operations *ops, struct net *net)
 {
+	struct net_generic *ng;
 	int err = -ENOMEM;
 	void *data = NULL;
 
@@ -130,6 +142,12 @@
 	if (!err)
 		return 0;
 
+	if (ops->id && ops->size) {
+		ng = rcu_dereference_protected(net->gen,
+					       lockdep_is_held(&pernet_ops_rwsem));
+		ng->ptr[*ops->id] = NULL;
+	}
+
 cleanup:
 	kfree(data);
 
@@ -141,6 +159,17 @@
 {
 	if (ops->id && ops->size) {
 		kfree(net_generic(net, *ops->id));
+	}
+}
+
+static void ops_pre_exit_list(const struct pernet_operations *ops,
+			      struct list_head *net_exit_list)
+{
+	struct net *net;
+
+	if (ops->pre_exit) {
+		list_for_each_entry(net, net_exit_list, exit_list)
+			ops->pre_exit(net);
 	}
 }
 
@@ -194,16 +223,10 @@
 	return 0;
 }
 
-/* Must be called from RCU-critical section or with nsid_lock held. If
- * a new id is assigned, the bool alloc is set to true, thus the
- * caller knows that the new id must be notified via rtnl.
- */
-static int __peernet2id_alloc(struct net *net, struct net *peer, bool *alloc)
+/* Must be called from RCU-critical section or with nsid_lock held */
+static int __peernet2id(const struct net *net, struct net *peer)
 {
 	int id = idr_for_each(&net->netns_ids, net_eq_idr, peer);
-	bool alloc_it = *alloc;
-
-	*alloc = false;
 
 	/* Magic value for id 0. */
 	if (id == NET_ID_ZERO)
@@ -211,55 +234,53 @@
 	if (id > 0)
 		return id;
 
-	if (alloc_it) {
-		id = alloc_netid(net, peer, -1);
-		*alloc = true;
-		return id >= 0 ? id : NETNSA_NSID_NOT_ASSIGNED;
-	}
-
 	return NETNSA_NSID_NOT_ASSIGNED;
 }
 
-/* Must be called from RCU-critical section or with nsid_lock held */
-static int __peernet2id(struct net *net, struct net *peer)
-{
-	bool no = false;
-
-	return __peernet2id_alloc(net, peer, &no);
-}
-
-static void rtnl_net_notifyid(struct net *net, int cmd, int id, gfp_t gfp);
+static void rtnl_net_notifyid(struct net *net, int cmd, int id, u32 portid,
+			      struct nlmsghdr *nlh, gfp_t gfp);
 /* This function returns the id of a peer netns. If no id is assigned, one will
  * be allocated and returned.
  */
 int peernet2id_alloc(struct net *net, struct net *peer, gfp_t gfp)
 {
-	bool alloc = false, alive = false;
 	int id;
 
 	if (refcount_read(&net->count) == 0)
 		return NETNSA_NSID_NOT_ASSIGNED;
+
 	spin_lock_bh(&net->nsid_lock);
-	/*
-	 * When peer is obtained from RCU lists, we may race with
+	id = __peernet2id(net, peer);
+	if (id >= 0) {
+		spin_unlock_bh(&net->nsid_lock);
+		return id;
+	}
+
+	/* When peer is obtained from RCU lists, we may race with
 	 * its cleanup. Check whether it's alive, and this guarantees
 	 * we never hash a peer back to net->netns_ids, after it has
 	 * just been idr_remove()'d from there in cleanup_net().
 	 */
-	if (maybe_get_net(peer))
-		alive = alloc = true;
-	id = __peernet2id_alloc(net, peer, &alloc);
+	if (!maybe_get_net(peer)) {
+		spin_unlock_bh(&net->nsid_lock);
+		return NETNSA_NSID_NOT_ASSIGNED;
+	}
+
+	id = alloc_netid(net, peer, -1);
 	spin_unlock_bh(&net->nsid_lock);
-	if (alloc && id >= 0)
-		rtnl_net_notifyid(net, RTM_NEWNSID, id, gfp);
-	if (alive)
-		put_net(peer);
+
+	put_net(peer);
+	if (id < 0)
+		return NETNSA_NSID_NOT_ASSIGNED;
+
+	rtnl_net_notifyid(net, RTM_NEWNSID, id, 0, NULL, gfp);
+
 	return id;
 }
 EXPORT_SYMBOL_GPL(peernet2id_alloc);
 
 /* This function returns, if assigned, the id of a peer netns. */
-int peernet2id(struct net *net, struct net *peer)
+int peernet2id(const struct net *net, struct net *peer)
 {
 	int id;
 
@@ -274,12 +295,12 @@
 /* This function returns true is the peer netns has an id assigned into the
  * current netns.
  */
-bool peernet_has_id(struct net *net, struct net *peer)
+bool peernet_has_id(const struct net *net, struct net *peer)
 {
 	return peernet2id(net, peer) >= 0;
 }
 
-struct net *get_net_ns_by_id(struct net *net, int id)
+struct net *get_net_ns_by_id(const struct net *net, int id)
 {
 	struct net *peer;
 
@@ -308,6 +329,9 @@
 	refcount_set(&net->count, 1);
 	refcount_set(&net->passive, 1);
 	get_random_bytes(&net->hash_mix, sizeof(u32));
+	preempt_disable();
+	atomic64_set(&net->net_cookie, gen_cookie_next(&net_cookie));
+	preempt_enable();
 	net->dev_base_seq = 1;
 	net->user_ns = user_ns;
 	idr_init(&net->netns_ids);
@@ -331,6 +355,12 @@
 	 */
 	list_add(&net->exit_list, &net_exit_list);
 	saved_ops = ops;
+	list_for_each_entry_continue_reverse(ops, &pernet_list, list)
+		ops_pre_exit_list(ops, &net_exit_list);
+
+	synchronize_rcu();
+
+	ops = saved_ops;
 	list_for_each_entry_continue_reverse(ops, &pernet_list, list)
 		ops_exit_list(ops, &net_exit_list);
 
@@ -389,10 +419,22 @@
 	if (!net)
 		goto out_free;
 
+#ifdef CONFIG_KEYS
+	net->key_domain = kzalloc(sizeof(struct key_tag), GFP_KERNEL);
+	if (!net->key_domain)
+		goto out_free_2;
+	refcount_set(&net->key_domain->usage, 1);
+#endif
+
 	rcu_assign_pointer(net->gen, ng);
 out:
 	return net;
 
+#ifdef CONFIG_KEYS
+out_free_2:
+	kmem_cache_free(net_cachep, net);
+	net = NULL;
+#endif
 out_free:
 	kfree(ng);
 	goto out;
@@ -444,6 +486,9 @@
 
 	if (rv < 0) {
 put_userns:
+#ifdef CONFIG_KEYS
+		key_remove_domain(net->key_domain);
+#endif
 		put_user_ns(user_ns);
 		net_drop_ns(net);
 dec_ucounts:
@@ -498,7 +543,7 @@
 			idr_remove(&tmp->netns_ids, id);
 		spin_unlock_bh(&tmp->nsid_lock);
 		if (id >= 0)
-			rtnl_net_notifyid(tmp, RTM_DELNSID, id,
+			rtnl_net_notifyid(tmp, RTM_DELNSID, id, 0, NULL,
 					  GFP_KERNEL);
 		if (tmp == last)
 			break;
@@ -544,10 +589,15 @@
 		list_add_tail(&net->exit_list, &net_exit_list);
 	}
 
+	/* Run all of the network namespace pre_exit methods */
+	list_for_each_entry_reverse(ops, &pernet_list, list)
+		ops_pre_exit_list(ops, &net_exit_list);
+
 	/*
 	 * Another CPU might be rcu-iterating the list, wait for it.
 	 * This needs to be before calling the exit() notifiers, so
 	 * the rcu_barrier() below isn't sufficient alone.
+	 * Also the pre_exit() and exit() methods need this barrier.
 	 */
 	synchronize_rcu();
 
@@ -570,6 +620,9 @@
 	list_for_each_entry_safe(net, tmp, &net_exit_list, exit_list) {
 		list_del_init(&net->exit_list);
 		dec_net_namespaces(net->ucounts);
+#ifdef CONFIG_KEYS
+		key_remove_domain(net->key_domain);
+#endif
 		put_user_ns(net->user_ns);
 		net_drop_ns(net);
 	}
@@ -686,6 +739,7 @@
 	[NETNSA_NSID]		= { .type = NLA_S32 },
 	[NETNSA_PID]		= { .type = NLA_U32 },
 	[NETNSA_FD]		= { .type = NLA_U32 },
+	[NETNSA_TARGET_NSID]	= { .type = NLA_S32 },
 };
 
 static int rtnl_net_newid(struct sk_buff *skb, struct nlmsghdr *nlh,
@@ -697,8 +751,8 @@
 	struct net *peer;
 	int nsid, err;
 
-	err = nlmsg_parse(nlh, sizeof(struct rtgenmsg), tb, NETNSA_MAX,
-			  rtnl_net_policy, extack);
+	err = nlmsg_parse_deprecated(nlh, sizeof(struct rtgenmsg), tb,
+				     NETNSA_MAX, rtnl_net_policy, extack);
 	if (err < 0)
 		return err;
 	if (!tb[NETNSA_NSID]) {
@@ -736,7 +790,8 @@
 	err = alloc_netid(net, peer, nsid);
 	spin_unlock_bh(&net->nsid_lock);
 	if (err >= 0) {
-		rtnl_net_notifyid(net, RTM_NEWNSID, err, GFP_KERNEL);
+		rtnl_net_notifyid(net, RTM_NEWNSID, err, NETLINK_CB(skb).portid,
+				  nlh, GFP_KERNEL);
 		err = 0;
 	} else if (err == -ENOSPC && nsid >= 0) {
 		err = -EEXIST;
@@ -752,23 +807,38 @@
 {
 	return NLMSG_ALIGN(sizeof(struct rtgenmsg))
 	       + nla_total_size(sizeof(s32)) /* NETNSA_NSID */
+	       + nla_total_size(sizeof(s32)) /* NETNSA_CURRENT_NSID */
 	       ;
 }
 
-static int rtnl_net_fill(struct sk_buff *skb, u32 portid, u32 seq, int flags,
-			 int cmd, struct net *net, int nsid)
+struct net_fill_args {
+	u32 portid;
+	u32 seq;
+	int flags;
+	int cmd;
+	int nsid;
+	bool add_ref;
+	int ref_nsid;
+};
+
+static int rtnl_net_fill(struct sk_buff *skb, struct net_fill_args *args)
 {
 	struct nlmsghdr *nlh;
 	struct rtgenmsg *rth;
 
-	nlh = nlmsg_put(skb, portid, seq, cmd, sizeof(*rth), flags);
+	nlh = nlmsg_put(skb, args->portid, args->seq, args->cmd, sizeof(*rth),
+			args->flags);
 	if (!nlh)
 		return -EMSGSIZE;
 
 	rth = nlmsg_data(nlh);
 	rth->rtgen_family = AF_UNSPEC;
 
-	if (nla_put_s32(skb, NETNSA_NSID, nsid))
+	if (nla_put_s32(skb, NETNSA_NSID, args->nsid))
+		goto nla_put_failure;
+
+	if (args->add_ref &&
+	    nla_put_s32(skb, NETNSA_CURRENT_NSID, args->ref_nsid))
 		goto nla_put_failure;
 
 	nlmsg_end(skb, nlh);
@@ -779,18 +849,59 @@
 	return -EMSGSIZE;
 }
 
+static int rtnl_net_valid_getid_req(struct sk_buff *skb,
+				    const struct nlmsghdr *nlh,
+				    struct nlattr **tb,
+				    struct netlink_ext_ack *extack)
+{
+	int i, err;
+
+	if (!netlink_strict_get_check(skb))
+		return nlmsg_parse_deprecated(nlh, sizeof(struct rtgenmsg),
+					      tb, NETNSA_MAX, rtnl_net_policy,
+					      extack);
+
+	err = nlmsg_parse_deprecated_strict(nlh, sizeof(struct rtgenmsg), tb,
+					    NETNSA_MAX, rtnl_net_policy,
+					    extack);
+	if (err)
+		return err;
+
+	for (i = 0; i <= NETNSA_MAX; i++) {
+		if (!tb[i])
+			continue;
+
+		switch (i) {
+		case NETNSA_PID:
+		case NETNSA_FD:
+		case NETNSA_NSID:
+		case NETNSA_TARGET_NSID:
+			break;
+		default:
+			NL_SET_ERR_MSG(extack, "Unsupported attribute in peer netns getid request");
+			return -EINVAL;
+		}
+	}
+
+	return 0;
+}
+
 static int rtnl_net_getid(struct sk_buff *skb, struct nlmsghdr *nlh,
 			  struct netlink_ext_ack *extack)
 {
 	struct net *net = sock_net(skb->sk);
 	struct nlattr *tb[NETNSA_MAX + 1];
+	struct net_fill_args fillargs = {
+		.portid = NETLINK_CB(skb).portid,
+		.seq = nlh->nlmsg_seq,
+		.cmd = RTM_NEWNSID,
+	};
+	struct net *peer, *target = net;
 	struct nlattr *nla;
 	struct sk_buff *msg;
-	struct net *peer;
-	int err, id;
+	int err;
 
-	err = nlmsg_parse(nlh, sizeof(struct rtgenmsg), tb, NETNSA_MAX,
-			  rtnl_net_policy, extack);
+	err = rtnl_net_valid_getid_req(skb, nlh, tb, extack);
 	if (err < 0)
 		return err;
 	if (tb[NETNSA_PID]) {
@@ -799,6 +910,11 @@
 	} else if (tb[NETNSA_FD]) {
 		peer = get_net_ns_by_fd(nla_get_u32(tb[NETNSA_FD]));
 		nla = tb[NETNSA_FD];
+	} else if (tb[NETNSA_NSID]) {
+		peer = get_net_ns_by_id(net, nla_get_s32(tb[NETNSA_NSID]));
+		if (!peer)
+			peer = ERR_PTR(-ENOENT);
+		nla = tb[NETNSA_NSID];
 	} else {
 		NL_SET_ERR_MSG(extack, "Peer netns reference is missing");
 		return -EINVAL;
@@ -810,15 +926,29 @@
 		return PTR_ERR(peer);
 	}
 
+	if (tb[NETNSA_TARGET_NSID]) {
+		int id = nla_get_s32(tb[NETNSA_TARGET_NSID]);
+
+		target = rtnl_get_net_ns_capable(NETLINK_CB(skb).sk, id);
+		if (IS_ERR(target)) {
+			NL_SET_BAD_ATTR(extack, tb[NETNSA_TARGET_NSID]);
+			NL_SET_ERR_MSG(extack,
+				       "Target netns reference is invalid");
+			err = PTR_ERR(target);
+			goto out;
+		}
+		fillargs.add_ref = true;
+		fillargs.ref_nsid = peernet2id(net, peer);
+	}
+
 	msg = nlmsg_new(rtnl_net_get_size(), GFP_KERNEL);
 	if (!msg) {
 		err = -ENOMEM;
 		goto out;
 	}
 
-	id = peernet2id(net, peer);
-	err = rtnl_net_fill(msg, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
-			    RTM_NEWNSID, net, id);
+	fillargs.nsid = peernet2id(target, peer);
+	err = rtnl_net_fill(msg, &fillargs);
 	if (err < 0)
 		goto err_out;
 
@@ -828,14 +958,17 @@
 err_out:
 	nlmsg_free(msg);
 out:
+	if (fillargs.add_ref)
+		put_net(target);
 	put_net(peer);
 	return err;
 }
 
 struct rtnl_net_dump_cb {
-	struct net *net;
+	struct net *tgt_net;
+	struct net *ref_net;
 	struct sk_buff *skb;
-	struct netlink_callback *cb;
+	struct net_fill_args fillargs;
 	int idx;
 	int s_idx;
 };
@@ -849,9 +982,10 @@
 	if (net_cb->idx < net_cb->s_idx)
 		goto cont;
 
-	ret = rtnl_net_fill(net_cb->skb, NETLINK_CB(net_cb->cb->skb).portid,
-			    net_cb->cb->nlh->nlmsg_seq, NLM_F_MULTI,
-			    RTM_NEWNSID, net_cb->net, id);
+	net_cb->fillargs.nsid = id;
+	if (net_cb->fillargs.add_ref)
+		net_cb->fillargs.ref_nsid = __peernet2id(net_cb->ref_net, peer);
+	ret = rtnl_net_fill(net_cb->skb, &net_cb->fillargs);
 	if (ret < 0)
 		return ret;
 
@@ -860,27 +994,90 @@
 	return 0;
 }
 
+static int rtnl_valid_dump_net_req(const struct nlmsghdr *nlh, struct sock *sk,
+				   struct rtnl_net_dump_cb *net_cb,
+				   struct netlink_callback *cb)
+{
+	struct netlink_ext_ack *extack = cb->extack;
+	struct nlattr *tb[NETNSA_MAX + 1];
+	int err, i;
+
+	err = nlmsg_parse_deprecated_strict(nlh, sizeof(struct rtgenmsg), tb,
+					    NETNSA_MAX, rtnl_net_policy,
+					    extack);
+	if (err < 0)
+		return err;
+
+	for (i = 0; i <= NETNSA_MAX; i++) {
+		if (!tb[i])
+			continue;
+
+		if (i == NETNSA_TARGET_NSID) {
+			struct net *net;
+
+			net = rtnl_get_net_ns_capable(sk, nla_get_s32(tb[i]));
+			if (IS_ERR(net)) {
+				NL_SET_BAD_ATTR(extack, tb[i]);
+				NL_SET_ERR_MSG(extack,
+					       "Invalid target network namespace id");
+				return PTR_ERR(net);
+			}
+			net_cb->fillargs.add_ref = true;
+			net_cb->ref_net = net_cb->tgt_net;
+			net_cb->tgt_net = net;
+		} else {
+			NL_SET_BAD_ATTR(extack, tb[i]);
+			NL_SET_ERR_MSG(extack,
+				       "Unsupported attribute in dump request");
+			return -EINVAL;
+		}
+	}
+
+	return 0;
+}
+
 static int rtnl_net_dumpid(struct sk_buff *skb, struct netlink_callback *cb)
 {
-	struct net *net = sock_net(skb->sk);
 	struct rtnl_net_dump_cb net_cb = {
-		.net = net,
+		.tgt_net = sock_net(skb->sk),
 		.skb = skb,
-		.cb = cb,
+		.fillargs = {
+			.portid = NETLINK_CB(cb->skb).portid,
+			.seq = cb->nlh->nlmsg_seq,
+			.flags = NLM_F_MULTI,
+			.cmd = RTM_NEWNSID,
+		},
 		.idx = 0,
 		.s_idx = cb->args[0],
 	};
+	int err = 0;
+
+	if (cb->strict_check) {
+		err = rtnl_valid_dump_net_req(cb->nlh, skb->sk, &net_cb, cb);
+		if (err < 0)
+			goto end;
+	}
 
 	rcu_read_lock();
-	idr_for_each(&net->netns_ids, rtnl_net_dumpid_one, &net_cb);
+	idr_for_each(&net_cb.tgt_net->netns_ids, rtnl_net_dumpid_one, &net_cb);
 	rcu_read_unlock();
 
 	cb->args[0] = net_cb.idx;
-	return skb->len;
+end:
+	if (net_cb.fillargs.add_ref)
+		put_net(net_cb.tgt_net);
+	return err < 0 ? err : skb->len;
 }
 
-static void rtnl_net_notifyid(struct net *net, int cmd, int id, gfp_t gfp)
+static void rtnl_net_notifyid(struct net *net, int cmd, int id, u32 portid,
+			      struct nlmsghdr *nlh, gfp_t gfp)
 {
+	struct net_fill_args fillargs = {
+		.portid = portid,
+		.seq = nlh ? nlh->nlmsg_seq : 0,
+		.cmd = cmd,
+		.nsid = id,
+	};
 	struct sk_buff *msg;
 	int err = -ENOMEM;
 
@@ -888,11 +1085,11 @@
 	if (!msg)
 		goto out;
 
-	err = rtnl_net_fill(msg, 0, 0, 0, cmd, net, id);
+	err = rtnl_net_fill(msg, &fillargs);
 	if (err < 0)
 		goto err_out;
 
-	rtnl_notify(msg, net, 0, RTNLGRP_NSID, NULL, gfp);
+	rtnl_notify(msg, net, portid, RTNLGRP_NSID, nlh, gfp);
 	return;
 
 err_out:
@@ -967,6 +1164,8 @@
 out_undo:
 	/* If I have an error cleanup all namespaces I initialized */
 	list_del(&ops->list);
+	ops_pre_exit_list(ops, &net_exit_list);
+	synchronize_rcu();
 	ops_exit_list(ops, &net_exit_list);
 	ops_free_list(ops, &net_exit_list);
 	return error;
@@ -981,6 +1180,8 @@
 	/* See comment in __register_pernet_operations() */
 	for_each_net(net)
 		list_add_tail(&net->exit_list, &net_exit_list);
+	ops_pre_exit_list(ops, &net_exit_list);
+	synchronize_rcu();
 	ops_exit_list(ops, &net_exit_list);
 	ops_free_list(ops, &net_exit_list);
 }
@@ -1005,6 +1206,8 @@
 	} else {
 		LIST_HEAD(net_exit_list);
 		list_add(&init_net.exit_list, &net_exit_list);
+		ops_pre_exit_list(ops, &net_exit_list);
+		synchronize_rcu();
 		ops_exit_list(ops, &net_exit_list);
 		ops_free_list(ops, &net_exit_list);
 	}
@@ -1166,12 +1369,13 @@
 	put_net(to_net_ns(ns));
 }
 
-static int netns_install(struct nsproxy *nsproxy, struct ns_common *ns)
+static int netns_install(struct nsset *nsset, struct ns_common *ns)
 {
+	struct nsproxy *nsproxy = nsset->nsproxy;
 	struct net *net = to_net_ns(ns);
 
 	if (!ns_capable(net->user_ns, CAP_SYS_ADMIN) ||
-	    !ns_capable(current_user_ns(), CAP_SYS_ADMIN))
+	    !ns_capable(nsset->cred->user_ns, CAP_SYS_ADMIN))
 		return -EPERM;
 
 	put_net(nsproxy->net_ns);

--
Gitblit v1.6.2