From 102a0743326a03cd1a1202ceda21e175b7d3575c Mon Sep 17 00:00:00 2001
From: hc <hc@nodka.com>
Date: Tue, 20 Feb 2024 01:20:52 +0000
Subject: [PATCH] add new system file

---
 kernel/mm/mremap.c |  528 +++++++++++++++++++++++++++++++++++++++++++++++++++-------
 1 files changed, 461 insertions(+), 67 deletions(-)

diff --git a/kernel/mm/mremap.c b/kernel/mm/mremap.c
index 16630b5..b1263e9 100644
--- a/kernel/mm/mremap.c
+++ b/kernel/mm/mremap.c
@@ -30,12 +30,11 @@
 
 #include "internal.h"
 
-static pmd_t *get_old_pmd(struct mm_struct *mm, unsigned long addr)
+static pud_t *get_old_pud(struct mm_struct *mm, unsigned long addr)
 {
 	pgd_t *pgd;
 	p4d_t *p4d;
 	pud_t *pud;
-	pmd_t *pmd;
 
 	pgd = pgd_offset(mm, addr);
 	if (pgd_none_or_clear_bad(pgd))
@@ -49,6 +48,18 @@
 	if (pud_none_or_clear_bad(pud))
 		return NULL;
 
+	return pud;
+}
+
+static pmd_t *get_old_pmd(struct mm_struct *mm, unsigned long addr)
+{
+	pud_t *pud;
+	pmd_t *pmd;
+
+	pud = get_old_pud(mm, addr);
+	if (!pud)
+		return NULL;
+
 	pmd = pmd_offset(pud, addr);
 	if (pmd_none(*pmd))
 		return NULL;
@@ -56,19 +67,27 @@
 	return pmd;
 }
 
