From 1543e317f1da31b75942316931e8f491a8920811 Mon Sep 17 00:00:00 2001
From: hc <hc@nodka.com>
Date: Thu, 04 Jan 2024 10:08:02 +0000
Subject: [PATCH] disable FB

---
 kernel/drivers/infiniband/core/ucma.c |  743 ++++++++++++++++++++++++++++----------------------------
 1 files changed, 372 insertions(+), 371 deletions(-)

diff --git a/kernel/drivers/infiniband/core/ucma.c b/kernel/drivers/infiniband/core/ucma.c
index 01052de..d12018c 100644
--- a/kernel/drivers/infiniband/core/ucma.c
+++ b/kernel/drivers/infiniband/core/ucma.c
@@ -52,6 +52,9 @@
 #include <rdma/rdma_cm_ib.h>
 #include <rdma/ib_addr.h>
 #include <rdma/ib.h>
+#include <rdma/ib_cm.h>
+#include <rdma/rdma_netlink.h>
+#include "core_priv.h"
 
 MODULE_AUTHOR("Sean Hefty");
 MODULE_DESCRIPTION("RDMA Userspace Connection Manager Access");
@@ -77,15 +80,14 @@
 	struct list_head	ctx_list;
 	struct list_head	event_list;
 	wait_queue_head_t	poll_wait;
-	struct workqueue_struct	*close_wq;
 };
 
 struct ucma_context {
-	int			id;
+	u32			id;
 	struct completion	comp;
-	atomic_t		ref;
+	refcount_t		ref;
 	int			events_reported;
-	int			backlog;
+	atomic_t		backlog;
 
 	struct ucma_file	*file;
 	struct rdma_cm_id	*cm_id;
@@ -94,18 +96,12 @@
 
 	struct list_head	list;
 	struct list_head	mc_list;
-	/* mark that device is in process of destroying the internal HW
-	 * resources, protected by the global mut
-	 */
-	int			closing;
-	/* sync between removal event and id destroy, protected by file mut */
-	int			destroying;
 	struct work_struct	close_work;
 };
 
 struct ucma_multicast {
 	struct ucma_context	*ctx;
-	int			id;
+	u32			id;
 	int			events_reported;
 
 	u64			uid;
@@ -116,28 +112,27 @@
 
 struct ucma_event {
 	struct ucma_context	*ctx;
+	struct ucma_context	*conn_req_ctx;
 	struct ucma_multicast	*mc;
 	struct list_head	list;
-	struct rdma_cm_id	*cm_id;
 	struct rdma_ucm_event_resp resp;
-	struct work_struct	close_work;
 };
 
-static DEFINE_MUTEX(mut);
-static DEFINE_IDR(ctx_idr);
-static DEFINE_IDR(multicast_idr);
+static DEFINE_XARRAY_ALLOC(ctx_table);
+static DEFINE_XARRAY_ALLOC(multicast_table);
 
 static const struct file_operations ucma_fops;
+static int ucma_destroy_private_ctx(struct ucma_context *ctx);
 
 static inline struct ucma_context *_ucma_find_context(int id,
 						      struct ucma_file *file)
 {
 	struct ucma_context *ctx;
 
-	ctx = idr_find(&ctx_idr, id);
+	ctx = xa_load(&ctx_table, id);
 	if (!ctx)
 		ctx = ERR_PTR(-ENOENT);
-	else if (ctx->file != file || !ctx->cm_id)
+	else if (ctx->file != file)
 		ctx = ERR_PTR(-EINVAL);
 	return ctx;
 }
@@ -146,21 +141,18 @@
 {
 	struct ucma_context *ctx;
 
-	mutex_lock(&mut);
+	xa_lock(&ctx_table);
 	ctx = _ucma_find_context(id, file);
-	if (!IS_ERR(ctx)) {
-		if (ctx->closing)
-			ctx = ERR_PTR(-EIO);
-		else
-			atomic_inc(&ctx->ref);
-	}
-	mutex_unlock(&mut);
+	if (!IS_ERR(ctx))
+		if (!refcount_inc_not_zero(&ctx->ref))
+			ctx = ERR_PTR(-ENXIO);
+	xa_unlock(&ctx_table);
 	return ctx;
 }
 
 static void ucma_put_ctx(struct ucma_context *ctx)
 {
-	if (atomic_dec_and_test(&ctx->ref))
+	if (refcount_dec_and_test(&ctx->ref))
 		complete(&ctx->comp);
 }
 
@@ -181,26 +173,21 @@
 	return ctx;
 }
 
-static void ucma_close_event_id(struct work_struct *work)
-{
-	struct ucma_event *uevent_close =  container_of(work, struct ucma_event, close_work);
-
-	rdma_destroy_id(uevent_close->cm_id);
-	kfree(uevent_close);
-}
-
 static void ucma_close_id(struct work_struct *work)
 {
 	struct ucma_context *ctx =  container_of(work, struct ucma_context, close_work);
 
 	/* once all inflight tasks are finished, we close all underlying
 	 * resources. The context is still alive till its explicit destryoing
-	 * by its creator.
+	 * by its creator. This puts back the xarray's reference.
 	 */
 	ucma_put_ctx(ctx);
 	wait_for_completion(&ctx->comp);
 	/* No new events will be generated after destroying the id. */
 	rdma_destroy_id(ctx->cm_id);
+
+	/* Reading the cm_id without holding a positive ref is not allowed */
+	ctx->cm_id = NULL;
 }
 
 static struct ucma_context *ucma_alloc_ctx(struct ucma_file *file)
@@ -212,47 +199,32 @@
 		return NULL;
 
 	INIT_WORK(&ctx->close_work, ucma_close_id);
-	atomic_set(&ctx->ref, 1);
 	init_completion(&ctx->comp);
 	INIT_LIST_HEAD(&ctx->mc_list);
+	/* So list_del() will work if we don't do ucma_finish_ctx() */
+	INIT_LIST_HEAD(&ctx->list);
 	ctx->file = file;
 	mutex_init(&ctx->mutex);
 
-	mutex_lock(&mut);
-	ctx->id = idr_alloc(&ctx_idr, ctx, 0, 0, GFP_KERNEL);
-	mutex_unlock(&mut);
-	if (ctx->id < 0)
-		goto error;
-
-	list_add_tail(&ctx->list, &file->ctx_list);
+	if (xa_alloc(&ctx_table, &ctx->id, NULL, xa_limit_32b, GFP_KERNEL)) {
+		kfree(ctx);
+		return NULL;
+	}
 	return ctx;
