Commit fbcb4859 authored by Wang Yufen's avatar Wang Yufen Committed by Zheng Zengkai
Browse files

tcp_comp: implement recvmsg for tcp compression

hulk inclusion
category: feature
bugzilla: https://gitee.com/openeuler/kernel/issues/I4PNEK


CVE: NA

-------------------------------------------------

This patch implement software level compression for
receiving tcp messages. The compressed TCP payload will be
decompressed after receive.

Signed-off-by: default avatarWang Yufen <wangyufen@huawei.com>
Signed-off-by: default avatarYang Yingliang <yangyingliang@huawei.com>
Signed-off-by: default avatarLu Wei <luwei32@huawei.com>
Reviewed-by: default avatarWei Yongjun <weiyongjun1@huawei.com>
Signed-off-by: default avatarZheng Zengkai <zhengzengkai@huawei.com>
parent e9ce37bb
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -745,7 +745,8 @@ config TCP_MD5SIG

config TCP_COMP
	bool "TCP: Transport Layer Compression support"
	depends on ZSTD_COMPRESS=y
	depends on CRYPTO_ZSTD=y
	select STREAM_PARSER
	help
          Enable kernel payload compression support for TCP protocol. This allows
          payload compression handling of the TCP protocol to be done in-kernel.
+375 −2
Original line number Diff line number Diff line
@@ -9,8 +9,11 @@
#include <linux/zstd.h>

#define TCP_COMP_MAX_PADDING	64
#define TCP_COMP_SCRATCH_SIZE	65400
#define TCP_COMP_SCRATCH_SIZE	65535
#define TCP_COMP_MAX_CSIZE	(TCP_COMP_SCRATCH_SIZE + TCP_COMP_MAX_PADDING)
#define TCP_COMP_ALLOC_ORDER   get_order(65536)
#define TCP_COMP_MAX_WINDOWLOG 17
#define TCP_COMP_MAX_INPUT (1 << TCP_COMP_MAX_WINDOWLOG)

#define TCP_COMP_SEND_PENDING	1
#define ZSTD_COMP_DEFAULT_LEVEL	1
@@ -31,6 +34,20 @@ struct tcp_comp_context_tx {
	bool in_tcp_sendpages;
};

struct tcp_comp_context_rx {
	ZSTD_DStream *dstream;
	void *dworkspace;
	void *plaintext_data;
	void *compressed_data;
	void *remaining_data;

	size_t data_offset;
	struct strparser strp;
	void (*saved_data_ready)(struct sock *sk);
	struct sk_buff *pkt;
	bool decompressed;
};

struct tcp_comp_context {
	struct rcu_head rcu;

@@ -38,6 +55,7 @@ struct tcp_comp_context {
	void (*sk_write_space)(struct sock *sk);

	struct tcp_comp_context_tx tx;
	struct tcp_comp_context_rx rx;

	unsigned long flags;
};
@@ -426,12 +444,344 @@ static int tcp_comp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
	return copied ? copied : err;
}

static struct sk_buff *comp_wait_data(struct sock *sk, int flags,
				      long timeo, int *err)
{
	struct tcp_comp_context *ctx = comp_get_ctx(sk);
	struct sk_buff *skb;
	DEFINE_WAIT_FUNC(wait, woken_wake_function);

	while (!(skb = ctx->rx.pkt)) {
		if (sk->sk_err) {
			*err = sock_error(sk);
			return NULL;
		}

		if (!skb_queue_empty(&sk->sk_receive_queue)) {
			__strp_unpause(&ctx->rx.strp);
			if (ctx->rx.pkt)
				return ctx->rx.pkt;
		}

		if (sk->sk_shutdown & RCV_SHUTDOWN)
			return NULL;

		if (sock_flag(sk, SOCK_DONE))
			return NULL;

		if ((flags & MSG_DONTWAIT) || !timeo) {
			*err = -EAGAIN;
			return NULL;
		}

		add_wait_queue(sk_sleep(sk), &wait);
		sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
		sk_wait_event(sk, &timeo, ctx->rx.pkt != skb, &wait);
		sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
		remove_wait_queue(sk_sleep(sk), &wait);

		/* Handle signals */
		if (signal_pending(current)) {
			*err = sock_intr_errno(timeo);
			return NULL;
		}
	}

	return skb;
}

