Commit 252c765b authored by Tong Tiangen's avatar Tong Tiangen Committed by Daniel Borkmann
Browse files

riscv, bpf: Add BPF exception tables



When a tracing BPF program attempts to read memory without using the
bpf_probe_read() helper, the verifier marks the load instruction with
the BPF_PROBE_MEM flag. Since the riscv JIT does not currently recognize
this flag it falls back to the interpreter.

Add support for BPF_PROBE_MEM, by appending an exception table to the
BPF program. If the load instruction causes a data abort, the fixup
infrastructure finds the exception table and fixes up the fault, by
clearing the destination register and jumping over the faulting
instruction.

A more generic solution would add a "handler" field to the table entry,
like on x86 and s390. The same issue in ARM64 is fixed in 80083428
("bpf, arm64: Add BPF exception tables").

Signed-off-by: default avatarTong Tiangen <tongtiangen@huawei.com>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
Tested-by: default avatarPu Lehui <pulehui@huawei.com>
Tested-by: default avatarBjörn Töpel <bjorn@kernel.org>
Acked-by: default avatarBjörn Töpel <bjorn@kernel.org>
Link: https://lore.kernel.org/bpf/20211027111822.3801679-1-tongtiangen@huawei.com
parent 03e6a7a9
Loading
Loading
Loading
Loading
+14 −5
Original line number Diff line number Diff line
@@ -11,14 +11,23 @@
#include <linux/module.h>
#include <linux/uaccess.h>

#ifdef CONFIG_BPF_JIT
int rv_bpf_fixup_exception(const struct exception_table_entry *ex, struct pt_regs *regs);
#endif

int fixup_exception(struct pt_regs *regs)
{
	const struct exception_table_entry *fixup;

	fixup = search_exception_tables(regs->epc);
	if (fixup) {
	if (!fixup)
		return 0;

#ifdef CONFIG_BPF_JIT
	if (regs->epc >= BPF_JIT_REGION_START && regs->epc < BPF_JIT_REGION_END)
		return rv_bpf_fixup_exception(fixup, regs);
#endif

	regs->epc = fixup->fixup;
	return 1;
}
	return 0;
}
+1 −0
Original line number Diff line number Diff line
@@ -71,6 +71,7 @@ struct rv_jit_context {
	int ninsns;
	int epilogue_offset;
	int *offset;		/* BPF to RV */
	int nexentries;
	unsigned long flags;
	int stack_size;
};
+149 −36
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@
 *
 */

#include <linux/bitfield.h>
#include <linux/bpf.h>
#include <linux/filter.h>
#include "bpf_jit.h"
@@ -27,6 +28,21 @@ static const int regmap[] = {
	[BPF_REG_AX] =	RV_REG_T0,
};

static const int pt_regmap[] = {
	[RV_REG_A0] = offsetof(struct pt_regs, a0),
	[RV_REG_A1] = offsetof(struct pt_regs, a1),
	[RV_REG_A2] = offsetof(struct pt_regs, a2),
	[RV_REG_A3] = offsetof(struct pt_regs, a3),
	[RV_REG_A4] = offsetof(struct pt_regs, a4),
	[RV_REG_A5] = offsetof(struct pt_regs, a5),
	[RV_REG_S1] = offsetof(struct pt_regs, s1),
	[RV_REG_S2] = offsetof(struct pt_regs, s2),
	[RV_REG_S3] = offsetof(struct pt_regs, s3),
	[RV_REG_S4] = offsetof(struct pt_regs, s4),
	[RV_REG_S5] = offsetof(struct pt_regs, s5),
	[RV_REG_T0] = offsetof(struct pt_regs, t0),
};

enum {
	RV_CTX_F_SEEN_TAIL_CALL =	0,
	RV_CTX_F_SEEN_CALL =		RV_REG_RA,
@@ -440,6 +456,69 @@ static int emit_call(bool fixed, u64 addr, struct rv_jit_context *ctx)
	return 0;
}

#define BPF_FIXUP_OFFSET_MASK   GENMASK(26, 0)
#define BPF_FIXUP_REG_MASK      GENMASK(31, 27)

int rv_bpf_fixup_exception(const struct exception_table_entry *ex,
				struct pt_regs *regs)
{
	off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
	int regs_offset = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);

	*(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
	regs->epc = (unsigned long)&ex->fixup - offset;

	return 1;
}

/* For accesses to BTF pointers, add an entry to the exception table */
static int add_exception_handler(const struct bpf_insn *insn,
				 struct rv_jit_context *ctx,
				 int dst_reg, int insn_len)
{
	struct exception_table_entry *ex;
	unsigned long pc;
	off_t offset;

	if (!ctx->insns || !ctx->prog->aux->extable || BPF_MODE(insn->code) != BPF_PROBE_MEM)
		return 0;

	if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries))
		return -EINVAL;

	if (WARN_ON_ONCE(insn_len > ctx->ninsns))
		return -EINVAL;

	if (WARN_ON_ONCE(!rvc_enabled() && insn_len == 1))
		return -EINVAL;

	ex = &ctx->prog->aux->extable[ctx->nexentries];
	pc = (unsigned long)&ctx->insns[ctx->ninsns - insn_len];

	offset = pc - (long)&ex->insn;
	if (WARN_ON_ONCE(offset >= 0 || offset < INT_MIN))
		return -ERANGE;
	ex->insn = pc;

	/*
	 * Since the extable follows the program, the fixup offset is always
	 * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
	 * to keep things simple, and put the destination register in the upper
	 * bits. We don't need to worry about buildtime or runtime sort
	 * modifying the upper bits because the table is already sorted, and
	 * isn't part of the main exception table.
	 */
	offset = (long)&ex->fixup - (pc + insn_len * sizeof(u16));
	if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, offset))
		return -ERANGE;

	ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, offset) |
		FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);

	ctx->nexentries++;
	return 0;
}

