Commit c03531e0 authored by Andrii Nakryiko's avatar Andrii Nakryiko
Browse files

Merge branch 'verify scalar ids mapping in regsafe()'

Eduard Zingerman says:

====================
Update regsafe() to use check_ids() for scalar values.
Otherwise the following unsafe pattern is accepted by verifier:

  1: r9 = ... some pointer with range X ...
  2: r6 = ... unbound scalar ID=a ...
  3: r7 = ... unbound scalar ID=b ...
  4: if (r6 > r7) goto +1
  5: r6 = r7
  6: if (r6 > X) goto ...
  --- checkpoint ---
  7: r9 += r7
  8: *(u64 *)r9 = Y

This example is unsafe because not all execution paths verify r7 range.
Because of the jump at (4) the verifier would arrive at (6) in two states:
I.  r6{.id=b}, r7{.id=b} via path 1-6;
II. r6{.id=a}, r7{.id=b} via path 1-4, 6.

Currently regsafe() does not call check_ids() for scalar registers,
thus from POV of regsafe() states (I) and (II) are identical.

The change is split in two parts:
- patches #1,2: update for mark_chain_precision() to propagate
  precision marks through scalar IDs.
- patches #3,4: update for regsafe() to use a special version of
  check_ids() for precise scalar values.

Changelog:
- V5 -> V6:
  - check_ids() is modified to disallow mapping different 'old_id' to
    the same 'cur_id', check_scalar_ids() simplified (Andrii);
  - idset_push() updated to return -EFAULT instead of -1 (Andrii);
  - comments fixed in check_ids_in_regsafe() test case
    (Maxim Mikityanskiy);
  - fixed memset warning in states_equal() reported in [4].
- V4 -> V5 (all changes are based on feedback for V4 from Andrii):
  - mark_precise_scalar_ids() error code is updated to EFAULT;
  - bpf_verifier_env::idmap_scratch field type is changed to struct
    bpf_idmap to encapsulate temporary ID generation counter;
  - regsafe() is updated to call scalar_regs_exact() only for
    env->explore_alu_limits case (this had no measurable impact on
    verification duration when tested using veristat).
- V3 -> V4:
  - check_ids() in regsafe() is replaced by check_scalar_ids(),
    as discussed with Andrii in [3],
    Note: I did not transfer Andrii's ack for patch #3 from V3 because
          of the changes to the algorithm.
  - reg_id_scratch is renamed to idset_scratch;
  - mark_precise_scalar_ids() is modified to propagate error from
    idset_push();
  - test cases adjusted according to feedback from Andrii for V3.
- V2 -> V3:
  - u32_hashset for IDs used for range transfer is removed;
  - mark_chain_precision() is updated as discussed with Andrii in [2].
- V1 -> v2:
  - 'rold->precise' and 'rold->id' checks are dropped as unsafe
    (thanks to discussion with Yonghong);
  - patches #3,4 adding tracking of ids used for range transfer in
    order to mitigate performance impact.
- RFC -> V1:
  - Function verifier.c:mark_equal_scalars_as_read() is dropped,
    as it was an incorrect fix for problem solved by commit [3].
  - check_ids() is called only for precise scalar values.
  - Test case updated to use inline assembly.

[V1]  https://lore.kernel.org/bpf/20230526184126.3104040-1-eddyz87@gmail.com/
[V2]  https://lore.kernel.org/bpf/20230530172739.447290-1-eddyz87@gmail.com/
[V3]  https://lore.kernel.org/bpf/20230606222411.1820404-1-eddyz87@gmail.com/
[V4]  https://lore.kernel.org/bpf/20230609210143.2625430-1-eddyz87@gmail.com/
[V5]  https://lore.kernel.org/bpf/20230612160801.2804666-1-eddyz87@gmail.com/
[RFC] https://lore.kernel.org/bpf/20221128163442.280187-1-eddyz87@gmail.com/
[1]   https://gist.github.com/eddyz87/a32ea7e62a27d3c201117c9a39ab4286
[2]   https://lore.kernel.org/bpf/20230530172739.447290-1-eddyz87@gmail.com/T/#mc21009dcd8574b195c1860a98014bb037f16f450
[3]   https://lore.kernel.org/bpf/20230606222411.1820404-1-eddyz87@gmail.com/T/#m89da8eeb2fa8c9ca1202c5d0b6660e1f72e45e04
[4]   https://lore.kernel.org/oe-kbuild-all/202306131550.U3M9AJGm-lkp@intel.com/


