Commit d1f66ac6 authored by David S. Miller's avatar David S. Miller
Browse files

Merge branch 'tls-rx-refactor-part-1'



Jakub Kicinski says:

====================
tls: rx: random refactoring part 1

TLS Rx refactoring. Part 1 of 3. A couple of features to follow.
====================

Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents dc2e0617 71471ca3
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -70,6 +70,10 @@ struct sk_skb_cb {
	 * when dst_reg == src_reg.
	 */
	u64 temp_reg;
	struct tls_msg {
		u8 control;
		u8 decrypted;
	} tls;
};

static inline struct strp_msg *strp_msg(struct sk_buff *skb)
+4 −8
Original line number Diff line number Diff line
@@ -64,6 +64,7 @@
#define TLS_AAD_SPACE_SIZE		13

#define MAX_IV_SIZE			16
#define TLS_TAG_SIZE			16
#define TLS_MAX_REC_SEQ_SIZE		8

/* For CCM mode, the full 16-bytes of IV is made of '4' fields of given sizes.
@@ -117,11 +118,6 @@ struct tls_rec {
	u8 aead_req_ctx[];
};

struct tls_msg {
	struct strp_msg rxm;
	u8 control;
};

struct tx_work {
	struct delayed_work work;
	struct sock *sk;
@@ -152,9 +148,7 @@ struct tls_sw_context_rx {
	void (*saved_data_ready)(struct sock *sk);

	struct sk_buff *recv_pkt;
	u8 control;
	u8 async_capable:1;
	u8 decrypted:1;
	atomic_t decrypt_pending;
	/* protect crypto_wait with decrypt_pending*/
	spinlock_t decrypt_compl_lock;
@@ -411,7 +405,9 @@ void tls_free_partial_record(struct sock *sk, struct tls_context *ctx);

static inline struct tls_msg *tls_msg(struct sk_buff *skb)
{
	return (struct tls_msg *)strp_msg(skb);
	struct sk_skb_cb *scb = (struct sk_skb_cb *)skb->cb;

	return &scb->tls;
}

static inline bool tls_is_partially_sent_record(struct tls_context *ctx)
+2 −4
Original line number Diff line number Diff line
@@ -962,11 +962,9 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
				   tls_ctx->rx.rec_seq, rxm->full_len,
				   is_encrypted, is_decrypted);

	ctx->sw.decrypted |= is_decrypted;

	if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags))) {
		if (likely(is_encrypted || is_decrypted))
			return 0;
			return is_decrypted;

		/* After tls_device_down disables the offload, the next SKB will
		 * likely have initial fragments decrypted, and final ones not
@@ -981,7 +979,7 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
	 */
	if (is_decrypted) {
		ctx->resync_nh_reset = 1;
		return 0;
		return is_decrypted;
	}
	if (is_encrypted) {
		tls_device_core_ctrl_rx_resync(tls_ctx, ctx, sk, skb);
+60 −69
Original line number Diff line number Diff line
@@ -128,32 +128,31 @@ static int skb_nsg(struct sk_buff *skb, int offset, int len)
        return __skb_nsg(skb, offset, len, 0);
}

static int padding_length(struct tls_sw_context_rx *ctx,
			  struct tls_prot_info *prot, struct sk_buff *skb)
static int padding_length(struct tls_prot_info *prot, struct sk_buff *skb)
{
	struct strp_msg *rxm = strp_msg(skb);
	struct tls_msg *tlm = tls_msg(skb);
	int sub = 0;

	/* Determine zero-padding length */
	if (prot->version == TLS_1_3_VERSION) {
		int offset = rxm->full_len - TLS_TAG_SIZE - 1;
		char content_type = 0;
		int err;
		int back = 17;

		while (content_type == 0) {
			if (back > rxm->full_len - prot->prepend_size)
			if (offset < prot->prepend_size)
				return -EBADMSG;
			err = skb_copy_bits(skb,
					    rxm->offset + rxm->full_len - back,
			err = skb_copy_bits(skb, rxm->offset + offset,
					    &content_type, 1);
			if (err)
				return err;
			if (content_type)
				break;
			sub++;
			back++;
			offset--;
		}
		ctx->control = content_type;
		tlm->control = content_type;
	}
	return sub;
}
@@ -187,7 +186,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
		struct strp_msg *rxm = strp_msg(skb);
		int pad;

		pad = padding_length(ctx, prot, skb);
		pad = padding_length(prot, skb);
		if (pad < 0) {
			ctx->async_wait.err = pad;
			tls_err_abort(skb->sk, pad);
@@ -1421,6 +1420,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
	struct tls_prot_info *prot = &tls_ctx->prot_info;
	struct strp_msg *rxm = strp_msg(skb);
	struct tls_msg *tlm = tls_msg(skb);
	int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
	struct aead_request *aead_req;
	struct sk_buff *unused;
@@ -1505,7 +1505,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
	/* Prepare AAD */
	tls_make_aad(aad, rxm->full_len - prot->overhead_size +
		     prot->tail_size,
		     tls_ctx->rx.rec_seq, ctx->control, prot);
		     tls_ctx->rx.rec_seq, tlm->control, prot);

	/* Prepare sgin */
	sg_init_table(sgin, n_sgin);
@@ -1561,36 +1561,38 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
			      bool async)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
	struct tls_prot_info *prot = &tls_ctx->prot_info;
	struct strp_msg *rxm = strp_msg(skb);
	int pad, err = 0;
	struct tls_msg *tlm = tls_msg(skb);
	int pad, err;

	if (tlm->decrypted) {
		*zc = false;
		return 0;
	}

	if (!ctx->decrypted) {
	if (tls_ctx->rx_conf == TLS_HW) {
		err = tls_device_decrypted(sk, tls_ctx, skb, rxm);
		if (err < 0)
			return err;
		if (err > 0) {
			tlm->decrypted = 1;
			*zc = false;
			goto decrypt_done;
		}
	}

		/* Still not decrypted after tls_device */
		if (!ctx->decrypted) {
			err = decrypt_internal(sk, skb, dest, NULL, chunk, zc,
					       async);
	err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, async);
	if (err < 0) {
		if (err == -EINPROGRESS)
					tls_advance_record_sn(sk, prot,
							      &tls_ctx->rx);
			tls_advance_record_sn(sk, prot, &tls_ctx->rx);
		else if (err == -EBADMSG)
					TLS_INC_STATS(sock_net(sk),
						      LINUX_MIB_TLSDECRYPTERROR);
			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
		return err;
	}
		} else {
			*zc = false;
		}

		pad = padding_length(ctx, prot, skb);
decrypt_done:
	pad = padding_length(prot, skb);
	if (pad < 0)
		return pad;

@@ -1598,13 +1600,9 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
	rxm->offset += prot->prepend_size;
	rxm->full_len -= prot->overhead_size;
	tls_advance_record_sn(sk, prot, &tls_ctx->rx);
		ctx->decrypted = 1;
		ctx->saved_data_ready(sk);
	} else {
		*zc = false;
	}
	tlm->decrypted = 1;

	return err;
	return 0;
}