-
-error:
-	kfree(ctx);
-	return NULL;
 }
 
-static struct ucma_multicast* ucma_alloc_multicast(struct ucma_context *ctx)
+static void ucma_set_ctx_cm_id(struct ucma_context *ctx,
+			       struct rdma_cm_id *cm_id)
 {
-	struct ucma_multicast *mc;
+	refcount_set(&ctx->ref, 1);
+	ctx->cm_id = cm_id;
+}
 
-	mc = kzalloc(sizeof(*mc), GFP_KERNEL);
-	if (!mc)
-		return NULL;
-
-	mutex_lock(&mut);
-	mc->id = idr_alloc(&multicast_idr, NULL, 0, 0, GFP_KERNEL);
-	mutex_unlock(&mut);
-	if (mc->id < 0)
-		goto error;
-
-	mc->ctx = ctx;
-	list_add_tail(&mc->list, &ctx->mc_list);
-	return mc;
-
-error:
-	kfree(mc);
-	return NULL;
+static void ucma_finish_ctx(struct ucma_context *ctx)
+{
+	lockdep_assert_held(&ctx->file->mut);
+	list_add_tail(&ctx->list, &ctx->file->ctx_list);
+	xa_store(&ctx_table, ctx->id, ctx, GFP_KERNEL);
 }
 
 static void ucma_copy_conn_event(struct rdma_ucm_conn_param *dst,
@@ -284,10 +256,15 @@
 	dst->qkey = src->qkey;
 }
 
-static void ucma_set_event_context(struct ucma_context *ctx,
-				   struct rdma_cm_event *event,
-				   struct ucma_event *uevent)
+static struct ucma_event *ucma_create_uevent(struct ucma_context *ctx,
+					     struct rdma_cm_event *event)
 {
+	struct ucma_event *uevent;
+
+	uevent = kzalloc(sizeof(*uevent), GFP_KERNEL);
+	if (!uevent)
+		return NULL;
+
 	uevent->ctx = ctx;
 	switch (event->event) {
 	case RDMA_CM_EVENT_MULTICAST_JOIN:
@@ -302,44 +279,55 @@
 		uevent->resp.id = ctx->id;
 		break;
 	}
+	uevent->resp.event = event->event;
+	uevent->resp.status = event->status;
+	if (ctx->cm_id->qp_type == IB_QPT_UD)
+		ucma_copy_ud_event(ctx->cm_id->device, &uevent->resp.param.ud,
+				   &event->param.ud);
+	else
+		ucma_copy_conn_event(&uevent->resp.param.conn,
+				     &event->param.conn);
+
+	uevent->resp.ece.vendor_id = event->ece.vendor_id;
+	uevent->resp.ece.attr_mod = event->ece.attr_mod;
+	return uevent;
 }
 
