Commit 04a88637 authored by Jakub Kicinski's avatar Jakub Kicinski
Browse files

Merge branch 'tcp-add-cmsg-rx-timestamps-to-rx-zerocopy'

Arjun Roy says:

====================
tcp: add CMSG+rx timestamps to rx. zerocopy

Provide CMSG and receive timestamp support to TCP
receive zerocopy. Patch 1 refactors CMSG pending state for
tcp_recvmsg() to avoid the use of magic numbers; patch 2 implements
receive timestamp via CMSG support for receive zerocopy, and uses the
constants added in patch 1.
====================

Link: https://lore.kernel.org/r/20210121004148.2340206-1-arjunroy.kdev@gmail.com


Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parents 5225d5f5 7eeba170
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -354,5 +354,9 @@ struct tcp_zerocopy_receive {
	__u64 copybuf_address;	/* in: copybuf address (small reads) */
	__s32 copybuf_len; /* in/out: copybuf bytes avail/used or error */
	__u32 flags; /* in: flags */
	__u64 msg_control; /* ancillary data */
	__u64 msg_controllen;
	__u32 msg_flags;
	/* __u32 hole;  Next we must add >1 u32 otherwise length checks fail. */
};
#endif /* _UAPI_LINUX_TCP_H */
+94 −36
Original line number Diff line number Diff line
@@ -280,6 +280,12 @@
#include <asm/ioctls.h>
#include <net/busy_poll.h>

/* Track pending CMSGs. */
enum {
	TCP_CMSG_INQ = 1,
	TCP_CMSG_TS = 2
};

struct percpu_counter tcp_orphan_count;
EXPORT_SYMBOL_GPL(tcp_orphan_count);

@@ -1739,6 +1745,20 @@ int tcp_set_rcvlowat(struct sock *sk, int val)
}
EXPORT_SYMBOL(tcp_set_rcvlowat);

static void tcp_update_recv_tstamps(struct sk_buff *skb,
				    struct scm_timestamping_internal *tss)
{
	if (skb->tstamp)
		tss->ts[0] = ktime_to_timespec64(skb->tstamp);
	else
		tss->ts[0] = (struct timespec64) {0};

	if (skb_hwtstamps(skb)->hwtstamp)
		tss->ts[2] = ktime_to_timespec64(skb_hwtstamps(skb)->hwtstamp);
	else
		tss->ts[2] = (struct timespec64) {0};
}

#ifdef CONFIG_MMU
static const struct vm_operations_struct tcp_vm_ops = {
};
@@ -1842,13 +1862,13 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len,
			      struct scm_timestamping_internal *tss,
			      int *cmsg_flags);
static int receive_fallback_to_copy(struct sock *sk,
				    struct tcp_zerocopy_receive *zc, int inq)
				    struct tcp_zerocopy_receive *zc, int inq,
				    struct scm_timestamping_internal *tss)
{
	unsigned long copy_address = (unsigned long)zc->copybuf_address;
	struct scm_timestamping_internal tss_unused;
	int err, cmsg_flags_unused;
	struct msghdr msg = {};
	struct iovec iov;
	int err;

	zc->length = 0;
	zc->recv_skip_hint = 0;
@@ -1862,7 +1882,7 @@ static int receive_fallback_to_copy(struct sock *sk,
		return err;

	err = tcp_recvmsg_locked(sk, &msg, inq, /*nonblock=*/1, /*flags=*/0,
				 &tss_unused, &cmsg_flags_unused);
				 tss, &zc->msg_flags);
	if (err < 0)
		return err;

@@ -1903,21 +1923,27 @@ static int tcp_copy_straggler_data(struct tcp_zerocopy_receive *zc,
	return (__s32)copylen;
}

static int tcp_zerocopy_handle_leftover_data(struct tcp_zerocopy_receive *zc,
static int tcp_zc_handle_leftover(struct tcp_zerocopy_receive *zc,
				  struct sock *sk,
				  struct sk_buff *skb,
				  u32 *seq,
					     s32 copybuf_len)
				  s32 copybuf_len,
				  struct scm_timestamping_internal *tss)
{
	u32 offset, copylen = min_t(u32, copybuf_len, zc->recv_skip_hint);

