Commit 11a9b902 authored by Liam R. Howlett's avatar Liam R. Howlett Committed by Andrew Morton
Browse files

userfaultfd: use vma iterator

Use the vma iterator so that the iterator can be invalidated or updated to
avoid each caller doing so.

Link: https://lkml.kernel.org/r/20230120162650.984577-17-Liam.Howlett@oracle.com


Signed-off-by: default avatarLiam R. Howlett <Liam.Howlett@oracle.com>
Signed-off-by: default avatarAndrew Morton <akpm@linux-foundation.org>
parent 27b26701
Loading
Loading
Loading
Loading
+33 −54
Original line number Diff line number Diff line
@@ -883,7 +883,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
	/* len == 0 means wake all */
	struct userfaultfd_wake_range range = { .len = 0, };
	unsigned long new_flags;
	MA_STATE(mas, &mm->mm_mt, 0, 0);
	VMA_ITERATOR(vmi, mm, 0);

	WRITE_ONCE(ctx->released, true);

@@ -900,7 +900,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
	 */
	mmap_write_lock(mm);
	prev = NULL;
	mas_for_each(&mas, vma, ULONG_MAX) {
	for_each_vma(vmi, vma) {
		cond_resched();
		BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^
		       !!(vma->vm_flags & __VM_UFFD_FLAGS));
@@ -909,13 +909,12 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
			continue;
		}
		new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
		prev = vma_merge(mm, prev, vma->vm_start, vma->vm_end,
		prev = vmi_vma_merge(&vmi, mm, prev, vma->vm_start, vma->vm_end,
				 new_flags, vma->anon_vma,
				 vma->vm_file, vma->vm_pgoff,
				 vma_policy(vma),
				 NULL_VM_UFFD_CTX, anon_vma_name(vma));
		if (prev) {
			mas_pause(&mas);
			vma = prev;
		} else {
			prev = vma;
@@ -1302,7 +1301,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
	bool found;
	bool basic_ioctls;
	unsigned long start, end, vma_end;
	MA_STATE(mas, &mm->mm_mt, 0, 0);
	struct vma_iterator vmi;

	user_uffdio_register = (struct uffdio_register __user *) arg;

@@ -1344,17 +1343,13 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
	if (!mmget_not_zero(mm))
		goto out;

	ret = -EINVAL;
	mmap_write_lock(mm);
	mas_set(&mas, start);
	vma = mas_find(&mas, ULONG_MAX);
	vma_iter_init(&vmi, mm, start);
	vma = vma_find(&vmi, end);
	if (!vma)
		goto out_unlock;

	/* check that there's at least one vma in the range */
	ret = -EINVAL;
	if (vma->vm_start >= end)
		goto out_unlock;

	/*
	 * If the first vma contains huge pages, make sure start address
	 * is aligned to huge page size.
@@ -1371,7 +1366,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
	 */
	found = false;
	basic_ioctls = false;
	for (cur = vma; cur; cur = mas_next(&mas, end - 1)) {
	cur = vma;
	do {
		cond_resched();

		BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^
@@ -1428,16 +1424,14 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
			basic_ioctls = true;

		found = true;
	}
	} for_each_vma_range(vmi, cur, end);
	BUG_ON(!found);

	mas_set(&mas, start);
	prev = mas_prev(&mas, 0);
	if (prev != vma)
		mas_next(&mas, ULONG_MAX);
	vma_iter_set(&vmi, start);
	prev = vma_prev(&vmi);

	ret = 0;
	do {
	for_each_vma_range(vmi, vma, end) {
		cond_resched();

		BUG_ON(!vma_can_userfault(vma, vm_flags));
@@ -1458,30 +1452,25 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
		vma_end = min(end, vma->vm_end);

		new_flags = (vma->vm_flags & ~__VM_UFFD_FLAGS) | vm_flags;
		prev = vma_merge(mm, prev, start, vma_end, new_flags,
		prev = vmi_vma_merge(&vmi, mm, prev, start, vma_end, new_flags,
				 vma->anon_vma, vma->vm_file, vma->vm_pgoff,
				 vma_policy(vma),
				 ((struct vm_userfaultfd_ctx){ ctx }),
				 anon_vma_name(vma));
		if (prev) {
			/* vma_merge() invalidated the mas */
			mas_pause(&mas);
			vma = prev;
			goto next;
		}
		if (vma->vm_start < start) {
			ret = split_vma(mm, vma, start, 1);
			ret = vmi_split_vma(&vmi, mm, vma, start, 1);
			if (ret)
				break;
			/* split_vma() invalidated the mas */
			mas_pause(&mas);
		}
		if (vma->vm_end > end) {
			ret = split_vma(mm, vma, end, 0);
			ret = vmi_split_vma(&vmi, mm, vma, end, 0);
			if (ret)
				break;
			/* split_vma() invalidated the mas */
			mas_pause(&mas);
		}
	next:
		/*
@@ -1498,8 +1487,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
	skip:
		prev = vma;
		start = vma->vm_end;
		vma = mas_next(&mas, end - 1);
	} while (vma);
	}

out_unlock:
	mmap_write_unlock(mm);
	mmput(mm);
@@ -1543,7 +1532,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
	bool found;
	unsigned long start, end, vma_end;
	const void __user *buf = (void __user *)arg;
	MA_STATE(mas, &mm->mm_mt, 0, 0);
	struct vma_iterator vmi;

	ret = -EFAULT;
	if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister)))
@@ -1562,14 +1551,10 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
		goto out;

	mmap_write_lock(mm);
	mas_set(&mas, start);
	vma = mas_find(&mas, ULONG_MAX);
	if (!vma)
		goto out_unlock;

	/* check that there's at least one vma in the range */
	ret = -EINVAL;
	if (vma->vm_start >= end)
	vma_iter_init(&vmi, mm, start);
	vma = vma_find(&vmi, end);
	if (!vma)
		goto out_unlock;

	/*
@@ -1587,8 +1572,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
	 * Search for not compatible vmas.
	 */
	found = false;
	ret = -EINVAL;
	for (cur = vma; cur; cur = mas_next(&mas, end - 1)) {
	cur = vma;
	do {
		cond_resched();

		BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^
@@ -1605,16 +1590,13 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
			goto out_unlock;

		found = true;
	}
	} for_each_vma_range(vmi, cur, end);
	BUG_ON(!found);

	mas_set(&mas, start);
	prev = mas_prev(&mas, 0);
	if (prev != vma)
		mas_next(&mas, ULONG_MAX);

	vma_iter_set(&vmi, start);
	prev = vma_prev(&vmi);
	ret = 0;
	do {
	for_each_vma_range(vmi, vma, end) {
		cond_resched();

		BUG_ON(!vma_can_userfault(vma, vma->vm_flags));
@@ -1650,26 +1632,23 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
			uffd_wp_range(mm, vma, start, vma_end - start, false);

		new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
		prev = vma_merge(mm, prev, start, vma_end, new_flags,
		prev = vmi_vma_merge(&vmi, mm, prev, start, vma_end, new_flags,
				 vma->anon_vma, vma->vm_file, vma->vm_pgoff,
				 vma_policy(vma),
				 NULL_VM_UFFD_CTX, anon_vma_name(vma));
		if (prev) {
			vma = prev;
			mas_pause(&mas);
			goto next;
		}
		if (vma->vm_start < start) {
			ret = split_vma(mm, vma, start, 1);
			ret = vmi_split_vma(&vmi, mm, vma, start, 1);
			if (ret)
				break;
			mas_pause(&mas);
		}
		if (vma->vm_end > end) {
			ret = split_vma(mm, vma, end, 0);
			ret = vmi_split_vma(&vmi, mm, vma, end, 0);
			if (ret)
				break;
			mas_pause(&mas);
		}
	next:
		/*
@@ -1683,8 +1662,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
	skip:
		prev = vma;
		start = vma->vm_end;
		vma = mas_next(&mas, end - 1);
	} while (vma);
	}

out_unlock:
	mmap_write_unlock(mm);
	mmput(mm);