Commit c5f3ee69 authored by Wang Yufen's avatar Wang Yufen Committed by Zheng Zengkai
Browse files

tcp_comp: Add dpkt to save decompressed skb

hulk inclusion
category: bugfix
bugzilla: https://gitee.com/openeuler/kernel/issues/I48H9Z?from=project-issue


CVE: NA

-------------------------------------------------

In order to separate the compressed data and decompressed data, this patch adds
dpkt to tcp_comp_context_rx, dpkt is used to save decompressed skb.

Signed-off-by: default avatarWang Yufen <wangyufen@huawei.com>
Signed-off-by: default avatarYang Yingliang <yangyingliang@huawei.com>
Signed-off-by: default avatarLu Wei <luwei32@huawei.com>
Reviewed-by: default avatarWei Yongjun <weiyongjun1@huawei.com>
Signed-off-by: default avatarZheng Zengkai <zhengzengkai@huawei.com>
parent d876fdbb
Loading
Loading
Loading
Loading
+64 −30
Original line number Diff line number Diff line
@@ -45,7 +45,7 @@ struct tcp_comp_context_rx {
	struct strparser strp;
	void (*saved_data_ready)(struct sock *sk);
	struct sk_buff *pkt;
	bool decompressed;
	struct sk_buff *dpkt;
};

struct tcp_comp_context {
@@ -510,6 +510,24 @@ static bool comp_advance_skb(struct sock *sk, struct sk_buff *skb,
	return true;
}

static bool comp_advance_dskb(struct sock *sk, struct sk_buff *skb,
			      unsigned int len)
{
	struct tcp_comp_context *ctx = comp_get_ctx(sk);
	struct strp_msg *rxm = strp_msg(skb);

	if (len < rxm->full_len) {
		rxm->offset += len;
		rxm->full_len -= len;
		return false;
	}

	/* Finished with message */
	ctx->rx.dpkt = NULL;
	kfree_skb(skb);
	return true;
}

static int tcp_comp_rx_context_init(struct tcp_comp_context *ctx)
{
	int dsize;
@@ -566,13 +584,14 @@ static void *tcp_comp_get_rx_stream(struct sock *sk)
	return ctx->rx.plaintext_data;
}

static int tcp_comp_decompress(struct sock *sk, struct sk_buff *skb)
static int tcp_comp_decompress(struct sock *sk, struct sk_buff *skb, int flags)
{
	struct tcp_comp_context *ctx = comp_get_ctx(sk);
	struct strp_msg *rxm = strp_msg(skb);
	const int plen = skb->len;
	ZSTD_outBuffer outbuf;
	ZSTD_inBuffer inbuf;
	struct sk_buff *nskb;
	int len;
	void *to;

@@ -586,6 +605,10 @@ static int tcp_comp_decompress(struct sock *sk, struct sk_buff *skb)
	if (plen + ctx->rx.data_offset > TCP_COMP_MAX_CSIZE)
		return -ENOMEM;

	nskb = skb_copy(skb, GFP_KERNEL);
	if (!nskb)
		return -ENOMEM;

	if (ctx->rx.data_offset)
		memcpy(ctx->rx.compressed_data, ctx->rx.remaining_data,
		       ctx->rx.data_offset);
@@ -607,34 +630,38 @@ static int tcp_comp_decompress(struct sock *sk, struct sk_buff *skb)

		to = outbuf.dst;
		ret = ZSTD_decompressStream(ctx->rx.dstream, &outbuf, &inbuf);
		if (ZSTD_isError(ret))
		if (ZSTD_isError(ret)) {
			kfree_skb(nskb);
			return -EIO;
		}

		len = outbuf.pos - plen;
		if (len > skb_tailroom(skb))
			len = skb_tailroom(skb);
		if (len > skb_tailroom(nskb))
			len = skb_tailroom(nskb);

		__skb_put(skb, len);
		rxm->full_len += (len + rxm->offset);
		rxm->offset = 0;
		__skb_put(nskb, len);

		len += plen;
		skb_copy_to_linear_data(skb, to, len);
		skb_copy_to_linear_data(nskb, to, len);

		while ((to += len, outbuf.pos -= len) > 0) {
			struct page *pages;
			skb_frag_t *frag;

			if (WARN_ON(skb_shinfo(skb)->nr_frags >= MAX_SKB_FRAGS))
			if (WARN_ON(skb_shinfo(nskb)->nr_frags >= MAX_SKB_FRAGS)) {
				kfree_skb(nskb);
				return -EMSGSIZE;
			}

			frag = skb_shinfo(skb)->frags +
			       skb_shinfo(skb)->nr_frags;
			frag = skb_shinfo(nskb)->frags +
			       skb_shinfo(nskb)->nr_frags;
			pages = alloc_pages(__GFP_NOWARN | GFP_KERNEL | __GFP_COMP,
					    TCP_COMP_ALLOC_ORDER);

			if (!pages)
			if (!pages) {
				kfree_skb(nskb);
				return -ENOMEM;
			}

			__skb_frag_set_page(frag, pages);
			len = PAGE_SIZE << TCP_COMP_ALLOC_ORDER;
@@ -645,11 +672,10 @@ static int tcp_comp_decompress(struct sock *sk, struct sk_buff *skb)
			skb_frag_size_set(frag, len);
			memcpy(skb_frag_address(frag), to, len);

			skb->truesize += len;
			skb->data_len += len;
			skb->len += len;
			rxm->full_len += len;
			skb_shinfo(skb)->nr_frags++;
			nskb->truesize += len;
			nskb->data_len += len;
			nskb->len += len;
			skb_shinfo(nskb)->nr_frags++;
		}

		if (ret == 0)
@@ -665,6 +691,13 @@ static int tcp_comp_decompress(struct sock *sk, struct sk_buff *skb)
			break;
		}
	}

	ctx->rx.dpkt = nskb;
	rxm = strp_msg(nskb);
	rxm->full_len = nskb->len;
	rxm->offset = 0;
	comp_advance_skb(sk, skb, plen - rxm->offset);

	return 0;
}

