Commit 797b84f7 authored by Martin KaFai Lau's avatar Martin KaFai Lau Committed by Alexei Starovoitov
Browse files

bpf: Support kernel function call in x86-32



This patch adds kernel function call support to the x86-32 bpf jit.

Signed-off-by: default avatarMartin KaFai Lau <kafai@fb.com>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
Link: https://lore.kernel.org/bpf/20210325015149.1545267-1-kafai@fb.com
parent e6ac2450
Loading
Loading
Loading
Loading
+198 −0
Original line number Diff line number Diff line
@@ -1390,6 +1390,19 @@ static inline void emit_push_r64(const u8 src[], u8 **pprog)
	*pprog = prog;
}

static void emit_push_r32(const u8 src[], u8 **pprog)
{
	u8 *prog = *pprog;
	int cnt = 0;

	/* mov ecx,dword ptr [ebp+off] */
	EMIT3(0x8B, add_2reg(0x40, IA32_EBP, IA32_ECX), STACK_VAR(src_lo));
	/* push ecx */
	EMIT1(0x51);

	*pprog = prog;
}

static u8 get_cond_jmp_opcode(const u8 op, bool is_cmp_lo)
{
	u8 jmp_cond;
@@ -1459,6 +1472,174 @@ static u8 get_cond_jmp_opcode(const u8 op, bool is_cmp_lo)
	return jmp_cond;
}

/* i386 kernel compiles with "-mregparm=3".  From gcc document:
 *
 * ==== snippet ====
 * regparm (number)
 *	On x86-32 targets, the regparm attribute causes the compiler
 *	to pass arguments number one to (number) if they are of integral
 *	type in registers EAX, EDX, and ECX instead of on the stack.
 *	Functions that take a variable number of arguments continue
 *	to be passed all of their arguments on the stack.
 * ==== snippet ====
 *
 * The first three args of a function will be considered for
 * putting into the 32bit register EAX, EDX, and ECX.
 *
 * Two 32bit registers are used to pass a 64bit arg.
 *
 * For example,
 * void foo(u32 a, u32 b, u32 c, u32 d):
 *	u32 a: EAX
 *	u32 b: EDX
 *	u32 c: ECX
 *	u32 d: stack
 *
 * void foo(u64 a, u32 b, u32 c):
 *	u64 a: EAX (lo32) EDX (hi32)
 *	u32 b: ECX
 *	u32 c: stack
 *
 * void foo(u32 a, u64 b, u32 c):
 *	u32 a: EAX
 *	u64 b: EDX (lo32) ECX (hi32)
 *	u32 c: stack
 *
 * void foo(u32 a, u32 b, u64 c):
 *	u32 a: EAX
 *	u32 b: EDX
 *	u64 c: stack
 *
 * The return value will be stored in the EAX (and EDX for 64bit value).
 *
 * For example,
 * u32 foo(u32 a, u32 b, u32 c):
 *	return value: EAX
 *
 * u64 foo(u32 a, u32 b, u32 c):
 *	return value: EAX (lo32) EDX (hi32)
 *
 * Notes:
 *	The verifier only accepts function having integer and pointers
 *	as its args and return value, so it does not have
 *	struct-by-value.
 *
 * emit_kfunc_call() finds out the btf_func_model by calling
 * bpf_jit_find_kfunc_model().  A btf_func_model
 * has the details about the number of args, size of each arg,
 * and the size of the return value.
 *
 * It first decides how many args can be passed by EAX, EDX, and ECX.
 * That will decide what args should be pushed to the stack:
 * [first_stack_regno, last_stack_regno] are the bpf regnos
 * that should be pushed to the stack.
 *
 * It will first push all args to the stack because the push
 * will need to use ECX.  Then, it moves
 * [BPF_REG_1, first_stack_regno) to EAX, EDX, and ECX.
 *
 * When emitting a call (0xE8), it needs to figure out
 * the jmp_offset relative to the jit-insn address immediately
 * following the call (0xE8) instruction.  At this point, it knows
 * the end of the jit-insn address after completely translated the
 * current (BPF_JMP | BPF_CALL) bpf-insn.  It is passed as "end_addr"
 * to the emit_kfunc_call().  Thus, it can learn the "immediate-follow-call"
 * address by figuring out how many jit-insn is generated between
 * the call (0xE8) and the end_addr:
 *	- 0-1 jit-insn (3 bytes each) to restore the esp pointer if there
 *	  is arg pushed to the stack.
 *	- 0-2 jit-insns (3 bytes each) to handle the return value.
 */
