Commit 7b805819 authored by Alexei Starovoitov's avatar Alexei Starovoitov
Browse files

Merge branch 'fix-ktls-with-sk_skb_verdict'



John Fastabend says:

====================
If a socket is running a BPF_SK_SKB_SREAM_VERDICT program and KTLS is
enabled the data stream may be broken if both TLS stream parser and
BPF stream parser try to handle data. Fix this here by making KTLS
stream parser run first to ensure TLS messages are received correctly
and then calling the verdict program. This analogous to how we handle
a similar conflict on the TX side.

Note, this is a fix but it doesn't make sense to push this late to
bpf tree so targeting bpf-next and keeping fixes tags.
====================

Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parents df8fe57c 463bac5f
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -437,4 +437,12 @@ static inline void psock_progs_drop(struct sk_psock_progs *progs)
	psock_set_prog(&progs->skb_verdict, NULL);
}

int sk_psock_tls_strp_read(struct sk_psock *psock, struct sk_buff *skb);

static inline bool sk_psock_strp_enabled(struct sk_psock *psock)
{
	if (!psock)
		return false;
	return psock->parser.enabled;
}
#endif /* _LINUX_SKMSG_H */
+9 −0
Original line number Diff line number Diff line
@@ -571,6 +571,15 @@ static inline bool tls_sw_has_ctx_tx(const struct sock *sk)
	return !!tls_sw_ctx_tx(ctx);
}

static inline bool tls_sw_has_ctx_rx(const struct sock *sk)
{
	struct tls_context *ctx = tls_get_ctx(sk);

	if (!ctx)
		return false;
	return !!tls_sw_ctx_rx(ctx);
}

void tls_sw_write_space(struct sock *sk, struct tls_context *ctx);
void tls_device_write_space(struct sock *sk, struct tls_context *ctx);

+74 −24
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@

#include <net/sock.h>
#include <net/tcp.h>
#include <net/tls.h>

static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce)
{
@@ -682,13 +683,75 @@ static struct sk_psock *sk_psock_from_strp(struct strparser *strp)
	return container_of(parser, struct sk_psock, parser);
}

static void sk_psock_verdict_apply(struct sk_psock *psock,
				   struct sk_buff *skb, int verdict)
static void sk_psock_skb_redirect(struct sk_psock *psock, struct sk_buff *skb)
{
	struct sk_psock *psock_other;
	struct sock *sk_other;
	bool ingress;

	sk_other = tcp_skb_bpf_redirect_fetch(skb);
	if (unlikely(!sk_other)) {
		kfree_skb(skb);
		return;
	}
	psock_other = sk_psock(sk_other);
	if (!psock_other || sock_flag(sk_other, SOCK_DEAD) ||
	    !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) {
		kfree_skb(skb);
		return;
	}

	ingress = tcp_skb_bpf_ingress(skb);
	if ((!ingress && sock_writeable(sk_other)) ||
	    (ingress &&
	     atomic_read(&sk_other->sk_rmem_alloc) <=
	     sk_other->sk_rcvbuf)) {
		if (!ingress)
			skb_set_owner_w(skb, sk_other);
		skb_queue_tail(&psock_other->ingress_skb, skb);
		schedule_work(&psock_other->work);
	} else {
		kfree_skb(skb);
	}
}

static void sk_psock_tls_verdict_apply(struct sk_psock *psock,
				       struct sk_buff *skb, int verdict)
{
	switch (verdict) {
	case __SK_REDIRECT:
		sk_psock_skb_redirect(psock, skb);
		break;
	case __SK_PASS:
	case __SK_DROP:
	default:
		break;
	}
}

int sk_psock_tls_strp_read(struct sk_psock *psock, struct sk_buff *skb)
{
	struct bpf_prog *prog;
	int ret = __SK_PASS;

	rcu_read_lock();
	prog = READ_ONCE(psock->progs.skb_verdict);
	if (likely(prog)) {
		tcp_skb_bpf_redirect_clear(skb);
		ret = sk_psock_bpf_run(psock, prog, skb);
		ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb));
	}
	rcu_read_unlock();
	sk_psock_tls_verdict_apply(psock, skb, ret);
	return ret;
}
EXPORT_SYMBOL_GPL(sk_psock_tls_strp_read);

static void sk_psock_verdict_apply(struct sk_psock *psock,
				   struct sk_buff *skb, int verdict)
{
	struct sock *sk_other;

