From 1543e317f1da31b75942316931e8f491a8920811 Mon Sep 17 00:00:00 2001
From: hc <hc@nodka.com>
Date: Thu, 04 Jan 2024 10:08:02 +0000
Subject: [PATCH] disable FB

---
 kernel/drivers/infiniband/core/netlink.c |  170 +++++++++++++++++++++++++++++++-------------------------
 1 files changed, 95 insertions(+), 75 deletions(-)

diff --git a/kernel/drivers/infiniband/core/netlink.c b/kernel/drivers/infiniband/core/netlink.c
index 3ccaae1..8cd31ef 100644
--- a/kernel/drivers/infiniband/core/netlink.c
+++ b/kernel/drivers/infiniband/core/netlink.c
@@ -36,27 +36,31 @@
 #include <linux/export.h>
 #include <net/netlink.h>
 #include <net/net_namespace.h>
+#include <net/netns/generic.h>
 #include <net/sock.h>
 #include <rdma/rdma_netlink.h>
 #include <linux/module.h>
 #include "core_priv.h"
 
-static DEFINE_MUTEX(rdma_nl_mutex);
-static struct sock *nls;
 static struct {
-	const struct rdma_nl_cbs   *cb_table;
+	const struct rdma_nl_cbs *cb_table;
+	/* Synchronizes between ongoing netlink commands and netlink client
+	 * unregistration.
+	 */
+	struct rw_semaphore sem;
 } rdma_nl_types[RDMA_NL_NUM_CLIENTS];
 
