Commit 69dbe6da authored by Liam R. Howlett's avatar Liam R. Howlett Committed by Andrew Morton
Browse files

userfaultfd: use maple tree iterator to iterate VMAs

Don't use the mm_struct linked list or the vma->vm_next in prep for
removal.

Link: https://lkml.kernel.org/r/20220906194824.2110408-45-Liam.Howlett@oracle.com


Signed-off-by: default avatarLiam R. Howlett <Liam.Howlett@Oracle.com>
Tested-by: default avatarYu Zhao <yuzhao@google.com>
Cc: Catalin Marinas <catalin.marinas@arm.com>
Cc: David Hildenbrand <david@redhat.com>
Cc: David Howells <dhowells@redhat.com>
Cc: Davidlohr Bueso <dave@stgolabs.net>
Cc: "Matthew Wilcox (Oracle)" <willy@infradead.org>
Cc: SeongJae Park <sj@kernel.org>
Cc: Sven Schnelle <svens@linux.ibm.com>
Cc: Vlastimil Babka <vbabka@suse.cz>
Cc: Will Deacon <will@kernel.org>
Signed-off-by: default avatarAndrew Morton <akpm@linux-foundation.org>
parent c4c84f06
Loading
Loading
Loading
Loading
+42 −20
Original line number Diff line number Diff line
@@ -611,14 +611,16 @@ static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx,
	if (release_new_ctx) {
		struct vm_area_struct *vma;
		struct mm_struct *mm = release_new_ctx->mm;
		VMA_ITERATOR(vmi, mm, 0);

		/* the various vma->vm_userfaultfd_ctx still points to it */
		mmap_write_lock(mm);
		for (vma = mm->mmap; vma; vma = vma->vm_next)
		for_each_vma(vmi, vma) {
			if (vma->vm_userfaultfd_ctx.ctx == release_new_ctx) {
				vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
				vma->vm_flags &= ~__VM_UFFD_FLAGS;
			}
		}
		mmap_write_unlock(mm);

		userfaultfd_ctx_put(release_new_ctx);
@@ -799,11 +801,13 @@ static bool has_unmap_ctx(struct userfaultfd_ctx *ctx, struct list_head *unmaps,
	return false;
}

int userfaultfd_unmap_prep(struct vm_area_struct *vma,
			   unsigned long start, unsigned long end,
			   struct list_head *unmaps)
int userfaultfd_unmap_prep(struct mm_struct *mm, unsigned long start,
			   unsigned long end, struct list_head *unmaps)
{
	for ( ; vma && vma->vm_start < end; vma = vma->vm_next) {
	VMA_ITERATOR(vmi, mm, start);
	struct vm_area_struct *vma;

	for_each_vma_range(vmi, vma, end) {
		struct userfaultfd_unmap_ctx *unmap_ctx;
		struct userfaultfd_ctx *ctx = vma->vm_userfaultfd_ctx.ctx;

@@ -853,6 +857,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
	/* len == 0 means wake all */
	struct userfaultfd_wake_range range = { .len = 0, };
	unsigned long new_flags;
	MA_STATE(mas, &mm->mm_mt, 0, 0);

	WRITE_ONCE(ctx->released, true);

@@ -869,7 +874,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
	 */
	mmap_write_lock(mm);
	prev = NULL;
	for (vma = mm->mmap; vma; vma = vma->vm_next) {
	mas_for_each(&mas, vma, ULONG_MAX) {
		cond_resched();
		BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^
		       !!(vma->vm_flags & __VM_UFFD_FLAGS));
@@ -883,10 +888,13 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
				 vma->vm_file, vma->vm_pgoff,
				 vma_policy(vma),
				 NULL_VM_UFFD_CTX, anon_vma_name(vma));
		if (prev)
		if (prev) {
			mas_pause(&mas);
			vma = prev;
		else
		} else {
			prev = vma;
		}

		vma->vm_flags = new_flags;
		vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
	}
@@ -1268,6 +1276,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
	bool found;
	bool basic_ioctls;
	unsigned long start, end, vma_end;
	MA_STATE(mas, &mm->mm_mt, 0, 0);

	user_uffdio_register = (struct uffdio_register __user *) arg;

@@ -1310,7 +1319,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
		goto out;

	mmap_write_lock(mm);
	vma = find_vma_prev(mm, start, &prev);
	mas_set(&mas, start);
	vma = mas_find(&mas, ULONG_MAX);
	if (!vma)
		goto out_unlock;

@@ -1335,7 +1345,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
	 */
	found = false;
	basic_ioctls = false;
	for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) {
	for (cur = vma; cur; cur = mas_next(&mas, end - 1)) {
		cond_resched();

		BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^
@@ -1395,8 +1405,10 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
	}
	BUG_ON(!found);

	if (vma->vm_start < start)
		prev = vma;
	mas_set(&mas, start);
	prev = mas_prev(&mas, 0);
	if (prev != vma)
		mas_next(&mas, ULONG_MAX);

	ret = 0;
	do {
@@ -1426,6 +1438,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
				 ((struct vm_userfaultfd_ctx){ ctx }),
				 anon_vma_name(vma));
		if (prev) {
			/* vma_merge() invalidated the mas */
			mas_pause(&mas);
			vma = prev;
			goto next;
		}
@@ -1433,11 +1447,15 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
			ret = split_vma(mm, vma, start, 1);
			if (ret)
				break;
			/* split_vma() invalidated the mas */
			mas_pause(&mas);
		}
		if (vma->vm_end > end) {
			ret = split_vma(mm, vma, end, 0);
			if (ret)
				break;
			/* split_vma() invalidated the mas */
			mas_pause(&mas);
		}
	next:
		/*
@@ -1454,8 +1472,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
	skip:
		prev = vma;
		start = vma->vm_end;
		vma = vma->vm_next;
	} while (vma && vma->vm_start < end);
		vma = mas_next(&mas, end - 1);
	} while (vma);
out_unlock:
	mmap_write_unlock(mm);
	mmput(mm);
@@ -1499,6 +1517,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
	bool found;
	unsigned long start, end, vma_end;
	const void __user *buf = (void __user *)arg;
	MA_STATE(mas, &mm->mm_mt, 0, 0);

	ret = -EFAULT;
	if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister)))
@@ -1517,7 +1536,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
		goto out;

	mmap_write_lock(mm);
	vma = find_vma_prev(mm, start, &prev);
	mas_set(&mas, start);
	vma = mas_find(&mas, ULONG_MAX);
	if (!vma)
		goto out_unlock;

@@ -1542,7 +1562,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
	 */
	found = false;
	ret = -EINVAL;
	for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) {
	for (cur = vma; cur; cur = mas_next(&mas, end - 1)) {
		cond_resched();

		BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^
@@ -1562,8 +1582,10 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
	}
	BUG_ON(!found);

	if (vma->vm_start < start)
		prev = vma;
	mas_set(&mas, start);
	prev = mas_prev(&mas, 0);
	if (prev != vma)
		mas_next(&mas, ULONG_MAX);

	ret = 0;
	do {
@@ -1632,8 +1654,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
	skip:
		prev = vma;
		start = vma->vm_end;
		vma = vma->vm_next;
	} while (vma && vma->vm_start < end);
		vma = mas_next(&mas, end - 1);
	} while (vma);
out_unlock:
	mmap_write_unlock(mm);
	mmput(mm);
+3 −4
Original line number Diff line number Diff line
@@ -175,9 +175,8 @@ extern bool userfaultfd_remove(struct vm_area_struct *vma,
			       unsigned long start,
			       unsigned long end);

extern int userfaultfd_unmap_prep(struct vm_area_struct *vma,
				  unsigned long start, unsigned long end,
				  struct list_head *uf);
extern int userfaultfd_unmap_prep(struct mm_struct *mm, unsigned long start,
				  unsigned long end, struct list_head *uf);
extern void userfaultfd_unmap_complete(struct mm_struct *mm,
				       struct list_head *uf);

@@ -258,7 +257,7 @@ static inline bool userfaultfd_remove(struct vm_area_struct *vma,
	return true;
}

static inline int userfaultfd_unmap_prep(struct vm_area_struct *vma,
static inline int userfaultfd_unmap_prep(struct mm_struct *mm,
					 unsigned long start, unsigned long end,
					 struct list_head *uf)
{
+1 −1
Original line number Diff line number Diff line
@@ -2545,7 +2545,7 @@ do_mas_align_munmap(struct ma_state *mas, struct vm_area_struct *vma,
		 * split, despite we could. This is unlikely enough
		 * failure that it's not worth optimizing it for.
		 */
		error = userfaultfd_unmap_prep(vma, start, end, uf);
		error = userfaultfd_unmap_prep(mm, start, end, uf);

		if (error)
			goto userfaultfd_error;