-/* Called with file->mut locked for the relevant context. */
-static void ucma_removal_event_handler(struct rdma_cm_id *cm_id)
+static int ucma_connect_event_handler(struct rdma_cm_id *cm_id,
+				      struct rdma_cm_event *event)
 {
-	struct ucma_context *ctx = cm_id->context;
-	struct ucma_event *con_req_eve;
-	int event_found = 0;
+	struct ucma_context *listen_ctx = cm_id->context;
+	struct ucma_context *ctx;
+	struct ucma_event *uevent;
 
-	if (ctx->destroying)
-		return;
+	if (!atomic_add_unless(&listen_ctx->backlog, -1, 0))
+		return -ENOMEM;
+	ctx = ucma_alloc_ctx(listen_ctx->file);
+	if (!ctx)
+		goto err_backlog;
+	ucma_set_ctx_cm_id(ctx, cm_id);
 
-	/* only if context is pointing to cm_id that it owns it and can be
-	 * queued to be closed, otherwise that cm_id is an inflight one that
-	 * is part of that context event list pending to be detached and
-	 * reattached to its new context as part of ucma_get_event,
-	 * handled separately below.
-	 */
-	if (ctx->cm_id == cm_id) {
-		mutex_lock(&mut);
-		ctx->closing = 1;
-		mutex_unlock(&mut);
-		queue_work(ctx->file->close_wq, &ctx->close_work);
-		return;
-	}
+	uevent = ucma_create_uevent(listen_ctx, event);
+	if (!uevent)
+		goto err_alloc;
+	uevent->conn_req_ctx = ctx;
+	uevent->resp.id = ctx->id;
 
-	list_for_each_entry(con_req_eve, &ctx->file->event_list, list) {
-		if (con_req_eve->cm_id == cm_id &&
-		    con_req_eve->resp.event == RDMA_CM_EVENT_CONNECT_REQUEST) {
-			list_del(&con_req_eve->list);
-			INIT_WORK(&con_req_eve->close_work, ucma_close_event_id);
-			queue_work(ctx->file->close_wq, &con_req_eve->close_work);
-			event_found = 1;
-			break;
-		}
-	}
-	if (!event_found)
-		pr_err("ucma_removal_event_handler: warning: connect request event wasn't found\n");
+	ctx->cm_id->context = ctx;
+
+	mutex_lock(&ctx->file->mut);
+	ucma_finish_ctx(ctx);
+	list_add_tail(&uevent->list, &ctx->file->event_list);
+	mutex_unlock(&ctx->file->mut);
+	wake_up_interruptible(&ctx->file->poll_wait);
+	return 0;
+
+err_alloc:
+	ucma_destroy_private_ctx(ctx);
+err_backlog:
+	atomic_inc(&listen_ctx->backlog);
+	/* Returning error causes the new ID to be destroyed */
+	return -ENOMEM;
 }
 
 static int ucma_event_handler(struct rdma_cm_id *cm_id,
@@ -347,69 +335,49 @@
 {
 	struct ucma_event *uevent;
 	struct ucma_context *ctx = cm_id->context;
-	int ret = 0;
 
-	uevent = kzalloc(sizeof(*uevent), GFP_KERNEL);
-	if (!uevent)
-		return event->event == RDMA_CM_EVENT_CONNECT_REQUEST;
+	if (event->event == RDMA_CM_EVENT_CONNECT_REQUEST)
+		return ucma_connect_event_handler(cm_id, event);
 
-	mutex_lock(&ctx->file->mut);
-	uevent->cm_id = cm_id;
-	ucma_set_event_context(ctx, event, uevent);
-	uevent->resp.event = event->event;
-	uevent->resp.status = event->status;
-	if (cm_id->qp_type == IB_QPT_UD)
-		ucma_copy_ud_event(cm_id->device, &uevent->resp.param.ud,
-				   &event->param.ud);
-	else
-		ucma_copy_conn_event(&uevent->resp.param.conn,
-				     &event->param.conn);
+	/*
+	 * We ignore events for new connections until userspace has set their
+	 * context.  This can only happen if an error occurs on a new connection
+	 * before the user accepts it.  This is okay, since the accept will just
+	 * fail later. However, we do need to release the underlying HW
+	 * resources in case of a device removal event.
+	 */
+	if (ctx->uid) {
+		uevent = ucma_create_uevent(ctx, event);
+		if (!uevent)
+			return 0;
 
-	if (event->event == RDMA_CM_EVENT_CONNECT_REQUEST) {
-		if (!ctx->backlog) {
-			ret = -ENOMEM;
-			kfree(uevent);
-			goto out;
-		}
-		ctx->backlog--;
-	} else if (!ctx->uid || ctx->cm_id != cm_id) {
-		/*
-		 * We ignore events for new connections until userspace has set
-		 * their context.  This can only happen if an error occurs on a
-		 * new connection before the user accepts it.  This is okay,
-		 * since the accept will just fail later. However, we do need
-		 * to release the underlying HW resources in case of a device
-		 * removal event.
-		 */
-		if (event->event == RDMA_CM_EVENT_DEVICE_REMOVAL)
-			ucma_removal_event_handler(cm_id);
-
-		kfree(uevent);
-		goto out;
+		mutex_lock(&ctx->file->mut);
+		list_add_tail(&uevent->list, &ctx->file->event_list);
+		mutex_unlock(&ctx->file->mut);
+		wake_up_interruptible(&ctx->file->poll_wait);
 	}
 
-	list_add_tail(&uevent->list, &ctx->file->event_list);
-	wake_up_interruptible(&ctx->file->poll_wait);
-	if (event->event == RDMA_CM_EVENT_DEVICE_REMOVAL)
-		ucma_removal_event_handler(cm_id);
-out:
-	mutex_unlock(&ctx->file->mut);
-	return ret;
+	if (event->event == RDMA_CM_EVENT_DEVICE_REMOVAL) {
+		xa_lock(&ctx_table);
+		if (xa_load(&ctx_table, ctx->id) == ctx)
+			queue_work(system_unbound_wq, &ctx->close_work);
+		xa_unlock(&ctx_table);
+	}
+	return 0;
 }
 
 static ssize_t ucma_get_event(struct ucma_file *file, const char __user *inbuf,
 			      int in_len, int out_len)
 {
-	struct ucma_context *ctx;
 	struct rdma_ucm_get_event cmd;
 	struct ucma_event *uevent;
-	int ret = 0;
 
 	/*
 	 * Old 32 bit user space does not send the 4 byte padding in the
 	 * reserved field. We don't care, allow it to keep working.
 	 */
-	if (out_len < sizeof(uevent->resp) - sizeof(uevent->resp.reserved))
+	if (out_len < sizeof(uevent->resp) - sizeof(uevent->resp.reserved) -
+			      sizeof(uevent->resp.ece))
 		return -ENOSPC;
 
 	if (copy_from_user(&cmd, inbuf, sizeof(cmd)))
@@ -429,35 +397,25 @@
 		mutex_lock(&file->mut);
 	}
 
-	uevent = list_entry(file->event_list.next, struct ucma_event, list);
-
-	if (uevent->resp.event == RDMA_CM_EVENT_CONNECT_REQUEST) {
-		ctx = ucma_alloc_ctx(file);
-		if (!ctx) {
-			ret = -ENOMEM;
-			goto done;
-		}
-		uevent->ctx->backlog++;
-		ctx->cm_id = uevent->cm_id;
-		ctx->cm_id->context = ctx;
-		uevent->resp.id = ctx->id;
-	}
+	uevent = list_first_entry(&file->event_list, struct ucma_event, list);
 
 	if (copy_to_user(u64_to_user_ptr(cmd.response),
 			 &uevent->resp,
 			 min_t(size_t, out_len, sizeof(uevent->resp)))) {
-		ret = -EFAULT;
-		goto done;
+		mutex_unlock(&file->mut);
+		return -EFAULT;
 	}
 
 	list_del(&uevent->list);
 	uevent->ctx->events_reported++;
 	if (uevent->mc)
 		uevent->mc->events_reported++;
-	kfree(uevent);
-done:
+	if (uevent->resp.event == RDMA_CM_EVENT_CONNECT_REQUEST)
+		atomic_inc(&uevent->ctx->backlog);
 	mutex_unlock(&file->mut);
-	return ret;
+
+	kfree(uevent);
+	return 0;
 }
 
 static int ucma_get_qp_type(struct rdma_ucm_create_id *cmd, enum ib_qp_type *qp_type)
@@ -498,40 +456,32 @@
 	if (ret)
 		return ret;
 
-	mutex_lock(&file->mut);
 	ctx = ucma_alloc_ctx(file);
-	mutex_unlock(&file->mut);
 	if (!ctx)
 		return -ENOMEM;
 
 	ctx->uid = cmd.uid;
-	cm_id = __rdma_create_id(current->nsproxy->net_ns,
-				 ucma_event_handler, ctx, cmd.ps, qp_type, NULL);
+	cm_id = rdma_create_user_id(ucma_event_handler, ctx, cmd.ps, qp_type);
 	if (IS_ERR(cm_id)) {
 		ret = PTR_ERR(cm_id);
 		goto err1;
 	}
+	ucma_set_ctx_cm_id(ctx, cm_id);
 
 	resp.id = ctx->id;
 	if (copy_to_user(u64_to_user_ptr(cmd.response),
 			 &resp, sizeof(resp))) {
-		ret = -EFAULT;
-		goto err2;
+		ucma_destroy_private_ctx(ctx);
+		return -EFAULT;
 	}
 
-	ctx->cm_id = cm_id;
+	mutex_lock(&file->mut);
+	ucma_finish_ctx(ctx);
+	mutex_unlock(&file->mut);
 	return 0;
 
-err2:
-	rdma_destroy_id(cm_id);
 err1:
-	mutex_lock(&mut);
-	idr_remove(&ctx_idr, ctx->id);
-	mutex_unlock(&mut);
-	mutex_lock(&file->mut);
-	list_del(&ctx->list);
-	mutex_unlock(&file->mut);
-	kfree(ctx);
+	ucma_destroy_private_ctx(ctx);
 	return ret;
 }
 
