forked from ~ljy/RK356X_SDK_RELEASE

hc
2024-10-09 244b2c5ca8b14627e4a17755e5922221e121c771
kernel/drivers/infiniband/core/umem_odp.c
....@@ -39,598 +39,307 @@
3939 #include <linux/export.h>
4040 #include <linux/vmalloc.h>
4141 #include <linux/hugetlb.h>
42
-#include <linux/interval_tree_generic.h>
42
+#include <linux/interval_tree.h>
43
+#include <linux/hmm.h>
44
+#include <linux/pagemap.h>
4345
4446 #include <rdma/ib_verbs.h>
4547 #include <rdma/ib_umem.h>
4648 #include <rdma/ib_umem_odp.h>
4749
48
-/*
49
- * The ib_umem list keeps track of memory regions for which the HW
50
- * device request to receive notification when the related memory
51
- * mapping is changed.
52
- *
53
- * ib_umem_lock protects the list.
54
- */
50
+#include "uverbs.h"
5551
56
-static u64 node_start(struct umem_odp_node *n)
52
+static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp,
53
+ const struct mmu_interval_notifier_ops *ops)
5754 {
58
- struct ib_umem_odp *umem_odp =
59
- container_of(n, struct ib_umem_odp, interval_tree);
60
-
61
- return ib_umem_start(umem_odp->umem);
62
-}
63
-
64
-/* Note that the representation of the intervals in the interval tree
65
- * considers the ending point as contained in the interval, while the
66
- * function ib_umem_end returns the first address which is not contained
67
- * in the umem.
68
- */
69
-static u64 node_last(struct umem_odp_node *n)
70
-{
71
- struct ib_umem_odp *umem_odp =
72
- container_of(n, struct ib_umem_odp, interval_tree);
73
-
74
- return ib_umem_end(umem_odp->umem) - 1;
75
-}
76
-
77
-INTERVAL_TREE_DEFINE(struct umem_odp_node, rb, u64, __subtree_last,
78
- node_start, node_last, static, rbt_ib_umem)
79
-
80
-static void ib_umem_notifier_start_account(struct ib_umem *item)
81
-{
82
- mutex_lock(&item->odp_data->umem_mutex);
83
-
84
- /* Only update private counters for this umem if it has them.
85
- * Otherwise skip it. All page faults will be delayed for this umem. */
86
- if (item->odp_data->mn_counters_active) {
87
- int notifiers_count = item->odp_data->notifiers_count++;
88
-
89
- if (notifiers_count == 0)
90
- /* Initialize the completion object for waiting on
91
- * notifiers. Since notifier_count is zero, no one
92
- * should be waiting right now. */
93
- reinit_completion(&item->odp_data->notifier_completion);
94
- }
95
- mutex_unlock(&item->odp_data->umem_mutex);
96
-}
97
-
98
-static void ib_umem_notifier_end_account(struct ib_umem *item)
99
-{
100
- mutex_lock(&item->odp_data->umem_mutex);
101
-
102
- /* Only update private counters for this umem if it has them.
103
- * Otherwise skip it. All page faults will be delayed for this umem. */
104
- if (item->odp_data->mn_counters_active) {
105
- /*
106
- * This sequence increase will notify the QP page fault that
107
- * the page that is going to be mapped in the spte could have
108
- * been freed.
109
- */
110
- ++item->odp_data->notifiers_seq;
111
- if (--item->odp_data->notifiers_count == 0)
112
- complete_all(&item->odp_data->notifier_completion);
113
- }
114
- mutex_unlock(&item->odp_data->umem_mutex);
115
-}
116
-
117
-/* Account for a new mmu notifier in an ib_ucontext. */
118
-static void ib_ucontext_notifier_start_account(struct ib_ucontext *context)
119
-{
120
- atomic_inc(&context->notifier_count);
121
-}
122
-
123
-/* Account for a terminating mmu notifier in an ib_ucontext.
124
- *
125
- * Must be called with the ib_ucontext->umem_rwsem semaphore unlocked, since
126
- * the function takes the semaphore itself. */
127
-static void ib_ucontext_notifier_end_account(struct ib_ucontext *context)
128
-{
129
- int zero_notifiers = atomic_dec_and_test(&context->notifier_count);
130
-
131
- if (zero_notifiers &&
132
- !list_empty(&context->no_private_counters)) {
133
- /* No currently running mmu notifiers. Now is the chance to
134
- * add private accounting to all previously added umems. */
135
- struct ib_umem_odp *odp_data, *next;
136
-
137
- /* Prevent concurrent mmu notifiers from working on the
138
- * no_private_counters list. */
139
- down_write(&context->umem_rwsem);
140
-
141
- /* Read the notifier_count again, with the umem_rwsem
142
- * semaphore taken for write. */
143
- if (!atomic_read(&context->notifier_count)) {
144
- list_for_each_entry_safe(odp_data, next,
145
- &context->no_private_counters,
146
- no_private_counters) {
147
- mutex_lock(&odp_data->umem_mutex);
148
- odp_data->mn_counters_active = true;
149
- list_del(&odp_data->no_private_counters);
150
- complete_all(&odp_data->notifier_completion);
151
- mutex_unlock(&odp_data->umem_mutex);
152
- }
153
- }
154
-
155
- up_write(&context->umem_rwsem);
156
- }
157
-}
158
-
159
-static int ib_umem_notifier_release_trampoline(struct ib_umem *item, u64 start,
160
- u64 end, void *cookie) {
161
- /*
162
- * Increase the number of notifiers running, to
163
- * prevent any further fault handling on this MR.
164
- */
165
- ib_umem_notifier_start_account(item);
166
- item->odp_data->dying = 1;
167
- /* Make sure that the fact the umem is dying is out before we release
168
- * all pending page faults. */
169
- smp_wmb();
170
- complete_all(&item->odp_data->notifier_completion);
171
- item->context->invalidate_range(item, ib_umem_start(item),
172
- ib_umem_end(item));
173
- return 0;
174
-}
175
-
176
-static void ib_umem_notifier_release(struct mmu_notifier *mn,
177
- struct mm_struct *mm)
178
-{
179
- struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn);
180
-
181
- if (!context->invalidate_range)
182
- return;
183
-
184
- ib_ucontext_notifier_start_account(context);
185
- down_read(&context->umem_rwsem);
186
- rbt_ib_umem_for_each_in_range(&context->umem_tree, 0,
187
- ULLONG_MAX,
188
- ib_umem_notifier_release_trampoline,
189
- true,
190
- NULL);
191
- up_read(&context->umem_rwsem);
192
-}
193
-
194
-static int invalidate_page_trampoline(struct ib_umem *item, u64 start,
195
- u64 end, void *cookie)
196
-{
197
- ib_umem_notifier_start_account(item);
198
- item->context->invalidate_range(item, start, start + PAGE_SIZE);
199
- ib_umem_notifier_end_account(item);
200
- return 0;
201
-}
202
-
203
-static int invalidate_range_start_trampoline(struct ib_umem *item, u64 start,
204
- u64 end, void *cookie)
205
-{
206
- ib_umem_notifier_start_account(item);
207
- item->context->invalidate_range(item, start, end);
208
- return 0;
209
-}
210
-
211
-static int ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn,
212
- struct mm_struct *mm,
213
- unsigned long start,
214
- unsigned long end,
215
- bool blockable)
216
-{
217
- struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn);
21855 int ret;
21956
220
- if (!context->invalidate_range)
221
- return 0;
57
+ umem_odp->umem.is_odp = 1;
58
+ mutex_init(&umem_odp->umem_mutex);
22259
223
- if (blockable)
224
- down_read(&context->umem_rwsem);
225
- else if (!down_read_trylock(&context->umem_rwsem))
226
- return -EAGAIN;
60
+ if (!umem_odp->is_implicit_odp) {
61
+ size_t page_size = 1UL << umem_odp->page_shift;
62
+ unsigned long start;
63
+ unsigned long end;
64
+ size_t ndmas, npfns;
22765
228
- ib_ucontext_notifier_start_account(context);
229
- ret = rbt_ib_umem_for_each_in_range(&context->umem_tree, start,
230
- end,
231
- invalidate_range_start_trampoline,
232
- blockable, NULL);
233
- up_read(&context->umem_rwsem);
66
+ start = ALIGN_DOWN(umem_odp->umem.address, page_size);
67
+ if (check_add_overflow(umem_odp->umem.address,
68
+ (unsigned long)umem_odp->umem.length,
69
+ &end))
70
+ return -EOVERFLOW;
71
+ end = ALIGN(end, page_size);
72
+ if (unlikely(end < page_size))
73
+ return -EOVERFLOW;
23474
75
+ ndmas = (end - start) >> umem_odp->page_shift;
76
+ if (!ndmas)
77
+ return -EINVAL;
78
+
79
+ npfns = (end - start) >> PAGE_SHIFT;
80
+ umem_odp->pfn_list = kvcalloc(
81
+ npfns, sizeof(*umem_odp->pfn_list), GFP_KERNEL);
82
+ if (!umem_odp->pfn_list)
83
+ return -ENOMEM;
84
+
85
+ umem_odp->dma_list = kvcalloc(
86
+ ndmas, sizeof(*umem_odp->dma_list), GFP_KERNEL);
87
+ if (!umem_odp->dma_list) {
88
+ ret = -ENOMEM;
89
+ goto out_pfn_list;
90
+ }
91
+
92
+ ret = mmu_interval_notifier_insert(&umem_odp->notifier,
93
+ umem_odp->umem.owning_mm,
94
+ start, end - start, ops);
95
+ if (ret)
96
+ goto out_dma_list;
97
+ }
98
+
99
+ return 0;
100
+
101
+out_dma_list:
102
+ kvfree(umem_odp->dma_list);
103
+out_pfn_list:
104
+ kvfree(umem_odp->pfn_list);
235105 return ret;
236106 }
237107
238
-static int invalidate_range_end_trampoline(struct ib_umem *item, u64 start,
239
- u64 end, void *cookie)
240
-{
241
- ib_umem_notifier_end_account(item);
242
- return 0;
243
-}
244
-
245
-static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
246
- struct mm_struct *mm,
247
- unsigned long start,
248
- unsigned long end)
249
-{
250
- struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn);
251
-
252
- if (!context->invalidate_range)
253
- return;
254
-
255
- /*
256
- * TODO: we currently bail out if there is any sleepable work to be done
257
- * in ib_umem_notifier_invalidate_range_start so we shouldn't really block
258
- * here. But this is ugly and fragile.
259
- */
260
- down_read(&context->umem_rwsem);
261
- rbt_ib_umem_for_each_in_range(&context->umem_tree, start,
262
- end,
263
- invalidate_range_end_trampoline, true, NULL);
264
- up_read(&context->umem_rwsem);
265
- ib_ucontext_notifier_end_account(context);
266
-}
267
-
268
-static const struct mmu_notifier_ops ib_umem_notifiers = {
269
- .release = ib_umem_notifier_release,
270
- .invalidate_range_start = ib_umem_notifier_invalidate_range_start,
271
- .invalidate_range_end = ib_umem_notifier_invalidate_range_end,
272
-};
273
-
274
-struct ib_umem *ib_alloc_odp_umem(struct ib_ucontext *context,
275
- unsigned long addr,
276
- size_t size)
108
+/**
109
+ * ib_umem_odp_alloc_implicit - Allocate a parent implicit ODP umem
110
+ *
111
+ * Implicit ODP umems do not have a VA range and do not have any page lists.
112
+ * They exist only to hold the per_mm reference to help the driver create
113
+ * children umems.
114
+ *
115
+ * @device: IB device to create UMEM
116
+ * @access: ib_reg_mr access flags
117
+ */
118
+struct ib_umem_odp *ib_umem_odp_alloc_implicit(struct ib_device *device,
119
+ int access)
277120 {
278121 struct ib_umem *umem;
279
- struct ib_umem_odp *odp_data;
280
- int pages = size >> PAGE_SHIFT;
122
+ struct ib_umem_odp *umem_odp;
281123 int ret;
282124
283
- umem = kzalloc(sizeof(*umem), GFP_KERNEL);
284
- if (!umem)
285
- return ERR_PTR(-ENOMEM);
125
+ if (access & IB_ACCESS_HUGETLB)
126
+ return ERR_PTR(-EINVAL);
286127
287
- umem->context = context;
288
- umem->length = size;
289
- umem->address = addr;
290
- umem->page_shift = PAGE_SHIFT;
291
- umem->writable = 1;
128
+ umem_odp = kzalloc(sizeof(*umem_odp), GFP_KERNEL);
129
+ if (!umem_odp)
130
+ return ERR_PTR(-ENOMEM);
131
+ umem = &umem_odp->umem;
132
+ umem->ibdev = device;
133
+ umem->writable = ib_access_writable(access);
134
+ umem->owning_mm = current->mm;
135
+ umem_odp->is_implicit_odp = 1;
136
+ umem_odp->page_shift = PAGE_SHIFT;
137
+
138
+ umem_odp->tgid = get_task_pid(current->group_leader, PIDTYPE_PID);
139
+ ret = ib_init_umem_odp(umem_odp, NULL);
140
+ if (ret) {
141
+ put_pid(umem_odp->tgid);
142
+ kfree(umem_odp);
143
+ return ERR_PTR(ret);
144
+ }
145
+ return umem_odp;
146
+}
147
+EXPORT_SYMBOL(ib_umem_odp_alloc_implicit);
148
+
149
+/**
150
+ * ib_umem_odp_alloc_child - Allocate a child ODP umem under an implicit
151
+ * parent ODP umem
152
+ *
153
+ * @root: The parent umem enclosing the child. This must be allocated using
154
+ * ib_alloc_implicit_odp_umem()
155
+ * @addr: The starting userspace VA
156
+ * @size: The length of the userspace VA
157
+ * @ops: MMU interval ops, currently only @invalidate
158
+ */
159
+struct ib_umem_odp *
160
+ib_umem_odp_alloc_child(struct ib_umem_odp *root, unsigned long addr,
161
+ size_t size,
162
+ const struct mmu_interval_notifier_ops *ops)
163
+{
164
+ /*
165
+ * Caller must ensure that root cannot be freed during the call to
166
+ * ib_alloc_odp_umem.
167
+ */
168
+ struct ib_umem_odp *odp_data;
169
+ struct ib_umem *umem;
170
+ int ret;
171
+
172
+ if (WARN_ON(!root->is_implicit_odp))
173
+ return ERR_PTR(-EINVAL);
292174
293175 odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
294
- if (!odp_data) {
295
- ret = -ENOMEM;
296
- goto out_umem;
297
- }
298
- odp_data->umem = umem;
176
+ if (!odp_data)
177
+ return ERR_PTR(-ENOMEM);
178
+ umem = &odp_data->umem;
179
+ umem->ibdev = root->umem.ibdev;
180
+ umem->length = size;
181
+ umem->address = addr;
182
+ umem->writable = root->umem.writable;
183
+ umem->owning_mm = root->umem.owning_mm;
184
+ odp_data->page_shift = PAGE_SHIFT;
185
+ odp_data->notifier.ops = ops;
299186
300
- mutex_init(&odp_data->umem_mutex);
301
- init_completion(&odp_data->notifier_completion);
302
-
303
- odp_data->page_list =
304
- vzalloc(array_size(pages, sizeof(*odp_data->page_list)));
305
- if (!odp_data->page_list) {
306
- ret = -ENOMEM;
307
- goto out_odp_data;
187
+ /*
188
+ * A mmget must be held when registering a notifier, the owming_mm only
189
+ * has a mm_grab at this point.
190
+ */
191
+ if (!mmget_not_zero(umem->owning_mm)) {
192
+ ret = -EFAULT;
193
+ goto out_free;
308194 }
309195
310
- odp_data->dma_list =
311
- vzalloc(array_size(pages, sizeof(*odp_data->dma_list)));
312
- if (!odp_data->dma_list) {
313
- ret = -ENOMEM;
314
- goto out_page_list;
315
- }
196
+ odp_data->tgid = get_pid(root->tgid);
197
+ ret = ib_init_umem_odp(odp_data, ops);
198
+ if (ret)
199
+ goto out_tgid;
200
+ mmput(umem->owning_mm);
201
+ return odp_data;
316202
317
- down_write(&context->umem_rwsem);
318
- context->odp_mrs_count++;
319
- rbt_ib_umem_insert(&odp_data->interval_tree, &context->umem_tree);
320
- if (likely(!atomic_read(&context->notifier_count)))
321
- odp_data->mn_counters_active = true;
322
- else
323
- list_add(&odp_data->no_private_counters,
324
- &context->no_private_counters);
325
- up_write(&context->umem_rwsem);
326
-
327
- umem->odp_data = odp_data;
328
-
329
- return umem;
330
-
331
-out_page_list:
332
- vfree(odp_data->page_list);
333
-out_odp_data:
203
+out_tgid:
204
+ put_pid(odp_data->tgid);
205
+ mmput(umem->owning_mm);
206
+out_free:
334207 kfree(odp_data);
335
-out_umem:
336
- kfree(umem);
337208 return ERR_PTR(ret);
338209 }
339
-EXPORT_SYMBOL(ib_alloc_odp_umem);
210
+EXPORT_SYMBOL(ib_umem_odp_alloc_child);
340211
341
-int ib_umem_odp_get(struct ib_ucontext *context, struct ib_umem *umem,
342
- int access)
212
+/**
213
+ * ib_umem_odp_get - Create a umem_odp for a userspace va
214
+ *
215
+ * @device: IB device struct to get UMEM
216
+ * @addr: userspace virtual address to start at
217
+ * @size: length of region to pin
218
+ * @access: IB_ACCESS_xxx flags for memory being pinned
219
+ * @ops: MMU interval ops, currently only @invalidate
220
+ *
221
+ * The driver should use when the access flags indicate ODP memory. It avoids
222
+ * pinning, instead, stores the mm for future page fault handling in
223
+ * conjunction with MMU notifiers.
224
+ */
225
+struct ib_umem_odp *ib_umem_odp_get(struct ib_device *device,
226
+ unsigned long addr, size_t size, int access,
227
+ const struct mmu_interval_notifier_ops *ops)
343228 {
344
- int ret_val;
345
- struct pid *our_pid;
346
- struct mm_struct *mm = get_task_mm(current);
229
+ struct ib_umem_odp *umem_odp;
230
+ struct mm_struct *mm;
231
+ int ret;
347232
348
- if (!mm)
349
- return -EINVAL;
233
+ if (WARN_ON_ONCE(!(access & IB_ACCESS_ON_DEMAND)))
234
+ return ERR_PTR(-EINVAL);
350235
351
- if (access & IB_ACCESS_HUGETLB) {
352
- struct vm_area_struct *vma;
353
- struct hstate *h;
236
+ umem_odp = kzalloc(sizeof(struct ib_umem_odp), GFP_KERNEL);
237
+ if (!umem_odp)
238
+ return ERR_PTR(-ENOMEM);
354239
355
- down_read(&mm->mmap_sem);
356
- vma = find_vma(mm, ib_umem_start(umem));
357
- if (!vma || !is_vm_hugetlb_page(vma)) {
358
- up_read(&mm->mmap_sem);
359
- ret_val = -EINVAL;
360
- goto out_mm;
361
- }
362
- h = hstate_vma(vma);
363
- umem->page_shift = huge_page_shift(h);
364
- up_read(&mm->mmap_sem);
365
- umem->hugetlb = 1;
366
- } else {
367
- umem->hugetlb = 0;
368
- }
240
+ umem_odp->umem.ibdev = device;
241
+ umem_odp->umem.length = size;
242
+ umem_odp->umem.address = addr;
243
+ umem_odp->umem.writable = ib_access_writable(access);
244
+ umem_odp->umem.owning_mm = mm = current->mm;
245
+ umem_odp->notifier.ops = ops;
369246
370
- /* Prevent creating ODP MRs in child processes */
371
- rcu_read_lock();
372
- our_pid = get_task_pid(current->group_leader, PIDTYPE_PID);
373
- rcu_read_unlock();
374
- put_pid(our_pid);
375
- if (context->tgid != our_pid) {
376
- ret_val = -EINVAL;
377
- goto out_mm;
378
- }
247
+ umem_odp->page_shift = PAGE_SHIFT;
248
+#ifdef CONFIG_HUGETLB_PAGE
249
+ if (access & IB_ACCESS_HUGETLB)
250
+ umem_odp->page_shift = HPAGE_SHIFT;
251
+#endif
379252
380
- umem->odp_data = kzalloc(sizeof(*umem->odp_data), GFP_KERNEL);
381
- if (!umem->odp_data) {
382
- ret_val = -ENOMEM;
383
- goto out_mm;
384
- }
385
- umem->odp_data->umem = umem;
253
+ umem_odp->tgid = get_task_pid(current->group_leader, PIDTYPE_PID);
254
+ ret = ib_init_umem_odp(umem_odp, ops);
255
+ if (ret)
256
+ goto err_put_pid;
257
+ return umem_odp;
386258
387
- mutex_init(&umem->odp_data->umem_mutex);
388
-
389
- init_completion(&umem->odp_data->notifier_completion);
390
-
391
- if (ib_umem_num_pages(umem)) {
392
- umem->odp_data->page_list =
393
- vzalloc(array_size(sizeof(*umem->odp_data->page_list),
394
- ib_umem_num_pages(umem)));
395
- if (!umem->odp_data->page_list) {
396
- ret_val = -ENOMEM;
397
- goto out_odp_data;
398
- }
399
-
400
- umem->odp_data->dma_list =
401
- vzalloc(array_size(sizeof(*umem->odp_data->dma_list),
402
- ib_umem_num_pages(umem)));
403
- if (!umem->odp_data->dma_list) {
404
- ret_val = -ENOMEM;
405
- goto out_page_list;
406
- }
407
- }
408
-
409
- /*
410
- * When using MMU notifiers, we will get a
411
- * notification before the "current" task (and MM) is
412
- * destroyed. We use the umem_rwsem semaphore to synchronize.
413
- */
414
- down_write(&context->umem_rwsem);
415
- context->odp_mrs_count++;
416
- if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
417
- rbt_ib_umem_insert(&umem->odp_data->interval_tree,
418
- &context->umem_tree);
419
- if (likely(!atomic_read(&context->notifier_count)) ||
420
- context->odp_mrs_count == 1)
421
- umem->odp_data->mn_counters_active = true;
422
- else
423
- list_add(&umem->odp_data->no_private_counters,
424
- &context->no_private_counters);
425
- downgrade_write(&context->umem_rwsem);
426
-
427
- if (context->odp_mrs_count == 1) {
428
- /*
429
- * Note that at this point, no MMU notifier is running
430
- * for this context!
431
- */
432
- atomic_set(&context->notifier_count, 0);
433
- INIT_HLIST_NODE(&context->mn.hlist);
434
- context->mn.ops = &ib_umem_notifiers;
435
- /*
436
- * Lock-dep detects a false positive for mmap_sem vs.
437
- * umem_rwsem, due to not grasping downgrade_write correctly.
438
- */
439
- lockdep_off();
440
- ret_val = mmu_notifier_register(&context->mn, mm);
441
- lockdep_on();
442
- if (ret_val) {
443
- pr_err("Failed to register mmu_notifier %d\n", ret_val);
444
- ret_val = -EBUSY;
445
- goto out_mutex;
446
- }
447
- }
448
-
449
- up_read(&context->umem_rwsem);
450
-
451
- /*
452
- * Note that doing an mmput can cause a notifier for the relevant mm.
453
- * If the notifier is called while we hold the umem_rwsem, this will
454
- * cause a deadlock. Therefore, we release the reference only after we
455
- * released the semaphore.
456
- */
457
- mmput(mm);
458
- return 0;
459
-
460
-out_mutex:
461
- up_read(&context->umem_rwsem);
462
- vfree(umem->odp_data->dma_list);
463
-out_page_list:
464
- vfree(umem->odp_data->page_list);
465
-out_odp_data:
466
- kfree(umem->odp_data);
467
-out_mm:
468
- mmput(mm);
469
- return ret_val;
259
+err_put_pid:
260
+ put_pid(umem_odp->tgid);
261
+ kfree(umem_odp);
262
+ return ERR_PTR(ret);
470263 }
264
+EXPORT_SYMBOL(ib_umem_odp_get);
471265
472
-void ib_umem_odp_release(struct ib_umem *umem)
266
+void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
473267 {
474
- struct ib_ucontext *context = umem->context;
475
-
476268 /*
477269 * Ensure that no more pages are mapped in the umem.
478270 *
479271 * It is the driver's responsibility to ensure, before calling us,
480272 * that the hardware will not attempt to access the MR any more.
481273 */
482
- ib_umem_odp_unmap_dma_pages(umem, ib_umem_start(umem),
483
- ib_umem_end(umem));
484
-
485
- down_write(&context->umem_rwsem);
486
- if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
487
- rbt_ib_umem_remove(&umem->odp_data->interval_tree,
488
- &context->umem_tree);
489
- context->odp_mrs_count--;
490
- if (!umem->odp_data->mn_counters_active) {
491
- list_del(&umem->odp_data->no_private_counters);
492
- complete_all(&umem->odp_data->notifier_completion);
274
+ if (!umem_odp->is_implicit_odp) {
275
+ mutex_lock(&umem_odp->umem_mutex);
276
+ ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem_odp),
277
+ ib_umem_end(umem_odp));
278
+ mutex_unlock(&umem_odp->umem_mutex);
279
+ mmu_interval_notifier_remove(&umem_odp->notifier);
280
+ kvfree(umem_odp->dma_list);
281
+ kvfree(umem_odp->pfn_list);
493282 }
494
-
495
- /*
496
- * Downgrade the lock to a read lock. This ensures that the notifiers
497
- * (who lock the mutex for reading) will be able to finish, and we
498
- * will be able to enventually obtain the mmu notifiers SRCU. Note
499
- * that since we are doing it atomically, no other user could register
500
- * and unregister while we do the check.
501
- */
502
- downgrade_write(&context->umem_rwsem);
503
- if (!context->odp_mrs_count) {
504
- struct task_struct *owning_process = NULL;
505
- struct mm_struct *owning_mm = NULL;
506
-
507
- owning_process = get_pid_task(context->tgid,
508
- PIDTYPE_PID);
509
- if (owning_process == NULL)
510
- /*
511
- * The process is already dead, notifier were removed
512
- * already.
513
- */
514
- goto out;
515
-
516
- owning_mm = get_task_mm(owning_process);
517
- if (owning_mm == NULL)
518
- /*
519
- * The process' mm is already dead, notifier were
520
- * removed already.
521
- */
522
- goto out_put_task;
523
- mmu_notifier_unregister(&context->mn, owning_mm);
524
-
525
- mmput(owning_mm);
526
-
527
-out_put_task:
528
- put_task_struct(owning_process);
529
- }
530
-out:
531
- up_read(&context->umem_rwsem);
532
-
533
- vfree(umem->odp_data->dma_list);
534
- vfree(umem->odp_data->page_list);
535
- kfree(umem->odp_data);
536
- kfree(umem);
283
+ put_pid(umem_odp->tgid);
284
+ kfree(umem_odp);
537285 }
286
+EXPORT_SYMBOL(ib_umem_odp_release);
538287
539288 /*
540289 * Map for DMA and insert a single page into the on-demand paging page tables.
541290 *
542291 * @umem: the umem to insert the page to.
543
- * @page_index: index in the umem to add the page to.
292
+ * @dma_index: index in the umem to add the dma to.
544293 * @page: the page struct to map and add.
545294 * @access_mask: access permissions needed for this page.
546295 * @current_seq: sequence number for synchronization with invalidations.
547296 * the sequence number is taken from
548
- * umem->odp_data->notifiers_seq.
297
+ * umem_odp->notifiers_seq.
549298 *
550
- * The function returns -EFAULT if the DMA mapping operation fails. It returns
551
- * -EAGAIN if a concurrent invalidation prevents us from updating the page.
299
+ * The function returns -EFAULT if the DMA mapping operation fails.
552300 *
553
- * The page is released via put_page even if the operation failed. For
554
- * on-demand pinning, the page is released whenever it isn't stored in the
555
- * umem.
556301 */
557302 static int ib_umem_odp_map_dma_single_page(
558
- struct ib_umem *umem,
559
- int page_index,
303
+ struct ib_umem_odp *umem_odp,
304
+ unsigned int dma_index,
560305 struct page *page,
561
- u64 access_mask,
562
- unsigned long current_seq)
306
+ u64 access_mask)
563307 {
564
- struct ib_device *dev = umem->context->device;
565
- dma_addr_t dma_addr;
566
- int stored_page = 0;
567
- int remove_existing_mapping = 0;
568
- int ret = 0;
308
+ struct ib_device *dev = umem_odp->umem.ibdev;
309
+ dma_addr_t *dma_addr = &umem_odp->dma_list[dma_index];
569310
570
- /*
571
- * Note: we avoid writing if seq is different from the initial seq, to
572
- * handle case of a racing notifier. This check also allows us to bail
573
- * early if we have a notifier running in parallel with us.
574
- */
575
- if (ib_umem_mmu_notifier_retry(umem, current_seq)) {
576
- ret = -EAGAIN;
577
- goto out;
578
- }
579
- if (!(umem->odp_data->dma_list[page_index])) {
580
- dma_addr = ib_dma_map_page(dev,
581
- page,
582
- 0, BIT(umem->page_shift),
583
- DMA_BIDIRECTIONAL);
584
- if (ib_dma_mapping_error(dev, dma_addr)) {
585
- ret = -EFAULT;
586
- goto out;
587
- }
588
- umem->odp_data->dma_list[page_index] = dma_addr | access_mask;
589
- umem->odp_data->page_list[page_index] = page;
590
- umem->npages++;
591
- stored_page = 1;
592
- } else if (umem->odp_data->page_list[page_index] == page) {
593
- umem->odp_data->dma_list[page_index] |= access_mask;
594
- } else {
595
- pr_err("error: got different pages in IB device and from get_user_pages. IB device page: %p, gup page: %p\n",
596
- umem->odp_data->page_list[page_index], page);
597
- /* Better remove the mapping now, to prevent any further
598
- * damage. */
599
- remove_existing_mapping = 1;
311
+ if (*dma_addr) {
312
+ /*
313
+ * If the page is already dma mapped it means it went through
314
+ * a non-invalidating trasition, like read-only to writable.
315
+ * Resync the flags.
316
+ */
317
+ *dma_addr = (*dma_addr & ODP_DMA_ADDR_MASK) | access_mask;
318
+ return 0;
600319 }
601320
602
-out:
603
- /* On Demand Paging - avoid pinning the page */
604
- if (umem->context->invalidate_range || !stored_page)
605
- put_page(page);
606
-
607
- if (remove_existing_mapping && umem->context->invalidate_range) {
608
- invalidate_page_trampoline(
609
- umem,
610
- ib_umem_start(umem) + (page_index >> umem->page_shift),
611
- ib_umem_start(umem) + ((page_index + 1) >>
612
- umem->page_shift),
613
- NULL);
614
- ret = -EAGAIN;
321
+ *dma_addr = ib_dma_map_page(dev, page, 0, 1 << umem_odp->page_shift,
322
+ DMA_BIDIRECTIONAL);
323
+ if (ib_dma_mapping_error(dev, *dma_addr)) {
324
+ *dma_addr = 0;
325
+ return -EFAULT;
615326 }
616
-
617
- return ret;
327
+ umem_odp->npages++;
328
+ *dma_addr |= access_mask;
329
+ return 0;
618330 }
619331
620332 /**
621
- * ib_umem_odp_map_dma_pages - Pin and DMA map userspace memory in an ODP MR.
333
+ * ib_umem_odp_map_dma_and_lock - DMA map userspace memory in an ODP MR and lock it.
622334 *
623
- * Pins the range of pages passed in the argument, and maps them to
624
- * DMA addresses. The DMA addresses of the mapped pages is updated in
625
- * umem->odp_data->dma_list.
335
+ * Maps the range passed in the argument to DMA addresses.
336
+ * The DMA addresses of the mapped pages is updated in umem_odp->dma_list.
337
+ * Upon success the ODP MR will be locked to let caller complete its device
338
+ * page table update.
626339 *
627340 * Returns the number of pages mapped in success, negative error code
628341 * for failure.
629
- * An -EAGAIN error code is returned when a concurrent mmu notifier prevents
630
- * the function from completing its task.
631
- * An -ENOENT error code indicates that userspace process is being terminated
632
- * and mm was already destroyed.
633
- * @umem: the umem to map and pin
342
+ * @umem_odp: the umem to map and pin
634343 * @user_virt: the address from which we need to map.
635344 * @bcnt: the minimal number of bytes to pin and map. The mapping might be
636345 * bigger due to alignment, and may also be smaller in case of an error
....@@ -638,150 +347,158 @@
638347 * the return value.
639348 * @access_mask: bit mask of the requested access permissions for the given
640349 * range.
641
- * @current_seq: the MMU notifiers sequance value for synchronization with
642
- * invalidations. the sequance number is read from
643
- * umem->odp_data->notifiers_seq before calling this function
350
+ * @fault: is faulting required for the given range
644351 */
645
-int ib_umem_odp_map_dma_pages(struct ib_umem *umem, u64 user_virt, u64 bcnt,
646
- u64 access_mask, unsigned long current_seq)
352
+int ib_umem_odp_map_dma_and_lock(struct ib_umem_odp *umem_odp, u64 user_virt,
353
+ u64 bcnt, u64 access_mask, bool fault)
354
+ __acquires(&umem_odp->umem_mutex)
647355 {
648356 struct task_struct *owning_process = NULL;
649
- struct mm_struct *owning_mm = NULL;
650
- struct page **local_page_list = NULL;
651
- u64 page_mask, off;
652
- int j, k, ret = 0, start_idx, npages = 0, page_shift;
653
- unsigned int flags = 0;
654
- phys_addr_t p = 0;
357
+ struct mm_struct *owning_mm = umem_odp->umem.owning_mm;
358
+ int pfn_index, dma_index, ret = 0, start_idx;
359
+ unsigned int page_shift, hmm_order, pfn_start_idx;
360
+ unsigned long num_pfns, current_seq;
361
+ struct hmm_range range = {};
362
+ unsigned long timeout;
655363
656364 if (access_mask == 0)
657365 return -EINVAL;
658366
659
- if (user_virt < ib_umem_start(umem) ||
660
- user_virt + bcnt > ib_umem_end(umem))
367
+ if (user_virt < ib_umem_start(umem_odp) ||
368
+ user_virt + bcnt > ib_umem_end(umem_odp))
661369 return -EFAULT;
662370
663
- local_page_list = (struct page **)__get_free_page(GFP_KERNEL);
664
- if (!local_page_list)
665
- return -ENOMEM;
371
+ page_shift = umem_odp->page_shift;
666372
667
- page_shift = umem->page_shift;
668
- page_mask = ~(BIT(page_shift) - 1);
669
- off = user_virt & (~page_mask);
670
- user_virt = user_virt & page_mask;
671
- bcnt += off; /* Charge for the first page offset as well. */
672
-
673
- owning_process = get_pid_task(umem->context->tgid, PIDTYPE_PID);
674
- if (owning_process == NULL) {
373
+ /*
374
+ * owning_process is allowed to be NULL, this means somehow the mm is
375
+ * existing beyond the lifetime of the originating process.. Presumably
376
+ * mmget_not_zero will fail in this case.
377
+ */
378
+ owning_process = get_pid_task(umem_odp->tgid, PIDTYPE_PID);
379
+ if (!owning_process || !mmget_not_zero(owning_mm)) {
675380 ret = -EINVAL;
676
- goto out_no_task;
677
- }
678
-
679
- owning_mm = get_task_mm(owning_process);
680
- if (owning_mm == NULL) {
681
- ret = -ENOENT;
682381 goto out_put_task;
683382 }
684383
685
- if (access_mask & ODP_WRITE_ALLOWED_BIT)
686
- flags |= FOLL_WRITE;
384
+ range.notifier = &umem_odp->notifier;
385
+ range.start = ALIGN_DOWN(user_virt, 1UL << page_shift);
386
+ range.end = ALIGN(user_virt + bcnt, 1UL << page_shift);
387
+ pfn_start_idx = (range.start - ib_umem_start(umem_odp)) >> PAGE_SHIFT;
388
+ num_pfns = (range.end - range.start) >> PAGE_SHIFT;
389
+ if (fault) {
390
+ range.default_flags = HMM_PFN_REQ_FAULT;
687391
688
- start_idx = (user_virt - ib_umem_start(umem)) >> page_shift;
689
- k = start_idx;
392
+ if (access_mask & ODP_WRITE_ALLOWED_BIT)
393
+ range.default_flags |= HMM_PFN_REQ_WRITE;
394
+ }
690395
691
- while (bcnt > 0) {
692
- const size_t gup_num_pages = min_t(size_t,
693
- ALIGN(bcnt, PAGE_SIZE) / PAGE_SIZE,
694
- PAGE_SIZE / sizeof(struct page *));
396
+ range.hmm_pfns = &(umem_odp->pfn_list[pfn_start_idx]);
397
+ timeout = jiffies + msecs_to_jiffies(HMM_RANGE_DEFAULT_TIMEOUT);
695398
696
- down_read(&owning_mm->mmap_sem);
697
- /*
698
- * Note: this might result in redundent page getting. We can
699
- * avoid this by checking dma_list to be 0 before calling
700
- * get_user_pages. However, this make the code much more
701
- * complex (and doesn't gain us much performance in most use
702
- * cases).
703
- */
704
- npages = get_user_pages_remote(owning_process, owning_mm,
705
- user_virt, gup_num_pages,
706
- flags, local_page_list, NULL, NULL);
707
- up_read(&owning_mm->mmap_sem);
399
+retry:
400
+ current_seq = range.notifier_seq =
401
+ mmu_interval_read_begin(&umem_odp->notifier);
708402
709
- if (npages < 0)
710
- break;
403
+ mmap_read_lock(owning_mm);
404
+ ret = hmm_range_fault(&range);
405
+ mmap_read_unlock(owning_mm);
406
+ if (unlikely(ret)) {
407
+ if (ret == -EBUSY && !time_after(jiffies, timeout))
408
+ goto retry;
409
+ goto out_put_mm;
410
+ }
711411
712
- bcnt -= min_t(size_t, npages << PAGE_SHIFT, bcnt);
713
- mutex_lock(&umem->odp_data->umem_mutex);
714
- for (j = 0; j < npages; j++, user_virt += PAGE_SIZE) {
715
- if (user_virt & ~page_mask) {
716
- p += PAGE_SIZE;
717
- if (page_to_phys(local_page_list[j]) != p) {
718
- ret = -EFAULT;
719
- break;
720
- }
721
- put_page(local_page_list[j]);
412
+ start_idx = (range.start - ib_umem_start(umem_odp)) >> page_shift;
413
+ dma_index = start_idx;
414
+
415
+ mutex_lock(&umem_odp->umem_mutex);
416
+ if (mmu_interval_read_retry(&umem_odp->notifier, current_seq)) {
417
+ mutex_unlock(&umem_odp->umem_mutex);
418
+ goto retry;
419
+ }
420
+
421
+ for (pfn_index = 0; pfn_index < num_pfns;
422
+ pfn_index += 1 << (page_shift - PAGE_SHIFT), dma_index++) {
423
+
424
+ if (fault) {
425
+ /*
426
+ * Since we asked for hmm_range_fault() to populate
427
+ * pages it shouldn't return an error entry on success.
428
+ */
429
+ WARN_ON(range.hmm_pfns[pfn_index] & HMM_PFN_ERROR);
430
+ WARN_ON(!(range.hmm_pfns[pfn_index] & HMM_PFN_VALID));
431
+ } else {
432
+ if (!(range.hmm_pfns[pfn_index] & HMM_PFN_VALID)) {
433
+ WARN_ON(umem_odp->dma_list[dma_index]);
722434 continue;
723435 }
724
-
725
- ret = ib_umem_odp_map_dma_single_page(
726
- umem, k, local_page_list[j],
727
- access_mask, current_seq);
728
- if (ret < 0)
729
- break;
730
-
731
- p = page_to_phys(local_page_list[j]);
732
- k++;
436
+ access_mask = ODP_READ_ALLOWED_BIT;
437
+ if (range.hmm_pfns[pfn_index] & HMM_PFN_WRITE)
438
+ access_mask |= ODP_WRITE_ALLOWED_BIT;
733439 }
734
- mutex_unlock(&umem->odp_data->umem_mutex);
735440
441
+ hmm_order = hmm_pfn_to_map_order(range.hmm_pfns[pfn_index]);
442
+ /* If a hugepage was detected and ODP wasn't set for, the umem
443
+ * page_shift will be used, the opposite case is an error.
444
+ */
445
+ if (hmm_order + PAGE_SHIFT < page_shift) {
446
+ ret = -EINVAL;
447
+ ibdev_dbg(umem_odp->umem.ibdev,
448
+ "%s: un-expected hmm_order %d, page_shift %d\n",
449
+ __func__, hmm_order, page_shift);
450
+ break;
451
+ }
452
+
453
+ ret = ib_umem_odp_map_dma_single_page(
454
+ umem_odp, dma_index, hmm_pfn_to_page(range.hmm_pfns[pfn_index]),
455
+ access_mask);
736456 if (ret < 0) {
737
- /* Release left over pages when handling errors. */
738
- for (++j; j < npages; ++j)
739
- put_page(local_page_list[j]);
457
+ ibdev_dbg(umem_odp->umem.ibdev,
458
+ "ib_umem_odp_map_dma_single_page failed with error %d\n", ret);
740459 break;
741460 }
742461 }
462
+ /* upon sucesss lock should stay on hold for the callee */
463
+ if (!ret)
464
+ ret = dma_index - start_idx;
465
+ else
466
+ mutex_unlock(&umem_odp->umem_mutex);
743467
744
- if (ret >= 0) {
745
- if (npages < 0 && k == start_idx)
746
- ret = npages;
747
- else
748
- ret = k - start_idx;
749
- }
750
-
751
- mmput(owning_mm);
468
+out_put_mm:
469
+ mmput_async(owning_mm);
752470 out_put_task:
753
- put_task_struct(owning_process);
754
-out_no_task:
755
- free_page((unsigned long)local_page_list);
471
+ if (owning_process)
472
+ put_task_struct(owning_process);
756473 return ret;
757474 }
758
-EXPORT_SYMBOL(ib_umem_odp_map_dma_pages);
475
+EXPORT_SYMBOL(ib_umem_odp_map_dma_and_lock);
759476
760
-void ib_umem_odp_unmap_dma_pages(struct ib_umem *umem, u64 virt,
477
+void ib_umem_odp_unmap_dma_pages(struct ib_umem_odp *umem_odp, u64 virt,
761478 u64 bound)
762479 {
480
+ dma_addr_t dma_addr;
481
+ dma_addr_t dma;
763482 int idx;
764483 u64 addr;
765
- struct ib_device *dev = umem->context->device;
484
+ struct ib_device *dev = umem_odp->umem.ibdev;
766485
767
- virt = max_t(u64, virt, ib_umem_start(umem));
768
- bound = min_t(u64, bound, ib_umem_end(umem));
769
- /* Note that during the run of this function, the
770
- * notifiers_count of the MR is > 0, preventing any racing
771
- * faults from completion. We might be racing with other
772
- * invalidations, so we must make sure we free each page only
773
- * once. */
774
- mutex_lock(&umem->odp_data->umem_mutex);
775
- for (addr = virt; addr < bound; addr += BIT(umem->page_shift)) {
776
- idx = (addr - ib_umem_start(umem)) >> umem->page_shift;
777
- if (umem->odp_data->page_list[idx]) {
778
- struct page *page = umem->odp_data->page_list[idx];
779
- dma_addr_t dma = umem->odp_data->dma_list[idx];
780
- dma_addr_t dma_addr = dma & ODP_DMA_ADDR_MASK;
486
+ lockdep_assert_held(&umem_odp->umem_mutex);
781487
782
- WARN_ON(!dma_addr);
488
+ virt = max_t(u64, virt, ib_umem_start(umem_odp));
489
+ bound = min_t(u64, bound, ib_umem_end(umem_odp));
490
+ for (addr = virt; addr < bound; addr += BIT(umem_odp->page_shift)) {
491
+ idx = (addr - ib_umem_start(umem_odp)) >> umem_odp->page_shift;
492
+ dma = umem_odp->dma_list[idx];
783493
784
- ib_dma_unmap_page(dev, dma_addr, PAGE_SIZE,
494
+ /* The access flags guaranteed a valid DMA address in case was NULL */
495
+ if (dma) {
496
+ unsigned long pfn_idx = (addr - ib_umem_start(umem_odp)) >> PAGE_SHIFT;
497
+ struct page *page = hmm_pfn_to_page(umem_odp->pfn_list[pfn_idx]);
498
+
499
+ dma_addr = dma & ODP_DMA_ADDR_MASK;
500
+ ib_dma_unmap_page(dev, dma_addr,
501
+ BIT(umem_odp->page_shift),
785502 DMA_BIDIRECTIONAL);
786503 if (dma & ODP_WRITE_ALLOWED_BIT) {
787504 struct page *head_page = compound_head(page);
....@@ -796,57 +513,9 @@
796513 */
797514 set_page_dirty(head_page);
798515 }
799
- /* on demand pinning support */
800
- if (!umem->context->invalidate_range)
801
- put_page(page);
802
- umem->odp_data->page_list[idx] = NULL;
803
- umem->odp_data->dma_list[idx] = 0;
804
- umem->npages--;
516
+ umem_odp->dma_list[idx] = 0;
517
+ umem_odp->npages--;
805518 }
806519 }
807
- mutex_unlock(&umem->odp_data->umem_mutex);
808520 }
809521 EXPORT_SYMBOL(ib_umem_odp_unmap_dma_pages);
810
-
811
-/* @last is not a part of the interval. See comment for function
812
- * node_last.
813
- */
814
-int rbt_ib_umem_for_each_in_range(struct rb_root_cached *root,
815
- u64 start, u64 last,
816
- umem_call_back cb,
817
- bool blockable,
818
- void *cookie)
819
-{
820
- int ret_val = 0;
821
- struct umem_odp_node *node, *next;
822
- struct ib_umem_odp *umem;
823
-
824
- if (unlikely(start == last))
825
- return ret_val;
826
-
827
- for (node = rbt_ib_umem_iter_first(root, start, last - 1);
828
- node; node = next) {
829
- /* TODO move the blockable decision up to the callback */
830
- if (!blockable)
831
- return -EAGAIN;
832
- next = rbt_ib_umem_iter_next(node, start, last - 1);
833
- umem = container_of(node, struct ib_umem_odp, interval_tree);
834
- ret_val = cb(umem->umem, start, last, cookie) || ret_val;
835
- }
836
-
837
- return ret_val;
838
-}
839
-EXPORT_SYMBOL(rbt_ib_umem_for_each_in_range);
840
-
841
-struct ib_umem_odp *rbt_ib_umem_lookup(struct rb_root_cached *root,
842
- u64 addr, u64 length)
843
-{
844
- struct umem_odp_node *node;
845
-
846
- node = rbt_ib_umem_iter_first(root, addr, addr + length - 1);
847
- if (node)
848
- return container_of(node, struct ib_umem_odp, interval_tree);
849
- return NULL;
850
-
851
-}
852
-EXPORT_SYMBOL(rbt_ib_umem_lookup);