Commit b239da34 authored by Kumar Kartikeya Dwivedi's avatar Kumar Kartikeya Dwivedi Committed by Alexei Starovoitov
Browse files

bpf: Add helper macro bpf_for_each_reg_in_vstate



For a lot of use cases in future patches, we will want to modify the
state of registers part of some same 'group' (e.g. same ref_obj_id). It
won't just be limited to releasing reference state, but setting a type
flag dynamically based on certain actions, etc.

Hence, we need a way to easily pass a callback to the function that
iterates over all registers in current bpf_verifier_state in all frames
upto (and including) the curframe.

While in C++ we would be able to easily use a lambda to pass state and
the callback together, sadly we aren't using C++ in the kernel. The next
best thing to avoid defining a function for each case seems like
statement expressions in GNU C. The kernel already uses them heavily,
hence they can passed to the macro in the style of a lambda. The
statement expression will then be substituted in the for loop bodies.

Variables __state and __reg are set to current bpf_func_state and reg
for each invocation of the expression inside the passed in verifier
state.

Then, convert mark_ptr_or_null_regs, clear_all_pkt_pointers,
release_reference, find_good_pkt_pointers, find_equal_scalars to
use bpf_for_each_reg_in_vstate.

Signed-off-by: default avatarKumar Kartikeya Dwivedi <memxor@gmail.com>
Link: https://lore.kernel.org/r/20220904204145.3089-16-memxor@gmail.com


Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parent cc487558
Loading
Loading
Loading
Loading
+21 −0
Original line number Diff line number Diff line
@@ -348,6 +348,27 @@ struct bpf_verifier_state {
	     iter < frame->allocated_stack / BPF_REG_SIZE;		\
	     iter++, reg = bpf_get_spilled_reg(iter, frame))

/* Invoke __expr over regsiters in __vst, setting __state and __reg */
#define bpf_for_each_reg_in_vstate(__vst, __state, __reg, __expr)   \
	({                                                               \
		struct bpf_verifier_state *___vstate = __vst;            \
		int ___i, ___j;                                          \
		for (___i = 0; ___i <= ___vstate->curframe; ___i++) {    \
			struct bpf_reg_state *___regs;                   \
			__state = ___vstate->frame[___i];                \
			___regs = __state->regs;                         \
			for (___j = 0; ___j < MAX_BPF_REG; ___j++) {     \
				__reg = &___regs[___j];                  \
				(void)(__expr);                          \
			}                                                \
			bpf_for_each_spilled_reg(___j, __state, __reg) { \
				if (!__reg)                              \
					continue;                        \
				(void)(__expr);                          \
			}                                                \
		}                                                        \
	})

/* linked list of verifier states used to prune search */
struct bpf_verifier_state_list {
	struct bpf_verifier_state state;
+28 −107
Original line number Diff line number Diff line
@@ -6513,31 +6513,15 @@ static int check_func_proto(const struct bpf_func_proto *fn, int func_id)
/* Packet data might have moved, any old PTR_TO_PACKET[_META,_END]
 * are now invalid, so turn them into unknown SCALAR_VALUE.
 */
static void __clear_all_pkt_pointers(struct bpf_verifier_env *env,
				     struct bpf_func_state *state)
static void clear_all_pkt_pointers(struct bpf_verifier_env *env)
{
	struct bpf_reg_state *regs = state->regs, *reg;
	int i;

	for (i = 0; i < MAX_BPF_REG; i++)
		if (reg_is_pkt_pointer_any(&regs[i]))
			mark_reg_unknown(env, regs, i);
	struct bpf_func_state *state;
	struct bpf_reg_state *reg;

	bpf_for_each_spilled_reg(i, state, reg) {
		if (!reg)
			continue;
	bpf_for_each_reg_in_vstate(env->cur_state, state, reg, ({
		if (reg_is_pkt_pointer_any(reg))
			__mark_reg_unknown(env, reg);
	}
}

static void clear_all_pkt_pointers(struct bpf_verifier_env *env)
{
	struct bpf_verifier_state *vstate = env->cur_state;
	int i;

	for (i = 0; i <= vstate->curframe; i++)
		__clear_all_pkt_pointers(env, vstate->frame[i]);
	}));
}

enum {
@@ -6566,41 +6550,24 @@ static void mark_pkt_end(struct bpf_verifier_state *vstate, int regn, bool range
		reg->range = AT_PKT_END;
}

static void release_reg_references(struct bpf_verifier_env *env,
				   struct bpf_func_state *state,
				   int ref_obj_id)
{
	struct bpf_reg_state *regs = state->regs, *reg;
	int i;

	for (i = 0; i < MAX_BPF_REG; i++)
		if (regs[i].ref_obj_id == ref_obj_id)
			mark_reg_unknown(env, regs, i);

	bpf_for_each_spilled_reg(i, state, reg) {
		if (!reg)
			continue;
		if (reg->ref_obj_id == ref_obj_id)
			__mark_reg_unknown(env, reg);
	}
}

/* The pointer with the specified id has released its reference to kernel
 * resources. Identify all copies of the same pointer and clear the reference.
 */
static int release_reference(struct bpf_verifier_env *env,
			     int ref_obj_id)
{
	struct bpf_verifier_state *vstate = env->cur_state;
	struct bpf_func_state *state;
	struct bpf_reg_state *reg;
	int err;
	int i;

	err = release_reference_state(cur_func(env), ref_obj_id);
	if (err)
		return err;

	for (i = 0; i <= vstate->curframe; i++)
		release_reg_references(env, vstate->frame[i], ref_obj_id);
	bpf_for_each_reg_in_vstate(env->cur_state, state, reg, ({
		if (reg->ref_obj_id == ref_obj_id)
			__mark_reg_unknown(env, reg);
	}));

	return 0;
}
@@ -9335,34 +9302,14 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)
	return 0;
}

