Commit 2bc793e3 authored by Cong Wang's avatar Cong Wang Committed by Alexei Starovoitov
Browse files

skmsg: Extract __tcp_bpf_recvmsg() and tcp_bpf_wait_data()



Although these two functions are only used by TCP, they are not
specific to TCP at all, both operate on skmsg and ingress_msg,
so fit in net/core/skmsg.c very well.

And we will need them for non-TCP, so rename and move them to
skmsg.c and export them to modules.

Signed-off-by: default avatarCong Wang <cong.wang@bytedance.com>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
Link: https://lore.kernel.org/bpf/20210331023237.41094-13-xiyou.wangcong@gmail.com
parent d7f57118
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -125,6 +125,10 @@ int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
			      struct sk_msg *msg, u32 bytes);
int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
			     struct sk_msg *msg, u32 bytes);
int sk_msg_wait_data(struct sock *sk, struct sk_psock *psock, int flags,
		     long timeo, int *err);
int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
		   int len, int flags);

static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
{
+0 −2
Original line number Diff line number Diff line
@@ -2209,8 +2209,6 @@ void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);

int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes,
			  int flags);
int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
		      struct msghdr *msg, int len, int flags);
#endif /* CONFIG_NET_SOCK_MSG */

#if !defined(CONFIG_BPF_SYSCALL) || !defined(CONFIG_NET_SOCK_MSG)
+98 −0
Original line number Diff line number Diff line
@@ -399,6 +399,104 @@ int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
}
EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);

int sk_msg_wait_data(struct sock *sk, struct sk_psock *psock, int flags,
		     long timeo, int *err)
{
	DEFINE_WAIT_FUNC(wait, woken_wake_function);
	int ret = 0;

	if (sk->sk_shutdown & RCV_SHUTDOWN)
		return 1;

	if (!timeo)
		return ret;

	add_wait_queue(sk_sleep(sk), &wait);
	sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
	ret = sk_wait_event(sk, &timeo,
			    !list_empty(&psock->ingress_msg) ||
			    !skb_queue_empty(&sk->sk_receive_queue), &wait);
	sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
	remove_wait_queue(sk_sleep(sk), &wait);
	return ret;
}
EXPORT_SYMBOL_GPL(sk_msg_wait_data);

/* Receive sk_msg from psock->ingress_msg to @msg. */
int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
		   int len, int flags)
{
	struct iov_iter *iter = &msg->msg_iter;
	int peek = flags & MSG_PEEK;
	struct sk_msg *msg_rx;
	int i, copied = 0;

	msg_rx = sk_psock_peek_msg(psock);
	while (copied != len) {
		struct scatterlist *sge;

		if (unlikely(!msg_rx))
			break;

		i = msg_rx->sg.start;
		do {
			struct page *page;
			int copy;

			sge = sk_msg_elem(msg_rx, i);
			copy = sge->length;
			page = sg_page(sge);
			if (copied + copy > len)
				copy = len - copied;
			copy = copy_page_to_iter(page, sge->offset, copy, iter);
			if (!copy)
				return copied ? copied : -EFAULT;

			copied += copy;
			if (likely(!peek)) {
				sge->offset += copy;
				sge->length -= copy;
				if (!msg_rx->skb)
					sk_mem_uncharge(sk, copy);
				msg_rx->sg.size -= copy;

				if (!sge->length) {
					sk_msg_iter_var_next(i);
					if (!msg_rx->skb)
						put_page(page);
				}
			} else {
				/* Lets not optimize peek case if copy_page_to_iter
				 * didn't copy the entire length lets just break.
				 */
				if (copy != sge->length)
					return copied;
				sk_msg_iter_var_next(i);
			}

			if (copied == len)
				break;
		} while (i != msg_rx->sg.end);

		if (unlikely(peek)) {
			msg_rx = sk_psock_next_msg(psock, msg_rx);
			if (!msg_rx)
				break;
			continue;
		}

		msg_rx->sg.start = i;
		if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
			msg_rx = sk_psock_dequeue_msg(psock);
			kfree_sk_msg(msg_rx);
		}
		msg_rx = sk_psock_peek_msg(psock);
	}

	return copied;
}
EXPORT_SYMBOL_GPL(sk_msg_recvmsg);