	if (!copylen)
		return 0;
	/* skb is null if inq < PAGE_SIZE. */
	if (skb)
	if (skb) {
		offset = *seq - TCP_SKB_CB(skb)->seq;
	else
	} else {
		skb = tcp_recv_skb(sk, *seq, &offset);
		if (TCP_SKB_CB(skb)->has_rxtstamp) {
			tcp_update_recv_tstamps(skb, tss);
			zc->msg_flags |= TCP_CMSG_TS;
		}
	}

	zc->copybuf_len = tcp_copy_straggler_data(zc, skb, copylen, &offset,
						  seq);
@@ -2004,9 +2030,37 @@ static int tcp_zerocopy_vm_insert_batch(struct vm_area_struct *vma,
		err);
}

static void tcp_recv_timestamp(struct msghdr *msg, const struct sock *sk,
			       struct scm_timestamping_internal *tss);
static void tcp_zc_finalize_rx_tstamp(struct sock *sk,
				      struct tcp_zerocopy_receive *zc,
				      struct scm_timestamping_internal *tss)
{
	unsigned long msg_control_addr;
	struct msghdr cmsg_dummy;

	msg_control_addr = (unsigned long)zc->msg_control;
	cmsg_dummy.msg_control = (void *)msg_control_addr;
	cmsg_dummy.msg_controllen =
		(__kernel_size_t)zc->msg_controllen;
	cmsg_dummy.msg_flags = in_compat_syscall()
		? MSG_CMSG_COMPAT : 0;
	zc->msg_flags = 0;
	if (zc->msg_control == msg_control_addr &&
	    zc->msg_controllen == cmsg_dummy.msg_controllen) {
		tcp_recv_timestamp(&cmsg_dummy, sk, tss);
		zc->msg_control = (__u64)
			((uintptr_t)cmsg_dummy.msg_control);
		zc->msg_controllen =
			(__u64)cmsg_dummy.msg_controllen;
		zc->msg_flags = (__u32)cmsg_dummy.msg_flags;
	}
}

#define TCP_ZEROCOPY_PAGE_BATCH_SIZE 32
static int tcp_zerocopy_receive(struct sock *sk,
				struct tcp_zerocopy_receive *zc)
				struct tcp_zerocopy_receive *zc,
				struct scm_timestamping_internal *tss)
{
	u32 length = 0, offset, vma_len, avail_len, copylen = 0;
	unsigned long address = (unsigned long)zc->address;
@@ -2023,6 +2077,7 @@ static int tcp_zerocopy_receive(struct sock *sk,
	int ret;

	zc->copybuf_len = 0;
	zc->msg_flags = 0;

	if (address & (PAGE_SIZE - 1) || address != zc->address)
		return -EINVAL;
@@ -2033,7 +2088,7 @@ static int tcp_zerocopy_receive(struct sock *sk,
	sock_rps_record_flow(sk);

	if (inq && inq <= copybuf_len)
		return receive_fallback_to_copy(sk, zc, inq);
		return receive_fallback_to_copy(sk, zc, inq, tss);

	if (inq < PAGE_SIZE) {
		zc->length = 0;
@@ -2078,6 +2133,11 @@ static int tcp_zerocopy_receive(struct sock *sk,
			} else {
				skb = tcp_recv_skb(sk, seq, &offset);
			}

			if (TCP_SKB_CB(skb)->has_rxtstamp) {
				tcp_update_recv_tstamps(skb, tss);
				zc->msg_flags |= TCP_CMSG_TS;
			}
			zc->recv_skip_hint = skb->len - offset;
			frags = skb_advance_to_frag(skb, offset, &offset_frag);
			if (!frags || offset_frag)
@@ -2120,8 +2180,7 @@ static int tcp_zerocopy_receive(struct sock *sk,
	mmap_read_unlock(current->mm);
	/* Try to copy straggler data. */
	if (!ret)
		copylen = tcp_zerocopy_handle_leftover_data(zc, sk, skb, &seq,
							    copybuf_len);
		copylen = tcp_zc_handle_leftover(zc, sk, skb, &seq, copybuf_len, tss);

	if (length + copylen) {
		WRITE_ONCE(tp->copied_seq, seq);
@@ -2142,20 +2201,6 @@ static int tcp_zerocopy_receive(struct sock *sk,
}
#endif

static void tcp_update_recv_tstamps(struct sk_buff *skb,
				    struct scm_timestamping_internal *tss)
{
	if (skb->tstamp)
		tss->ts[0] = ktime_to_timespec64(skb->tstamp);
	else
		tss->ts[0] = (struct timespec64) {0};

	if (skb_hwtstamps(skb)->hwtstamp)
		tss->ts[2] = ktime_to_timespec64(skb_hwtstamps(skb)->hwtstamp);
	else
		tss->ts[2] = (struct timespec64) {0};
}

/* Similar to __sock_recv_timestamp, but does not require an skb */
static void tcp_recv_timestamp(struct msghdr *msg, const struct sock *sk,
			       struct scm_timestamping_internal *tss)
@@ -2272,7 +2317,7 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len,
		goto out;

	if (tp->recvmsg_inq)
		*cmsg_flags = 1;
		*cmsg_flags = TCP_CMSG_INQ;
	timeo = sock_rcvtimeo(sk, nonblock);

	/* Urgent data needs to be handled specially. */
@@ -2453,7 +2498,7 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len,

		if (TCP_SKB_CB(skb)->has_rxtstamp) {
			tcp_update_recv_tstamps(skb, tss);
			*cmsg_flags |= 2;
			*cmsg_flags |= TCP_CMSG_TS;
		}

		if (used + offset < skb->len)
@@ -2513,9 +2558,9 @@ int tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int nonblock,
	release_sock(sk);

	if (cmsg_flags && ret >= 0) {
		if (cmsg_flags & 2)
		if (cmsg_flags & TCP_CMSG_TS)
			tcp_recv_timestamp(msg, sk, &tss);
		if (cmsg_flags & 1) {
		if (cmsg_flags & TCP_CMSG_INQ) {
			inq = tcp_inq_hint(sk);
			put_cmsg(msg, SOL_TCP, TCP_CM_INQ, sizeof(inq), &inq);
		}
@@ -4099,6 +4144,7 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
	}
#ifdef CONFIG_MMU
	case TCP_ZEROCOPY_RECEIVE: {
		struct scm_timestamping_internal tss;
		struct tcp_zerocopy_receive zc = {};
		int err;

@@ -4114,11 +4160,18 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
		if (copy_from_user(&zc, optval, len))
			return -EFAULT;
		lock_sock(sk);
		err = tcp_zerocopy_receive(sk, &zc);
		err = tcp_zerocopy_receive(sk, &zc, &tss);
		release_sock(sk);
		if (len >= offsetofend(struct tcp_zerocopy_receive, err))
			goto zerocopy_rcv_sk_err;
		if (len >= offsetofend(struct tcp_zerocopy_receive, msg_flags))
			goto zerocopy_rcv_cmsg;
		switch (len) {
		case offsetofend(struct tcp_zerocopy_receive, msg_flags):
			goto zerocopy_rcv_cmsg;
		case offsetofend(struct tcp_zerocopy_receive, msg_controllen):
		case offsetofend(struct tcp_zerocopy_receive, msg_control):
		case offsetofend(struct tcp_zerocopy_receive, flags):
		case offsetofend(struct tcp_zerocopy_receive, copybuf_len):
		case offsetofend(struct tcp_zerocopy_receive, copybuf_address):
		case offsetofend(struct tcp_zerocopy_receive, err):
			goto zerocopy_rcv_sk_err;
		case offsetofend(struct tcp_zerocopy_receive, inq):
@@ -4127,6 +4180,11 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
		default:
			goto zerocopy_rcv_out;
		}
zerocopy_rcv_cmsg:
		if (zc.msg_flags & TCP_CMSG_TS)
			tcp_zc_finalize_rx_tstamp(sk, &zc, &tss);
		else
			zc.msg_flags = 0;
zerocopy_rcv_sk_err:
		if (!err)
			zc.err = sock_error(sk);