static void __find_good_pkt_pointers(struct bpf_func_state *state,
				     struct bpf_reg_state *dst_reg,
				     enum bpf_reg_type type, int new_range)
{
	struct bpf_reg_state *reg;
	int i;

	for (i = 0; i < MAX_BPF_REG; i++) {
		reg = &state->regs[i];
		if (reg->type == type && reg->id == dst_reg->id)
			/* keep the maximum range already checked */
			reg->range = max(reg->range, new_range);
	}

	bpf_for_each_spilled_reg(i, state, reg) {
		if (!reg)
			continue;
		if (reg->type == type && reg->id == dst_reg->id)
			reg->range = max(reg->range, new_range);
	}
}

static void find_good_pkt_pointers(struct bpf_verifier_state *vstate,
				   struct bpf_reg_state *dst_reg,
				   enum bpf_reg_type type,
				   bool range_right_open)
{
	int new_range, i;
	struct bpf_func_state *state;
	struct bpf_reg_state *reg;
	int new_range;

	if (dst_reg->off < 0 ||
	    (dst_reg->off == 0 && range_right_open))
@@ -9427,9 +9374,11 @@ static void find_good_pkt_pointers(struct bpf_verifier_state *vstate,
	 * the range won't allow anything.
	 * dst_reg->off is known < MAX_PACKET_OFF, therefore it fits in a u16.
	 */
	for (i = 0; i <= vstate->curframe; i++)
		__find_good_pkt_pointers(vstate->frame[i], dst_reg, type,
					 new_range);
	bpf_for_each_reg_in_vstate(vstate, state, reg, ({
		if (reg->type == type && reg->id == dst_reg->id)
			/* keep the maximum range already checked */
			reg->range = max(reg->range, new_range);
	}));
}

static int is_branch32_taken(struct bpf_reg_state *reg, u32 val, u8 opcode)
@@ -9918,7 +9867,7 @@ static void mark_ptr_or_null_reg(struct bpf_func_state *state,

		if (!reg_may_point_to_spin_lock(reg)) {
			/* For not-NULL ptr, reg->ref_obj_id will be reset
			 * in release_reg_references().
			 * in release_reference().
			 *
			 * reg->id is still used by spin_lock ptr. Other
			 * than spin_lock ptr type, reg->id can be reset.
@@ -9928,22 +9877,6 @@ static void mark_ptr_or_null_reg(struct bpf_func_state *state,
	}
}

static void __mark_ptr_or_null_regs(struct bpf_func_state *state, u32 id,
				    bool is_null)
{
	struct bpf_reg_state *reg;
	int i;

	for (i = 0; i < MAX_BPF_REG; i++)
		mark_ptr_or_null_reg(state, &state->regs[i], id, is_null);

	bpf_for_each_spilled_reg(i, state, reg) {
		if (!reg)
			continue;
		mark_ptr_or_null_reg(state, reg, id, is_null);
	}
}

/* The logic is similar to find_good_pkt_pointers(), both could eventually
 * be folded together at some point.
 */
@@ -9951,10 +9884,9 @@ static void mark_ptr_or_null_regs(struct bpf_verifier_state *vstate, u32 regno,
				  bool is_null)
{
	struct bpf_func_state *state = vstate->frame[vstate->curframe];
	struct bpf_reg_state *regs = state->regs;
	struct bpf_reg_state *regs = state->regs, *reg;
	u32 ref_obj_id = regs[regno].ref_obj_id;
	u32 id = regs[regno].id;
	int i;

	if (ref_obj_id && ref_obj_id == id && is_null)
		/* regs[regno] is in the " == NULL" branch.
@@ -9963,8 +9895,9 @@ static void mark_ptr_or_null_regs(struct bpf_verifier_state *vstate, u32 regno,
		 */
		WARN_ON_ONCE(release_reference_state(state, id));

	for (i = 0; i <= vstate->curframe; i++)
		__mark_ptr_or_null_regs(vstate->frame[i], id, is_null);
	bpf_for_each_reg_in_vstate(vstate, state, reg, ({
		mark_ptr_or_null_reg(state, reg, id, is_null);
	}));
}

static bool try_match_pkt_pointers(const struct bpf_insn *insn,
@@ -10077,23 +10010,11 @@ static void find_equal_scalars(struct bpf_verifier_state *vstate,
{
	struct bpf_func_state *state;
	struct bpf_reg_state *reg;
	int i, j;

	for (i = 0; i <= vstate->curframe; i++) {
		state = vstate->frame[i];
		for (j = 0; j < MAX_BPF_REG; j++) {
			reg = &state->regs[j];
			if (reg->type == SCALAR_VALUE && reg->id == known_reg->id)
				*reg = *known_reg;
		}

		bpf_for_each_spilled_reg(j, state, reg) {
			if (!reg)
				continue;
	bpf_for_each_reg_in_vstate(vstate, state, reg, ({
		if (reg->type == SCALAR_VALUE && reg->id == known_reg->id)
			*reg = *known_reg;
		}
	}
	}));
}

static int check_cond_jmp_op(struct bpf_verifier_env *env,