From 102a0743326a03cd1a1202ceda21e175b7d3575c Mon Sep 17 00:00:00 2001
From: hc <hc@nodka.com>
Date: Tue, 20 Feb 2024 01:20:52 +0000
Subject: [PATCH] add new system file

---
 kernel/net/sunrpc/svcauth_unix.c |   82 ++++++++++++++++++++++++++--------------
 1 files changed, 53 insertions(+), 29 deletions(-)

diff --git a/kernel/net/sunrpc/svcauth_unix.c b/kernel/net/sunrpc/svcauth_unix.c
index af7f28f..60754a2 100644
--- a/kernel/net/sunrpc/svcauth_unix.c
+++ b/kernel/net/sunrpc/svcauth_unix.c
@@ -1,3 +1,4 @@
+// SPDX-License-Identifier: GPL-2.0-only
 #include <linux/types.h>
 #include <linux/sched.h>
 #include <linux/module.h>
@@ -37,12 +38,18 @@
 extern struct auth_ops svcauth_null;
 extern struct auth_ops svcauth_unix;
 
-static void svcauth_unix_domain_release(struct auth_domain *dom)
+static void svcauth_unix_domain_release_rcu(struct rcu_head *head)
 {
+	struct auth_domain *dom = container_of(head, struct auth_domain, rcu_head);
 	struct unix_domain *ud = container_of(dom, struct unix_domain, h);
 
 	kfree(dom->name);
 	kfree(ud);
+}
+
+static void svcauth_unix_domain_release(struct auth_domain *dom)
+{
+	call_rcu(&dom->rcu_head, svcauth_unix_domain_release_rcu);
 }
 
 struct auth_domain *unix_domain_find(char *name)
@@ -50,7 +57,7 @@
 	struct auth_domain *rv;
 	struct unix_domain *new = NULL;
 
-	rv = auth_domain_lookup(name, NULL);
+	rv = auth_domain_find(name);
 	while(1) {
 		if (rv) {
 			if (new && rv != &new->h)
@@ -91,6 +98,7 @@
 	char			m_class[8]; /* e.g. "nfsd" */
 	struct in6_addr		m_addr;
 	struct unix_domain	*m_client;
+	struct rcu_head		m_rcu;
 };
 
 static void ip_map_put(struct kref *kref)
@@ -101,7 +109,7 @@
 	if (test_bit(CACHE_VALID, &item->flags) &&
 	    !test_bit(CACHE_NEGATIVE, &item->flags))
 		auth_domain_put(&im->m_client->h);
-	kfree(im);
+	kfree_rcu(im, m_rcu);
 }
 
 static inline int hash_ip6(const struct in6_addr *ip)
@@ -140,6 +148,11 @@
 		return NULL;
 }
 
+static int ip_map_upcall(struct cache_detail *cd, struct cache_head *h)
+{
+	return sunrpc_cache_pipe_upcall(cd, h);
+}
+
 static void ip_map_request(struct cache_detail *cd,
 				  struct cache_head *h,
 				  char **bpp, int *blen)
@@ -158,7 +171,7 @@
 }
 
 static struct ip_map *__ip_map_lookup(struct cache_detail *cd, char *class, struct in6_addr *addr);
-static int __ip_map_update(struct cache_detail *cd, struct ip_map *ipm, struct unix_domain *udom, time_t expiry);
+static int __ip_map_update(struct cache_detail *cd, struct ip_map *ipm, struct unix_domain *udom, time64_t expiry);
 
 static int ip_map_parse(struct cache_detail *cd,
 			  char *mesg, int mlen)
@@ -179,7 +192,7 @@
 
 	struct ip_map *ipmp;
 	struct auth_domain *dom;
-	time_t expiry;
+	time64_t expiry;
 
 	if (mesg[mlen-1] != '\n')
 		return -EINVAL;
@@ -280,9 +293,9 @@
 
 	strcpy(ip.m_class, class);
 	ip.m_addr = *addr;
-	ch = sunrpc_cache_lookup(cd, &ip.h,
-				 hash_str(class, IP_HASHBITS) ^
-				 hash_ip6(addr));
+	ch = sunrpc_cache_lookup_rcu(cd, &ip.h,
+				     hash_str(class, IP_HASHBITS) ^
+				     hash_ip6(addr));
 
 	if (ch)
 		return container_of(ch, struct ip_map, h);
@@ -300,7 +313,7 @@
 }
 
 static int __ip_map_update(struct cache_detail *cd, struct ip_map *ipm,
-		struct unix_domain *udom, time_t expiry)
+		struct unix_domain *udom, time64_t expiry)
 {
 	struct ip_map ip;
 	struct cache_head *ch;
@@ -317,15 +330,6 @@
 		return -ENOMEM;
 	cache_put(ch, cd);
 	return 0;
-}
-
-static inline int ip_map_update(struct net *net, struct ip_map *ipm,
-		struct unix_domain *udom, time_t expiry)
-{
-	struct sunrpc_net *sn;
-
-	sn = net_generic(net, sunrpc_net_id);
-	return __ip_map_update(sn->ip_map_cache, ipm, udom, expiry);
 }
 
 void svcauth_unix_purge(struct net *net)
@@ -412,6 +416,7 @@
 	struct cache_head	h;
 	kuid_t			uid;
 	struct group_info	*gi;
+	struct rcu_head		rcu;
 };
 
 static int unix_gid_hash(kuid_t uid)
@@ -419,14 +424,23 @@
 	return hash_long(from_kuid(&init_user_ns, uid), GID_HASHBITS);
 }
 