int decrypt_skb(struct sock *sk, struct sk_buff *skb,
@@ -1760,6 +1758,7 @@ int tls_sw_recvmsg(struct sock *sk,
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
	struct tls_prot_info *prot = &tls_ctx->prot_info;
	struct sk_psock *psock;
	int num_async, pending;
	unsigned char control = 0;
	ssize_t decrypted = 0;
	struct strp_msg *rxm;
@@ -1772,8 +1771,6 @@ int tls_sw_recvmsg(struct sock *sk,
	bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
	bool is_peek = flags & MSG_PEEK;
	bool bpf_strp_enabled;
	int num_async = 0;
	int pending;

	flags |= nonblock;

@@ -1790,17 +1787,18 @@ int tls_sw_recvmsg(struct sock *sk,
	if (err < 0) {
		tls_err_abort(sk, err);
		goto end;
	} else {
		copied = err;
	}

	copied = err;
	if (len <= copied)
		goto recv_end;
		goto end;

	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
	len = len - copied;
	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);

	decrypted = 0;
	num_async = 0;
	while (len && (decrypted + copied < target || ctx->recv_pkt)) {
		bool retain_skb = false;
		bool zc = false;
@@ -1822,26 +1820,21 @@ int tls_sw_recvmsg(struct sock *sk,
				}
			}
			goto recv_end;
		} else {
			tlm = tls_msg(skb);
			if (prot->version == TLS_1_3_VERSION)
				tlm->control = 0;
			else
				tlm->control = ctx->control;
		}

		rxm = strp_msg(skb);
		tlm = tls_msg(skb);

		to_decrypt = rxm->full_len - prot->overhead_size;

		if (to_decrypt <= len && !is_kvec && !is_peek &&
		    ctx->control == TLS_RECORD_TYPE_DATA &&
		    tlm->control == TLS_RECORD_TYPE_DATA &&
		    prot->version != TLS_1_3_VERSION &&
		    !bpf_strp_enabled)
			zc = true;

		/* Do not use async mode if record is non-data */
		if (ctx->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
		if (tlm->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
			async_capable = ctx->async_capable;
		else
			async_capable = false;
@@ -1856,8 +1849,6 @@ int tls_sw_recvmsg(struct sock *sk,
		if (err == -EINPROGRESS) {
			async = true;
			num_async++;
		} else if (prot->version == TLS_1_3_VERSION) {
			tlm->control = ctx->control;
		}

		/* If the type of records being processed is not known yet,
@@ -2005,6 +1996,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
	struct strp_msg *rxm = NULL;
	struct sock *sk = sock->sk;
	struct tls_msg *tlm;
	struct sk_buff *skb;
	ssize_t copied = 0;
	bool from_queue;
@@ -2033,14 +2025,15 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
		}
	}

	rxm = strp_msg(skb);
	tlm = tls_msg(skb);

	/* splice does not support reading control messages */
	if (ctx->control != TLS_RECORD_TYPE_DATA) {
	if (tlm->control != TLS_RECORD_TYPE_DATA) {
		err = -EINVAL;
		goto splice_read_end;
	}

	rxm = strp_msg(skb);

	chunk = min_t(unsigned int, rxm->full_len, len);
	copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
	if (copied < 0)
@@ -2084,10 +2077,10 @@ bool tls_sw_sock_is_readable(struct sock *sk)
static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
{
	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
	struct tls_prot_info *prot = &tls_ctx->prot_info;
	char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
	struct strp_msg *rxm = strp_msg(skb);
	struct tls_msg *tlm = tls_msg(skb);
	size_t cipher_overhead;
	size_t data_len = 0;
	int ret;
@@ -2104,11 +2097,11 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)

	/* Linearize header to local buffer */
	ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size);

	if (ret < 0)
		goto read_failure;

	ctx->control = header[0];
	tlm->decrypted = 0;
	tlm->control = header[0];

	data_len = ((header[4] & 0xFF) | (header[3] << 8));

@@ -2149,8 +2142,6 @@ static void tls_queue(struct strparser *strp, struct sk_buff *skb)
	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);

	ctx->decrypted = 0;

	ctx->recv_pkt = skb;
	strp_pause(strp);

@@ -2501,7 +2492,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)

	/* Sanity-check the sizes for stack allocations. */
	if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
	    rec_seq_size > TLS_MAX_REC_SEQ_SIZE) {
	    rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE) {
		rc = -EINVAL;
		goto free_priv;
	}