Commit 7ff94f27 authored by Kui-Feng Lee's avatar Kui-Feng Lee Committed by Alexei Starovoitov
Browse files

bpf: keep a reference to the mm, in case the task is dead.



Fix the system crash that happens when a task iterator travel through
vma of tasks.

In task iterators, we used to access mm by following the pointer on
the task_struct; however, the death of a task will clear the pointer,
even though we still hold the task_struct.  That can cause an
unexpected crash for a null pointer when an iterator is visiting a
task that dies during the visit.  Keeping a reference of mm on the
iterator ensures we always have a valid pointer to mm.

Co-developed-by: default avatarSong Liu <song@kernel.org>
Signed-off-by: default avatarSong Liu <song@kernel.org>
Signed-off-by: default avatarKui-Feng Lee <kuifeng@meta.com>
Reported-by: default avatarNathan Slingerland <slinger@meta.com>
Acked-by: default avatarYonghong Song <yhs@fb.com>
Link: https://lore.kernel.org/r/20221216221855.4122288-2-kuifeng@meta.com


Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parent 8f161ca1
Loading
Loading
Loading
Loading
+27 −12
Original line number Diff line number Diff line
@@ -438,6 +438,7 @@ struct bpf_iter_seq_task_vma_info {
	 */
	struct bpf_iter_seq_task_common common;
	struct task_struct *task;
	struct mm_struct *mm;
	struct vm_area_struct *vma;
	u32 tid;
	unsigned long prev_vm_start;
@@ -456,16 +457,19 @@ task_vma_seq_get_next(struct bpf_iter_seq_task_vma_info *info)
	enum bpf_task_vma_iter_find_op op;
	struct vm_area_struct *curr_vma;
	struct task_struct *curr_task;
	struct mm_struct *curr_mm;
	u32 saved_tid = info->tid;

	/* If this function returns a non-NULL vma, it holds a reference to
	 * the task_struct, and holds read lock on vma->mm->mmap_lock.
	 * the task_struct, holds a refcount on mm->mm_users, and holds
	 * read lock on vma->mm->mmap_lock.
	 * If this function returns NULL, it does not hold any reference or
	 * lock.
	 */
	if (info->task) {
		curr_task = info->task;
		curr_vma = info->vma;
		curr_mm = info->mm;
		/* In case of lock contention, drop mmap_lock to unblock
		 * the writer.
		 *
@@ -504,13 +508,15 @@ task_vma_seq_get_next(struct bpf_iter_seq_task_vma_info *info)
		 *    4.2) VMA2 and VMA2' covers different ranges, process
		 *         VMA2'.
		 */
		if (mmap_lock_is_contended(curr_task->mm)) {
		if (mmap_lock_is_contended(curr_mm)) {
			info->prev_vm_start = curr_vma->vm_start;
			info->prev_vm_end = curr_vma->vm_end;
			op = task_vma_iter_find_vma;
			mmap_read_unlock(curr_task->mm);
			if (mmap_read_lock_killable(curr_task->mm))
			mmap_read_unlock(curr_mm);
			if (mmap_read_lock_killable(curr_mm)) {
				mmput(curr_mm);
				goto finish;
			}
		} else {
			op = task_vma_iter_next_vma;
		}
@@ -535,42 +541,47 @@ task_vma_seq_get_next(struct bpf_iter_seq_task_vma_info *info)
			op = task_vma_iter_find_vma;
		}

		if (!curr_task->mm)
		curr_mm = get_task_mm(curr_task);
		if (!curr_mm)
			goto next_task;

		if (mmap_read_lock_killable(curr_task->mm))
		if (mmap_read_lock_killable(curr_mm)) {
			mmput(curr_mm);
			goto finish;
		}
	}

	switch (op) {
	case task_vma_iter_first_vma:
		curr_vma = find_vma(curr_task->mm, 0);
		curr_vma = find_vma(curr_mm, 0);
		break;
	case task_vma_iter_next_vma:
		curr_vma = find_vma(curr_task->mm, curr_vma->vm_end);
		curr_vma = find_vma(curr_mm, curr_vma->vm_end);
		break;
	case task_vma_iter_find_vma:
		/* We dropped mmap_lock so it is necessary to use find_vma
		 * to find the next vma. This is similar to the  mechanism
		 * in show_smaps_rollup().
		 */
		curr_vma = find_vma(curr_task->mm, info->prev_vm_end - 1);
		curr_vma = find_vma(curr_mm, info->prev_vm_end - 1);
		/* case 1) and 4.2) above just use curr_vma */

		/* check for case 2) or case 4.1) above */
		if (curr_vma &&
		    curr_vma->vm_start == info->prev_vm_start &&
		    curr_vma->vm_end == info->prev_vm_end)
			curr_vma = find_vma(curr_task->mm, curr_vma->vm_end);
			curr_vma = find_vma(curr_mm, curr_vma->vm_end);
		break;
	}
	if (!curr_vma) {
		/* case 3) above, or case 2) 4.1) with vma->next == NULL */
		mmap_read_unlock(curr_task->mm);
		mmap_read_unlock(curr_mm);
		mmput(curr_mm);
		goto next_task;
	}
	info->task = curr_task;
	info->vma = curr_vma;
	info->mm = curr_mm;
	return curr_vma;

next_task:
@@ -579,6 +590,7 @@ task_vma_seq_get_next(struct bpf_iter_seq_task_vma_info *info)

	put_task_struct(curr_task);
	info->task = NULL;
	info->mm = NULL;
	info->tid++;
	goto again;

@@ -587,6 +599,7 @@ task_vma_seq_get_next(struct bpf_iter_seq_task_vma_info *info)
		put_task_struct(curr_task);
	info->task = NULL;
	info->vma = NULL;
	info->mm = NULL;
	return NULL;
}

@@ -658,7 +671,9 @@ static void task_vma_seq_stop(struct seq_file *seq, void *v)
		 */
		info->prev_vm_start = ~0UL;
		info->prev_vm_end = info->vma->vm_end;
		mmap_read_unlock(info->task->mm);
		mmap_read_unlock(info->mm);
		mmput(info->mm);
		info->mm = NULL;
		put_task_struct(info->task);
		info->task = NULL;
	}