Commit 92fed820 authored by Liam R. Howlett's avatar Liam R. Howlett Committed by Andrew Morton
Browse files

mm/mmap: convert brk to use vma iterator

Use the vma iterator API for the brk() system call.  This will provide
type safety at compile time.

Link: https://lkml.kernel.org/r/20230120162650.984577-9-Liam.Howlett@oracle.com


Signed-off-by: default avatarLiam R. Howlett <Liam.Howlett@oracle.com>
Signed-off-by: default avatarAndrew Morton <akpm@linux-foundation.org>
parent b62b633e
Loading
Loading
Loading
Loading
+23 −25
Original line number Diff line number Diff line
@@ -180,10 +180,10 @@ static int check_brk_limits(unsigned long addr, unsigned long len)

	return mlock_future_check(current->mm, current->mm->def_flags, len);
}
static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
static int do_brk_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma,
			 unsigned long newbrk, unsigned long oldbrk,
			 struct list_head *uf);
static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *brkvma,
static int do_brk_flags(struct vma_iterator *vmi, struct vm_area_struct *brkvma,
		unsigned long addr, unsigned long request, unsigned long flags);
SYSCALL_DEFINE1(brk, unsigned long, brk)
{
@@ -194,7 +194,7 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
	bool populate;
	bool downgraded = false;
	LIST_HEAD(uf);
	MA_STATE(mas, &mm->mm_mt, 0, 0);
	struct vma_iterator vmi;

	if (mmap_write_lock_killable(mm))
		return -EINTR;
@@ -242,8 +242,8 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
		int ret;

		/* Search one past newbrk */
		mas_set(&mas, newbrk);
		brkvma = mas_find(&mas, oldbrk);
		vma_iter_init(&vmi, mm, newbrk);
		brkvma = vma_find(&vmi, oldbrk);
		if (!brkvma || brkvma->vm_start >= oldbrk)
			goto out; /* mapping intersects with an existing non-brk vma. */
		/*
@@ -252,7 +252,7 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
		 * before calling do_brk_munmap().
		 */
		mm->brk = brk;
		ret = do_brk_munmap(&mas, brkvma, newbrk, oldbrk, &uf);
		ret = do_brk_munmap(&vmi, brkvma, newbrk, oldbrk, &uf);
		if (ret == 1)  {
			downgraded = true;
			goto success;
@@ -270,14 +270,14 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
	 * Only check if the next VMA is within the stack_guard_gap of the
	 * expansion area
	 */
	mas_set(&mas, oldbrk);
	next = mas_find(&mas, newbrk - 1 + PAGE_SIZE + stack_guard_gap);
	vma_iter_init(&vmi, mm, oldbrk);
	next = vma_find(&vmi, newbrk + PAGE_SIZE + stack_guard_gap);
	if (next && newbrk + PAGE_SIZE > vm_start_gap(next))
		goto out;

	brkvma = mas_prev(&mas, mm->start_brk);
	brkvma = vma_prev_limit(&vmi, mm->start_brk);
	/* Ok, looks good - let it rip. */
	if (do_brk_flags(&mas, brkvma, oldbrk, newbrk - oldbrk, 0) < 0)
	if (do_brk_flags(&vmi, brkvma, oldbrk, newbrk - oldbrk, 0) < 0)
		goto out;

	mm->brk = brk;
@@ -2917,8 +2917,8 @@ SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size,
}

/*
 * brk_munmap() - Unmap a partial vma.
 * @mas: The maple tree state.
 * brk_munmap() - Unmap a full or partial vma.
 * @vmi: The vma iterator
 * @vma: The vma to be modified
 * @newbrk: the start of the address to unmap
 * @oldbrk: The end of the address to unmap
@@ -2928,7 +2928,7 @@ SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size,
 * unmaps a partial VMA mapping.  Does not handle alignment, downgrades lock if
 * possible.
 */
static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
static int do_brk_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma,
			 unsigned long newbrk, unsigned long oldbrk,
			 struct list_head *uf)
{
@@ -2936,14 +2936,14 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
	int ret;

	arch_unmap(mm, newbrk, oldbrk);
	ret = do_mas_align_munmap(mas, vma, mm, newbrk, oldbrk, uf, true);
	ret = do_mas_align_munmap(&vmi->mas, vma, mm, newbrk, oldbrk, uf, true);
	validate_mm_mt(mm);
	return ret;
}

/*
 * do_brk_flags() - Increase the brk vma if the flags match.
 * @mas: The maple tree state.
 * @vmi: The vma iterator
 * @addr: The start address
 * @len: The length of the increase
 * @vma: The vma,
@@ -2953,7 +2953,7 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
 * do not match then create a new anonymous VMA.  Eventually we may be able to
 * do some brk-specific accounting here.
 */
static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *vma,
static int do_brk_flags(struct vma_iterator *vmi, struct vm_area_struct *vma,
		unsigned long addr, unsigned long len, unsigned long flags)
{
	struct mm_struct *mm = current->mm;
@@ -2980,8 +2980,7 @@ static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *vma,
	if (vma && vma->vm_end == addr && !vma_policy(vma) &&
	    can_vma_merge_after(vma, flags, NULL, NULL,
				addr >> PAGE_SHIFT, NULL_VM_UFFD_CTX, NULL)) {
		mas_set_range(mas, vma->vm_start, addr + len - 1);
		if (mas_preallocate(mas, GFP_KERNEL))
		if (vma_iter_prealloc(vmi))
			goto unacct_fail;

		vma_adjust_trans_huge(vma, vma->vm_start, addr + len, 0);
@@ -2991,7 +2990,7 @@ static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *vma,
		}
		vma->vm_end = addr + len;
		vma->vm_flags |= VM_SOFTDIRTY;
		mas_store_prealloc(mas, vma);
		vma_iter_store(vmi, vma);

		if (vma->anon_vma) {
			anon_vma_interval_tree_post_update_vma(vma);
@@ -3012,8 +3011,7 @@ static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *vma,
	vma->vm_pgoff = addr >> PAGE_SHIFT;
	vma->vm_flags = flags;
	vma->vm_page_prot = vm_get_page_prot(flags);
	mas_set_range(mas, vma->vm_start, addr + len - 1);
	if (mas_store_gfp(mas, vma, GFP_KERNEL))
	if (vma_iter_store_gfp(vmi, vma, GFP_KERNEL))
		goto mas_store_fail;

	mm->map_count++;
@@ -3042,7 +3040,7 @@ int vm_brk_flags(unsigned long addr, unsigned long request, unsigned long flags)
	int ret;
	bool populate;
	LIST_HEAD(uf);
	MA_STATE(mas, &mm->mm_mt, addr, addr);
	VMA_ITERATOR(vmi, mm, addr);

	len = PAGE_ALIGN(request);
	if (len < request)
@@ -3061,12 +3059,12 @@ int vm_brk_flags(unsigned long addr, unsigned long request, unsigned long flags)
	if (ret)
		goto limits_failed;

	ret = do_mas_munmap(&mas, mm, addr, len, &uf, 0);
	ret = do_mas_munmap(&vmi.mas, mm, addr, len, &uf, 0);
	if (ret)
		goto munmap_failed;

	vma = mas_prev(&mas, 0);
	ret = do_brk_flags(&mas, vma, addr, len, flags);
	vma = vma_prev(&vmi);
	ret = do_brk_flags(&vmi, vma, addr, len, flags);
	populate = ((mm->def_flags & VM_LOCKED) != 0);
	mmap_write_unlock(mm);
	userfaultfd_unmap_complete(mm, &uf);