From f70575805708cabdedea7498aaa3f710fde4d920 Mon Sep 17 00:00:00 2001
From: hc <hc@nodka.com>
Date: Wed, 31 Jan 2024 03:29:01 +0000
Subject: [PATCH] add lvds1024*800

---
 kernel/drivers/infiniband/hw/mlx5/mr.c | 1845 +++++++++++++++++++++++++++++++++++++----------------------
 1 files changed, 1,149 insertions(+), 696 deletions(-)

diff --git a/kernel/drivers/infiniband/hw/mlx5/mr.c b/kernel/drivers/infiniband/hw/mlx5/mr.c
index 18fd9aa..d827a4e 100644
--- a/kernel/drivers/infiniband/hw/mlx5/mr.c
+++ b/kernel/drivers/infiniband/hw/mlx5/mr.c
@@ -47,10 +47,69 @@
 
 #define MLX5_UMR_ALIGN 2048
 
+static void
+create_mkey_callback(int status, struct mlx5_async_work *context);
+
+static void set_mkc_access_pd_addr_fields(void *mkc, int acc, u64 start_addr,
+					  struct ib_pd *pd)
+{
+	struct mlx5_ib_dev *dev = to_mdev(pd->device);
+
+	MLX5_SET(mkc, mkc, a, !!(acc & IB_ACCESS_REMOTE_ATOMIC));
+	MLX5_SET(mkc, mkc, rw, !!(acc & IB_ACCESS_REMOTE_WRITE));
+	MLX5_SET(mkc, mkc, rr, !!(acc & IB_ACCESS_REMOTE_READ));
+	MLX5_SET(mkc, mkc, lw, !!(acc & IB_ACCESS_LOCAL_WRITE));
+	MLX5_SET(mkc, mkc, lr, 1);
+
+	if (MLX5_CAP_GEN(dev->mdev, relaxed_ordering_write))
+		MLX5_SET(mkc, mkc, relaxed_ordering_write,
+			 !!(acc & IB_ACCESS_RELAXED_ORDERING));
+	if (MLX5_CAP_GEN(dev->mdev, relaxed_ordering_read))
+		MLX5_SET(mkc, mkc, relaxed_ordering_read,
+			 !!(acc & IB_ACCESS_RELAXED_ORDERING));
+
+	MLX5_SET(mkc, mkc, pd, to_mpd(pd)->pdn);
+	MLX5_SET(mkc, mkc, qpn, 0xffffff);
+	MLX5_SET64(mkc, mkc, start_addr, start_addr);
+}
+
+static void
+assign_mkey_variant(struct mlx5_ib_dev *dev, struct mlx5_core_mkey *mkey,
+		    u32 *in)
+{
+	u8 key = atomic_inc_return(&dev->mkey_var);
+	void *mkc;
+
+	mkc = MLX5_ADDR_OF(create_mkey_in, in, memory_key_mkey_entry);
+	MLX5_SET(mkc, mkc, mkey_7_0, key);
+	mkey->key = key;
+}
+
+static int
+mlx5_ib_create_mkey(struct mlx5_ib_dev *dev, struct mlx5_core_mkey *mkey,
+		    u32 *in, int inlen)
+{
+	assign_mkey_variant(dev, mkey, in);
+	return mlx5_core_create_mkey(dev->mdev, mkey, in, inlen);
+}
+
+static int
+mlx5_ib_create_mkey_cb(struct mlx5_ib_dev *dev,
+		       struct mlx5_core_mkey *mkey,
+		       struct mlx5_async_ctx *async_ctx,
+		       u32 *in, int inlen, u32 *out, int outlen,
+		       struct mlx5_async_work *context)
+{
+	MLX5_SET(create_mkey_in, in, opcode, MLX5_CMD_OP_CREATE_MKEY);
+	assign_mkey_variant(dev, mkey, in);
+	return mlx5_cmd_exec_cb(async_ctx, in, inlen, out, outlen,
+				create_mkey_callback, context);
+}
+
 static void clean_mr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr);
 static void dereg_mr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr);
 static int mr_cache_max_order(struct mlx5_ib_dev *dev);