-static pmd_t *alloc_new_pmd(struct mm_struct *mm, struct vm_area_struct *vma,
+static pud_t *alloc_new_pud(struct mm_struct *mm, struct vm_area_struct *vma,
 			    unsigned long addr)
 {
 	pgd_t *pgd;
 	p4d_t *p4d;
-	pud_t *pud;
-	pmd_t *pmd;
 
 	pgd = pgd_offset(mm, addr);
 	p4d = p4d_alloc(mm, pgd, addr);
 	if (!p4d)
 		return NULL;
-	pud = pud_alloc(mm, p4d, addr);
+
+	return pud_alloc(mm, p4d, addr);
+}
+
+static pmd_t *alloc_new_pmd(struct mm_struct *mm, struct vm_area_struct *vma,
+			    unsigned long addr)
+{
+	pud_t *pud;
+	pmd_t *pmd;
+
+	pud = alloc_new_pud(mm, vma, addr);
 	if (!pud)
 		return NULL;
 
@@ -133,7 +152,7 @@
 	 * such races:
 	 *
 	 * - During exec() shift_arg_pages(), we use a specially tagged vma
-	 *   which rmap call sites look for using is_vma_temporary_stack().
+	 *   which rmap call sites look for using vma_is_temporary_stack().
 	 *
 	 * - During mremap(), new_vma is often known to be placed after vma
 	 *   in rmap traversal order. This ensures rmap will always observe
@@ -146,7 +165,7 @@
 
 	/*
 	 * We don't have to worry about the ordering of src and dst
-	 * pte locks because exclusive mmap_sem prevents deadlock.
+	 * pte locks because exclusive mmap_lock prevents deadlock.
 	 */
 	old_pte = pte_offset_map_lock(mm, old_pmd, old_addr, &old_ptl);
 	new_pte = pte_offset_map(new_pmd, new_addr);
@@ -191,63 +210,327 @@
 		drop_rmap_locks(vma);
 }
 
+#ifdef CONFIG_SPECULATIVE_PAGE_FAULT
+static inline bool trylock_vma_ref_count(struct vm_area_struct *vma)
+{
+	/*
+	 * If we have the only reference, swap the refcount to -1. This
+	 * will prevent other concurrent references by get_vma() for SPFs.
+	 */
+	return atomic_cmpxchg(&vma->vm_ref_count, 1, -1) == 1;
+}
+
+/*
+ * Restore the VMA reference count to 1 after a fast mremap.
+ */
+static inline void unlock_vma_ref_count(struct vm_area_struct *vma)
+{
+	/*
+	 * This should only be called after a corresponding,
+	 * successful trylock_vma_ref_count().
+	 */
+	VM_BUG_ON_VMA(atomic_cmpxchg(&vma->vm_ref_count, -1, 1) != -1,
+		      vma);
+}
+#else	/* !CONFIG_SPECULATIVE_PAGE_FAULT */
+static inline bool trylock_vma_ref_count(struct vm_area_struct *vma)
+{
+	return true;
+}
+static inline void unlock_vma_ref_count(struct vm_area_struct *vma)
+{
+}
+#endif	/* CONFIG_SPECULATIVE_PAGE_FAULT */
+
+#ifdef CONFIG_HAVE_MOVE_PMD
+static bool move_normal_pmd(struct vm_area_struct *vma, unsigned long old_addr,
+		  unsigned long new_addr, pmd_t *old_pmd, pmd_t *new_pmd)
+{
+	spinlock_t *old_ptl, *new_ptl;
+	struct mm_struct *mm = vma->vm_mm;
+	pmd_t pmd;
+
+	/*
+	 * The destination pmd shouldn't be established, free_pgtables()
+	 * should have released it.
+	 *
+	 * However, there's a case during execve() where we use mremap
+	 * to move the initial stack, and in that case the target area
+	 * may overlap the source area (always moving down).
+	 *
+	 * If everything is PMD-aligned, that works fine, as moving
+	 * each pmd down will clear the source pmd. But if we first
+	 * have a few 4kB-only pages that get moved down, and then
+	 * hit the "now the rest is PMD-aligned, let's do everything
+	 * one pmd at a time", we will still have the old (now empty
+	 * of any 4kB pages, but still there) PMD in the page table
+	 * tree.
+	 *
+	 * Warn on it once - because we really should try to figure
+	 * out how to do this better - but then say "I won't move
+	 * this pmd".
+	 *
+	 * One alternative might be to just unmap the target pmd at
+	 * this point, and verify that it really is empty. We'll see.
+	 */
+	if (WARN_ON_ONCE(!pmd_none(*new_pmd)))
+		return false;
+
+	/*
+	 * We hold both exclusive mmap_lock and rmap_lock at this point and
+	 * cannot block. If we cannot immediately take exclusive ownership
+	 * of the VMA fallback to the move_ptes().
+	 */
+	if (!trylock_vma_ref_count(vma))
+		return false;
+
+	/*
+	 * We don't have to worry about the ordering of src and dst
+	 * ptlocks because exclusive mmap_lock prevents deadlock.
+	 */
+	old_ptl = pmd_lock(vma->vm_mm, old_pmd);
+	new_ptl = pmd_lockptr(mm, new_pmd);
+	if (new_ptl != old_ptl)
+		spin_lock_nested(new_ptl, SINGLE_DEPTH_NESTING);
+
+	/* Clear the pmd */
+	pmd = *old_pmd;
+	pmd_clear(old_pmd);
+
+	VM_BUG_ON(!pmd_none(*new_pmd));
+
+	/* Set the new pmd */
+	set_pmd_at(mm, new_addr, new_pmd, pmd);
+	flush_tlb_range(vma, old_addr, old_addr + PMD_SIZE);
+	if (new_ptl != old_ptl)
+		spin_unlock(new_ptl);
+	spin_unlock(old_ptl);
+
+	unlock_vma_ref_count(vma);
+	return true;
+}
+#else
+static inline bool move_normal_pmd(struct vm_area_struct *vma,
+		unsigned long old_addr, unsigned long new_addr, pmd_t *old_pmd,
+		pmd_t *new_pmd)
+{
+	return false;
+}
+#endif
+
+#ifdef CONFIG_HAVE_MOVE_PUD
+static bool move_normal_pud(struct vm_area_struct *vma, unsigned long old_addr,
+		  unsigned long new_addr, pud_t *old_pud, pud_t *new_pud)
+{
+	spinlock_t *old_ptl, *new_ptl;
+	struct mm_struct *mm = vma->vm_mm;
+	pud_t pud;
+
+	/*
+	 * The destination pud shouldn't be established, free_pgtables()
+	 * should have released it.
+	 */
+	if (WARN_ON_ONCE(!pud_none(*new_pud)))
+		return false;
+
+	/*
+	 * We hold both exclusive mmap_lock and rmap_lock at this point and
+	 * cannot block. If we cannot immediately take exclusive ownership
+	 * of the VMA fallback to the move_ptes().
+	 */
+	if (!trylock_vma_ref_count(vma))
+		return false;
+
+	/*
+	 * We don't have to worry about the ordering of src and dst
+	 * ptlocks because exclusive mmap_lock prevents deadlock.
+	 */
+	old_ptl = pud_lock(vma->vm_mm, old_pud);
+	new_ptl = pud_lockptr(mm, new_pud);
+	if (new_ptl != old_ptl)
+		spin_lock_nested(new_ptl, SINGLE_DEPTH_NESTING);
+
+	/* Clear the pud */
+	pud = *old_pud;
+	pud_clear(old_pud);
+
+	VM_BUG_ON(!pud_none(*new_pud));
+
+	/* Set the new pud */
+	set_pud_at(mm, new_addr, new_pud, pud);
+	flush_tlb_range(vma, old_addr, old_addr + PUD_SIZE);
+	if (new_ptl != old_ptl)
+		spin_unlock(new_ptl);
+	spin_unlock(old_ptl);
+
+	unlock_vma_ref_count(vma);
+	return true;
+}
+#else
+static inline bool move_normal_pud(struct vm_area_struct *vma,
+		unsigned long old_addr, unsigned long new_addr, pud_t *old_pud,
+		pud_t *new_pud)
+{
+	return false;
+}
+#endif
+
+enum pgt_entry {
+	NORMAL_PMD,
+	HPAGE_PMD,
+	NORMAL_PUD,
+};
+
+/*
+ * Returns an extent of the corresponding size for the pgt_entry specified if
+ * valid. Else returns a smaller extent bounded by the end of the source and
+ * destination pgt_entry.
+ */
+static __always_inline unsigned long get_extent(enum pgt_entry entry,
+			unsigned long old_addr, unsigned long old_end,
+			unsigned long new_addr)
+{
+	unsigned long next, extent, mask, size;
+
+	switch (entry) {
+	case HPAGE_PMD:
+	case NORMAL_PMD:
+		mask = PMD_MASK;
+		size = PMD_SIZE;
+		break;
+	case NORMAL_PUD:
+		mask = PUD_MASK;
+		size = PUD_SIZE;
+		break;
+	default:
+		BUILD_BUG();
+		break;
+	}
+
+	next = (old_addr + size) & mask;
+	/* even if next overflowed, extent below will be ok */
+	extent = next - old_addr;
+	if (extent > old_end - old_addr)
+		extent = old_end - old_addr;
+	next = (new_addr + size) & mask;
+	if (extent > next - new_addr)
+		extent = next - new_addr;
+	return extent;
+}
+
+/*
+ * Attempts to speedup the move by moving entry at the level corresponding to
+ * pgt_entry. Returns true if the move was successful, else false.
+ */
+static bool move_pgt_entry(enum pgt_entry entry, struct vm_area_struct *vma,
+			unsigned long old_addr, unsigned long new_addr,
+			void *old_entry, void *new_entry, bool need_rmap_locks)
+{
+	bool moved = false;
+
+	/* See comment in move_ptes() */
+	if (need_rmap_locks)
+		take_rmap_locks(vma);
+
+	switch (entry) {
+	case NORMAL_PMD:
+		moved = move_normal_pmd(vma, old_addr, new_addr, old_entry,
+					new_entry);
+		break;
+	case NORMAL_PUD:
+		moved = move_normal_pud(vma, old_addr, new_addr, old_entry,
+					new_entry);
+		break;
+	case HPAGE_PMD:
+		moved = IS_ENABLED(CONFIG_TRANSPARENT_HUGEPAGE) &&
+			move_huge_pmd(vma, old_addr, new_addr, old_entry,
+				      new_entry);
+		break;
+	default:
+		WARN_ON_ONCE(1);
+		break;
+	}
+
+	if (need_rmap_locks)
+		drop_rmap_locks(vma);
+
+	return moved;
+}
+
 unsigned long move_page_tables(struct vm_area_struct *vma,
 		unsigned long old_addr, struct vm_area_struct *new_vma,
 		unsigned long new_addr, unsigned long len,
 		bool need_rmap_locks)
 {
-	unsigned long extent, next, old_end;
+	unsigned long extent, old_end;
+	struct mmu_notifier_range range;
 	pmd_t *old_pmd, *new_pmd;
-	unsigned long mmun_start;	/* For mmu_notifiers */
-	unsigned long mmun_end;		/* For mmu_notifiers */
+
+	if (!len)
+		return 0;
 
 	old_end = old_addr + len;
 	flush_cache_range(vma, old_addr, old_end);
 
-	mmun_start = old_addr;
-	mmun_end   = old_end;
-	mmu_notifier_invalidate_range_start(vma->vm_mm, mmun_start, mmun_end);
+	mmu_notifier_range_init(&range, MMU_NOTIFY_UNMAP, 0, vma, vma->vm_mm,
+				old_addr, old_end);
+	mmu_notifier_invalidate_range_start(&range);
 
 	for (; old_addr < old_end; old_addr += extent, new_addr += extent) {
 		cond_resched();
-		next = (old_addr + PMD_SIZE) & PMD_MASK;
-		/* even if next overflowed, extent below will be ok */
-		extent = next - old_addr;
-		if (extent > old_end - old_addr)
-			extent = old_end - old_addr;
+		/*
+		 * If extent is PUD-sized try to speed up the move by moving at the
+		 * PUD level if possible.
+		 */
+		extent = get_extent(NORMAL_PUD, old_addr, old_end, new_addr);
+		if (IS_ENABLED(CONFIG_HAVE_MOVE_PUD) && extent == PUD_SIZE) {
+			pud_t *old_pud, *new_pud;
+
+			old_pud = get_old_pud(vma->vm_mm, old_addr);
+			if (!old_pud)
+				continue;
+			new_pud = alloc_new_pud(vma->vm_mm, vma, new_addr);
+			if (!new_pud)
+				break;
+			if (move_pgt_entry(NORMAL_PUD, vma, old_addr, new_addr,
+					   old_pud, new_pud, true))
+				continue;
+		}
+
+		extent = get_extent(NORMAL_PMD, old_addr, old_end, new_addr);
 		old_pmd = get_old_pmd(vma->vm_mm, old_addr);
 		if (!old_pmd)
 			continue;
 		new_pmd = alloc_new_pmd(vma->vm_mm, vma, new_addr);
 		if (!new_pmd)
 			break;
-		if (is_swap_pmd(*old_pmd) || pmd_trans_huge(*old_pmd) || pmd_devmap(*old_pmd)) {
-			if (extent == HPAGE_PMD_SIZE) {
-				bool moved;
-				/* See comment in move_ptes() */
-				if (need_rmap_locks)
-					take_rmap_locks(vma);
-				moved = move_huge_pmd(vma, old_addr, new_addr,
-						    old_end, old_pmd, new_pmd);
-				if (need_rmap_locks)
-					drop_rmap_locks(vma);
-				if (moved)
-					continue;
-			}
+		if (is_swap_pmd(*old_pmd) || pmd_trans_huge(*old_pmd) ||
+		    pmd_devmap(*old_pmd)) {
+			if (extent == HPAGE_PMD_SIZE &&
+			    move_pgt_entry(HPAGE_PMD, vma, old_addr, new_addr,
+					   old_pmd, new_pmd, need_rmap_locks))
+				continue;
 			split_huge_pmd(vma, old_pmd, old_addr);
 			if (pmd_trans_unstable(old_pmd))
 				continue;
+		} else if (IS_ENABLED(CONFIG_HAVE_MOVE_PMD) &&
+			   extent == PMD_SIZE) {
+			/*
+			 * If the extent is PMD-sized, try to speed the move by
+			 * moving at the PMD level if possible.
+			 */
+			if (move_pgt_entry(NORMAL_PMD, vma, old_addr, new_addr,
+					   old_pmd, new_pmd, true))
+				continue;
 		}
-		if (pte_alloc(new_vma->vm_mm, new_pmd, new_addr))
+
+		if (pte_alloc(new_vma->vm_mm, new_pmd))
 			break;
-		next = (new_addr + PMD_SIZE) & PMD_MASK;
-		if (extent > next - new_addr)
-			extent = next - new_addr;
 		move_ptes(vma, old_pmd, old_addr, old_addr + extent, new_vma,
 			  new_pmd, new_addr, need_rmap_locks);
 	}
 
-	mmu_notifier_invalidate_range_end(vma->vm_mm, mmun_start, mmun_end);
+	mmu_notifier_invalidate_range_end(&range);
 
 	return len + old_addr - old_end;	/* how much done */
 }
@@ -255,8 +538,8 @@
 static unsigned long move_vma(struct vm_area_struct *vma,
 		unsigned long old_addr, unsigned long old_len,
 		unsigned long new_len, unsigned long new_addr,
-		bool *locked, struct vm_userfaultfd_ctx *uf,
-		struct list_head *uf_unmap)
+		bool *locked, unsigned long flags,
+		struct vm_userfaultfd_ctx *uf, struct list_head *uf_unmap)
 {
 	struct mm_struct *mm = vma->vm_mm;
 	struct vm_area_struct *new_vma;
@@ -294,6 +577,14 @@
 	if (!new_vma)
 		return -ENOMEM;
 
+	/* new_vma is returned protected by copy_vma, to prevent speculative
+	 * page fault to be done in the destination area before we move the pte.
+	 * Now, we must also protect the source VMA since we don't want pages
+	 * to be mapped in our back while we are copying the PTEs.
+	 */
+	if (vma != new_vma)
+		vm_write_begin(vma);
+
 	moved_len = move_page_tables(vma, old_addr, new_vma, new_addr, old_len,
 				     need_rmap_locks);
 	if (moved_len < old_len) {
@@ -310,6 +601,8 @@
 		 */
 		move_page_tables(new_vma, new_addr, vma, old_addr, moved_len,
 				 true);
+		if (vma != new_vma)
+			vm_write_end(vma);
 		vma = new_vma;
 		old_len = new_len;
 		old_addr = new_addr;
@@ -318,7 +611,10 @@
 		mremap_userfaultfd_prep(new_vma, uf);
 		arch_remap(mm, old_addr, old_addr + old_len,
 			   new_addr, new_addr + new_len);
+		if (vma != new_vma)
+			vm_write_end(vma);
 	}
+	vm_write_end(new_vma);
 
 	/* Conceal VM_ACCOUNT so old reservation is not undone */
 	if (vm_flags & VM_ACCOUNT) {
@@ -345,11 +641,43 @@
 	if (unlikely(vma->vm_flags & VM_PFNMAP))
 		untrack_pfn_moved(vma);
 
+	if (unlikely(!err && (flags & MREMAP_DONTUNMAP))) {
+		if (vm_flags & VM_ACCOUNT) {
+			/* Always put back VM_ACCOUNT since we won't unmap */
+			vma->vm_flags |= VM_ACCOUNT;
+
+			vm_acct_memory(new_len >> PAGE_SHIFT);
+		}
+
+		/*
+		 * VMAs can actually be merged back together in copy_vma
+		 * calling merge_vma. This can happen with anonymous vmas
+		 * which have not yet been faulted, so if we were to consider
+		 * this VMA split we'll end up adding VM_ACCOUNT on the
+		 * next VMA, which is completely unrelated if this VMA
+		 * was re-merged.
+		 */
+		if (split && new_vma == vma)
+			split = 0;
+
+		/* We always clear VM_LOCKED[ONFAULT] on the old vma */
+		vma->vm_flags &= VM_LOCKED_CLEAR_MASK;
+
+		/* Because we won't unmap we don't need to touch locked_vm */
+		goto out;
+	}
+
 	if (do_munmap(mm, old_addr, old_len, uf_unmap) < 0) {
 		/* OOM: unable to split vma, just get accounts right */
 		vm_unacct_memory(excess >> PAGE_SHIFT);
 		excess = 0;
 	}
+
+	if (vm_flags & VM_LOCKED) {
+		mm->locked_vm += new_len >> PAGE_SHIFT;
+		*locked = true;
+	}
+out:
 	mm->hiwater_vm = hiwater_vm;
 
 	/* Restore VM_ACCOUNT if one or two pieces of vma left */
@@ -359,16 +687,12 @@
 			vma->vm_next->vm_flags |= VM_ACCOUNT;
 	}
 
-	if (vm_flags & VM_LOCKED) {
-		mm->locked_vm += new_len >> PAGE_SHIFT;
-		*locked = true;
-	}
-
 	return new_addr;
 }
 
 static struct vm_area_struct *vma_to_resize(unsigned long addr,
-	unsigned long old_len, unsigned long new_len, unsigned long *p)
+	unsigned long old_len, unsigned long new_len, unsigned long flags,
+	unsigned long *p)
 {
 	struct mm_struct *mm = current->mm;
 	struct vm_area_struct *vma = find_vma(mm, addr);
@@ -389,6 +713,10 @@
 		pr_warn_once("%s (%d): attempted to duplicate a private mapping with mremap.  This is not supported.\n", current->comm, current->pid);
 		return ERR_PTR(-EINVAL);
 	}