@@ -539,19 +489,25 @@
 {
 	struct ucma_multicast *mc, *tmp;
 
-	mutex_lock(&mut);
+	xa_lock(&multicast_table);
 	list_for_each_entry_safe(mc, tmp, &ctx->mc_list, list) {
 		list_del(&mc->list);
-		idr_remove(&multicast_idr, mc->id);
+		/*
+		 * At this point mc->ctx->ref is 0 so the mc cannot leave the
+		 * lock on the reader and this is enough serialization
+		 */
+		__xa_erase(&multicast_table, mc->id);
 		kfree(mc);
 	}
-	mutex_unlock(&mut);
+	xa_unlock(&multicast_table);
 }
 
 static void ucma_cleanup_mc_events(struct ucma_multicast *mc)
 {
 	struct ucma_event *uevent, *tmp;
 
+	rdma_lock_handler(mc->ctx->cm_id);
+	mutex_lock(&mc->ctx->file->mut);
 	list_for_each_entry_safe(uevent, tmp, &mc->ctx->file->event_list, list) {
 		if (uevent->mc != mc)
 			continue;
@@ -559,45 +515,74 @@
 		list_del(&uevent->list);
 		kfree(uevent);
 	}
+	mutex_unlock(&mc->ctx->file->mut);
+	rdma_unlock_handler(mc->ctx->cm_id);
 }
 
-/*
- * ucma_free_ctx is called after the underlying rdma CM-ID is destroyed. At
- * this point, no new events will be reported from the hardware. However, we
- * still need to cleanup the UCMA context for this ID. Specifically, there
- * might be events that have not yet been consumed by the user space software.
- * These might include pending connect requests which we have not completed
- * processing.  We cannot call rdma_destroy_id while holding the lock of the
- * context (file->mut), as it might cause a deadlock. We therefore extract all
- * relevant events from the context pending events list while holding the
- * mutex. After that we release them as needed.
- */
-static int ucma_free_ctx(struct ucma_context *ctx)
+static int ucma_cleanup_ctx_events(struct ucma_context *ctx)
 {
 	int events_reported;
 	struct ucma_event *uevent, *tmp;
 	LIST_HEAD(list);
 
-
-	ucma_cleanup_multicast(ctx);
-
-	/* Cleanup events not yet reported to the user. */
+	/* Cleanup events not yet reported to the user.*/
 	mutex_lock(&ctx->file->mut);
 	list_for_each_entry_safe(uevent, tmp, &ctx->file->event_list, list) {
-		if (uevent->ctx == ctx)
+		if (uevent->ctx != ctx)
+			continue;
+
+		if (uevent->resp.event == RDMA_CM_EVENT_CONNECT_REQUEST &&
+		    xa_cmpxchg(&ctx_table, uevent->conn_req_ctx->id,
+			       uevent->conn_req_ctx, XA_ZERO_ENTRY,
+			       GFP_KERNEL) == uevent->conn_req_ctx) {
 			list_move_tail(&uevent->list, &list);
+			continue;
+		}
+		list_del(&uevent->list);
+		kfree(uevent);
 	}
 	list_del(&ctx->list);
 	events_reported = ctx->events_reported;
 	mutex_unlock(&ctx->file->mut);
 
+	/*
+	 * If this was a listening ID then any connections spawned from it that
+	 * have not been delivered to userspace are cleaned up too. Must be done
+	 * outside any locks.
+	 */
 	list_for_each_entry_safe(uevent, tmp, &list, list) {
-		list_del(&uevent->list);
-		if (uevent->resp.event == RDMA_CM_EVENT_CONNECT_REQUEST)
-			rdma_destroy_id(uevent->cm_id);
+		ucma_destroy_private_ctx(uevent->conn_req_ctx);
 		kfree(uevent);
 	}
+	return events_reported;
+}
 