static bool comp_advance_skb(struct sock *sk, struct sk_buff *skb,
			     unsigned int len)
{
	struct tcp_comp_context *ctx = comp_get_ctx(sk);
	struct strp_msg *rxm = strp_msg(skb);

	if (len < rxm->full_len) {
		rxm->offset += len;
		rxm->full_len -= len;
		return false;
	}

	/* Finished with message */
	ctx->rx.pkt = NULL;
	kfree_skb(skb);
	__strp_unpause(&ctx->rx.strp);

	return true;
}

static int tcp_comp_rx_context_init(struct tcp_comp_context *ctx)
{
	int dsize;

	dsize = ZSTD_DStreamWorkspaceBound(TCP_COMP_MAX_INPUT);
	if (dsize <= 0)
		return -EINVAL;

	ctx->rx.dworkspace = kmalloc(dsize, GFP_KERNEL);
	if (!ctx->rx.dworkspace)
		return -ENOMEM;

	ctx->rx.dstream = ZSTD_initDStream(TCP_COMP_MAX_INPUT,
					   ctx->rx.dworkspace, dsize);
	if (!ctx->rx.dstream)
		goto err_dstream;

	ctx->rx.plaintext_data = kvmalloc(TCP_COMP_MAX_CSIZE * 32, GFP_KERNEL);
	if (!ctx->rx.plaintext_data)
		goto err_dstream;

	ctx->rx.compressed_data = kvmalloc(TCP_COMP_MAX_CSIZE, GFP_KERNEL);
	if (!ctx->rx.compressed_data)
		goto err_compressed;

	ctx->rx.remaining_data = kvmalloc(TCP_COMP_MAX_CSIZE, GFP_KERNEL);
	if (!ctx->rx.remaining_data)
		goto err_remaining;

	ctx->rx.data_offset = 0;

	return 0;

err_remaining:
	kvfree(ctx->rx.compressed_data);
	ctx->rx.compressed_data = NULL;
err_compressed:
	kvfree(ctx->rx.plaintext_data);
	ctx->rx.plaintext_data = NULL;
err_dstream:
	kfree(ctx->rx.dworkspace);
	ctx->rx.dworkspace = NULL;

	return -ENOMEM;
}

static void *tcp_comp_get_rx_stream(struct sock *sk)
{
	struct tcp_comp_context *ctx = comp_get_ctx(sk);

	if (!ctx->rx.plaintext_data)
		tcp_comp_rx_context_init(ctx);

	return ctx->rx.plaintext_data;
}

