Commit 2897041e authored by David S. Miller's avatar David S. Miller
Browse files

Merge branch 'tls-fixes'



Jakub Kicinski says:

====================
tls: rx: strp: fix inline crypto offload

The local strparser version I added to TLS does not preserve
decryption status, which breaks inline crypto (NIC offload).
====================

Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 7e01c7f7 74836ec8
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
@@ -1587,6 +1587,16 @@ static inline void skb_copy_hash(struct sk_buff *to, const struct sk_buff *from)
	to->l4_hash = from->l4_hash;
};

static inline int skb_cmp_decrypted(const struct sk_buff *skb1,
				    const struct sk_buff *skb2)
{
#ifdef CONFIG_TLS_DEVICE
	return skb2->decrypted - skb1->decrypted;
#else
	return 0;
#endif
}

static inline void skb_copy_decrypted(struct sk_buff *to,
				      const struct sk_buff *from)
{
+1 −0
Original line number Diff line number Diff line
@@ -126,6 +126,7 @@ struct tls_strparser {
	u32 mark : 8;
	u32 stopped : 1;
	u32 copy_mode : 1;
	u32 mixed_decrypted : 1;
	u32 msg_ready : 1;

	struct strp_msg stm;
+5 −0
Original line number Diff line number Diff line
@@ -167,6 +167,11 @@ static inline bool tls_strp_msg_ready(struct tls_sw_context_rx *ctx)
	return ctx->strp.msg_ready;
}

static inline bool tls_strp_msg_mixed_decrypted(struct tls_sw_context_rx *ctx)
{
	return ctx->strp.mixed_decrypted;
}

#ifdef CONFIG_TLS_DEVICE
int tls_device_init(void);
void tls_device_cleanup(void);
+8 −14
Original line number Diff line number Diff line
@@ -1007,20 +1007,14 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx)
	struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx);
	struct sk_buff *skb = tls_strp_msg(sw_ctx);
	struct strp_msg *rxm = strp_msg(skb);
	int is_decrypted = skb->decrypted;
	int is_encrypted = !is_decrypted;
	struct sk_buff *skb_iter;
	int left;

	left = rxm->full_len - skb->len;
	/* Check if all the data is decrypted already */
	skb_iter = skb_shinfo(skb)->frag_list;
	while (skb_iter && left > 0) {
		is_decrypted &= skb_iter->decrypted;
		is_encrypted &= !skb_iter->decrypted;

		left -= skb_iter->len;
		skb_iter = skb_iter->next;
	int is_decrypted, is_encrypted;

	if (!tls_strp_msg_mixed_decrypted(sw_ctx)) {
		is_decrypted = skb->decrypted;
		is_encrypted = !is_decrypted;
	} else {
		is_decrypted = 0;
		is_encrypted = 0;
	}

	trace_tls_device_decrypted(sk, tcp_sk(sk)->copied_seq - rxm->full_len,
+149 −36
Original line number Diff line number Diff line
@@ -29,34 +29,50 @@ static void tls_strp_anchor_free(struct tls_strparser *strp)
	struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);

	DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1);
	if (!strp->copy_mode)
		shinfo->frag_list = NULL;
	consume_skb(strp->anchor);
	strp->anchor = NULL;
}

/* Create a new skb with the contents of input copied to its page frags */
static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp)
static struct sk_buff *
tls_strp_skb_copy(struct tls_strparser *strp, struct sk_buff *in_skb,
		  int offset, int len)
{
	struct strp_msg *rxm;
	struct sk_buff *skb;
	int i, err, offset;
	int i, err;

	skb = alloc_skb_with_frags(0, strp->stm.full_len, TLS_PAGE_ORDER,
	skb = alloc_skb_with_frags(0, len, TLS_PAGE_ORDER,
				   &err, strp->sk->sk_allocation);
	if (!skb)
		return NULL;

	offset = strp->stm.offset;
	for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
		skb_frag_t *frag = &skb_shinfo(skb)->frags[i];

		WARN_ON_ONCE(skb_copy_bits(strp->anchor, offset,
		WARN_ON_ONCE(skb_copy_bits(in_skb, offset,
					   skb_frag_address(frag),
					   skb_frag_size(frag)));
		offset += skb_frag_size(frag);
	}

	skb_copy_header(skb, strp->anchor);
	skb->len = len;
	skb->data_len = len;
	skb_copy_header(skb, in_skb);
	return skb;
}

/* Create a new skb with the contents of input copied to its page frags */
static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp)
{
	struct strp_msg *rxm;
	struct sk_buff *skb;

	skb = tls_strp_skb_copy(strp, strp->anchor, strp->stm.offset,
				strp->stm.full_len);
	if (!skb)
		return NULL;

	rxm = strp_msg(skb);
	rxm->offset = 0;
	return skb;
@@ -180,22 +196,22 @@ static void tls_strp_flush_anchor_copy(struct tls_strparser *strp)
	for (i = 0; i < shinfo->nr_frags; i++)
		__skb_frag_unref(&shinfo->frags[i], false);
	shinfo->nr_frags = 0;
	if (strp->copy_mode) {
		kfree_skb_list(shinfo->frag_list);
		shinfo->frag_list = NULL;
	}
	strp->copy_mode = 0;
	strp->mixed_decrypted = 0;
}

static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
			   unsigned int offset, size_t in_len)