+/*
+ * When this is called the xarray must have a XA_ZERO_ENTRY in the ctx->id (ie
+ * the ctx is not public to the user). This either because:
+ *  - ucma_finish_ctx() hasn't been called
+ *  - xa_cmpxchg() succeed to remove the entry (only one thread can succeed)
+ */
+static int ucma_destroy_private_ctx(struct ucma_context *ctx)
+{
+	int events_reported;
+
+	/*
+	 * Destroy the underlying cm_id. New work queuing is prevented now by
+	 * the removal from the xarray. Once the work is cancled ref will either
+	 * be 0 because the work ran to completion and consumed the ref from the
+	 * xarray, or it will be positive because we still have the ref from the
+	 * xarray. This can also be 0 in cases where cm_id was never set
+	 */
+	cancel_work_sync(&ctx->close_work);
+	if (refcount_read(&ctx->ref))
+		ucma_close_id(&ctx->close_work);
+
+	events_reported = ucma_cleanup_ctx_events(ctx);
+	ucma_cleanup_multicast(ctx);
+
+	WARN_ON(xa_cmpxchg(&ctx_table, ctx->id, XA_ZERO_ENTRY, NULL,
+			   GFP_KERNEL) != NULL);
 	mutex_destroy(&ctx->mutex);
 	kfree(ctx);
 	return events_reported;
@@ -617,33 +602,19 @@
 	if (copy_from_user(&cmd, inbuf, sizeof(cmd)))
 		return -EFAULT;
 
-	mutex_lock(&mut);
+	xa_lock(&ctx_table);
 	ctx = _ucma_find_context(cmd.id, file);
-	if (!IS_ERR(ctx))
-		idr_remove(&ctx_idr, ctx->id);
-	mutex_unlock(&mut);
+	if (!IS_ERR(ctx)) {
+		if (__xa_cmpxchg(&ctx_table, ctx->id, ctx, XA_ZERO_ENTRY,
+				 GFP_KERNEL) != ctx)
+			ctx = ERR_PTR(-ENOENT);
+	}
+	xa_unlock(&ctx_table);
 
 	if (IS_ERR(ctx))
 		return PTR_ERR(ctx);
 
-	mutex_lock(&ctx->file->mut);
-	ctx->destroying = 1;
-	mutex_unlock(&ctx->file->mut);
-
-	flush_workqueue(ctx->file->close_wq);
-	/* At this point it's guaranteed that there is no inflight
-	 * closing task */
-	mutex_lock(&mut);
-	if (!ctx->closing) {
-		mutex_unlock(&mut);
-		ucma_put_ctx(ctx);
-		wait_for_completion(&ctx->comp);
-		rdma_destroy_id(ctx->cm_id);
-	} else {
-		mutex_unlock(&mut);
-	}
-
-	resp.events_reported = ucma_free_ctx(ctx);
+	resp.events_reported = ucma_destroy_private_ctx(ctx);
 	if (copy_to_user(u64_to_user_ptr(cmd.response),
 			 &resp, sizeof(resp)))
 		ret = -EFAULT;
@@ -796,7 +767,7 @@
 	case 2:
 		ib_copy_path_rec_to_user(&resp->ib_route[1],
 					 &route->path_rec[1]);
-		/* fall through */
+		fallthrough;
 	case 1:
 		ib_copy_path_rec_to_user(&resp->ib_route[0],
 					 &route->path_rec[0]);
@@ -822,7 +793,7 @@
 	case 2:
 		ib_copy_path_rec_to_user(&resp->ib_route[1],
 					 &route->path_rec[1]);
-		/* fall through */
+		fallthrough;
 	case 1:
 		ib_copy_path_rec_to_user(&resp->ib_route[0],
 					 &route->path_rec[0]);
@@ -852,7 +823,7 @@
 	struct sockaddr *addr;
 	int ret = 0;
 
-	if (out_len < sizeof(resp))
+	if (out_len < offsetof(struct rdma_ucm_query_route_resp, ibdev_index))
 		return -ENOSPC;
 
 	if (copy_from_user(&cmd, inbuf, sizeof(cmd)))
@@ -876,6 +847,7 @@
 		goto out;
 
 	resp.node_guid = (__force __u64) ctx->cm_id->device->node_guid;
+	resp.ibdev_index = ctx->cm_id->device->index;
 	resp.port_num = ctx->cm_id->port_num;
 
 	if (rdma_cap_ib_sa(ctx->cm_id->device, ctx->cm_id->port_num))
@@ -887,8 +859,8 @@
 
 out:
 	mutex_unlock(&ctx->mutex);
-	if (copy_to_user(u64_to_user_ptr(cmd.response),
-			 &resp, sizeof(resp)))
+	if (copy_to_user(u64_to_user_ptr(cmd.response), &resp,
+			 min_t(size_t, out_len, sizeof(resp))))
 		ret = -EFAULT;
 
 	ucma_put_ctx(ctx);
@@ -902,6 +874,7 @@
 		return;
 
 	resp->node_guid = (__force __u64) cm_id->device->node_guid;
+	resp->ibdev_index = cm_id->device->index;
 	resp->port_num = cm_id->port_num;
 	resp->pkey = (__force __u16) cpu_to_be16(
 		     ib_addr_get_pkey(&cm_id->route.addr.dev_addr));
@@ -914,7 +887,7 @@
 	struct sockaddr *addr;
 	int ret = 0;
 
-	if (out_len < sizeof(resp))
+	if (out_len < offsetof(struct rdma_ucm_query_addr_resp, ibdev_index))
 		return -ENOSPC;
 
 	memset(&resp, 0, sizeof resp);
@@ -929,7 +902,7 @@
 
 	ucma_query_device_addr(ctx->cm_id, &resp);
 
-	if (copy_to_user(response, &resp, sizeof(resp)))
+	if (copy_to_user(response, &resp, min_t(size_t, out_len, sizeof(resp))))
 		ret = -EFAULT;
 
 	return ret;
@@ -967,8 +940,7 @@
 		}
 	}
 
-	if (copy_to_user(response, resp,
-			 sizeof(*resp) + (i * sizeof(struct ib_path_rec_data))))
+	if (copy_to_user(response, resp, struct_size(resp, path_data, i)))
 		ret = -EFAULT;
 
 	kfree(resp);
@@ -982,7 +954,7 @@
 	struct sockaddr_ib *addr;
 	int ret = 0;
 
-	if (out_len < sizeof(resp))
+	if (out_len < offsetof(struct rdma_ucm_query_addr_resp, ibdev_index))
 		return -ENOSPC;
 
 	memset(&resp, 0, sizeof resp);
@@ -1015,7 +987,7 @@
 						    &ctx->cm_id->route.addr.dst_addr);
 	}
 
-	if (copy_to_user(response, &resp, sizeof(resp)))
+	if (copy_to_user(response, &resp, min_t(size_t, out_len, sizeof(resp))))
 		ret = -EFAULT;
 
 	return ret;
@@ -1071,19 +1043,24 @@
 	dst->retry_count = src->retry_count;
 	dst->rnr_retry_count = src->rnr_retry_count;
 	dst->srq = src->srq;