	switch (verdict) {
	case __SK_PASS:
		sk_other = psock->sk;
@@ -707,25 +770,8 @@ static void sk_psock_verdict_apply(struct sk_psock *psock,
		}
		goto out_free;
	case __SK_REDIRECT:
		sk_other = tcp_skb_bpf_redirect_fetch(skb);
		if (unlikely(!sk_other))
			goto out_free;
		psock_other = sk_psock(sk_other);
		if (!psock_other || sock_flag(sk_other, SOCK_DEAD) ||
		    !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED))
			goto out_free;
		ingress = tcp_skb_bpf_ingress(skb);
		if ((!ingress && sock_writeable(sk_other)) ||
		    (ingress &&
		     atomic_read(&sk_other->sk_rmem_alloc) <=
		     sk_other->sk_rcvbuf)) {
			if (!ingress)
				skb_set_owner_w(skb, sk_other);
			skb_queue_tail(&psock_other->ingress_skb, skb);
			schedule_work(&psock_other->work);
		sk_psock_skb_redirect(psock, skb);
		break;
		}
		/* fall-through */
	case __SK_DROP:
		/* fall-through */
	default:
@@ -779,10 +825,14 @@ static void sk_psock_strp_data_ready(struct sock *sk)
	rcu_read_lock();
	psock = sk_psock(sk);
	if (likely(psock)) {
		if (tls_sw_has_ctx_rx(sk)) {
			psock->parser.saved_data_ready(sk);
		} else {
			write_lock_bh(&sk->sk_callback_lock);
			strp_data_ready(&psock->parser.strp);
			write_unlock_bh(&sk->sk_callback_lock);
		}
	}
	rcu_read_unlock();
}

+18 −2
Original line number Diff line number Diff line
@@ -1742,6 +1742,7 @@ int tls_sw_recvmsg(struct sock *sk,
	long timeo;
	bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
	bool is_peek = flags & MSG_PEEK;
	bool bpf_strp_enabled;
	int num_async = 0;
	int pending;

@@ -1752,6 +1753,7 @@ int tls_sw_recvmsg(struct sock *sk,

	psock = sk_psock_get(sk);
	lock_sock(sk);
	bpf_strp_enabled = sk_psock_strp_enabled(psock);

	/* Process pending decrypted records. It must be non-zero-copy */
	err = process_rx_list(ctx, msg, &control, &cmsg, 0, len, false,
@@ -1805,11 +1807,12 @@ int tls_sw_recvmsg(struct sock *sk,

		if (to_decrypt <= len && !is_kvec && !is_peek &&
		    ctx->control == TLS_RECORD_TYPE_DATA &&
		    prot->version != TLS_1_3_VERSION)
		    prot->version != TLS_1_3_VERSION &&
		    !bpf_strp_enabled)
			zc = true;

		/* Do not use async mode if record is non-data */
		if (ctx->control == TLS_RECORD_TYPE_DATA)
		if (ctx->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
			async_capable = ctx->async_capable;
		else
			async_capable = false;
@@ -1859,6 +1862,19 @@ int tls_sw_recvmsg(struct sock *sk,
			goto pick_next_record;

		if (!zc) {
			if (bpf_strp_enabled) {
				err = sk_psock_tls_strp_read(psock, skb);
				if (err != __SK_PASS) {
					rxm->offset = rxm->offset + rxm->full_len;
					rxm->full_len = 0;
					if (err == __SK_DROP)
						consume_skb(skb);
					ctx->recv_pkt = NULL;
					__strp_unpause(&ctx->strp);
					continue;
				}
			}

			if (rxm->full_len > len) {
				retain_skb = true;
				chunk = len;
+45 −1
Original line number Diff line number Diff line
@@ -79,11 +79,18 @@ struct {

struct {
	__uint(type, BPF_MAP_TYPE_ARRAY);
	__uint(max_entries, 1);
	__uint(max_entries, 2);
	__type(key, int);
	__type(value, int);
} sock_skb_opts SEC(".maps");

struct {
	__uint(type, TEST_MAP_TYPE);
	__uint(max_entries, 20);
	__uint(key_size, sizeof(int));
	__uint(value_size, sizeof(int));
} tls_sock_map SEC(".maps");

SEC("sk_skb1")
int bpf_prog1(struct __sk_buff *skb)
{
@@ -118,6 +125,43 @@ int bpf_prog2(struct __sk_buff *skb)

}

SEC("sk_skb3")
int bpf_prog3(struct __sk_buff *skb)
{
	const int one = 1;
	int err, *f, ret = SK_PASS;
	void *data_end;
	char *c;

	err = bpf_skb_pull_data(skb, 19);
	if (err)
		goto tls_out;

	c = (char *)(long)skb->data;
	data_end = (void *)(long)skb->data_end;

	if (c + 18 < data_end)
		memcpy(&c[13], "PASS", 4);
	f = bpf_map_lookup_elem(&sock_skb_opts, &one);
	if (f && *f) {
		__u64 flags = 0;

		ret = 0;
		flags = *f;
#ifdef SOCKMAP
		return bpf_sk_redirect_map(skb, &tls_sock_map, ret, flags);
#else
		return bpf_sk_redirect_hash(skb, &tls_sock_map, &ret, flags);
#endif
	}

	f = bpf_map_lookup_elem(&sock_skb_opts, &one);
	if (f && *f)
		ret = SK_DROP;
tls_out:
	return ret;
}

SEC("sockops")
int bpf_sockmap(struct bpf_sock_ops *skops)
{
Loading