-int rdma_nl_chk_listeners(unsigned int group)
+bool rdma_nl_chk_listeners(unsigned int group)
 {
-	return (netlink_has_listeners(nls, group)) ? 0 : -1;
+	struct rdma_dev_net *rnet = rdma_net_to_dev_net(&init_net);
+
+	return netlink_has_listeners(rnet->nl_sock, group);
 }
 EXPORT_SYMBOL(rdma_nl_chk_listeners);
 
 static bool is_nl_msg_valid(unsigned int type, unsigned int op)
 {
 	static const unsigned int max_num_ops[RDMA_NL_NUM_CLIENTS] = {
-		[RDMA_NL_RDMA_CM] = RDMA_NL_RDMA_CM_NUM_OPS,
 		[RDMA_NL_IWCM] = RDMA_NL_IWPM_NUM_OPS,
 		[RDMA_NL_LS] = RDMA_NL_LS_NUM_OPS,
 		[RDMA_NL_NLDEV] = RDMA_NLDEV_NUM_OPS,
@@ -74,62 +78,53 @@
 	return (op < max_num_ops[type]) ? true : false;
 }
 
-static bool is_nl_valid(unsigned int type, unsigned int op)
+static const struct rdma_nl_cbs *
+get_cb_table(const struct sk_buff *skb, unsigned int type, unsigned int op)
 {
 	const struct rdma_nl_cbs *cb_table;
 
-	if (!is_nl_msg_valid(type, op))
-		return false;
+	/*
+	 * Currently only NLDEV client is supporting netlink commands in
+	 * non init_net net namespace.
+	 */
+	if (sock_net(skb->sk) != &init_net && type != RDMA_NL_NLDEV)
+		return NULL;
 
-	if (!rdma_nl_types[type].cb_table) {
-		mutex_unlock(&rdma_nl_mutex);
+	cb_table = READ_ONCE(rdma_nl_types[type].cb_table);
+	if (!cb_table) {
+		/*
+		 * Didn't get valid reference of the table, attempt module
+		 * load once.
+		 */
+		up_read(&rdma_nl_types[type].sem);
+
 		request_module("rdma-netlink-subsys-%d", type);
-		mutex_lock(&rdma_nl_mutex);
+
+		down_read(&rdma_nl_types[type].sem);
+		cb_table = READ_ONCE(rdma_nl_types[type].cb_table);
 	}
-
-	cb_table = rdma_nl_types[type].cb_table;
-
 	if (!cb_table || (!cb_table[op].dump && !cb_table[op].doit))
-		return false;
-	return true;
+		return NULL;
+	return cb_table;
 }
 
 void rdma_nl_register(unsigned int index,
 		      const struct rdma_nl_cbs cb_table[])
 {
-	mutex_lock(&rdma_nl_mutex);
-	if (!is_nl_msg_valid(index, 0)) {
-		/*
-		 * All clients are not interesting in success/failure of
-		 * this call. They want to see the print to error log and
-		 * continue their initialization. Print warning for them,
-		 * because it is programmer's error to be here.
-		 */
-		mutex_unlock(&rdma_nl_mutex);
-		WARN(true,
-		     "The not-valid %u index was supplied to RDMA netlink\n",
-		     index);
+	if (WARN_ON(!is_nl_msg_valid(index, 0)) ||
+	    WARN_ON(READ_ONCE(rdma_nl_types[index].cb_table)))
 		return;
-	}
 
-	if (rdma_nl_types[index].cb_table) {
-		mutex_unlock(&rdma_nl_mutex);
-		WARN(true,
-		     "The %u index is already registered in RDMA netlink\n",
-		     index);
-		return;
-	}
-
-	rdma_nl_types[index].cb_table = cb_table;
-	mutex_unlock(&rdma_nl_mutex);
+	/* Pairs with the READ_ONCE in is_nl_valid() */
+	smp_store_release(&rdma_nl_types[index].cb_table, cb_table);
 }
 EXPORT_SYMBOL(rdma_nl_register);
 
 void rdma_nl_unregister(unsigned int index)
 {
-	mutex_lock(&rdma_nl_mutex);
+	down_write(&rdma_nl_types[index].sem);
 	rdma_nl_types[index].cb_table = NULL;
-	mutex_unlock(&rdma_nl_mutex);
+	up_write(&rdma_nl_types[index].sem);
 }
 EXPORT_SYMBOL(rdma_nl_unregister);
 
@@ -161,15 +156,21 @@
 	unsigned int index = RDMA_NL_GET_CLIENT(type);
 	unsigned int op = RDMA_NL_GET_OP(type);
 	const struct rdma_nl_cbs *cb_table;
+	int err = -EINVAL;
 
-	if (!is_nl_valid(index, op))
+	if (!is_nl_msg_valid(index, op))
 		return -EINVAL;
 
-	cb_table = rdma_nl_types[index].cb_table;
+	down_read(&rdma_nl_types[index].sem);
+	cb_table = get_cb_table(skb, index, op);
+	if (!cb_table)
+		goto done;
 
 	if ((cb_table[op].flags & RDMA_NL_ADMIN_PERM) &&
-	    !netlink_capable(skb, CAP_NET_ADMIN))
-		return -EPERM;
+	    !netlink_capable(skb, CAP_NET_ADMIN)) {
+		err = -EPERM;
+		goto done;
+	}
 
 	/*
 	 * LS responses overload the 0x100 (NLM_F_ROOT) flag.  Don't
@@ -177,24 +178,24 @@
 	 */
 	if (index == RDMA_NL_LS) {
 		if (cb_table[op].doit)
-			return cb_table[op].doit(skb, nlh, extack);
-		return -EINVAL;
+			err = cb_table[op].doit(skb, nlh, extack);
+		goto done;
 	}
 	/* FIXME: Convert IWCM to properly handle doit callbacks */
-	if ((nlh->nlmsg_flags & NLM_F_DUMP) || index == RDMA_NL_RDMA_CM ||
-	    index == RDMA_NL_IWCM) {
+	if ((nlh->nlmsg_flags & NLM_F_DUMP) || index == RDMA_NL_IWCM) {
 		struct netlink_dump_control c = {
 			.dump = cb_table[op].dump,
 		};
 		if (c.dump)
-			return netlink_dump_start(nls, skb, nlh, &c);
-		return -EINVAL;
+			err = netlink_dump_start(skb->sk, skb, nlh, &c);
+		goto done;
 	}
 
 	if (cb_table[op].doit)
-		return cb_table[op].doit(skb, nlh, extack);
-
-	return 0;
+		err = cb_table[op].doit(skb, nlh, extack);
+done:
+	up_read(&rdma_nl_types[index].sem);
+	return err;
 }
 
 /*
@@ -255,47 +256,44 @@
 
 static void rdma_nl_rcv(struct sk_buff *skb)
 {
-	mutex_lock(&rdma_nl_mutex);
 	rdma_nl_rcv_skb(skb, &rdma_nl_rcv_msg);
-	mutex_unlock(&rdma_nl_mutex);
 }
 
-int rdma_nl_unicast(struct sk_buff *skb, u32 pid)
+int rdma_nl_unicast(struct net *net, struct sk_buff *skb, u32 pid)
 {
+	struct rdma_dev_net *rnet = rdma_net_to_dev_net(net);
 	int err;
 
-	err = netlink_unicast(nls, skb, pid, MSG_DONTWAIT);
+	err = netlink_unicast(rnet->nl_sock, skb, pid, MSG_DONTWAIT);
 	return (err < 0) ? err : 0;
 }
 EXPORT_SYMBOL(rdma_nl_unicast);
 
-int rdma_nl_unicast_wait(struct sk_buff *skb, __u32 pid)
+int rdma_nl_unicast_wait(struct net *net, struct sk_buff *skb, __u32 pid)
 {
+	struct rdma_dev_net *rnet = rdma_net_to_dev_net(net);
 	int err;
 
-	err = netlink_unicast(nls, skb, pid, 0);
+	err = netlink_unicast(rnet->nl_sock, skb, pid, 0);
 	return (err < 0) ? err : 0;
 }
 EXPORT_SYMBOL(rdma_nl_unicast_wait);
 
-int rdma_nl_multicast(struct sk_buff *skb, unsigned int group, gfp_t flags)
+int rdma_nl_multicast(struct net *net, struct sk_buff *skb,
+		      unsigned int group, gfp_t flags)
 {
-	return nlmsg_multicast(nls, skb, 0, group, flags);
+	struct rdma_dev_net *rnet = rdma_net_to_dev_net(net);
+
+	return nlmsg_multicast(rnet->nl_sock, skb, 0, group, flags);
 }
 EXPORT_SYMBOL(rdma_nl_multicast);
 
-int __init rdma_nl_init(void)
+void rdma_nl_init(void)
 {
-	struct netlink_kernel_cfg cfg = {
-		.input	= rdma_nl_rcv,
-	};
+	int idx;
 
-	nls = netlink_kernel_create(&init_net, NETLINK_RDMA, &cfg);
-	if (!nls)
-		return -ENOMEM;
-
-	nls->sk_sndtimeo = 10 * HZ;
-	return 0;
+	for (idx = 0; idx < RDMA_NL_NUM_CLIENTS; idx++)
+		init_rwsem(&rdma_nl_types[idx].sem);
 }
 
 void rdma_nl_exit(void)
@@ -303,9 +301,31 @@
 	int idx;
 
 	for (idx = 0; idx < RDMA_NL_NUM_CLIENTS; idx++)
-		rdma_nl_unregister(idx);
+		WARN(rdma_nl_types[idx].cb_table,
+		     "Netlink client %d wasn't released prior to unloading %s\n",
+		     idx, KBUILD_MODNAME);
+}
 
-	netlink_kernel_release(nls);
+int rdma_nl_net_init(struct rdma_dev_net *rnet)
+{
+	struct net *net = read_pnet(&rnet->net);
+	struct netlink_kernel_cfg cfg = {
+		.input	= rdma_nl_rcv,
+	};
+	struct sock *nls;
+
+	nls = netlink_kernel_create(net, NETLINK_RDMA, &cfg);
+	if (!nls)
+		return -ENOMEM;
+
+	nls->sk_sndtimeo = 10 * HZ;
+	rnet->nl_sock = nls;
+	return 0;
+}
+
+void rdma_nl_net_exit(struct rdma_dev_net *rnet)
+{
+	netlink_kernel_release(rnet->nl_sock);
 }
 
 MODULE_ALIAS_NET_PF_PROTO(PF_NETLINK, NETLINK_RDMA);

--
Gitblit v1.6.2