-	dst->qp_num = src->qp_num;
+	dst->qp_num = src->qp_num & 0xFFFFFF;
 	dst->qkey = (id->route.addr.src_addr.ss_family == AF_IB) ? src->qkey : 0;
 }
 
 static ssize_t ucma_connect(struct ucma_file *file, const char __user *inbuf,
 			    int in_len, int out_len)
 {
-	struct rdma_ucm_connect cmd;
 	struct rdma_conn_param conn_param;
+	struct rdma_ucm_ece ece = {};
+	struct rdma_ucm_connect cmd;
 	struct ucma_context *ctx;
+	size_t in_size;
 	int ret;
 
-	if (copy_from_user(&cmd, inbuf, sizeof(cmd)))
+	if (in_len < offsetofend(typeof(cmd), reserved))
+		return -EINVAL;
+	in_size = min_t(size_t, in_len, sizeof(cmd));
+	if (copy_from_user(&cmd, inbuf, in_size))
 		return -EFAULT;
 
 	if (!cmd.conn_param.valid)
@@ -1094,8 +1071,13 @@
 		return PTR_ERR(ctx);
 
 	ucma_copy_conn_param(ctx->cm_id, &conn_param, &cmd.conn_param);
+	if (offsetofend(typeof(cmd), ece) <= in_size) {
+		ece.vendor_id = cmd.ece.vendor_id;
+		ece.attr_mod = cmd.ece.attr_mod;
+	}
+
 	mutex_lock(&ctx->mutex);
-	ret = rdma_connect(ctx->cm_id, &conn_param);
+	ret = rdma_connect_ece(ctx->cm_id, &conn_param, &ece);
 	mutex_unlock(&ctx->mutex);
 	ucma_put_ctx(ctx);
 	return ret;
@@ -1115,10 +1097,12 @@
 	if (IS_ERR(ctx))
 		return PTR_ERR(ctx);
 
-	ctx->backlog = cmd.backlog > 0 && cmd.backlog < max_backlog ?
-		       cmd.backlog : max_backlog;
+	if (cmd.backlog <= 0 || cmd.backlog > max_backlog)
+		cmd.backlog = max_backlog;
+	atomic_set(&ctx->backlog, cmd.backlog);
+
 	mutex_lock(&ctx->mutex);
-	ret = rdma_listen(ctx->cm_id, ctx->backlog);
+	ret = rdma_listen(ctx->cm_id, cmd.backlog);
 	mutex_unlock(&ctx->mutex);
 	ucma_put_ctx(ctx);
 	return ret;
@@ -1129,28 +1113,42 @@
 {
 	struct rdma_ucm_accept cmd;
 	struct rdma_conn_param conn_param;
+	struct rdma_ucm_ece ece = {};
 	struct ucma_context *ctx;
+	size_t in_size;
 	int ret;
 
-	if (copy_from_user(&cmd, inbuf, sizeof(cmd)))
+	if (in_len < offsetofend(typeof(cmd), reserved))
+		return -EINVAL;
+	in_size = min_t(size_t, in_len, sizeof(cmd));
+	if (copy_from_user(&cmd, inbuf, in_size))
 		return -EFAULT;
 
 	ctx = ucma_get_ctx_dev(file, cmd.id);
 	if (IS_ERR(ctx))
 		return PTR_ERR(ctx);
 
+	if (offsetofend(typeof(cmd), ece) <= in_size) {
+		ece.vendor_id = cmd.ece.vendor_id;
+		ece.attr_mod = cmd.ece.attr_mod;
+	}
+
 	if (cmd.conn_param.valid) {
 		ucma_copy_conn_param(ctx->cm_id, &conn_param, &cmd.conn_param);
-		mutex_lock(&file->mut);
 		mutex_lock(&ctx->mutex);
-		ret = __rdma_accept(ctx->cm_id, &conn_param, NULL);
-		mutex_unlock(&ctx->mutex);
-		if (!ret)
+		rdma_lock_handler(ctx->cm_id);
+		ret = rdma_accept_ece(ctx->cm_id, &conn_param, &ece);
+		if (!ret) {
+			/* The uid must be set atomically with the handler */
 			ctx->uid = cmd.uid;
-		mutex_unlock(&file->mut);
+		}
+		rdma_unlock_handler(ctx->cm_id);
+		mutex_unlock(&ctx->mutex);
 	} else {
 		mutex_lock(&ctx->mutex);
-		ret = __rdma_accept(ctx->cm_id, NULL, NULL);
+		rdma_lock_handler(ctx->cm_id);
+		ret = rdma_accept_ece(ctx->cm_id, NULL, &ece);
+		rdma_unlock_handler(ctx->cm_id);
 		mutex_unlock(&ctx->mutex);
 	}
 	ucma_put_ctx(ctx);
@@ -1167,12 +1165,24 @@
 	if (copy_from_user(&cmd, inbuf, sizeof(cmd)))
 		return -EFAULT;
 
+	if (!cmd.reason)
+		cmd.reason = IB_CM_REJ_CONSUMER_DEFINED;
+
+	switch (cmd.reason) {
+	case IB_CM_REJ_CONSUMER_DEFINED:
+	case IB_CM_REJ_VENDOR_OPTION_NOT_SUPPORTED:
+		break;
+	default:
+		return -EINVAL;
+	}
+
 	ctx = ucma_get_ctx_dev(file, cmd.id);
 	if (IS_ERR(ctx))
 		return PTR_ERR(ctx);
 
 	mutex_lock(&ctx->mutex);
-	ret = rdma_reject(ctx->cm_id, cmd.private_data, cmd.private_data_len);
+	ret = rdma_reject(ctx->cm_id, cmd.private_data, cmd.private_data_len,
+			  cmd.reason);
 	mutex_unlock(&ctx->mutex);
 	ucma_put_ctx(ctx);
 	return ret;
@@ -1267,6 +1277,13 @@
 			break;
 		}
 		ret = rdma_set_afonly(ctx->cm_id, *((int *) optval) ? 1 : 0);
+		break;
+	case RDMA_OPTION_ID_ACK_TIMEOUT:
+		if (optlen != sizeof(u8)) {
+			ret = -EINVAL;
+			break;
+		}
+		ret = rdma_set_ack_timeout(ctx->cm_id, *((u8 *)optval));
 		break;
 	default:
 		ret = -ENOSYS;
@@ -1444,50 +1461,59 @@
 	if (IS_ERR(ctx))
 		return PTR_ERR(ctx);
 
-	mutex_lock(&file->mut);
-	mc = ucma_alloc_multicast(ctx);
+	mc = kzalloc(sizeof(*mc), GFP_KERNEL);
 	if (!mc) {
 		ret = -ENOMEM;
-		goto err1;
+		goto err_put_ctx;
 	}
+
+	mc->ctx = ctx;
 	mc->join_state = join_state;
 	mc->uid = cmd->uid;
 	memcpy(&mc->addr, addr, cmd->addr_size);
+
+	xa_lock(&multicast_table);
+	if (__xa_alloc(&multicast_table, &mc->id, NULL, xa_limit_32b,
+		     GFP_KERNEL)) {
+		ret = -ENOMEM;
+		goto err_free_mc;
+	}
+
+	list_add_tail(&mc->list, &ctx->mc_list);
+	xa_unlock(&multicast_table);
+
 	mutex_lock(&ctx->mutex);
 	ret = rdma_join_multicast(ctx->cm_id, (struct sockaddr *)&mc->addr,
 				  join_state, mc);
 	mutex_unlock(&ctx->mutex);
 	if (ret)
-		goto err2;
+		goto err_xa_erase;
 
 	resp.id = mc->id;
 	if (copy_to_user(u64_to_user_ptr(cmd->response),
 			 &resp, sizeof(resp))) {
 		ret = -EFAULT;
-		goto err3;
+		goto err_leave_multicast;
 	}
 
-	mutex_lock(&mut);
-	idr_replace(&multicast_idr, mc, mc->id);
-	mutex_unlock(&mut);
+	xa_store(&multicast_table, mc->id, mc, 0);
 
-	mutex_unlock(&file->mut);
 	ucma_put_ctx(ctx);
 	return 0;
 
-err3:
+err_leave_multicast:
 	mutex_lock(&ctx->mutex);
 	rdma_leave_multicast(ctx->cm_id, (struct sockaddr *) &mc->addr);
 	mutex_unlock(&ctx->mutex);
 	ucma_cleanup_mc_events(mc);
-err2:
-	mutex_lock(&mut);
-	idr_remove(&multicast_idr, mc->id);
-	mutex_unlock(&mut);
+err_xa_erase:
+	xa_lock(&multicast_table);
 	list_del(&mc->list);
+	__xa_erase(&multicast_table, mc->id);
+err_free_mc:
+	xa_unlock(&multicast_table);
 	kfree(mc);
-err1:
-	mutex_unlock(&file->mut);
+err_put_ctx:
 	ucma_put_ctx(ctx);
 	return ret;
 }
