Commit b01fd6e8 authored by Cong Wang's avatar Cong Wang Committed by Alexei Starovoitov
Browse files

skmsg: Introduce a spinlock to protect ingress_msg



Currently we rely on lock_sock to protect ingress_msg,
it is too big for this, we can actually just use a spinlock
to protect this list like protecting other skb queues.

__tcp_bpf_recvmsg() is still special because of peeking,
it still has to use lock_sock.

Signed-off-by: default avatarCong Wang <cong.wang@bytedance.com>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
Acked-by: default avatarJakub Sitnicki <jakub@cloudflare.com>
Acked-by: default avatarJohn Fastabend <john.fastabend@gmail.com>
Link: https://lore.kernel.org/bpf/20210331023237.41094-3-xiyou.wangcong@gmail.com
parent 37f0e514
Loading
Loading
Loading
Loading
+46 −0
Original line number Diff line number Diff line
@@ -89,6 +89,7 @@ struct sk_psock {
#endif
	struct sk_buff_head		ingress_skb;
	struct list_head		ingress_msg;
	spinlock_t			ingress_lock;
	unsigned long			state;
	struct list_head		link;
	spinlock_t			link_lock;
@@ -284,7 +285,45 @@ static inline struct sk_psock *sk_psock(const struct sock *sk)
static inline void sk_psock_queue_msg(struct sk_psock *psock,
				      struct sk_msg *msg)
{
	spin_lock_bh(&psock->ingress_lock);
	list_add_tail(&msg->list, &psock->ingress_msg);
	spin_unlock_bh(&psock->ingress_lock);
}

static inline struct sk_msg *sk_psock_dequeue_msg(struct sk_psock *psock)
{
	struct sk_msg *msg;

	spin_lock_bh(&psock->ingress_lock);
	msg = list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
	if (msg)
		list_del(&msg->list);
	spin_unlock_bh(&psock->ingress_lock);
	return msg;
}

static inline struct sk_msg *sk_psock_peek_msg(struct sk_psock *psock)
{
	struct sk_msg *msg;

	spin_lock_bh(&psock->ingress_lock);
	msg = list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
	spin_unlock_bh(&psock->ingress_lock);
	return msg;
}

static inline struct sk_msg *sk_psock_next_msg(struct sk_psock *psock,
					       struct sk_msg *msg)
{
	struct sk_msg *ret;

	spin_lock_bh(&psock->ingress_lock);
	if (list_is_last(&msg->list, &psock->ingress_msg))
		ret = NULL;
	else
		ret = list_next_entry(msg, list);
	spin_unlock_bh(&psock->ingress_lock);
	return ret;
}

static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
@@ -292,6 +331,13 @@ static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
	return psock ? list_empty(&psock->ingress_msg) : true;
}

static inline void kfree_sk_msg(struct sk_msg *msg)
{
	if (msg->skb)
		consume_skb(msg->skb);
	kfree(msg);
}

static inline void sk_psock_report_error(struct sk_psock *psock, int err)
{
	struct sock *sk = psock->sk;
+3 −0
Original line number Diff line number Diff line
@@ -592,6 +592,7 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node)

	INIT_WORK(&psock->work, sk_psock_backlog);
	INIT_LIST_HEAD(&psock->ingress_msg);
	spin_lock_init(&psock->ingress_lock);
	skb_queue_head_init(&psock->ingress_skb);

	sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED);
@@ -638,7 +639,9 @@ static void sk_psock_zap_ingress(struct sk_psock *psock)
		skb_bpf_redirect_clear(skb);
		kfree_skb(skb);
	}
	spin_lock_bh(&psock->ingress_lock);
	__sk_psock_purge_ingress_msg(psock);
	spin_unlock_bh(&psock->ingress_lock);
}

static void sk_psock_link_destroy(struct sk_psock *psock)
+6 −12
Original line number Diff line number Diff line
@@ -18,9 +18,7 @@ int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
	struct sk_msg *msg_rx;
	int i, copied = 0;

	msg_rx = list_first_entry_or_null(&psock->ingress_msg,
					  struct sk_msg, list);

	msg_rx = sk_psock_peek_msg(psock);
	while (copied != len) {
		struct scatterlist *sge;

@@ -68,22 +66,18 @@ int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
		} while (i != msg_rx->sg.end);

		if (unlikely(peek)) {
			if (msg_rx == list_last_entry(&psock->ingress_msg,
						      struct sk_msg, list))
			msg_rx = sk_psock_next_msg(psock, msg_rx);
			if (!msg_rx)
				break;
			msg_rx = list_next_entry(msg_rx, list);
			continue;
		}

		msg_rx->sg.start = i;
		if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
			list_del(&msg_rx->list);
			if (msg_rx->skb)
				consume_skb(msg_rx->skb);
			kfree(msg_rx);
			msg_rx = sk_psock_dequeue_msg(psock);
			kfree_sk_msg(msg_rx);
		}
		msg_rx = list_first_entry_or_null(&psock->ingress_msg,
						  struct sk_msg, list);
		msg_rx = sk_psock_peek_msg(psock);
	}

	return copied;