-static int unreg_umr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr);
+static void queue_adjust_cache_locked(struct mlx5_cache_ent *ent);
 
 static bool umr_can_use_indirect_mkey(struct mlx5_ib_dev *dev)
 {
@@ -59,113 +118,79 @@
 
 static int destroy_mkey(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr)
 {
-	int err = mlx5_core_destroy_mkey(dev->mdev, &mr->mmkey);
+	WARN_ON(xa_load(&dev->odp_mkeys, mlx5_base_mkey(mr->mmkey.key)));
 
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-	/* Wait until all page fault handlers using the mr complete. */
-	synchronize_srcu(&dev->mr_srcu);
-#endif
-
-	return err;
+	return mlx5_core_destroy_mkey(dev->mdev, &mr->mmkey);
 }
 
-static int order2idx(struct mlx5_ib_dev *dev, int order)
-{
-	struct mlx5_mr_cache *cache = &dev->cache;
-
-	if (order < cache->ent[0].order)
-		return 0;
-	else
-		return order - cache->ent[0].order;
-}
-
-static bool use_umr_mtt_update(struct mlx5_ib_mr *mr, u64 start, u64 length)
+static inline bool mlx5_ib_pas_fits_in_mr(struct mlx5_ib_mr *mr, u64 start,
+					  u64 length)
 {
 	return ((u64)1 << mr->order) * MLX5_ADAPTER_PAGE_SIZE >=
 		length + (start & (MLX5_ADAPTER_PAGE_SIZE - 1));
 }
 
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-static void update_odp_mr(struct mlx5_ib_mr *mr)
+static void create_mkey_callback(int status, struct mlx5_async_work *context)
 {
-	if (mr->umem->odp_data) {
-		/*
-		 * This barrier prevents the compiler from moving the
-		 * setting of umem->odp_data->private to point to our
-		 * MR, before reg_umr finished, to ensure that the MR
-		 * initialization have finished before starting to
-		 * handle invalidations.
-		 */
-		smp_wmb();
-		mr->umem->odp_data->private = mr;
-		/*
-		 * Make sure we will see the new
-		 * umem->odp_data->private value in the invalidation
-		 * routines, before we can get page faults on the
-		 * MR. Page faults can happen once we put the MR in
-		 * the tree, below this line. Without the barrier,
-		 * there can be a fault handling and an invalidation
-		 * before umem->odp_data->private == mr is visible to
-		 * the invalidation handler.
-		 */
-		smp_wmb();
-	}
-}
-#endif
-
-static void reg_mr_callback(int status, void *context)
-{
-	struct mlx5_ib_mr *mr = context;
+	struct mlx5_ib_mr *mr =
+		container_of(context, struct mlx5_ib_mr, cb_work);
 	struct mlx5_ib_dev *dev = mr->dev;
-	struct mlx5_mr_cache *cache = &dev->cache;
-	int c = order2idx(dev, mr->order);
-	struct mlx5_cache_ent *ent = &cache->ent[c];
-	u8 key;
+	struct mlx5_cache_ent *ent = mr->cache_ent;
 	unsigned long flags;
-	struct mlx5_mkey_table *table = &dev->mdev->priv.mkey_table;
-	int err;
 
-	spin_lock_irqsave(&ent->lock, flags);
-	ent->pending--;
-	spin_unlock_irqrestore(&ent->lock, flags);
 	if (status) {
 		mlx5_ib_warn(dev, "async reg mr failed. status %d\n", status);
 		kfree(mr);
-		dev->fill_delay = 1;
+		spin_lock_irqsave(&ent->lock, flags);
+		ent->pending--;
+		WRITE_ONCE(dev->fill_delay, 1);
+		spin_unlock_irqrestore(&ent->lock, flags);
 		mod_timer(&dev->delay_timer, jiffies + HZ);
 		return;
 	}
 
 	mr->mmkey.type = MLX5_MKEY_MR;
-	spin_lock_irqsave(&dev->mdev->priv.mkey_lock, flags);
-	key = dev->mdev->priv.mkey_key++;
-	spin_unlock_irqrestore(&dev->mdev->priv.mkey_lock, flags);
-	mr->mmkey.key = mlx5_idx_to_mkey(MLX5_GET(create_mkey_out, mr->out, mkey_index)) | key;
+	mr->mmkey.key |= mlx5_idx_to_mkey(
+		MLX5_GET(create_mkey_out, mr->out, mkey_index));
 
-	cache->last_add = jiffies;
+	WRITE_ONCE(dev->cache.last_add, jiffies);
 
 	spin_lock_irqsave(&ent->lock, flags);
 	list_add_tail(&mr->list, &ent->head);
-	ent->cur++;
-	ent->size++;
+	ent->available_mrs++;
+	ent->total_mrs++;
+	/* If we are doing fill_to_high_water then keep going. */
+	queue_adjust_cache_locked(ent);
+	ent->pending--;
 	spin_unlock_irqrestore(&ent->lock, flags);
-
-	write_lock_irqsave(&table->lock, flags);
-	err = radix_tree_insert(&table->tree, mlx5_base_mkey(mr->mmkey.key),
-				&mr->mmkey);
-	if (err)
-		pr_err("Error inserting to mkey tree. 0x%x\n", -err);
-	write_unlock_irqrestore(&table->lock, flags);
-
-	if (!completion_done(&ent->compl))
-		complete(&ent->compl);
 }
 
-static int add_keys(struct mlx5_ib_dev *dev, int c, int num)
+static struct mlx5_ib_mr *alloc_cache_mr(struct mlx5_cache_ent *ent, void *mkc)
 {
-	struct mlx5_mr_cache *cache = &dev->cache;
-	struct mlx5_cache_ent *ent = &cache->ent[c];
-	int inlen = MLX5_ST_SZ_BYTES(create_mkey_in);
+	struct mlx5_ib_mr *mr;
+
+	mr = kzalloc(sizeof(*mr), GFP_KERNEL);
+	if (!mr)
+		return NULL;
+	mr->order = ent->order;
+	mr->cache_ent = ent;
+	mr->dev = ent->dev;
+
+	set_mkc_access_pd_addr_fields(mkc, 0, 0, ent->dev->umrc.pd);
+	MLX5_SET(mkc, mkc, free, 1);
+	MLX5_SET(mkc, mkc, umr_en, 1);
+	MLX5_SET(mkc, mkc, access_mode_1_0, ent->access_mode & 0x3);
+	MLX5_SET(mkc, mkc, access_mode_4_2, (ent->access_mode >> 2) & 0x7);
+
+	MLX5_SET(mkc, mkc, translations_octword_size, ent->xlt);
+	MLX5_SET(mkc, mkc, log_page_size, ent->page);
+	return mr;
+}
+
+/* Asynchronously schedule new MRs to be populated in the cache. */
+static int add_keys(struct mlx5_cache_ent *ent, unsigned int num)
+{
+	size_t inlen = MLX5_ST_SZ_BYTES(create_mkey_in);
 	struct mlx5_ib_mr *mr;
 	void *mkc;
 	u32 *in;
@@ -178,42 +203,29 @@
 
 	mkc = MLX5_ADDR_OF(create_mkey_in, in, memory_key_mkey_entry);
 	for (i = 0; i < num; i++) {
-		if (ent->pending >= MAX_PENDING_REG_MR) {
-			err = -EAGAIN;
-			break;
-		}
-
-		mr = kzalloc(sizeof(*mr), GFP_KERNEL);
+		mr = alloc_cache_mr(ent, mkc);
 		if (!mr) {
 			err = -ENOMEM;
 			break;
 		}
-		mr->order = ent->order;
-		mr->allocated_from_cache = 1;
-		mr->dev = dev;
-
-		MLX5_SET(mkc, mkc, free, 1);
-		MLX5_SET(mkc, mkc, umr_en, 1);
-		MLX5_SET(mkc, mkc, access_mode_1_0, ent->access_mode & 0x3);
-		MLX5_SET(mkc, mkc, access_mode_4_2,
-			 (ent->access_mode >> 2) & 0x7);
-
-		MLX5_SET(mkc, mkc, qpn, 0xffffff);
-		MLX5_SET(mkc, mkc, translations_octword_size, ent->xlt);
-		MLX5_SET(mkc, mkc, log_page_size, ent->page);
-
 		spin_lock_irq(&ent->lock);
+		if (ent->pending >= MAX_PENDING_REG_MR) {
+			err = -EAGAIN;
+			spin_unlock_irq(&ent->lock);
+			kfree(mr);
+			break;
+		}
 		ent->pending++;
 		spin_unlock_irq(&ent->lock);
-		err = mlx5_core_create_mkey_cb(dev->mdev, &mr->mmkey,
-					       in, inlen,
-					       mr->out, sizeof(mr->out),
-					       reg_mr_callback, mr);
+		err = mlx5_ib_create_mkey_cb(ent->dev, &mr->mmkey,
+					     &ent->dev->async_ctx, in, inlen,
+					     mr->out, sizeof(mr->out),
+					     &mr->cb_work);
 		if (err) {
 			spin_lock_irq(&ent->lock);
 			ent->pending--;
 			spin_unlock_irq(&ent->lock);
-			mlx5_ib_warn(dev, "create mkey failed %d\n", err);
+			mlx5_ib_warn(ent->dev, "create mkey failed %d\n", err);
 			kfree(mr);
 			break;
 		}
@@ -223,36 +235,89 @@
 	return err;
 }
 
-static void remove_keys(struct mlx5_ib_dev *dev, int c, int num)
+/* Synchronously create a MR in the cache */
+static struct mlx5_ib_mr *create_cache_mr(struct mlx5_cache_ent *ent)
 {
-	struct mlx5_mr_cache *cache = &dev->cache;
-	struct mlx5_cache_ent *ent = &cache->ent[c];
-	struct mlx5_ib_mr *tmp_mr;
+	size_t inlen = MLX5_ST_SZ_BYTES(create_mkey_in);
 	struct mlx5_ib_mr *mr;
-	LIST_HEAD(del_list);
-	int i;
+	void *mkc;
+	u32 *in;
+	int err;
 
-	for (i = 0; i < num; i++) {
-		spin_lock_irq(&ent->lock);
-		if (list_empty(&ent->head)) {
-			spin_unlock_irq(&ent->lock);
-			break;
-		}
-		mr = list_first_entry(&ent->head, struct mlx5_ib_mr, list);
-		list_move(&mr->list, &del_list);
-		ent->cur--;
-		ent->size--;
-		spin_unlock_irq(&ent->lock);
-		mlx5_core_destroy_mkey(dev->mdev, &mr->mmkey);
+	in = kzalloc(inlen, GFP_KERNEL);
+	if (!in)
+		return ERR_PTR(-ENOMEM);
+	mkc = MLX5_ADDR_OF(create_mkey_in, in, memory_key_mkey_entry);
+
+	mr = alloc_cache_mr(ent, mkc);
+	if (!mr) {
+		err = -ENOMEM;
+		goto free_in;
 	}
 
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-	synchronize_srcu(&dev->mr_srcu);
-#endif
+	err = mlx5_core_create_mkey(ent->dev->mdev, &mr->mmkey, in, inlen);
+	if (err)
+		goto free_mr;
 
-	list_for_each_entry_safe(mr, tmp_mr, &del_list, list) {
-		list_del(&mr->list);
-		kfree(mr);
+	mr->mmkey.type = MLX5_MKEY_MR;
+	WRITE_ONCE(ent->dev->cache.last_add, jiffies);
+	spin_lock_irq(&ent->lock);
+	ent->total_mrs++;
+	spin_unlock_irq(&ent->lock);
+	kfree(in);
+	return mr;
+free_mr:
+	kfree(mr);
+free_in:
+	kfree(in);
+	return ERR_PTR(err);
+}
+
+static void remove_cache_mr_locked(struct mlx5_cache_ent *ent)
+{
+	struct mlx5_ib_mr *mr;
+
+	lockdep_assert_held(&ent->lock);
+	if (list_empty(&ent->head))
+		return;
+	mr = list_first_entry(&ent->head, struct mlx5_ib_mr, list);
+	list_del(&mr->list);
+	ent->available_mrs--;
+	ent->total_mrs--;
+	spin_unlock_irq(&ent->lock);
+	mlx5_core_destroy_mkey(ent->dev->mdev, &mr->mmkey);
+	kfree(mr);
+	spin_lock_irq(&ent->lock);
+}
+
+static int resize_available_mrs(struct mlx5_cache_ent *ent, unsigned int target,
+				bool limit_fill)
+{
+	int err;
+
+	lockdep_assert_held(&ent->lock);
+
+	while (true) {
+		if (limit_fill)
+			target = ent->limit * 2;
+		if (target == ent->available_mrs + ent->pending)
+			return 0;
+		if (target > ent->available_mrs + ent->pending) {
+			u32 todo = target - (ent->available_mrs + ent->pending);
+
+			spin_unlock_irq(&ent->lock);
+			err = add_keys(ent, todo);
+			if (err == -EAGAIN)
+				usleep_range(3000, 5000);
+			spin_lock_irq(&ent->lock);
+			if (err) {
+				if (err != -EAGAIN)
+					return err;
+			} else
+				return 0;
+		} else {
+			remove_cache_mr_locked(ent);
+		}
 	}
 }
 
@@ -260,37 +325,38 @@
 			  size_t count, loff_t *pos)
 {
 	struct mlx5_cache_ent *ent = filp->private_data;
-	struct mlx5_ib_dev *dev = ent->dev;
-	char lbuf[20] = {0};
-	u32 var;
+	u32 target;
 	int err;
-	int c;
 
-	count = min(count, sizeof(lbuf) - 1);
-	if (copy_from_user(lbuf, buf, count))
-		return -EFAULT;
+	err = kstrtou32_from_user(buf, count, 0, &target);
+	if (err)
+		return err;
 
-	c = order2idx(dev, ent->order);
-
-	if (sscanf(lbuf, "%u", &var) != 1)
-		return -EINVAL;
-
-	if (var < ent->limit)
-		return -EINVAL;
-
-	if (var > ent->size) {
-		do {
-			err = add_keys(dev, c, var - ent->size);
-			if (err && err != -EAGAIN)
-				return err;
-
-			usleep_range(3000, 5000);
-		} while (err);
-	} else if (var < ent->size) {
-		remove_keys(dev, c, ent->size - var);
+	/*
+	 * Target is the new value of total_mrs the user requests, however we
+	 * cannot free MRs that are in use. Compute the target value for
+	 * available_mrs.
+	 */
+	spin_lock_irq(&ent->lock);
+	if (target < ent->total_mrs - ent->available_mrs) {
+		err = -EINVAL;
+		goto err_unlock;
 	}
+	target = target - (ent->total_mrs - ent->available_mrs);
+	if (target < ent->limit || target > ent->limit*2) {
+		err = -EINVAL;
+		goto err_unlock;
+	}
+	err = resize_available_mrs(ent, target, false);
+	if (err)
+		goto err_unlock;
+	spin_unlock_irq(&ent->lock);
 
 	return count;
+
+err_unlock:
+	spin_unlock_irq(&ent->lock);
+	return err;
 }
 
 static ssize_t size_read(struct file *filp, char __user *buf, size_t count,
@@ -300,7 +366,7 @@
 	char lbuf[20];
 	int err;
 
-	err = snprintf(lbuf, sizeof(lbuf), "%d\n", ent->size);
+	err = snprintf(lbuf, sizeof(lbuf), "%d\n", ent->total_mrs);
 	if (err < 0)
 		return err;
 
@@ -318,32 +384,23 @@
 			   size_t count, loff_t *pos)
 {
 	struct mlx5_cache_ent *ent = filp->private_data;
-	struct mlx5_ib_dev *dev = ent->dev;
-	char lbuf[20] = {0};
 	u32 var;
 	int err;
-	int c;
 
-	count = min(count, sizeof(lbuf) - 1);
-	if (copy_from_user(lbuf, buf, count))
-		return -EFAULT;
+	err = kstrtou32_from_user(buf, count, 0, &var);
+	if (err)
+		return err;
 
-	c = order2idx(dev, ent->order);
-
-	if (sscanf(lbuf, "%u", &var) != 1)
-		return -EINVAL;
-
-	if (var > ent->size)
-		return -EINVAL;
-
+	/*
+	 * Upon set we immediately fill the cache to high water mark implied by
+	 * the limit.
+	 */
+	spin_lock_irq(&ent->lock);
 	ent->limit = var;
-
-	if (ent->cur < ent->limit) {
-		err = add_keys(dev, c, 2 * ent->limit - ent->cur);
-		if (err)
-			return err;
-	}
-
+	err = resize_available_mrs(ent, 0, true);
+	spin_unlock_irq(&ent->lock);
+	if (err)
+		return err;
 	return count;
 }
 
@@ -368,68 +425,121 @@
 	.read	= limit_read,
 };
 
-static int someone_adding(struct mlx5_mr_cache *cache)
+static bool someone_adding(struct mlx5_mr_cache *cache)
 {
-	int i;
+	unsigned int i;
 
 	for (i = 0; i < MAX_MR_CACHE_ENTRIES; i++) {
-		if (cache->ent[i].cur < cache->ent[i].limit)
-			return 1;
-	}
+		struct mlx5_cache_ent *ent = &cache->ent[i];
+		bool ret;
 
-	return 0;
+		spin_lock_irq(&ent->lock);
+		ret = ent->available_mrs < ent->limit;
+		spin_unlock_irq(&ent->lock);
+		if (ret)
+			return true;
+	}
+	return false;
+}
+
+/*
+ * Check if the bucket is outside the high/low water mark and schedule an async
+ * update. The cache refill has hysteresis, once the low water mark is hit it is
+ * refilled up to the high mark.
+ */
+static void queue_adjust_cache_locked(struct mlx5_cache_ent *ent)
+{
+	lockdep_assert_held(&ent->lock);
+
+	if (ent->disabled || READ_ONCE(ent->dev->fill_delay))
+		return;
+	if (ent->available_mrs < ent->limit) {
+		ent->fill_to_high_water = true;
+		queue_work(ent->dev->cache.wq, &ent->work);
+	} else if (ent->fill_to_high_water &&
+		   ent->available_mrs + ent->pending < 2 * ent->limit) {
+		/*
+		 * Once we start populating due to hitting a low water mark
+		 * continue until we pass the high water mark.
+		 */
+		queue_work(ent->dev->cache.wq, &ent->work);
+	} else if (ent->available_mrs == 2 * ent->limit) {
+		ent->fill_to_high_water = false;
+	} else if (ent->available_mrs > 2 * ent->limit) {
+		/* Queue deletion of excess entries */
+		ent->fill_to_high_water = false;
+		if (ent->pending)
+			queue_delayed_work(ent->dev->cache.wq, &ent->dwork,
+					   msecs_to_jiffies(1000));
+		else
+			queue_work(ent->dev->cache.wq, &ent->work);
+	}
 }
 
 static void __cache_work_func(struct mlx5_cache_ent *ent)
 {
 	struct mlx5_ib_dev *dev = ent->dev;
 	struct mlx5_mr_cache *cache = &dev->cache;
-	int i = order2idx(dev, ent->order);
 	int err;
 
-	if (cache->stopped)
-		return;
+	spin_lock_irq(&ent->lock);
+	if (ent->disabled)
+		goto out;
 
-	ent = &dev->cache.ent[i];
-	if (ent->cur < 2 * ent->limit && !dev->fill_delay) {
-		err = add_keys(dev, i, 1);
-		if (ent->cur < 2 * ent->limit) {
-			if (err == -EAGAIN) {
-				mlx5_ib_dbg(dev, "returned eagain, order %d\n",
-					    i + 2);
-				queue_delayed_work(cache->wq, &ent->dwork,
-						   msecs_to_jiffies(3));
-			} else if (err) {
-				mlx5_ib_warn(dev, "command failed order %d, err %d\n",
-					     i + 2, err);
+	if (ent->fill_to_high_water &&
+	    ent->available_mrs + ent->pending < 2 * ent->limit &&
+	    !READ_ONCE(dev->fill_delay)) {
+		spin_unlock_irq(&ent->lock);
+		err = add_keys(ent, 1);
+		spin_lock_irq(&ent->lock);
+		if (ent->disabled)
+			goto out;
+		if (err) {
+			/*
+			 * EAGAIN only happens if pending is positive, so we
+			 * will be rescheduled from reg_mr_callback(). The only
+			 * failure path here is ENOMEM.
+			 */
+			if (err != -EAGAIN) {
+				mlx5_ib_warn(
+					dev,
+					"command failed order %d, err %d\n",
+					ent->order, err);
 				queue_delayed_work(cache->wq, &ent->dwork,
 						   msecs_to_jiffies(1000));
-			} else {
-				queue_work(cache->wq, &ent->work);
 			}
 		}
-	} else if (ent->cur > 2 * ent->limit) {
+	} else if (ent->available_mrs > 2 * ent->limit) {
+		bool need_delay;
+
 		/*
-		 * The remove_keys() logic is performed as garbage collection
-		 * task. Such task is intended to be run when no other active
-		 * processes are running.
+		 * The remove_cache_mr() logic is performed as garbage
+		 * collection task. Such task is intended to be run when no
+		 * other active processes are running.
 		 *
 		 * The need_resched() will return TRUE if there are user tasks
 		 * to be activated in near future.
 		 *
-		 * In such case, we don't execute remove_keys() and postpone
-		 * the garbage collection work to try to run in next cycle,
-		 * in order to free CPU resources to other tasks.
+		 * In such case, we don't execute remove_cache_mr() and postpone
+		 * the garbage collection work to try to run in next cycle, in
+		 * order to free CPU resources to other tasks.
 		 */
-		if (!need_resched() && !someone_adding(cache) &&
-		    time_after(jiffies, cache->last_add + 300 * HZ)) {
-			remove_keys(dev, i, 1);
-			if (ent->cur > ent->limit)
-				queue_work(cache->wq, &ent->work);
-		} else {
+		spin_unlock_irq(&ent->lock);
+		need_delay = need_resched() || someone_adding(cache) ||
+			     !time_after(jiffies,
+					 READ_ONCE(cache->last_add) + 300 * HZ);
+		spin_lock_irq(&ent->lock);
+		if (ent->disabled)
+			goto out;
+		if (need_delay) {
 			queue_delayed_work(cache->wq, &ent->dwork, 300 * HZ);
+			goto out;
 		}
+		remove_cache_mr_locked(ent);
+		queue_adjust_cache_locked(ent);
 	}
+out:
+	spin_unlock_irq(&ent->lock);
 }
 
 static void delayed_cache_work_func(struct work_struct *work)
@@ -448,117 +558,103 @@
 	__cache_work_func(ent);
 }
 
-struct mlx5_ib_mr *mlx5_mr_cache_alloc(struct mlx5_ib_dev *dev, int entry)
+/* Allocate a special entry from the cache */
+struct mlx5_ib_mr *mlx5_mr_cache_alloc(struct mlx5_ib_dev *dev,
+				       unsigned int entry, int access_flags)
 {
 	struct mlx5_mr_cache *cache = &dev->cache;
 	struct mlx5_cache_ent *ent;
 	struct mlx5_ib_mr *mr;
-	int err;
 
-	if (entry < 0 || entry >= MAX_MR_CACHE_ENTRIES) {
-		mlx5_ib_err(dev, "cache entry %d is out of range\n", entry);
+	if (WARN_ON(entry <= MR_CACHE_LAST_STD_ENTRY ||
+		    entry >= ARRAY_SIZE(cache->ent)))
 		return ERR_PTR(-EINVAL);
-	}
+
+	/* Matches access in alloc_cache_mr() */
+	if (!mlx5_ib_can_reconfig_with_umr(dev, 0, access_flags))
+		return ERR_PTR(-EOPNOTSUPP);
 
 	ent = &cache->ent[entry];
-	while (1) {
-		spin_lock_irq(&ent->lock);
-		if (list_empty(&ent->head)) {
-			spin_unlock_irq(&ent->lock);
-
-			err = add_keys(dev, entry, 1);
-			if (err && err != -EAGAIN)
-				return ERR_PTR(err);
-
-			wait_for_completion(&ent->compl);
-		} else {
-			mr = list_first_entry(&ent->head, struct mlx5_ib_mr,
-					      list);
-			list_del(&mr->list);
-			ent->cur--;
-			spin_unlock_irq(&ent->lock);
-			if (ent->cur < ent->limit)
-				queue_work(cache->wq, &ent->work);
+	spin_lock_irq(&ent->lock);
+	if (list_empty(&ent->head)) {
+		queue_adjust_cache_locked(ent);
+		ent->miss++;
+		spin_unlock_irq(&ent->lock);
+		mr = create_cache_mr(ent);
+		if (IS_ERR(mr))
 			return mr;
-		}
+	} else {
+		mr = list_first_entry(&ent->head, struct mlx5_ib_mr, list);
+		list_del(&mr->list);
+		ent->available_mrs--;
+		queue_adjust_cache_locked(ent);
+		spin_unlock_irq(&ent->lock);
 	}
+	mr->access_flags = access_flags;
+	return mr;
 }
 
-static struct mlx5_ib_mr *alloc_cached_mr(struct mlx5_ib_dev *dev, int order)
+/* Return a MR already available in the cache */
+static struct mlx5_ib_mr *get_cache_mr(struct mlx5_cache_ent *req_ent)
 {
-	struct mlx5_mr_cache *cache = &dev->cache;
+	struct mlx5_ib_dev *dev = req_ent->dev;
 	struct mlx5_ib_mr *mr = NULL;
-	struct mlx5_cache_ent *ent;
-	int last_umr_cache_entry;
-	int c;
-	int i;
+	struct mlx5_cache_ent *ent = req_ent;
 
-	c = order2idx(dev, order);
-	last_umr_cache_entry = order2idx(dev, mr_cache_max_order(dev));
-	if (c < 0 || c > last_umr_cache_entry) {
-		mlx5_ib_warn(dev, "order %d, cache index %d\n", order, c);
-		return NULL;
-	}
-
-	for (i = c; i <= last_umr_cache_entry; i++) {
-		ent = &cache->ent[i];
-
-		mlx5_ib_dbg(dev, "order %d, cache index %d\n", ent->order, i);
+	/* Try larger MR pools from the cache to satisfy the allocation */
+	for (; ent != &dev->cache.ent[MR_CACHE_LAST_STD_ENTRY + 1]; ent++) {
+		mlx5_ib_dbg(dev, "order %u, cache index %zu\n", ent->order,
+			    ent - dev->cache.ent);
 
 		spin_lock_irq(&ent->lock);
 		if (!list_empty(&ent->head)) {
 			mr = list_first_entry(&ent->head, struct mlx5_ib_mr,
 					      list);
 			list_del(&mr->list);
-			ent->cur--;
+			ent->available_mrs--;
+			queue_adjust_cache_locked(ent);
 			spin_unlock_irq(&ent->lock);
-			if (ent->cur < ent->limit)
-				queue_work(cache->wq, &ent->work);
 			break;
 		}
+		queue_adjust_cache_locked(ent);
 		spin_unlock_irq(&ent->lock);
-
-		queue_work(cache->wq, &ent->work);
 	}
 
 	if (!mr)
-		cache->ent[c].miss++;
+		req_ent->miss++;
 
 	return mr;
 }
 
+static void detach_mr_from_cache(struct mlx5_ib_mr *mr)
+{
+	struct mlx5_cache_ent *ent = mr->cache_ent;
+
+	mr->cache_ent = NULL;
+	spin_lock_irq(&ent->lock);
+	ent->total_mrs--;
+	spin_unlock_irq(&ent->lock);
+}
+
 void mlx5_mr_cache_free(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr)
 {
-	struct mlx5_mr_cache *cache = &dev->cache;
-	struct mlx5_cache_ent *ent;
-	int shrink = 0;
-	int c;
+	struct mlx5_cache_ent *ent = mr->cache_ent;
 
-	if (!mr->allocated_from_cache)
+	if (!ent)
 		return;
 
-	c = order2idx(dev, mr->order);
-	WARN_ON(c < 0 || c >= MAX_MR_CACHE_ENTRIES);
-
-	if (unreg_umr(dev, mr)) {
-		mr->allocated_from_cache = false;
+	if (mlx5_mr_cache_invalidate(mr)) {
+		detach_mr_from_cache(mr);
 		destroy_mkey(dev, mr);
-		ent = &cache->ent[c];
-		if (ent->cur < ent->limit)
-			queue_work(cache->wq, &ent->work);
+		kfree(mr);
 		return;
 	}
 
-	ent = &cache->ent[c];
 	spin_lock_irq(&ent->lock);
 	list_add_tail(&mr->list, &ent->head);
-	ent->cur++;
-	if (ent->cur > 2 * ent->limit)
-		shrink = 1;
+	ent->available_mrs++;
+	queue_adjust_cache_locked(ent);
 	spin_unlock_irq(&ent->lock);
-
-	if (shrink)
-		queue_work(cache->wq, &ent->work);
 }
 
 static void clean_keys(struct mlx5_ib_dev *dev, int c)
@@ -578,15 +674,11 @@
 		}
 		mr = list_first_entry(&ent->head, struct mlx5_ib_mr, list);
 		list_move(&mr->list, &del_list);
-		ent->cur--;
-		ent->size--;
+		ent->available_mrs--;
+		ent->total_mrs--;
 		spin_unlock_irq(&ent->lock);
 		mlx5_core_destroy_mkey(dev->mdev, &mr->mmkey);
 	}
-
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-	synchronize_srcu(&dev->mr_srcu);
-#endif
 
 	list_for_each_entry_safe(mr, tmp_mr, &del_list, list) {
 		list_del(&mr->list);
@@ -596,73 +688,47 @@
 
 static void mlx5_mr_cache_debugfs_cleanup(struct mlx5_ib_dev *dev)
 {
-	if (!mlx5_debugfs_root || dev->rep)
+	if (!mlx5_debugfs_root || dev->is_rep)
 		return;
 
 	debugfs_remove_recursive(dev->cache.root);
 	dev->cache.root = NULL;
 }
 
-static int mlx5_mr_cache_debugfs_init(struct mlx5_ib_dev *dev)
+static void mlx5_mr_cache_debugfs_init(struct mlx5_ib_dev *dev)
 {
 	struct mlx5_mr_cache *cache = &dev->cache;
 	struct mlx5_cache_ent *ent;
+	struct dentry *dir;
 	int i;
 
-	if (!mlx5_debugfs_root || dev->rep)
-		return 0;
+	if (!mlx5_debugfs_root || dev->is_rep)
+		return;
 
 	cache->root = debugfs_create_dir("mr_cache", dev->mdev->priv.dbg_root);
-	if (!cache->root)
-		return -ENOMEM;
 
 	for (i = 0; i < MAX_MR_CACHE_ENTRIES; i++) {
 		ent = &cache->ent[i];
 		sprintf(ent->name, "%d", ent->order);
-		ent->dir = debugfs_create_dir(ent->name,  cache->root);
-		if (!ent->dir)
-			goto err;
-
-		ent->fsize = debugfs_create_file("size", 0600, ent->dir, ent,
-						 &size_fops);
-		if (!ent->fsize)
-			goto err;
-
-		ent->flimit = debugfs_create_file("limit", 0600, ent->dir, ent,
-						  &limit_fops);
-		if (!ent->flimit)
-			goto err;
-
-		ent->fcur = debugfs_create_u32("cur", 0400, ent->dir,
-					       &ent->cur);
-		if (!ent->fcur)
-			goto err;
-
-		ent->fmiss = debugfs_create_u32("miss", 0600, ent->dir,
-						&ent->miss);
-		if (!ent->fmiss)
-			goto err;
+		dir = debugfs_create_dir(ent->name, cache->root);
+		debugfs_create_file("size", 0600, dir, ent, &size_fops);
+		debugfs_create_file("limit", 0600, dir, ent, &limit_fops);
+		debugfs_create_u32("cur", 0400, dir, &ent->available_mrs);
+		debugfs_create_u32("miss", 0600, dir, &ent->miss);
 	}
-
-	return 0;
-err:
-	mlx5_mr_cache_debugfs_cleanup(dev);
-
-	return -ENOMEM;
 }
 
 static void delay_time_func(struct timer_list *t)
 {
 	struct mlx5_ib_dev *dev = from_timer(dev, t, delay_timer);
 
-	dev->fill_delay = 0;
+	WRITE_ONCE(dev->fill_delay, 0);
 }
 
 int mlx5_mr_cache_init(struct mlx5_ib_dev *dev)
 {
 	struct mlx5_mr_cache *cache = &dev->cache;
 	struct mlx5_cache_ent *ent;
-	int err;
 	int i;
 
 	mutex_init(&dev->slow_path_mutex);
@@ -672,6 +738,7 @@
 		return -ENOMEM;
 	}
 
+	mlx5_cmd_init_async_ctx(dev->mdev, &dev->async_ctx);
 	timer_setup(&dev->delay_timer, delay_time_func, 0);
 	for (i = 0; i < MAX_MR_CACHE_ENTRIES; i++) {
 		ent = &cache->ent[i];
@@ -681,7 +748,6 @@
 		ent->dev = dev;
 		ent->limit = 0;
 
-		init_completion(&ent->compl);
 		INIT_WORK(&ent->work, cache_work_func);
 		INIT_DELAYED_WORK(&ent->dwork, delayed_cache_work_func);
 
@@ -698,70 +764,45 @@
 			   MLX5_IB_UMR_OCTOWORD;
 		ent->access_mode = MLX5_MKC_ACCESS_MODE_MTT;
 		if ((dev->mdev->profile->mask & MLX5_PROF_MASK_MR_CACHE) &&
-		    !dev->rep &&
-		    mlx5_core_is_pf(dev->mdev))
+		    !dev->is_rep && mlx5_core_is_pf(dev->mdev) &&
+		    mlx5_ib_can_load_pas_with_umr(dev, 0))
 			ent->limit = dev->mdev->profile->mr_cache[i].limit;
 		else
 			ent->limit = 0;
-		queue_work(cache->wq, &ent->work);
+		spin_lock_irq(&ent->lock);
+		queue_adjust_cache_locked(ent);
+		spin_unlock_irq(&ent->lock);
 	}
 
-	err = mlx5_mr_cache_debugfs_init(dev);
-	if (err)
-		mlx5_ib_warn(dev, "cache debugfs failure\n");
-
-	/*
-	 * We don't want to fail driver if debugfs failed to initialize,
-	 * so we are not forwarding error to the user.
-	 */
+	mlx5_mr_cache_debugfs_init(dev);
 
 	return 0;
 }
 
-static void wait_for_async_commands(struct mlx5_ib_dev *dev)
-{
-	struct mlx5_mr_cache *cache = &dev->cache;
-	struct mlx5_cache_ent *ent;
-	int total = 0;
-	int i;
-	int j;
-
-	for (i = 0; i < MAX_MR_CACHE_ENTRIES; i++) {
-		ent = &cache->ent[i];
-		for (j = 0 ; j < 1000; j++) {
-			if (!ent->pending)
-				break;
-			msleep(50);
-		}
-	}
-	for (i = 0; i < MAX_MR_CACHE_ENTRIES; i++) {
-		ent = &cache->ent[i];
-		total += ent->pending;
-	}
-
-	if (total)
-		mlx5_ib_warn(dev, "aborted while there are %d pending mr requests\n", total);
-	else
-		mlx5_ib_warn(dev, "done with all pending requests\n");
-}
-
 int mlx5_mr_cache_cleanup(struct mlx5_ib_dev *dev)
 {
-	int i;
+	unsigned int i;
 
 	if (!dev->cache.wq)
 		return 0;
 
-	dev->cache.stopped = 1;
-	flush_workqueue(dev->cache.wq);
+	for (i = 0; i < MAX_MR_CACHE_ENTRIES; i++) {
+		struct mlx5_cache_ent *ent = &dev->cache.ent[i];
+
+		spin_lock_irq(&ent->lock);
+		ent->disabled = true;
+		spin_unlock_irq(&ent->lock);
+		cancel_work_sync(&ent->work);
+		cancel_delayed_work_sync(&ent->dwork);
+	}
 
 	mlx5_mr_cache_debugfs_cleanup(dev);
+	mlx5_cmd_cleanup_async_ctx(&dev->async_ctx);
 
 	for (i = 0; i < MAX_MR_CACHE_ENTRIES; i++)
 		clean_keys(dev, i);
 
 	destroy_workqueue(dev->cache.wq);
-	wait_for_async_commands(dev);
 	del_timer_sync(&dev->delay_timer);
 
 	return 0;
@@ -771,7 +812,6 @@
 {
 	struct mlx5_ib_dev *dev = to_mdev(pd->device);
 	int inlen = MLX5_ST_SZ_BYTES(create_mkey_in);
-	struct mlx5_core_dev *mdev = dev->mdev;
 	struct mlx5_ib_mr *mr;
 	void *mkc;
 	u32 *in;
@@ -790,18 +830,10 @@
 	mkc = MLX5_ADDR_OF(create_mkey_in, in, memory_key_mkey_entry);
 
 	MLX5_SET(mkc, mkc, access_mode_1_0, MLX5_MKC_ACCESS_MODE_PA);
-	MLX5_SET(mkc, mkc, a, !!(acc & IB_ACCESS_REMOTE_ATOMIC));
-	MLX5_SET(mkc, mkc, rw, !!(acc & IB_ACCESS_REMOTE_WRITE));
-	MLX5_SET(mkc, mkc, rr, !!(acc & IB_ACCESS_REMOTE_READ));
-	MLX5_SET(mkc, mkc, lw, !!(acc & IB_ACCESS_LOCAL_WRITE));
-	MLX5_SET(mkc, mkc, lr, 1);
-
 	MLX5_SET(mkc, mkc, length64, 1);
-	MLX5_SET(mkc, mkc, pd, to_mpd(pd)->pdn);
-	MLX5_SET(mkc, mkc, qpn, 0xffffff);
-	MLX5_SET64(mkc, mkc, start_addr, 0);
+	set_mkc_access_pd_addr_fields(mkc, acc, 0, pd);
 
-	err = mlx5_core_create_mkey(mdev, &mr->mmkey, in, inlen);
+	err = mlx5_ib_create_mkey(dev, &mr->mmkey, in, inlen);
 	if (err)
 		goto err_in;
 
@@ -840,26 +872,43 @@
 	return MLX5_MAX_UMR_SHIFT;
 }
 
-static int mr_umem_get(struct ib_pd *pd, u64 start, u64 length,
-		       int access_flags, struct ib_umem **umem,
-		       int *npages, int *page_shift, int *ncont,
-		       int *order)
+static int mr_umem_get(struct mlx5_ib_dev *dev, u64 start, u64 length,
+		       int access_flags, struct ib_umem **umem, int *npages,
+		       int *page_shift, int *ncont, int *order)
 {
-	struct mlx5_ib_dev *dev = to_mdev(pd->device);
 	struct ib_umem *u;
-	int err;
 
 	*umem = NULL;
 
-	u = ib_umem_get(pd->uobject->context, start, length, access_flags, 0);
-	err = PTR_ERR_OR_ZERO(u);
-	if (err) {
-		mlx5_ib_dbg(dev, "umem get failed (%d)\n", err);
-		return err;
+	if (access_flags & IB_ACCESS_ON_DEMAND) {
+		struct ib_umem_odp *odp;
+
+		odp = ib_umem_odp_get(&dev->ib_dev, start, length, access_flags,
+				      &mlx5_mn_ops);
+		if (IS_ERR(odp)) {
+			mlx5_ib_dbg(dev, "umem get failed (%ld)\n",
+				    PTR_ERR(odp));
+			return PTR_ERR(odp);
+		}
+
+		u = &odp->umem;
+
+		*page_shift = odp->page_shift;
+		*ncont = ib_umem_odp_num_pages(odp);
+		*npages = *ncont << (*page_shift - PAGE_SHIFT);
+		if (order)
+			*order = ilog2(roundup_pow_of_two(*ncont));
+	} else {
+		u = ib_umem_get(&dev->ib_dev, start, length, access_flags);
+		if (IS_ERR(u)) {
+			mlx5_ib_dbg(dev, "umem get failed (%ld)\n", PTR_ERR(u));
+			return PTR_ERR(u);
+		}
+
+		mlx5_ib_cont_pages(u, start, MLX5_MKEY_PAGE_SHIFT_MASK, npages,
+				   page_shift, ncont, order);
 	}
 
-	mlx5_ib_cont_pages(u, start, MLX5_MKEY_PAGE_SHIFT_MASK, npages,
-			   page_shift, ncont, order);
 	if (!*npages) {
 		mlx5_ib_warn(dev, "avoid zero region\n");
 		ib_umem_release(u);
@@ -917,30 +966,41 @@
 	return err;
 }
 
-static struct mlx5_ib_mr *alloc_mr_from_cache(
-				  struct ib_pd *pd, struct ib_umem *umem,
-				  u64 virt_addr, u64 len, int npages,
-				  int page_shift, int order, int access_flags)
+static struct mlx5_cache_ent *mr_cache_ent_from_order(struct mlx5_ib_dev *dev,
+						      unsigned int order)
+{
+	struct mlx5_mr_cache *cache = &dev->cache;
+
+	if (order < cache->ent[0].order)
+		return &cache->ent[0];
+	order = order - cache->ent[0].order;
+	if (order > MR_CACHE_LAST_STD_ENTRY)
+		return NULL;
+	return &cache->ent[order];
+}
+
+static struct mlx5_ib_mr *
+alloc_mr_from_cache(struct ib_pd *pd, struct ib_umem *umem, u64 virt_addr,
+		    u64 len, int npages, int page_shift, unsigned int order,
+		    int access_flags)
 {
 	struct mlx5_ib_dev *dev = to_mdev(pd->device);
+	struct mlx5_cache_ent *ent = mr_cache_ent_from_order(dev, order);
 	struct mlx5_ib_mr *mr;
-	int err = 0;
-	int i;
 
-	for (i = 0; i < 1; i++) {
-		mr = alloc_cached_mr(dev, order);
-		if (mr)
-			break;
+	if (!ent)
+		return ERR_PTR(-E2BIG);
 
-		err = add_keys(dev, order2idx(dev, order), 1);
-		if (err && err != -EAGAIN) {
-			mlx5_ib_warn(dev, "add_keys failed, err %d\n", err);
-			break;
-		}
+	/* Matches access in alloc_cache_mr() */
+	if (!mlx5_ib_can_reconfig_with_umr(dev, 0, access_flags))
+		return ERR_PTR(-EOPNOTSUPP);
+
+	mr = get_cache_mr(ent);
+	if (!mr) {
+		mr = create_cache_mr(ent);
+		if (IS_ERR(mr))
+			return mr;
 	}
-
-	if (!mr)
-		return ERR_PTR(-EAGAIN);
 
 	mr->ibmr.pd = pd;
 	mr->umem = umem;
@@ -951,36 +1011,6 @@
 	mr->mmkey.pd = to_mpd(pd)->pdn;
 
 	return mr;
-}
-
-static inline int populate_xlt(struct mlx5_ib_mr *mr, int idx, int npages,
-			       void *xlt, int page_shift, size_t size,
-			       int flags)
-{
-	struct mlx5_ib_dev *dev = mr->dev;
-	struct ib_umem *umem = mr->umem;
-
-	if (flags & MLX5_IB_UPD_XLT_INDIRECT) {
-		if (!umr_can_use_indirect_mkey(dev))
-			return -EPERM;
-		mlx5_odp_populate_klm(xlt, idx, npages, mr, flags);
-		return npages;
-	}
-
-	npages = min_t(size_t, npages, ib_umem_num_pages(umem) - idx);
-
-	if (!(flags & MLX5_IB_UPD_XLT_ZAP)) {
-		__mlx5_ib_populate_pas(dev, umem, page_shift,
-				       idx, npages, xlt,
-				       MLX5_IB_MTT_PRESENT);
-		/* Clear padding after the pages
-		 * brought from the umem.
-		 */
-		memset(xlt + (npages * sizeof(struct mlx5_mtt)), 0,
-		       size - npages * sizeof(struct mlx5_mtt));
-	}
-
-	return npages;
 }
 
 #define MLX5_MAX_UMR_CHUNK ((1 << (MLX5_MAX_UMR_SHIFT + 4)) - \
@@ -1006,6 +1036,7 @@
 	size_t pages_mapped = 0;
 	size_t pages_to_map = 0;
 	size_t pages_iter = 0;
+	size_t size_to_map = 0;
 	gfp_t gfp;
 	bool use_emergency_page = false;
 
@@ -1052,6 +1083,15 @@
 		goto free_xlt;
 	}
 
+	if (mr->umem->is_odp) {
+		if (!(flags & MLX5_IB_UPD_XLT_INDIRECT)) {
+			struct ib_umem_odp *odp = to_ib_umem_odp(mr->umem);
+			size_t max_pages = ib_umem_odp_num_pages(odp) - idx;
+
+			pages_to_map = min_t(size_t, pages_to_map, max_pages);
+		}
+	}
+
 	sg.addr = dma;
 	sg.lkey = dev->umrc.pd->local_dma_lkey;
 
@@ -1074,14 +1114,22 @@
 	     pages_mapped < pages_to_map && !err;
 	     pages_mapped += pages_iter, idx += pages_iter) {
 		npages = min_t(int, pages_iter, pages_to_map - pages_mapped);
+		size_to_map = npages * desc_size;
 		dma_sync_single_for_cpu(ddev, dma, size, DMA_TO_DEVICE);
-		npages = populate_xlt(mr, idx, npages, xlt,
-				      page_shift, size, flags);
-
+		if (mr->umem->is_odp) {
+			mlx5_odp_populate_xlt(xlt, idx, npages, mr, flags);
+		} else {
+			__mlx5_ib_populate_pas(dev, mr->umem, page_shift, idx,
+					       npages, xlt,
+					       MLX5_IB_MTT_PRESENT);
+			/* Clear padding after the pages
+			 * brought from the umem.
+			 */
+			memset(xlt + size_to_map, 0, size - size_to_map);
+		}
 		dma_sync_single_for_device(ddev, dma, size, DMA_TO_DEVICE);
 
-		sg.length = ALIGN(npages * desc_size,
-				  MLX5_UMR_MTT_ALIGNMENT);
+		sg.length = ALIGN(size_to_map, MLX5_UMR_MTT_ALIGNMENT);
 
 		if (pages_mapped + pages_iter >= pages_to_map) {
 			if (flags & MLX5_IB_UPD_XLT_ENABLE)
@@ -1149,38 +1197,37 @@
 		goto err_1;
 	}
 	pas = (__be64 *)MLX5_ADDR_OF(create_mkey_in, in, klm_pas_mtt);
-	if (populate && !(access_flags & IB_ACCESS_ON_DEMAND))
+	if (populate) {
+		if (WARN_ON(access_flags & IB_ACCESS_ON_DEMAND)) {
+			err = -EINVAL;
+			goto err_2;
+		}
 		mlx5_ib_populate_pas(dev, umem, page_shift, pas,
 				     pg_cap ? MLX5_IB_MTT_PRESENT : 0);
+	}
 
 	/* The pg_access bit allows setting the access flags
 	 * in the page list submitted with the command. */
 	MLX5_SET(create_mkey_in, in, pg_access, !!(pg_cap));
 
 	mkc = MLX5_ADDR_OF(create_mkey_in, in, memory_key_mkey_entry);
+	set_mkc_access_pd_addr_fields(mkc, access_flags, virt_addr,
+				      populate ? pd : dev->umrc.pd);
 	MLX5_SET(mkc, mkc, free, !populate);
 	MLX5_SET(mkc, mkc, access_mode_1_0, MLX5_MKC_ACCESS_MODE_MTT);
-	MLX5_SET(mkc, mkc, a, !!(access_flags & IB_ACCESS_REMOTE_ATOMIC));
-	MLX5_SET(mkc, mkc, rw, !!(access_flags & IB_ACCESS_REMOTE_WRITE));
-	MLX5_SET(mkc, mkc, rr, !!(access_flags & IB_ACCESS_REMOTE_READ));
-	MLX5_SET(mkc, mkc, lw, !!(access_flags & IB_ACCESS_LOCAL_WRITE));
-	MLX5_SET(mkc, mkc, lr, 1);
 	MLX5_SET(mkc, mkc, umr_en, 1);
 
-	MLX5_SET64(mkc, mkc, start_addr, virt_addr);
 	MLX5_SET64(mkc, mkc, len, length);
-	MLX5_SET(mkc, mkc, pd, to_mpd(pd)->pdn);
 	MLX5_SET(mkc, mkc, bsf_octword_size, 0);
 	MLX5_SET(mkc, mkc, translations_octword_size,
 		 get_octo_len(virt_addr, length, page_shift));
 	MLX5_SET(mkc, mkc, log_page_size, page_shift);
-	MLX5_SET(mkc, mkc, qpn, 0xffffff);
 	if (populate) {
 		MLX5_SET(create_mkey_in, in, translations_octword_actual_size,
 			 get_octo_len(virt_addr, length, page_shift));
 	}
 
-	err = mlx5_core_create_mkey(dev->mdev, &mr->mmkey, in, inlen);
+	err = mlx5_ib_create_mkey(dev, &mr->mmkey, in, inlen);
 	if (err) {
 		mlx5_ib_warn(dev, "create mkey failed\n");
 		goto err_2;
@@ -1204,23 +1251,20 @@
 	return ERR_PTR(err);
 }
 
-static void set_mr_fileds(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr,
-			  int npages, u64 length, int access_flags)
+static void set_mr_fields(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr,
+			  u64 length, int access_flags)
 {
-	mr->npages = npages;
-	atomic_add(npages, &dev->mdev->priv.reg_pages);
 	mr->ibmr.lkey = mr->mmkey.key;
 	mr->ibmr.rkey = mr->mmkey.key;
 	mr->ibmr.length = length;
 	mr->access_flags = access_flags;
 }
 
-static struct ib_mr *mlx5_ib_get_memic_mr(struct ib_pd *pd, u64 memic_addr,
-					  u64 length, int acc)
+static struct ib_mr *mlx5_ib_get_dm_mr(struct ib_pd *pd, u64 start_addr,
+				       u64 length, int acc, int mode)
 {
 	struct mlx5_ib_dev *dev = to_mdev(pd->device);
 	int inlen = MLX5_ST_SZ_BYTES(create_mkey_in);
-	struct mlx5_core_dev *mdev = dev->mdev;
 	struct mlx5_ib_mr *mr;
 	void *mkc;
 	u32 *in;
@@ -1238,29 +1282,18 @@
 
 	mkc = MLX5_ADDR_OF(create_mkey_in, in, memory_key_mkey_entry);
 
-	MLX5_SET(mkc, mkc, access_mode_1_0, MLX5_MKC_ACCESS_MODE_MEMIC & 0x3);
-	MLX5_SET(mkc, mkc, access_mode_4_2,
-		 (MLX5_MKC_ACCESS_MODE_MEMIC >> 2) & 0x7);
-	MLX5_SET(mkc, mkc, a, !!(acc & IB_ACCESS_REMOTE_ATOMIC));
-	MLX5_SET(mkc, mkc, rw, !!(acc & IB_ACCESS_REMOTE_WRITE));
-	MLX5_SET(mkc, mkc, rr, !!(acc & IB_ACCESS_REMOTE_READ));
-	MLX5_SET(mkc, mkc, lw, !!(acc & IB_ACCESS_LOCAL_WRITE));
-	MLX5_SET(mkc, mkc, lr, 1);
-
+	MLX5_SET(mkc, mkc, access_mode_1_0, mode & 0x3);
+	MLX5_SET(mkc, mkc, access_mode_4_2, (mode >> 2) & 0x7);
 	MLX5_SET64(mkc, mkc, len, length);
-	MLX5_SET(mkc, mkc, pd, to_mpd(pd)->pdn);
-	MLX5_SET(mkc, mkc, qpn, 0xffffff);
-	MLX5_SET64(mkc, mkc, start_addr,
-		   memic_addr - pci_resource_start(dev->mdev->pdev, 0));
+	set_mkc_access_pd_addr_fields(mkc, acc, start_addr, pd);
 
-	err = mlx5_core_create_mkey(mdev, &mr->mmkey, in, inlen);
+	err = mlx5_ib_create_mkey(dev, &mr->mmkey, in, inlen);
 	if (err)
 		goto err_in;
 
 	kfree(in);
 
-	mr->umem = NULL;
-	set_mr_fileds(dev, mr, 0, length, acc);
+	set_mr_fields(dev, mr, length, acc);
 
 	return &mr->ibmr;
 
@@ -1273,20 +1306,52 @@
 	return ERR_PTR(err);
 }
 
+int mlx5_ib_advise_mr(struct ib_pd *pd,
+		      enum ib_uverbs_advise_mr_advice advice,
+		      u32 flags,
+		      struct ib_sge *sg_list,
+		      u32 num_sge,
+		      struct uverbs_attr_bundle *attrs)
+{
+	if (advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH &&
+	    advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_WRITE &&
+	    advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_NO_FAULT)
+		return -EOPNOTSUPP;
+
+	return mlx5_ib_advise_mr_prefetch(pd, advice, flags,
+					 sg_list, num_sge);
+}
+
 struct ib_mr *mlx5_ib_reg_dm_mr(struct ib_pd *pd, struct ib_dm *dm,
 				struct ib_dm_mr_attr *attr,
 				struct uverbs_attr_bundle *attrs)
 {
 	struct mlx5_ib_dm *mdm = to_mdm(dm);
-	u64 memic_addr;
+	struct mlx5_core_dev *dev = to_mdev(dm->device)->mdev;
+	u64 start_addr = mdm->dev_addr + attr->offset;
+	int mode;
 
-	if (attr->access_flags & ~MLX5_IB_DM_ALLOWED_ACCESS)
+	switch (mdm->type) {
+	case MLX5_IB_UAPI_DM_TYPE_MEMIC:
+		if (attr->access_flags & ~MLX5_IB_DM_MEMIC_ALLOWED_ACCESS)
+			return ERR_PTR(-EINVAL);
+
+		mode = MLX5_MKC_ACCESS_MODE_MEMIC;
+		start_addr -= pci_resource_start(dev->pdev, 0);
+		break;
+	case MLX5_IB_UAPI_DM_TYPE_STEERING_SW_ICM:
+	case MLX5_IB_UAPI_DM_TYPE_HEADER_MODIFY_SW_ICM:
+		if (attr->access_flags & ~MLX5_IB_DM_SW_ICM_ALLOWED_ACCESS)
+			return ERR_PTR(-EINVAL);
+
+		mode = MLX5_MKC_ACCESS_MODE_SW_ICM;
+		break;
+	default:
 		return ERR_PTR(-EINVAL);
+	}
 
-	memic_addr = mdm->dev_addr + attr->offset;
-
-	return mlx5_ib_get_memic_mr(pd, memic_addr, attr->length,
-				    attr->access_flags);
+	return mlx5_ib_get_dm_mr(pd, start_addr, attr->length,
+				 attr->access_flags, mode);
 }
 
 struct ib_mr *mlx5_ib_reg_user_mr(struct ib_pd *pd, u64 start, u64 length,
@@ -1295,7 +1360,7 @@
 {
 	struct mlx5_ib_dev *dev = to_mdev(pd->device);
 	struct mlx5_ib_mr *mr = NULL;
-	bool use_umr;
+	bool xlt_with_umr;
 	struct ib_umem *umem;
 	int page_shift;
 	int npages;
@@ -1309,49 +1374,42 @@
 	mlx5_ib_dbg(dev, "start 0x%llx, virt_addr 0x%llx, length 0x%llx, access_flags 0x%x\n",
 		    start, virt_addr, length, access_flags);
 
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-	if (!start && length == U64_MAX) {
+	xlt_with_umr = mlx5_ib_can_load_pas_with_umr(dev, length);
+	/* ODP requires xlt update via umr to work. */
+	if (!xlt_with_umr && (access_flags & IB_ACCESS_ON_DEMAND))
+		return ERR_PTR(-EINVAL);
+
+	if (IS_ENABLED(CONFIG_INFINIBAND_ON_DEMAND_PAGING) && !start &&
+	    length == U64_MAX) {
+		if (virt_addr != start)
+			return ERR_PTR(-EINVAL);
 		if (!(access_flags & IB_ACCESS_ON_DEMAND) ||
 		    !(dev->odp_caps.general_caps & IB_ODP_SUPPORT_IMPLICIT))
 			return ERR_PTR(-EINVAL);
 
-		mr = mlx5_ib_alloc_implicit_mr(to_mpd(pd), access_flags);
+		mr = mlx5_ib_alloc_implicit_mr(to_mpd(pd), udata, access_flags);
 		if (IS_ERR(mr))
 			return ERR_CAST(mr);
 		return &mr->ibmr;
 	}
-#endif
 
-	err = mr_umem_get(pd, start, length, access_flags, &umem, &npages,
-			   &page_shift, &ncont, &order);
+	err = mr_umem_get(dev, start, length, access_flags, &umem,
+			  &npages, &page_shift, &ncont, &order);
 
 	if (err < 0)
 		return ERR_PTR(err);
 
-	use_umr = !MLX5_CAP_GEN(dev->mdev, umr_modify_entity_size_disabled) &&
-		  (!MLX5_CAP_GEN(dev->mdev, umr_modify_atomic_disabled) ||
-		   !MLX5_CAP_GEN(dev->mdev, atomic));
-
-	if (order <= mr_cache_max_order(dev) && use_umr) {
+	if (xlt_with_umr) {
 		mr = alloc_mr_from_cache(pd, umem, virt_addr, length, ncont,
 					 page_shift, order, access_flags);
-		if (PTR_ERR(mr) == -EAGAIN) {
-			mlx5_ib_dbg(dev, "cache empty for order %d\n", order);
+		if (IS_ERR(mr))
 			mr = NULL;
-		}
-	} else if (!MLX5_CAP_GEN(dev->mdev, umr_extended_translation_offset)) {
-		if (access_flags & IB_ACCESS_ON_DEMAND) {
-			err = -EINVAL;
-			pr_err("Got MR registration for ODP MR > 512MB, not supported for Connect-IB\n");
-			goto error;
-		}
-		use_umr = false;
 	}
 
 	if (!mr) {
 		mutex_lock(&dev->slow_path_mutex);
 		mr = reg_create(NULL, pd, virt_addr, length, umem, ncont,
-				page_shift, access_flags, !use_umr);
+				page_shift, access_flags, !xlt_with_umr);
 		mutex_unlock(&dev->slow_path_mutex);
 	}
 
@@ -1363,52 +1421,74 @@
 	mlx5_ib_dbg(dev, "mkey 0x%x\n", mr->mmkey.key);
 
 	mr->umem = umem;
-	set_mr_fileds(dev, mr, npages, length, access_flags);
+	mr->npages = npages;
+	atomic_add(mr->npages, &dev->mdev->priv.reg_pages);
+	set_mr_fields(dev, mr, length, access_flags);
 
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-	update_odp_mr(mr);
-#endif
-
-	if (use_umr) {
+	if (xlt_with_umr && !(access_flags & IB_ACCESS_ON_DEMAND)) {
+		/*
+		 * If the MR was created with reg_create then it will be
+		 * configured properly but left disabled. It is safe to go ahead
+		 * and configure it again via UMR while enabling it.
+		 */
 		int update_xlt_flags = MLX5_IB_UPD_XLT_ENABLE;
-
-		if (access_flags & IB_ACCESS_ON_DEMAND)
-			update_xlt_flags |= MLX5_IB_UPD_XLT_ZAP;
 
 		err = mlx5_ib_update_xlt(mr, 0, ncont, page_shift,
 					 update_xlt_flags);
-
 		if (err) {
 			dereg_mr(dev, mr);
 			return ERR_PTR(err);
 		}
 	}
 
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-	mr->live = 1;
-#endif
+	if (is_odp_mr(mr)) {
+		to_ib_umem_odp(mr->umem)->private = mr;
+		init_waitqueue_head(&mr->q_deferred_work);
+		atomic_set(&mr->num_deferred_work, 0);
+		err = xa_err(xa_store(&dev->odp_mkeys,
+				      mlx5_base_mkey(mr->mmkey.key), &mr->mmkey,
+				      GFP_KERNEL));
+		if (err) {
+			dereg_mr(dev, mr);
+			return ERR_PTR(err);
+		}
+
+		err = mlx5_ib_init_odp_mr(mr, xlt_with_umr);
+		if (err) {
+			dereg_mr(dev, mr);
+			return ERR_PTR(err);
+		}
+	}
+
 	return &mr->ibmr;
 error:
 	ib_umem_release(umem);
 	return ERR_PTR(err);
 }
 
-static int unreg_umr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr)
+/**
+ * mlx5_mr_cache_invalidate - Fence all DMA on the MR
+ * @mr: The MR to fence
+ *
+ * Upon return the NIC will not be doing any DMA to the pages under the MR,
+ * and any DMA inprogress will be completed. Failure of this function
+ * indicates the HW has failed catastrophically.
+ */
+int mlx5_mr_cache_invalidate(struct mlx5_ib_mr *mr)
 {
-	struct mlx5_core_dev *mdev = dev->mdev;
 	struct mlx5_umr_wr umrwr = {};
 
-	if (mdev->state == MLX5_DEVICE_STATE_INTERNAL_ERROR)
+	if (mr->dev->mdev->state == MLX5_DEVICE_STATE_INTERNAL_ERROR)
 		return 0;
 
 	umrwr.wr.send_flags = MLX5_IB_SEND_UMR_DISABLE_MR |
 			      MLX5_IB_SEND_UMR_UPDATE_PD_ACCESS;
 	umrwr.wr.opcode = MLX5_IB_WR_UMR;
-	umrwr.pd = dev->umrc.pd;
+	umrwr.pd = mr->dev->umrc.pd;
 	umrwr.mkey = mr->mmkey.key;
 	umrwr.ignore_free_state = 1;
 
-	return mlx5_ib_post_send_wait(dev, &umrwr);
+	return mlx5_ib_post_send_wait(mr->dev, &umrwr);
 }
 
 static int rereg_umr(struct ib_pd *pd, struct mlx5_ib_mr *mr,
@@ -1455,10 +1535,11 @@
 	mlx5_ib_dbg(dev, "start 0x%llx, virt_addr 0x%llx, length 0x%llx, access_flags 0x%x\n",
 		    start, virt_addr, length, access_flags);
 
-	atomic_sub(mr->npages, &dev->mdev->priv.reg_pages);
-
 	if (!mr->umem)
 		return -EINVAL;
+
+	if (is_odp_mr(mr))
+		return -EOPNOTSUPP;
 
 	if (flags & IB_MR_REREG_TRANS) {
 		addr = virt_addr;
@@ -1474,22 +1555,30 @@
 		 * used.
 		 */
 		flags |= IB_MR_REREG_TRANS;
+		atomic_sub(mr->npages, &dev->mdev->priv.reg_pages);
+		mr->npages = 0;
 		ib_umem_release(mr->umem);
 		mr->umem = NULL;
-		err = mr_umem_get(pd, addr, len, access_flags, &mr->umem,
+
+		err = mr_umem_get(dev, addr, len, access_flags, &mr->umem,
 				  &npages, &page_shift, &ncont, &order);
 		if (err)
 			goto err;
+		mr->npages = ncont;
+		atomic_add(mr->npages, &dev->mdev->priv.reg_pages);
 	}
 
-	if (flags & IB_MR_REREG_TRANS && !use_umr_mtt_update(mr, addr, len)) {
+	if (!mlx5_ib_can_reconfig_with_umr(dev, mr->access_flags,
+					   access_flags) ||
+	    !mlx5_ib_can_load_pas_with_umr(dev, len) ||
+	    (flags & IB_MR_REREG_TRANS &&
+	     !mlx5_ib_pas_fits_in_mr(mr, addr, len))) {
 		/*
 		 * UMR can't be used - MKey needs to be replaced.
 		 */
-		if (mr->allocated_from_cache)
-			err = unreg_umr(dev, mr);
-		else
-			err = destroy_mkey(dev, mr);
+		if (mr->cache_ent)
+			detach_mr_from_cache(mr);
+		err = destroy_mkey(dev, mr);
 		if (err)
 			goto err;
 
@@ -1501,11 +1590,6 @@
 			mr = to_mmr(ib_mr);
 			goto err;
 		}
-
-		mr->allocated_from_cache = 0;
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-		mr->live = 1;
-#endif
 	} else {
 		/*
 		 * Send a UMR WQE
@@ -1532,18 +1616,14 @@
 			goto err;
 	}
 
-	set_mr_fileds(dev, mr, npages, len, access_flags);
+	set_mr_fields(dev, mr, len, access_flags);
 
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-	update_odp_mr(mr);
-#endif
 	return 0;
 
 err:
-	if (mr->umem) {
-		ib_umem_release(mr->umem);
-		mr->umem = NULL;
-	}
+	ib_umem_release(mr->umem);
+	mr->umem = NULL;
+
 	clean_mr(dev, mr);
 	return err;
 }
@@ -1596,8 +1676,6 @@
 
 static void clean_mr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr)
 {
-	int allocated_from_cache = mr->allocated_from_cache;
-
 	if (mr->sig) {
 		if (mlx5_core_destroy_psv(dev->mdev,
 					  mr->sig->psv_memory.psv_idx))
@@ -1607,11 +1685,12 @@
 					  mr->sig->psv_wire.psv_idx))
 			mlx5_ib_warn(dev, "failed to destroy wire psv %d\n",
 				     mr->sig->psv_wire.psv_idx);
+		xa_erase(&dev->sig_mrs, mlx5_base_mkey(mr->mmkey.key));
 		kfree(mr->sig);
 		mr->sig = NULL;
 	}
 
-	if (!allocated_from_cache) {
+	if (!mr->cache_ent) {
 		destroy_mkey(dev, mr);
 		mlx5_free_priv_descs(mr);
 	}
@@ -1622,60 +1701,235 @@
 	int npages = mr->npages;
 	struct ib_umem *umem = mr->umem;
 
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-	if (umem && umem->odp_data) {
-		/* Prevent new page faults from succeeding */
-		mr->live = 0;
-		/* Wait for all running page-fault handlers to finish. */
-		synchronize_srcu(&dev->mr_srcu);
-		/* Destroy all page mappings */
-		if (umem->odp_data->page_list)
-			mlx5_ib_invalidate_range(umem, ib_umem_start(umem),
-						 ib_umem_end(umem));
-		else
-			mlx5_ib_free_implicit_mr(mr);
-		/*
-		 * We kill the umem before the MR for ODP,
-		 * so that there will not be any invalidations in
-		 * flight, looking at the *mr struct.
-		 */
-		ib_umem_release(umem);
-		atomic_sub(npages, &dev->mdev->priv.reg_pages);
+	/* Stop all DMA */
+	if (is_odp_mr(mr))
+		mlx5_ib_fence_odp_mr(mr);
+	else
+		clean_mr(dev, mr);
 
-		/* Avoid double-freeing the umem. */
-		umem = NULL;
-	}
-#endif
-	clean_mr(dev, mr);
-
-	/*
-	 * We should unregister the DMA address from the HCA before
-	 * remove the DMA mapping.
-	 */
-	mlx5_mr_cache_free(dev, mr);
-	if (umem) {
-		ib_umem_release(umem);
-		atomic_sub(npages, &dev->mdev->priv.reg_pages);
-	}
-	if (!mr->allocated_from_cache)
+	if (mr->cache_ent)
+		mlx5_mr_cache_free(dev, mr);
+	else
 		kfree(mr);
+
+	ib_umem_release(umem);
+	atomic_sub(npages, &dev->mdev->priv.reg_pages);
+
 }
 
-int mlx5_ib_dereg_mr(struct ib_mr *ibmr)
+int mlx5_ib_dereg_mr(struct ib_mr *ibmr, struct ib_udata *udata)
 {
-	dereg_mr(to_mdev(ibmr->device), to_mmr(ibmr));
+	struct mlx5_ib_mr *mmr = to_mmr(ibmr);
+
+	if (ibmr->type == IB_MR_TYPE_INTEGRITY) {
+		dereg_mr(to_mdev(mmr->mtt_mr->ibmr.device), mmr->mtt_mr);
+		dereg_mr(to_mdev(mmr->klm_mr->ibmr.device), mmr->klm_mr);
+	}
+
+	if (is_odp_mr(mmr) && to_ib_umem_odp(mmr->umem)->is_implicit_odp) {
+		mlx5_ib_free_implicit_mr(mmr);
+		return 0;
+	}
+
+	dereg_mr(to_mdev(ibmr->device), mmr);
+
 	return 0;
 }
 
-struct ib_mr *mlx5_ib_alloc_mr(struct ib_pd *pd,
-			       enum ib_mr_type mr_type,
-			       u32 max_num_sg)
+static void mlx5_set_umr_free_mkey(struct ib_pd *pd, u32 *in, int ndescs,
+				   int access_mode, int page_shift)
+{
+	void *mkc;
+
+	mkc = MLX5_ADDR_OF(create_mkey_in, in, memory_key_mkey_entry);
+
+	/* This is only used from the kernel, so setting the PD is OK. */
+	set_mkc_access_pd_addr_fields(mkc, 0, 0, pd);
+	MLX5_SET(mkc, mkc, free, 1);
+	MLX5_SET(mkc, mkc, translations_octword_size, ndescs);
+	MLX5_SET(mkc, mkc, access_mode_1_0, access_mode & 0x3);
+	MLX5_SET(mkc, mkc, access_mode_4_2, (access_mode >> 2) & 0x7);
+	MLX5_SET(mkc, mkc, umr_en, 1);
+	MLX5_SET(mkc, mkc, log_page_size, page_shift);
+}
+
+static int _mlx5_alloc_mkey_descs(struct ib_pd *pd, struct mlx5_ib_mr *mr,
+				  int ndescs, int desc_size, int page_shift,
+				  int access_mode, u32 *in, int inlen)
+{
+	struct mlx5_ib_dev *dev = to_mdev(pd->device);
+	int err;
+
+	mr->access_mode = access_mode;
+	mr->desc_size = desc_size;
+	mr->max_descs = ndescs;
+
+	err = mlx5_alloc_priv_descs(pd->device, mr, ndescs, desc_size);
+	if (err)
+		return err;
+
+	mlx5_set_umr_free_mkey(pd, in, ndescs, access_mode, page_shift);
+
+	err = mlx5_ib_create_mkey(dev, &mr->mmkey, in, inlen);
+	if (err)
+		goto err_free_descs;
+
+	mr->mmkey.type = MLX5_MKEY_MR;
+	mr->ibmr.lkey = mr->mmkey.key;
+	mr->ibmr.rkey = mr->mmkey.key;
+
+	return 0;
+
+err_free_descs:
+	mlx5_free_priv_descs(mr);
+	return err;
+}
+
+static struct mlx5_ib_mr *mlx5_ib_alloc_pi_mr(struct ib_pd *pd,
+				u32 max_num_sg, u32 max_num_meta_sg,
+				int desc_size, int access_mode)
+{
+	int inlen = MLX5_ST_SZ_BYTES(create_mkey_in);
+	int ndescs = ALIGN(max_num_sg + max_num_meta_sg, 4);
+	int page_shift = 0;
+	struct mlx5_ib_mr *mr;
+	u32 *in;
+	int err;
+
+	mr = kzalloc(sizeof(*mr), GFP_KERNEL);
+	if (!mr)
+		return ERR_PTR(-ENOMEM);
+
+	mr->ibmr.pd = pd;
+	mr->ibmr.device = pd->device;
+
+	in = kzalloc(inlen, GFP_KERNEL);
+	if (!in) {
+		err = -ENOMEM;
+		goto err_free;
+	}
+
+	if (access_mode == MLX5_MKC_ACCESS_MODE_MTT)
+		page_shift = PAGE_SHIFT;
+
+	err = _mlx5_alloc_mkey_descs(pd, mr, ndescs, desc_size, page_shift,
+				     access_mode, in, inlen);
+	if (err)
+		goto err_free_in;
+
+	mr->umem = NULL;
+	kfree(in);
+
+	return mr;
+
+err_free_in:
+	kfree(in);
+err_free:
+	kfree(mr);
+	return ERR_PTR(err);
+}
+
+static int mlx5_alloc_mem_reg_descs(struct ib_pd *pd, struct mlx5_ib_mr *mr,
+				    int ndescs, u32 *in, int inlen)
+{
+	return _mlx5_alloc_mkey_descs(pd, mr, ndescs, sizeof(struct mlx5_mtt),
+				      PAGE_SHIFT, MLX5_MKC_ACCESS_MODE_MTT, in,
+				      inlen);
+}
+
+static int mlx5_alloc_sg_gaps_descs(struct ib_pd *pd, struct mlx5_ib_mr *mr,
+				    int ndescs, u32 *in, int inlen)
+{
+	return _mlx5_alloc_mkey_descs(pd, mr, ndescs, sizeof(struct mlx5_klm),
+				      0, MLX5_MKC_ACCESS_MODE_KLMS, in, inlen);
+}
+
+static int mlx5_alloc_integrity_descs(struct ib_pd *pd, struct mlx5_ib_mr *mr,
+				      int max_num_sg, int max_num_meta_sg,
+				      u32 *in, int inlen)
+{
+	struct mlx5_ib_dev *dev = to_mdev(pd->device);
+	u32 psv_index[2];
+	void *mkc;
+	int err;
+
+	mr->sig = kzalloc(sizeof(*mr->sig), GFP_KERNEL);
+	if (!mr->sig)
+		return -ENOMEM;
+
+	/* create mem & wire PSVs */
+	err = mlx5_core_create_psv(dev->mdev, to_mpd(pd)->pdn, 2, psv_index);
+	if (err)
+		goto err_free_sig;
+
+	mr->sig->psv_memory.psv_idx = psv_index[0];
+	mr->sig->psv_wire.psv_idx = psv_index[1];
+
+	mr->sig->sig_status_checked = true;
+	mr->sig->sig_err_exists = false;
+	/* Next UMR, Arm SIGERR */
+	++mr->sig->sigerr_count;
+	mr->klm_mr = mlx5_ib_alloc_pi_mr(pd, max_num_sg, max_num_meta_sg,
+					 sizeof(struct mlx5_klm),
+					 MLX5_MKC_ACCESS_MODE_KLMS);
+	if (IS_ERR(mr->klm_mr)) {
+		err = PTR_ERR(mr->klm_mr);
+		goto err_destroy_psv;
+	}
+	mr->mtt_mr = mlx5_ib_alloc_pi_mr(pd, max_num_sg, max_num_meta_sg,
+					 sizeof(struct mlx5_mtt),
+					 MLX5_MKC_ACCESS_MODE_MTT);
+	if (IS_ERR(mr->mtt_mr)) {
+		err = PTR_ERR(mr->mtt_mr);
+		goto err_free_klm_mr;
+	}
+
+	/* Set bsf descriptors for mkey */
+	mkc = MLX5_ADDR_OF(create_mkey_in, in, memory_key_mkey_entry);
+	MLX5_SET(mkc, mkc, bsf_en, 1);
+	MLX5_SET(mkc, mkc, bsf_octword_size, MLX5_MKEY_BSF_OCTO_SIZE);
+
+	err = _mlx5_alloc_mkey_descs(pd, mr, 4, sizeof(struct mlx5_klm), 0,
+				     MLX5_MKC_ACCESS_MODE_KLMS, in, inlen);
+	if (err)
+		goto err_free_mtt_mr;
+
+	err = xa_err(xa_store(&dev->sig_mrs, mlx5_base_mkey(mr->mmkey.key),
+			      mr->sig, GFP_KERNEL));
+	if (err)
+		goto err_free_descs;
+	return 0;
+
+err_free_descs:
+	destroy_mkey(dev, mr);
+	mlx5_free_priv_descs(mr);
+err_free_mtt_mr:
+	dereg_mr(to_mdev(mr->mtt_mr->ibmr.device), mr->mtt_mr);
+	mr->mtt_mr = NULL;
+err_free_klm_mr:
+	dereg_mr(to_mdev(mr->klm_mr->ibmr.device), mr->klm_mr);
+	mr->klm_mr = NULL;
+err_destroy_psv:
+	if (mlx5_core_destroy_psv(dev->mdev, mr->sig->psv_memory.psv_idx))
+		mlx5_ib_warn(dev, "failed to destroy mem psv %d\n",
+			     mr->sig->psv_memory.psv_idx);
+	if (mlx5_core_destroy_psv(dev->mdev, mr->sig->psv_wire.psv_idx))
+		mlx5_ib_warn(dev, "failed to destroy wire psv %d\n",
+			     mr->sig->psv_wire.psv_idx);
+err_free_sig:
+	kfree(mr->sig);
+
+	return err;
+}
+
+static struct ib_mr *__mlx5_ib_alloc_mr(struct ib_pd *pd,
+					enum ib_mr_type mr_type, u32 max_num_sg,
+					u32 max_num_meta_sg)
 {
 	struct mlx5_ib_dev *dev = to_mdev(pd->device);
 	int inlen = MLX5_ST_SZ_BYTES(create_mkey_in);
 	int ndescs = ALIGN(max_num_sg, 4);
 	struct mlx5_ib_mr *mr;
-	void *mkc;
 	u32 *in;
 	int err;
 
@@ -1689,93 +1943,32 @@
 		goto err_free;
 	}
 
-	mkc = MLX5_ADDR_OF(create_mkey_in, in, memory_key_mkey_entry);
-	MLX5_SET(mkc, mkc, free, 1);
-	MLX5_SET(mkc, mkc, translations_octword_size, ndescs);
-	MLX5_SET(mkc, mkc, qpn, 0xffffff);
-	MLX5_SET(mkc, mkc, pd, to_mpd(pd)->pdn);
+	mr->ibmr.device = pd->device;
+	mr->umem = NULL;
 
-	if (mr_type == IB_MR_TYPE_MEM_REG) {
-		mr->access_mode = MLX5_MKC_ACCESS_MODE_MTT;
-		MLX5_SET(mkc, mkc, log_page_size, PAGE_SHIFT);
-		err = mlx5_alloc_priv_descs(pd->device, mr,
-					    ndescs, sizeof(struct mlx5_mtt));
-		if (err)
-			goto err_free_in;
-
-		mr->desc_size = sizeof(struct mlx5_mtt);
-		mr->max_descs = ndescs;
-	} else if (mr_type == IB_MR_TYPE_SG_GAPS) {
-		mr->access_mode = MLX5_MKC_ACCESS_MODE_KLMS;
-
-		err = mlx5_alloc_priv_descs(pd->device, mr,
-					    ndescs, sizeof(struct mlx5_klm));
-		if (err)
-			goto err_free_in;
-		mr->desc_size = sizeof(struct mlx5_klm);
-		mr->max_descs = ndescs;
-	} else if (mr_type == IB_MR_TYPE_SIGNATURE) {
-		u32 psv_index[2];
-
-		MLX5_SET(mkc, mkc, bsf_en, 1);
-		MLX5_SET(mkc, mkc, bsf_octword_size, MLX5_MKEY_BSF_OCTO_SIZE);
-		mr->sig = kzalloc(sizeof(*mr->sig), GFP_KERNEL);
-		if (!mr->sig) {
-			err = -ENOMEM;
-			goto err_free_in;
-		}
-
-		/* create mem & wire PSVs */
-		err = mlx5_core_create_psv(dev->mdev, to_mpd(pd)->pdn,
-					   2, psv_index);
-		if (err)
-			goto err_free_sig;
-
-		mr->access_mode = MLX5_MKC_ACCESS_MODE_KLMS;
-		mr->sig->psv_memory.psv_idx = psv_index[0];
-		mr->sig->psv_wire.psv_idx = psv_index[1];
-
-		mr->sig->sig_status_checked = true;
-		mr->sig->sig_err_exists = false;
-		/* Next UMR, Arm SIGERR */
-		++mr->sig->sigerr_count;
-	} else {
+	switch (mr_type) {
+	case IB_MR_TYPE_MEM_REG:
+		err = mlx5_alloc_mem_reg_descs(pd, mr, ndescs, in, inlen);
+		break;
+	case IB_MR_TYPE_SG_GAPS:
+		err = mlx5_alloc_sg_gaps_descs(pd, mr, ndescs, in, inlen);
+		break;
+	case IB_MR_TYPE_INTEGRITY:
+		err = mlx5_alloc_integrity_descs(pd, mr, max_num_sg,
+						 max_num_meta_sg, in, inlen);
+		break;
+	default:
 		mlx5_ib_warn(dev, "Invalid mr type %d\n", mr_type);
 		err = -EINVAL;
-		goto err_free_in;
 	}
 
-	MLX5_SET(mkc, mkc, access_mode_1_0, mr->access_mode & 0x3);
-	MLX5_SET(mkc, mkc, access_mode_4_2, (mr->access_mode >> 2) & 0x7);
-	MLX5_SET(mkc, mkc, umr_en, 1);
-
-	mr->ibmr.device = pd->device;
-	err = mlx5_core_create_mkey(dev->mdev, &mr->mmkey, in, inlen);
 	if (err)
-		goto err_destroy_psv;
+		goto err_free_in;
 
-	mr->mmkey.type = MLX5_MKEY_MR;
-	mr->ibmr.lkey = mr->mmkey.key;
-	mr->ibmr.rkey = mr->mmkey.key;
-	mr->umem = NULL;
 	kfree(in);
 
 	return &mr->ibmr;
 
-err_destroy_psv:
-	if (mr->sig) {
-		if (mlx5_core_destroy_psv(dev->mdev,
-					  mr->sig->psv_memory.psv_idx))
-			mlx5_ib_warn(dev, "failed to destroy mem psv %d\n",
-				     mr->sig->psv_memory.psv_idx);
-		if (mlx5_core_destroy_psv(dev->mdev,
-					  mr->sig->psv_wire.psv_idx))
-			mlx5_ib_warn(dev, "failed to destroy wire psv %d\n",
-				     mr->sig->psv_wire.psv_idx);
-	}
-	mlx5_free_priv_descs(mr);
-err_free_sig:
-	kfree(mr->sig);
 err_free_in:
 	kfree(in);
 err_free:
@@ -1783,12 +1976,24 @@
 	return ERR_PTR(err);
 }
 
-struct ib_mw *mlx5_ib_alloc_mw(struct ib_pd *pd, enum ib_mw_type type,
-			       struct ib_udata *udata)
+struct ib_mr *mlx5_ib_alloc_mr(struct ib_pd *pd, enum ib_mr_type mr_type,
+			       u32 max_num_sg)
 {
-	struct mlx5_ib_dev *dev = to_mdev(pd->device);
+	return __mlx5_ib_alloc_mr(pd, mr_type, max_num_sg, 0);
+}
+
+struct ib_mr *mlx5_ib_alloc_mr_integrity(struct ib_pd *pd,
+					 u32 max_num_sg, u32 max_num_meta_sg)
+{
+	return __mlx5_ib_alloc_mr(pd, IB_MR_TYPE_INTEGRITY, max_num_sg,
+				  max_num_meta_sg);
+}
+
+int mlx5_ib_alloc_mw(struct ib_mw *ibmw, struct ib_udata *udata)
+{
+	struct mlx5_ib_dev *dev = to_mdev(ibmw->device);
 	int inlen = MLX5_ST_SZ_BYTES(create_mkey_in);
-	struct mlx5_ib_mw *mw = NULL;
+	struct mlx5_ib_mw *mw = to_mmw(ibmw);
 	u32 *in = NULL;
 	void *mkc;
 	int ndescs;
@@ -1801,21 +2006,20 @@
 
 	err = ib_copy_from_udata(&req, udata, min(udata->inlen, sizeof(req)));
 	if (err)
-		return ERR_PTR(err);
+		return err;
 
 	if (req.comp_mask || req.reserved1 || req.reserved2)
-		return ERR_PTR(-EOPNOTSUPP);
+		return -EOPNOTSUPP;
 
 	if (udata->inlen > sizeof(req) &&
 	    !ib_is_udata_cleared(udata, sizeof(req),
 				 udata->inlen - sizeof(req)))
-		return ERR_PTR(-EOPNOTSUPP);
+		return -EOPNOTSUPP;
 
 	ndescs = req.num_klms ? roundup(req.num_klms, 4) : roundup(1, 4);
 
-	mw = kzalloc(sizeof(*mw), GFP_KERNEL);
 	in = kzalloc(inlen, GFP_KERNEL);
-	if (!mw || !in) {
+	if (!in) {
 		err = -ENOMEM;
 		goto free;
 	}
@@ -1824,50 +2028,62 @@
 
 	MLX5_SET(mkc, mkc, free, 1);
 	MLX5_SET(mkc, mkc, translations_octword_size, ndescs);
-	MLX5_SET(mkc, mkc, pd, to_mpd(pd)->pdn);
+	MLX5_SET(mkc, mkc, pd, to_mpd(ibmw->pd)->pdn);
 	MLX5_SET(mkc, mkc, umr_en, 1);
 	MLX5_SET(mkc, mkc, lr, 1);
 	MLX5_SET(mkc, mkc, access_mode_1_0, MLX5_MKC_ACCESS_MODE_KLMS);
-	MLX5_SET(mkc, mkc, en_rinval, !!((type == IB_MW_TYPE_2)));
+	MLX5_SET(mkc, mkc, en_rinval, !!((ibmw->type == IB_MW_TYPE_2)));
 	MLX5_SET(mkc, mkc, qpn, 0xffffff);
 
-	err = mlx5_core_create_mkey(dev->mdev, &mw->mmkey, in, inlen);
+	err = mlx5_ib_create_mkey(dev, &mw->mmkey, in, inlen);
 	if (err)
 		goto free;
 
 	mw->mmkey.type = MLX5_MKEY_MW;
-	mw->ibmw.rkey = mw->mmkey.key;
+	ibmw->rkey = mw->mmkey.key;
 	mw->ndescs = ndescs;
 
-	resp.response_length = min(offsetof(typeof(resp), response_length) +
-				   sizeof(resp.response_length), udata->outlen);
+	resp.response_length =
+		min(offsetofend(typeof(resp), response_length), udata->outlen);
 	if (resp.response_length) {
 		err = ib_copy_to_udata(udata, &resp, resp.response_length);
-		if (err) {
-			mlx5_core_destroy_mkey(dev->mdev, &mw->mmkey);
-			goto free;
-		}
+		if (err)
+			goto free_mkey;
+	}
+
+	if (IS_ENABLED(CONFIG_INFINIBAND_ON_DEMAND_PAGING)) {
+		err = xa_err(xa_store(&dev->odp_mkeys,
+				      mlx5_base_mkey(mw->mmkey.key), &mw->mmkey,
+				      GFP_KERNEL));
+		if (err)
+			goto free_mkey;
 	}
 
 	kfree(in);
-	return &mw->ibmw;
+	return 0;
 
+free_mkey:
+	mlx5_core_destroy_mkey(dev->mdev, &mw->mmkey);
 free:
-	kfree(mw);
 	kfree(in);
-	return ERR_PTR(err);
+	return err;
 }
 
 int mlx5_ib_dealloc_mw(struct ib_mw *mw)
 {
+	struct mlx5_ib_dev *dev = to_mdev(mw->device);
 	struct mlx5_ib_mw *mmw = to_mmw(mw);
-	int err;
 
-	err =  mlx5_core_destroy_mkey((to_mdev(mw->device))->mdev,
-				      &mmw->mmkey);
-	if (!err)
-		kfree(mmw);
-	return err;
+	if (IS_ENABLED(CONFIG_INFINIBAND_ON_DEMAND_PAGING)) {
+		xa_erase(&dev->odp_mkeys, mlx5_base_mkey(mmw->mmkey.key));
+		/*
+		 * pagefault_single_data_segment() may be accessing mmw under
+		 * SRCU if the user bound an ODP MR to this MW.
+		 */
+		synchronize_srcu(&dev->odp_srcu);
+	}
+
+	return mlx5_core_destroy_mkey(dev->mdev, &mmw->mmkey);
 }
 
 int mlx5_ib_check_mr_status(struct ib_mr *ibmr, u32 check_mask,
@@ -1912,16 +2128,53 @@
 }
 
 static int
+mlx5_ib_map_pa_mr_sg_pi(struct ib_mr *ibmr, struct scatterlist *data_sg,
+			int data_sg_nents, unsigned int *data_sg_offset,
+			struct scatterlist *meta_sg, int meta_sg_nents,
+			unsigned int *meta_sg_offset)
+{
+	struct mlx5_ib_mr *mr = to_mmr(ibmr);
+	unsigned int sg_offset = 0;
+	int n = 0;
+
+	mr->meta_length = 0;
+	if (data_sg_nents == 1) {
+		n++;
+		mr->ndescs = 1;
+		if (data_sg_offset)
+			sg_offset = *data_sg_offset;
+		mr->data_length = sg_dma_len(data_sg) - sg_offset;
+		mr->data_iova = sg_dma_address(data_sg) + sg_offset;
+		if (meta_sg_nents == 1) {
+			n++;
+			mr->meta_ndescs = 1;
+			if (meta_sg_offset)
+				sg_offset = *meta_sg_offset;
+			else
+				sg_offset = 0;
+			mr->meta_length = sg_dma_len(meta_sg) - sg_offset;
+			mr->pi_iova = sg_dma_address(meta_sg) + sg_offset;
+		}
+		ibmr->length = mr->data_length + mr->meta_length;
+	}
+
+	return n;
+}
+
+static int
 mlx5_ib_sg_to_klms(struct mlx5_ib_mr *mr,
 		   struct scatterlist *sgl,
 		   unsigned short sg_nents,
-		   unsigned int *sg_offset_p)
+		   unsigned int *sg_offset_p,
+		   struct scatterlist *meta_sgl,
+		   unsigned short meta_sg_nents,
+		   unsigned int *meta_sg_offset_p)
 {
 	struct scatterlist *sg = sgl;
 	struct mlx5_klm *klms = mr->descs;
 	unsigned int sg_offset = sg_offset_p ? *sg_offset_p : 0;
 	u32 lkey = mr->ibmr.pd->local_dma_lkey;
-	int i;
+	int i, j = 0;
 
 	mr->ibmr.iova = sg_dma_address(sg) + sg_offset;
 	mr->ibmr.length = 0;
@@ -1936,12 +2189,36 @@
 
 		sg_offset = 0;
 	}
-	mr->ndescs = i;
 
 	if (sg_offset_p)
 		*sg_offset_p = sg_offset;
 
-	return i;
+	mr->ndescs = i;
+	mr->data_length = mr->ibmr.length;
+
+	if (meta_sg_nents) {
+		sg = meta_sgl;
+		sg_offset = meta_sg_offset_p ? *meta_sg_offset_p : 0;
+		for_each_sg(meta_sgl, sg, meta_sg_nents, j) {
+			if (unlikely(i + j >= mr->max_descs))
+				break;
+			klms[i + j].va = cpu_to_be64(sg_dma_address(sg) +
+						     sg_offset);
+			klms[i + j].bcount = cpu_to_be32(sg_dma_len(sg) -
+							 sg_offset);
+			klms[i + j].key = cpu_to_be32(lkey);
+			mr->ibmr.length += sg_dma_len(sg) - sg_offset;
+
+			sg_offset = 0;
+		}
+		if (meta_sg_offset_p)
+			*meta_sg_offset_p = sg_offset;
+
+		mr->meta_ndescs = j;
+		mr->meta_length = mr->ibmr.length - mr->data_length;
+	}
+
+	return i + j;
 }
 
 static int mlx5_set_page(struct ib_mr *ibmr, u64 addr)
@@ -1954,6 +2231,181 @@
 
 	descs = mr->descs;
 	descs[mr->ndescs++] = cpu_to_be64(addr | MLX5_EN_RD | MLX5_EN_WR);
+
+	return 0;
+}
+
+static int mlx5_set_page_pi(struct ib_mr *ibmr, u64 addr)
+{
+	struct mlx5_ib_mr *mr = to_mmr(ibmr);
+	__be64 *descs;
+
+	if (unlikely(mr->ndescs + mr->meta_ndescs == mr->max_descs))
+		return -ENOMEM;
+
+	descs = mr->descs;
+	descs[mr->ndescs + mr->meta_ndescs++] =
+		cpu_to_be64(addr | MLX5_EN_RD | MLX5_EN_WR);
+
+	return 0;
+}
+
+static int
+mlx5_ib_map_mtt_mr_sg_pi(struct ib_mr *ibmr, struct scatterlist *data_sg,
+			 int data_sg_nents, unsigned int *data_sg_offset,
+			 struct scatterlist *meta_sg, int meta_sg_nents,
+			 unsigned int *meta_sg_offset)
+{
+	struct mlx5_ib_mr *mr = to_mmr(ibmr);
+	struct mlx5_ib_mr *pi_mr = mr->mtt_mr;
+	int n;
+
+	pi_mr->ndescs = 0;
+	pi_mr->meta_ndescs = 0;
+	pi_mr->meta_length = 0;
+
+	ib_dma_sync_single_for_cpu(ibmr->device, pi_mr->desc_map,
+				   pi_mr->desc_size * pi_mr->max_descs,
+				   DMA_TO_DEVICE);
+
+	pi_mr->ibmr.page_size = ibmr->page_size;
+	n = ib_sg_to_pages(&pi_mr->ibmr, data_sg, data_sg_nents, data_sg_offset,
+			   mlx5_set_page);
+	if (n != data_sg_nents)
+		return n;
+
+	pi_mr->data_iova = pi_mr->ibmr.iova;
+	pi_mr->data_length = pi_mr->ibmr.length;
+	pi_mr->ibmr.length = pi_mr->data_length;
+	ibmr->length = pi_mr->data_length;
+
+	if (meta_sg_nents) {
+		u64 page_mask = ~((u64)ibmr->page_size - 1);
+		u64 iova = pi_mr->data_iova;
+
+		n += ib_sg_to_pages(&pi_mr->ibmr, meta_sg, meta_sg_nents,
+				    meta_sg_offset, mlx5_set_page_pi);
+
+		pi_mr->meta_length = pi_mr->ibmr.length;
+		/*
+		 * PI address for the HW is the offset of the metadata address
+		 * relative to the first data page address.
+		 * It equals to first data page address + size of data pages +
+		 * metadata offset at the first metadata page
+		 */
+		pi_mr->pi_iova = (iova & page_mask) +
+				 pi_mr->ndescs * ibmr->page_size +
+				 (pi_mr->ibmr.iova & ~page_mask);
+		/*
+		 * In order to use one MTT MR for data and metadata, we register
+		 * also the gaps between the end of the data and the start of
+		 * the metadata (the sig MR will verify that the HW will access
+		 * to right addresses). This mapping is safe because we use
+		 * internal mkey for the registration.
+		 */
+		pi_mr->ibmr.length = pi_mr->pi_iova + pi_mr->meta_length - iova;
+		pi_mr->ibmr.iova = iova;
+		ibmr->length += pi_mr->meta_length;
+	}
+
+	ib_dma_sync_single_for_device(ibmr->device, pi_mr->desc_map,
+				      pi_mr->desc_size * pi_mr->max_descs,
+				      DMA_TO_DEVICE);
+
+	return n;
+}
+
+static int
+mlx5_ib_map_klm_mr_sg_pi(struct ib_mr *ibmr, struct scatterlist *data_sg,
+			 int data_sg_nents, unsigned int *data_sg_offset,
+			 struct scatterlist *meta_sg, int meta_sg_nents,
+			 unsigned int *meta_sg_offset)
+{
+	struct mlx5_ib_mr *mr = to_mmr(ibmr);
+	struct mlx5_ib_mr *pi_mr = mr->klm_mr;
+	int n;
+
+	pi_mr->ndescs = 0;
+	pi_mr->meta_ndescs = 0;
+	pi_mr->meta_length = 0;
+
+	ib_dma_sync_single_for_cpu(ibmr->device, pi_mr->desc_map,
+				   pi_mr->desc_size * pi_mr->max_descs,
+				   DMA_TO_DEVICE);
+
+	n = mlx5_ib_sg_to_klms(pi_mr, data_sg, data_sg_nents, data_sg_offset,
+			       meta_sg, meta_sg_nents, meta_sg_offset);
+
+	ib_dma_sync_single_for_device(ibmr->device, pi_mr->desc_map,
+				      pi_mr->desc_size * pi_mr->max_descs,
+				      DMA_TO_DEVICE);
+
+	/* This is zero-based memory region */
+	pi_mr->data_iova = 0;
+	pi_mr->ibmr.iova = 0;
+	pi_mr->pi_iova = pi_mr->data_length;
+	ibmr->length = pi_mr->ibmr.length;
+
+	return n;
+}
+
+int mlx5_ib_map_mr_sg_pi(struct ib_mr *ibmr, struct scatterlist *data_sg,
+			 int data_sg_nents, unsigned int *data_sg_offset,
+			 struct scatterlist *meta_sg, int meta_sg_nents,
+			 unsigned int *meta_sg_offset)
+{
+	struct mlx5_ib_mr *mr = to_mmr(ibmr);
+	struct mlx5_ib_mr *pi_mr = NULL;
+	int n;
+
+	WARN_ON(ibmr->type != IB_MR_TYPE_INTEGRITY);
+
+	mr->ndescs = 0;
+	mr->data_length = 0;
+	mr->data_iova = 0;
+	mr->meta_ndescs = 0;
+	mr->pi_iova = 0;
+	/*
+	 * As a performance optimization, if possible, there is no need to
+	 * perform UMR operation to register the data/metadata buffers.
+	 * First try to map the sg lists to PA descriptors with local_dma_lkey.
+	 * Fallback to UMR only in case of a failure.
+	 */
+	n = mlx5_ib_map_pa_mr_sg_pi(ibmr, data_sg, data_sg_nents,
+				    data_sg_offset, meta_sg, meta_sg_nents,
+				    meta_sg_offset);
+	if (n == data_sg_nents + meta_sg_nents)
+		goto out;
+	/*
+	 * As a performance optimization, if possible, there is no need to map
+	 * the sg lists to KLM descriptors. First try to map the sg lists to MTT
+	 * descriptors and fallback to KLM only in case of a failure.
+	 * It's more efficient for the HW to work with MTT descriptors
+	 * (especially in high load).
+	 * Use KLM (indirect access) only if it's mandatory.
+	 */
+	pi_mr = mr->mtt_mr;
+	n = mlx5_ib_map_mtt_mr_sg_pi(ibmr, data_sg, data_sg_nents,
+				     data_sg_offset, meta_sg, meta_sg_nents,
+				     meta_sg_offset);
+	if (n == data_sg_nents + meta_sg_nents)
+		goto out;
+
+	pi_mr = mr->klm_mr;
+	n = mlx5_ib_map_klm_mr_sg_pi(ibmr, data_sg, data_sg_nents,
+				     data_sg_offset, meta_sg, meta_sg_nents,
+				     meta_sg_offset);
+	if (unlikely(n != data_sg_nents + meta_sg_nents))
+		return -ENOMEM;
+
+out:
+	/* This is zero-based memory region */
+	ibmr->iova = 0;
+	mr->pi_mr = pi_mr;
+	if (pi_mr)
+		ibmr->sig_attrs->meta_length = pi_mr->meta_length;
+	else
+		ibmr->sig_attrs->meta_length = mr->meta_length;
 
 	return 0;
 }
@@ -1971,7 +2423,8 @@
 				   DMA_TO_DEVICE);
 
 	if (mr->access_mode == MLX5_MKC_ACCESS_MODE_KLMS)
-		n = mlx5_ib_sg_to_klms(mr, sg, sg_nents, sg_offset);
+		n = mlx5_ib_sg_to_klms(mr, sg, sg_nents, sg_offset, NULL, 0,
+				       NULL);
 	else
 		n = ib_sg_to_pages(ibmr, sg, sg_nents, sg_offset,
 				mlx5_set_page);

--
Gitblit v1.6.2