@@ -1545,31 +1571,30 @@
 	if (copy_from_user(&cmd, inbuf, sizeof(cmd)))
 		return -EFAULT;
 
-	mutex_lock(&mut);
-	mc = idr_find(&multicast_idr, cmd.id);
+	xa_lock(&multicast_table);
+	mc = xa_load(&multicast_table, cmd.id);
 	if (!mc)
 		mc = ERR_PTR(-ENOENT);
-	else if (mc->ctx->file != file)
+	else if (READ_ONCE(mc->ctx->file) != file)
 		mc = ERR_PTR(-EINVAL);
-	else if (!atomic_inc_not_zero(&mc->ctx->ref))
+	else if (!refcount_inc_not_zero(&mc->ctx->ref))
 		mc = ERR_PTR(-ENXIO);
-	else
-		idr_remove(&multicast_idr, mc->id);
-	mutex_unlock(&mut);
 
 	if (IS_ERR(mc)) {
+		xa_unlock(&multicast_table);
 		ret = PTR_ERR(mc);
 		goto out;
 	}
+
+	list_del(&mc->list);
+	__xa_erase(&multicast_table, mc->id);
+	xa_unlock(&multicast_table);
 
 	mutex_lock(&mc->ctx->mutex);
 	rdma_leave_multicast(mc->ctx->cm_id, (struct sockaddr *) &mc->addr);
 	mutex_unlock(&mc->ctx->mutex);
 
-	mutex_lock(&mc->ctx->file->mut);
 	ucma_cleanup_mc_events(mc);
-	list_del(&mc->list);
-	mutex_unlock(&mc->ctx->file->mut);
 
 	ucma_put_ctx(mc->ctx);
 	resp.events_reported = mc->events_reported;
@@ -1582,45 +1607,15 @@
 	return ret;
 }
 
-static void ucma_lock_files(struct ucma_file *file1, struct ucma_file *file2)
-{
-	/* Acquire mutex's based on pointer comparison to prevent deadlock. */
-	if (file1 < file2) {
-		mutex_lock(&file1->mut);
-		mutex_lock_nested(&file2->mut, SINGLE_DEPTH_NESTING);
-	} else {
-		mutex_lock(&file2->mut);
-		mutex_lock_nested(&file1->mut, SINGLE_DEPTH_NESTING);
-	}
-}
-
-static void ucma_unlock_files(struct ucma_file *file1, struct ucma_file *file2)
-{
-	if (file1 < file2) {
-		mutex_unlock(&file2->mut);
-		mutex_unlock(&file1->mut);
-	} else {
-		mutex_unlock(&file1->mut);
-		mutex_unlock(&file2->mut);
-	}
-}
-
-static void ucma_move_events(struct ucma_context *ctx, struct ucma_file *file)
-{
-	struct ucma_event *uevent, *tmp;
-
-	list_for_each_entry_safe(uevent, tmp, &ctx->file->event_list, list)
-		if (uevent->ctx == ctx)
-			list_move_tail(&uevent->list, &file->event_list);
-}
-
 static ssize_t ucma_migrate_id(struct ucma_file *new_file,
 			       const char __user *inbuf,
 			       int in_len, int out_len)
 {
 	struct rdma_ucm_migrate_id cmd;
 	struct rdma_ucm_migrate_resp resp;
+	struct ucma_event *uevent, *tmp;
 	struct ucma_context *ctx;
+	LIST_HEAD(event_list);
 	struct fd f;
 	struct ucma_file *cur_file;
 	int ret = 0;
@@ -1636,42 +1631,53 @@
 		ret = -EINVAL;
 		goto file_put;
 	}
+	cur_file = f.file->private_data;
 
 	/* Validate current fd and prevent destruction of id. */
-	ctx = ucma_get_ctx(f.file->private_data, cmd.id);
+	ctx = ucma_get_ctx(cur_file, cmd.id);
 	if (IS_ERR(ctx)) {
 		ret = PTR_ERR(ctx);
 		goto file_put;
 	}
 
-	cur_file = ctx->file;
-	if (cur_file == new_file) {
-		mutex_lock(&cur_file->mut);
-		resp.events_reported = ctx->events_reported;
-		mutex_unlock(&cur_file->mut);
-		goto response;
-	}
-
+	rdma_lock_handler(ctx->cm_id);
 	/*
-	 * Migrate events between fd's, maintaining order, and avoiding new
-	 * events being added before existing events.
+	 * ctx->file can only be changed under the handler & xa_lock. xa_load()
+	 * must be checked again to ensure the ctx hasn't begun destruction
+	 * since the ucma_get_ctx().
 	 */