int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
		      bool extra_pass)
{
@@ -893,52 +972,86 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,

	/* LDX: dst = *(size *)(src + off) */
	case BPF_LDX | BPF_MEM | BPF_B:
	case BPF_LDX | BPF_MEM | BPF_H:
	case BPF_LDX | BPF_MEM | BPF_W:
	case BPF_LDX | BPF_MEM | BPF_DW:
	case BPF_LDX | BPF_PROBE_MEM | BPF_B:
	case BPF_LDX | BPF_PROBE_MEM | BPF_H:
	case BPF_LDX | BPF_PROBE_MEM | BPF_W:
	case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
	{
		int insn_len, insns_start;

		switch (BPF_SIZE(code)) {
		case BPF_B:
			if (is_12b_int(off)) {
				insns_start = ctx->ninsns;
				emit(rv_lbu(rd, off, rs), ctx);
				insn_len = ctx->ninsns - insns_start;
				break;
			}

			emit_imm(RV_REG_T1, off, ctx);
			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
			insns_start = ctx->ninsns;
			emit(rv_lbu(rd, 0, RV_REG_T1), ctx);
			insn_len = ctx->ninsns - insns_start;
			if (insn_is_zext(&insn[1]))
				return 1;
			break;
	case BPF_LDX | BPF_MEM | BPF_H:
		case BPF_H:
			if (is_12b_int(off)) {
				insns_start = ctx->ninsns;
				emit(rv_lhu(rd, off, rs), ctx);
				insn_len = ctx->ninsns - insns_start;
				break;
			}

			emit_imm(RV_REG_T1, off, ctx);
			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
			insns_start = ctx->ninsns;
			emit(rv_lhu(rd, 0, RV_REG_T1), ctx);
			insn_len = ctx->ninsns - insns_start;
			if (insn_is_zext(&insn[1]))
				return 1;
			break;
	case BPF_LDX | BPF_MEM | BPF_W:
		case BPF_W:
			if (is_12b_int(off)) {
				insns_start = ctx->ninsns;
				emit(rv_lwu(rd, off, rs), ctx);
				insn_len = ctx->ninsns - insns_start;
				break;
			}

			emit_imm(RV_REG_T1, off, ctx);
			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
			insns_start = ctx->ninsns;
			emit(rv_lwu(rd, 0, RV_REG_T1), ctx);
			insn_len = ctx->ninsns - insns_start;
			if (insn_is_zext(&insn[1]))
				return 1;
			break;
	case BPF_LDX | BPF_MEM | BPF_DW:
		case BPF_DW:
			if (is_12b_int(off)) {
				insns_start = ctx->ninsns;
				emit_ld(rd, off, rs, ctx);
				insn_len = ctx->ninsns - insns_start;
				break;
			}

			emit_imm(RV_REG_T1, off, ctx);
			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
			insns_start = ctx->ninsns;
			emit_ld(rd, 0, RV_REG_T1, ctx);
			insn_len = ctx->ninsns - insns_start;
			break;
		}

		ret = add_exception_handler(insn, ctx, rd, insn_len);
		if (ret)
			return ret;
		break;
	}
	/* speculation barrier */
	case BPF_ST | BPF_NOSPEC:
		break;
+13 −6
Original line number Diff line number Diff line
@@ -41,12 +41,12 @@ bool bpf_jit_needs_zext(void)

struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
{
	unsigned int prog_size = 0, extable_size = 0;
	bool tmp_blinded = false, extra_pass = false;
	struct bpf_prog *tmp, *orig_prog = prog;
	int pass = 0, prev_ninsns = 0, i;
	struct rv_jit_data *jit_data;
	struct rv_jit_context *ctx;
	unsigned int image_size = 0;

	if (!prog->jit_requested)
		return orig_prog;
@@ -73,7 +73,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)

	if (ctx->offset) {
		extra_pass = true;
		image_size = sizeof(*ctx->insns) * ctx->ninsns;
		prog_size = sizeof(*ctx->insns) * ctx->ninsns;
		goto skip_init_ctx;
	}

@@ -102,10 +102,13 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
		if (ctx->ninsns == prev_ninsns) {
			if (jit_data->header)
				break;
			/* obtain the actual image size */
			extable_size = prog->aux->num_exentries *
				sizeof(struct exception_table_entry);
			prog_size = sizeof(*ctx->insns) * ctx->ninsns;

			image_size = sizeof(*ctx->insns) * ctx->ninsns;
			jit_data->header =
				bpf_jit_binary_alloc(image_size,
				bpf_jit_binary_alloc(prog_size + extable_size,
						     &jit_data->image,
						     sizeof(u32),
						     bpf_fill_ill_insns);
@@ -130,9 +133,13 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
		goto out_offset;
	}

	if (extable_size)
		prog->aux->extable = (void *)ctx->insns + prog_size;

skip_init_ctx:
	pass++;
	ctx->ninsns = 0;
	ctx->nexentries = 0;

	bpf_jit_build_prologue(ctx);
	if (build_body(ctx, extra_pass, NULL)) {
@@ -143,11 +150,11 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
	bpf_jit_build_epilogue(ctx);

	if (bpf_jit_enable > 1)
		bpf_jit_dump(prog->len, image_size, pass, ctx->insns);
		bpf_jit_dump(prog->len, prog_size, pass, ctx->insns);

	prog->bpf_func = (void *)ctx->insns;
	prog->jited = 1;
	prog->jited_len = image_size;
	prog->jited_len = prog_size;

	bpf_flush_icache(jit_data->header, ctx->insns + ctx->ninsns);