+
+	if ((flags & MREMAP_DONTUNMAP) &&
+			(vma->vm_flags & (VM_DONTEXPAND | VM_PFNMAP)))
+		return ERR_PTR(-EINVAL);
 
 	if (is_vm_hugetlb_page(vma))
 		return ERR_PTR(-EINVAL);
@@ -434,7 +762,7 @@
 
 static unsigned long mremap_to(unsigned long addr, unsigned long old_len,
 		unsigned long new_addr, unsigned long new_len, bool *locked,
-		struct vm_userfaultfd_ctx *uf,
+		unsigned long flags, struct vm_userfaultfd_ctx *uf,
 		struct list_head *uf_unmap_early,
 		struct list_head *uf_unmap)
 {
@@ -442,7 +770,7 @@
 	struct vm_area_struct *vma;
 	unsigned long ret = -EINVAL;
 	unsigned long charged = 0;
-	unsigned long map_flags;
+	unsigned long map_flags = 0;
 
 	if (offset_in_page(new_addr))
 		goto out;
@@ -454,9 +782,28 @@
 	if (addr + old_len > new_addr && new_addr + new_len > addr)
 		goto out;
 
-	ret = do_munmap(mm, new_addr, new_len, uf_unmap_early);
-	if (ret)
-		goto out;
+	/*
+	 * move_vma() need us to stay 4 maps below the threshold, otherwise
+	 * it will bail out at the very beginning.
+	 * That is a problem if we have already unmaped the regions here
+	 * (new_addr, and old_addr), because userspace will not know the
+	 * state of the vma's after it gets -ENOMEM.
+	 * So, to avoid such scenario we can pre-compute if the whole
+	 * operation has high chances to success map-wise.
+	 * Worst-scenario case is when both vma's (new_addr and old_addr) get
+	 * split in 3 before unmaping it.
+	 * That means 2 more maps (1 for each) to the ones we already hold.
+	 * Check whether current map count plus 2 still leads us to 4 maps below
+	 * the threshold, otherwise return -ENOMEM here to be more safe.
+	 */
+	if ((mm->map_count + 2) >= sysctl_max_map_count - 3)
+		return -ENOMEM;
+
+	if (flags & MREMAP_FIXED) {
+		ret = do_munmap(mm, new_addr, new_len, uf_unmap_early);
+		if (ret)
+			goto out;
+	}
 
 	if (old_len >= new_len) {
 		ret = do_munmap(mm, addr+new_len, old_len - new_len, uf_unmap);
@@ -465,26 +812,41 @@
 		old_len = new_len;
 	}
 
-	vma = vma_to_resize(addr, old_len, new_len, &charged);
+	vma = vma_to_resize(addr, old_len, new_len, flags, &charged);
 	if (IS_ERR(vma)) {
 		ret = PTR_ERR(vma);
 		goto out;
 	}
 
-	map_flags = MAP_FIXED;
+	/* MREMAP_DONTUNMAP expands by old_len since old_len == new_len */
+	if (flags & MREMAP_DONTUNMAP &&
+		!may_expand_vm(mm, vma->vm_flags, old_len >> PAGE_SHIFT)) {
+		ret = -ENOMEM;
+		goto out;
+	}
+
+	if (flags & MREMAP_FIXED)
+		map_flags |= MAP_FIXED;
+
 	if (vma->vm_flags & VM_MAYSHARE)
 		map_flags |= MAP_SHARED;
 
 	ret = get_unmapped_area(vma->vm_file, new_addr, new_len, vma->vm_pgoff +
 				((addr - vma->vm_start) >> PAGE_SHIFT),
 				map_flags);
-	if (offset_in_page(ret))
+	if (IS_ERR_VALUE(ret))
 		goto out1;
 
-	ret = move_vma(vma, addr, old_len, new_len, new_addr, locked, uf,
+	/* We got a new mapping */
+	if (!(flags & MREMAP_FIXED))
+		new_addr = ret;
+
+	ret = move_vma(vma, addr, old_len, new_len, new_addr, locked, flags, uf,
 		       uf_unmap);
+
 	if (!(offset_in_page(ret)))
 		goto out;
+
 out1:
 	vm_unacct_memory(charged);
 
@@ -521,17 +883,37 @@
 	unsigned long ret = -EINVAL;
 	unsigned long charged = 0;
 	bool locked = false;
+	bool downgraded = false;
 	struct vm_userfaultfd_ctx uf = NULL_VM_UFFD_CTX;
 	LIST_HEAD(uf_unmap_early);
 	LIST_HEAD(uf_unmap);
 
+	/*
+	 * There is a deliberate asymmetry here: we strip the pointer tag
+	 * from the old address but leave the new address alone. This is
+	 * for consistency with mmap(), where we prevent the creation of
+	 * aliasing mappings in userspace by leaving the tag bits of the
+	 * mapping address intact. A non-zero tag will cause the subsequent
+	 * range checks to reject the address as invalid.
+	 *
+	 * See Documentation/arm64/tagged-address-abi.rst for more information.
+	 */
 	addr = untagged_addr(addr);
 
-	if (flags & ~(MREMAP_FIXED | MREMAP_MAYMOVE))
+	if (flags & ~(MREMAP_FIXED | MREMAP_MAYMOVE | MREMAP_DONTUNMAP))
 		return ret;
 
 	if (flags & MREMAP_FIXED && !(flags & MREMAP_MAYMOVE))
 		return ret;
+
+	/*
+	 * MREMAP_DONTUNMAP is always a move and it does not allow resizing
+	 * in the process.
+	 */
+	if (flags & MREMAP_DONTUNMAP &&
+			(!(flags & MREMAP_MAYMOVE) || old_len != new_len))
+		return ret;
+
 
 	if (offset_in_page(addr))
 		return ret;
@@ -547,24 +929,33 @@
 	if (!new_len)
 		return ret;
 
-	if (down_write_killable(&current->mm->mmap_sem))
+	if (mmap_write_lock_killable(current->mm))
 		return -EINTR;
 
-	if (flags & MREMAP_FIXED) {
+	if (flags & (MREMAP_FIXED | MREMAP_DONTUNMAP)) {
 		ret = mremap_to(addr, old_len, new_addr, new_len,
-				&locked, &uf, &uf_unmap_early, &uf_unmap);
+				&locked, flags, &uf, &uf_unmap_early,
+				&uf_unmap);
 		goto out;
 	}
 
 	/*
 	 * Always allow a shrinking remap: that just unmaps
 	 * the unnecessary pages..
-	 * do_munmap does all the needed commit accounting
+	 * __do_munmap does all the needed commit accounting, and
+	 * downgrades mmap_lock to read if so directed.
 	 */
 	if (old_len >= new_len) {
-		ret = do_munmap(mm, addr+new_len, old_len - new_len, &uf_unmap);
-		if (ret && old_len != new_len)
+		int retval;
+
+		retval = __do_munmap(mm, addr+new_len, old_len - new_len,
+				  &uf_unmap, true);
+		if (retval < 0 && old_len != new_len) {
+			ret = retval;
 			goto out;
+		/* Returning 1 indicates mmap_lock is downgraded to read. */
+		} else if (retval == 1)
+			downgraded = true;
 		ret = addr;
 		goto out;
 	}
@@ -572,7 +963,7 @@
 	/*
 	 * Ok, we need to grow..
 	 */
-	vma = vma_to_resize(addr, old_len, new_len, &charged);
+	vma = vma_to_resize(addr, old_len, new_len, flags, &charged);
 	if (IS_ERR(vma)) {
 		ret = PTR_ERR(vma);
 		goto out;
@@ -616,24 +1007,27 @@
 					vma->vm_pgoff +
 					((addr - vma->vm_start) >> PAGE_SHIFT),
 					map_flags);
-		if (offset_in_page(new_addr)) {
+		if (IS_ERR_VALUE(new_addr)) {
 			ret = new_addr;
 			goto out;
 		}
 
 		ret = move_vma(vma, addr, old_len, new_len, new_addr,
-			       &locked, &uf, &uf_unmap);
+			       &locked, flags, &uf, &uf_unmap);
 	}
 out:
 	if (offset_in_page(ret)) {
 		vm_unacct_memory(charged);
-		locked = 0;
+		locked = false;
 	}
-	up_write(&current->mm->mmap_sem);
+	if (downgraded)
+		mmap_read_unlock(current->mm);
+	else
+		mmap_write_unlock(current->mm);
 	if (locked && new_len > old_len)
 		mm_populate(new_addr + old_len, new_len - old_len);
 	userfaultfd_unmap_complete(mm, &uf_unmap_early);
-	mremap_userfaultfd_complete(&uf, addr, new_addr, old_len);
+	mremap_userfaultfd_complete(&uf, addr, ret, old_len);
 	userfaultfd_unmap_complete(mm, &uf_unmap);
 	return ret;
 }

--
Gitblit v1.6.2