-	ucma_lock_files(cur_file, new_file);
-	mutex_lock(&mut);
-
-	list_move_tail(&ctx->list, &new_file->ctx_list);
-	ucma_move_events(ctx, new_file);
+	xa_lock(&ctx_table);
+	if (_ucma_find_context(cmd.id, cur_file) != ctx) {
+		xa_unlock(&ctx_table);
+		ret = -ENOENT;
+		goto err_unlock;
+	}
 	ctx->file = new_file;
+	xa_unlock(&ctx_table);
+
+	mutex_lock(&cur_file->mut);
+	list_del(&ctx->list);
+	/*
+	 * At this point lock_handler() prevents addition of new uevents for
+	 * this ctx.
+	 */
+	list_for_each_entry_safe(uevent, tmp, &cur_file->event_list, list)
+		if (uevent->ctx == ctx)
+			list_move_tail(&uevent->list, &event_list);
 	resp.events_reported = ctx->events_reported;
+	mutex_unlock(&cur_file->mut);
 
-	mutex_unlock(&mut);
-	ucma_unlock_files(cur_file, new_file);
+	mutex_lock(&new_file->mut);
+	list_add_tail(&ctx->list, &new_file->ctx_list);
+	list_splice_tail(&event_list, &new_file->event_list);
+	mutex_unlock(&new_file->mut);
 
-response:
 	if (copy_to_user(u64_to_user_ptr(cmd.response),
 			 &resp, sizeof(resp)))
 		ret = -EFAULT;
 
+err_unlock:
+	rdma_unlock_handler(ctx->cm_id);
 	ucma_put_ctx(ctx);
 file_put:
 	fdput(f);
@@ -1771,13 +1777,6 @@
 	if (!file)
 		return -ENOMEM;
 
-	file->close_wq = alloc_ordered_workqueue("ucma_close_id",
-						 WQ_MEM_RECLAIM);
-	if (!file->close_wq) {
-		kfree(file);
-		return -ENOMEM;
-	}
-
 	INIT_LIST_HEAD(&file->event_list);
 	INIT_LIST_HEAD(&file->ctx_list);
 	init_waitqueue_head(&file->poll_wait);
@@ -1786,46 +1785,29 @@
 	filp->private_data = file;
 	file->filp = filp;
 
-	return nonseekable_open(inode, filp);
+	return stream_open(inode, filp);
 }
 
 static int ucma_close(struct inode *inode, struct file *filp)
 {
 	struct ucma_file *file = filp->private_data;
-	struct ucma_context *ctx, *tmp;
 
-	mutex_lock(&file->mut);
-	list_for_each_entry_safe(ctx, tmp, &file->ctx_list, list) {
-		ctx->destroying = 1;
-		mutex_unlock(&file->mut);
+	/*
+	 * All paths that touch ctx_list or ctx_list starting from write() are
+	 * prevented by this being a FD release function. The list_add_tail() in
+	 * ucma_connect_event_handler() can run concurrently, however it only
+	 * adds to the list *after* a listening ID. By only reading the first of
+	 * the list, and relying on ucma_destroy_private_ctx() to block
+	 * ucma_connect_event_handler(), no additional locking is needed.
+	 */
+	while (!list_empty(&file->ctx_list)) {
+		struct ucma_context *ctx = list_first_entry(
+			&file->ctx_list, struct ucma_context, list);
 
-		mutex_lock(&mut);
-		idr_remove(&ctx_idr, ctx->id);
-		mutex_unlock(&mut);
-
-		flush_workqueue(file->close_wq);
-		/* At that step once ctx was marked as destroying and workqueue
-		 * was flushed we are safe from any inflights handlers that
-		 * might put other closing task.
-		 */
-		mutex_lock(&mut);
-		if (!ctx->closing) {
-			mutex_unlock(&mut);
-			ucma_put_ctx(ctx);
-			wait_for_completion(&ctx->comp);
-			/* rdma_destroy_id ensures that no event handlers are
-			 * inflight for that id before releasing it.
-			 */
-			rdma_destroy_id(ctx->cm_id);
-		} else {
-			mutex_unlock(&mut);
-		}
-
-		ucma_free_ctx(ctx);
-		mutex_lock(&file->mut);
+		WARN_ON(xa_cmpxchg(&ctx_table, ctx->id, ctx, XA_ZERO_ENTRY,
+				   GFP_KERNEL) != ctx);
+		ucma_destroy_private_ctx(ctx);
 	}
-	mutex_unlock(&file->mut);
-	destroy_workqueue(file->close_wq);
 	kfree(file);
 	return 0;
 }
@@ -1846,6 +1828,19 @@
 	.mode		= 0666,
 	.fops		= &ucma_fops,
 };
+
+static int ucma_get_global_nl_info(struct ib_client_nl_info *res)
+{
+	res->abi = RDMA_USER_CM_ABI_VERSION;
+	res->cdev = ucma_misc.this_device;
+	return 0;
+}
+
+static struct ib_client rdma_cma_client = {
+	.name = "rdma_cm",
+	.get_global_nl_info = ucma_get_global_nl_info,
+};
+MODULE_ALIAS_RDMA_CLIENT("rdma_cm");
 
 static ssize_t show_abi_version(struct device *dev,
 				struct device_attribute *attr,
@@ -1875,7 +1870,14 @@
 		ret = -ENOMEM;
 		goto err2;
 	}
+
+	ret = ib_register_client(&rdma_cma_client);
+	if (ret)
+		goto err3;
+
 	return 0;
+err3:
+	unregister_net_sysctl_table(ucma_ctl_table_hdr);
 err2:
 	device_remove_file(ucma_misc.this_device, &dev_attr_abi_version);
 err1:
@@ -1885,11 +1887,10 @@
 
 static void __exit ucma_cleanup(void)
 {
+	ib_unregister_client(&rdma_cma_client);
 	unregister_net_sysctl_table(ucma_ctl_table_hdr);
 	device_remove_file(ucma_misc.this_device, &dev_attr_abi_version);
 	misc_deregister(&ucma_misc);
-	idr_destroy(&ctx_idr);
-	idr_destroy(&multicast_idr);
 }
 
 module_init(ucma_init);

--
Gitblit v1.6.2