static int tcp_comp_decompress(struct sock *sk, struct sk_buff *skb)
{
	struct tcp_comp_context *ctx = comp_get_ctx(sk);
	const int plen = skb->len;
	struct strp_msg *rxm;
	ZSTD_outBuffer outbuf;
	ZSTD_inBuffer inbuf;
	int len;
	void *to;

	to = tcp_comp_get_rx_stream(sk);
	if (!to)
		return -ENOSPC;

	if (skb_linearize_cow(skb))
		return -ENOMEM;

	if (plen + ctx->rx.data_offset > TCP_COMP_MAX_CSIZE)
		return -ENOMEM;

	if (ctx->rx.data_offset)
		memcpy(ctx->rx.compressed_data, ctx->rx.remaining_data,
		       ctx->rx.data_offset);

	memcpy((char *)ctx->rx.compressed_data + ctx->rx.data_offset,
	       skb->data, plen);

	inbuf.src = ctx->rx.compressed_data;
	inbuf.pos = 0;
	inbuf.size = plen + ctx->rx.data_offset;
	ctx->rx.data_offset = 0;

	outbuf.dst = ctx->rx.plaintext_data;
	outbuf.pos = 0;
	outbuf.size = TCP_COMP_MAX_CSIZE * 32;

	while (1) {
		size_t ret;

		to = outbuf.dst;

		ret = ZSTD_decompressStream(ctx->rx.dstream, &outbuf, &inbuf);
		if (ZSTD_isError(ret))
			return -EIO;

		len = outbuf.pos - plen;
		if (len > skb_tailroom(skb))
			len = skb_tailroom(skb);

		__skb_put(skb, len);
		rxm = strp_msg(skb);
		rxm->full_len += len;

		len += plen;
		skb_copy_to_linear_data(skb, to, len);

		while ((to += len, outbuf.pos -= len) > 0) {
			struct page *pages;
			skb_frag_t *frag;

			if (WARN_ON(skb_shinfo(skb)->nr_frags >= MAX_SKB_FRAGS))
				return -EMSGSIZE;

			frag = skb_shinfo(skb)->frags +
			       skb_shinfo(skb)->nr_frags;
			pages = alloc_pages(__GFP_NOWARN | GFP_KERNEL | __GFP_COMP,
					    TCP_COMP_ALLOC_ORDER);

			if (!pages)
				return -ENOMEM;

			__skb_frag_set_page(frag, pages);
			len = PAGE_SIZE << TCP_COMP_ALLOC_ORDER;
			if (outbuf.pos < len)
				len = outbuf.pos;

			frag->bv_offset = 0;
			skb_frag_size_set(frag, len);
			memcpy(skb_frag_address(frag), to, len);

			skb->truesize += len;
			skb->data_len += len;
			skb->len += len;
			rxm->full_len += len;
			skb_shinfo(skb)->nr_frags++;
		}

		if (ret == 0)
			break;

		if (inbuf.pos >= plen || !inbuf.pos) {
			if (inbuf.pos < inbuf.size) {
				memcpy((char *)ctx->rx.remaining_data,
				       (char *)inbuf.src + inbuf.pos,
				       inbuf.size - inbuf.pos);
				ctx->rx.data_offset = inbuf.size - inbuf.pos;
			}
			break;
		}
	}
	return 0;
}

static int tcp_comp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
			    int nonblock, int flags, int *addr_len)
{
	struct tcp_comp_context *ctx = comp_get_ctx(sk);
	struct strp_msg *rxm;
	struct sk_buff *skb;
	ssize_t copied = 0;
	int target, err = 0;
	long timeo;

	flags |= nonblock;

	if (unlikely(flags & MSG_ERRQUEUE))
		return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);

	lock_sock(sk);

	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
	timeo = sock_rcvtimeo(sk, flags & MSG_WAITALL);

	do {
		int chunk = 0;

		skb = comp_wait_data(sk, flags, timeo, &err);
		if (!skb)
			goto recv_end;

	return ctx->sk_proto->recvmsg(sk, msg, len, nonblock, flags, addr_len);
		if (!ctx->rx.decompressed) {
			err = tcp_comp_decompress(sk, skb);
			if (err < 0) {
				if (err != -ENOSPC)
					tcp_comp_err_abort(sk, EBADMSG);
				goto recv_end;
			}
			ctx->rx.decompressed = true;
		}
		rxm = strp_msg(skb);

		chunk = min_t(unsigned int, rxm->full_len, len);

		err = skb_copy_datagram_msg(skb, rxm->offset, msg,
					    chunk);
		if (err < 0)
			goto recv_end;

		copied += chunk;
		len -= chunk;
		if (likely(!(flags & MSG_PEEK)))
			comp_advance_skb(sk, skb, chunk);
		else
			break;

		if (copied >= target && !ctx->rx.pkt)
			break;
	} while (len > 0);