static int tls_strp_copyin_frag(struct tls_strparser *strp, struct sk_buff *skb,
				struct sk_buff *in_skb, unsigned int offset,
				size_t in_len)
{
	struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data;
	struct sk_buff *skb;
	skb_frag_t *frag;
	size_t len, chunk;
	skb_frag_t *frag;
	int sz;

	if (strp->msg_ready)
		return 0;

	skb = strp->anchor;
	frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE];

	len = in_len;
@@ -208,19 +224,26 @@ static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
					   skb_frag_size(frag),
					   chunk));

		sz = tls_rx_msg_size(strp, strp->anchor);
		if (sz < 0) {
			desc->error = sz;
			return 0;
		}

		/* We may have over-read, sz == 0 is guaranteed under-read */
		if (sz > 0)
			chunk =	min_t(size_t, chunk, sz - skb->len);

		skb->len += chunk;
		skb->data_len += chunk;
		skb_frag_size_add(frag, chunk);

		sz = tls_rx_msg_size(strp, skb);
		if (sz < 0)
			return sz;

		/* We may have over-read, sz == 0 is guaranteed under-read */
		if (unlikely(sz && sz < skb->len)) {
			int over = skb->len - sz;

			WARN_ON_ONCE(over > chunk);
			skb->len -= over;
			skb->data_len -= over;
			skb_frag_size_add(frag, -over);

			chunk -= over;
		}

		frag++;
		len -= chunk;
		offset += chunk;
@@ -247,15 +270,99 @@ static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
		offset += chunk;
	}

	if (strp->stm.full_len == skb->len) {
read_done:
	return in_len - len;
}

static int tls_strp_copyin_skb(struct tls_strparser *strp, struct sk_buff *skb,
			       struct sk_buff *in_skb, unsigned int offset,
			       size_t in_len)
{
	struct sk_buff *nskb, *first, *last;
	struct skb_shared_info *shinfo;
	size_t chunk;
	int sz;

	if (strp->stm.full_len)
		chunk = strp->stm.full_len - skb->len;
	else
		chunk = TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE;
	chunk = min(chunk, in_len);

	nskb = tls_strp_skb_copy(strp, in_skb, offset, chunk);
	if (!nskb)
		return -ENOMEM;

	shinfo = skb_shinfo(skb);
	if (!shinfo->frag_list) {
		shinfo->frag_list = nskb;
		nskb->prev = nskb;
	} else {
		first = shinfo->frag_list;
		last = first->prev;
		last->next = nskb;
		first->prev = nskb;
	}

	skb->len += chunk;
	skb->data_len += chunk;

	if (!strp->stm.full_len) {
		sz = tls_rx_msg_size(strp, skb);
		if (sz < 0)
			return sz;

		/* We may have over-read, sz == 0 is guaranteed under-read */
		if (unlikely(sz && sz < skb->len)) {
			int over = skb->len - sz;

			WARN_ON_ONCE(over > chunk);
			skb->len -= over;
			skb->data_len -= over;
			__pskb_trim(nskb, nskb->len - over);

			chunk -= over;
		}

		strp->stm.full_len = sz;
	}

	return chunk;
}

static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
			   unsigned int offset, size_t in_len)
{
	struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data;
	struct sk_buff *skb;
	int ret;

	if (strp->msg_ready)
		return 0;

	skb = strp->anchor;
	if (!skb->len)
		skb_copy_decrypted(skb, in_skb);
	else
		strp->mixed_decrypted |= !!skb_cmp_decrypted(skb, in_skb);

	if (IS_ENABLED(CONFIG_TLS_DEVICE) && strp->mixed_decrypted)
		ret = tls_strp_copyin_skb(strp, skb, in_skb, offset, in_len);
	else
		ret = tls_strp_copyin_frag(strp, skb, in_skb, offset, in_len);
	if (ret < 0) {
		desc->error = ret;
		ret = 0;
	}

	if (strp->stm.full_len && strp->stm.full_len == skb->len) {
		desc->count = 0;

		strp->msg_ready = 1;
		tls_rx_msg_ready(strp);
	}

read_done:
	return in_len - len;
	return ret;
}

static int tls_strp_read_copyin(struct tls_strparser *strp)
@@ -315,15 +422,19 @@ static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort)
	return 0;
}

static bool tls_strp_check_no_dup(struct tls_strparser *strp)
static bool tls_strp_check_queue_ok(struct tls_strparser *strp)
{
	unsigned int len = strp->stm.offset + strp->stm.full_len;
	struct sk_buff *skb;
	struct sk_buff *first, *skb;
	u32 seq;

	skb = skb_shinfo(strp->anchor)->frag_list;
	seq = TCP_SKB_CB(skb)->seq;
	first = skb_shinfo(strp->anchor)->frag_list;
	skb = first;
	seq = TCP_SKB_CB(first)->seq;

	/* Make sure there's no duplicate data in the queue,
	 * and the decrypted status matches.
	 */
	while (skb->len < len) {
		seq += skb->len;
		len -= skb->len;
@@ -331,6 +442,8 @@ static bool tls_strp_check_no_dup(struct tls_strparser *strp)

		if (TCP_SKB_CB(skb)->seq != seq)
			return false;
		if (skb_cmp_decrypted(first, skb))
			return false;
	}

	return true;
@@ -411,7 +524,7 @@ static int tls_strp_read_sock(struct tls_strparser *strp)
			return tls_strp_read_copy(strp, true);
	}

	if (!tls_strp_check_no_dup(strp))
	if (!tls_strp_check_queue_ok(strp))
		return tls_strp_read_copy(strp, false);

	strp->msg_ready = 1;
Loading