static int emit_kfunc_call(const struct bpf_prog *bpf_prog, u8 *end_addr,
			   const struct bpf_insn *insn, u8 **pprog)
{
	const u8 arg_regs[] = { IA32_EAX, IA32_EDX, IA32_ECX };
	int i, cnt = 0, first_stack_regno, last_stack_regno;
	int free_arg_regs = ARRAY_SIZE(arg_regs);
	const struct btf_func_model *fm;
	int bytes_in_stack = 0;
	const u8 *cur_arg_reg;
	u8 *prog = *pprog;
	s64 jmp_offset;

	fm = bpf_jit_find_kfunc_model(bpf_prog, insn);
	if (!fm)
		return -EINVAL;

	first_stack_regno = BPF_REG_1;
	for (i = 0; i < fm->nr_args; i++) {
		int regs_needed = fm->arg_size[i] > sizeof(u32) ? 2 : 1;

		if (regs_needed > free_arg_regs)
			break;

		free_arg_regs -= regs_needed;
		first_stack_regno++;
	}

	/* Push the args to the stack */
	last_stack_regno = BPF_REG_0 + fm->nr_args;
	for (i = last_stack_regno; i >= first_stack_regno; i--) {
		if (fm->arg_size[i - 1] > sizeof(u32)) {
			emit_push_r64(bpf2ia32[i], &prog);
			bytes_in_stack += 8;
		} else {
			emit_push_r32(bpf2ia32[i], &prog);
			bytes_in_stack += 4;
		}
	}

	cur_arg_reg = &arg_regs[0];
	for (i = BPF_REG_1; i < first_stack_regno; i++) {
		/* mov e[adc]x,dword ptr [ebp+off] */
		EMIT3(0x8B, add_2reg(0x40, IA32_EBP, *cur_arg_reg++),
		      STACK_VAR(bpf2ia32[i][0]));
		if (fm->arg_size[i - 1] > sizeof(u32))
			/* mov e[adc]x,dword ptr [ebp+off] */
			EMIT3(0x8B, add_2reg(0x40, IA32_EBP, *cur_arg_reg++),
			      STACK_VAR(bpf2ia32[i][1]));
	}

	if (bytes_in_stack)
		/* add esp,"bytes_in_stack" */
		end_addr -= 3;

	/* mov dword ptr [ebp+off],edx */
	if (fm->ret_size > sizeof(u32))
		end_addr -= 3;

	/* mov dword ptr [ebp+off],eax */
	if (fm->ret_size)
		end_addr -= 3;

	jmp_offset = (u8 *)__bpf_call_base + insn->imm - end_addr;
	if (!is_simm32(jmp_offset)) {
		pr_err("unsupported BPF kernel function jmp_offset:%lld\n",
		       jmp_offset);
		return -EINVAL;
	}

	EMIT1_off32(0xE8, jmp_offset);

	if (fm->ret_size)
		/* mov dword ptr [ebp+off],eax */
		EMIT3(0x89, add_2reg(0x40, IA32_EBP, IA32_EAX),
		      STACK_VAR(bpf2ia32[BPF_REG_0][0]));

	if (fm->ret_size > sizeof(u32))
		/* mov dword ptr [ebp+off],edx */
		EMIT3(0x89, add_2reg(0x40, IA32_EBP, IA32_EDX),
		      STACK_VAR(bpf2ia32[BPF_REG_0][1]));

	if (bytes_in_stack)
		/* add esp,"bytes_in_stack" */
		EMIT3(0x83, add_1reg(0xC0, IA32_ESP), bytes_in_stack);

	*pprog = prog;

	return 0;
}

static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
		  int oldproglen, struct jit_context *ctx)
{
@@ -1888,6 +2069,18 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
			if (insn->src_reg == BPF_PSEUDO_CALL)
				goto notyet;

			if (insn->src_reg == BPF_PSEUDO_KFUNC_CALL) {
				int err;

				err = emit_kfunc_call(bpf_prog,
						      image + addrs[i],
						      insn, &prog);

				if (err)
					return err;
				break;
			}

			func = (u8 *) __bpf_call_base + imm32;
			jmp_offset = func - (image + addrs[i]);

@@ -2393,3 +2586,8 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
					   tmp : orig_prog);
	return prog;
}

bool bpf_jit_supports_kfunc_call(void)
{
	return true;
}