====================

Signed-off-by: default avatarAndrii Nakryiko <andrii@kernel.org>
parents 25085b4e 18b89265
Loading
Loading
Loading
Loading
+19 −6
Original line number Diff line number Diff line
@@ -313,11 +313,6 @@ struct bpf_idx_pair {
	u32 idx;
};

struct bpf_id_pair {
	u32 old;
	u32 cur;
};

#define MAX_CALL_FRAMES 8
/* Maximum number of register states that can exist at once */
#define BPF_ID_MAP_SIZE ((MAX_BPF_REG + MAX_BPF_STACK / BPF_REG_SIZE) * MAX_CALL_FRAMES)
@@ -557,6 +552,21 @@ struct backtrack_state {
	u64 stack_masks[MAX_CALL_FRAMES];
};

struct bpf_id_pair {
	u32 old;
	u32 cur;
};

struct bpf_idmap {
	u32 tmp_id_gen;
	struct bpf_id_pair map[BPF_ID_MAP_SIZE];
};

struct bpf_idset {
	u32 count;
	u32 ids[BPF_ID_MAP_SIZE];
};

/* single container for all structs
 * one verifier_env per bpf_check() call
 */
@@ -588,7 +598,10 @@ struct bpf_verifier_env {
	const struct bpf_line_info *prev_linfo;
	struct bpf_verifier_log log;
	struct bpf_subprog_info subprog_info[BPF_MAX_SUBPROGS + 1];
	struct bpf_id_pair idmap_scratch[BPF_ID_MAP_SIZE];
	union {
		struct bpf_idmap idmap_scratch;
		struct bpf_idset idset_scratch;
	};
	struct {
		int *insn_state;
		int *insn_stack;
+183 −23
Original line number Diff line number Diff line
@@ -3779,6 +3779,96 @@ static void mark_all_scalars_imprecise(struct bpf_verifier_env *env, struct bpf_
	}
}
static bool idset_contains(struct bpf_idset *s, u32 id)
{
	u32 i;
	for (i = 0; i < s->count; ++i)
		if (s->ids[i] == id)
			return true;
	return false;
}
static int idset_push(struct bpf_idset *s, u32 id)
{
	if (WARN_ON_ONCE(s->count >= ARRAY_SIZE(s->ids)))
		return -EFAULT;
	s->ids[s->count++] = id;
	return 0;
}
static void idset_reset(struct bpf_idset *s)
{
	s->count = 0;
}
/* Collect a set of IDs for all registers currently marked as precise in env->bt.
 * Mark all registers with these IDs as precise.
 */
static int mark_precise_scalar_ids(struct bpf_verifier_env *env, struct bpf_verifier_state *st)
{
	struct bpf_idset *precise_ids = &env->idset_scratch;
	struct backtrack_state *bt = &env->bt;
	struct bpf_func_state *func;
	struct bpf_reg_state *reg;
	DECLARE_BITMAP(mask, 64);
	int i, fr;
	idset_reset(precise_ids);
	for (fr = bt->frame; fr >= 0; fr--) {
		func = st->frame[fr];
		bitmap_from_u64(mask, bt_frame_reg_mask(bt, fr));
		for_each_set_bit(i, mask, 32) {
			reg = &func->regs[i];
			if (!reg->id || reg->type != SCALAR_VALUE)
				continue;
			if (idset_push(precise_ids, reg->id))
				return -EFAULT;
		}
		bitmap_from_u64(mask, bt_frame_stack_mask(bt, fr));
		for_each_set_bit(i, mask, 64) {
			if (i >= func->allocated_stack / BPF_REG_SIZE)
				break;
			if (!is_spilled_scalar_reg(&func->stack[i]))
				continue;
			reg = &func->stack[i].spilled_ptr;
			if (!reg->id)
				continue;
			if (idset_push(precise_ids, reg->id))
				return -EFAULT;
		}
	}
	for (fr = 0; fr <= st->curframe; ++fr) {
		func = st->frame[fr];
		for (i = BPF_REG_0; i < BPF_REG_10; ++i) {
			reg = &func->regs[i];
			if (!reg->id)
				continue;
			if (!idset_contains(precise_ids, reg->id))
				continue;
			bt_set_frame_reg(bt, fr, i);
		}
		for (i = 0; i < func->allocated_stack / BPF_REG_SIZE; ++i) {
			if (!is_spilled_scalar_reg(&func->stack[i]))
				continue;
			reg = &func->stack[i].spilled_ptr;
			if (!reg->id)
				continue;
			if (!idset_contains(precise_ids, reg->id))
				continue;
			bt_set_frame_slot(bt, fr, i);
		}
	}
	return 0;
}
/*
 * __mark_chain_precision() backtracks BPF program instruction sequence and
 * chain of verifier states making sure that register *regno* (if regno >= 0)
@@ -3910,6 +4000,31 @@ static int __mark_chain_precision(struct bpf_verifier_env *env, int regno)
				bt->frame, last_idx, first_idx, subseq_idx);
		}
		/* If some register with scalar ID is marked as precise,
		 * make sure that all registers sharing this ID are also precise.
		 * This is needed to estimate effect of find_equal_scalars().
		 * Do this at the last instruction of each state,
		 * bpf_reg_state::id fields are valid for these instructions.
		 *
		 * Allows to track precision in situation like below:
		 *
		 *     r2 = unknown value
		 *     ...
		 *   --- state #0 ---
		 *     ...
		 *     r1 = r2                 // r1 and r2 now share the same ID
		 *     ...
		 *   --- state #1 {r1.id = A, r2.id = A} ---
		 *     ...
		 *     if (r2 > 10) goto exit; // find_equal_scalars() assigns range to r1
		 *     ...
		 *   --- state #2 {r1.id = A, r2.id = A} ---
		 *     r3 = r10
		 *     r3 += r1                // need to mark both r1 and r2
		 */
		if (mark_precise_scalar_ids(env, st))
			return -EFAULT;
		if (last_idx < 0) {
			/* we are at the entry into subprog, which
			 * is expected for global funcs, but only if
@@ -12819,12 +12934,14 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)
		if (BPF_SRC(insn->code) == BPF_X) {
			struct bpf_reg_state *src_reg = regs + insn->src_reg;
			struct bpf_reg_state *dst_reg = regs + insn->dst_reg;
			bool need_id = src_reg->type == SCALAR_VALUE && !src_reg->id &&
				       !tnum_is_const(src_reg->var_off);
			if (BPF_CLASS(insn->code) == BPF_ALU64) {
				/* case: R1 = R2
				 * copy register state to dest reg
				 */
				if (src_reg->type == SCALAR_VALUE && !src_reg->id)
				if (need_id)
					/* Assign src and dst registers the same ID
					 * that will be used by find_equal_scalars()
					 * to propagate min/max range.
@@ -12843,7 +12960,7 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)
				} else if (src_reg->type == SCALAR_VALUE) {
					bool is_src_reg_u32 = src_reg->umax_value <= U32_MAX;
					if (is_src_reg_u32 && !src_reg->id)
					if (is_src_reg_u32 && need_id)
						src_reg->id = ++env->id_gen;
					copy_register_state(dst_reg, src_reg);
					/* Make sure ID is cleared if src_reg is not in u32 range otherwise
@@ -14999,8 +15116,9 @@ static bool range_within(struct bpf_reg_state *old,
 * So we look through our idmap to see if this old id has been seen before.  If
 * so, we require the new id to match; otherwise, we add the id pair to the map.
 */
static bool check_ids(u32 old_id, u32 cur_id, struct bpf_id_pair *idmap)
static bool check_ids(u32 old_id, u32 cur_id, struct bpf_idmap *idmap)
{
	struct bpf_id_pair *map = idmap->map;
	unsigned int i;
	/* either both IDs should be set or both should be zero */
@@ -15011,20 +15129,34 @@ static bool check_ids(u32 old_id, u32 cur_id, struct bpf_id_pair *idmap)
		return true;
	for (i = 0; i < BPF_ID_MAP_SIZE; i++) {
		if (!idmap[i].old) {
		if (!map[i].old) {
			/* Reached an empty slot; haven't seen this id before */
			idmap[i].old = old_id;
			idmap[i].cur = cur_id;
			map[i].old = old_id;
			map[i].cur = cur_id;
			return true;
		}
		if (idmap[i].old == old_id)
			return idmap[i].cur == cur_id;
		if (map[i].old == old_id)
			return map[i].cur == cur_id;
		if (map[i].cur == cur_id)
			return false;
	}
	/* We ran out of idmap slots, which should be impossible */
	WARN_ON_ONCE(1);
	return false;
}
/* Similar to check_ids(), but allocate a unique temporary ID
 * for 'old_id' or 'cur_id' of zero.
 * This makes pairs like '0 vs unique ID', 'unique ID vs 0' valid.
 */
static bool check_scalar_ids(u32 old_id, u32 cur_id, struct bpf_idmap *idmap)
{
	old_id = old_id ? old_id : ++idmap->tmp_id_gen;
	cur_id = cur_id ? cur_id : ++idmap->tmp_id_gen;
	return check_ids(old_id, cur_id, idmap);
}
static void clean_func_state(struct bpf_verifier_env *env,
			     struct bpf_func_state *st)
{
@@ -15123,7 +15255,7 @@ static void clean_live_states(struct bpf_verifier_env *env, int insn,
static bool regs_exact(const struct bpf_reg_state *rold,
		       const struct bpf_reg_state *rcur,
		       struct bpf_id_pair *idmap)
		       struct bpf_idmap *idmap)
{
	return memcmp(rold, rcur, offsetof(struct bpf_reg_state, id)) == 0 &&
	       check_ids(rold->id, rcur->id, idmap) &&
@@ -15132,7 +15264,7 @@ static bool regs_exact(const struct bpf_reg_state *rold,
/* Returns true if (rold safe implies rcur safe) */
static bool regsafe(struct bpf_verifier_env *env, struct bpf_reg_state *rold,
		    struct bpf_reg_state *rcur, struct bpf_id_pair *idmap)
		    struct bpf_reg_state *rcur, struct bpf_idmap *idmap)
{
	if (!(rold->live & REG_LIVE_READ))
		/* explored state didn't use this */
@@ -15169,15 +15301,42 @@ static bool regsafe(struct bpf_verifier_env *env, struct bpf_reg_state *rold,
	switch (base_type(rold->type)) {
	case SCALAR_VALUE:
		if (regs_exact(rold, rcur, idmap))
			return true;
		if (env->explore_alu_limits)
			return false;
		if (env->explore_alu_limits) {
			/* explore_alu_limits disables tnum_in() and range_within()
			 * logic and requires everything to be strict
			 */
			return memcmp(rold, rcur, offsetof(struct bpf_reg_state, id)) == 0 &&
			       check_scalar_ids(rold->id, rcur->id, idmap);
		}
		if (!rold->precise)
			return true;
		/* new val must satisfy old val knowledge */
		/* Why check_ids() for scalar registers?
		 *
		 * Consider the following BPF code:
		 *   1: r6 = ... unbound scalar, ID=a ...
		 *   2: r7 = ... unbound scalar, ID=b ...
		 *   3: if (r6 > r7) goto +1
		 *   4: r6 = r7
		 *   5: if (r6 > X) goto ...
		 *   6: ... memory operation using r7 ...
		 *
		 * First verification path is [1-6]:
		 * - at (4) same bpf_reg_state::id (b) would be assigned to r6 and r7;
		 * - at (5) r6 would be marked <= X, find_equal_scalars() would also mark
		 *   r7 <= X, because r6 and r7 share same id.
		 * Next verification path is [1-4, 6].
		 *
		 * Instruction (6) would be reached in two states:
		 *   I.  r6{.id=b}, r7{.id=b} via path 1-6;
		 *   II. r6{.id=a}, r7{.id=b} via path 1-4, 6.
		 *
		 * Use check_ids() to distinguish these states.
		 * ---
		 * Also verify that new value satisfies old value range knowledge.
		 */
		return range_within(rold, rcur) &&
		       tnum_in(rold->var_off, rcur->var_off);
		       tnum_in(rold->var_off, rcur->var_off) &&
		       check_scalar_ids(rold->id, rcur->id, idmap);
	case PTR_TO_MAP_KEY:
	case PTR_TO_MAP_VALUE:
	case PTR_TO_MEM:
@@ -15223,7 +15382,7 @@ static bool regsafe(struct bpf_verifier_env *env, struct bpf_reg_state *rold,
}
static bool stacksafe(struct bpf_verifier_env *env, struct bpf_func_state *old,
		      struct bpf_func_state *cur, struct bpf_id_pair *idmap)
		      struct bpf_func_state *cur, struct bpf_idmap *idmap)
{
	int i, spi;
@@ -15326,7 +15485,7 @@ static bool stacksafe(struct bpf_verifier_env *env, struct bpf_func_state *old,
}
static bool refsafe(struct bpf_func_state *old, struct bpf_func_state *cur,
		    struct bpf_id_pair *idmap)
		    struct bpf_idmap *idmap)
{
	int i;
@@ -15374,13 +15533,13 @@ static bool func_states_equal(struct bpf_verifier_env *env, struct bpf_func_stat
	for (i = 0; i < MAX_BPF_REG; i++)
		if (!regsafe(env, &old->regs[i], &cur->regs[i],
			     env->idmap_scratch))
			     &env->idmap_scratch))
			return false;
	if (!stacksafe(env, old, cur, env->idmap_scratch))
	if (!stacksafe(env, old, cur, &env->idmap_scratch))
		return false;
	if (!refsafe(old, cur, env->idmap_scratch))
	if (!refsafe(old, cur, &env->idmap_scratch))
		return false;
	return true;
@@ -15395,7 +15554,8 @@ static bool states_equal(struct bpf_verifier_env *env,
	if (old->curframe != cur->curframe)
		return false;
	memset(env->idmap_scratch, 0, sizeof(env->idmap_scratch));
	env->idmap_scratch.tmp_id_gen = env->id_gen;
	memset(&env->idmap_scratch.map, 0, sizeof(env->idmap_scratch.map));
	/* Verification state from speculative execution simulation
	 * must never prune a non-speculative execution one.
@@ -15413,7 +15573,7 @@ static bool states_equal(struct bpf_verifier_env *env,
		return false;
	if (old->active_lock.id &&
	    !check_ids(old->active_lock.id, cur->active_lock.id, env->idmap_scratch))
	    !check_ids(old->active_lock.id, cur->active_lock.id, &env->idmap_scratch))
		return false;
	if (old->active_rcu_lock != cur->active_rcu_lock)
+2 −0
Original line number Diff line number Diff line
@@ -50,6 +50,7 @@
#include "verifier_regalloc.skel.h"
#include "verifier_ringbuf.skel.h"
#include "verifier_runtime_jit.skel.h"
#include "verifier_scalar_ids.skel.h"
#include "verifier_search_pruning.skel.h"
#include "verifier_sock.skel.h"
#include "verifier_spill_fill.skel.h"
@@ -150,6 +151,7 @@ void test_verifier_ref_tracking(void) { RUN(verifier_ref_tracking); }
void test_verifier_regalloc(void)             { RUN(verifier_regalloc); }
void test_verifier_ringbuf(void)              { RUN(verifier_ringbuf); }
void test_verifier_runtime_jit(void)          { RUN(verifier_runtime_jit); }
void test_verifier_scalar_ids(void)           { RUN(verifier_scalar_ids); }
void test_verifier_search_pruning(void)       { RUN(verifier_search_pruning); }
void test_verifier_sock(void)                 { RUN(verifier_sock); }
void test_verifier_spill_fill(void)           { RUN(verifier_spill_fill); }
+659 −0

File added.

Preview size limit exceeded, changes collapsed.

+4 −4
Original line number Diff line number Diff line
@@ -46,7 +46,7 @@
	mark_precise: frame0: regs=r2 stack= before 20\
	mark_precise: frame0: parent state regs=r2 stack=:\
	mark_precise: frame0: last_idx 19 first_idx 10\
	mark_precise: frame0: regs=r2 stack= before 19\
	mark_precise: frame0: regs=r2,r9 stack= before 19\
	mark_precise: frame0: regs=r9 stack= before 18\
	mark_precise: frame0: regs=r8,r9 stack= before 17\
	mark_precise: frame0: regs=r0,r9 stack= before 15\
@@ -106,10 +106,10 @@
	mark_precise: frame0: regs=r2 stack= before 22\
	mark_precise: frame0: parent state regs=r2 stack=:\
	mark_precise: frame0: last_idx 20 first_idx 20\
	mark_precise: frame0: regs=r2 stack= before 20\
	mark_precise: frame0: parent state regs=r2 stack=:\
	mark_precise: frame0: regs=r2,r9 stack= before 20\
	mark_precise: frame0: parent state regs=r2,r9 stack=:\
	mark_precise: frame0: last_idx 19 first_idx 17\
	mark_precise: frame0: regs=r2 stack= before 19\
	mark_precise: frame0: regs=r2,r9 stack= before 19\
	mark_precise: frame0: regs=r9 stack= before 18\
	mark_precise: frame0: regs=r8,r9 stack= before 17\
	mark_precise: frame0: parent state regs= stack=:",