static struct sk_msg *sk_psock_create_ingress_msg(struct sock *sk,
						  struct sk_buff *skb)
{
+2 −98
Original line number Diff line number Diff line
@@ -10,80 +10,6 @@
#include <net/inet_common.h>
#include <net/tls.h>

int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
		      struct msghdr *msg, int len, int flags)
{
	struct iov_iter *iter = &msg->msg_iter;
	int peek = flags & MSG_PEEK;
	struct sk_msg *msg_rx;
	int i, copied = 0;

	msg_rx = sk_psock_peek_msg(psock);
	while (copied != len) {
		struct scatterlist *sge;

		if (unlikely(!msg_rx))
			break;

		i = msg_rx->sg.start;
		do {
			struct page *page;
			int copy;

			sge = sk_msg_elem(msg_rx, i);
			copy = sge->length;
			page = sg_page(sge);
			if (copied + copy > len)
				copy = len - copied;
			copy = copy_page_to_iter(page, sge->offset, copy, iter);
			if (!copy)
				return copied ? copied : -EFAULT;

			copied += copy;
			if (likely(!peek)) {
				sge->offset += copy;
				sge->length -= copy;
				if (!msg_rx->skb)
					sk_mem_uncharge(sk, copy);
				msg_rx->sg.size -= copy;

				if (!sge->length) {
					sk_msg_iter_var_next(i);
					if (!msg_rx->skb)
						put_page(page);
				}
			} else {
				/* Lets not optimize peek case if copy_page_to_iter
				 * didn't copy the entire length lets just break.
				 */
				if (copy != sge->length)
					return copied;
				sk_msg_iter_var_next(i);
			}

			if (copied == len)
				break;
		} while (i != msg_rx->sg.end);

		if (unlikely(peek)) {
			msg_rx = sk_psock_next_msg(psock, msg_rx);
			if (!msg_rx)
				break;
			continue;
		}

		msg_rx->sg.start = i;
		if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
			msg_rx = sk_psock_dequeue_msg(psock);
			kfree_sk_msg(msg_rx);
		}
		msg_rx = sk_psock_peek_msg(psock);
	}

	return copied;
}
EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg);

static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
			   struct sk_msg *msg, u32 apply_bytes, int flags)
{
@@ -237,28 +163,6 @@ static bool tcp_bpf_stream_read(const struct sock *sk)
	return !empty;
}

static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock,
			     int flags, long timeo, int *err)
{
	DEFINE_WAIT_FUNC(wait, woken_wake_function);
	int ret = 0;

	if (sk->sk_shutdown & RCV_SHUTDOWN)
		return 1;

	if (!timeo)
		return ret;

	add_wait_queue(sk_sleep(sk), &wait);
	sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
	ret = sk_wait_event(sk, &timeo,
			    !list_empty(&psock->ingress_msg) ||
			    !skb_queue_empty(&sk->sk_receive_queue), &wait);
	sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
	remove_wait_queue(sk_sleep(sk), &wait);
	return ret;
}

static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
		    int nonblock, int flags, int *addr_len)
{
@@ -278,13 +182,13 @@ static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
	}
	lock_sock(sk);
msg_bytes_ready:
	copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags);
	copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
	if (!copied) {
		int data, err = 0;
		long timeo;

		timeo = sock_rcvtimeo(sk, nonblock);
		data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err);
		data = sk_msg_wait_data(sk, psock, flags, timeo, &err);
		if (data) {
			if (!sk_psock_queue_empty(psock))
				goto msg_bytes_ready;
+2 −2
Original line number Diff line number Diff line
@@ -1789,8 +1789,8 @@ int tls_sw_recvmsg(struct sock *sk,
		skb = tls_wait_data(sk, psock, flags, timeo, &err);
		if (!skb) {
			if (psock) {
				int ret = __tcp_bpf_recvmsg(sk, psock,
							    msg, len, flags);
				int ret = sk_msg_recvmsg(sk, psock, msg, len,
							 flags);

				if (ret > 0) {
					decrypted += ret;