recv_end:
	release_sock(sk);
	return copied ? : err;
}

bool comp_stream_read(const struct sock *sk)
{
	struct tcp_comp_context *ctx = comp_get_ctx(sk);

	if (ctx->rx.pkt)
		return true;

	return false;
}

static void comp_data_ready(struct sock *sk)
{
	struct tcp_comp_context *ctx = comp_get_ctx(sk);

	strp_data_ready(&ctx->rx.strp);
}

static void comp_queue(struct strparser *strp, struct sk_buff *skb)
{
	struct tcp_comp_context *ctx = comp_get_ctx(strp->sk);

	ctx->rx.decompressed = false;
	ctx->rx.pkt = skb;
	strp_pause(strp);
	ctx->rx.saved_data_ready(strp->sk);
}

static int comp_read_size(struct strparser *strp, struct sk_buff *skb)
{
	struct strp_msg *rxm = strp_msg(skb);

	if (rxm->offset > skb->len)
		return 0;

	return skb->len;
}

void comp_setup_strp(struct sock *sk, struct tcp_comp_context *ctx)
{
	struct strp_callbacks cb;

	memset(&cb, 0, sizeof(cb));
	cb.rcv_msg = comp_queue;
	cb.parse_msg = comp_read_size;
	strp_init(&ctx->rx.strp, sk, &cb);

	write_lock_bh(&sk->sk_callback_lock);
	ctx->rx.saved_data_ready = sk->sk_data_ready;
	sk->sk_data_ready = comp_data_ready;
	write_unlock_bh(&sk->sk_callback_lock);

	strp_check_rcv(&ctx->rx.strp);
}

static void tcp_comp_write_space(struct sock *sk)
@@ -483,6 +833,7 @@ void tcp_init_compression(struct sock *sk)
	rcu_assign_pointer(icsk->icsk_ulp_data, ctx);

	sock_set_flag(sk, SOCK_COMP);
	comp_setup_strp(sk, ctx);
}

static void tcp_comp_context_tx_free(struct tcp_comp_context *ctx)
@@ -497,6 +848,21 @@ static void tcp_comp_context_tx_free(struct tcp_comp_context *ctx)
	ctx->tx.compressed_data = NULL;
}

static void tcp_comp_context_rx_free(struct tcp_comp_context *ctx)
{
	kfree(ctx->rx.dworkspace);
	ctx->rx.dworkspace = NULL;

	kvfree(ctx->rx.plaintext_data);
	ctx->rx.plaintext_data = NULL;

	kvfree(ctx->rx.compressed_data);
	ctx->rx.compressed_data = NULL;

	kvfree(ctx->rx.remaining_data);
	ctx->rx.remaining_data = NULL;
}

static void tcp_comp_context_free(struct rcu_head *head)
{
	struct tcp_comp_context *ctx;
@@ -504,6 +870,7 @@ static void tcp_comp_context_free(struct rcu_head *head)
	ctx = container_of(head, struct tcp_comp_context, rcu);

	tcp_comp_context_tx_free(ctx);
	tcp_comp_context_rx_free(ctx);
	kfree(ctx);
}

@@ -515,6 +882,11 @@ void tcp_cleanup_compression(struct sock *sk)
	if (!ctx || !sock_flag(sk, SOCK_COMP))
		return;

	if (ctx->rx.pkt) {
		kfree_skb(ctx->rx.pkt);
		ctx->rx.pkt = NULL;
	}

	rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
	call_rcu(&ctx->rcu, tcp_comp_context_free);
}
@@ -524,6 +896,7 @@ int tcp_comp_init(void)
	tcp_prot_override = tcp_prot;
	tcp_prot_override.sendmsg = tcp_comp_sendmsg;
	tcp_prot_override.recvmsg = tcp_comp_recvmsg;
	tcp_prot_override.stream_memory_read = comp_stream_read;

	return 0;
}