@@ -691,21 +724,19 @@ static int tcp_comp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
	do {
		int chunk = 0;

		if (!ctx->rx.dpkt) {
			skb = comp_wait_data(sk, flags, timeo, &err);
			if (!skb)
				goto recv_end;

		if (!ctx->rx.decompressed) {
			err = tcp_comp_decompress(sk, skb);
			err = tcp_comp_decompress(sk, skb, flags);
			if (err < 0) {
				goto recv_end;
			}
			ctx->rx.decompressed = true;
		}
		skb = ctx->rx.dpkt;
		rxm = strp_msg(skb);

		chunk = min_t(unsigned int, rxm->full_len, len);

		err = skb_copy_datagram_msg(skb, rxm->offset, msg,
					    chunk);
		if (err < 0)
@@ -714,11 +745,11 @@ static int tcp_comp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
		copied += chunk;
		len -= chunk;
		if (likely(!(flags & MSG_PEEK)))
			comp_advance_skb(sk, skb, chunk);
			comp_advance_dskb(sk, skb, chunk);
		else
			break;

		if (copied >= target && !ctx->rx.pkt)
		if (copied >= target && !ctx->rx.dpkt)
			break;
	} while (len > 0);

@@ -734,7 +765,7 @@ bool comp_stream_read(const struct sock *sk)
	if (!ctx)
		return false;

	if (ctx->rx.pkt)
	if (ctx->rx.pkt || ctx->rx.dpkt)
		return true;

	return false;
@@ -751,7 +782,6 @@ static void comp_queue(struct strparser *strp, struct sk_buff *skb)
{
	struct tcp_comp_context *ctx = comp_get_ctx(strp->sk);

	ctx->rx.decompressed = false;
	ctx->rx.pkt = skb;
	strp_pause(strp);
	ctx->rx.saved_data_ready(strp->sk);
@@ -887,6 +917,10 @@ void tcp_cleanup_compression(struct sock *sk)
		kfree_skb(ctx->rx.pkt);
		ctx->rx.pkt = NULL;
	}
	if (ctx->rx.dpkt) {
		kfree_skb(ctx->rx.dpkt);
		ctx->rx.dpkt = NULL;
	}
	strp_stop(&ctx->rx.strp);

	rcu_assign_pointer(icsk->icsk_ulp_data, NULL);