-static void unix_gid_put(struct kref *kref)
+static void unix_gid_free(struct rcu_head *rcu)
 {
-	struct cache_head *item = container_of(kref, struct cache_head, ref);
-	struct unix_gid *ug = container_of(item, struct unix_gid, h);
+	struct unix_gid *ug = container_of(rcu, struct unix_gid, rcu);
+	struct cache_head *item = &ug->h;
+
 	if (test_bit(CACHE_VALID, &item->flags) &&
 	    !test_bit(CACHE_NEGATIVE, &item->flags))
 		put_group_info(ug->gi);
 	kfree(ug);
+}
+
+static void unix_gid_put(struct kref *kref)
+{
+	struct cache_head *item = container_of(kref, struct cache_head, ref);
+	struct unix_gid *ug = container_of(item, struct unix_gid, h);
+
+	call_rcu(&ug->rcu, unix_gid_free);
 }
 
 static int unix_gid_match(struct cache_head *corig, struct cache_head *cnew)
@@ -458,6 +472,11 @@
 		return NULL;
 }
 
+static int unix_gid_upcall(struct cache_detail *cd, struct cache_head *h)
+{
+	return sunrpc_cache_pipe_upcall_timeout(cd, h);
+}
+
 static void unix_gid_request(struct cache_detail *cd,
 			     struct cache_head *h,
 			     char **bpp, int *blen)
@@ -482,7 +501,7 @@
 	int rv;
 	int i;
 	int err;
-	time_t expiry;
+	time64_t expiry;
 	struct unix_gid ug, *ugp;
 
 	if (mesg[mlen - 1] != '\n')
@@ -492,7 +511,7 @@
 	rv = get_int(&mesg, &id);
 	if (rv)
 		return -EINVAL;
-	uid = make_kuid(&init_user_ns, id);
+	uid = make_kuid(current_user_ns(), id);
 	ug.uid = uid;
 
 	expiry = get_expiry(&mesg);
@@ -514,7 +533,7 @@
 		err = -EINVAL;
 		if (rv)
 			goto out;
-		kgid = make_kgid(&init_user_ns, gid);
+		kgid = make_kgid(current_user_ns(), gid);
 		if (!gid_valid(kgid))
 			goto out;
 		ug.gi->gid[i] = kgid;
@@ -547,7 +566,7 @@
 			 struct cache_detail *cd,
 			 struct cache_head *h)
 {
-	struct user_namespace *user_ns = &init_user_ns;
+	struct user_namespace *user_ns = m->file->f_cred->user_ns;
 	struct unix_gid *ug;
 	int i;
 	int glen;
@@ -575,6 +594,7 @@
 	.hash_size	= GID_HASHMAX,
 	.name		= "auth.unix.gid",
 	.cache_put	= unix_gid_put,
+	.cache_upcall	= unix_gid_upcall,
 	.cache_request	= unix_gid_request,
 	.cache_parse	= unix_gid_parse,
 	.cache_show	= unix_gid_show,
@@ -619,7 +639,7 @@
 	struct cache_head *ch;
 
 	ug.uid = uid;
-	ch = sunrpc_cache_lookup(cd, &ug.h, unix_gid_hash(uid));
+	ch = sunrpc_cache_lookup_rcu(cd, &ug.h, unix_gid_hash(uid));
 	if (ch)
 		return container_of(ch, struct unix_gid, h);
 	else
@@ -788,6 +808,7 @@
 	struct kvec	*argv = &rqstp->rq_arg.head[0];
 	struct kvec	*resv = &rqstp->rq_res.head[0];
 	struct svc_cred	*cred = &rqstp->rq_cred;
+	struct user_namespace *userns;
 	u32		slen, i;
 	int		len   = argv->iov_len;
 
@@ -808,8 +829,10 @@
 	 * (export-specific) anonymous id by nfsd_setuser.
 	 * Supplementary gid's will be left alone.
 	 */
-	cred->cr_uid = make_kuid(&init_user_ns, svc_getnl(argv)); /* uid */
-	cred->cr_gid = make_kgid(&init_user_ns, svc_getnl(argv)); /* gid */
+	userns = (rqstp->rq_xprt && rqstp->rq_xprt->xpt_cred) ?
+		rqstp->rq_xprt->xpt_cred->user_ns : &init_user_ns;
+	cred->cr_uid = make_kuid(userns, svc_getnl(argv)); /* uid */
+	cred->cr_gid = make_kgid(userns, svc_getnl(argv)); /* gid */
 	slen = svc_getnl(argv);			/* gids length */
 	if (slen > UNX_NGROUPS || (len -= (slen + 2)*4) < 0)
 		goto badcred;
@@ -817,7 +840,7 @@
 	if (cred->cr_group_info == NULL)
 		return SVC_CLOSE;
 	for (i = 0; i < slen; i++) {
-		kgid_t kgid = make_kgid(&init_user_ns, svc_getnl(argv));
+		kgid_t kgid = make_kgid(userns, svc_getnl(argv));
 		cred->cr_group_info->gid[i] = kgid;
 	}
 	groups_sort(cred->cr_group_info);
@@ -869,6 +892,7 @@
 	.hash_size	= IP_HASHMAX,
 	.name		= "auth.unix.ip",
 	.cache_put	= ip_map_put,
+	.cache_upcall	= ip_map_upcall,
 	.cache_request	= ip_map_request,
 	.cache_parse	= ip_map_parse,
 	.cache_show	= ip_map_show,

--
Gitblit v1.6.2